Skip to main content

dynamo_runtime/pipeline/network/ingress/
http_endpoint.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! HTTP endpoint for receiving requests via Axum/HTTP/2
5
6use super::*;
7use crate::SystemHealth;
8use crate::config::HealthStatus;
9use crate::logging::TraceParent;
10use anyhow::Result;
11use axum::{
12    Router,
13    body::Bytes,
14    extract::{Path, State as AxumState},
15    http::{HeaderMap, StatusCode},
16    response::IntoResponse,
17    routing::post,
18};
19use dashmap::DashMap;
20use hyper_util::rt::{TokioExecutor, TokioIo};
21use hyper_util::server::conn::auto::Builder as Http2Builder;
22use hyper_util::service::TowerToHyperService;
23use parking_lot::Mutex;
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tokio::sync::{Notify, RwLock};
27use tokio_util::sync::CancellationToken;
28use tower_http::trace::TraceLayer;
29use tracing::Instrument;
30
31/// Default root path for dynamo RPC endpoints
32const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
33
34/// version of crate
35pub const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37/// Shared HTTP server that handles multiple endpoints on a single port
38pub struct SharedHttpServer {
39    handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
40    bind_addr: SocketAddr,
41    actual_addr: RwLock<Option<SocketAddr>>,
42    cancellation_token: CancellationToken,
43}
44
45/// Handler for a specific endpoint
46struct EndpointHandler {
47    service_handler: Arc<dyn PushWorkHandler>,
48    instance_id: u64,
49    namespace: Arc<String>,
50    component_name: Arc<String>,
51    endpoint_name: Arc<String>,
52    system_health: Arc<Mutex<SystemHealth>>,
53    inflight: Arc<AtomicU64>,
54    notify: Arc<Notify>,
55}
56
57impl SharedHttpServer {
58    pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
59        Arc::new(Self {
60            handlers: Arc::new(DashMap::new()),
61            bind_addr,
62            actual_addr: RwLock::new(None),
63            cancellation_token,
64        })
65    }
66
67    /// Get the actual bound address (after `bind_and_start` resolves).
68    pub fn actual_address(&self) -> Option<SocketAddr> {
69        self.actual_addr.try_read().ok().and_then(|g| *g)
70    }
71
72    /// Register an endpoint handler with this server
73    #[allow(clippy::too_many_arguments)]
74    pub async fn register_endpoint(
75        &self,
76        subject: String,
77        service_handler: Arc<dyn PushWorkHandler>,
78        instance_id: u64,
79        namespace: String,
80        component_name: String,
81        endpoint_name: String,
82        system_health: Arc<Mutex<SystemHealth>>,
83    ) -> Result<()> {
84        let handler = Arc::new(EndpointHandler {
85            service_handler,
86            instance_id,
87            namespace: Arc::new(namespace),
88            component_name: Arc::new(component_name),
89            endpoint_name: Arc::new(endpoint_name.clone()),
90            system_health: system_health.clone(),
91            inflight: Arc::new(AtomicU64::new(0)),
92            notify: Arc::new(Notify::new()),
93        });
94
95        // Insert handler FIRST to ensure it's ready to receive requests
96        let subject_clone = subject.clone();
97        self.handlers.insert(subject, handler);
98
99        // THEN set health status to Ready (after handler is registered and ready)
100        system_health
101            .lock()
102            .set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
103
104        tracing::debug!("Registered endpoint handler for subject: {subject_clone}");
105        Ok(())
106    }
107
108    /// Unregister an endpoint handler
109    pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
110        if let Some((_, handler)) = self.handlers.remove(subject) {
111            handler
112                .system_health
113                .lock()
114                .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
115            tracing::debug!(
116                endpoint_name = %endpoint_name,
117                subject = %subject,
118                "Unregistered HTTP endpoint handler"
119            );
120
121            let inflight_count = handler.inflight.load(Ordering::SeqCst);
122            if inflight_count > 0 {
123                tracing::info!(
124                    endpoint_name = %endpoint_name,
125                    inflight_count = inflight_count,
126                    "Waiting for inflight HTTP requests to complete"
127                );
128                while handler.inflight.load(Ordering::SeqCst) > 0 {
129                    handler.notify.notified().await;
130                }
131                tracing::info!(
132                    endpoint_name = %endpoint_name,
133                    "All inflight HTTP requests completed"
134                );
135            }
136        }
137    }
138
139    /// Bind the TCP listener and start the accept loop.
140    ///
141    /// Returns the actual bound `SocketAddr` (important when binding to port 0).
142    pub async fn bind_and_start(self: Arc<Self>) -> Result<SocketAddr> {
143        let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
144            .unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
145        let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
146
147        let app = Router::new()
148            .route(&route_pattern, post(handle_shared_request))
149            .layer(TraceLayer::new_for_http())
150            .with_state(self.clone());
151
152        let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
153        let actual_addr = listener.local_addr()?;
154
155        tracing::info!(
156            requested = %self.bind_addr,
157            actual = %actual_addr,
158            rpc_root = %rpc_root_path,
159            "HTTP/2 endpoint server bound"
160        );
161
162        // Store the actual address so `address()` returns the real port.
163        *self.actual_addr.write().await = Some(actual_addr);
164
165        let cancellation_token = self.cancellation_token.clone();
166
167        // Spawn the accept loop in the background.
168        tokio::spawn(async move {
169            loop {
170                tokio::select! {
171                    accept_result = listener.accept() => {
172                        match accept_result {
173                            Ok((stream, _addr)) => {
174                                let app_clone = app.clone();
175                                let cancel_clone = cancellation_token.clone();
176
177                                tokio::spawn(async move {
178                                    let http2_builder = Http2Builder::new(TokioExecutor::new());
179
180                                    let io = TokioIo::new(stream);
181                                    let tower_service = app_clone.into_service();
182                                    let hyper_service = TowerToHyperService::new(tower_service);
183
184                                    tokio::select! {
185                                        result = http2_builder.serve_connection(io, hyper_service) => {
186                                            if let Err(e) = result {
187                                                tracing::debug!("HTTP/2 connection error: {e}");
188                                            }
189                                        }
190                                        _ = cancel_clone.cancelled() => {
191                                            tracing::trace!("Connection cancelled");
192                                        }
193                                    }
194                                });
195                            }
196                            Err(e) => {
197                                tracing::error!("Failed to accept connection: {e}");
198                            }
199                        }
200                    }
201                    _ = cancellation_token.cancelled() => {
202                        tracing::info!("SharedHttpServer received cancellation signal, shutting down");
203                        return;
204                    }
205                }
206            }
207        });
208
209        Ok(actual_addr)
210    }
211
212    /// Wait for all inflight requests across all endpoints
213    pub async fn wait_for_inflight(&self) {
214        for handler in self.handlers.iter() {
215            while handler.value().inflight.load(Ordering::SeqCst) > 0 {
216                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
217            }
218        }
219    }
220}
221
222/// HTTP handler for the shared server
223async fn handle_shared_request(
224    AxumState(server): AxumState<Arc<SharedHttpServer>>,
225    Path(endpoint_path): Path<String>,
226    headers: HeaderMap,
227    body: Bytes,
228) -> impl IntoResponse {
229    // Look up the handler for this endpoint (lock-free read with DashMap)
230    let handler = match server.handlers.get(&endpoint_path) {
231        Some(h) => h.clone(),
232        None => {
233            tracing::warn!("No handler found for endpoint: {endpoint_path}");
234            return (StatusCode::NOT_FOUND, "Endpoint not found");
235        }
236    };
237
238    // Increment inflight counter
239    handler.inflight.fetch_add(1, Ordering::SeqCst);
240
241    // Extract tracing headers
242    let traceparent = TraceParent::from_axum_headers(&headers);
243
244    // Spawn async handler
245    let service_handler = handler.service_handler.clone();
246    let inflight = handler.inflight.clone();
247    let notify = handler.notify.clone();
248    let namespace = handler.namespace.clone();
249    let component_name = handler.component_name.clone();
250    let endpoint_name = handler.endpoint_name.clone();
251    let instance_id = handler.instance_id;
252
253    tokio::spawn(async move {
254        tracing::trace!(instance_id, "handling new HTTP request");
255        let result = service_handler
256            .handle_payload(body, traceparent.request_id.clone())
257            .instrument(tracing::info_span!(
258                "handle_payload",
259                component = component_name.as_ref(),
260                endpoint = endpoint_name.as_ref(),
261                namespace = namespace.as_ref(),
262                instance_id = instance_id,
263                trace_id = traceparent.trace_id,
264                parent_id = traceparent.parent_id,
265                x_request_id = traceparent.x_request_id,
266                request_id = traceparent.request_id,
267                tracestate = traceparent.tracestate
268            ))
269            .await;
270        match result {
271            Ok(_) => {
272                tracing::trace!(instance_id, "request handled successfully");
273            }
274            Err(e) => {
275                tracing::warn!("Failed to handle request: {}", e.to_string());
276            }
277        }
278
279        // Decrease inflight counter
280        inflight.fetch_sub(1, Ordering::SeqCst);
281        notify.notify_one();
282    });
283
284    // Return 202 Accepted immediately (like NATS ack)
285    (StatusCode::ACCEPTED, "")
286}
287
288/// Extension trait for TraceParent to support Axum headers
289impl TraceParent {
290    pub fn from_axum_headers(headers: &HeaderMap) -> Self {
291        let mut traceparent = TraceParent::default();
292
293        if let Some(value) = headers.get("traceparent")
294            && let Ok(s) = value.to_str()
295        {
296            traceparent.trace_id = Some(s.to_string());
297        }
298
299        if let Some(value) = headers.get("tracestate")
300            && let Ok(s) = value.to_str()
301        {
302            traceparent.tracestate = Some(s.to_string());
303        }
304
305        if let Some(value) = headers.get("x-request-id")
306            && let Ok(s) = value.to_str()
307        {
308            traceparent.x_request_id = Some(s.to_string());
309        }
310
311        // Read request-id from internal headers, with fallback to deprecated x-dynamo-request-id
312        if let Some(s) = headers
313            .get("request-id")
314            .and_then(|v| v.to_str().ok())
315            .filter(|s| uuid::Uuid::parse_str(s).is_ok())
316        {
317            traceparent.request_id = Some(s.to_string());
318        } else if let Some(s) = headers
319            .get("x-dynamo-request-id")
320            .and_then(|v| v.to_str().ok())
321            .filter(|s| uuid::Uuid::parse_str(s).is_ok())
322        {
323            traceparent.request_id = Some(s.to_string());
324        }
325
326        traceparent
327    }
328}
329
330// Implement RequestPlaneServer trait for SharedHttpServer
331#[async_trait::async_trait]
332impl super::unified_server::RequestPlaneServer for SharedHttpServer {
333    async fn register_endpoint(
334        &self,
335        endpoint_name: String,
336        service_handler: Arc<dyn PushWorkHandler>,
337        instance_id: u64,
338        namespace: String,
339        component_name: String,
340        system_health: Arc<Mutex<SystemHealth>>,
341    ) -> Result<()> {
342        // For HTTP, we use endpoint_name as both the subject (routing key) and endpoint_name
343        self.register_endpoint(
344            endpoint_name.clone(),
345            service_handler,
346            instance_id,
347            namespace,
348            component_name,
349            endpoint_name,
350            system_health,
351        )
352        .await
353    }
354
355    async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
356        self.unregister_endpoint(endpoint_name, endpoint_name).await;
357        Ok(())
358    }
359
360    fn address(&self) -> String {
361        let addr = self.actual_address().unwrap_or(self.bind_addr);
362        format!("http://{}:{}", addr.ip(), addr.port())
363    }
364
365    fn transport_name(&self) -> &'static str {
366        "http"
367    }
368
369    fn is_healthy(&self) -> bool {
370        // Server is healthy if it has been created
371        // TODO: Add more sophisticated health checks (e.g., check if listener is active)
372        true
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_traceparent_from_axum_headers() {
382        let mut headers = HeaderMap::new();
383        headers.insert("traceparent", "test-trace-id".parse().unwrap());
384        headers.insert("tracestate", "test-state".parse().unwrap());
385        headers.insert("x-request-id", "req-123".parse().unwrap());
386        headers.insert(
387            "x-dynamo-request-id",
388            "550e8400-e29b-41d4-a716-446655440000".parse().unwrap(),
389        );
390
391        let traceparent = TraceParent::from_axum_headers(&headers);
392        assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
393        assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
394        assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
395        assert_eq!(
396            traceparent.request_id,
397            Some("550e8400-e29b-41d4-a716-446655440000".to_string())
398        );
399    }
400
401    #[test]
402    fn test_shared_http_server_creation() {
403        use std::net::{IpAddr, Ipv4Addr};
404        let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
405        let token = CancellationToken::new();
406
407        let server = SharedHttpServer::new(bind_addr, token);
408        assert!(server.handlers.is_empty());
409    }
410
411    #[tokio::test]
412    async fn test_bind_and_start_assigns_os_port() {
413        use std::net::{IpAddr, Ipv4Addr};
414        let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
415        let token = CancellationToken::new();
416
417        let server = SharedHttpServer::new(bind_addr, token.clone());
418        let actual_addr = server.clone().bind_and_start().await.unwrap();
419
420        // OS should assign a non-zero port
421        assert_ne!(actual_addr.port(), 0);
422
423        // actual_address() should return the real bound address
424        assert_eq!(server.actual_address(), Some(actual_addr));
425
426        // address() should contain the real port
427        let addr_str =
428            <SharedHttpServer as super::unified_server::RequestPlaneServer>::address(&*server);
429        assert!(addr_str.contains(&actual_addr.port().to_string()));
430
431        token.cancel();
432    }
433
434    #[tokio::test]
435    async fn test_two_servers_get_different_ports() {
436        use std::net::{IpAddr, Ipv4Addr};
437        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
438
439        let token1 = CancellationToken::new();
440        let token2 = CancellationToken::new();
441
442        let server1 = SharedHttpServer::new(addr, token1.clone());
443        let server2 = SharedHttpServer::new(addr, token2.clone());
444
445        let actual1 = server1.clone().bind_and_start().await.unwrap();
446        let actual2 = server2.clone().bind_and_start().await.unwrap();
447
448        // Two servers binding to port 0 must get different ports
449        assert_ne!(actual1.port(), actual2.port());
450
451        token1.cancel();
452        token2.cancel();
453    }
454
455    #[tokio::test]
456    async fn test_bind_and_start_with_explicit_port() {
457        use std::net::{IpAddr, Ipv4Addr};
458
459        // First bind to port 0 to get a free port
460        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
461        let free_port = listener.local_addr().unwrap().port();
462        drop(listener); // Release the port
463
464        let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), free_port);
465        let token = CancellationToken::new();
466
467        let server = SharedHttpServer::new(bind_addr, token.clone());
468        let actual_addr = server.clone().bind_and_start().await.unwrap();
469
470        // When binding to an explicit port, actual should match
471        assert_eq!(actual_addr.port(), free_port);
472
473        token.cancel();
474    }
475}