pub mod media;
pub mod prompt;
pub mod speculative_prefill;
pub mod tools;
use anyhow::Context;
use anyhow::{Result, bail};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat,
};
use dynamo_runtime::error::{DynamoError, ErrorType};
use futures::Stream;
use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter;
use std::time::{Duration, Instant};
use std::{collections::HashMap, pin::Pin, sync::Arc};
use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::media::MediaLoader;
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, RoutingHints,
};
use crate::protocols::common::timing::RequestTracker;
use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{
AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
};
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
use crate::protocols::{
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{
DeltaGeneratorExt,
chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, jail::JailedStream,
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
nvext::NvExtProvider,
},
};
use crate::tokenizers::traits::Tokenizer;
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
pub use crate::protocols::common::preprocessor::PreprocessedEmbeddingRequest;
use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LLMMetricAnnotation {
pub input_tokens: usize,
pub output_tokens: usize,
pub chunk_tokens: usize,
pub cached_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_dp_rank: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokenize_latency: Option<Duration>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub detokenize_total_latency: Option<Duration>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub detokenize_count: Option<u64>,
}
impl LLMMetricAnnotation {
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
}
pub fn from_annotation<T>(
annotation: &Annotated<T>,
) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
if annotation.event.is_none() {
return Ok(None);
}
if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
return Ok(None);
}
let comments = annotation
.comment
.as_ref()
.ok_or("missing comments block")?;
if comments.len() != 1 {
return Err("malformed comments block - expected exactly 1 comment".into());
}
let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
Ok(Some(metrics))
}
}
struct ReasoningState {
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
reasoning_parser: Option<Box<dyn ReasoningParser>>,
}
pub struct OpenAIPreprocessor {
mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>,
model_info: Arc<dyn ModelInfo>,
lora_name: Option<String>,
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>,
media_loader: Option<MediaLoader>,
context_length: u32,
}
impl OpenAIPreprocessor {
pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let formatter = PromptFormatter::from_mdc(&mdc)?;
let tokenizer = mdc.tokenizer()?;
match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
}
}
pub fn new_with_parts(
mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: crate::tokenizers::Tokenizer,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum().to_string();
let tokenizer: Arc<dyn Tokenizer> = (*tokenizer).clone();
let lora_name = mdc.lora.as_ref().map(|l| l.name.clone());
let Some(ref model_info) = mdc.model_info else {
anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
);
};
let model_info = model_info.get_model_info()?;
let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
if let Some(ref lora_name) = lora_name {
tracing::info!(model = %mdc.display_name, lora_name, "LoRA adapter detected in MDC");
}
let runtime_config = mdc.runtime_config.clone();
let media_loader = match mdc.media_decoder {
Some(media_decoder) => Some(MediaLoader::new(media_decoder, mdc.media_fetcher)?),
None => None,
};
let context_length = mdc.context_length;
Ok(Arc::new(Self {
formatter,
tokenizer,
model_info,
mdcsum,
lora_name,
runtime_config,
tool_call_parser,
media_loader,
context_length,
}))
}
pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
self.tokenizer.encode(s)
}
pub async fn preprocess_request<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
tracker: Option<&RequestTracker>,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut builder = self.builder(request)?;
let formatted_prompt = self
.apply_template(request)
.with_context(|| "Failed to apply prompt template")?;
let annotations = self
.gather_tokens(request, &mut builder, formatted_prompt.clone(), tracker)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder, formatted_prompt)
.await
.with_context(|| "Failed to gather multimodal data")?;
Ok((builder.build()?, annotations))
}
pub fn builder<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<PreprocessedRequestBuilder> {
let mut builder = PreprocessedRequest::builder();
builder.model(request.model());
let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}
stop_conditions.apply_ignore_eos();
if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
let lora_name = self.lora_name.clone();
if let Some(nvext) = request.nvext() {
let hints = nvext.agent_hints.as_ref();
let routing = RoutingHints {
backend_instance_id: nvext.backend_instance_id,
prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id,
dp_rank: None, expected_output_tokens: hints.and_then(|h| h.osl),
priority_jump: hints.and_then(|h| h.latency_sensitivity),
priority: hints.and_then(|h| h.priority),
lora_name,
allowed_worker_ids: None,
};
builder.routing(Some(routing));
} else if lora_name.is_some() {
builder.routing(Some(RoutingHints {
lora_name,
..Default::default()
}));
}
Ok(builder)
}
pub fn apply_template<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<Option<String>> {
if let PromptInput::Text(_) = request.prompt_input_type()
&& let Some(TextInput::Single(_)) = request.extract_text()
{
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
Ok(Some(formatted_prompt))
} else {
Ok(None)
}
}
pub async fn gather_multi_modal_data<R: OAIChatLikeRequest>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<()> {
let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks: Vec<(String, ChatCompletionRequestUserMessageContentPart)> =
Vec::new();
let Some(messages) = request.typed_messages() else {
return Ok(());
};
for message in messages.iter() {
let content_parts = match message {
ChatCompletionRequestMessage::User(u) => match &u.content {
ChatCompletionRequestUserMessageContent::Array(parts) => parts,
_ => continue,
},
_ => continue,
};
for content_part in content_parts.iter() {
let (type_str, url) = match content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
("image_url".to_string(), image_part.image_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
("video_url".to_string(), video_part.video_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(audio_part) => {
("audio_url".to_string(), audio_part.audio_url.url.clone())
}
_ => continue,
};
if self.media_loader.is_some() {
fetch_tasks.push((type_str, content_part.clone()));
continue;
}
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Url(url));
}
}
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let media_io_kwargs = request.media_io_kwargs();
let results = futures::future::join_all(fetch_tasks.iter().map(|(_, content_part)| {
loader.fetch_and_decode_media_part(content_part, media_io_kwargs)
}))
.await;
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
let rdma_descriptor = result?;
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Decoded(rdma_descriptor));
}
}
if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
let messages_json = serde_json::to_value(request.messages())?;
let mut extra_args = serde_json::json!({
"messages": messages_json
});
if let Some(ref prompt) = formatted_prompt {
extra_args["formatted_prompt"] = serde_json::Value::String(prompt.clone());
}
builder.extra_args(Some(extra_args));
}
Ok(())
}
pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
tracker: Option<&RequestTracker>,
) -> Result<HashMap<String, String>> {
let mut annotations = HashMap::new();
let mut token_count: Option<usize> = None;
match request.prompt_input_type() {
PromptInput::Tokens(_) => {
if let Some(token_input) = request.extract_tokens() {
match token_input {
TokenInput::Single(tokens) => {
token_count = Some(tokens.len());
builder.token_ids(tokens);
}
TokenInput::Batch(token_batches) => {
if token_batches.len() == 1 {
token_count = Some(token_batches[0].len());
builder.token_ids(token_batches[0].clone());
} else {
bail!(
"Batch token input not supported for more than one token in requests (got {})",
token_batches.len()
);
}
}
}
}
}
PromptInput::Text(_) => {
if let Some(text_input) = request.extract_text() {
match text_input {
TextInput::Single(raw_prompt) => {
if let Some(f) = formatted_prompt.as_ref()
&& request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
{
annotations
.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
}
let prompt = formatted_prompt.unwrap_or(raw_prompt);
let has_backend_instance_id = request
.nvext()
.and_then(|ext| ext.backend_instance_id)
.is_some();
let token_data =
request.nvext().and_then(|ext| ext.token_data.as_ref());
let (tokens_vec, skip_token_annotation) = if has_backend_instance_id {
if let Some(tokens) = token_data {
tracing::trace!(
"Using provided tokens from EPP: {} ids",
tokens.len()
);
(tokens.clone(), true)
} else {
tracing::warn!(
"backend_instance_id provided but no token_data; tokenizing prompt"
);
let encoding = self.encode_with_timing(&prompt, tracker)?;
(encoding.token_ids().to_vec(), false)
}
} else {
let encoding = self.encode_with_timing(&prompt, tracker)?;
(encoding.token_ids().to_vec(), false)
};
if request.has_annotation(ANNOTATION_TOKEN_IDS)
&& !skip_token_annotation
{
annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&tokens_vec)?,
);
}
token_count = Some(tokens_vec.len());
builder.token_ids(tokens_vec);
}
TextInput::Batch(texts) => {
if texts.len() == 1 {
let encoding = self.encode_with_timing(&texts[0], tracker)?;
let tokens = encoding.token_ids().to_vec();
token_count = Some(tokens.len());
builder.token_ids(tokens);
} else {
bail!(
"Batch text input not supported for more than one text in requests (got {})",
texts.len()
);
}
}
}
}
}
}
if let Some(count) = token_count {
Self::validate_token_count(count, self.context_length)?;
}
Ok(annotations)
}
fn validate_token_count(token_count: usize, context_length: u32) -> Result<()> {
let max_len = context_length as usize;
if max_len > 0 && token_count >= max_len {
return Err(DynamoError::builder()
.error_type(ErrorType::InvalidArgument)
.message(format!(
"This model's maximum context length is {} tokens. \
However, your messages resulted in {} tokens. \
Please reduce the length of the messages.",
max_len, token_count,
))
.build()
.into());
}
Ok(())
}
fn encode_with_timing(
&self,
prompt: &str,
tracker: Option<&RequestTracker>,
) -> anyhow::Result<Encoding> {
let encode_start = Instant::now();
let encoding = self.tokenizer.encode(prompt)?;
if let Some(t) = tracker {
t.record_tokenize_latency(encode_start.elapsed());
}
Ok(encoding)
}
pub async fn preprocess_embedding_request(
&self,
request: &NvCreateEmbeddingRequest,
) -> Result<(PreprocessedEmbeddingRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new();
let mut builder = PreprocessedEmbeddingRequest::builder();
let all_token_ids = match &request.inner.input {
dynamo_async_openai::types::EmbeddingInput::String(s) => {
let encoding = self.tokenizer.encode(s)?;
vec![encoding.token_ids().to_vec()]
}
dynamo_async_openai::types::EmbeddingInput::StringArray(arr) => {
let input_strs: Vec<String> = arr.to_vec();
let encodings = tokio::task::spawn_blocking({
let tokenizer = self.tokenizer.clone();
let strs = input_strs.clone();
move || {
tokenizer.encode_batch(&strs.iter().map(|s| s.as_str()).collect::<Vec<_>>())
}
})
.await??;
let token_arrays: Vec<Vec<u32>> = encodings
.into_iter()
.map(|encoding| encoding.token_ids().to_vec())
.collect();
token_arrays
}
dynamo_async_openai::types::EmbeddingInput::IntegerArray(token_ids) => {
vec![token_ids.clone()]
}
dynamo_async_openai::types::EmbeddingInput::ArrayOfIntegerArray(token_arrays) => {
token_arrays.clone()
}
};
if request.has_annotation(ANNOTATION_TOKEN_IDS) {
annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&all_token_ids)?,
);
}
builder.token_ids(all_token_ids);
builder.model(request.inner.model.clone());
builder.encoding_format(request.inner.encoding_format.as_ref().map(|f| match f {
EncodingFormat::Float => "float".to_string(),
EncodingFormat::Base64 => "base64".to_string(),
}));
builder.dimensions(request.inner.dimensions);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
Ok((builder.build()?, annotations))
}
pub fn postprocessor_parsing_stream<S>(
&self,
stream: S,
request: &NvCreateChatCompletionRequest,
) -> anyhow::Result<
impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
>
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some()
&& !Self::is_reasoning_disabled_by_request(
self.runtime_config.reasoning_parser.as_deref(),
request.chat_template_args.as_ref(),
);
let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_parse_reasoning {
Box::pin(Self::parse_reasoning_content_from_stream(
stream,
self.runtime_config.reasoning_parser.clone().unwrap(), ))
} else {
Box::pin(stream)
};
let has_tools = request
.inner
.tools
.as_ref()
.is_some_and(|tools| !tools.is_empty());
let should_jail = Self::should_apply_tool_jail(
self.tool_call_parser.as_ref(),
request.inner.tool_choice.as_ref(),
has_tools,
)?;
let tool_definitions = request.inner.tools.as_ref().map(|tools| {
tools
.iter()
.map(|tool| dynamo_parsers::tool_calling::ToolDefinition {
name: tool.function.name.clone(),
parameters: tool.function.parameters.clone(),
})
.collect()
});
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
Box::pin(Self::apply_tool_calling_jail(
self.tool_call_parser.clone(),
request.inner.tool_choice.clone(),
tool_definitions,
stream,
))
} else {
Box::pin(stream)
};
Ok(transformed_stream)
}
pub fn transform_postprocessor_stream<S, Resp>(
stream: S,
generator: Box<dyn DeltaGeneratorExt<Resp>>,
context: Arc<dyn AsyncEngineContext>,
) -> impl Stream<Item = Annotated<Resp>> + Send
where
S: Stream<Item = Annotated<BackendOutput>> + Send + 'static,
Resp: Send + Sync + 'static + std::fmt::Debug,
{
struct State<Resp>
where
Resp: Send + Sync + 'static + std::fmt::Debug,
{
response_stream: Pin<Box<dyn Stream<Item = Annotated<BackendOutput>> + Send>>,
response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
context: Arc<dyn AsyncEngineContext>,
cancelled: bool,
cumulative_output_tokens: usize,
finish_reason_sent: bool,
usage_chunk_sent: bool,
finished: bool,
}
let state = State {
response_stream: Box::pin(stream),
response_generator: generator,
context: context.clone(),
cancelled: false,
cumulative_output_tokens: 0,
finish_reason_sent: false,
usage_chunk_sent: false,
finished: false,
};
stream::unfold(state, |mut inner| {
async move {
if inner.finished {
return None;
}
if let Some(response) = inner.response_stream.next().await {
if inner.cancelled {
tracing::debug!(
request_id = inner.context.id(),
"Cancellation issued last message; closing stream"
);
inner.finished = true; return None;
}
tracing::trace!(
request_id = inner.context.id(),
"Processing common response: {:?}",
response
);
let has_finish_reason = response
.data
.as_ref()
.map(|d| d.finish_reason.is_some())
.unwrap_or(false);
let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
let chunk_tokens = backend_output.token_ids.len();
inner.cumulative_output_tokens += chunk_tokens;
let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
(chunk_tokens, isl)
} else {
(0, 0)
};
let current_osl = inner.cumulative_output_tokens;
let mut response = response.map_data(|data| {
inner
.response_generator
.choice_from_postprocessor(data)
.inspect_err(|e| {
tracing::error!(
request_id = inner.context.id(),
"Error processing common response: {:?}",
e
);
inner.cancelled = true;
inner.context.stop_generating();
})
.map_err(|e| e.to_string())
});
let tracker = inner.response_generator.tracker();
let prefill_worker_id = tracker.as_ref().and_then(|t| t.prefill_worker_id());
let prefill_dp_rank = tracker.as_ref().and_then(|t| t.prefill_dp_rank());
let prefill_worker_type = tracker
.as_ref()
.and_then(|t| t.prefill_worker_type())
.map(String::from);
let decode_worker_id = tracker.as_ref().and_then(|t| t.decode_worker_id());
let decode_dp_rank = tracker.as_ref().and_then(|t| t.decode_dp_rank());
let decode_worker_type = tracker
.as_ref()
.and_then(|t| t.decode_worker_type())
.map(String::from);
let llm_metrics = LLMMetricAnnotation {
input_tokens: isl,
output_tokens: current_osl,
chunk_tokens,
cached_tokens: None,
prefill_worker_id,
prefill_dp_rank,
prefill_worker_type,
decode_worker_id,
decode_dp_rank,
decode_worker_type,
tokenize_latency: tracker.as_ref().and_then(|t| t.tokenize_latency()),
detokenize_total_latency: tracker.as_ref().and_then(|t| t.detokenize_total_latency()),
detokenize_count: tracker.as_ref().map(|t| t.detokenize_count()),
};
if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
if response.event.is_none() {
response.event = metrics_annotated.event;
response.comment = metrics_annotated.comment;
}
}
if has_finish_reason {
inner.finish_reason_sent = true;
}
tracing::trace!(
request_id = inner.context.id(),
"OpenAI NvCreateChatCompletionStreamResponse: {:?}",
response
);
Some((response, inner))
} else {
inner.finished = true;
if inner.finish_reason_sent && !inner.usage_chunk_sent {
inner.usage_chunk_sent = true;
let usage_chunk = inner.response_generator.create_usage_chunk();
let usage = inner.response_generator.get_usage();
let tracker = inner.response_generator.tracker();
let prefill_worker_id =
tracker.as_ref().and_then(|t| t.prefill_worker_id());
let prefill_dp_rank = tracker.as_ref().and_then(|t| t.prefill_dp_rank());
let prefill_worker_type = tracker
.as_ref()
.and_then(|t| t.prefill_worker_type())
.map(String::from);
let decode_worker_id = tracker.as_ref().and_then(|t| t.decode_worker_id());
let decode_dp_rank = tracker.as_ref().and_then(|t| t.decode_dp_rank());
let decode_worker_type = tracker
.as_ref()
.and_then(|t| t.decode_worker_type())
.map(String::from);
let llm_metrics = LLMMetricAnnotation {
input_tokens: usage.prompt_tokens as usize,
output_tokens: usage.completion_tokens as usize,
chunk_tokens: 0,
cached_tokens: usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens.map(|c| c as usize)),
prefill_worker_id,
prefill_dp_rank,
prefill_worker_type,
decode_worker_id,
decode_dp_rank,
decode_worker_type,
tokenize_latency: tracker.as_ref().and_then(|t| t.tokenize_latency()),
detokenize_total_latency: tracker
.as_ref()
.and_then(|t| t.detokenize_total_latency()),
detokenize_count: tracker.as_ref().map(|t| t.detokenize_count()),
};
let annotation = llm_metrics.to_annotation::<()>().unwrap_or_else(|e| {
tracing::warn!("Failed to serialize metrics: {}", e);
Annotated::<()>::from_data(())
});
let data = if inner.response_generator.is_usage_enabled() {
Some(usage_chunk)
} else {
None
};
let annotated_usage = Annotated::<Resp> {
id: None,
data,
event: Some(ANNOTATION_LLM_METRICS.to_string()),
comment: annotation.comment,
error: None,
};
tracing::trace!(
request_id = inner.context.id(),
"Sending final usage chunk for OpenAI compliance, annotated_usage: {:?}",
annotated_usage
);
Some((annotated_usage, inner))
} else {
None
}
}
}
})
.fuse()
}
pub fn transform_embedding_postprocessor_stream<S>(
stream: S,
original_request: NvCreateEmbeddingRequest,
) -> impl Stream<Item = Annotated<NvCreateEmbeddingResponse>> + Send
where
S: Stream<Item = Annotated<EmbeddingsEngineOutput>> + Send + 'static,
{
stream.map(move |output| {
output.map_data(|engine_output| {
let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output
.embeddings
.into_iter()
.enumerate()
.map(|(index, embedding)| dynamo_async_openai::types::Embedding {
index: index as u32,
object: "embedding".to_string(),
embedding: embedding.into_iter().map(|f| f as f32).collect(),
})
.collect();
let response = NvCreateEmbeddingResponse {
inner: dynamo_async_openai::types::CreateEmbeddingResponse {
object: "list".to_string(),
model: original_request.inner.model.clone(),
data: embeddings,
usage: dynamo_async_openai::types::EmbeddingUsage {
prompt_tokens: engine_output.prompt_tokens,
total_tokens: engine_output.total_tokens,
},
},
};
Ok(response)
})
})
}
pub fn should_apply_tool_jail(
tool_call_parser: Option<&String>,
tool_choice: Option<&ChatCompletionToolChoiceOption>,
has_tools: bool,
) -> std::result::Result<bool, Error> {
match (tool_call_parser, tool_choice, has_tools) {
(None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true),
(None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true),
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
tracing::warn!(
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}
(Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
Ok(false) }
(Some(_), Some(_), true) => Ok(true), (Some(_), None, true) => Ok(true),
_ => Ok(false),
}
}
pub fn apply_tool_calling_jail<S>(
tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
let mut builder = JailedStream::builder();
if let Some(tool_definitions) = tool_definitions
&& !tool_definitions.is_empty()
{
builder = builder.tool_definitions(tool_definitions);
}
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
builder = builder.tool_choice_required();
}
Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None)
| None => {
if let Some(parser) = tool_call_parser {
builder = builder.tool_call_parser(parser);
}
}
}
let jail = builder.build();
jail.apply_with_finish_reason(stream)
}
fn is_reasoning_disabled_by_request(
reasoning_parser: Option<&str>,
chat_template_args: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> bool {
match reasoning_parser {
Some("kimi_k25") => {
if let Some(args) = chat_template_args
&& let Some(thinking) = args.get("thinking")
{
return thinking == &serde_json::Value::Bool(false);
}
false
}
Some("nemotron_nano") | Some("nemotron3") => {
if let Some(args) = chat_template_args {
if let Some(enable_thinking) = args.get("enable_thinking")
&& enable_thinking == &serde_json::Value::Bool(false)
{
return true;
}
if let Some(force_nonempty) = args.get("force_nonempty_content")
&& force_nonempty == &serde_json::Value::Bool(true)
{
return true;
}
}
false
}
_ => false,
}
}
pub fn parse_reasoning_content_from_stream<S>(
stream: S,
parser_name: String,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let reasoning_parser = Box::new(ReasoningParserType::get_reasoning_parser_from_name(
parser_name.as_ref(),
)) as Box<dyn ReasoningParser>;
let state = ReasoningState {
stream: Box::pin(stream),
reasoning_parser: Some(reasoning_parser),
};
stream::unfold(state, |mut state| async move {
if let Some(response) = state.stream.next().await {
let processed_response = if let Some(ref mut parser) = state.reasoning_parser {
response.map_data(|mut data| {
for choice in data.choices.iter_mut() {
if let Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
text,
),
) = choice.delta.content.as_ref()
{
let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]);
choice.delta.content = parser_result.get_some_normal_text().map(
dynamo_async_openai::types::ChatCompletionMessageContent::Text,
);
choice.delta.reasoning_content = parser_result.get_some_reasoning();
}
}
Ok(data)
})
} else {
response
};
Some((processed_response, state))
} else {
None
}
})
.fuse()
}
}
#[async_trait]
impl
Operator<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
next: Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (mut request, context) = request.into_parts();
let request_id = context.id().to_string();
let original_stream_flag = request.inner.stream.unwrap_or(false);
let mut audit_handle = crate::audit::handle::create_handle(&request, &request_id);
if let Some(ref mut h) = audit_handle {
h.set_request(std::sync::Arc::new(request.clone()));
}
request.enable_usage_for_nonstreaming(original_stream_flag);
request.inner.stream = Some(true);
let response_generator = request.response_generator(context.id().to_string());
let tracker = response_generator.tracker();
let (mut common_request, annotations) = self
.preprocess_request(&request, tracker.as_deref())
.await?;
tracing::trace!(request = ?common_request, "Pre-processed request");
common_request.tracker = tracker;
let mut response_generator = Box::new(response_generator);
if common_request.prompt_embeds.is_none() {
let isl = common_request.token_ids.len() as u32;
response_generator.update_isl(isl);
}
let common_request = context.map(|_| common_request);
let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect();
let annotations_stream = stream::iter(annotations);
let response_stream = next.generate(common_request).await?;
let context = response_stream.context();
let stream = Self::transform_postprocessor_stream(
response_stream,
response_generator,
context.clone(),
);
let transformed_stream = self.postprocessor_parsing_stream(stream, &request)?;
let final_stream = if let Some(mut audit) = audit_handle {
let (stream, agg_fut) = if audit.streaming() {
crate::audit::stream::scan_aggregate_with_future(transformed_stream)
} else {
crate::audit::stream::fold_aggregate_with_future(transformed_stream)
};
tokio::spawn(async move {
let final_resp = agg_fut.await;
audit.set_response(Arc::new(final_resp));
audit.emit();
});
stream
} else {
Box::pin(transformed_stream)
};
let final_stream = speculative_prefill::maybe_wrap_stream(
final_stream,
&request,
&next,
&self.formatter,
&self.tokenizer,
);
let stream = annotations_stream.chain(final_stream);
Ok(ResponseStream::new(Box::pin(stream), context))
}
}
#[async_trait]
impl
Operator<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor
{
async fn generate(
&self,
request: SingleIn<NvCreateCompletionRequest>,
next: Arc<
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (mut request, context) = request.into_parts();
let original_stream_flag = request.inner.stream.unwrap_or(false);
request.enable_usage_for_nonstreaming(original_stream_flag);
request.inner.stream = Some(true);
let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator);
let tracker = response_generator.tracker();
let mut builder = self.builder(&request)?;
let annotations = if let Some(ref prompt_embeds) = request.inner.prompt_embeds {
builder.token_ids(vec![]); builder.prompt_embeds(Some(prompt_embeds.clone()));
HashMap::new()
} else {
self.gather_tokens(&request, &mut builder, None, tracker.as_deref())?
};
self.gather_multi_modal_data(&request, &mut builder, None)
.await?;
let mut common_request = builder.build()?;
common_request.tracker = tracker;
if common_request.prompt_embeds.is_none() {
let isl = common_request.token_ids.len() as u32;
response_generator.update_isl(isl);
}
let common_request = context.map(|_| common_request);
let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect();
let annotations_stream = stream::iter(annotations);
let response_stream = next.generate(common_request).await?;
let context = response_stream.context();
let stream = Self::transform_postprocessor_stream(
response_stream,
response_generator,
context.clone(),
);
let stream = annotations_stream.chain(stream);
Ok(ResponseStream::new(Box::pin(stream), context))
}
}
#[async_trait]
impl
Operator<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
> for OpenAIPreprocessor
{
async fn generate(
&self,
request: SingleIn<NvCreateEmbeddingRequest>,
next: Arc<
dyn AsyncEngine<
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
Error,
>,
>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
let (request, context) = request.into_parts();
let (preprocessed_request, annotations) =
self.preprocess_embedding_request(&request).await?;
let preprocessed_request = context.map(|_| preprocessed_request);
let response_stream = next.generate(preprocessed_request).await?;
let context = response_stream.context();
let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
let annotations_stream = stream::iter(
annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect::<Vec<_>>(),
);
let combined_stream = annotations_stream.chain(stream);
Ok(ResponseStream::new(Box::pin(combined_stream), context))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_reasoning_disabled_by_request() {
let thinking_true = {
let mut m = std::collections::HashMap::new();
m.insert("thinking".to_string(), serde_json::Value::Bool(true));
m
};
let thinking_false = {
let mut m = std::collections::HashMap::new();
m.insert("thinking".to_string(), serde_json::Value::Bool(false));
m
};
let enable_thinking_true = {
let mut m = std::collections::HashMap::new();
m.insert("enable_thinking".to_string(), serde_json::Value::Bool(true));
m
};
let enable_thinking_false = {
let mut m = std::collections::HashMap::new();
m.insert(
"enable_thinking".to_string(),
serde_json::Value::Bool(false),
);
m
};
let empty_args = std::collections::HashMap::new();
let cases = [
(
Some("kimi_k25"),
Some(&thinking_false),
true,
"kimi_k25 + thinking=false → disabled",
),
(
Some("kimi_k25"),
Some(&thinking_true),
false,
"kimi_k25 + thinking=true → enabled",
),
(
Some("kimi_k25"),
None,
false,
"kimi_k25 + no args → enabled",
),
(
Some("kimi_k25"),
Some(&empty_args),
false,
"kimi_k25 + empty args → enabled",
),
(
Some("deepseek_r1"),
Some(&thinking_false),
false,
"deepseek_r1 → never disabled",
),
(
Some("basic"),
Some(&thinking_false),
false,
"basic → never disabled",
),
(
None,
Some(&thinking_false),
false,
"no parser → never disabled",
),
(
Some("nemotron_nano"),
Some(&enable_thinking_false),
true,
"nemotron_nano + enable_thinking=false → disabled",
),
(
Some("nemotron_nano"),
Some(&enable_thinking_true),
false,
"nemotron_nano + enable_thinking=true → enabled",
),
(
Some("nemotron_nano"),
None,
false,
"nemotron_nano + no args → enabled",
),
(
Some("nemotron_nano"),
Some(&empty_args),
false,
"nemotron_nano + empty args → enabled",
),
];
for (parser, args, expected, desc) in cases {
assert_eq!(
OpenAIPreprocessor::is_reasoning_disabled_by_request(parser, args),
expected,
"FAILED: {desc}",
);
}
}
}