dynamo_llm/
preprocessor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! The Preprocessor consists of the following modules
17//!
18//! - `translation`: This module converts the allowed Ingress message types to the corresponding
19//!   internal representation.
20//! - `apply`: This module applies ModelConfig defaults to any empty optional fields specified
21//! - `prompt`: This module applies any prompt template logic to the internal Request object.
22//! - `tokenize`: This module tokenizes the formatted prompt string and returns the token ids.
23//!
24//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
25
26pub mod prompt;
27pub mod tools;
28
29use anyhow::Result;
30use futures::stream::{self, StreamExt};
31use prompt::OAIPromptFormatter;
32use std::{collections::HashMap, sync::Arc};
33use tracing;
34
35use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind};
36use crate::preprocessor::prompt::OAIChatLikeRequest;
37use crate::tokenizers::Encoding;
38
39use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
40use dynamo_runtime::pipeline::{
41    async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn,
42};
43use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
44
45use crate::protocols::{
46    common::{SamplingOptionsProvider, StopConditionsProvider},
47    openai::{
48        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
49        completions::{CompletionResponse, NvCreateCompletionRequest},
50        nvext::NvExtProvider,
51        DeltaGeneratorExt,
52    },
53};
54use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
55
56use crate::preprocessor::prompt::PromptFormatter;
57
58pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
59
60pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
61pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
62pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct LLMMetricAnnotation {
65    pub input_tokens: usize,
66    pub output_tokens: usize,
67    pub chunk_tokens: usize,
68}
69
70impl LLMMetricAnnotation {
71    /// Convert this metrics struct to an Annotated event
72    pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
73        Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
74    }
75
76    /// Extract LLM metrics from an Annotated event, if present
77    pub fn from_annotation<T>(
78        annotation: &Annotated<T>,
79    ) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
80        if annotation.event.is_none() {
81            return Ok(None);
82        }
83        if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
84            return Ok(None);
85        }
86        let comments = annotation
87            .comment
88            .as_ref()
89            .ok_or("missing comments block")?;
90        if comments.len() != 1 {
91            return Err("malformed comments block - expected exactly 1 comment".into());
92        }
93        let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
94        Ok(Some(metrics))
95    }
96}
97
98pub struct OpenAIPreprocessor {
99    mdcsum: String,
100    formatter: Arc<dyn OAIPromptFormatter>,
101    tokenizer: Arc<dyn Tokenizer>,
102    model_info: Arc<dyn ModelInfo>,
103}
104
105impl OpenAIPreprocessor {
106    pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
107        let mdcsum = mdc.mdcsum();
108        let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
109        let PromptFormatter::OAI(formatter) = formatter;
110
111        let tokenizer = match &mdc.tokenizer {
112            Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
113            Some(TokenizerKind::GGUF(tokenizer)) => {
114                HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
115            }
116            None => {
117                anyhow::bail!(
118                    "Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
119                );
120            }
121        };
122        let tokenizer = Arc::new(tokenizer);
123
124        let Some(model_info) = mdc.model_info else {
125            anyhow::bail!(
126                "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
127            );
128        };
129        let model_info = model_info.get_model_info().await?;
130
131        Ok(Arc::new(Self {
132            formatter,
133            tokenizer,
134            model_info,
135            mdcsum,
136        }))
137    }
138
139    /// Encode a string to it's tokens
140    pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
141        self.tokenizer.encode(s)
142    }
143
144    /// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request.
145    /// Returns both the common completion request and a hashmap of annotations.
146    ///
147    /// Annotations evaluated by this method include:
148    /// - `formatted_prompt`
149    /// - `token_ids`
150    pub fn preprocess_request<
151        R: OAIChatLikeRequest
152            + AnnotationsProvider
153            + SamplingOptionsProvider
154            + StopConditionsProvider
155            + NvExtProvider,
156    >(
157        &self,
158        request: &R,
159    ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
160        let mut annotations = HashMap::new();
161        let mut builder = PreprocessedRequest::builder();
162
163        let use_raw_prompt = request
164            .nvext()
165            .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
166
167        let formatted_prompt = if use_raw_prompt {
168            match request.raw_prompt() {
169                Some(prompt) => prompt,
170                None => {
171                    tracing::warn!("Raw prompt requested but not available");
172                    self.formatter.render(request)?
173                }
174            }
175        } else {
176            self.formatter.render(request)?
177        };
178
179        let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?;
180
181        if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
182            annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt);
183        }
184
185        if request.has_annotation(ANNOTATION_TOKEN_IDS) {
186            annotations.insert(
187                ANNOTATION_TOKEN_IDS.to_string(),
188                serde_json::to_string(&encoding.token_ids)?,
189            );
190        }
191
192        let mut stop_conditions = request.extract_stop_conditions()?;
193        if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
194            for eos_token in self.model_info.eos_token_ids() {
195                if !stop_tokens.contains(&eos_token) {
196                    stop_tokens.push(eos_token);
197                }
198            }
199        } else {
200            stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
201        }
202
203        // apply ignore eos if not already set
204        stop_conditions.apply_ignore_eos();
205
206        if !stop_conditions.ignore_eos.unwrap_or(false) {
207            builder.eos_token_ids(self.model_info.eos_token_ids());
208        }
209
210        builder.token_ids(encoding.token_ids);
211        builder.sampling_options(request.extract_sampling_options()?);
212        builder.stop_conditions(stop_conditions);
213        builder.annotations(request.annotations().unwrap_or_default());
214        builder.mdc_sum(Some(self.mdcsum.clone()));
215        builder.estimated_prefix_hit_num_blocks(None);
216
217        Ok((builder.build()?, annotations))
218    }
219
220    pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
221        stream: ManyOut<Annotated<BackendOutput>>,
222        generator: Box<dyn DeltaGeneratorExt<Resp>>,
223    ) -> ManyOut<Annotated<Resp>> {
224        let context = stream.context();
225
226        struct State<Resp: Send + Sync + 'static + std::fmt::Debug> {
227            response_stream: ManyOut<Annotated<BackendOutput>>,
228            response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
229            context: Arc<dyn AsyncEngineContext>,
230            cancelled: bool,
231            cumulative_output_tokens: usize,
232        }
233
234        let state = State {
235            response_stream: stream,
236            response_generator: generator,
237            context: context.clone(),
238            cancelled: false,
239            cumulative_output_tokens: 0,
240        };
241
242        // transform the common response stream into a chat response stream
243        let stream = stream::unfold(state, |mut inner| {
244            async move {
245                if let Some(response) = inner.response_stream.next().await {
246                    if inner.cancelled {
247                        tracing::debug!(
248                            request_id = inner.context.id(),
249                            "Cancellation issued last message; closing stream"
250                        );
251                        return None;
252                    }
253
254                    tracing::trace!(
255                        request_id = inner.context.id(),
256                        "Processing common response: {:?}",
257                        response
258                    );
259
260                    let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
261                        let chunk_tokens = backend_output.token_ids.len();
262                        inner.cumulative_output_tokens += chunk_tokens;
263
264                        let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
265
266                        (chunk_tokens, isl)
267                    } else {
268                        (0, 0)
269                    };
270
271                    let current_osl = inner.cumulative_output_tokens;
272
273                    let mut response = response.map_data(|data| {
274                        inner
275                            .response_generator
276                            .choice_from_postprocessor(data)
277                            .inspect_err(|e| {
278                                tracing::error!(
279                                    request_id = inner.context.id(),
280                                    "Error processing common response: {:?}",
281                                    e
282                                );
283                                inner.cancelled = true;
284                                inner.context.stop_generating();
285                            })
286                            .map_err(|e| e.to_string())
287                    });
288
289                    // Create LLM metrics annotation
290                    let llm_metrics = LLMMetricAnnotation {
291                        input_tokens: isl,
292                        output_tokens: current_osl,
293                        chunk_tokens,
294                    };
295
296                    if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
297                        // Only set event if not already set to avoid overriding existing events (like errors)
298                        if response.event.is_none() {
299                            response.event = metrics_annotated.event;
300                        }
301                        response.comment = metrics_annotated.comment;
302                    }
303
304                    tracing::trace!(
305                        request_id = inner.context.id(),
306                        "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
307                        response
308                    );
309
310                    Some((response, inner))
311                } else {
312                    // stream closed with out graceful closure
313                    // we did not detect an is_finished/completed message
314                    // Ok(None)
315                    None
316                }
317            }
318        });
319
320        ResponseStream::new(Box::pin(stream), context)
321    }
322}
323
324// for pals, we do not want to add the generation prompt to the formatted prompt
325// we also need to know if the template support this add_generation_prompt bool
326// any prompt template that does not support this should return an error
327// oob - we should update any prompt template that does not support this to support it
328
329#[async_trait]
330impl
331    Operator<
332        SingleIn<NvCreateChatCompletionRequest>,
333        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
334        SingleIn<PreprocessedRequest>,
335        ManyOut<Annotated<BackendOutput>>,
336    > for OpenAIPreprocessor
337{
338    async fn generate(
339        &self,
340        request: SingleIn<NvCreateChatCompletionRequest>,
341        next: Arc<
342            dyn AsyncEngine<
343                SingleIn<PreprocessedRequest>,
344                ManyOut<Annotated<BackendOutput>>,
345                Error,
346            >,
347        >,
348    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
349        // unpack the request
350        let (request, context) = request.into_parts();
351
352        // create a response generator
353        let response_generator = request.response_generator();
354        let mut response_generator = Box::new(response_generator);
355
356        // convert the chat completion request to a common completion request
357        let (common_request, annotations) = self.preprocess_request(&request)?;
358
359        // update isl
360        response_generator.update_isl(common_request.token_ids.len() as u32);
361
362        // repack the common completion request
363        let common_request = context.map(|_| common_request);
364
365        // create a stream of annotations this will be prepend to the response stream
366        let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
367            .into_iter()
368            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
369            .collect();
370        let annotations_stream = stream::iter(annotations);
371
372        // forward the common completion request to the next operator
373        let response_stream = next.generate(common_request).await?;
374
375        // transform the postprocessor stream
376        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
377        let context = stream.context();
378
379        // prepend the annotations to the response stream
380        let stream = annotations_stream.chain(stream);
381
382        // return the response stream
383        Ok(ResponseStream::new(Box::pin(stream), context))
384    }
385}
386
387#[async_trait]
388impl
389    Operator<
390        SingleIn<NvCreateCompletionRequest>,
391        ManyOut<Annotated<CompletionResponse>>,
392        SingleIn<PreprocessedRequest>,
393        ManyOut<Annotated<BackendOutput>>,
394    > for OpenAIPreprocessor
395{
396    async fn generate(
397        &self,
398        request: SingleIn<NvCreateCompletionRequest>,
399        next: Arc<
400            dyn AsyncEngine<
401                SingleIn<PreprocessedRequest>,
402                ManyOut<Annotated<BackendOutput>>,
403                Error,
404            >,
405        >,
406    ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
407        // unpack the request
408        let (request, context) = request.into_parts();
409
410        // create a response generator
411        let response_generator = request.response_generator();
412        let mut response_generator = Box::new(response_generator);
413        // convert the chat completion request to a common completion request
414        let (common_request, annotations) = self.preprocess_request(&request)?;
415
416        // update isl
417        response_generator.update_isl(common_request.token_ids.len() as i32);
418
419        // repack the common completion request
420        let common_request = context.map(|_| common_request);
421
422        // create a stream of annotations this will be prepend to the response stream
423        let annotations: Vec<Annotated<CompletionResponse>> = annotations
424            .into_iter()
425            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
426            .collect();
427        let annotations_stream = stream::iter(annotations);
428
429        // forward the common completion request to the next operator
430        let response_stream = next.generate(common_request).await?;
431
432        // transform the postprocessor stream
433        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
434        let context = stream.context();
435
436        // prepend the annotations to the response stream
437        let stream = annotations_stream.chain(stream);
438
439        // return the response stream
440        Ok(ResponseStream::new(Box::pin(stream), context))
441    }
442}