dynamo_llm/
preprocessor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! The Preprocessor consists of the following modules
17//!
18//! - `translation`: This module converts the allowed Ingress message types to the corresponding
19//!   internal representation.
20//! - `apply`: This module applies ModelConfig defaults to any empty optional fields specified
21//! - `prompt`: This module applies any prompt template logic to the internal Request object.
22//! - `tokenize`: This module tokenizes the formatted prompt string and returns the token ids.
23//!
24//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
25
26pub mod prompt;
27pub mod tools;
28
29use anyhow::Result;
30use futures::stream::{self, StreamExt};
31use prompt::OAIPromptFormatter;
32use std::{collections::HashMap, sync::Arc};
33use tracing;
34
35use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind};
36use crate::preprocessor::prompt::OAIChatLikeRequest;
37use crate::tokenizers::Encoding;
38
39use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
40use dynamo_runtime::pipeline::{
41    async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn,
42};
43use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
44
45use crate::protocols::{
46    common::{SamplingOptionsProvider, StopConditionsProvider},
47    openai::{
48        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
49        completions::{CompletionRequest, CompletionResponse},
50        nvext::NvExtProvider,
51        DeltaGeneratorExt,
52    },
53};
54use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
55
56use crate::preprocessor::prompt::PromptFormatter;
57
58pub use crate::protocols::common::llm_backend::{BackendInput, BackendOutput};
59
60pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
61pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
62
63pub struct OpenAIPreprocessor {
64    mdcsum: String,
65    formatter: Arc<dyn OAIPromptFormatter>,
66    tokenizer: Arc<dyn Tokenizer>,
67    model_info: Arc<dyn ModelInfo>,
68}
69
70impl OpenAIPreprocessor {
71    pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
72        let mdcsum = mdc.mdcsum();
73        let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
74        let PromptFormatter::OAI(formatter) = formatter;
75
76        let tokenizer = match &mdc.tokenizer {
77            Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
78            Some(TokenizerKind::GGUF(tokenizer)) => {
79                HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
80            }
81            None => {
82                anyhow::bail!(
83                    "Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
84                );
85            }
86        };
87        let tokenizer = Arc::new(tokenizer);
88
89        let Some(model_info) = mdc.model_info else {
90            anyhow::bail!(
91                "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
92            );
93        };
94        let model_info = model_info.get_model_info().await?;
95
96        Ok(Arc::new(Self {
97            formatter,
98            tokenizer,
99            model_info,
100            mdcsum,
101        }))
102    }
103
104    /// Encode a string to it's tokens
105    pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
106        self.tokenizer.encode(s)
107    }
108
109    /// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request.
110    /// Returns both the common completion request and a hashmap of annotations.
111    ///
112    /// Annotations evaluated by this method include:
113    /// - `formatted_prompt`
114    /// - `token_ids`
115    pub fn preprocess_request<
116        R: OAIChatLikeRequest
117            + AnnotationsProvider
118            + SamplingOptionsProvider
119            + StopConditionsProvider
120            + NvExtProvider,
121    >(
122        &self,
123        request: &R,
124    ) -> Result<(BackendInput, HashMap<String, String>)> {
125        let mut annotations = HashMap::new();
126        let mut builder = BackendInput::builder();
127
128        let use_raw_prompt = request
129            .nvext()
130            .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
131
132        let formatted_prompt = if use_raw_prompt {
133            match request.raw_prompt() {
134                Some(prompt) => prompt,
135                None => {
136                    tracing::warn!("Raw prompt requested but not available");
137                    self.formatter.render(request)?
138                }
139            }
140        } else {
141            self.formatter.render(request)?
142        };
143
144        let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?;
145
146        if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
147            annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt);
148        }
149
150        if request.has_annotation(ANNOTATION_TOKEN_IDS) {
151            annotations.insert(
152                ANNOTATION_TOKEN_IDS.to_string(),
153                serde_json::to_string(&encoding.token_ids)?,
154            );
155        }
156
157        let mut stop_conditions = request.extract_stop_conditions()?;
158        if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
159            for eos_token in self.model_info.eos_token_ids() {
160                if !stop_tokens.contains(&eos_token) {
161                    stop_tokens.push(eos_token);
162                }
163            }
164        } else {
165            stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
166        }
167
168        // apply ignore eos if not already set
169        stop_conditions.apply_ignore_eos();
170
171        if !stop_conditions.ignore_eos.unwrap_or(false) {
172            builder.eos_token_ids(self.model_info.eos_token_ids());
173        }
174
175        builder.token_ids(encoding.token_ids);
176        builder.sampling_options(request.extract_sampling_options()?);
177        builder.stop_conditions(stop_conditions);
178        builder.annotations(request.annotations().unwrap_or_default());
179        builder.mdc_sum(Some(self.mdcsum.clone()));
180
181        Ok((builder.build()?, annotations))
182    }
183
184    pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
185        stream: ManyOut<Annotated<BackendOutput>>,
186        generator: Box<dyn DeltaGeneratorExt<Resp>>,
187    ) -> ManyOut<Annotated<Resp>> {
188        let context = stream.context();
189
190        struct State<Resp: Send + Sync + 'static + std::fmt::Debug> {
191            response_stream: ManyOut<Annotated<BackendOutput>>,
192            response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
193            context: Arc<dyn AsyncEngineContext>,
194            cancelled: bool,
195        }
196
197        let state = State {
198            response_stream: stream,
199            response_generator: generator,
200            context: context.clone(),
201            cancelled: false,
202        };
203
204        // transform the common response stream into a chat response stream
205        let stream = stream::unfold(state, |mut inner| {
206            async move {
207                if let Some(response) = inner.response_stream.next().await {
208                    if inner.cancelled {
209                        tracing::debug!(
210                            request_id = inner.context.id(),
211                            "Cancellation issued last message; closing stream"
212                        );
213                        return None;
214                    }
215
216                    tracing::trace!(
217                        request_id = inner.context.id(),
218                        "Processing common response: {:?}",
219                        response
220                    );
221
222                    let response = response.map_data(|data| {
223                        inner
224                            .response_generator
225                            .choice_from_postprocessor(data)
226                            .inspect_err(|e| {
227                                tracing::error!(
228                                    request_id = inner.context.id(),
229                                    "Error processing common response: {:?}",
230                                    e
231                                );
232                                inner.cancelled = true;
233                                inner.context.stop_generating();
234                            })
235                            .map_err(|e| e.to_string())
236                    });
237
238                    tracing::trace!(
239                        request_id = inner.context.id(),
240                        "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
241                        response
242                    );
243
244                    Some((response, inner))
245                } else {
246                    // stream closed with out graceful closure
247                    // we did not detect an is_finished/completed message
248                    // Ok(None)
249                    None
250                }
251            }
252        });
253
254        ResponseStream::new(Box::pin(stream), context)
255    }
256}
257
258// for pals, we do not want to add the generation prompt to the formatted prompt
259// we also need to know if the template support this add_generation_prompt bool
260// any prompt template that does not support this should return an error
261// oob - we should update any prompt template that does not support this to support it
262
263#[async_trait]
264impl
265    Operator<
266        SingleIn<NvCreateChatCompletionRequest>,
267        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
268        SingleIn<BackendInput>,
269        ManyOut<Annotated<BackendOutput>>,
270    > for OpenAIPreprocessor
271{
272    async fn generate(
273        &self,
274        request: SingleIn<NvCreateChatCompletionRequest>,
275        next: Arc<
276            dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
277        >,
278    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
279        // unpack the request
280        let (request, context) = request.into_parts();
281
282        // create a response generator
283        let response_generator = request.response_generator();
284        let mut response_generator = Box::new(response_generator);
285
286        // convert the chat completion request to a common completion request
287        let (common_request, annotations) = self.preprocess_request(&request)?;
288
289        // update isl
290        response_generator.update_isl(common_request.token_ids.len() as u32);
291
292        // repack the common completion request
293        let common_request = context.map(|_| common_request);
294
295        // create a stream of annotations this will be prepend to the response stream
296        let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
297            .into_iter()
298            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
299            .collect();
300        let annotations_stream = stream::iter(annotations);
301
302        // forward the common completion request to the next operator
303        let response_stream = next.generate(common_request).await?;
304
305        // transform the postprocessor stream
306        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
307        let context = stream.context();
308
309        // prepend the annotations to the response stream
310        let stream = annotations_stream.chain(stream);
311
312        // return the response stream
313        Ok(ResponseStream::new(Box::pin(stream), context))
314    }
315}
316
317#[async_trait]
318impl
319    Operator<
320        SingleIn<CompletionRequest>,
321        ManyOut<Annotated<CompletionResponse>>,
322        SingleIn<BackendInput>,
323        ManyOut<Annotated<BackendOutput>>,
324    > for OpenAIPreprocessor
325{
326    async fn generate(
327        &self,
328        request: SingleIn<CompletionRequest>,
329        next: Arc<
330            dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
331        >,
332    ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
333        // unpack the request
334        let (request, context) = request.into_parts();
335
336        // create a response generator
337        let response_generator = request.response_generator();
338        let mut response_generator = Box::new(response_generator);
339        // convert the chat completion request to a common completion request
340        let (common_request, annotations) = self.preprocess_request(&request)?;
341
342        // update isl
343        response_generator.update_isl(common_request.token_ids.len() as i32);
344
345        // repack the common completion request
346        let common_request = context.map(|_| common_request);
347
348        // create a stream of annotations this will be prepend to the response stream
349        let annotations: Vec<Annotated<CompletionResponse>> = annotations
350            .into_iter()
351            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
352            .collect();
353        let annotations_stream = stream::iter(annotations);
354
355        // forward the common completion request to the next operator
356        let response_stream = next.generate(common_request).await?;
357
358        // transform the postprocessor stream
359        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
360        let context = stream.context();
361
362        // prepend the annotations to the response stream
363        let stream = annotations_stream.chain(stream);
364
365        // return the response stream
366        Ok(ResponseStream::new(Box::pin(stream), context))
367    }
368}