use axum::{
extract::State,
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
},
routing::{get, post},
Json, Router,
};
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
pin::Pin,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use tokio_stream::wrappers::ReceiverStream;
use super::DeploymentState;
use super::{
error::HttpError,
metrics::{Endpoint, InflightGuard},
RouteDoc,
};
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
};
use crate::request_template::RequestTemplate;
use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
Annotated,
};
use dynamo_runtime::pipeline::{AsyncEngineContext, Context};
#[derive(Serialize, Deserialize)]
pub(crate) struct ErrorResponse {
error: String,
}
impl ErrorResponse {
pub fn model_not_found() -> (StatusCode, Json<ErrorResponse>) {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "Model not found".to_string(),
}),
)
}
pub fn _service_unavailable() -> (StatusCode, Json<ErrorResponse>) {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "Service is not ready".to_string(),
}),
)
}
pub fn internal_server_error(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
tracing::error!("Internal server error: {msg}");
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: msg.to_string(),
}),
)
}
pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> (StatusCode, Json<ErrorResponse>) {
match err.downcast::<HttpError>() {
Ok(http_error) => ErrorResponse::from_http_error(http_error),
Err(err) => ErrorResponse::internal_server_error(&format!("{alt_msg}: {err}")),
}
}
pub fn from_http_error(err: HttpError) -> (StatusCode, Json<ErrorResponse>) {
if err.code < 400 || err.code >= 500 {
return ErrorResponse::internal_server_error(&err.message);
}
match StatusCode::from_u16(err.code) {
Ok(code) => (code, Json(ErrorResponse { error: err.message })),
Err(_) => ErrorResponse::internal_server_error(&err.message),
}
}
}
impl From<HttpError> for ErrorResponse {
fn from(err: HttpError) -> Self {
ErrorResponse { error: err.message }
}
}
#[tracing::instrument(skip_all)]
async fn completions(
State(state): State<Arc<DeploymentState>>,
Json(request): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?;
let request_id = uuid::Uuid::new_v4().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let inner = async_openai::types::CreateCompletionRequest {
stream: Some(true),
..request.inner
};
let request = CompletionRequest {
inner,
nvext: request.nvext,
};
let model = &request.inner.model;
let engine = state
.get_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?;
let mut inflight = state.create_inflight_guard(model, Endpoint::Completions, streaming);
let request = Context::with_id(request, request_id.clone());
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?;
let ctx = stream.context();
if streaming {
let stream = stream.map(|response| Event::try_from(EventConverter::from(response)));
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight).await;
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let response = CompletionResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
"Failed to fold completions stream for {}: {:?}",
request_id,
e
);
ErrorResponse::internal_server_error("Failed to fold completions stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
}
#[tracing::instrument(skip_all)]
async fn chat_completions(
State((state, template)): State<(Arc<DeploymentState>, Option<RequestTemplate>)>,
Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?;
if let Some(template) = template {
if request.inner.model.is_empty() {
request.inner.model = template.model.clone();
}
if request.inner.temperature.unwrap_or(0.0) == 0.0 {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_completion_tokens.unwrap_or(0) == 0 {
request.inner.max_completion_tokens = Some(template.max_completion_tokens);
}
}
tracing::trace!("Received chat completions request: {:?}", request.inner);
let request_id = uuid::Uuid::new_v4().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let inner_request = async_openai::types::CreateChatCompletionRequest {
stream: Some(true),
..request.inner
};
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: request.nvext,
};
let model = &request.inner.model;
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
.get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?;
let mut inflight = state.create_inflight_guard(model, Endpoint::ChatCompletions, streaming);
let request = Context::with_id(request, request_id.clone());
tracing::trace!("Issuing generate call for chat completions");
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?;
let ctx = stream.context();
if streaming {
let stream = stream.map(|response| Event::try_from(EventConverter::from(response)));
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight).await;
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorResponse::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
}
}
fn check_ready(_state: &Arc<DeploymentState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Ok(())
}
async fn list_models_custom(
State(state): State<Arc<DeploymentState>>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?;
let mut models = HashMap::new();
let chat_models = state
.chat_completion_engines
.lock()
.unwrap()
.engines
.keys()
.cloned()
.collect::<Vec<String>>();
let completion_models = state
.completion_engines
.lock()
.unwrap()
.engines
.keys()
.cloned()
.collect::<Vec<String>>();
models.insert("chat_completion_models", chat_models);
models.insert("completion_models", completion_models);
Ok(Json(models).into_response())
}
async fn list_models_openai(
State(state): State<Arc<DeploymentState>>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let mut data = Vec::new();
let models: HashSet<String> = state
.chat_completion_engines
.lock()
.unwrap()
.engines
.keys()
.chain(state.completion_engines.lock().unwrap().engines.keys())
.cloned()
.collect();
for model_id in models {
data.push(ModelListing {
id: model_id.clone(),
object: "object",
created, owned_by: "nvidia".to_string(), });
}
let out = ListModelOpenAI {
object: "list",
data,
};
Ok(Json(out).into_response())
}
#[derive(Serialize)]
struct ListModelOpenAI {
object: &'static str, data: Vec<ModelListing>,
}
#[derive(Serialize)]
struct ModelListing {
id: String,
object: &'static str, created: u64, owned_by: String,
}
async fn monitor_for_disconnects(
stream: Pin<
Box<dyn Stream<Item = Result<axum::response::sse::Event, axum::Error>> + std::marker::Send>,
>,
context: Arc<dyn AsyncEngineContext>,
inflight: InflightGuard,
) -> ReceiverStream<Result<Event, axum::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move {
let mut inflight = inflight;
let mut stream = stream;
while let Some(event) = stream.next().await {
let event = match event {
Ok(event) => Ok(event),
Err(err) => Ok(Event::default().event("error").comment(err.to_string())),
};
if (tx.send(event).await).is_err() {
tracing::trace!("Forwarding SSE stream was dropped; breaking loop");
context.stop_generating();
break;
}
}
if tx.send(Ok(Event::default().data("[DONE]"))).await.is_ok() {
inflight.mark_ok();
}
});
ReceiverStream::new(rx)
}
struct EventConverter<T>(Annotated<T>);
impl<T> From<Annotated<T>> for EventConverter<T> {
fn from(annotated: Annotated<T>) -> Self {
EventConverter(annotated)
}
}
impl<T: Serialize> TryFrom<EventConverter<T>> for Event {
type Error = axum::Error;
fn try_from(annotated: EventConverter<T>) -> Result<Self, Self::Error> {
let annotated = annotated.0;
let mut event = Event::default();
if let Some(data) = annotated.data {
event = event.json_data(data)?;
}
if let Some(msg) = annotated.event {
if msg == "error" {
let msgs = annotated
.comment
.unwrap_or_else(|| vec!["unspecified error".to_string()]);
return Err(axum::Error::new(msgs.join(" -- ")));
}
event = event.event(msg);
}
if let Some(comments) = annotated.comment {
for comment in comments {
event = event.comment(comment);
}
}
Ok(event)
}
}
pub fn completions_router(
state: Arc<DeploymentState>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(completions))
.with_state(state);
(vec![doc], router)
}
pub fn chat_completions_router(
state: Arc<DeploymentState>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/chat/completions".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(chat_completions))
.with_state((state, template));
(vec![doc], router)
}
pub fn list_models_router(
state: Arc<DeploymentState>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let custom_path = path.unwrap_or("/dynamo/alpha/list-models".to_string());
let doc_for_custom = RouteDoc::new(axum::http::Method::GET, &custom_path);
let openai_path = "/v1/models".to_string();
let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path);
let router = Router::new()
.route(&custom_path, get(list_models_custom))
.route(&openai_path, get(list_models_openai))
.with_state(state);
(vec![doc_for_custom, doc_for_openai], router)
}
#[cfg(test)]
mod tests {
use super::super::ServiceHttpError;
use super::*;
const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions";
fn http_error_from_engine(code: u16) -> Result<(), anyhow::Error> {
Err(HttpError {
code,
message: "custom error message".to_string(),
})?
}
fn other_error_from_engine() -> Result<(), anyhow::Error> {
Err(ServiceHttpError::ModelNotFound("foo".to_string()))?
}
#[test]
fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(response.error, "custom error message");
}
#[test]
fn test_error_response_from_anyhow_out_of_range() {
let err = http_error_from_engine(399).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
let err = http_error_from_engine(500).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
let err = http_error_from_engine(501).unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
}
#[test]
fn test_other_error_response_from_anyhow() {
let err = other_error_from_engine().unwrap_err();
let (status, response) = ErrorResponse::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.error,
format!(
"{}: {}",
BACKUP_ERROR_MESSAGE,
other_error_from_engine().unwrap_err()
)
);
}
}