use std::sync::Arc;
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{
local_model::runtime_config::ModelRuntimeConfig,
protocols::{
common::{self, timing::RequestTracker},
openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponseFieldSelection},
token_to_utf8_bytes,
},
},
types::TokenIdType,
};
impl NvCreateChatCompletionRequest {
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
if !original_stream_flag {
if self.inner.stream_options.is_none() {
self.inner.stream_options =
Some(dynamo_protocols::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
opts.include_usage = true;
}
}
}
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext());
let options = DeltaGeneratorOptions {
enable_usage: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
response_fields,
return_tokens_as_token_ids: self.return_tokens_as_token_ids.unwrap_or(false),
runtime_config: ModelRuntimeConfig::default(),
};
DeltaGenerator::new(self.inner.model.clone(), options, request_id)
}
}
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub continuous_usage_stats: bool,
pub enable_logprobs: bool,
pub response_fields: NvExtResponseFieldSelection,
pub return_tokens_as_token_ids: bool,
pub runtime_config: ModelRuntimeConfig,
}
pub struct DeltaGenerator {
id: String,
object: String,
created: u32,
model: String,
system_fingerprint: Option<String>,
service_tier: Option<dynamo_protocols::types::ServiceTierResponse>,
usage: dynamo_protocols::types::CompletionUsage,
msg_counter: u64,
options: DeltaGeneratorOptions,
tracker: Option<Arc<RequestTracker>>,
}
impl DeltaGenerator {
pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
let usage = dynamo_protocols::types::CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
};
let chatcmpl_id = format!("chatcmpl-{request_id}");
let tracker = Some(Arc::new(RequestTracker::new()));
Self {
id: chatcmpl_id,
object: "chat.completion.chunk".to_string(),
created: now,
model,
system_fingerprint: None,
service_tier: None,
usage,
msg_counter: 0,
options,
tracker,
}
}
pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
self.tracker.clone()
}
pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl;
}
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: &[TokenIdType],
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<dynamo_protocols::types::ChatChoiceLogprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let return_as_ids = self.options.return_tokens_as_token_ids;
let content = top_logprobs.map(|top_logprobs| {
toks.iter()
.zip(tok_lps)
.zip(top_logprobs)
.map(|(((t, tid), lp), top_lps)| {
let token_str = if return_as_ids {
format!("token_id:{}", tid)
} else {
t.clone()
};
let converted =
convert_backend_top_logprobs(&top_lps, t, *tid, lp, return_as_ids);
dynamo_protocols::types::ChatCompletionTokenLogprob {
token: token_str.clone(),
logprob: lp,
bytes: token_to_utf8_bytes(&token_str),
top_logprobs: converted,
}
})
.collect()
});
Some(dynamo_protocols::types::ChatChoiceLogprobs {
content,
refusal: None,
})
}
pub fn create_choice(
&mut self,
index: u32,
text: Option<String>,
finish_reason: Option<dynamo_protocols::types::FinishReason>,
logprobs: Option<dynamo_protocols::types::ChatChoiceLogprobs>,
) -> NvCreateChatCompletionStreamResponse {
let delta = dynamo_protocols::types::ChatCompletionStreamResponseDelta {
content: text.map(dynamo_protocols::types::ChatCompletionMessageContent::Text),
function_call: None,
tool_calls: None,
role: if self.msg_counter == 0 {
Some(dynamo_protocols::types::Role::Assistant)
} else {
None
},
refusal: None,
reasoning_content: None,
};
let choice = dynamo_protocols::types::ChatChoiceStream {
index,
delta,
finish_reason,
logprobs,
};
let choices = vec![choice];
NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices,
usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
service_tier: self.service_tier.clone(),
},
nvext: None, }
}
pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
let usage = self.get_usage();
NvCreateChatCompletionStreamResponse {
inner: dynamo_protocols::types::CreateChatCompletionStreamResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![], usage: Some(usage),
service_tier: self.service_tier.clone(),
},
nvext: None,
}
}
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_protocols::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
usage
}
}
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
for DeltaGenerator
{
fn choice_from_postprocessor(
&mut self,
delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
if let Some(completion_usage) = delta.completion_usage.as_ref() {
self.usage.prompt_tokens = completion_usage.prompt_tokens;
if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
}
let logprobs = self.create_logprobs(
delta.tokens,
&delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
let finish_reason = match delta.finish_reason {
Some(common::FinishReason::EoS) => Some(dynamo_protocols::types::FinishReason::Stop),
Some(common::FinishReason::Stop) => Some(dynamo_protocols::types::FinishReason::Stop),
Some(common::FinishReason::Length) => {
Some(dynamo_protocols::types::FinishReason::Length)
}
Some(common::FinishReason::Cancelled) => {
Some(dynamo_protocols::types::FinishReason::Stop)
}
Some(common::FinishReason::ContentFilter) => {
Some(dynamo_protocols::types::FinishReason::ContentFilter)
}
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg));
}
None => None,
};
let stop_reason = delta.stop_reason.clone();
let index = delta.index.unwrap_or(0);
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
if finish_reason.is_some()
&& let Some(ref tracker) = self.tracker
{
tracker.record_finish();
}
if let Some(nvext_response) = self.options.response_fields.build_response_nvext(
self.tracker.as_ref(),
delta.disaggregated_params.as_ref(),
finish_reason.is_some(),
delta.engine_data,
stop_reason,
) && let Ok(nvext_json) = serde_json::to_value(&nvext_response)
{
stream_response.nvext = Some(nvext_json);
if let Some(ref info) = nvext_response.worker_id {
tracing::debug!(
"Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
if let Some(ref tokens) = nvext_response.token_ids {
tracing::debug!(
"Injected token_ids into chat completion nvext: {} tokens",
tokens.len()
);
}
}
Ok(stream_response)
}
fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens)
}
fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
DeltaGenerator::create_usage_chunk(self)
}
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_protocols::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
fn tracker(&self) -> Option<std::sync::Arc<crate::protocols::common::timing::RequestTracker>> {
self.tracker.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::{self, llm_backend::BackendOutput, timing::WORKER_TYPE_PREFILL};
use crate::protocols::openai::DeltaGeneratorExt;
use dynamo_protocols::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
};
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
return_tokens_as_token_ids: None,
unsupported_fields: Default::default(),
}
}
#[test]
fn test_enable_usage_for_nonstreaming_enables_usage() {
let mut request = create_test_request();
assert!(request.inner.stream_options.is_none());
request.enable_usage_for_nonstreaming(false);
assert!(
request.inner.stream_options.is_some(),
"Non-streaming request should have stream_options created"
);
assert!(
request.inner.stream_options.unwrap().include_usage,
"Non-streaming request should have include_usage=true for OpenAI compliance"
);
assert!(
!request.inner.stream_options.unwrap().continuous_usage_stats,
"Non-streaming request should have continuous_usage_stats=false for OpenAI compliance"
);
}
#[test]
fn test_enable_usage_for_nonstreaming_ignores_streaming() {
let mut request = create_test_request();
assert!(request.inner.stream_options.is_none());
request.enable_usage_for_nonstreaming(true);
assert!(
request.inner.stream_options.is_none(),
"Streaming request should not have stream_options modified"
);
}
fn make_request_with_nvext(
nvext: crate::protocols::openai::nvext::NvExt,
) -> NvCreateChatCompletionRequest {
let mut request = create_test_request();
request.nvext = Some(nvext);
request
}
fn final_backend_output() -> BackendOutput {
BackendOutput {
token_ids: vec![1],
tokens: vec![Some("hello".to_string())],
text: Some("hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: Some(serde_json::json!({
"token_ids": [11, 22, 33],
"routed_experts": {"layer_0": [1, 3]}
})),
engine_data: None,
}
}
fn create_test_request_with_extra_fields(fields: Vec<String>) -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(true),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: Some(
crate::protocols::openai::nvext::NvExt::builder()
.extra_fields(fields)
.build()
.unwrap(),
),
chat_template_args: None,
media_io_kwargs: None,
return_tokens_as_token_ids: None,
unsupported_fields: Default::default(),
}
}
fn make_backend_output_with_engine_data() -> crate::protocols::common::llm_backend::BackendOutput
{
crate::protocols::common::llm_backend::BackendOutput {
token_ids: vec![42],
tokens: vec![Some("hello".to_string())],
text: Some("hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(crate::protocols::common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
engine_data: Some(serde_json::json!({
"kv_transfer_time_ms": 12.3,
"disaggregated_kv_transfer_time_ms": 8.1,
"prefill_compute_time_ms": 45.6
})),
}
}
#[test]
fn test_plain_request_without_extra_fields_omits_nvext() {
let request = create_test_request();
let mut generator = request.response_generator("req-no-nvext".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
assert!(response.nvext.is_none());
}
#[test]
fn test_backend_choice_index_is_preserved() {
let request = create_test_request();
let mut generator = request.response_generator("req-choice-index".to_string());
let mut output = final_backend_output();
output.index = Some(2);
let response = generator
.choice_from_postprocessor(output)
.expect("choice generation");
assert_eq!(response.inner.choices[0].index, 2);
}
#[test]
fn test_stop_reason_emits_in_nvext_when_requested() {
let request = create_test_request_with_extra_fields(vec!["stop_reason".to_string()]);
let mut generator = request.response_generator("req-stop-reason-nvext".to_string());
let mut output = final_backend_output();
output.stop_reason = Some(dynamo_protocols::types::StopReason::String(
"END".to_string(),
));
let response = generator
.choice_from_postprocessor(output)
.expect("choice generation");
let response_json = serde_json::to_value(&response).expect("serialize response");
assert!(response_json["choices"][0].get("stop_reason").is_none());
assert_eq!(response_json["nvext"]["stop_reason"], "END");
}
#[test]
fn test_timing_extra_field_emits_timing_on_final_chunk() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["timing".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-timing".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response.nvext.expect("nvext present for timing request");
assert!(
nvext_json.get("timing").is_some(),
"timing should be emitted when extra_fields=[\"timing\"]"
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("token_ids").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_query_instance_id_emits_worker_id_and_token_ids() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.annotations(vec!["query_instance_id:abc".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-qid".to_string());
let tracker = generator.tracker().expect("tracker");
tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for query_instance_id flow");
assert!(nvext_json.get("worker_id").is_some());
assert_eq!(
nvext_json.get("token_ids"),
Some(&serde_json::json!([11, 22, 33]))
);
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("routed_experts").is_none());
}
#[test]
fn test_routed_experts_extra_field_emits_routed_experts() {
use crate::protocols::openai::nvext::NvExt;
let nvext = NvExt::builder()
.extra_fields(vec!["routed_experts".to_string()])
.build()
.unwrap();
let mut generator =
make_request_with_nvext(nvext).response_generator("req-experts".to_string());
let response = generator
.choice_from_postprocessor(final_backend_output())
.expect("choice generation");
let nvext_json = response
.nvext
.expect("nvext present for routed_experts request");
assert_eq!(
nvext_json.get("routed_experts"),
Some(&serde_json::json!({"layer_0": [1, 3]}))
);
assert!(nvext_json.get("worker_id").is_none());
assert!(nvext_json.get("timing").is_none());
assert!(nvext_json.get("token_ids").is_none());
}
#[test]
fn test_engine_data_included_when_requested_via_extra_fields() {
let request = create_test_request_with_extra_fields(vec!["engine_data".to_string()]);
let mut generator = request.response_generator("req-engine-1".to_string());
let backend_output = make_backend_output_with_engine_data();
let response = generator
.choice_from_postprocessor(backend_output)
.expect("should produce a response");
let nvext = response.nvext.expect("nvext should be present");
let engine_data = nvext
.get("engine_data")
.expect("engine_data should be present");
assert_eq!(engine_data["kv_transfer_time_ms"], 12.3);
assert_eq!(engine_data["prefill_compute_time_ms"], 45.6);
}
#[test]
fn test_engine_data_excluded_when_not_requested() {
let request = create_test_request();
let mut generator = request.response_generator("req-engine-2".to_string());
let backend_output = make_backend_output_with_engine_data();
let response = generator
.choice_from_postprocessor(backend_output)
.expect("should produce a response");
if let Some(nvext) = &response.nvext {
assert!(
nvext.get("engine_data").is_none() || nvext.get("engine_data").unwrap().is_null(),
"engine_data should not be present when not requested"
);
}
}
#[test]
fn test_engine_data_excluded_when_other_extra_fields_requested() {
let request = create_test_request_with_extra_fields(vec!["timing".to_string()]);
let mut generator = request.response_generator("req-engine-3".to_string());
let backend_output = make_backend_output_with_engine_data();
let response = generator
.choice_from_postprocessor(backend_output)
.expect("should produce a response");
if let Some(nvext) = &response.nvext {
assert!(
nvext.get("engine_data").is_none() || nvext.get("engine_data").unwrap().is_null(),
"engine_data should not be present when only timing is requested"
);
}
}
#[test]
fn test_engine_data_none_from_backend_no_nvext_noise() {
let request = create_test_request_with_extra_fields(vec!["engine_data".to_string()]);
let mut generator = request.response_generator("req-engine-4".to_string());
let backend_output = crate::protocols::common::llm_backend::BackendOutput {
token_ids: vec![42],
tokens: vec![Some("hello".to_string())],
text: Some("hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(crate::protocols::common::FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
engine_data: None, };
let response = generator
.choice_from_postprocessor(backend_output)
.expect("should produce a response");
if let Some(nvext) = &response.nvext {
assert!(
nvext.get("engine_data").is_none() || nvext.get("engine_data").unwrap().is_null(),
"engine_data should not appear when backend provides None"
);
}
}
}