1pub mod prompt;
27pub mod tools;
28
29use anyhow::Result;
30use futures::stream::{self, StreamExt};
31use prompt::OAIPromptFormatter;
32use std::{collections::HashMap, sync::Arc};
33use tracing;
34
35use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind};
36use crate::preprocessor::prompt::OAIChatLikeRequest;
37use crate::tokenizers::Encoding;
38
39use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
40use dynamo_runtime::pipeline::{
41 async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn,
42};
43use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
44
45use crate::protocols::{
46 common::{SamplingOptionsProvider, StopConditionsProvider},
47 openai::{
48 chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
49 completions::{CompletionRequest, CompletionResponse},
50 nvext::NvExtProvider,
51 DeltaGeneratorExt,
52 },
53};
54use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
55
56use crate::preprocessor::prompt::PromptFormatter;
57
58pub use crate::protocols::common::llm_backend::{BackendInput, BackendOutput};
59
60pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
61pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
62
63pub struct OpenAIPreprocessor {
64 mdcsum: String,
65 formatter: Arc<dyn OAIPromptFormatter>,
66 tokenizer: Arc<dyn Tokenizer>,
67 model_info: Arc<dyn ModelInfo>,
68}
69
70impl OpenAIPreprocessor {
71 pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
72 let mdcsum = mdc.mdcsum();
73 let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
74 let PromptFormatter::OAI(formatter) = formatter;
75
76 let tokenizer = match &mdc.tokenizer {
77 Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
78 Some(TokenizerKind::GGUF(tokenizer)) => {
79 HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
80 }
81 None => {
82 anyhow::bail!(
83 "Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
84 );
85 }
86 };
87 let tokenizer = Arc::new(tokenizer);
88
89 let Some(model_info) = mdc.model_info else {
90 anyhow::bail!(
91 "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
92 );
93 };
94 let model_info = model_info.get_model_info().await?;
95
96 Ok(Arc::new(Self {
97 formatter,
98 tokenizer,
99 model_info,
100 mdcsum,
101 }))
102 }
103
104 pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
106 self.tokenizer.encode(s)
107 }
108
109 pub fn preprocess_request<
116 R: OAIChatLikeRequest
117 + AnnotationsProvider
118 + SamplingOptionsProvider
119 + StopConditionsProvider
120 + NvExtProvider,
121 >(
122 &self,
123 request: &R,
124 ) -> Result<(BackendInput, HashMap<String, String>)> {
125 let mut annotations = HashMap::new();
126 let mut builder = BackendInput::builder();
127
128 let use_raw_prompt = request
129 .nvext()
130 .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
131
132 let formatted_prompt = if use_raw_prompt {
133 match request.raw_prompt() {
134 Some(prompt) => prompt,
135 None => {
136 tracing::warn!("Raw prompt requested but not available");
137 self.formatter.render(request)?
138 }
139 }
140 } else {
141 self.formatter.render(request)?
142 };
143
144 let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?;
145
146 if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
147 annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt);
148 }
149
150 if request.has_annotation(ANNOTATION_TOKEN_IDS) {
151 annotations.insert(
152 ANNOTATION_TOKEN_IDS.to_string(),
153 serde_json::to_string(&encoding.token_ids)?,
154 );
155 }
156
157 let mut stop_conditions = request.extract_stop_conditions()?;
158 if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
159 for eos_token in self.model_info.eos_token_ids() {
160 if !stop_tokens.contains(&eos_token) {
161 stop_tokens.push(eos_token);
162 }
163 }
164 } else {
165 stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
166 }
167
168 stop_conditions.apply_ignore_eos();
170
171 if !stop_conditions.ignore_eos.unwrap_or(false) {
172 builder.eos_token_ids(self.model_info.eos_token_ids());
173 }
174
175 builder.token_ids(encoding.token_ids);
176 builder.sampling_options(request.extract_sampling_options()?);
177 builder.stop_conditions(stop_conditions);
178 builder.annotations(request.annotations().unwrap_or_default());
179 builder.mdc_sum(Some(self.mdcsum.clone()));
180
181 Ok((builder.build()?, annotations))
182 }
183
184 pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
185 stream: ManyOut<Annotated<BackendOutput>>,
186 generator: Box<dyn DeltaGeneratorExt<Resp>>,
187 ) -> ManyOut<Annotated<Resp>> {
188 let context = stream.context();
189
190 struct State<Resp: Send + Sync + 'static + std::fmt::Debug> {
191 response_stream: ManyOut<Annotated<BackendOutput>>,
192 response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
193 context: Arc<dyn AsyncEngineContext>,
194 cancelled: bool,
195 }
196
197 let state = State {
198 response_stream: stream,
199 response_generator: generator,
200 context: context.clone(),
201 cancelled: false,
202 };
203
204 let stream = stream::unfold(state, |mut inner| {
206 async move {
207 if let Some(response) = inner.response_stream.next().await {
208 if inner.cancelled {
209 tracing::debug!(
210 request_id = inner.context.id(),
211 "Cancellation issued last message; closing stream"
212 );
213 return None;
214 }
215
216 tracing::trace!(
217 request_id = inner.context.id(),
218 "Processing common response: {:?}",
219 response
220 );
221
222 let response = response.map_data(|data| {
223 inner
224 .response_generator
225 .choice_from_postprocessor(data)
226 .inspect_err(|e| {
227 tracing::error!(
228 request_id = inner.context.id(),
229 "Error processing common response: {:?}",
230 e
231 );
232 inner.cancelled = true;
233 inner.context.stop_generating();
234 })
235 .map_err(|e| e.to_string())
236 });
237
238 tracing::trace!(
239 request_id = inner.context.id(),
240 "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
241 response
242 );
243
244 Some((response, inner))
245 } else {
246 None
250 }
251 }
252 });
253
254 ResponseStream::new(Box::pin(stream), context)
255 }
256}
257
258#[async_trait]
264impl
265 Operator<
266 SingleIn<NvCreateChatCompletionRequest>,
267 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
268 SingleIn<BackendInput>,
269 ManyOut<Annotated<BackendOutput>>,
270 > for OpenAIPreprocessor
271{
272 async fn generate(
273 &self,
274 request: SingleIn<NvCreateChatCompletionRequest>,
275 next: Arc<
276 dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
277 >,
278 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
279 let (request, context) = request.into_parts();
281
282 let response_generator = request.response_generator();
284 let mut response_generator = Box::new(response_generator);
285
286 let (common_request, annotations) = self.preprocess_request(&request)?;
288
289 response_generator.update_isl(common_request.token_ids.len() as u32);
291
292 let common_request = context.map(|_| common_request);
294
295 let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
297 .into_iter()
298 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
299 .collect();
300 let annotations_stream = stream::iter(annotations);
301
302 let response_stream = next.generate(common_request).await?;
304
305 let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
307 let context = stream.context();
308
309 let stream = annotations_stream.chain(stream);
311
312 Ok(ResponseStream::new(Box::pin(stream), context))
314 }
315}
316
317#[async_trait]
318impl
319 Operator<
320 SingleIn<CompletionRequest>,
321 ManyOut<Annotated<CompletionResponse>>,
322 SingleIn<BackendInput>,
323 ManyOut<Annotated<BackendOutput>>,
324 > for OpenAIPreprocessor
325{
326 async fn generate(
327 &self,
328 request: SingleIn<CompletionRequest>,
329 next: Arc<
330 dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
331 >,
332 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
333 let (request, context) = request.into_parts();
335
336 let response_generator = request.response_generator();
338 let mut response_generator = Box::new(response_generator);
339 let (common_request, annotations) = self.preprocess_request(&request)?;
341
342 response_generator.update_isl(common_request.token_ids.len() as i32);
344
345 let common_request = context.map(|_| common_request);
347
348 let annotations: Vec<Annotated<CompletionResponse>> = annotations
350 .into_iter()
351 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
352 .collect();
353 let annotations_stream = stream::iter(annotations);
354
355 let response_stream = next.generate(common_request).await?;
357
358 let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
360 let context = stream.context();
361
362 let stream = annotations_stream.chain(stream);
364
365 Ok(ResponseStream::new(Box::pin(stream), context))
367 }
368}