use async_trait::async_trait;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
use dynamo_llm::protocols::openai::ParsingOptions;
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator,
};
use dynamo_llm::protocols::openai::completions::NvCreateCompletionRequest;
use dynamo_protocols::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionStreamOptions,
CreateChatCompletionRequest,
};
use dynamo_protocols::types::{
CompletionUsage as AoaiCompletionUsage, CreateCompletionRequestArgs, Prompt,
PromptTokensDetails,
};
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug)]
struct MockContext {
id: String,
stopped: AtomicBool,
killed: AtomicBool,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-request-123".to_string(),
stopped: AtomicBool::new(false),
killed: AtomicBool::new(false),
}
}
}
#[async_trait]
impl AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop_generating(&self) {
self.stopped.store(true, Ordering::SeqCst);
}
fn is_stopped(&self) -> bool {
self.stopped.load(Ordering::SeqCst)
}
fn is_killed(&self) -> bool {
self.killed.load(Ordering::SeqCst)
}
async fn stopped(&self) {
}
async fn killed(&self) {
}
fn stop(&self) {
self.stopped.store(true, Ordering::SeqCst);
}
fn kill(&self) {
self.killed.store(true, Ordering::SeqCst);
}
fn link_child(&self, _: Arc<dyn AsyncEngineContext>) {
}
}
fn create_mock_backend_stream(
ctx: Arc<dyn AsyncEngineContext>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = build_backend_outputs_with_cached_tokens(None);
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<BackendOutput> {
vec![
BackendOutput {
token_ids: vec![15339],
tokens: vec![Some("Hello".to_string())],
text: Some("Hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
},
BackendOutput {
token_ids: vec![1917],
tokens: vec![Some(" world".to_string())],
text: Some(" world".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
},
BackendOutput {
token_ids: vec![0],
tokens: vec![Some("!".to_string())],
text: Some("!".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: Some(PromptTokensDetails {
audio_tokens: None,
cached_tokens: Some(ct),
}),
completion_tokens_details: None,
}),
disaggregated_params: None,
},
]
}
fn create_backend_stream_with_cached_tokens(
ctx: Arc<dyn AsyncEngineContext>,
cached_tokens: Option<u32>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = build_backend_outputs_with_cached_tokens(cached_tokens);
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
fn create_chat_request(
include_usage: Option<bool>,
continuous_usage: Option<bool>,
) -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)];
let stream_options = include_usage.map(|include| ChatCompletionStreamOptions {
include_usage: include,
continuous_usage_stats: continuous_usage.unwrap_or(false),
});
let inner = CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(true),
stream_options,
..Default::default()
};
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
#[tokio::test]
async fn test_streaming_without_usage() {
let request = create_chat_request(None, None);
let request_id = "test-123".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
let content_chunks: Vec<_> = chunks
.into_iter()
.filter(|chunk| {
!(chunk
.event
.as_ref()
.map(|e| e == "llm_metrics")
.unwrap_or(false)
&& chunk.data.is_none())
})
.collect();
assert_eq!(
content_chunks.len(),
3,
"Should have exactly 3 content chunks"
);
for (i, chunk) in content_chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.inner.usage.is_none(),
"Chunk {} should have usage: None when stream_options not set",
i
);
assert!(
!response.inner.choices.is_empty(),
"Chunk {} should have choices",
i
);
}
}
}
#[tokio::test]
async fn test_streaming_with_usage_compliance() {
let request = create_chat_request(Some(true), None);
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(
chunks.len(),
4,
"Should have 3 content chunks + 1 usage chunk"
);
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.inner.usage.is_none(),
"Content chunk {} should have usage: None",
i
);
assert!(
!response.inner.choices.is_empty(),
"Content chunk {} should have choices",
i
);
}
}
if let Some(final_response) = &chunks[3].data {
assert!(
final_response.inner.choices.is_empty(),
"Final usage chunk should have empty choices array"
);
assert!(
final_response.inner.usage.is_some(),
"Final usage chunk should have usage statistics"
);
let usage = final_response.inner.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Should have 3 completion tokens"
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens, 3,
"Total tokens should be prompt + completion"
);
} else {
panic!("Final chunk should be a valid response");
}
}
#[tokio::test]
async fn test_streaming_with_continuous_usage() {
let request = create_chat_request(Some(true), Some(true));
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(
chunks.len(),
4,
"Should have 3 content chunks + 1 usage chunk"
);
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.inner.usage.is_some(),
"Content chunk {} should have usage: Some",
i
);
assert!(
!response.inner.choices.is_empty(),
"Content chunk {} should have choices",
i
);
let usage = response.inner.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens,
i as u32 + 1,
"Should have {} completion tokens",
i + 1
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens,
i as u32 + 1,
"Total tokens should be prompt + completion"
);
}
}
if let Some(final_response) = &chunks[3].data {
assert!(
final_response.inner.choices.is_empty(),
"Final usage chunk should have empty choices array"
);
assert!(
final_response.inner.usage.is_some(),
"Final usage chunk should have usage statistics"
);
let usage = final_response.inner.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Should have 3 completion tokens"
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens, 3,
"Total tokens should be prompt + completion"
);
} else {
panic!("Final chunk should be a valid response");
}
}
#[tokio::test]
async fn test_streaming_with_usage_false() {
let request = create_chat_request(Some(false), None);
let request_id = "test-789".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
let content_chunks: Vec<_> = chunks
.into_iter()
.filter(|chunk| {
!(chunk
.event
.as_ref()
.map(|e| e == "llm_metrics")
.unwrap_or(false)
&& chunk.data.is_none())
})
.collect();
assert_eq!(
content_chunks.len(),
3,
"Should have exactly 3 content chunks when include_usage is false"
);
for (i, chunk) in content_chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.inner.usage.is_none(),
"Chunk {} should have usage: None when include_usage is false",
i
);
}
}
}
fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCompletionRequest {
let inner = {
let mut builder = CreateCompletionRequestArgs::default();
builder
.model("test-model")
.prompt(Prompt::String("Hello".to_string()))
.stream(stream);
if let Some(include) = include_usage {
builder.stream_options(dynamo_protocols::types::ChatCompletionStreamOptions {
include_usage: include,
continuous_usage_stats: false,
});
}
builder.build().unwrap()
};
NvCreateCompletionRequest {
inner,
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
}
}
fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)];
let inner = CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
};
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
}
}
#[tokio::test]
async fn test_nonstreaming_has_usage_field() {
let mut request = create_nonstreaming_chat_request();
assert_eq!(
request.inner.stream,
Some(false),
"Request should be non-streaming"
);
assert!(
request.inner.stream_options.is_none(),
"stream_options should not be set initially"
);
let original_stream_flag = request.inner.stream.unwrap_or(false);
request.enable_usage_for_nonstreaming(original_stream_flag);
let request_id = "test-nonstream-123".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let result = dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionResponse::from_annotated_stream(
transformed_stream,
ParsingOptions::default(),
)
.await;
assert!(result.is_ok(), "Aggregation should succeed");
let response = result.unwrap();
assert!(
response.inner.usage.is_some(),
"Non-streaming chat completion response MUST have a usage field populated. \
This is required for OpenAI API compliance."
);
let usage = response.inner.usage.unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Completion tokens should match the number of tokens generated"
);
assert!(
usage.total_tokens > 0,
"Total tokens should be greater than 0"
);
assert_eq!(
usage.total_tokens,
usage.prompt_tokens + usage.completion_tokens,
"Total tokens should equal prompt_tokens + completion_tokens"
);
}
#[tokio::test]
async fn test_cmpl_streaming_with_usage_true_no_backend_usage() {
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-none-1".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(resp) = &chunk.data {
assert!(
resp.inner.usage.is_none(),
"Content chunk {} should have usage: None",
i
);
assert!(
!resp.inner.choices.is_empty(),
"Content chunk {} should have choices",
i
);
}
}
if let Some(final_resp) = &chunks[3].data {
assert!(
final_resp.inner.choices.is_empty(),
"Usage-only chunk must have empty choices"
);
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present");
assert_eq!(
usage.completion_tokens, 3,
"Aggregated completion tokens should be 3"
);
assert!(
usage.prompt_tokens_details.is_none(),
"prompt_tokens_details should be None when backend does not send usage"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_streaming_with_cached_tokens_propagation() {
let request = create_cmpl_request(Some(true), true);
let request_id = "cmpl-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(7));
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present on final chunk");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(7),
"cached_tokens must propagate to final usage chunk"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_chat_streaming_with_cached_tokens_propagation() {
let request = create_chat_request(Some(true), Some(true));
let request_id = "chat-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(5));
response_generator.update_isl(0);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let chunks: Vec<_> = transformed_stream.collect().await;
assert_eq!(chunks.len(), 4, "Should have 3 content + 1 usage chunk");
if let Some(final_resp) = &chunks[3].data {
let usage = final_resp
.inner
.usage
.as_ref()
.expect("Usage must be present");
let cached = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(5),
"cached_tokens must propagate for chat completions"
);
} else {
panic!("Final chunk should be present");
}
}
#[tokio::test]
async fn test_cmpl_nonstreaming_has_usage_and_cached_tokens() {
let mut request = create_cmpl_request(None, false);
let original_stream_flag = request.inner.stream.unwrap_or(false);
request.enable_usage_for_nonstreaming(original_stream_flag);
let request_id = "cmpl-nonstream-usage".to_string();
let response_generator = Box::new(request.response_generator(request_id));
let ctx = Arc::new(MockContext::new());
let backend_stream = create_backend_stream_with_cached_tokens(ctx.clone(), Some(9));
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
let parsing = ParsingOptions::default();
let result =
dynamo_llm::protocols::openai::completions::NvCreateCompletionResponse::from_annotated_stream(
transformed_stream,
parsing,
)
.await;
assert!(result.is_ok(), "Aggregation should succeed");
let resp = result.unwrap();
let usage = resp
.inner
.usage
.expect("usage must be present for non-streaming");
assert_eq!(
usage.completion_tokens, 3,
"completion_tokens must aggregate"
);
let cached = usage.prompt_tokens_details.and_then(|d| d.cached_tokens);
assert_eq!(
cached,
Some(9),
"cached_tokens must propagate to non-streaming response"
);
}