use std::collections::HashMap;
use std::env::var;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use axum::body::Body;
use axum::http::Response;
use super::Metrics;
use super::RouteDoc;
use super::metrics;
use super::metrics::register_worker_timing_metrics;
use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType;
use crate::kv_router::metrics::{RoutingOverheadMetrics, register_worker_load_metrics};
use crate::request_template::RequestTemplate;
use anyhow::Result;
use axum_server::tls_rustls::RustlsConfig;
use derive_builder::Builder;
use dynamo_runtime::config::env_is_truthy;
use dynamo_runtime::config::environment_names::llm as env_llm;
use dynamo_runtime::discovery::Discovery;
use dynamo_runtime::logging::make_request_span;
use std::net::SocketAddr;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
discovery_client: Arc<dyn Discovery>,
flags: StateFlags,
cancel_token: CancellationToken,
}
#[derive(Default, Debug)]
struct StateFlags {
chat_endpoints_enabled: AtomicBool,
cmpl_endpoints_enabled: AtomicBool,
embeddings_endpoints_enabled: AtomicBool,
images_endpoints_enabled: AtomicBool,
videos_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool,
anthropic_endpoints_enabled: AtomicBool,
}
impl StateFlags {
pub fn get(&self, endpoint_type: &EndpointType) -> bool {
match endpoint_type {
EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Audios => false,
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::AnthropicMessages => {
self.anthropic_endpoints_enabled.load(Ordering::Relaxed)
}
}
}
pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
match endpoint_type {
EndpointType::Chat => self
.chat_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Completion => self
.cmpl_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Embedding => self
.embeddings_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Images => self
.images_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Videos => self
.videos_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Audios => {}
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::AnthropicMessages => self
.anthropic_endpoints_enabled
.store(enabled, Ordering::Relaxed),
}
}
}
impl State {
pub fn new(
manager: Arc<ModelManager>,
discovery_client: Arc<dyn Discovery>,
cancel_token: CancellationToken,
) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
discovery_client,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
images_endpoints_enabled: AtomicBool::new(false),
videos_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
anthropic_endpoints_enabled: AtomicBool::new(false),
},
cancel_token,
}
}
pub fn metrics_clone(&self) -> Arc<Metrics> {
self.metrics.clone()
}
pub fn manager(&self) -> &ModelManager {
Arc::as_ref(&self.manager)
}
pub fn manager_clone(&self) -> Arc<ModelManager> {
self.manager.clone()
}
pub fn discovery(&self) -> Arc<dyn Discovery> {
self.discovery_client.clone()
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn cancel_token(&self) -> &CancellationToken {
&self.cancel_token
}
pub fn sse_keep_alive(&self) -> Option<Duration> {
None
}
}
#[derive(Clone)]
pub struct HttpService {
state: Arc<State>,
router: axum::Router,
port: u16,
host: String,
enable_tls: bool,
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
route_docs: Vec<RouteDoc>,
}
#[derive(Clone, Builder)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
pub struct HttpServiceConfig {
#[builder(default = "8787")]
port: u16,
#[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
host: String,
#[builder(default = "false")]
enable_tls: bool,
#[builder(default = "None")]
tls_cert_path: Option<PathBuf>,
#[builder(default = "None")]
tls_key_path: Option<PathBuf>,
#[builder(default = "false")]
enable_chat_endpoints: bool,
#[builder(default = "false")]
enable_cmpl_endpoints: bool,
#[builder(default = "true")]
enable_embeddings_endpoints: bool,
#[builder(default = "true")]
enable_responses_endpoints: bool,
#[builder(default = "false")]
enable_anthropic_endpoints: bool,
#[builder(default = "None")]
request_template: Option<RequestTemplate>,
#[builder(default = "None")]
discovery: Option<Arc<dyn Discovery>>,
#[builder(default = "None")]
cancel_token: Option<CancellationToken>,
#[builder(default = "None")]
drt_metrics: Option<dynamo_runtime::metrics::MetricsRegistry>,
#[builder(default = "None")]
drt_discovery: Option<Arc<dyn Discovery>>,
}
impl HttpService {
pub fn builder() -> HttpServiceConfigBuilder {
HttpServiceConfigBuilder::default()
}
pub fn state_clone(&self) -> Arc<State> {
self.state.clone()
}
pub fn state(&self) -> &State {
Arc::as_ref(&self.state)
}
pub fn model_manager(&self) -> &ModelManager {
self.state().manager()
}
pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
let this = self.clone();
tokio::spawn(async move { this.run(cancel_token).await })
}
pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
let address = format!("{}:{}", self.host, self.port);
let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" };
tracing::info!(protocol, address, "Starting HTTP(S) service");
let router = self.router.clone();
let observer = cancel_token.child_token();
let state_cancel = self.state.cancel_token().clone();
let addr: SocketAddr = address
.parse()
.map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?;
if self.enable_tls {
let cert_path = self
.tls_cert_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?;
let key_path = self
.tls_key_path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?;
if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() {
tracing::debug!("TLS crypto provider already installed: {e:?}");
}
let config = RustlsConfig::from_pem_file(cert_path, key_path)
.await
.map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?;
let handle = axum_server::Handle::new();
let server = axum_server::bind_rustls(addr, config)
.handle(handle.clone())
.serve(router.into_make_service());
tokio::select! {
result = server => {
result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e))?;
}
_ = observer.cancelled() => {
state_cancel.cancel();
tracing::info!("HTTPS server shutdown requested");
handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64)));
}
}
} else {
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
tracing::error!(
protocol = %protocol,
address = %address,
error = %e,
"Failed to bind server to address"
);
match e.kind() {
std::io::ErrorKind::AddrInUse => anyhow::anyhow!(
"Failed to start {} server: port {} already in use. Use --http-port to specify a different port.",
protocol,
self.port
),
_ => anyhow::anyhow!(
"Failed to start {} server on {}: {}",
protocol,
address,
e
),
}
})?;
axum::serve(listener, router)
.with_graceful_shutdown(async move {
observer.cancelled_owned().await;
state_cancel.cancel();
tracing::info!("HTTP server shutdown requested");
tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64))
.await;
})
.await
.inspect_err(|_| cancel_token.cancel())?;
}
Ok(())
}
pub fn route_docs(&self) -> &[RouteDoc] {
&self.route_docs
}
pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
self.state.flags.set(&endpoint_type, enable);
tracing::info!(
"{} endpoints {}",
endpoint_type.as_str(),
if enable { "enabled" } else { "disabled" }
);
}
}
fn get_graceful_shutdown_timeout() -> usize {
std::env::var(env_llm::DYN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT_SECS)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(5)
}
static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";
static HTTP_SVC_ANTHROPIC_PATH_ENV: &str = "DYN_HTTP_SVC_ANTHROPIC_PATH";
impl HttpServiceConfigBuilder {
pub fn build(self) -> Result<HttpService, anyhow::Error> {
let config: HttpServiceConfig = self.build_internal()?;
let model_manager = Arc::new(ModelManager::new());
let cancel_token = config.cancel_token.unwrap_or_default();
let discovery_client = config.discovery.unwrap_or_else(|| {
use dynamo_runtime::discovery::KVStoreDiscovery;
Arc::new(KVStoreDiscovery::new(
dynamo_runtime::storage::kv::Manager::memory(),
cancel_token.child_token(),
)) as Arc<dyn Discovery>
});
let state = Arc::new(State::new(model_manager, discovery_client, cancel_token));
state
.flags
.set(&EndpointType::Chat, config.enable_chat_endpoints);
state
.flags
.set(&EndpointType::Completion, config.enable_cmpl_endpoints);
state
.flags
.set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
state
.flags
.set(&EndpointType::Responses, config.enable_responses_endpoints);
state.flags.set(
&EndpointType::AnthropicMessages,
config.enable_anthropic_endpoints,
);
let registry = metrics::Registry::new();
state.metrics_clone().register(®istry)?;
if let Err(e) = register_worker_load_metrics(®istry) {
tracing::warn!("Failed to register worker load metrics: {}", e);
}
if let Err(e) = register_worker_timing_metrics(®istry) {
tracing::warn!("Failed to register worker timing metrics: {}", e);
}
if let Some(ref discovery) = config.drt_discovery {
let instance_id = discovery.instance_id();
if let Err(e) = RoutingOverheadMetrics::register(®istry, instance_id) {
tracing::warn!("Failed to register routing overhead metrics: {}", e);
}
}
let mut router = axum::Router::new();
let mut all_docs = Vec::new();
let mut routes = vec![
metrics::router(
registry,
var(HTTP_SVC_METRICS_PATH_ENV).ok(),
config.drt_metrics,
),
super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
super::busy_threshold::busy_threshold_router(state.clone(), None),
];
let endpoint_routes =
HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
routes.extend(endpoint_routes);
for (route_docs, route) in routes {
router = router.merge(route);
all_docs.extend(route_docs);
}
let (openapi_docs, openapi_route) =
super::openapi_docs::openapi_router(all_docs.clone(), None);
router = router.merge(openapi_route);
all_docs.extend(openapi_docs);
router = router.layer(
TraceLayer::new_for_http()
.make_span_with(make_request_span)
.on_response(
|response: &Response<Body>, latency: Duration, _span: &tracing::Span| {
let status = response.status();
let latency_ms = latency.as_millis();
if status.is_server_error() {
tracing::error!(
status = %status.as_u16(),
latency_ms = %latency_ms,
"request completed with server error"
);
} else if status.is_client_error() {
tracing::warn!(
status = %status.as_u16(),
latency_ms = %latency_ms,
"request completed with client request error"
);
} else {
tracing::debug!(
status = %status.as_u16(),
latency_ms = %latency_ms,
"request completed"
);
}
},
),
);
Ok(HttpService {
state,
router,
port: config.port,
host: config.host,
enable_tls: config.enable_tls,
tls_cert_path: config.tls_cert_path,
tls_key_path: config.tls_key_path,
route_docs: all_docs,
})
}
pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
self.request_template = Some(request_template);
self
}
fn get_endpoints_router(
state: Arc<State>,
request_template: &Option<RequestTemplate>,
) -> Vec<(Vec<RouteDoc>, axum::Router)> {
let mut routes = Vec::new();
let (chat_docs, chat_route) = super::openai::chat_completions_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_CHAT_PATH_ENV).ok(),
);
let (cmpl_docs, cmpl_route) =
super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
let (embed_docs, embed_route) =
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (images_docs, images_route) = super::openai::images_router(state.clone(), None);
let (videos_docs, videos_route) = super::openai::videos_router(state.clone(), None);
let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
);
let mut endpoint_routes = HashMap::new();
endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Images, (images_docs, images_route));
endpoint_routes.insert(EndpointType::Videos, (videos_docs, videos_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
if env_is_truthy(env_llm::DYN_ENABLE_ANTHROPIC_API) {
tracing::warn!("Anthropic Messages API (/v1/messages) is experimental.");
let (anthropic_docs, anthropic_route) = super::anthropic::anthropic_messages_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_ANTHROPIC_PATH_ENV).ok(),
);
endpoint_routes.insert(
EndpointType::AnthropicMessages,
(anthropic_docs, anthropic_route),
);
}
for endpoint_type in EndpointType::all() {
let state_route = state.clone();
if !endpoint_routes.contains_key(&endpoint_type) {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
continue;
}
let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
let route = route.route_layer(axum::middleware::from_fn(
move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
let state: Arc<State> = state_route.clone();
async move {
let enabled = state.flags.get(&endpoint_type);
if enabled {
Ok(next.run(req).await)
} else {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
Err(axum::http::StatusCode::NOT_FOUND)
}
}
},
));
routes.push((docs, route));
}
routes
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
#[tokio::test]
#[serial]
async fn test_liveness_endpoint_reflects_cancellation() {
let cancel_token = Arc::new(CancellationToken::new());
let service = HttpService::builder().build().unwrap();
let port = service.port;
let service_token = cancel_token.clone();
let handle = tokio::spawn(async move {
service.run((*service_token).clone()).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cancel_token.cancel();
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let client = reqwest::Client::new();
let resp = client
.get(format!("http://localhost:{}/live", port))
.send()
.await
.expect("Request failed");
assert_eq!(resp.status(), reqwest::StatusCode::SERVICE_UNAVAILABLE);
handle.abort();
}
}