use std::collections::HashSet;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::{
Json, Router,
body::Body,
extract::State,
http::{HeaderMap, Request, StatusCode},
middleware::{self, Next},
response::{
IntoResponse, Response,
sse::{KeepAlive, Sse},
},
routing::{get, post},
};
use dynamo_runtime::config::{env_is_truthy, environment_names::llm as env_llm};
use dynamo_runtime::pipeline::{AsyncEngineContextProvider, Context};
use futures::{StreamExt, stream};
use tracing::Instrument;
use super::{
RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
metrics::{CancellationLabels, Endpoint, process_response_and_observe_metrics},
service_v2,
};
use crate::protocols::anthropic::stream_converter::AnthropicStreamConverter;
use crate::protocols::anthropic::types::{
AnthropicCountTokensRequest, AnthropicCountTokensResponse, AnthropicCreateMessageRequest,
AnthropicErrorBody, AnthropicErrorResponse, SystemContent,
chat_completion_to_anthropic_response,
};
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse,
aggregator::ChatCompletionAggregator,
};
use crate::protocols::unified::UnifiedRequest;
use crate::request_template::RequestTemplate;
use crate::types::Annotated;
use super::openai::{get_body_limit, get_or_create_request_id};
pub fn anthropic_messages_router(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/messages".to_string());
let count_tokens_path = format!("{}/count_tokens", &path);
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let count_doc = RouteDoc::new(axum::http::Method::POST, &count_tokens_path);
let router = Router::new()
.route(&path, post(handler_anthropic_messages))
.route(&count_tokens_path, post(handler_count_tokens))
.layer(middleware::from_fn(anthropic_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state((state, template));
(vec![doc, count_doc], router)
}
pub fn anthropic_models_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let models_path = path.unwrap_or("/v1/models".to_string());
let retrieve_path = format!("{}/{{*model_id}}", models_path);
let list_doc = RouteDoc::new(axum::http::Method::GET, &models_path);
let retrieve_doc = RouteDoc::new(axum::http::Method::GET, &retrieve_path);
let router = Router::new()
.route(&models_path, get(list_models))
.route(&retrieve_path, get(get_model))
.with_state(state);
(vec![list_doc, retrieve_doc], router)
}
async fn anthropic_error_middleware(request: Request<Body>, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, get_body_limit())
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
return anthropic_error(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&error_message,
);
}
response
}
async fn handler_anthropic_messages(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(request): Json<AnthropicCreateMessageRequest>,
) -> Result<Response, Response> {
if request.messages.is_empty() {
return Err(anthropic_error(
StatusCode::BAD_REQUEST,
"invalid_request_error",
"messages: field required",
));
}
if request.max_tokens == 0 {
return Err(anthropic_error(
StatusCode::BAD_REQUEST,
"invalid_request_error",
"max_tokens: must be greater than 0",
));
}
let request_id = get_or_create_request_id(&headers);
let streaming = request.stream;
let cancellation_labels = CancellationLabels {
model: request.model.clone(),
endpoint: Endpoint::AnthropicMessages.to_string(),
request_type: if streaming { "stream" } else { "unary" }.to_string(),
};
let request = Context::with_id(request, request_id);
let context = request.context();
let (mut connection_handle, stream_handle) = create_connection_monitor(
context.clone(),
Some(state.metrics_clone()),
cancellation_labels,
)
.await;
let response =
tokio::spawn(anthropic_messages(state, template, request, stream_handle).in_current_span())
.await
.map_err(|e| {
anthropic_error(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
&format!("Failed to await messages task: {:?}", e),
)
})?;
connection_handle.disarm();
response
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn anthropic_messages(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
mut request: Context<AnthropicCreateMessageRequest>,
mut stream_handle: ConnectionHandle,
) -> Result<Response, Response> {
let streaming = request.stream;
let request_id = request.id().to_string();
if let Some(template) = template {
if request.model.is_empty() {
request.model = template.model.clone();
}
if request.temperature.is_none() {
request.temperature = Some(template.temperature);
}
if request.max_tokens == 0 {
request.max_tokens = template.max_completion_tokens;
}
}
if env_is_truthy(env_llm::DYN_STRIP_ANTHROPIC_PREAMBLE) {
strip_billing_preamble(&mut request.system);
}
let model = request.model.clone();
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
tracing::trace!("Received Anthropic messages request: {:?}", &*request);
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
anthropic_error(
StatusCode::NOT_FOUND,
"not_found_error",
&format!("Model '{}' not found", model),
)
})?;
let (orig_request, context) = request.into_parts();
let model_for_resp = orig_request.model.clone();
let thinking_explicitly_disabled = orig_request
.thinking
.as_ref()
.is_some_and(|t| t.thinking_type == "disabled");
let estimated_input_tokens = if streaming {
estimate_input_tokens(&orig_request)
} else {
0
};
let unified_request: UnifiedRequest = orig_request.try_into().map_err(|e: anyhow::Error| {
tracing::error!(
request_id,
error = %e,
"Failed to convert AnthropicCreateMessageRequest to UnifiedRequest",
);
anthropic_error(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!("Failed to convert request: {}", e),
)
})?;
let anthropic_ctx = unified_request.anthropic_context().cloned();
let mut chat_request = unified_request.into_inner();
let prompt_injected_reasoning =
parsing_options.reasoning_parser.is_some() && !thinking_explicitly_disabled;
if prompt_injected_reasoning {
let args = chat_request
.chat_template_args
.get_or_insert_with(Default::default);
args.entry("enable_thinking".to_string())
.or_insert(serde_json::Value::Bool(true));
args.entry("truncate_history_thinking".to_string())
.or_insert(serde_json::Value::Bool(false));
}
let request = context.map(|_req| chat_request);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let mut inflight_guard = state.metrics_clone().create_inflight_guard(
&model,
Endpoint::AnthropicMessages,
streaming,
request.id(),
);
tracing::trace!("Issuing generate call for Anthropic messages");
let engine_stream = engine.generate(request).await.map_err(|e| {
if super::metrics::request_was_rejected(e.as_ref()) {
state
.metrics_clone()
.inc_rejection(&model, super::metrics::Endpoint::AnthropicMessages);
}
if super::metrics::request_was_cancelled(e.as_ref()) {
inflight_guard.mark_error(super::metrics::ErrorType::Cancelled);
return anthropic_error(
StatusCode::from_u16(499).unwrap(),
"request_cancelled",
&format!("Request cancelled: {}", e),
);
}
inflight_guard.mark_error(super::metrics::ErrorType::Internal);
anthropic_error(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
&format!("Failed to generate completions: {}", e),
)
})?;
let ctx = engine_stream.context();
let engine_stream: Pin<
Box<dyn futures::Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>,
> = Box::pin(engine_stream);
if streaming {
stream_handle.arm();
use std::sync::atomic::{AtomicBool, Ordering};
let mut converter = match anthropic_ctx {
Some(ctx) => {
AnthropicStreamConverter::with_context(model_for_resp, estimated_input_tokens, ctx)
}
None => AnthropicStreamConverter::new(model_for_resp, estimated_input_tokens),
};
let start_events = converter.emit_start_events();
let converter = std::sync::Arc::new(std::sync::Mutex::new(converter));
let converter_end = converter.clone();
let saw_error = std::sync::Arc::new(AtomicBool::new(false));
let saw_error_end = saw_error.clone();
let mut http_queue_guard = Some(http_queue_guard);
let event_stream = engine_stream
.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
})
.filter_map(move |annotated_chunk| {
let converter = converter.clone();
let saw_error = saw_error.clone();
async move {
if annotated_chunk.data.is_none() {
if annotated_chunk.event.as_deref() == Some("error") {
saw_error.store(true, Ordering::Release);
}
return None;
}
let stream_resp = annotated_chunk.data?;
let mut conv = converter.lock().expect("converter lock poisoned");
let events = conv.process_chunk(&stream_resp);
Some(stream::iter(events))
}
})
.flatten();
let start_stream = stream::iter(start_events);
let done_stream = stream::once(async move {
let mut conv = converter_end.lock().expect("converter lock poisoned");
let end_events = if saw_error_end.load(Ordering::Acquire) {
conv.emit_error_events()
} else {
conv.emit_end_events()
};
stream::iter(end_events)
})
.flatten();
let full_stream = start_stream.chain(event_stream).chain(done_stream);
let full_stream = full_stream.map(|result| result.map_err(axum::Error::new));
let stream = monitor_for_disconnects(full_stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
} else {
let stream_with_check = super::openai::check_for_backend_error(engine_stream)
.await
.map_err(|(status, json_err)| {
tracing::error!(request_id, %status, ?json_err, "Backend error detected");
anthropic_error(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
"Backend error during generation",
)
})?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream_with_check.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let chat_response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(request_id, "Failed to fold messages stream: {:?}", e);
anthropic_error(
StatusCode::INTERNAL_SERVER_ERROR,
"api_error",
&format!("Failed to fold messages stream: {}", e),
)
})?;
let response = chat_completion_to_anthropic_response(
chat_response,
&model_for_resp,
anthropic_ctx.as_ref(),
);
inflight_guard.mark_ok();
Ok(Json(response).into_response())
}
}
async fn handler_count_tokens(
State((_state, _template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
Json(mut request): Json<AnthropicCountTokensRequest>,
) -> Result<Response, Response> {
if env_is_truthy(env_llm::DYN_STRIP_ANTHROPIC_PREAMBLE) {
strip_billing_preamble(&mut request.system);
}
let tokens = request.estimate_tokens();
Ok(Json(AnthropicCountTokensResponse {
input_tokens: tokens,
})
.into_response())
}
fn build_model_context_map(state: &service_v2::State) -> std::collections::HashMap<String, u32> {
state
.manager()
.get_model_cards()
.iter()
.map(|c| (c.display_name.clone(), c.context_length))
.collect()
}
fn model_env_overrides() -> (Option<u64>, Option<u64>) {
let context_window = match std::env::var("DYN_CONTEXT_WINDOW") {
Ok(v) => match v.parse::<u64>() {
Ok(val) => Some(val),
Err(_) => {
tracing::warn!("Invalid DYN_CONTEXT_WINDOW value '{}', ignoring", v);
None
}
},
Err(_) => None,
};
let max_output_tokens = match std::env::var("DYN_MAX_OUTPUT_TOKENS") {
Ok(v) => match v.parse::<u64>() {
Ok(val) => Some(val),
Err(_) => {
tracing::warn!("Invalid DYN_MAX_OUTPUT_TOKENS value '{}', ignoring", v);
None
}
},
Err(_) => None,
};
(context_window, max_output_tokens)
}
fn resolve_context_window(
model_name: &str,
card_map: &std::collections::HashMap<String, u32>,
env_override: Option<u64>,
) -> Option<u64> {
env_override.or_else(|| card_map.get(model_name).map(|&cl| cl as u64))
}
async fn list_models(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
) -> Result<Response, super::openai::ErrorResponse> {
super::openai::check_ready(&state)?;
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let models: HashSet<String> = state.manager().model_display_names();
let card_map = build_model_context_map(&state);
let (cw_override, mot_override) = model_env_overrides();
if headers.contains_key("anthropic-version") {
let created_at = chrono::DateTime::from_timestamp(created as i64, 0)
.unwrap_or_default()
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let data: Vec<serde_json::Value> = models
.iter()
.map(|name| {
let mut obj = serde_json::json!({
"id": name,
"display_name": name,
"type": "model",
"created_at": created_at,
});
if let Some(cw) = resolve_context_window(name, &card_map, cw_override) {
obj["max_input_tokens"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_tokens"] = serde_json::json!(mot);
}
obj
})
.collect();
let first_id = data
.first()
.and_then(|d| d["id"].as_str().map(String::from));
let last_id = data.last().and_then(|d| d["id"].as_str().map(String::from));
return Ok(Json(serde_json::json!({
"data": data,
"has_more": false,
"first_id": first_id,
"last_id": last_id,
}))
.into_response());
}
let data: Vec<serde_json::Value> = models
.iter()
.map(|name| {
let mut obj = serde_json::json!({
"id": name,
"object": "model",
"created": created,
"owned_by": "nvidia",
});
if let Some(cw) = resolve_context_window(name, &card_map, cw_override) {
obj["context_window"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_output_tokens"] = serde_json::json!(mot);
}
obj
})
.collect();
Ok(Json(serde_json::json!({
"object": "list",
"data": data,
}))
.into_response())
}
async fn get_model(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
axum::extract::Path(model_id): axum::extract::Path<String>,
) -> Result<Response, super::openai::ErrorResponse> {
super::openai::check_ready(&state)?;
let model_id = model_id.strip_prefix('/').unwrap_or(&model_id);
let models: HashSet<String> = state.manager().model_display_names();
if !models.contains(model_id) {
return Err(super::openai::ErrorMessage::model_not_found());
}
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let card_map = build_model_context_map(&state);
let (cw_override, mot_override) = model_env_overrides();
let context_window = resolve_context_window(model_id, &card_map, cw_override);
if headers.contains_key("anthropic-version") {
let created_at = chrono::DateTime::from_timestamp(created as i64, 0)
.unwrap_or_default()
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let mut obj = serde_json::json!({
"id": model_id,
"display_name": model_id,
"type": "model",
"created_at": created_at,
});
if let Some(cw) = context_window {
obj["max_input_tokens"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_tokens"] = serde_json::json!(mot);
}
Ok(Json(obj).into_response())
} else {
let mut obj = serde_json::json!({
"id": model_id,
"object": "model",
"created": created,
"owned_by": "nvidia",
});
if let Some(cw) = context_window {
obj["context_window"] = serde_json::json!(cw);
}
if let Some(mot) = mot_override {
obj["max_output_tokens"] = serde_json::json!(mot);
}
Ok(Json(obj).into_response())
}
}
fn strip_billing_preamble(system: &mut Option<SystemContent>) {
if let Some(content) = system {
let trimmed = content.text.trim_start();
if trimmed.starts_with("x-anthropic-billing-header:")
&& let Some(newline_pos) = trimmed.find('\n')
{
content.text = trimmed[newline_pos + 1..].to_string();
}
}
}
fn estimate_input_tokens(req: &AnthropicCreateMessageRequest) -> u32 {
let count_req = AnthropicCountTokensRequest {
model: req.model.clone(),
messages: req.messages.clone(),
system: req.system.clone(),
tools: req.tools.clone(),
};
count_req.estimate_tokens()
}
fn anthropic_error(status: StatusCode, error_type: &str, message: &str) -> Response {
let mapped_type = match status.as_u16() {
400 => "invalid_request_error",
401 => "authentication_error",
403 => "permission_error",
404 => "not_found_error",
429 => "rate_limit_error",
503 | 529 => "overloaded_error",
_ => error_type,
};
(
status,
Json(AnthropicErrorResponse {
object_type: "error".to_string(),
error: AnthropicErrorBody {
error_type: mapped_type.to_string(),
message: message.to_string(),
},
}),
)
.into_response()
}