dynamo_llm/http/service/
service_v2.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::env::var;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::Ordering;
10use std::time::Duration;
11
12use super::Metrics;
13use super::RouteDoc;
14use super::metrics;
15use crate::discovery::ModelManager;
16use crate::endpoint_type::EndpointType;
17use crate::request_template::RequestTemplate;
18use anyhow::Result;
19use axum_server::tls_rustls::RustlsConfig;
20use derive_builder::Builder;
21use dynamo_runtime::logging::make_request_span;
22use dynamo_runtime::transports::etcd;
23use std::net::SocketAddr;
24use tokio::task::JoinHandle;
25use tokio_util::sync::CancellationToken;
26use tower_http::trace::TraceLayer;
27
28/// HTTP service shared state
29#[derive(Default)]
30pub struct State {
31    metrics: Arc<Metrics>,
32    manager: Arc<ModelManager>,
33    etcd_client: Option<etcd::Client>,
34    flags: StateFlags,
35}
36
37#[derive(Default, Debug)]
38struct StateFlags {
39    chat_endpoints_enabled: AtomicBool,
40    cmpl_endpoints_enabled: AtomicBool,
41    embeddings_endpoints_enabled: AtomicBool,
42    responses_endpoints_enabled: AtomicBool,
43}
44
45impl StateFlags {
46    pub fn get(&self, endpoint_type: &EndpointType) -> bool {
47        match endpoint_type {
48            EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
49            EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
50            EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
51            EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
52        }
53    }
54
55    pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
56        match endpoint_type {
57            EndpointType::Chat => self
58                .chat_endpoints_enabled
59                .store(enabled, Ordering::Relaxed),
60            EndpointType::Completion => self
61                .cmpl_endpoints_enabled
62                .store(enabled, Ordering::Relaxed),
63            EndpointType::Embedding => self
64                .embeddings_endpoints_enabled
65                .store(enabled, Ordering::Relaxed),
66            EndpointType::Responses => self
67                .responses_endpoints_enabled
68                .store(enabled, Ordering::Relaxed),
69        }
70    }
71}
72
73impl State {
74    pub fn new(manager: Arc<ModelManager>) -> Self {
75        Self {
76            manager,
77            metrics: Arc::new(Metrics::default()),
78            etcd_client: None,
79            flags: StateFlags {
80                chat_endpoints_enabled: AtomicBool::new(false),
81                cmpl_endpoints_enabled: AtomicBool::new(false),
82                embeddings_endpoints_enabled: AtomicBool::new(false),
83                responses_endpoints_enabled: AtomicBool::new(false),
84            },
85        }
86    }
87
88    pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
89        Self {
90            manager,
91            metrics: Arc::new(Metrics::default()),
92            etcd_client,
93            flags: StateFlags {
94                chat_endpoints_enabled: AtomicBool::new(false),
95                cmpl_endpoints_enabled: AtomicBool::new(false),
96                embeddings_endpoints_enabled: AtomicBool::new(false),
97                responses_endpoints_enabled: AtomicBool::new(false),
98            },
99        }
100    }
101    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
102    pub fn metrics_clone(&self) -> Arc<Metrics> {
103        self.metrics.clone()
104    }
105
106    pub fn manager(&self) -> &ModelManager {
107        Arc::as_ref(&self.manager)
108    }
109
110    pub fn manager_clone(&self) -> Arc<ModelManager> {
111        self.manager.clone()
112    }
113
114    pub fn etcd_client(&self) -> Option<&etcd::Client> {
115        self.etcd_client.as_ref()
116    }
117
118    // TODO
119    pub fn sse_keep_alive(&self) -> Option<Duration> {
120        None
121    }
122}
123
124#[derive(Clone)]
125pub struct HttpService {
126    // The state we share with every request handler
127    state: Arc<State>,
128
129    router: axum::Router,
130    port: u16,
131    host: String,
132    enable_tls: bool,
133    tls_cert_path: Option<PathBuf>,
134    tls_key_path: Option<PathBuf>,
135    route_docs: Vec<RouteDoc>,
136
137    // Metrics polling configuration
138    etcd_client: Option<dynamo_runtime::transports::etcd::Client>,
139}
140
141#[derive(Clone, Builder)]
142#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
143pub struct HttpServiceConfig {
144    #[builder(default = "8787")]
145    port: u16,
146
147    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
148    host: String,
149
150    #[builder(default = "false")]
151    enable_tls: bool,
152
153    #[builder(default = "None")]
154    tls_cert_path: Option<PathBuf>,
155
156    #[builder(default = "None")]
157    tls_key_path: Option<PathBuf>,
158
159    // #[builder(default)]
160    // custom: Vec<axum::Router>
161    #[builder(default = "false")]
162    enable_chat_endpoints: bool,
163
164    #[builder(default = "false")]
165    enable_cmpl_endpoints: bool,
166
167    #[builder(default = "true")]
168    enable_embeddings_endpoints: bool,
169
170    #[builder(default = "true")]
171    enable_responses_endpoints: bool,
172
173    #[builder(default = "None")]
174    request_template: Option<RequestTemplate>,
175
176    #[builder(default = "None")]
177    etcd_client: Option<etcd::Client>,
178}
179
180impl HttpService {
181    pub fn builder() -> HttpServiceConfigBuilder {
182        HttpServiceConfigBuilder::default()
183    }
184
185    pub fn state_clone(&self) -> Arc<State> {
186        self.state.clone()
187    }
188
189    pub fn state(&self) -> &State {
190        Arc::as_ref(&self.state)
191    }
192
193    pub fn model_manager(&self) -> &ModelManager {
194        self.state().manager()
195    }
196
197    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
198        let this = self.clone();
199        tokio::spawn(async move { this.run(cancel_token).await })
200    }
201
202    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
203        let address = format!("{}:{}", self.host, self.port);
204        let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" };
205        tracing::info!(protocol, address, "Starting HTTP(S) service");
206
207        // Start background task to poll runtime config metrics with proper cancellation
208        let poll_interval_secs = std::env::var("DYN_HTTP_SVC_CONFIG_METRICS_POLL_INTERVAL_SECS")
209            .ok()
210            .and_then(|s| s.parse::<f64>().ok())
211            .filter(|&secs| secs > 0.0) // Guard against zero or negative values
212            .unwrap_or(8.0);
213        let poll_interval = Duration::from_secs_f64(poll_interval_secs);
214
215        let _polling_task = super::metrics::Metrics::start_runtime_config_polling_task(
216            self.state.metrics_clone(),
217            self.state.manager_clone(),
218            self.etcd_client.clone(),
219            poll_interval,
220            cancel_token.child_token(),
221        );
222
223        let router = self.router.clone();
224        let observer = cancel_token.child_token();
225
226        let addr: SocketAddr = address
227            .parse()
228            .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?;
229
230        if self.enable_tls {
231            let cert_path = self
232                .tls_cert_path
233                .as_ref()
234                .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?;
235            let key_path = self
236                .tls_key_path
237                .as_ref()
238                .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?;
239
240            // aws_lc_rs is the default but other crates pull in `ring` also,
241            // so rustls doesn't know which one to use. Tell it.
242            if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() {
243                tracing::debug!("TLS crypto provider already installed: {e:?}");
244            }
245
246            let config = RustlsConfig::from_pem_file(cert_path, key_path)
247                .await
248                .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?;
249
250            let handle = axum_server::Handle::new();
251            let server = axum_server::bind_rustls(addr, config)
252                .handle(handle.clone())
253                .serve(router.into_make_service());
254
255            tokio::select! {
256                result = server => {
257                    result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e))?;
258                }
259                _ = observer.cancelled() => {
260                    tracing::info!("HTTPS server shutdown requested");
261                    handle.graceful_shutdown(Some(Duration::from_secs(5)));
262                    // TODO: Do we need to wait?
263                }
264            }
265        } else {
266            let listener = tokio::net::TcpListener::bind(addr)
267                .await
268                .unwrap_or_else(|_| panic!("could not bind to address: {address}"));
269
270            axum::serve(listener, router)
271                .with_graceful_shutdown(observer.cancelled_owned())
272                .await
273                .inspect_err(|_| cancel_token.cancel())?;
274        }
275
276        Ok(())
277    }
278
279    /// Documentation of exposed HTTP endpoints
280    pub fn route_docs(&self) -> &[RouteDoc] {
281        &self.route_docs
282    }
283
284    pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
285        self.state.flags.set(&endpoint_type, enable);
286        tracing::info!(
287            "{} endpoints {}",
288            endpoint_type.as_str(),
289            if enable { "enabled" } else { "disabled" }
290        );
291    }
292}
293
294/// Environment variable to set the metrics endpoint path (default: `/metrics`)
295static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
296/// Environment variable to set the models endpoint path (default: `/v1/models`)
297static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
298/// Environment variable to set the health endpoint path (default: `/health`)
299static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
300/// Environment variable to set the live endpoint path (default: `/live`)
301static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
302/// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`)
303static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
304/// Environment variable to set the completions endpoint path (default: `/v1/completions`)
305static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
306/// Environment variable to set the embeddings endpoint path (default: `/v1/embeddings`)
307static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
308/// Environment variable to set the responses endpoint path (default: `/v1/responses`)
309static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";
310
311impl HttpServiceConfigBuilder {
312    pub fn build(self) -> Result<HttpService, anyhow::Error> {
313        let config: HttpServiceConfig = self.build_internal()?;
314
315        let model_manager = Arc::new(ModelManager::new());
316        let etcd_client = config.etcd_client.clone();
317        let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
318
319        state
320            .flags
321            .set(&EndpointType::Chat, config.enable_chat_endpoints);
322        state
323            .flags
324            .set(&EndpointType::Completion, config.enable_cmpl_endpoints);
325        state
326            .flags
327            .set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
328        state
329            .flags
330            .set(&EndpointType::Responses, config.enable_responses_endpoints);
331
332        // enable prometheus metrics
333        let registry = metrics::Registry::new();
334        state.metrics_clone().register(&registry)?;
335
336        // Note: Metrics polling task will be started in run() method to have access to cancellation token
337
338        let mut router = axum::Router::new();
339
340        let mut all_docs = Vec::new();
341
342        let mut routes = vec![
343            metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()),
344            super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
345            super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
346            super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
347        ];
348
349        let endpoint_routes =
350            HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
351        routes.extend(endpoint_routes);
352        for (route_docs, route) in routes {
353            router = router.merge(route);
354            all_docs.extend(route_docs);
355        }
356
357        // Add span for tracing
358        router = router.layer(TraceLayer::new_for_http().make_span_with(make_request_span));
359
360        Ok(HttpService {
361            state,
362            router,
363            port: config.port,
364            host: config.host,
365            enable_tls: config.enable_tls,
366            tls_cert_path: config.tls_cert_path,
367            tls_key_path: config.tls_key_path,
368            route_docs: all_docs,
369            etcd_client,
370        })
371    }
372
373    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
374        self.request_template = Some(request_template);
375        self
376    }
377
378    pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
379        self.etcd_client = Some(etcd_client);
380        self
381    }
382
383    fn get_endpoints_router(
384        state: Arc<State>,
385        request_template: &Option<RequestTemplate>,
386    ) -> Vec<(Vec<RouteDoc>, axum::Router)> {
387        let mut routes = Vec::new();
388        // Add chat completions route with conditional middleware
389        let (chat_docs, chat_route) = super::openai::chat_completions_router(
390            state.clone(),
391            request_template.clone(),
392            var(HTTP_SVC_CHAT_PATH_ENV).ok(),
393        );
394        let (cmpl_docs, cmpl_route) =
395            super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
396        let (embed_docs, embed_route) =
397            super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
398        let (responses_docs, responses_route) = super::openai::responses_router(
399            state.clone(),
400            request_template.clone(),
401            var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
402        );
403
404        let mut endpoint_routes = HashMap::new();
405        endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
406        endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
407        endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
408        endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
409
410        for endpoint_type in EndpointType::all() {
411            let state_route = state.clone();
412            if !endpoint_routes.contains_key(&endpoint_type) {
413                tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
414                continue;
415            }
416            let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
417            let route = route.route_layer(axum::middleware::from_fn(
418                move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
419                    let state: Arc<State> = state_route.clone();
420                    async move {
421                        // Check if the endpoint is enabled
422                        let enabled = state.flags.get(&endpoint_type);
423                        if enabled {
424                            Ok(next.run(req).await)
425                        } else {
426                            tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
427                            Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
428                        }
429                    }
430                },
431            ));
432            routes.push((docs, route));
433        }
434        routes
435    }
436}