use std::{
collections::HashSet,
fmt::Display,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{
Json, Router,
body::Body,
extract::State,
http::Request,
http::{HeaderMap, StatusCode},
middleware::{self, Next},
response::{
IntoResponse, Response,
sse::{KeepAlive, Sse},
},
routing::{get, post},
};
use base64::Engine as _;
use bytes::Bytes;
use dynamo_runtime::config::environment_names::llm as env_llm;
use dynamo_runtime::{
pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider,
};
use futures::{StreamExt, stream};
use serde::{Deserialize, Serialize};
use super::{
RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
error::HttpError,
metrics::{
Endpoint, ErrorType, EventConverter, process_response_and_observe_metrics,
process_response_using_event_converter_and_observe_metrics,
},
service_v2,
};
use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::nvext::apply_header_routing_overrides;
use crate::protocols::openai::{
chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
NvCreateChatCompletionStreamResponse,
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse},
responses::{NvCreateResponse, NvResponse, ResponseParams, chat_completion_to_response},
videos::{NvCreateVideoRequest, NvVideosResponse},
};
use crate::request_template::RequestTemplate;
use crate::types::Annotated;
use dynamo_runtime::logging::get_distributed_tracing_context;
use tracing::Instrument;
pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id";
pub const ANNOTATION_REQUEST_ID: &str = "request_id";
const VALIDATION_PREFIX: &str = "Validation: ";
pub(super) fn get_body_limit() -> usize {
std::env::var(env_llm::DYN_HTTP_BODY_LIMIT_MB)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|mb| mb * 1024 * 1024)
.unwrap_or(45 * 1024 * 1024)
}
pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ErrorMessage {
message: String,
#[serde(rename = "type")]
error_type: String,
code: u16,
}
fn map_error_code_to_error_type(code: StatusCode) -> String {
match code.canonical_reason() {
Some(reason) => reason.to_string(),
None => "UnknownError".to_string(),
}
}
fn classify_error_for_metrics(code: StatusCode, message: &str) -> ErrorType {
match code {
StatusCode::BAD_REQUEST => {
if message.starts_with("Validation:") {
ErrorType::Validation
} else {
ErrorType::Internal
}
}
StatusCode::NOT_FOUND => ErrorType::NotFound, StatusCode::NOT_IMPLEMENTED => ErrorType::NotImplemented, StatusCode::TOO_MANY_REQUESTS => ErrorType::Overload, StatusCode::SERVICE_UNAVAILABLE => ErrorType::Overload, StatusCode::INTERNAL_SERVER_ERROR => ErrorType::Internal, _ if code.is_client_error() => ErrorType::Validation, _ => ErrorType::Internal, }
}
fn extract_error_type_from_response(response: &ErrorResponse) -> ErrorType {
classify_error_for_metrics(response.0, &response.1.message)
}
impl ErrorMessage {
pub fn model_not_found() -> ErrorResponse {
let code = StatusCode::NOT_FOUND;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: "Model not found".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
pub fn _service_unavailable() -> ErrorResponse {
let code = StatusCode::SERVICE_UNAVAILABLE;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: "Service is not ready".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
pub fn internal_server_error(msg: &str) -> ErrorResponse {
tracing::error!("Internal server error: {msg}");
let code = StatusCode::INTERNAL_SERVER_ERROR;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: msg.to_string(),
error_type,
code: code.as_u16(),
}),
)
}
pub fn not_implemented_error<T: Display>(msg: T) -> ErrorResponse {
tracing::error!("Not Implemented error: {msg}");
let code = StatusCode::NOT_IMPLEMENTED;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: msg.to_string(),
error_type,
code: code.as_u16(),
}),
)
}
pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse {
if let Some(pipeline_err) =
err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
&& matches!(
pipeline_err,
dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
)
{
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
message: pipeline_err.to_string(),
error_type: map_error_code_to_error_type(StatusCode::SERVICE_UNAVAILABLE),
code: StatusCode::SERVICE_UNAVAILABLE.as_u16(),
}),
);
}
if let Some(dynamo_err) = err.downcast_ref::<dynamo_runtime::error::DynamoError>()
&& dynamo_err.error_type() == dynamo_runtime::error::ErrorType::InvalidArgument
{
return (
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
message: dynamo_err.message().to_string(),
error_type: map_error_code_to_error_type(StatusCode::BAD_REQUEST),
code: StatusCode::BAD_REQUEST.as_u16(),
}),
);
}
match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err:#}")),
}
}
pub fn from_http_error(err: HttpError) -> ErrorResponse {
if err.code < 400 || err.code >= 500 {
return ErrorMessage::internal_server_error(&err.message);
}
match StatusCode::from_u16(err.code) {
Ok(code) => (
code,
Json(ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(code),
code: code.as_u16(),
}),
),
Err(_) => ErrorMessage::internal_server_error(&err.message),
}
}
}
impl From<HttpError> for ErrorMessage {
fn from(err: HttpError) -> Self {
ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(
StatusCode::from_u16(err.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
),
code: err.code,
}
}
}
pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, get_body_limit())
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
message: error_message,
error_type: map_error_code_to_error_type(StatusCode::BAD_REQUEST),
code: StatusCode::BAD_REQUEST.as_u16(),
}),
)
.into_response()
} else {
response
}
}
pub(super) fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
if let Some(trace_context) = get_distributed_tracing_context()
&& let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id
{
return x_dynamo_request_id;
}
if let Some(primary) = primary
&& let Ok(uuid) = uuid::Uuid::parse_str(primary)
{
return uuid.to_string();
}
let request_id_opt = headers
.get(DYNAMO_REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok());
let uuid = match request_id_opt {
Some(request_id) => {
uuid::Uuid::parse_str(request_id).unwrap_or_else(|_| uuid::Uuid::new_v4())
}
None => uuid::Uuid::new_v4(),
};
uuid.to_string()
}
async fn handler_completions(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(mut request): Json<NvCreateCompletionRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let response = tokio::spawn(completions(state, request, stream_handle).in_current_span())
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await chat completions task: {:?}",
e,
))
})?;
connection_handle.disarm();
response
}
#[tracing::instrument(skip_all)]
async fn completions(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::get_prompt_batch_size;
check_ready(&state)?;
validate_completion_stream_options(&request)?;
validate_completion_fields_generic(&request)?;
let batch_size = get_prompt_batch_size(&request.inner.prompt);
let n = request.inner.n.unwrap_or(1);
if batch_size == 1 {
return completions_single(state, request, stream_handle).await;
}
completions_batch(state, request, stream_handle, batch_size, n).await
}
#[tracing::instrument(skip_all)]
async fn completions_single(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
let request_id = request.id().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let model = request.inner.model.clone();
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Completions, streaming);
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let annotations = request.annotations();
let stream = engine.generate(request).await.map_err(|e| {
let err_response = ErrorMessage::from_anyhow(e, "Failed to generate completions");
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let ctx = stream.context();
let annotations = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
let stream = stream::iter(annotations).chain(stream);
if streaming {
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream
.map(move |response| {
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
"Failed to fold completions stream for {}: {:?}",
request_id,
e
);
let err_response = ErrorMessage::internal_server_error(&format!(
"Failed to fold completions stream for {}: {:?}",
request_id, e
));
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
inflight_guard.mark_ok();
if ctx.is_killed() {
inflight_guard.mark_error(ErrorType::Cancelled);
}
Ok(Json(response).into_response())
}
}
#[tracing::instrument(skip_all)]
async fn completions_batch(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
batch_size: usize,
n: u8,
) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::extract_single_prompt;
use futures::stream::{self, StreamExt};
let request_id = request.id().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let model = request.inner.model.clone();
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Completions, streaming);
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let annotations = request.annotations();
let mut all_streams = Vec::new();
let mut first_ctx = None;
for prompt_idx in 0..batch_size {
let single_prompt = extract_single_prompt(&request.inner.prompt, prompt_idx);
let mut single_request = request.content().clone();
single_request.inner.prompt = single_prompt;
let unique_request_id = format!("{}-{}", request.id(), prompt_idx);
let single_request_context = Context::with_id(single_request, unique_request_id);
let stream = engine.generate(single_request_context).await.map_err(|e| {
let err_response = ErrorMessage::from_anyhow(e, "Failed to generate completions");
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
if first_ctx.is_none() {
first_ctx = Some(stream.context());
}
let prompt_idx_u32 = prompt_idx as u32;
let n_u32 = n as u32;
let remapped_stream = stream.map(move |mut response| {
if let Some(ref mut data) = response.data {
for choice in &mut data.inner.choices {
choice.index += prompt_idx_u32 * n_u32;
}
}
response
});
all_streams.push(remapped_stream);
}
let merged_stream = stream::select_all(all_streams);
let ctx = first_ctx.expect("At least one stream should be generated");
let annotations_vec = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
let merged_stream = stream::iter(annotations_vec).chain(merged_stream);
if streaming {
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream
.map(move |response| {
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
"Failed to fold completions stream for {}: {:?}",
request_id,
e
);
let err_response = ErrorMessage::internal_server_error(&format!(
"Failed to fold completions stream for {}: {:?}",
request_id, e
));
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
inflight_guard.mark_ok();
if ctx.is_killed() {
inflight_guard.mark_error(ErrorType::Cancelled);
}
Ok(Json(response).into_response())
}
}
#[tracing::instrument(skip_all)]
async fn embeddings(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
let streaming = false;
let model = &request.inner.model;
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Embeddings, streaming);
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
let engine = state.manager().get_embeddings_engine(model).map_err(|_| {
let err_response = ErrorMessage::model_not_found();
inflight.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut response_collector = state.metrics_clone().create_response_collector(model);
let stream = engine.generate(request).await.map_err(|e| {
let err_response = ErrorMessage::from_anyhow(e, "Failed to generate embeddings");
inflight.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvCreateEmbeddingResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
"Failed to fold embeddings stream for {}: {:?}",
request_id,
e
);
let err_response =
ErrorMessage::internal_server_error("Failed to fold embeddings stream");
inflight.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
async fn handler_chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let response =
tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span())
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await chat completions task: {:?}",
e,
))
})?;
connection_handle.disarm();
response
}
fn extract_backend_error_if_present<T: serde::Serialize>(
event: &Annotated<T>,
) -> Option<(String, StatusCode)> {
#[derive(serde::Deserialize)]
struct ErrorPayload {
message: Option<String>,
code: Option<u16>,
}
if let Some(event_type) = &event.event
&& event_type == "error"
{
let error_str = if let Some(ref dynamo_err) = event.error {
let mut parts = Vec::new();
let mut current: Option<&dyn std::error::Error> = Some(dynamo_err);
while let Some(e) = current {
if let Some(de) = e.downcast_ref::<dynamo_runtime::error::DynamoError>() {
parts.push(de.message().to_string());
} else {
parts.push(e.to_string());
}
current = e.source();
}
parts.join(", ")
} else {
event
.comment
.as_ref()
.map(|c| c.join(", "))
.unwrap_or_else(|| "Unknown error".to_string())
};
if let Ok(error_payload) = serde_json::from_str::<ErrorPayload>(&error_str) {
let code = error_payload
.code
.and_then(|c| StatusCode::from_u16(c).ok())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let message = error_payload.message.unwrap_or(error_str);
return Some((message, code));
}
return Some((error_str, StatusCode::INTERNAL_SERVER_ERROR));
}
if let Some(data) = &event.data
&& let Ok(json_value) = serde_json::to_value(data)
&& let Ok(error_payload) = serde_json::from_value::<ErrorPayload>(json_value.clone())
&& let Some(code_num) = error_payload.code
&& code_num >= 400
{
let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let message = error_payload
.message
.unwrap_or_else(|| json_value.to_string());
return Some((message, code));
}
if let Some(comments) = &event.comment
&& !comments.is_empty()
{
let comment_str = comments.join(", ");
if let Ok(error_payload) = serde_json::from_str::<ErrorPayload>(&comment_str)
&& let Some(code_num) = error_payload.code
&& code_num >= 400
{
let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let message = error_payload.message.unwrap_or(comment_str);
return Some((message, code));
}
if event.data.is_none() && event.event.is_none() {
return Some((comment_str, StatusCode::INTERNAL_SERVER_ERROR));
}
}
None
}
pub(super) async fn check_for_backend_error(
mut stream: impl futures::Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>
+ Send
+ Unpin
+ 'static,
) -> Result<
impl futures::Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send,
ErrorResponse,
> {
use futures::stream::StreamExt;
if let Some(first_event) = stream.next().await {
if let Some((error_msg, status_code)) = extract_backend_error_if_present(&first_event) {
return Err((
status_code,
Json(ErrorMessage {
message: error_msg,
error_type: map_error_code_to_error_type(status_code),
code: status_code.as_u16(),
}),
));
}
let reconstructed_stream = futures::stream::iter(vec![first_event]).chain(stream);
Ok(reconstructed_stream)
} else {
Ok(futures::stream::iter(vec![]).chain(stream))
}
}
async fn chat_completions(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
mut request: Context<NvCreateChatCompletionRequest>,
mut stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let request_id = request.id().to_string();
let streaming = request.inner.stream.unwrap_or(false);
if let Some(template) = template {
if request.inner.model.is_empty() {
request.inner.model = template.model.clone();
}
if request.inner.temperature.unwrap_or(0.0) == 0.0 {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_completion_tokens.unwrap_or(0) == 0 {
request.inner.max_completion_tokens = Some(template.max_completion_tokens);
}
}
let model = request.inner.model.clone();
tracing::trace!("Received chat completions request: {:?}", request.content());
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::ChatCompletions, streaming);
if let Err(err_response) = validate_chat_completion_unsupported_fields(&request) {
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
return Err(err_response);
}
if let Err(err_response) = validate_chat_completion_required_fields(&request) {
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
return Err(err_response);
}
if let Err(err_response) = validate_chat_completion_stream_options(&request) {
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
return Err(err_response);
}
if let Err(err_response) = validate_chat_completion_fields_generic(&request) {
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
return Err(err_response);
}
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
tracing::trace!("Getting chat completions engine for model: {}", model);
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let annotations = request.annotations();
let stream = engine.generate(request).await.map_err(|e| {
let err_response = ErrorMessage::from_anyhow(e, "Failed to generate completions");
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let ctx = stream.context();
let annotations = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::from_annotation(ANNOTATION_REQUEST_ID, &request_id).ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
let stream = stream::iter(annotations).chain(stream);
if streaming {
stream_handle.arm();
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream
.map(move |response| {
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
})
.filter_map(|result| {
use futures::future;
future::ready(result.transpose())
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let stream_with_check =
check_for_backend_error(stream)
.await
.map_err(|error_response| {
tracing::error!(request_id, "Backend error detected: {:?}", error_response);
inflight_guard.mark_error(extract_error_type_from_response(&error_response));
error_response
})?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream_with_check.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to parse chat completion response: {:?}",
e
);
let err_response = ErrorMessage::internal_server_error(&format!(
"Failed to parse chat completion response: {}",
e
));
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
inflight_guard.mark_ok();
if ctx.is_killed() {
inflight_guard.mark_error(ErrorType::Cancelled);
}
Ok(Json(response).into_response())
}
}
#[allow(deprecated)]
pub fn validate_chat_completion_unsupported_fields(
request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
let inner = &request.inner;
if inner.function_call.is_some() {
return Err(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string()
+ "`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
));
}
if inner.functions.is_some() {
return Err(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string()
+ "`functions` is deprecated. Please migrate to use `tools` instead.",
));
}
Ok(())
}
pub fn validate_chat_completion_required_fields(
request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
let inner = &request.inner;
if inner.messages.is_empty() {
return Err(ErrorMessage::from_http_error(HttpError {
code: 400,
message: VALIDATION_PREFIX.to_string()
+ "The 'messages' field cannot be empty. At least one message is required.",
}));
}
Ok(())
}
pub fn validate_chat_completion_stream_options(
request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
let inner = &request.inner;
let streaming = inner.stream.unwrap_or(false);
if !streaming && inner.stream_options.is_some() {
return Err(ErrorMessage::from_http_error(HttpError {
code: 400,
message: VALIDATION_PREFIX.to_string()
+ "The 'stream_options' field is only allowed when 'stream' is set to true.",
}));
}
Ok(())
}
pub fn validate_chat_completion_fields_generic(
request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
request.validate().map_err(|e| {
ErrorMessage::from_http_error(HttpError {
code: 400,
message: VALIDATION_PREFIX.to_string() + &e.to_string(),
})
})
}
pub fn validate_completion_stream_options(
request: &NvCreateCompletionRequest,
) -> Result<(), ErrorResponse> {
let inner = &request.inner;
let streaming = inner.stream.unwrap_or(false);
if !streaming && inner.stream_options.is_some() {
return Err(ErrorMessage::from_http_error(HttpError {
code: 400,
message: VALIDATION_PREFIX.to_string()
+ "The 'stream_options' field is only allowed when 'stream' is set to true.",
}));
}
Ok(())
}
pub fn validate_completion_fields_generic(
request: &NvCreateCompletionRequest,
) -> Result<(), ErrorResponse> {
request.validate().map_err(|e| {
ErrorMessage::from_http_error(HttpError {
code: 400,
message: VALIDATION_PREFIX.to_string() + &e.to_string(),
})
})
}
async fn handler_responses(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(mut request): Json<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
let request_id = get_or_create_request_id(None, &headers);
let request = Context::with_id(request, request_id);
let context = request.context();
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let response =
tokio::spawn(responses(state, template, request, stream_handle).in_current_span())
.await
.map_err(|e| {
ErrorMessage::internal_server_error(&format!(
"Failed to await responses task: {:?}",
e,
))
})?;
connection_handle.disarm();
response
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn responses(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
mut request: Context<NvCreateResponse>,
mut stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
const DEFAULT_MAX_OUTPUT_TOKENS: u32 = 4096;
if let Some(template) = template {
if request.inner.model.as_deref().unwrap_or("").is_empty() {
request.inner.model = Some(template.model.clone());
}
if request.inner.temperature.is_none() {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_output_tokens.is_none() {
request.inner.max_output_tokens = Some(template.max_completion_tokens);
}
} else if request.inner.max_output_tokens.is_none() {
request.inner.max_output_tokens = Some(DEFAULT_MAX_OUTPUT_TOKENS);
}
tracing::trace!("Received responses request: {:?}", request.inner);
let model = request.inner.model.clone().unwrap_or_default();
let streaming = request.inner.stream.unwrap_or(false);
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Responses, streaming);
if let Some(resp) = validate_response_unsupported_fields(&request) {
inflight_guard.mark_error(ErrorType::NotImplemented);
return Ok(resp.into_response());
}
let response_params = ResponseParams {
model: request.inner.model.clone(),
temperature: request.inner.temperature,
top_p: request.inner.top_p,
max_output_tokens: request.inner.max_output_tokens,
store: request.inner.store,
tools: request.inner.tools.clone(),
tool_choice: request.inner.tool_choice.clone(),
instructions: request.inner.instructions.clone(),
reasoning: request.inner.reasoning.clone(),
text: request.inner.text.clone(),
service_tier: request.inner.service_tier,
include: request.inner.include.clone(),
truncation: request.inner.truncation,
};
let request_id = request.id().to_string();
let (orig_request, context) = request.into_parts();
let mut chat_request: NvCreateChatCompletionRequest =
orig_request.try_into().map_err(|e: anyhow::Error| {
tracing::error!(
request_id,
error = %e,
"Failed to convert NvCreateResponse to NvCreateChatCompletionRequest",
);
let err_response = ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string()
+ "Failed to convert responses request: "
+ &e.to_string(),
);
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
chat_request.inner.stream = Some(true);
chat_request.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
let request = context.map(|mut _req| chat_request);
tracing::trace!("Getting chat completions engine for model: {}", model);
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for responses");
let engine_stream = engine.generate(request).await.map_err(|e| {
let err_response = ErrorMessage::from_anyhow(e, "Failed to generate completions");
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let ctx = engine_stream.context();
if streaming {
stream_handle.arm();
use crate::protocols::openai::responses::stream_converter::ResponseStreamConverter;
use std::sync::atomic::{AtomicBool, Ordering};
let mut converter = ResponseStreamConverter::new(model.clone(), response_params);
let start_events = converter.emit_start_events();
let converter = std::sync::Arc::new(std::sync::Mutex::new(converter));
let converter_end = converter.clone();
let saw_error = std::sync::Arc::new(AtomicBool::new(false));
let saw_error_end = saw_error.clone();
let mut http_queue_guard = Some(http_queue_guard);
let event_stream = engine_stream
.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
})
.filter_map(move |annotated_chunk| {
let converter = converter.clone();
let saw_error = saw_error.clone();
async move {
if annotated_chunk.data.is_none() {
if annotated_chunk.event.as_deref() == Some("error") {
saw_error.store(true, Ordering::Release);
}
return None;
}
let stream_resp = annotated_chunk.data?;
let mut conv = converter.lock().expect("converter lock poisoned");
let events = conv.process_chunk(&stream_resp);
Some(stream::iter(events))
}
})
.flatten();
let start_stream = stream::iter(start_events);
let done_stream = stream::once(async move {
let mut conv = converter_end.lock().expect("converter lock poisoned");
let end_events = if saw_error_end.load(Ordering::Acquire) {
conv.emit_error_events()
} else {
conv.emit_end_events()
};
stream::iter(end_events)
})
.flatten();
let full_stream = start_stream.chain(event_stream).chain(done_stream);
let full_stream = full_stream.map(|result| result.map_err(axum::Error::new));
let stream = monitor_for_disconnects(full_stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let stream_with_check =
check_for_backend_error(engine_stream)
.await
.map_err(|error_response| {
tracing::error!(request_id, "Backend error detected: {:?}", error_response);
inflight_guard.mark_error(extract_error_type_from_response(&error_response));
error_response
})?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream_with_check.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(request_id, "Failed to fold responses stream: {:?}", e);
let err_response = ErrorMessage::internal_server_error(&format!(
"Failed to fold responses stream: {}",
e
));
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
let response: NvResponse = chat_completion_to_response(response, &response_params)
.map_err(|e| {
tracing::error!(
request_id,
"Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}",
e
);
let err_response =
ErrorMessage::internal_server_error("Failed to convert internal response");
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
inflight_guard.mark_ok();
if ctx.is_killed() {
inflight_guard.mark_error(ErrorType::Cancelled);
}
Ok(Json(response).into_response())
}
}
pub fn validate_response_unsupported_fields(
request: &NvCreateResponse,
) -> Option<impl IntoResponse> {
let inner = &request.inner;
if inner.background == Some(true) {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`background: true` is not supported.",
));
}
if inner.previous_response_id.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`previous_response_id` is not supported.",
));
}
if inner.prompt.is_some() {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`prompt` is not supported.",
));
}
if inner.store == Some(true) {
return Some(ErrorMessage::not_implemented_error(
VALIDATION_PREFIX.to_string() + "`store: true` is not supported.",
));
}
None
}
fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), ErrorResponse> {
Ok(())
}
async fn list_models_openai(
State(state): State<Arc<service_v2::State>>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let mut data = Vec::new();
let models: HashSet<String> = state.manager().model_display_names();
for model_name in models {
data.push(ModelListing {
id: model_name.clone(),
object: "model", created,
owned_by: "nvidia".to_string(),
});
}
let out = ListModelOpenAI {
object: "list",
data,
};
Ok(Json(out).into_response())
}
#[derive(Serialize)]
struct ListModelOpenAI {
object: &'static str, data: Vec<ModelListing>,
}
#[derive(Serialize)]
struct ModelListing {
id: String,
object: &'static str, created: u64, owned_by: String,
}
pub fn completions_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
pub fn chat_completions_router(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/chat/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_chat_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state((state, template));
(vec![doc], router)
}
pub fn embeddings_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/embeddings".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(embeddings))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
pub fn list_models_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let openai_path = path.unwrap_or("/v1/models".to_string());
let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path);
let router = Router::new()
.route(&openai_path, get(list_models_openai))
.with_state(state);
(vec![doc_for_openai], router)
}
pub fn responses_router(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/responses".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_responses))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state((state, template));
(vec![doc], router)
}
async fn images(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateImageRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
let streaming = false;
let model = request
.inner
.model
.as_ref()
.map(|m| match m {
dynamo_async_openai::types::ImageModel::DallE2 => "dall-e-2".to_string(),
dynamo_async_openai::types::ImageModel::DallE3 => "dall-e-3".to_string(),
dynamo_async_openai::types::ImageModel::Other(s) => s.clone(),
})
.unwrap_or_else(|| "diffusion".to_string());
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_images_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Images, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate images"))?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvImagesResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold images stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold images stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
pub fn images_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/images/generations".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(images))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
async fn videos(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateVideoRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let request_id = get_or_create_request_id(request.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
let streaming = false;
let model = request.model.clone();
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Videos, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate videos"))?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvVideosResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold videos stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold videos stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
async fn video_stream(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateVideoRequest>,
) -> Result<Response, ErrorResponse> {
check_ready(&state)?;
let request_id = get_or_create_request_id(request.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let model = request.model.clone();
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight = state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Videos, true);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to start video stream"))?;
let ctx = stream.context();
let (mut connection_handle, mut stream_handle) =
create_connection_monitor(ctx.clone(), Some(state.metrics_clone())).await;
connection_handle.disarm();
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let mjpeg_stream = stream.filter_map(|annotated| async move {
let ann = match annotated.ok() {
Ok(a) => a,
Err(e) => {
tracing::error!("Video stream error: {e}");
return None;
}
};
let response = ann.data?;
let frame = response.data.into_iter().next()?;
let b64 = frame.b64_json?;
let jpeg_bytes = match base64::prelude::BASE64_STANDARD.decode(&b64) {
Ok(b) => b,
Err(e) => {
tracing::warn!("Failed to decode frame base64: {e}");
return None;
}
};
let header = format!(
"--frame\r\nContent-Type: image/jpeg\r\nContent-Length: {}\r\n\r\n",
jpeg_bytes.len()
);
let mut chunk = Vec::with_capacity(header.len() + jpeg_bytes.len() + 2);
chunk.extend_from_slice(header.as_bytes());
chunk.extend_from_slice(&jpeg_bytes);
chunk.extend_from_slice(b"\r\n");
Some(Ok::<Bytes, std::convert::Infallible>(Bytes::from(chunk)))
});
stream_handle.arm();
let monitored_stream = async_stream::stream! {
tokio::pin!(mjpeg_stream);
loop {
tokio::select! {
frame = mjpeg_stream.next() => {
match frame {
Some(item) => yield item,
None => {
inflight.mark_ok();
stream_handle.disarm();
break;
}
}
}
_ = ctx.stopped() => {
tracing::trace!("Context stopped; breaking MJPEG stream");
break;
}
}
}
};
axum::http::Response::builder()
.status(axum::http::StatusCode::OK)
.header(
axum::http::header::CONTENT_TYPE,
"multipart/x-mixed-replace; boundary=frame",
)
.body(Body::from_stream(monitored_stream))
.map(|r| r.into_response())
.map_err(|e| {
ErrorMessage::internal_server_error(&format!("Failed to build MJPEG response: {e}"))
})
}
pub fn videos_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/videos".to_string());
let stream_path = format!("{}/stream", path);
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let stream_doc = RouteDoc::new(axum::http::Method::POST, &stream_path);
let router = Router::new()
.route(&path, post(videos))
.route(&stream_path, post(video_stream))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc, stream_doc], router)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::ModelManagerError;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use crate::protocols::openai::common_ext::CommonExt;
use crate::protocols::openai::completions::NvCreateCompletionRequest;
use crate::protocols::openai::responses::NvCreateResponse;
use dynamo_async_openai::types::responses::{CreateResponse, Input, PromptConfig};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
CreateCompletionRequest,
};
const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions";
fn http_error_from_engine(code: u16) -> Result<(), anyhow::Error> {
Err(HttpError {
code,
message: "custom error message".to_string(),
})?
}
fn other_error_from_engine() -> Result<(), anyhow::Error> {
Err(ModelManagerError::ModelNotFound("foo".to_string()))?
}
fn make_base_request() -> NvCreateResponse {
NvCreateResponse {
inner: CreateResponse {
input: Input::Text("hello".into()),
model: Some("test-model".into()),
..Default::default()
},
nvext: None,
}
}
#[test]
fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::BAD_REQUEST);
assert_eq!(response.1.message, "custom error message");
}
#[test]
fn test_error_response_from_anyhow_out_of_range() {
let err = http_error_from_engine(399).unwrap_err();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(500).unwrap_err();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(501).unwrap_err();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
}
#[test]
fn test_other_error_response_from_anyhow() {
let err = other_error_from_engine().unwrap_err();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.1.message,
format!(
"{}: {}",
BACKUP_ERROR_MESSAGE,
other_error_from_engine().unwrap_err()
)
);
}
#[test]
fn test_service_overloaded_error_response_from_anyhow() {
use dynamo_runtime::pipeline::error::PipelineError;
let err: anyhow::Error = PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into();
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
response.1.message,
"Service temporarily unavailable: All workers are busy, please retry later"
);
}
#[test]
fn test_validate_unsupported_fields_accepts_clean_request() {
let request = make_base_request();
let result = validate_response_unsupported_fields(&request);
assert!(result.is_none());
}
#[test]
fn test_validate_unsupported_fields_accepts_parallel_tool_calls() {
let mut request = make_base_request();
request.inner.parallel_tool_calls = Some(true);
let result = validate_response_unsupported_fields(&request);
assert!(result.is_none(), "parallel_tool_calls should be supported");
}
#[test]
fn test_validate_unsupported_fields_detects_flags() {
#[allow(clippy::type_complexity)]
let unsupported_cases: Vec<(&str, Box<dyn FnOnce(&mut CreateResponse)>)> = vec![
("background", Box::new(|r| r.background = Some(true))),
(
"previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())),
),
(
"prompt",
Box::new(|r| {
r.prompt = Some(PromptConfig {
id: "template-id".into(),
version: None,
variables: None,
})
}),
),
("store", Box::new(|r| r.store = Some(true))),
];
for (field, set_field) in unsupported_cases {
let mut req = make_base_request();
(set_field)(&mut req.inner);
let result = validate_response_unsupported_fields(&req);
assert!(result.is_some(), "Expected rejection for `{field}`");
}
}
#[test]
fn test_validate_chat_completion_required_fields_empty_messages() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![],
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!(
"{VALIDATION_PREFIX}The 'messages' field cannot be empty. At least one message is required."
)
);
}
}
#[test]
fn test_validate_chat_completion_required_fields_with_messages() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_ok());
}
#[test]
fn test_bad_base_request_for_completion() {
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
frequency_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Frequency penalty must be between -2 and 2, got -3")
);
}
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
presence_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Presence penalty must be between -2 and 2, got -3")
);
}
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
temperature: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Temperature must be between 0 and 2, got -3")
);
}
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
top_p: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Top_p must be between 0 and 1, got -3")
);
}
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
..Default::default()
},
common: CommonExt::builder()
.repetition_penalty(-3.0)
.build()
.unwrap(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Repetition penalty must be between 0 and 2, got -3")
);
}
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
logprobs: Some(6),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Logprobs must be between 0 and 5, got 6")
);
}
}
#[test]
fn test_metadata_field_nested() {
use serde_json::json;
let request = NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: "test-model".to_string(),
prompt: "Hello".into(),
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: json!({
"user": {"id": 1, "name": "user-1"},
"session": {"id": "session-1", "timestamp": 1640995200}
})
.into(),
unsupported_fields: Default::default(),
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_ok());
assert!(request.metadata.is_some());
assert_eq!(request.metadata.as_ref().unwrap()["user"]["id"], 1);
}
#[test]
fn test_bad_base_request_for_chatcompletion() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
frequency_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Frequency penalty must be between -2 and 2, got -3")
);
}
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
presence_penalty: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Presence penalty must be between -2 and 2, got -3")
);
}
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
temperature: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Temperature must be between 0 and 2, got -3")
);
}
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
top_p: Some(-3.0),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Top_p must be between 0 and 1, got -3")
);
}
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
..Default::default()
},
common: CommonExt::builder()
.repetition_penalty(-3.0)
.build()
.unwrap(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Repetition penalty must be between 0 and 2, got -3")
);
}
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)],
top_logprobs: Some(25),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.1.message,
format!("{VALIDATION_PREFIX}Top_logprobs must be between 0 and 20, got 25")
);
}
}
#[test]
fn test_chat_completions_unknown_fields_rejected() {
let json = r#"{
"messages": [{"role": "user", "content": "Hello"}],
"model": "test-model",
"add_special_tokens": true,
"documents": ["doc1"],
"chat_template": "custom"
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("documents"));
assert!(request.unsupported_fields.contains_key("chat_template"));
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("documents"));
assert!(msg.contains("chat_template"));
}
}
#[test]
fn test_completions_unsupported_fields_rejected() {
let json = r#"{
"model": "test-model",
"prompt": "Hello",
"add_special_tokens": true,
"response_format": {"type": "json_object"}
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json).unwrap();
assert!(
request
.unsupported_fields
.contains_key("add_special_tokens")
);
assert!(request.unsupported_fields.contains_key("response_format"));
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
let msg = &error_response.1.message;
assert!(msg.contains("Unsupported parameter"));
assert!(msg.contains("add_special_tokens"));
assert!(msg.contains("response_format"));
}
}
#[tokio::test]
async fn test_check_for_backend_error_with_error_event() {
use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use futures::stream;
let error_event = Annotated::<NvCreateChatCompletionStreamResponse> {
data: None,
id: None,
event: Some("error".to_string()),
comment: Some(vec!["Backend service unavailable".to_string()]),
error: None,
};
let test_stream = stream::iter(vec![error_event]);
let result = check_for_backend_error(test_stream).await;
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(error_response.1.message, "Backend service unavailable");
}
}
#[tokio::test]
async fn test_check_for_backend_error_with_json_error_and_code() {
use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use futures::stream;
let error_json =
r#"{"message":"prompt > max_seq_len","type":"Internal Server Error","code":500}"#;
let error_event = Annotated::<NvCreateChatCompletionStreamResponse> {
data: None,
id: None,
event: Some("error".to_string()),
comment: Some(vec![error_json.to_string()]),
error: None,
};
let test_stream = stream::iter(vec![error_event]);
let result = check_for_backend_error(test_stream).await;
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(error_response.1.message, "prompt > max_seq_len");
assert_eq!(error_response.1.code, 500);
}
}
#[tokio::test]
async fn test_check_for_backend_error_with_normal_event() {
use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_async_openai::types::CreateChatCompletionStreamResponse;
use futures::stream::{self, StreamExt};
let normal_event = Annotated::<NvCreateChatCompletionStreamResponse> {
data: Some(CreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![],
created: 0,
model: "test-model".to_string(),
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
service_tier: None,
usage: None,
nvext: None,
}),
id: Some("msg-1".to_string()),
event: None,
comment: None,
error: None,
};
let test_stream = stream::iter(vec![normal_event.clone()]);
let result = check_for_backend_error(test_stream).await;
assert!(result.is_ok());
let mut returned_stream = result.unwrap();
let first = returned_stream.next().await;
assert!(first.is_some());
let first_event = first.unwrap();
assert_eq!(first_event.id, Some("msg-1".to_string()));
}
#[tokio::test]
async fn test_check_for_backend_error_with_empty_stream() {
use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use futures::stream::{self, StreamExt};
let test_stream =
stream::iter::<Vec<Annotated<NvCreateChatCompletionStreamResponse>>>(vec![]);
let result = check_for_backend_error(test_stream).await;
assert!(result.is_ok());
let mut returned_stream = result.unwrap();
let first = returned_stream.next().await;
assert!(first.is_none());
}
#[tokio::test]
async fn test_check_for_backend_error_with_comment_but_no_event_type() {
use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use futures::stream;
let error_event = Annotated::<NvCreateChatCompletionStreamResponse> {
data: None,
id: None,
event: None,
comment: Some(vec!["Connection timeout".to_string()]),
error: None,
};
let test_stream = stream::iter(vec![error_event]);
let result = check_for_backend_error(test_stream).await;
assert!(result.is_err());
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(error_response.1.message, "Connection timeout");
}
}
#[test]
fn test_classify_error_for_metrics_validation() {
let error_type =
classify_error_for_metrics(StatusCode::BAD_REQUEST, "Validation: Invalid parameter");
assert_eq!(error_type, ErrorType::Validation);
let error_type = classify_error_for_metrics(StatusCode::BAD_REQUEST, "Some other error");
assert_eq!(error_type, ErrorType::Internal);
}
#[test]
fn test_classify_error_for_metrics_status_codes() {
assert_eq!(
classify_error_for_metrics(StatusCode::NOT_FOUND, "Model not found"),
ErrorType::NotFound
);
assert_eq!(
classify_error_for_metrics(StatusCode::NOT_IMPLEMENTED, "Feature not supported"),
ErrorType::NotImplemented
);
assert_eq!(
classify_error_for_metrics(StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded"),
ErrorType::Overload
);
assert_eq!(
classify_error_for_metrics(StatusCode::SERVICE_UNAVAILABLE, "Overloaded"),
ErrorType::Overload
);
assert_eq!(
classify_error_for_metrics(StatusCode::INTERNAL_SERVER_ERROR, "Panic"),
ErrorType::Internal
);
}
#[test]
fn test_classify_error_for_metrics_client_errors() {
assert_eq!(
classify_error_for_metrics(StatusCode::UNAUTHORIZED, "Unauthorized"),
ErrorType::Validation
);
assert_eq!(
classify_error_for_metrics(StatusCode::FORBIDDEN, "Forbidden"),
ErrorType::Validation
);
}
#[test]
fn test_extract_error_type_from_response_validation() {
let response = ErrorMessage::from_http_error(HttpError {
code: 400,
message: "Validation: bad input".to_string(),
});
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::Validation
);
}
#[test]
fn test_extract_error_type_from_response_not_found() {
let response = ErrorMessage::model_not_found();
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::NotFound
);
}
#[test]
fn test_extract_error_type_from_response_internal() {
let response = ErrorMessage::internal_server_error("Something went wrong");
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::Internal
);
}
#[test]
fn test_extract_error_type_from_response_not_implemented() {
let response = ErrorMessage::not_implemented_error("Feature not available");
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::NotImplemented
);
}
}