1use std::collections::HashMap;
5use std::env::var;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::Ordering;
10use std::time::Duration;
11
12use super::Metrics;
13use super::RouteDoc;
14use super::metrics;
15use crate::discovery::ModelManager;
16use crate::endpoint_type::EndpointType;
17use crate::request_template::RequestTemplate;
18use anyhow::Result;
19use axum_server::tls_rustls::RustlsConfig;
20use derive_builder::Builder;
21use dynamo_runtime::logging::make_request_span;
22use dynamo_runtime::transports::etcd;
23use std::net::SocketAddr;
24use tokio::task::JoinHandle;
25use tokio_util::sync::CancellationToken;
26use tower_http::trace::TraceLayer;
27
28#[derive(Default)]
30pub struct State {
31 metrics: Arc<Metrics>,
32 manager: Arc<ModelManager>,
33 etcd_client: Option<etcd::Client>,
34 flags: StateFlags,
35}
36
37#[derive(Default, Debug)]
38struct StateFlags {
39 chat_endpoints_enabled: AtomicBool,
40 cmpl_endpoints_enabled: AtomicBool,
41 embeddings_endpoints_enabled: AtomicBool,
42 responses_endpoints_enabled: AtomicBool,
43}
44
45impl StateFlags {
46 pub fn get(&self, endpoint_type: &EndpointType) -> bool {
47 match endpoint_type {
48 EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
49 EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
50 EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
51 EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
52 }
53 }
54
55 pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
56 match endpoint_type {
57 EndpointType::Chat => self
58 .chat_endpoints_enabled
59 .store(enabled, Ordering::Relaxed),
60 EndpointType::Completion => self
61 .cmpl_endpoints_enabled
62 .store(enabled, Ordering::Relaxed),
63 EndpointType::Embedding => self
64 .embeddings_endpoints_enabled
65 .store(enabled, Ordering::Relaxed),
66 EndpointType::Responses => self
67 .responses_endpoints_enabled
68 .store(enabled, Ordering::Relaxed),
69 }
70 }
71}
72
73impl State {
74 pub fn new(manager: Arc<ModelManager>) -> Self {
75 Self {
76 manager,
77 metrics: Arc::new(Metrics::default()),
78 etcd_client: None,
79 flags: StateFlags {
80 chat_endpoints_enabled: AtomicBool::new(false),
81 cmpl_endpoints_enabled: AtomicBool::new(false),
82 embeddings_endpoints_enabled: AtomicBool::new(false),
83 responses_endpoints_enabled: AtomicBool::new(false),
84 },
85 }
86 }
87
88 pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
89 Self {
90 manager,
91 metrics: Arc::new(Metrics::default()),
92 etcd_client,
93 flags: StateFlags {
94 chat_endpoints_enabled: AtomicBool::new(false),
95 cmpl_endpoints_enabled: AtomicBool::new(false),
96 embeddings_endpoints_enabled: AtomicBool::new(false),
97 responses_endpoints_enabled: AtomicBool::new(false),
98 },
99 }
100 }
101 pub fn metrics_clone(&self) -> Arc<Metrics> {
103 self.metrics.clone()
104 }
105
106 pub fn manager(&self) -> &ModelManager {
107 Arc::as_ref(&self.manager)
108 }
109
110 pub fn manager_clone(&self) -> Arc<ModelManager> {
111 self.manager.clone()
112 }
113
114 pub fn etcd_client(&self) -> Option<&etcd::Client> {
115 self.etcd_client.as_ref()
116 }
117
118 pub fn sse_keep_alive(&self) -> Option<Duration> {
120 None
121 }
122}
123
124#[derive(Clone)]
125pub struct HttpService {
126 state: Arc<State>,
128
129 router: axum::Router,
130 port: u16,
131 host: String,
132 enable_tls: bool,
133 tls_cert_path: Option<PathBuf>,
134 tls_key_path: Option<PathBuf>,
135 route_docs: Vec<RouteDoc>,
136
137 etcd_client: Option<dynamo_runtime::transports::etcd::Client>,
139}
140
141#[derive(Clone, Builder)]
142#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
143pub struct HttpServiceConfig {
144 #[builder(default = "8787")]
145 port: u16,
146
147 #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
148 host: String,
149
150 #[builder(default = "false")]
151 enable_tls: bool,
152
153 #[builder(default = "None")]
154 tls_cert_path: Option<PathBuf>,
155
156 #[builder(default = "None")]
157 tls_key_path: Option<PathBuf>,
158
159 #[builder(default = "false")]
162 enable_chat_endpoints: bool,
163
164 #[builder(default = "false")]
165 enable_cmpl_endpoints: bool,
166
167 #[builder(default = "true")]
168 enable_embeddings_endpoints: bool,
169
170 #[builder(default = "true")]
171 enable_responses_endpoints: bool,
172
173 #[builder(default = "None")]
174 request_template: Option<RequestTemplate>,
175
176 #[builder(default = "None")]
177 etcd_client: Option<etcd::Client>,
178}
179
180impl HttpService {
181 pub fn builder() -> HttpServiceConfigBuilder {
182 HttpServiceConfigBuilder::default()
183 }
184
185 pub fn state_clone(&self) -> Arc<State> {
186 self.state.clone()
187 }
188
189 pub fn state(&self) -> &State {
190 Arc::as_ref(&self.state)
191 }
192
193 pub fn model_manager(&self) -> &ModelManager {
194 self.state().manager()
195 }
196
197 pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
198 let this = self.clone();
199 tokio::spawn(async move { this.run(cancel_token).await })
200 }
201
202 pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
203 let address = format!("{}:{}", self.host, self.port);
204 let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" };
205 tracing::info!(protocol, address, "Starting HTTP(S) service");
206
207 let poll_interval_secs = std::env::var("DYN_HTTP_SVC_CONFIG_METRICS_POLL_INTERVAL_SECS")
209 .ok()
210 .and_then(|s| s.parse::<f64>().ok())
211 .filter(|&secs| secs > 0.0) .unwrap_or(8.0);
213 let poll_interval = Duration::from_secs_f64(poll_interval_secs);
214
215 let _polling_task = super::metrics::Metrics::start_runtime_config_polling_task(
216 self.state.metrics_clone(),
217 self.state.manager_clone(),
218 self.etcd_client.clone(),
219 poll_interval,
220 cancel_token.child_token(),
221 );
222
223 let router = self.router.clone();
224 let observer = cancel_token.child_token();
225
226 let addr: SocketAddr = address
227 .parse()
228 .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?;
229
230 if self.enable_tls {
231 let cert_path = self
232 .tls_cert_path
233 .as_ref()
234 .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?;
235 let key_path = self
236 .tls_key_path
237 .as_ref()
238 .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?;
239
240 if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() {
243 tracing::debug!("TLS crypto provider already installed: {e:?}");
244 }
245
246 let config = RustlsConfig::from_pem_file(cert_path, key_path)
247 .await
248 .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?;
249
250 let handle = axum_server::Handle::new();
251 let server = axum_server::bind_rustls(addr, config)
252 .handle(handle.clone())
253 .serve(router.into_make_service());
254
255 tokio::select! {
256 result = server => {
257 result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e))?;
258 }
259 _ = observer.cancelled() => {
260 tracing::info!("HTTPS server shutdown requested");
261 handle.graceful_shutdown(Some(Duration::from_secs(5)));
262 }
264 }
265 } else {
266 let listener = tokio::net::TcpListener::bind(addr)
267 .await
268 .unwrap_or_else(|_| panic!("could not bind to address: {address}"));
269
270 axum::serve(listener, router)
271 .with_graceful_shutdown(observer.cancelled_owned())
272 .await
273 .inspect_err(|_| cancel_token.cancel())?;
274 }
275
276 Ok(())
277 }
278
279 pub fn route_docs(&self) -> &[RouteDoc] {
281 &self.route_docs
282 }
283
284 pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
285 self.state.flags.set(&endpoint_type, enable);
286 tracing::info!(
287 "{} endpoints {}",
288 endpoint_type.as_str(),
289 if enable { "enabled" } else { "disabled" }
290 );
291 }
292}
293
294static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
296static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
298static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
300static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
302static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
304static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
306static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
308static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";
310
311impl HttpServiceConfigBuilder {
312 pub fn build(self) -> Result<HttpService, anyhow::Error> {
313 let config: HttpServiceConfig = self.build_internal()?;
314
315 let model_manager = Arc::new(ModelManager::new());
316 let etcd_client = config.etcd_client.clone();
317 let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
318
319 state
320 .flags
321 .set(&EndpointType::Chat, config.enable_chat_endpoints);
322 state
323 .flags
324 .set(&EndpointType::Completion, config.enable_cmpl_endpoints);
325 state
326 .flags
327 .set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
328 state
329 .flags
330 .set(&EndpointType::Responses, config.enable_responses_endpoints);
331
332 let registry = metrics::Registry::new();
334 state.metrics_clone().register(®istry)?;
335
336 let mut router = axum::Router::new();
339
340 let mut all_docs = Vec::new();
341
342 let mut routes = vec![
343 metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()),
344 super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
345 super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
346 super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
347 ];
348
349 let endpoint_routes =
350 HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
351 routes.extend(endpoint_routes);
352 for (route_docs, route) in routes {
353 router = router.merge(route);
354 all_docs.extend(route_docs);
355 }
356
357 router = router.layer(TraceLayer::new_for_http().make_span_with(make_request_span));
359
360 Ok(HttpService {
361 state,
362 router,
363 port: config.port,
364 host: config.host,
365 enable_tls: config.enable_tls,
366 tls_cert_path: config.tls_cert_path,
367 tls_key_path: config.tls_key_path,
368 route_docs: all_docs,
369 etcd_client,
370 })
371 }
372
373 pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
374 self.request_template = Some(request_template);
375 self
376 }
377
378 pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
379 self.etcd_client = Some(etcd_client);
380 self
381 }
382
383 fn get_endpoints_router(
384 state: Arc<State>,
385 request_template: &Option<RequestTemplate>,
386 ) -> Vec<(Vec<RouteDoc>, axum::Router)> {
387 let mut routes = Vec::new();
388 let (chat_docs, chat_route) = super::openai::chat_completions_router(
390 state.clone(),
391 request_template.clone(),
392 var(HTTP_SVC_CHAT_PATH_ENV).ok(),
393 );
394 let (cmpl_docs, cmpl_route) =
395 super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
396 let (embed_docs, embed_route) =
397 super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
398 let (responses_docs, responses_route) = super::openai::responses_router(
399 state.clone(),
400 request_template.clone(),
401 var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
402 );
403
404 let mut endpoint_routes = HashMap::new();
405 endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
406 endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
407 endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
408 endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
409
410 for endpoint_type in EndpointType::all() {
411 let state_route = state.clone();
412 if !endpoint_routes.contains_key(&endpoint_type) {
413 tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
414 continue;
415 }
416 let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
417 let route = route.route_layer(axum::middleware::from_fn(
418 move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
419 let state: Arc<State> = state_route.clone();
420 async move {
421 let enabled = state.flags.get(&endpoint_type);
423 if enabled {
424 Ok(next.run(req).await)
425 } else {
426 tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
427 Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
428 }
429 }
430 },
431 ));
432 routes.push((docs, route));
433 }
434 routes
435 }
436}