Skip to main content

dynamo_runtime/pipeline/network/ingress/
http_endpoint.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! HTTP endpoint for receiving requests via Axum/HTTP/2
5
6use super::*;
7use crate::SystemHealth;
8use crate::config::HealthStatus;
9use crate::logging::TraceParent;
10use anyhow::Result;
11use axum::{
12    Router,
13    body::Bytes,
14    extract::{Path, State as AxumState},
15    http::{HeaderMap, StatusCode},
16    response::IntoResponse,
17    routing::post,
18};
19use dashmap::DashMap;
20use hyper_util::rt::{TokioExecutor, TokioIo};
21use hyper_util::server::conn::auto::Builder as Http2Builder;
22use hyper_util::service::TowerToHyperService;
23use parking_lot::Mutex;
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tokio::sync::Notify;
27use tokio_util::sync::CancellationToken;
28use tower_http::trace::TraceLayer;
29use tracing::Instrument;
30
31/// Default root path for dynamo RPC endpoints
32const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
33
34/// version of crate
35pub const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37/// Shared HTTP server that handles multiple endpoints on a single port
38pub struct SharedHttpServer {
39    handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
40    bind_addr: SocketAddr,
41    cancellation_token: CancellationToken,
42}
43
44/// Handler for a specific endpoint
45struct EndpointHandler {
46    service_handler: Arc<dyn PushWorkHandler>,
47    instance_id: u64,
48    namespace: Arc<String>,
49    component_name: Arc<String>,
50    endpoint_name: Arc<String>,
51    system_health: Arc<Mutex<SystemHealth>>,
52    inflight: Arc<AtomicU64>,
53    notify: Arc<Notify>,
54}
55
56impl SharedHttpServer {
57    pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
58        Arc::new(Self {
59            handlers: Arc::new(DashMap::new()),
60            bind_addr,
61            cancellation_token,
62        })
63    }
64
65    /// Register an endpoint handler with this server
66    #[allow(clippy::too_many_arguments)]
67    pub async fn register_endpoint(
68        &self,
69        subject: String,
70        service_handler: Arc<dyn PushWorkHandler>,
71        instance_id: u64,
72        namespace: String,
73        component_name: String,
74        endpoint_name: String,
75        system_health: Arc<Mutex<SystemHealth>>,
76    ) -> Result<()> {
77        let handler = Arc::new(EndpointHandler {
78            service_handler,
79            instance_id,
80            namespace: Arc::new(namespace),
81            component_name: Arc::new(component_name),
82            endpoint_name: Arc::new(endpoint_name.clone()),
83            system_health: system_health.clone(),
84            inflight: Arc::new(AtomicU64::new(0)),
85            notify: Arc::new(Notify::new()),
86        });
87
88        // Insert handler FIRST to ensure it's ready to receive requests
89        let subject_clone = subject.clone();
90        self.handlers.insert(subject, handler);
91
92        // THEN set health status to Ready (after handler is registered and ready)
93        system_health
94            .lock()
95            .set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
96
97        tracing::debug!("Registered endpoint handler for subject: {}", subject_clone);
98        Ok(())
99    }
100
101    /// Unregister an endpoint handler
102    pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
103        if let Some((_, handler)) = self.handlers.remove(subject) {
104            handler
105                .system_health
106                .lock()
107                .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
108            tracing::debug!(
109                endpoint_name = %endpoint_name,
110                subject = %subject,
111                "Unregistered HTTP endpoint handler"
112            );
113
114            let inflight_count = handler.inflight.load(Ordering::SeqCst);
115            if inflight_count > 0 {
116                tracing::info!(
117                    endpoint_name = %endpoint_name,
118                    inflight_count = inflight_count,
119                    "Waiting for inflight HTTP requests to complete"
120                );
121                while handler.inflight.load(Ordering::SeqCst) > 0 {
122                    handler.notify.notified().await;
123                }
124                tracing::info!(
125                    endpoint_name = %endpoint_name,
126                    "All inflight HTTP requests completed"
127                );
128            }
129        }
130    }
131
132    /// Start the shared HTTP server
133    pub async fn start(self: Arc<Self>) -> Result<()> {
134        let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
135            .unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
136        let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
137
138        let app = Router::new()
139            .route(&route_pattern, post(handle_shared_request))
140            .layer(TraceLayer::new_for_http())
141            .with_state(self.clone());
142
143        tracing::info!(
144            "Starting shared HTTP/2 endpoint server on {} at path {}/:endpoint",
145            self.bind_addr,
146            rpc_root_path
147        );
148
149        let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
150        let cancellation_token = self.cancellation_token.clone();
151
152        loop {
153            tokio::select! {
154                accept_result = listener.accept() => {
155                    match accept_result {
156                        Ok((stream, _addr)) => {
157                            let app_clone = app.clone();
158                            let cancel_clone = cancellation_token.clone();
159
160                            tokio::spawn(async move {
161                                // Create HTTP/2 connection builder with prior knowledge
162                                let http2_builder = Http2Builder::new(TokioExecutor::new());
163
164                                let io = TokioIo::new(stream);
165                                let tower_service = app_clone.into_service();
166
167                                // Wrap Tower service for Hyper compatibility
168                                let hyper_service = TowerToHyperService::new(tower_service);
169
170                                tokio::select! {
171                                    result = http2_builder.serve_connection(io, hyper_service) => {
172                                        if let Err(e) = result {
173                                            tracing::debug!("HTTP/2 connection error: {}", e);
174                                        }
175                                    }
176                                    _ = cancel_clone.cancelled() => {
177                                        tracing::trace!("Connection cancelled");
178                                    }
179                                }
180                            });
181                        }
182                        Err(e) => {
183                            tracing::error!("Failed to accept connection: {}", e);
184                        }
185                    }
186                }
187                _ = cancellation_token.cancelled() => {
188                    tracing::info!("SharedHttpServer received cancellation signal, shutting down");
189                    return Ok(());
190                }
191            }
192        }
193    }
194
195    /// Wait for all inflight requests across all endpoints
196    pub async fn wait_for_inflight(&self) {
197        for handler in self.handlers.iter() {
198            while handler.value().inflight.load(Ordering::SeqCst) > 0 {
199                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
200            }
201        }
202    }
203}
204
205/// HTTP handler for the shared server
206async fn handle_shared_request(
207    AxumState(server): AxumState<Arc<SharedHttpServer>>,
208    Path(endpoint_path): Path<String>,
209    headers: HeaderMap,
210    body: Bytes,
211) -> impl IntoResponse {
212    // Look up the handler for this endpoint (lock-free read with DashMap)
213    let handler = match server.handlers.get(&endpoint_path) {
214        Some(h) => h.clone(),
215        None => {
216            tracing::warn!("No handler found for endpoint: {}", endpoint_path);
217            return (StatusCode::NOT_FOUND, "Endpoint not found");
218        }
219    };
220
221    // Increment inflight counter
222    handler.inflight.fetch_add(1, Ordering::SeqCst);
223
224    // Extract tracing headers
225    let traceparent = TraceParent::from_axum_headers(&headers);
226
227    // Spawn async handler
228    let service_handler = handler.service_handler.clone();
229    let inflight = handler.inflight.clone();
230    let notify = handler.notify.clone();
231    let namespace = handler.namespace.clone();
232    let component_name = handler.component_name.clone();
233    let endpoint_name = handler.endpoint_name.clone();
234    let instance_id = handler.instance_id;
235
236    tokio::spawn(async move {
237        tracing::trace!(instance_id, "handling new HTTP request");
238        let result = service_handler
239            .handle_payload(body)
240            .instrument(tracing::info_span!(
241                "handle_payload",
242                component = component_name.as_ref(),
243                endpoint = endpoint_name.as_ref(),
244                namespace = namespace.as_ref(),
245                instance_id = instance_id,
246                trace_id = traceparent.trace_id,
247                parent_id = traceparent.parent_id,
248                x_request_id = traceparent.x_request_id,
249                x_dynamo_request_id = traceparent.x_dynamo_request_id,
250                tracestate = traceparent.tracestate
251            ))
252            .await;
253        match result {
254            Ok(_) => {
255                tracing::trace!(instance_id, "request handled successfully");
256            }
257            Err(e) => {
258                tracing::warn!("Failed to handle request: {}", e.to_string());
259            }
260        }
261
262        // Decrease inflight counter
263        inflight.fetch_sub(1, Ordering::SeqCst);
264        notify.notify_one();
265    });
266
267    // Return 202 Accepted immediately (like NATS ack)
268    (StatusCode::ACCEPTED, "")
269}
270
271/// Extension trait for TraceParent to support Axum headers
272impl TraceParent {
273    pub fn from_axum_headers(headers: &HeaderMap) -> Self {
274        let mut traceparent = TraceParent::default();
275
276        if let Some(value) = headers.get("traceparent")
277            && let Ok(s) = value.to_str()
278        {
279            traceparent.trace_id = Some(s.to_string());
280        }
281
282        if let Some(value) = headers.get("tracestate")
283            && let Ok(s) = value.to_str()
284        {
285            traceparent.tracestate = Some(s.to_string());
286        }
287
288        if let Some(value) = headers.get("x-request-id")
289            && let Ok(s) = value.to_str()
290        {
291            traceparent.x_request_id = Some(s.to_string());
292        }
293
294        if let Some(value) = headers.get("x-dynamo-request-id")
295            && let Ok(s) = value.to_str()
296        {
297            traceparent.x_dynamo_request_id = Some(s.to_string());
298        }
299
300        traceparent
301    }
302}
303
304// Implement RequestPlaneServer trait for SharedHttpServer
305#[async_trait::async_trait]
306impl super::unified_server::RequestPlaneServer for SharedHttpServer {
307    async fn register_endpoint(
308        &self,
309        endpoint_name: String,
310        service_handler: Arc<dyn PushWorkHandler>,
311        instance_id: u64,
312        namespace: String,
313        component_name: String,
314        system_health: Arc<Mutex<SystemHealth>>,
315    ) -> Result<()> {
316        // For HTTP, we use endpoint_name as both the subject (routing key) and endpoint_name
317        self.register_endpoint(
318            endpoint_name.clone(),
319            service_handler,
320            instance_id,
321            namespace,
322            component_name,
323            endpoint_name,
324            system_health,
325        )
326        .await
327    }
328
329    async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
330        self.unregister_endpoint(endpoint_name, endpoint_name).await;
331        Ok(())
332    }
333
334    fn address(&self) -> String {
335        format!("http://{}:{}", self.bind_addr.ip(), self.bind_addr.port())
336    }
337
338    fn transport_name(&self) -> &'static str {
339        "http"
340    }
341
342    fn is_healthy(&self) -> bool {
343        // Server is healthy if it has been created
344        // TODO: Add more sophisticated health checks (e.g., check if listener is active)
345        true
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_traceparent_from_axum_headers() {
355        let mut headers = HeaderMap::new();
356        headers.insert("traceparent", "test-trace-id".parse().unwrap());
357        headers.insert("tracestate", "test-state".parse().unwrap());
358        headers.insert("x-request-id", "req-123".parse().unwrap());
359        headers.insert("x-dynamo-request-id", "dyn-456".parse().unwrap());
360
361        let traceparent = TraceParent::from_axum_headers(&headers);
362        assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
363        assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
364        assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
365        assert_eq!(traceparent.x_dynamo_request_id, Some("dyn-456".to_string()));
366    }
367
368    #[test]
369    fn test_shared_http_server_creation() {
370        use std::net::{IpAddr, Ipv4Addr};
371        let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
372        let token = CancellationToken::new();
373
374        let server = SharedHttpServer::new(bind_addr, token);
375        assert!(server.handlers.is_empty());
376    }
377}