dynamo_llm/
preprocessor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The Preprocessor consists of the following modules
5//!
6//! - `translation`: This module converts the allowed Ingress message types to the corresponding
7//!   internal representation.
8//! - `apply`: This module applies ModelConfig defaults to any empty optional fields specified
9//! - `prompt`: This module applies any prompt template logic to the internal Request object.
10//! - `tokenize`: This module tokenizes the formatted prompt string and returns the token ids.
11//!
12//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
13
14pub mod prompt;
15pub mod tools;
16
17use anyhow::Result;
18use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
19use futures::Stream;
20use futures::stream::{self, StreamExt};
21use prompt::OAIPromptFormatter;
22use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
23use std::{collections::HashMap, pin::Pin, sync::Arc};
24use tracing;
25
26use crate::model_card::{ModelDeploymentCard, ModelInfo};
27use crate::preprocessor::prompt::OAIChatLikeRequest;
28use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
29use crate::tokenizers::Encoding;
30
31use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
32use dynamo_runtime::pipeline::{
33    AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
34};
35use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
36
37use crate::protocols::{
38    common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
39    openai::{
40        DeltaGeneratorExt,
41        chat_completions::{
42            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, jail::JailedStream,
43        },
44        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
45        embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
46        nvext::NvExtProvider,
47    },
48};
49use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
50
51use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
52
53pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
54pub use crate::protocols::common::preprocessor::PreprocessedEmbeddingRequest;
55
56use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
57
58pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
59pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
60pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
62pub struct LLMMetricAnnotation {
63    pub input_tokens: usize,
64    pub output_tokens: usize,
65    pub chunk_tokens: usize,
66}
67
68impl LLMMetricAnnotation {
69    /// Convert this metrics struct to an Annotated event
70    pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
71        Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
72    }
73
74    /// Extract LLM metrics from an Annotated event, if present
75    pub fn from_annotation<T>(
76        annotation: &Annotated<T>,
77    ) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
78        if annotation.event.is_none() {
79            return Ok(None);
80        }
81        if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
82            return Ok(None);
83        }
84        let comments = annotation
85            .comment
86            .as_ref()
87            .ok_or("missing comments block")?;
88        if comments.len() != 1 {
89            return Err("malformed comments block - expected exactly 1 comment".into());
90        }
91        let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
92        Ok(Some(metrics))
93    }
94}
95
96pub struct OpenAIPreprocessor {
97    mdcsum: String,
98    formatter: Arc<dyn OAIPromptFormatter>,
99    tokenizer: Arc<dyn Tokenizer>,
100    model_info: Arc<dyn ModelInfo>,
101    /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
102    runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
103    tool_call_parser: Option<String>,
104}
105
106impl OpenAIPreprocessor {
107    pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
108        let formatter = PromptFormatter::from_mdc(&mdc)?;
109        let tokenizer = mdc.tokenizer_hf()?;
110        match formatter {
111            PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
112        }
113    }
114
115    pub fn new_with_parts(
116        mdc: ModelDeploymentCard,
117        formatter: Arc<dyn OAIPromptFormatter>,
118        hf_tokenizer: tokenizers::Tokenizer,
119    ) -> Result<Arc<Self>> {
120        let mdcsum = mdc.mdcsum();
121        let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
122        let Some(model_info) = mdc.model_info else {
123            anyhow::bail!(
124                "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
125            );
126        };
127        let model_info = model_info.get_model_info()?;
128        let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
129
130        // // Initialize runtime config from the ModelDeploymentCard
131        let runtime_config = mdc.runtime_config.clone();
132
133        Ok(Arc::new(Self {
134            formatter,
135            tokenizer,
136            model_info,
137            mdcsum,
138            runtime_config,
139            tool_call_parser,
140        }))
141    }
142    /// Encode a string to it's tokens
143    pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
144        self.tokenizer.encode(s)
145    }
146
147    /// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request.
148    /// Returns both the common completion request and a hashmap of annotations.
149    ///
150    /// Annotations evaluated by this method include:
151    /// - `formatted_prompt`
152    /// - `token_ids`
153    pub fn preprocess_request<
154        R: OAIChatLikeRequest
155            + AnnotationsProvider
156            + SamplingOptionsProvider
157            + StopConditionsProvider
158            + OutputOptionsProvider
159            + NvExtProvider,
160    >(
161        &self,
162        request: &R,
163    ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
164        let mut builder = self.builder(request)?;
165        let formatted_prompt = self.apply_template(request)?;
166        let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
167
168        Ok((builder.build()?, annotations))
169    }
170
171    pub fn builder<
172        R: OAIChatLikeRequest
173            + AnnotationsProvider
174            + SamplingOptionsProvider
175            + StopConditionsProvider
176            + OutputOptionsProvider
177            + NvExtProvider,
178    >(
179        &self,
180        request: &R,
181    ) -> Result<PreprocessedRequestBuilder> {
182        let mut builder = PreprocessedRequest::builder();
183        builder.model(request.model());
184
185        let mut stop_conditions = request.extract_stop_conditions()?;
186        if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
187            for eos_token in self.model_info.eos_token_ids() {
188                if !stop_tokens.contains(&eos_token) {
189                    stop_tokens.push(eos_token);
190                }
191            }
192        } else {
193            stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
194        }
195
196        // apply ignore eos if not already set
197        stop_conditions.apply_ignore_eos();
198
199        if !stop_conditions.ignore_eos.unwrap_or(false) {
200            builder.eos_token_ids(self.model_info.eos_token_ids());
201        }
202
203        builder.stop_conditions(stop_conditions);
204        builder.sampling_options(request.extract_sampling_options()?);
205        builder.output_options(request.extract_output_options()?);
206        builder.annotations(request.annotations().unwrap_or_default());
207        builder.mdc_sum(Some(self.mdcsum.clone()));
208        builder.estimated_prefix_hit_num_blocks(None);
209        // Extract backend_instance_id from nvext if present
210        if let Some(nvext) = request.nvext() {
211            builder.backend_instance_id(nvext.backend_instance_id);
212        }
213
214        Ok(builder)
215    }
216
217    pub fn apply_template<
218        R: OAIChatLikeRequest
219            + AnnotationsProvider
220            + SamplingOptionsProvider
221            + StopConditionsProvider
222            + OutputOptionsProvider
223            + NvExtProvider,
224    >(
225        &self,
226        request: &R,
227    ) -> Result<Option<String>> {
228        if let PromptInput::Text(_) = request.prompt_input_type()
229            && let Some(TextInput::Single(_)) = request.extract_text()
230        {
231            let use_raw_prompt = request
232                .nvext()
233                .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
234
235            let formatted_prompt = if use_raw_prompt {
236                match request.raw_prompt() {
237                    Some(prompt) => prompt,
238                    None => {
239                        tracing::warn!("Raw prompt requested but not available");
240                        self.formatter.render(request)?
241                    }
242                }
243            } else {
244                self.formatter.render(request)?
245            };
246            Ok(Some(formatted_prompt))
247        } else {
248            Ok(None)
249        }
250    }
251
252    pub fn gather_tokens<
253        R: OAIChatLikeRequest
254            + AnnotationsProvider
255            + SamplingOptionsProvider
256            + StopConditionsProvider
257            + OutputOptionsProvider
258            + NvExtProvider,
259    >(
260        &self,
261        request: &R,
262        builder: &mut PreprocessedRequestBuilder,
263        formatted_prompt: Option<String>,
264    ) -> Result<HashMap<String, String>> {
265        let mut annotations = HashMap::new();
266        // match request type before any conversion/processing
267        match request.prompt_input_type() {
268            PromptInput::Tokens(_) => {
269                if let Some(token_input) = request.extract_tokens() {
270                    match token_input {
271                        TokenInput::Single(tokens) => {
272                            builder.token_ids(tokens);
273                        }
274                        TokenInput::Batch(token_batches) => {
275                            if token_batches.len() == 1 {
276                                builder.token_ids(token_batches[0].clone());
277                            } else {
278                                builder.batch_token_ids(Some(token_batches));
279                                builder.token_ids(vec![]);
280                            }
281                        }
282                    }
283                }
284            }
285            PromptInput::Text(_) => {
286                if let Some(text_input) = request.extract_text() {
287                    match text_input {
288                        TextInput::Single(raw_prompt) => {
289                            if let Some(f) = formatted_prompt.as_ref()
290                                && request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
291                            {
292                                annotations
293                                    .insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
294                            }
295
296                            // Completions will use raw_prompt, no template
297                            let prompt = formatted_prompt.unwrap_or(raw_prompt);
298
299                            // Check if backend_instance_id is present and token_data is provided
300                            let has_backend_instance_id = request
301                                .nvext()
302                                .and_then(|ext| ext.backend_instance_id)
303                                .is_some();
304
305                            let token_data =
306                                request.nvext().and_then(|ext| ext.token_data.as_ref());
307
308                            let (tokens_vec, skip_token_annotation) = if has_backend_instance_id {
309                                if let Some(tokens) = token_data {
310                                    tracing::trace!(
311                                        "Using provided tokens from EPP: {} ids",
312                                        tokens.len()
313                                    );
314                                    // need ownership for the builder, so clone.
315                                    (tokens.clone(), true)
316                                } else {
317                                    tracing::warn!(
318                                        "backend_instance_id provided but no token_data; tokenizing prompt"
319                                    );
320                                    let encoding = self.tokenizer.encode(&prompt)?;
321                                    (encoding.token_ids().to_vec(), false)
322                                }
323                            } else {
324                                // No backend_instance_id provided, continue the normal flow.
325                                let encoding = self.tokenizer.encode(&prompt)?;
326                                (encoding.token_ids().to_vec(), false)
327                            };
328
329                            if request.has_annotation(ANNOTATION_TOKEN_IDS)
330                                && !skip_token_annotation
331                            {
332                                annotations.insert(
333                                    ANNOTATION_TOKEN_IDS.to_string(),
334                                    serde_json::to_string(&tokens_vec)?,
335                                );
336                            }
337
338                            builder.token_ids(tokens_vec);
339                        }
340                        TextInput::Batch(texts) => {
341                            let token_batches: Vec<Vec<u32>> = texts
342                                .par_iter()
343                                .map(|text| {
344                                    self.tokenizer
345                                        .encode(text)
346                                        .map(|encoded| encoded.token_ids().to_vec())
347                                })
348                                .collect::<Result<Vec<_>>>()?;
349                            builder.batch_token_ids(Some(token_batches));
350                            builder.token_ids(vec![]);
351                        }
352                    }
353                }
354            }
355        }
356        Ok(annotations)
357    }
358
359    /// Preprocess an embedding request, handling both text and token ID inputs.
360    ///
361    /// For text inputs, tokenizes the text using the configured tokenizer.
362    /// For token ID inputs, uses the provided token IDs directly and skips tokenization.
363    ///
364    /// Returns both the preprocessed request and a hashmap of annotations.
365    pub async fn preprocess_embedding_request(
366        &self,
367        request: &NvCreateEmbeddingRequest,
368    ) -> Result<(PreprocessedEmbeddingRequest, HashMap<String, String>)> {
369        let mut annotations = HashMap::new();
370        let mut builder = PreprocessedEmbeddingRequest::builder();
371
372        let all_token_ids = match &request.inner.input {
373            dynamo_async_openai::types::EmbeddingInput::String(s) => {
374                let encoding = self.tokenizer.encode(s)?;
375                vec![encoding.token_ids().to_vec()]
376            }
377            dynamo_async_openai::types::EmbeddingInput::StringArray(arr) => {
378                let input_strs: Vec<String> = arr.to_vec();
379                let encodings = tokio::task::spawn_blocking({
380                    let tokenizer = self.tokenizer.clone();
381                    let strs = input_strs.clone();
382                    move || {
383                        tokenizer.encode_batch(&strs.iter().map(|s| s.as_str()).collect::<Vec<_>>())
384                    }
385                })
386                .await??;
387                let token_arrays: Vec<Vec<u32>> = encodings
388                    .into_iter()
389                    .map(|encoding| encoding.token_ids().to_vec())
390                    .collect();
391                token_arrays
392            }
393            dynamo_async_openai::types::EmbeddingInput::IntegerArray(token_ids) => {
394                vec![token_ids.clone()]
395            }
396            dynamo_async_openai::types::EmbeddingInput::ArrayOfIntegerArray(token_arrays) => {
397                token_arrays.clone()
398            }
399        };
400
401        // Handle annotations
402        if request.has_annotation(ANNOTATION_TOKEN_IDS) {
403            annotations.insert(
404                ANNOTATION_TOKEN_IDS.to_string(),
405                serde_json::to_string(&all_token_ids)?,
406            );
407        }
408
409        builder.token_ids(all_token_ids);
410        builder.model(request.inner.model.clone());
411        builder.encoding_format(request.inner.encoding_format.as_ref().map(|f| match f {
412            EncodingFormat::Float => "float".to_string(),
413            EncodingFormat::Base64 => "base64".to_string(),
414        }));
415        builder.dimensions(request.inner.dimensions);
416
417        builder.annotations(request.annotations().unwrap_or_default());
418        builder.mdc_sum(Some(self.mdcsum.clone()));
419
420        Ok((builder.build()?, annotations))
421    }
422
423    pub fn transform_postprocessor_stream<S, Resp>(
424        stream: S,
425        generator: Box<dyn DeltaGeneratorExt<Resp>>,
426        context: Arc<dyn AsyncEngineContext>,
427    ) -> impl Stream<Item = Annotated<Resp>> + Send
428    where
429        S: Stream<Item = Annotated<BackendOutput>> + Send + 'static,
430        Resp: Send + Sync + 'static + std::fmt::Debug,
431    {
432        struct State<Resp>
433        where
434            Resp: Send + Sync + 'static + std::fmt::Debug,
435        {
436            response_stream: Pin<Box<dyn Stream<Item = Annotated<BackendOutput>> + Send>>,
437            response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
438            context: Arc<dyn AsyncEngineContext>,
439            cancelled: bool,
440            cumulative_output_tokens: usize,
441            finish_reason_sent: bool,
442            usage_chunk_sent: bool,
443            finished: bool,
444        }
445
446        let state = State {
447            response_stream: Box::pin(stream),
448            response_generator: generator,
449            context: context.clone(),
450            cancelled: false,
451            cumulative_output_tokens: 0,
452            finish_reason_sent: false,
453            usage_chunk_sent: false,
454            finished: false,
455        };
456
457        // transform the common response stream into a chat response stream
458
459        stream::unfold(state, |mut inner| {
460            async move {
461                // If already finished, return None immediately
462                if inner.finished {
463                    return None;
464                }
465
466                if let Some(response) = inner.response_stream.next().await {
467                    if inner.cancelled {
468                        tracing::debug!(
469                            request_id = inner.context.id(),
470                            "Cancellation issued last message; closing stream"
471                        );
472                        inner.finished = true; // Mark as finished
473                        return None;
474                    }
475
476                    tracing::trace!(
477                        request_id = inner.context.id(),
478                        "Processing common response: {:?}",
479                        response
480                    );
481
482                    // Check if this response has a finish_reason
483                    let has_finish_reason = response
484                        .data
485                        .as_ref()
486                        .map(|d| d.finish_reason.is_some())
487                        .unwrap_or(false);
488
489                    let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
490                        let chunk_tokens = backend_output.token_ids.len();
491                        inner.cumulative_output_tokens += chunk_tokens;
492
493                        let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
494
495                        (chunk_tokens, isl)
496                    } else {
497                        (0, 0)
498                    };
499
500                    let current_osl = inner.cumulative_output_tokens;
501
502                    let mut response = response.map_data(|data| {
503                        inner
504                            .response_generator
505                            .choice_from_postprocessor(data)
506                            .inspect_err(|e| {
507                                tracing::error!(
508                                    request_id = inner.context.id(),
509                                    "Error processing common response: {:?}",
510                                    e
511                                );
512                                inner.cancelled = true;
513                                inner.context.stop_generating();
514                            })
515                            .map_err(|e| e.to_string())
516                    });
517
518                    // Create LLM metrics annotation
519                    let llm_metrics = LLMMetricAnnotation {
520                        input_tokens: isl,
521                        output_tokens: current_osl,
522                        chunk_tokens,
523                    };
524
525                    if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
526                        // Only set event if not already set to avoid overriding existing events (like errors)
527                        if response.event.is_none() {
528                            response.event = metrics_annotated.event;
529                            response.comment = metrics_annotated.comment;
530                        }
531                    }
532
533                    // Mark if we've seen a finish_reason
534                    if has_finish_reason {
535                        inner.finish_reason_sent = true;
536                    }
537
538                    tracing::trace!(
539                        request_id = inner.context.id(),
540                        "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
541                        response
542                    );
543
544                    Some((response, inner))
545                } else {
546                    // Stream has ended - must set finished to true to prevent unfold from polling
547                    // again. The stream is exhausted and will panic if polled after None.
548                    inner.finished = true;
549
550                    // Check if we need to send a usage chunk
551                    if inner.response_generator.is_usage_enabled()
552                        && inner.finish_reason_sent
553                        && !inner.usage_chunk_sent
554                    {
555                        inner.usage_chunk_sent = true;
556
557                        // Create the final usage chunk
558                        let usage_chunk = inner.response_generator.create_usage_chunk();
559                        let annotated_usage = Annotated::<Resp> {
560                            id: None,
561                            data: Some(usage_chunk),
562                            event: Some(ANNOTATION_LLM_METRICS.to_string()),
563                            comment: None,
564                        };
565
566                        tracing::trace!(
567                            request_id = inner.context.id(),
568                            "Sending final usage chunk for OpenAI compliance"
569                        );
570
571                        Some((annotated_usage, inner))
572                    } else {
573                        // stream closed
574                        None
575                    }
576                }
577            }
578        })
579    }
580
581    /// Transform engine embedding output stream to OpenAI embedding response stream
582    pub fn transform_embedding_postprocessor_stream<S>(
583        stream: S,
584        original_request: NvCreateEmbeddingRequest,
585    ) -> impl Stream<Item = Annotated<NvCreateEmbeddingResponse>> + Send
586    where
587        S: Stream<Item = Annotated<EmbeddingsEngineOutput>> + Send + 'static,
588    {
589        stream.map(move |output| {
590            output.map_data(|engine_output| {
591                // Convert engine output to OpenAI response format
592                let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output
593                    .embeddings
594                    .into_iter()
595                    .enumerate()
596                    .map(|(index, embedding)| dynamo_async_openai::types::Embedding {
597                        index: index as u32,
598                        object: "embedding".to_string(),
599                        embedding: embedding.into_iter().map(|f| f as f32).collect(),
600                    })
601                    .collect();
602
603                let response = NvCreateEmbeddingResponse {
604                    inner: dynamo_async_openai::types::CreateEmbeddingResponse {
605                        object: "list".to_string(),
606                        model: original_request.inner.model.clone(),
607                        data: embeddings,
608                        usage: dynamo_async_openai::types::EmbeddingUsage {
609                            prompt_tokens: engine_output.prompt_tokens,
610                            total_tokens: engine_output.total_tokens,
611                        },
612                    },
613                };
614
615                Ok(response)
616            })
617        })
618    }
619
620    /// Determine if we should apply the tool calling jail based on configuration
621    /// Returns Ok(true) if jail should be applied, Ok(false) if not, or Err if invalid config
622    pub fn should_apply_tool_jail(
623        tool_call_parser: Option<&String>,
624        tool_choice: Option<&ChatCompletionToolChoiceOption>,
625        has_tools: bool,
626    ) -> std::result::Result<bool, Error> {
627        match (tool_call_parser, tool_choice, has_tools) {
628            // No parser but tools requested - error cases
629            (None, Some(ChatCompletionToolChoiceOption::Required), true) => {
630                tracing::warn!(
631                    "Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
632                );
633                Ok(false)
634            }
635            (None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
636                tracing::warn!(
637                    "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
638                );
639                Ok(false)
640            }
641            (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
642                tracing::warn!(
643                    "Named tool choice specified but no tool parser configured; proceeding without jailing"
644                );
645                Ok(false)
646            }
647
648            // Parser exists and tools might be called
649            (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
650                Ok(false) // Explicitly disabled
651            }
652            (Some(_), Some(_), true) => Ok(true), // Any other tool_choice with tools
653            (Some(_), None, true) => Ok(true),    // Default behavior when tools present
654
655            // No tools or no parser
656            _ => Ok(false),
657        }
658    }
659
660    /// Apply tool calling jail to the stream if needed
661    pub fn apply_tool_calling_jail<S>(
662        tool_call_parser: String,
663        stream: S,
664    ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
665    where
666        S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
667    {
668        let jail = JailedStream::builder()
669            .tool_call_parser(tool_call_parser)
670            .build();
671        jail.apply(stream)
672    }
673}
674
675// for pals, we do not want to add the generation prompt to the formatted prompt
676// we also need to know if the template support this add_generation_prompt bool
677// any prompt template that does not support this should return an error
678// oob - we should update any prompt template that does not support this to support it
679
680#[async_trait]
681impl
682    Operator<
683        SingleIn<NvCreateChatCompletionRequest>,
684        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
685        SingleIn<PreprocessedRequest>,
686        ManyOut<Annotated<BackendOutput>>,
687    > for OpenAIPreprocessor
688{
689    async fn generate(
690        &self,
691        request: SingleIn<NvCreateChatCompletionRequest>,
692        next: Arc<
693            dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
694        >,
695    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
696        // unpack the request
697        let (request, context) = request.into_parts();
698
699        // create a response generator
700        let response_generator = request.response_generator(context.id().to_string());
701
702        // convert the chat completion request to a common completion request
703        let (common_request, annotations) = self.preprocess_request(&request)?;
704
705        let mut response_generator = Box::new(response_generator);
706
707        // set the runtime configuration
708        response_generator.set_reasoning_parser(self.runtime_config.clone());
709
710        // update isl
711        response_generator.update_isl(common_request.token_ids.len() as u32);
712
713        // repack the common completion request
714        let common_request = context.map(|_| common_request);
715
716        // create a stream of annotations this will be prepend to the response stream
717        let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
718            .into_iter()
719            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
720            .collect();
721        let annotations_stream = stream::iter(annotations);
722
723        // forward the common completion request to the next operator
724        let response_stream = next.generate(common_request).await?;
725
726        // Extract context once
727        let context = response_stream.context();
728
729        // transform the postprocessor stream (no boxing yet)
730        let stream = Self::transform_postprocessor_stream(
731            response_stream,
732            response_generator,
733            context.clone(),
734        );
735
736        // Check if tools are present and if we should apply jail
737        let has_tools =
738            request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
739
740        // Context was already extracted above from response_stream
741
742        // Determine if we should apply jail (do this before moving request)
743        let should_jail = Self::should_apply_tool_jail(
744            self.tool_call_parser.as_ref(),
745            request.inner.tool_choice.as_ref(),
746            has_tools,
747        )?;
748
749        // Apply jail conditionally
750        let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
751            if let Some(parser) = self.tool_call_parser.clone() {
752                Box::pin(Self::apply_tool_calling_jail(parser, stream))
753            } else {
754                Box::pin(stream) // Should not happen due to should_jail check
755            }
756        } else {
757            Box::pin(stream)
758        };
759        // prepend the annotations to the response stream
760        let stream = annotations_stream.chain(stream);
761
762        // return the response stream - single boxing at the end
763        Ok(ResponseStream::new(Box::pin(stream), context))
764    }
765}
766
767#[async_trait]
768impl
769    Operator<
770        SingleIn<NvCreateCompletionRequest>,
771        ManyOut<Annotated<NvCreateCompletionResponse>>,
772        SingleIn<PreprocessedRequest>,
773        ManyOut<Annotated<BackendOutput>>,
774    > for OpenAIPreprocessor
775{
776    async fn generate(
777        &self,
778        request: SingleIn<NvCreateCompletionRequest>,
779        next: Arc<
780            dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
781        >,
782    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
783        // unpack the request
784        let (request, context) = request.into_parts();
785
786        // create a response generator
787        let response_generator = request.response_generator(context.id().to_string());
788        let mut response_generator = Box::new(response_generator);
789        // convert the chat completion request to a common completion request
790        let mut builder = self.builder(&request)?;
791        let annotations = self.gather_tokens(&request, &mut builder, None)?;
792        let common_request = builder.build()?;
793
794        // update isl
795        response_generator.update_isl(common_request.token_ids.len() as u32);
796
797        // repack the common completion request
798        let common_request = context.map(|_| common_request);
799
800        // create a stream of annotations this will be prepend to the response stream
801        let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
802            .into_iter()
803            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
804            .collect();
805        let annotations_stream = stream::iter(annotations);
806
807        // forward the common completion request to the next operator
808        let response_stream = next.generate(common_request).await?;
809
810        // Extract context once
811        let context = response_stream.context();
812
813        // transform the postprocessor stream
814        let stream = Self::transform_postprocessor_stream(
815            response_stream,
816            response_generator,
817            context.clone(),
818        );
819
820        // prepend the annotations to the response stream
821        let stream = annotations_stream.chain(stream);
822
823        // return the response stream
824        Ok(ResponseStream::new(Box::pin(stream), context))
825    }
826}
827
828#[async_trait]
829impl
830    Operator<
831        SingleIn<NvCreateEmbeddingRequest>,
832        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
833        SingleIn<PreprocessedEmbeddingRequest>,
834        ManyOut<Annotated<EmbeddingsEngineOutput>>,
835    > for OpenAIPreprocessor
836{
837    async fn generate(
838        &self,
839        request: SingleIn<NvCreateEmbeddingRequest>,
840        next: Arc<
841            dyn AsyncEngine<
842                    SingleIn<PreprocessedEmbeddingRequest>,
843                    ManyOut<Annotated<EmbeddingsEngineOutput>>,
844                    Error,
845                >,
846        >,
847    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
848        // Unpack request
849        let (request, context) = request.into_parts();
850
851        // Preprocess the embedding request
852        let (preprocessed_request, annotations) =
853            self.preprocess_embedding_request(&request).await?;
854
855        // Forward to next stage
856        let preprocessed_request = context.map(|_| preprocessed_request);
857        let response_stream = next.generate(preprocessed_request).await?;
858
859        // Extract context once
860        let context = response_stream.context();
861
862        // Transform response stream back to OpenAI format
863        let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
864
865        // Prepend annotations
866        let annotations_stream = stream::iter(
867            annotations
868                .into_iter()
869                .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
870                .collect::<Vec<_>>(),
871        );
872
873        let combined_stream = annotations_stream.chain(stream);
874        Ok(ResponseStream::new(Box::pin(combined_stream), context))
875    }
876}
877
878// Note: tests for jailing and parser detection live in `lib/llm/tests/test_jail.rs`