use anyhow::Error;
use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, ErrorType, RequestType, Status},
service_v2::HttpService,
},
},
model_card::ModelDeploymentCard,
};
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{
CancellationToken,
engine::AsyncEngineContext,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
},
};
use futures::StreamExt;
use prometheus::{Registry, proto::MetricType};
use reqwest::StatusCode;
use std::{io::Cursor, sync::Arc};
use tokio::time::timeout;
use tokio_util::codec::FramedRead;
#[path = "common/ports.rs"]
mod ports;
use ports::get_random_port;
struct CounterEngine {}
struct LongRunningEngine {
delay_ms: u64,
cancelled: Arc<std::sync::atomic::AtomicBool>,
}
impl LongRunningEngine {
fn new(delay_ms: u64) -> Self {
Self {
delay_ms,
cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
fn was_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::Acquire)
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for CounterEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
#[allow(deprecated)]
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
let mut generator = request.response_generator(ctx.id().to_string());
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 {
let output = generator.create_choice(i, Some(format!("choice {i}")), None, None, None);
yield Annotated::from_data(output);
}
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for LongRunningEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (_request, context) = request.transfer(());
let ctx = context.context();
tracing::info!(
"LongRunningEngine: Starting generation with {}ms delay",
self.delay_ms
);
let cancelled_flag = self.cancelled.clone();
let delay_ms = self.delay_ms;
let ctx_clone = ctx.clone();
let stream = async_stream::stream! {
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(delay_ms)) => {
cancelled_flag.store(false, std::sync::atomic::Ordering::SeqCst);
}
_ = ctx_clone.stopped() => {
cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
yield Annotated::<NvCreateChatCompletionStreamResponse>::from_annotation("event.dynamo.test.sentinel", &"DONE".to_string()).expect("Failed to create annotated response");
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
struct AlwaysFailEngine {}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for AlwaysFailEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
Err(HttpError {
code: 403,
message: "Always fail".to_string(),
})?
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for AlwaysFailEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
Err(HttpError {
code: 401,
message: "Always fail".to_string(),
})?
}
}
fn compare_counter(
metrics: &Metrics,
model: &str,
endpoint: &Endpoint,
request_type: &RequestType,
status: &Status,
error_type: &ErrorType,
expected: u64,
) {
assert_eq!(
metrics.get_request_counter(model, endpoint, request_type, status, error_type),
expected,
"model: {}, endpoint: {:?}, request_type: {:?}, status: {:?}, error_type: {:?}",
model,
endpoint.as_str(),
request_type.as_str(),
status.as_str(),
error_type.as_str()
);
}
fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Status) -> usize {
let endpoint = match endpoint {
Endpoint::Completions => 0,
Endpoint::ChatCompletions => 1,
Endpoint::Embeddings => todo!(),
Endpoint::Responses => todo!(),
Endpoint::AnthropicMessages => todo!(),
Endpoint::Tensor => todo!(),
Endpoint::Images => todo!(),
Endpoint::Videos => todo!(),
};
let request_type = match request_type {
RequestType::Unary => 0,
RequestType::Stream => 1,
};
let status = match status {
Status::Success => 0,
Status::Error => 1,
};
endpoint * 4 + request_type * 2 + status
}
fn compare_counters(metrics: &Metrics, model: &str, expected: &[u64; 8]) {
for endpoint in &[Endpoint::Completions, Endpoint::ChatCompletions] {
for request_type in &[RequestType::Unary, RequestType::Stream] {
for status in &[Status::Success, Status::Error] {
let index = compute_index(endpoint, request_type, status);
let error_type = match status {
Status::Success => &ErrorType::None,
Status::Error => &ErrorType::Validation, };
compare_counter(
metrics,
model,
endpoint,
request_type,
status,
error_type,
expected[index],
);
}
}
}
}
fn inc_counter(
endpoint: Endpoint,
request_type: RequestType,
status: Status,
expected: &mut [u64; 8],
) {
let index = compute_index(&endpoint, &request_type, &status);
expected[index] += 1;
}
#[allow(deprecated)]
#[tokio::test]
async fn test_http_service() {
let port = get_random_port().await;
let service = HttpService::builder()
.port(port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token.clone()).await });
wait_for_service_ready(port).await;
let registry = Registry::new();
let card = ModelDeploymentCard::with_name_only("foo");
let counter = Arc::new(CounterEngine {});
let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
assert!(result.is_ok());
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("bar");
let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
assert!(result.is_ok());
let result = manager.add_completions_model("bar", card.mdcsum(), failure);
assert!(result.is_ok());
let metrics = state.metrics_clone();
metrics.register(®istry).unwrap();
let mut foo_counters = [0u64; 8];
let mut bar_counters = [0u64; 8];
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
let client = reqwest::Client::new();
let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"hi".to_string(),
),
name: None,
},
);
let mut request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![message])
.build()
.expect("Failed to build request");
request.stream = Some(true);
request.max_tokens = Some(3000);
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert!(response.status().is_success(), "{:?}", response);
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
assert_eq!(metrics.get_inflight_count("foo"), 1);
let _ = response.bytes().await.unwrap();
inc_counter(
Endpoint::ChatCompletions,
RequestType::Stream,
Status::Success,
&mut foo_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
let families = registry.gather();
let histogram_metric_family = families
.into_iter()
.find(|m| {
m.get_name()
== format!(
"{}_{}",
name_prefix::FRONTEND,
frontend_service::REQUEST_DURATION_SECONDS
)
})
.expect("Histogram metric not found");
assert_eq!(
histogram_metric_family.get_field_type(),
MetricType::HISTOGRAM
);
let histogram_metric = histogram_metric_family.get_metric();
assert_eq!(histogram_metric.len(), 1);
let metric = &histogram_metric[0];
let histogram = metric.get_histogram();
let buckets = histogram.get_bucket();
let mut found = false;
let mut expected_count = 0;
for bucket_idx in 1..buckets.len() {
if buckets[bucket_idx].get_upper_bound() >= 2.5
&& buckets[bucket_idx - 1].get_upper_bound() < 2.5
{
found = true;
assert_eq!(
buckets[bucket_idx].get_cumulative_count(),
1,
"Observation should be counted in the bucket containing 2.5"
);
expected_count = 1;
} else {
assert_eq!(
buckets[bucket_idx].get_cumulative_count(),
expected_count,
"No observations should be in this bucket"
);
}
}
assert!(found, "The expected bucket was not found");
request.stream = Some(false);
request.max_tokens = Some(0);
let future = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send();
let response = future.await.unwrap();
assert!(response.status().is_success(), "{:?}", response);
inc_counter(
Endpoint::ChatCompletions,
RequestType::Unary,
Status::Success,
&mut foo_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
request.model = "bar".to_string();
request.max_tokens = Some(0);
request.stream = Some(true);
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
inc_counter(
Endpoint::ChatCompletions,
RequestType::Stream,
Status::Error,
&mut bar_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
request.stream = Some(false);
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
inc_counter(
Endpoint::ChatCompletions,
RequestType::Unary,
Status::Error,
&mut bar_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
let mut request = dynamo_async_openai::types::CreateCompletionRequestArgs::default()
.model("bar")
.prompt("hi")
.build()
.unwrap();
let response = client
.post(format!("http://localhost:{}/v1/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
inc_counter(
Endpoint::Completions,
RequestType::Unary,
Status::Error,
&mut bar_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
request.stream = Some(true);
let response = client
.post(format!("http://localhost:{}/v1/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
inc_counter(
Endpoint::Completions,
RequestType::Stream,
Status::Error,
&mut bar_counters,
);
compare_counters(&metrics, "foo", &foo_counters);
compare_counters(&metrics, "bar", &bar_counters);
request.stream = Some(false);
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{:?}", response);
let response = client
.get(format!("http://localhost:{}/metrics", port))
.send()
.await
.unwrap();
assert!(response.status().is_success(), "{:?}", response);
println!("{}", response.text().await.unwrap());
cancel_token.cancel();
task.await.unwrap().unwrap();
}
async fn wait_for_service_ready(port: u16) {
let start = tokio::time::Instant::now();
let timeout = tokio::time::Duration::from_secs(5);
loop {
match reqwest::get(&format!("http://localhost:{}/health", port)).await {
Ok(_) => break,
Err(_) if start.elapsed() < timeout => {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Err(e) => panic!("Service failed to start within timeout: {}", e),
}
}
}
async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>, u16) {
let port = get_random_port().await;
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let manager = service.model_manager();
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager
.add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager
.add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
(service, counter, failure, port)
}
fn pure_openai_client(port: u16) -> PureOpenAIClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
PureOpenAIClient::new(config)
}
fn nv_custom_client(port: u16) -> NvCustomClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
NvCustomClient::new(config)
}
fn generic_byot_client(port: u16) -> GenericBYOTClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
GenericBYOTClient::new(config)
}
#[tokio::test]
async fn test_pure_openai_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let pure_openai_client = pure_openai_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(result.is_ok(), "PureOpenAI client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; }
}
assert!(count > 0, "Should receive at least one response");
let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("bar") .messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
let ctx = HttpRequestContext::new();
let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_nv_custom_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let nv_custom_client = nv_custom_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let inner_request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(result.is_ok(), "NvCustom client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; }
}
assert!(count > 0, "Should receive at least one response");
let inner_request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("bar") .messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
let ctx = HttpRequestContext::new();
let inner_request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content:
dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_generic_byot_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let generic_byot_client = generic_byot_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(result.is_ok(), "GenericBYOT client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
println!("Response: {:?}", response);
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; }
}
assert!(count > 0, "Should receive at least one response");
let request = serde_json::json!({
"model": "bar", "messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
let ctx = HttpRequestContext::new();
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
let port = get_random_port().await;
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let card = ModelDeploymentCard::with_name_only("slow-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will take a long time".to_string(),
),
name: None,
},
);
let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-model")
.messages(vec![message])
.stream(false) .build()
.expect("Failed to build request");
let start_time = std::time::Instant::now();
let request_future = async {
client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
};
let result = timeout(std::time::Duration::from_millis(1000), request_future).await;
let elapsed = start_time.elapsed();
assert!(result.is_err(), "Request should have timed out");
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to client disconnect"
);
assert!(
elapsed < std::time::Duration::from_secs(2),
"Cancellation should have propagated quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Client disconnect test passed! Request cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_client_disconnect_cancellation_streaming() {
dynamo_runtime::logging::init();
let port = get_random_port().await;
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let card = ModelDeploymentCard::with_name_only("slow-stream-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model(
"slow-stream-model",
card.mdcsum(),
long_running_engine.clone(),
)
.unwrap();
let client = reqwest::Client::new();
let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"This will stream for a long time".to_string(),
),
name: None,
},
);
let request = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("slow-stream-model")
.messages(vec![message])
.stream(true) .build()
.expect("Failed to build request");
let start_time = std::time::Instant::now();
let request_future = async {
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.json(&request)
.send()
.await
.unwrap();
let mut stream = response.bytes_stream();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let _ = StreamExt::next(&mut stream).await;
};
let _result = timeout(std::time::Duration::from_millis(1500), request_future).await;
let elapsed = start_time.elapsed();
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
assert!(
long_running_engine.was_cancelled(),
"Engine should have been cancelled due to streaming client disconnect"
);
assert!(
elapsed < std::time::Duration::from_secs(3),
"Stream cancellation should have propagated reasonably quickly, took {:?}",
elapsed
);
tracing::info!(
"✅ Streaming client disconnect test passed! Stream cancelled in {:?}, engine detected cancellation",
elapsed
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_request_id_annotation() {
dynamo_runtime::logging::init();
let port = get_random_port().await;
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new();
let cancel_token = token.clone();
let task = tokio::spawn(async move { service.run(token).await });
wait_for_service_ready(port).await;
let card = ModelDeploymentCard::with_name_only("test-model");
let counter_engine = Arc::new(CounterEngine {});
manager
.add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
.unwrap();
let client = reqwest::Client::new();
let request_uuid = uuid::Uuid::new_v4();
let request_json = serde_json::json!({
"model": "test-model",
"messages": [
{
"role": "user",
"content": "Test request with annotation"
}
],
"stream": true,
"max_tokens": 50,
"nvext": {
"annotations": ["request_id"]
}
});
let response = client
.post(format!("http://localhost:{}/v1/chat/completions", port))
.header("x-dynamo-request-id", request_uuid.to_string())
.json(&request_json)
.send()
.await
.expect("Request should succeed");
assert!(
response.status().is_success(),
"Response should be successful"
);
let body_bytes = response
.bytes()
.await
.expect("Failed to read response body");
let body_text = String::from_utf8_lossy(&body_bytes);
let cursor = Cursor::new(body_text.to_string());
let framed = FramedRead::new(cursor, SseLineCodec::new());
let annotated_stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(framed);
let mut found_request_id_annotation = false;
let mut received_request_id = None;
let mut annotated_stream = std::pin::pin!(annotated_stream);
while let Some(annotated_response) = annotated_stream.next().await {
if let Some(event) = &annotated_response.event
&& event == "request_id"
{
found_request_id_annotation = true;
if let Some(comments) = &annotated_response.comment
&& let Some(comment) = comments.first()
{
if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
received_request_id = Some(parsed_value);
} else {
received_request_id = Some(comment.trim_matches('"').to_string());
}
}
break;
}
}
assert!(
found_request_id_annotation,
"Should have received request_id annotation in the stream"
);
assert!(
received_request_id.is_some(),
"Should have received the request ID in the annotation"
);
let received_uuid_str = received_request_id.unwrap();
assert_eq!(
received_uuid_str,
request_uuid.to_string(),
"Received request ID should match the one we sent: expected {}, got {}",
request_uuid,
received_uuid_str
);
tracing::info!(
"✅ Request ID annotation test passed! Sent UUID: {}, Received: {}",
request_uuid,
received_uuid_str
);
cancel_token.cancel();
task.await.unwrap().unwrap();
}