1pub mod prompt;
15pub mod tools;
16
17use anyhow::Result;
18use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
19use futures::Stream;
20use futures::stream::{self, StreamExt};
21use prompt::OAIPromptFormatter;
22use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
23use std::{collections::HashMap, pin::Pin, sync::Arc};
24use tracing;
25
26use crate::model_card::{ModelDeploymentCard, ModelInfo};
27use crate::preprocessor::prompt::OAIChatLikeRequest;
28use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
29use crate::tokenizers::Encoding;
30
31use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
32use dynamo_runtime::pipeline::{
33 AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
34};
35use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
36
37use crate::protocols::{
38 common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
39 openai::{
40 DeltaGeneratorExt,
41 chat_completions::{
42 NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, jail::JailedStream,
43 },
44 completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
45 embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
46 nvext::NvExtProvider,
47 },
48};
49use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
50
51use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
52
53pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
54pub use crate::protocols::common::preprocessor::PreprocessedEmbeddingRequest;
55
56use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
57
58pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
59pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
60pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
62pub struct LLMMetricAnnotation {
63 pub input_tokens: usize,
64 pub output_tokens: usize,
65 pub chunk_tokens: usize,
66}
67
68impl LLMMetricAnnotation {
69 pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
71 Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
72 }
73
74 pub fn from_annotation<T>(
76 annotation: &Annotated<T>,
77 ) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
78 if annotation.event.is_none() {
79 return Ok(None);
80 }
81 if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
82 return Ok(None);
83 }
84 let comments = annotation
85 .comment
86 .as_ref()
87 .ok_or("missing comments block")?;
88 if comments.len() != 1 {
89 return Err("malformed comments block - expected exactly 1 comment".into());
90 }
91 let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
92 Ok(Some(metrics))
93 }
94}
95
96pub struct OpenAIPreprocessor {
97 mdcsum: String,
98 formatter: Arc<dyn OAIPromptFormatter>,
99 tokenizer: Arc<dyn Tokenizer>,
100 model_info: Arc<dyn ModelInfo>,
101 runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
103 tool_call_parser: Option<String>,
104}
105
106impl OpenAIPreprocessor {
107 pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
108 let formatter = PromptFormatter::from_mdc(&mdc)?;
109 let tokenizer = mdc.tokenizer_hf()?;
110 match formatter {
111 PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
112 }
113 }
114
115 pub fn new_with_parts(
116 mdc: ModelDeploymentCard,
117 formatter: Arc<dyn OAIPromptFormatter>,
118 hf_tokenizer: tokenizers::Tokenizer,
119 ) -> Result<Arc<Self>> {
120 let mdcsum = mdc.mdcsum();
121 let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
122 let Some(model_info) = mdc.model_info else {
123 anyhow::bail!(
124 "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
125 );
126 };
127 let model_info = model_info.get_model_info()?;
128 let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
129
130 let runtime_config = mdc.runtime_config.clone();
132
133 Ok(Arc::new(Self {
134 formatter,
135 tokenizer,
136 model_info,
137 mdcsum,
138 runtime_config,
139 tool_call_parser,
140 }))
141 }
142 pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
144 self.tokenizer.encode(s)
145 }
146
147 pub fn preprocess_request<
154 R: OAIChatLikeRequest
155 + AnnotationsProvider
156 + SamplingOptionsProvider
157 + StopConditionsProvider
158 + OutputOptionsProvider
159 + NvExtProvider,
160 >(
161 &self,
162 request: &R,
163 ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
164 let mut builder = self.builder(request)?;
165 let formatted_prompt = self.apply_template(request)?;
166 let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
167
168 Ok((builder.build()?, annotations))
169 }
170
171 pub fn builder<
172 R: OAIChatLikeRequest
173 + AnnotationsProvider
174 + SamplingOptionsProvider
175 + StopConditionsProvider
176 + OutputOptionsProvider
177 + NvExtProvider,
178 >(
179 &self,
180 request: &R,
181 ) -> Result<PreprocessedRequestBuilder> {
182 let mut builder = PreprocessedRequest::builder();
183 builder.model(request.model());
184
185 let mut stop_conditions = request.extract_stop_conditions()?;
186 if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
187 for eos_token in self.model_info.eos_token_ids() {
188 if !stop_tokens.contains(&eos_token) {
189 stop_tokens.push(eos_token);
190 }
191 }
192 } else {
193 stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
194 }
195
196 stop_conditions.apply_ignore_eos();
198
199 if !stop_conditions.ignore_eos.unwrap_or(false) {
200 builder.eos_token_ids(self.model_info.eos_token_ids());
201 }
202
203 builder.stop_conditions(stop_conditions);
204 builder.sampling_options(request.extract_sampling_options()?);
205 builder.output_options(request.extract_output_options()?);
206 builder.annotations(request.annotations().unwrap_or_default());
207 builder.mdc_sum(Some(self.mdcsum.clone()));
208 builder.estimated_prefix_hit_num_blocks(None);
209 if let Some(nvext) = request.nvext() {
211 builder.backend_instance_id(nvext.backend_instance_id);
212 }
213
214 Ok(builder)
215 }
216
217 pub fn apply_template<
218 R: OAIChatLikeRequest
219 + AnnotationsProvider
220 + SamplingOptionsProvider
221 + StopConditionsProvider
222 + OutputOptionsProvider
223 + NvExtProvider,
224 >(
225 &self,
226 request: &R,
227 ) -> Result<Option<String>> {
228 if let PromptInput::Text(_) = request.prompt_input_type()
229 && let Some(TextInput::Single(_)) = request.extract_text()
230 {
231 let use_raw_prompt = request
232 .nvext()
233 .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
234
235 let formatted_prompt = if use_raw_prompt {
236 match request.raw_prompt() {
237 Some(prompt) => prompt,
238 None => {
239 tracing::warn!("Raw prompt requested but not available");
240 self.formatter.render(request)?
241 }
242 }
243 } else {
244 self.formatter.render(request)?
245 };
246 Ok(Some(formatted_prompt))
247 } else {
248 Ok(None)
249 }
250 }
251
252 pub fn gather_tokens<
253 R: OAIChatLikeRequest
254 + AnnotationsProvider
255 + SamplingOptionsProvider
256 + StopConditionsProvider
257 + OutputOptionsProvider
258 + NvExtProvider,
259 >(
260 &self,
261 request: &R,
262 builder: &mut PreprocessedRequestBuilder,
263 formatted_prompt: Option<String>,
264 ) -> Result<HashMap<String, String>> {
265 let mut annotations = HashMap::new();
266 match request.prompt_input_type() {
268 PromptInput::Tokens(_) => {
269 if let Some(token_input) = request.extract_tokens() {
270 match token_input {
271 TokenInput::Single(tokens) => {
272 builder.token_ids(tokens);
273 }
274 TokenInput::Batch(token_batches) => {
275 if token_batches.len() == 1 {
276 builder.token_ids(token_batches[0].clone());
277 } else {
278 builder.batch_token_ids(Some(token_batches));
279 builder.token_ids(vec![]);
280 }
281 }
282 }
283 }
284 }
285 PromptInput::Text(_) => {
286 if let Some(text_input) = request.extract_text() {
287 match text_input {
288 TextInput::Single(raw_prompt) => {
289 if let Some(f) = formatted_prompt.as_ref()
290 && request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
291 {
292 annotations
293 .insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
294 }
295
296 let prompt = formatted_prompt.unwrap_or(raw_prompt);
298
299 let has_backend_instance_id = request
301 .nvext()
302 .and_then(|ext| ext.backend_instance_id)
303 .is_some();
304
305 let token_data =
306 request.nvext().and_then(|ext| ext.token_data.as_ref());
307
308 let (tokens_vec, skip_token_annotation) = if has_backend_instance_id {
309 if let Some(tokens) = token_data {
310 tracing::trace!(
311 "Using provided tokens from EPP: {} ids",
312 tokens.len()
313 );
314 (tokens.clone(), true)
316 } else {
317 tracing::warn!(
318 "backend_instance_id provided but no token_data; tokenizing prompt"
319 );
320 let encoding = self.tokenizer.encode(&prompt)?;
321 (encoding.token_ids().to_vec(), false)
322 }
323 } else {
324 let encoding = self.tokenizer.encode(&prompt)?;
326 (encoding.token_ids().to_vec(), false)
327 };
328
329 if request.has_annotation(ANNOTATION_TOKEN_IDS)
330 && !skip_token_annotation
331 {
332 annotations.insert(
333 ANNOTATION_TOKEN_IDS.to_string(),
334 serde_json::to_string(&tokens_vec)?,
335 );
336 }
337
338 builder.token_ids(tokens_vec);
339 }
340 TextInput::Batch(texts) => {
341 let token_batches: Vec<Vec<u32>> = texts
342 .par_iter()
343 .map(|text| {
344 self.tokenizer
345 .encode(text)
346 .map(|encoded| encoded.token_ids().to_vec())
347 })
348 .collect::<Result<Vec<_>>>()?;
349 builder.batch_token_ids(Some(token_batches));
350 builder.token_ids(vec![]);
351 }
352 }
353 }
354 }
355 }
356 Ok(annotations)
357 }
358
359 pub async fn preprocess_embedding_request(
366 &self,
367 request: &NvCreateEmbeddingRequest,
368 ) -> Result<(PreprocessedEmbeddingRequest, HashMap<String, String>)> {
369 let mut annotations = HashMap::new();
370 let mut builder = PreprocessedEmbeddingRequest::builder();
371
372 let all_token_ids = match &request.inner.input {
373 dynamo_async_openai::types::EmbeddingInput::String(s) => {
374 let encoding = self.tokenizer.encode(s)?;
375 vec![encoding.token_ids().to_vec()]
376 }
377 dynamo_async_openai::types::EmbeddingInput::StringArray(arr) => {
378 let input_strs: Vec<String> = arr.to_vec();
379 let encodings = tokio::task::spawn_blocking({
380 let tokenizer = self.tokenizer.clone();
381 let strs = input_strs.clone();
382 move || {
383 tokenizer.encode_batch(&strs.iter().map(|s| s.as_str()).collect::<Vec<_>>())
384 }
385 })
386 .await??;
387 let token_arrays: Vec<Vec<u32>> = encodings
388 .into_iter()
389 .map(|encoding| encoding.token_ids().to_vec())
390 .collect();
391 token_arrays
392 }
393 dynamo_async_openai::types::EmbeddingInput::IntegerArray(token_ids) => {
394 vec![token_ids.clone()]
395 }
396 dynamo_async_openai::types::EmbeddingInput::ArrayOfIntegerArray(token_arrays) => {
397 token_arrays.clone()
398 }
399 };
400
401 if request.has_annotation(ANNOTATION_TOKEN_IDS) {
403 annotations.insert(
404 ANNOTATION_TOKEN_IDS.to_string(),
405 serde_json::to_string(&all_token_ids)?,
406 );
407 }
408
409 builder.token_ids(all_token_ids);
410 builder.model(request.inner.model.clone());
411 builder.encoding_format(request.inner.encoding_format.as_ref().map(|f| match f {
412 EncodingFormat::Float => "float".to_string(),
413 EncodingFormat::Base64 => "base64".to_string(),
414 }));
415 builder.dimensions(request.inner.dimensions);
416
417 builder.annotations(request.annotations().unwrap_or_default());
418 builder.mdc_sum(Some(self.mdcsum.clone()));
419
420 Ok((builder.build()?, annotations))
421 }
422
423 pub fn transform_postprocessor_stream<S, Resp>(
424 stream: S,
425 generator: Box<dyn DeltaGeneratorExt<Resp>>,
426 context: Arc<dyn AsyncEngineContext>,
427 ) -> impl Stream<Item = Annotated<Resp>> + Send
428 where
429 S: Stream<Item = Annotated<BackendOutput>> + Send + 'static,
430 Resp: Send + Sync + 'static + std::fmt::Debug,
431 {
432 struct State<Resp>
433 where
434 Resp: Send + Sync + 'static + std::fmt::Debug,
435 {
436 response_stream: Pin<Box<dyn Stream<Item = Annotated<BackendOutput>> + Send>>,
437 response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
438 context: Arc<dyn AsyncEngineContext>,
439 cancelled: bool,
440 cumulative_output_tokens: usize,
441 finish_reason_sent: bool,
442 usage_chunk_sent: bool,
443 finished: bool,
444 }
445
446 let state = State {
447 response_stream: Box::pin(stream),
448 response_generator: generator,
449 context: context.clone(),
450 cancelled: false,
451 cumulative_output_tokens: 0,
452 finish_reason_sent: false,
453 usage_chunk_sent: false,
454 finished: false,
455 };
456
457 stream::unfold(state, |mut inner| {
460 async move {
461 if inner.finished {
463 return None;
464 }
465
466 if let Some(response) = inner.response_stream.next().await {
467 if inner.cancelled {
468 tracing::debug!(
469 request_id = inner.context.id(),
470 "Cancellation issued last message; closing stream"
471 );
472 inner.finished = true; return None;
474 }
475
476 tracing::trace!(
477 request_id = inner.context.id(),
478 "Processing common response: {:?}",
479 response
480 );
481
482 let has_finish_reason = response
484 .data
485 .as_ref()
486 .map(|d| d.finish_reason.is_some())
487 .unwrap_or(false);
488
489 let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
490 let chunk_tokens = backend_output.token_ids.len();
491 inner.cumulative_output_tokens += chunk_tokens;
492
493 let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
494
495 (chunk_tokens, isl)
496 } else {
497 (0, 0)
498 };
499
500 let current_osl = inner.cumulative_output_tokens;
501
502 let mut response = response.map_data(|data| {
503 inner
504 .response_generator
505 .choice_from_postprocessor(data)
506 .inspect_err(|e| {
507 tracing::error!(
508 request_id = inner.context.id(),
509 "Error processing common response: {:?}",
510 e
511 );
512 inner.cancelled = true;
513 inner.context.stop_generating();
514 })
515 .map_err(|e| e.to_string())
516 });
517
518 let llm_metrics = LLMMetricAnnotation {
520 input_tokens: isl,
521 output_tokens: current_osl,
522 chunk_tokens,
523 };
524
525 if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
526 if response.event.is_none() {
528 response.event = metrics_annotated.event;
529 response.comment = metrics_annotated.comment;
530 }
531 }
532
533 if has_finish_reason {
535 inner.finish_reason_sent = true;
536 }
537
538 tracing::trace!(
539 request_id = inner.context.id(),
540 "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
541 response
542 );
543
544 Some((response, inner))
545 } else {
546 inner.finished = true;
549
550 if inner.response_generator.is_usage_enabled()
552 && inner.finish_reason_sent
553 && !inner.usage_chunk_sent
554 {
555 inner.usage_chunk_sent = true;
556
557 let usage_chunk = inner.response_generator.create_usage_chunk();
559 let annotated_usage = Annotated::<Resp> {
560 id: None,
561 data: Some(usage_chunk),
562 event: Some(ANNOTATION_LLM_METRICS.to_string()),
563 comment: None,
564 };
565
566 tracing::trace!(
567 request_id = inner.context.id(),
568 "Sending final usage chunk for OpenAI compliance"
569 );
570
571 Some((annotated_usage, inner))
572 } else {
573 None
575 }
576 }
577 }
578 })
579 }
580
581 pub fn transform_embedding_postprocessor_stream<S>(
583 stream: S,
584 original_request: NvCreateEmbeddingRequest,
585 ) -> impl Stream<Item = Annotated<NvCreateEmbeddingResponse>> + Send
586 where
587 S: Stream<Item = Annotated<EmbeddingsEngineOutput>> + Send + 'static,
588 {
589 stream.map(move |output| {
590 output.map_data(|engine_output| {
591 let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output
593 .embeddings
594 .into_iter()
595 .enumerate()
596 .map(|(index, embedding)| dynamo_async_openai::types::Embedding {
597 index: index as u32,
598 object: "embedding".to_string(),
599 embedding: embedding.into_iter().map(|f| f as f32).collect(),
600 })
601 .collect();
602
603 let response = NvCreateEmbeddingResponse {
604 inner: dynamo_async_openai::types::CreateEmbeddingResponse {
605 object: "list".to_string(),
606 model: original_request.inner.model.clone(),
607 data: embeddings,
608 usage: dynamo_async_openai::types::EmbeddingUsage {
609 prompt_tokens: engine_output.prompt_tokens,
610 total_tokens: engine_output.total_tokens,
611 },
612 },
613 };
614
615 Ok(response)
616 })
617 })
618 }
619
620 pub fn should_apply_tool_jail(
623 tool_call_parser: Option<&String>,
624 tool_choice: Option<&ChatCompletionToolChoiceOption>,
625 has_tools: bool,
626 ) -> std::result::Result<bool, Error> {
627 match (tool_call_parser, tool_choice, has_tools) {
628 (None, Some(ChatCompletionToolChoiceOption::Required), true) => {
630 tracing::warn!(
631 "Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
632 );
633 Ok(false)
634 }
635 (None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
636 tracing::warn!(
637 "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
638 );
639 Ok(false)
640 }
641 (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
642 tracing::warn!(
643 "Named tool choice specified but no tool parser configured; proceeding without jailing"
644 );
645 Ok(false)
646 }
647
648 (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
650 Ok(false) }
652 (Some(_), Some(_), true) => Ok(true), (Some(_), None, true) => Ok(true), _ => Ok(false),
657 }
658 }
659
660 pub fn apply_tool_calling_jail<S>(
662 tool_call_parser: String,
663 stream: S,
664 ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
665 where
666 S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
667 {
668 let jail = JailedStream::builder()
669 .tool_call_parser(tool_call_parser)
670 .build();
671 jail.apply(stream)
672 }
673}
674
675#[async_trait]
681impl
682 Operator<
683 SingleIn<NvCreateChatCompletionRequest>,
684 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
685 SingleIn<PreprocessedRequest>,
686 ManyOut<Annotated<BackendOutput>>,
687 > for OpenAIPreprocessor
688{
689 async fn generate(
690 &self,
691 request: SingleIn<NvCreateChatCompletionRequest>,
692 next: Arc<
693 dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
694 >,
695 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
696 let (request, context) = request.into_parts();
698
699 let response_generator = request.response_generator(context.id().to_string());
701
702 let (common_request, annotations) = self.preprocess_request(&request)?;
704
705 let mut response_generator = Box::new(response_generator);
706
707 response_generator.set_reasoning_parser(self.runtime_config.clone());
709
710 response_generator.update_isl(common_request.token_ids.len() as u32);
712
713 let common_request = context.map(|_| common_request);
715
716 let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
718 .into_iter()
719 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
720 .collect();
721 let annotations_stream = stream::iter(annotations);
722
723 let response_stream = next.generate(common_request).await?;
725
726 let context = response_stream.context();
728
729 let stream = Self::transform_postprocessor_stream(
731 response_stream,
732 response_generator,
733 context.clone(),
734 );
735
736 let has_tools =
738 request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
739
740 let should_jail = Self::should_apply_tool_jail(
744 self.tool_call_parser.as_ref(),
745 request.inner.tool_choice.as_ref(),
746 has_tools,
747 )?;
748
749 let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
751 if let Some(parser) = self.tool_call_parser.clone() {
752 Box::pin(Self::apply_tool_calling_jail(parser, stream))
753 } else {
754 Box::pin(stream) }
756 } else {
757 Box::pin(stream)
758 };
759 let stream = annotations_stream.chain(stream);
761
762 Ok(ResponseStream::new(Box::pin(stream), context))
764 }
765}
766
767#[async_trait]
768impl
769 Operator<
770 SingleIn<NvCreateCompletionRequest>,
771 ManyOut<Annotated<NvCreateCompletionResponse>>,
772 SingleIn<PreprocessedRequest>,
773 ManyOut<Annotated<BackendOutput>>,
774 > for OpenAIPreprocessor
775{
776 async fn generate(
777 &self,
778 request: SingleIn<NvCreateCompletionRequest>,
779 next: Arc<
780 dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
781 >,
782 ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
783 let (request, context) = request.into_parts();
785
786 let response_generator = request.response_generator(context.id().to_string());
788 let mut response_generator = Box::new(response_generator);
789 let mut builder = self.builder(&request)?;
791 let annotations = self.gather_tokens(&request, &mut builder, None)?;
792 let common_request = builder.build()?;
793
794 response_generator.update_isl(common_request.token_ids.len() as u32);
796
797 let common_request = context.map(|_| common_request);
799
800 let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
802 .into_iter()
803 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
804 .collect();
805 let annotations_stream = stream::iter(annotations);
806
807 let response_stream = next.generate(common_request).await?;
809
810 let context = response_stream.context();
812
813 let stream = Self::transform_postprocessor_stream(
815 response_stream,
816 response_generator,
817 context.clone(),
818 );
819
820 let stream = annotations_stream.chain(stream);
822
823 Ok(ResponseStream::new(Box::pin(stream), context))
825 }
826}
827
828#[async_trait]
829impl
830 Operator<
831 SingleIn<NvCreateEmbeddingRequest>,
832 ManyOut<Annotated<NvCreateEmbeddingResponse>>,
833 SingleIn<PreprocessedEmbeddingRequest>,
834 ManyOut<Annotated<EmbeddingsEngineOutput>>,
835 > for OpenAIPreprocessor
836{
837 async fn generate(
838 &self,
839 request: SingleIn<NvCreateEmbeddingRequest>,
840 next: Arc<
841 dyn AsyncEngine<
842 SingleIn<PreprocessedEmbeddingRequest>,
843 ManyOut<Annotated<EmbeddingsEngineOutput>>,
844 Error,
845 >,
846 >,
847 ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
848 let (request, context) = request.into_parts();
850
851 let (preprocessed_request, annotations) =
853 self.preprocess_embedding_request(&request).await?;
854
855 let preprocessed_request = context.map(|_| preprocessed_request);
857 let response_stream = next.generate(preprocessed_request).await?;
858
859 let context = response_stream.context();
861
862 let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
864
865 let annotations_stream = stream::iter(
867 annotations
868 .into_iter()
869 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
870 .collect::<Vec<_>>(),
871 );
872
873 let combined_stream = annotations_stream.chain(stream);
874 Ok(ResponseStream::new(Box::pin(combined_stream), context))
875 }
876}
877
878