use std::sync::Arc;
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{
protocols::{
common::{self, timing::RequestTracker},
openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponse, TimingInfo},
},
},
types::TokenIdType,
};
impl NvCreateCompletionRequest {
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
if !original_stream_flag {
if self.inner.stream_options.is_none() {
self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
opts.include_usage = true;
}
}
}
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
let enable_tracking = self
.nvext()
.map(|nv| {
nv.extra_fields
.as_ref()
.is_some_and(|fields| fields.iter().any(|f| f == "timing"))
|| nv.annotations.as_ref().is_some_and(|annots| {
annots.iter().any(|a| a.starts_with("query_instance_id"))
})
})
.unwrap_or(false);
let options = DeltaGeneratorOptions {
enable_usage: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_tracking,
};
DeltaGenerator::new(self.inner.model.clone(), options, request_id)
}
}
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub continuous_usage_stats: bool,
pub enable_logprobs: bool,
pub enable_tracking: bool,
}
pub struct DeltaGenerator {
id: String,
object: String,
created: u32,
model: String,
system_fingerprint: Option<String>,
usage: dynamo_async_openai::types::CompletionUsage,
options: DeltaGeneratorOptions,
tracker: Option<Arc<RequestTracker>>,
}
impl DeltaGenerator {
pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
let usage = dynamo_async_openai::types::CompletionUsage {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
};
let completion_id = format!("cmpl-{request_id}");
let tracker = Some(Arc::new(RequestTracker::new()));
Self {
id: completion_id,
object: "text_completion".to_string(),
created: now,
model,
system_fingerprint: None,
usage,
options,
tracker,
}
}
pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
self.tracker.clone()
}
pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl;
}
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<dynamo_async_openai::types::Logprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
toks.iter()
.zip(tok_lps.iter())
.zip(top_logprobs.iter())
.map(|(((t, tid), lp), top_lps)| {
let converted = convert_backend_top_logprobs(top_lps, t, *tid, *lp);
serde_json::to_value(converted).unwrap()
})
.collect()
});
Some(dynamo_async_openai::types::Logprobs {
tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
token_logprobs: tok_lps.into_iter().map(Some).collect(),
text_offset: vec![],
top_logprobs: top_lps,
})
}
pub fn create_choice(
&self,
index: u32,
text: Option<String>,
finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
logprobs: Option<dynamo_async_openai::types::Logprobs>,
) -> NvCreateCompletionResponse {
let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![dynamo_async_openai::types::Choice {
text: text.unwrap_or_default(),
index,
finish_reason,
logprobs,
}],
usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
nvext: None, };
NvCreateCompletionResponse { inner }
}
pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
let usage = self.get_usage();
let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![], usage: Some(usage),
nvext: None, };
NvCreateCompletionResponse { inner }
}
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
usage
}
}
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
fn choice_from_postprocessor(
&mut self,
delta: common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateCompletionResponse> {
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
if let Some(completion_usage) = delta.completion_usage.as_ref() {
self.usage.prompt_tokens = completion_usage.prompt_tokens;
if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
}
let logprobs = self.create_logprobs(
delta.tokens,
delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
let finish_reason = delta.finish_reason.map(Into::into);
let index = delta.index.unwrap_or(0);
let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.tracker.as_ref().map(|tracker| {
tracker.record_finish();
tracker.get_timing_info()
})
} else {
None
};
if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
{
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
routed_experts,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
response.inner.nvext = Some(nvext_json);
if let Some(ref info) = worker_id_info {
tracing::debug!(
"Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
info.prefill_worker_id,
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into completions nvext: {} tokens",
tokens.len()
);
}
}
}
Ok(response)
}
fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens)
}
fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
DeltaGenerator::create_usage_chunk(self)
}
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
fn tracker(&self) -> Option<std::sync::Arc<crate::protocols::common::timing::RequestTracker>> {
self.tracker.clone()
}
}