1pub mod prompt;
27pub mod tools;
28
29use anyhow::Result;
30use futures::stream::{self, StreamExt};
31use prompt::OAIPromptFormatter;
32use std::{collections::HashMap, sync::Arc};
33use tracing;
34
35use crate::model_card::model::{ModelDeploymentCard, ModelInfo, TokenizerKind};
36use crate::preprocessor::prompt::OAIChatLikeRequest;
37use crate::tokenizers::Encoding;
38
39use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
40use dynamo_runtime::pipeline::{
41 async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn,
42};
43use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
44
45use crate::protocols::{
46 common::{SamplingOptionsProvider, StopConditionsProvider},
47 openai::{
48 chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
49 completions::{CompletionResponse, NvCreateCompletionRequest},
50 nvext::NvExtProvider,
51 DeltaGeneratorExt,
52 },
53};
54use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
55
56use crate::preprocessor::prompt::PromptFormatter;
57
58pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
59
60pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
61pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
62pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct LLMMetricAnnotation {
65 pub input_tokens: usize,
66 pub output_tokens: usize,
67 pub chunk_tokens: usize,
68}
69
70impl LLMMetricAnnotation {
71 pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
73 Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
74 }
75
76 pub fn from_annotation<T>(
78 annotation: &Annotated<T>,
79 ) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
80 if annotation.event.is_none() {
81 return Ok(None);
82 }
83 if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
84 return Ok(None);
85 }
86 let comments = annotation
87 .comment
88 .as_ref()
89 .ok_or("missing comments block")?;
90 if comments.len() != 1 {
91 return Err("malformed comments block - expected exactly 1 comment".into());
92 }
93 let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
94 Ok(Some(metrics))
95 }
96}
97
98pub struct OpenAIPreprocessor {
99 mdcsum: String,
100 formatter: Arc<dyn OAIPromptFormatter>,
101 tokenizer: Arc<dyn Tokenizer>,
102 model_info: Arc<dyn ModelInfo>,
103}
104
105impl OpenAIPreprocessor {
106 pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
107 let mdcsum = mdc.mdcsum();
108 let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
109 let PromptFormatter::OAI(formatter) = formatter;
110
111 let tokenizer = match &mdc.tokenizer {
112 Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
113 Some(TokenizerKind::GGUF(tokenizer)) => {
114 HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
115 }
116 None => {
117 anyhow::bail!(
118 "Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
119 );
120 }
121 };
122 let tokenizer = Arc::new(tokenizer);
123
124 let Some(model_info) = mdc.model_info else {
125 anyhow::bail!(
126 "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
127 );
128 };
129 let model_info = model_info.get_model_info().await?;
130
131 Ok(Arc::new(Self {
132 formatter,
133 tokenizer,
134 model_info,
135 mdcsum,
136 }))
137 }
138
139 pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
141 self.tokenizer.encode(s)
142 }
143
144 pub fn preprocess_request<
151 R: OAIChatLikeRequest
152 + AnnotationsProvider
153 + SamplingOptionsProvider
154 + StopConditionsProvider
155 + NvExtProvider,
156 >(
157 &self,
158 request: &R,
159 ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
160 let mut annotations = HashMap::new();
161 let mut builder = PreprocessedRequest::builder();
162
163 let use_raw_prompt = request
164 .nvext()
165 .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
166
167 let formatted_prompt = if use_raw_prompt {
168 match request.raw_prompt() {
169 Some(prompt) => prompt,
170 None => {
171 tracing::warn!("Raw prompt requested but not available");
172 self.formatter.render(request)?
173 }
174 }
175 } else {
176 self.formatter.render(request)?
177 };
178
179 let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?;
180
181 if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
182 annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt);
183 }
184
185 if request.has_annotation(ANNOTATION_TOKEN_IDS) {
186 annotations.insert(
187 ANNOTATION_TOKEN_IDS.to_string(),
188 serde_json::to_string(&encoding.token_ids)?,
189 );
190 }
191
192 let mut stop_conditions = request.extract_stop_conditions()?;
193 if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
194 for eos_token in self.model_info.eos_token_ids() {
195 if !stop_tokens.contains(&eos_token) {
196 stop_tokens.push(eos_token);
197 }
198 }
199 } else {
200 stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
201 }
202
203 stop_conditions.apply_ignore_eos();
205
206 if !stop_conditions.ignore_eos.unwrap_or(false) {
207 builder.eos_token_ids(self.model_info.eos_token_ids());
208 }
209
210 builder.token_ids(encoding.token_ids);
211 builder.sampling_options(request.extract_sampling_options()?);
212 builder.stop_conditions(stop_conditions);
213 builder.annotations(request.annotations().unwrap_or_default());
214 builder.mdc_sum(Some(self.mdcsum.clone()));
215 builder.estimated_prefix_hit_num_blocks(None);
216
217 Ok((builder.build()?, annotations))
218 }
219
220 pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
221 stream: ManyOut<Annotated<BackendOutput>>,
222 generator: Box<dyn DeltaGeneratorExt<Resp>>,
223 ) -> ManyOut<Annotated<Resp>> {
224 let context = stream.context();
225
226 struct State<Resp: Send + Sync + 'static + std::fmt::Debug> {
227 response_stream: ManyOut<Annotated<BackendOutput>>,
228 response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
229 context: Arc<dyn AsyncEngineContext>,
230 cancelled: bool,
231 cumulative_output_tokens: usize,
232 }
233
234 let state = State {
235 response_stream: stream,
236 response_generator: generator,
237 context: context.clone(),
238 cancelled: false,
239 cumulative_output_tokens: 0,
240 };
241
242 let stream = stream::unfold(state, |mut inner| {
244 async move {
245 if let Some(response) = inner.response_stream.next().await {
246 if inner.cancelled {
247 tracing::debug!(
248 request_id = inner.context.id(),
249 "Cancellation issued last message; closing stream"
250 );
251 return None;
252 }
253
254 tracing::trace!(
255 request_id = inner.context.id(),
256 "Processing common response: {:?}",
257 response
258 );
259
260 let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
261 let chunk_tokens = backend_output.token_ids.len();
262 inner.cumulative_output_tokens += chunk_tokens;
263
264 let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
265
266 (chunk_tokens, isl)
267 } else {
268 (0, 0)
269 };
270
271 let current_osl = inner.cumulative_output_tokens;
272
273 let mut response = response.map_data(|data| {
274 inner
275 .response_generator
276 .choice_from_postprocessor(data)
277 .inspect_err(|e| {
278 tracing::error!(
279 request_id = inner.context.id(),
280 "Error processing common response: {:?}",
281 e
282 );
283 inner.cancelled = true;
284 inner.context.stop_generating();
285 })
286 .map_err(|e| e.to_string())
287 });
288
289 let llm_metrics = LLMMetricAnnotation {
291 input_tokens: isl,
292 output_tokens: current_osl,
293 chunk_tokens,
294 };
295
296 if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
297 if response.event.is_none() {
299 response.event = metrics_annotated.event;
300 }
301 response.comment = metrics_annotated.comment;
302 }
303
304 tracing::trace!(
305 request_id = inner.context.id(),
306 "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
307 response
308 );
309
310 Some((response, inner))
311 } else {
312 None
316 }
317 }
318 });
319
320 ResponseStream::new(Box::pin(stream), context)
321 }
322}
323
324#[async_trait]
330impl
331 Operator<
332 SingleIn<NvCreateChatCompletionRequest>,
333 ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
334 SingleIn<PreprocessedRequest>,
335 ManyOut<Annotated<BackendOutput>>,
336 > for OpenAIPreprocessor
337{
338 async fn generate(
339 &self,
340 request: SingleIn<NvCreateChatCompletionRequest>,
341 next: Arc<
342 dyn AsyncEngine<
343 SingleIn<PreprocessedRequest>,
344 ManyOut<Annotated<BackendOutput>>,
345 Error,
346 >,
347 >,
348 ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
349 let (request, context) = request.into_parts();
351
352 let response_generator = request.response_generator();
354 let mut response_generator = Box::new(response_generator);
355
356 let (common_request, annotations) = self.preprocess_request(&request)?;
358
359 response_generator.update_isl(common_request.token_ids.len() as u32);
361
362 let common_request = context.map(|_| common_request);
364
365 let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
367 .into_iter()
368 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
369 .collect();
370 let annotations_stream = stream::iter(annotations);
371
372 let response_stream = next.generate(common_request).await?;
374
375 let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
377 let context = stream.context();
378
379 let stream = annotations_stream.chain(stream);
381
382 Ok(ResponseStream::new(Box::pin(stream), context))
384 }
385}
386
387#[async_trait]
388impl
389 Operator<
390 SingleIn<NvCreateCompletionRequest>,
391 ManyOut<Annotated<CompletionResponse>>,
392 SingleIn<PreprocessedRequest>,
393 ManyOut<Annotated<BackendOutput>>,
394 > for OpenAIPreprocessor
395{
396 async fn generate(
397 &self,
398 request: SingleIn<NvCreateCompletionRequest>,
399 next: Arc<
400 dyn AsyncEngine<
401 SingleIn<PreprocessedRequest>,
402 ManyOut<Annotated<BackendOutput>>,
403 Error,
404 >,
405 >,
406 ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
407 let (request, context) = request.into_parts();
409
410 let response_generator = request.response_generator();
412 let mut response_generator = Box::new(response_generator);
413 let (common_request, annotations) = self.preprocess_request(&request)?;
415
416 response_generator.update_isl(common_request.token_ids.len() as i32);
418
419 let common_request = context.map(|_| common_request);
421
422 let annotations: Vec<Annotated<CompletionResponse>> = annotations
424 .into_iter()
425 .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
426 .collect();
427 let annotations_stream = stream::iter(annotations);
428
429 let response_stream = next.generate(common_request).await?;
431
432 let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
434 let context = stream.context();
435
436 let stream = annotations_stream.chain(stream);
438
439 Ok(ResponseStream::new(Box::pin(stream), context))
441 }
442}