dynamo_runtime/pipeline/network/ingress/
http_endpoint.rs1use super::*;
7use crate::SystemHealth;
8use crate::config::HealthStatus;
9use crate::logging::TraceParent;
10use anyhow::Result;
11use axum::{
12 Router,
13 body::Bytes,
14 extract::{Path, State as AxumState},
15 http::{HeaderMap, StatusCode},
16 response::IntoResponse,
17 routing::post,
18};
19use dashmap::DashMap;
20use hyper_util::rt::{TokioExecutor, TokioIo};
21use hyper_util::server::conn::auto::Builder as Http2Builder;
22use hyper_util::service::TowerToHyperService;
23use parking_lot::Mutex;
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tokio::sync::{Notify, RwLock};
27use tokio_util::sync::CancellationToken;
28use tower_http::trace::TraceLayer;
29use tracing::Instrument;
30
31const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
33
34pub const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37pub struct SharedHttpServer {
39 handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
40 bind_addr: SocketAddr,
41 actual_addr: RwLock<Option<SocketAddr>>,
42 cancellation_token: CancellationToken,
43}
44
45struct EndpointHandler {
47 service_handler: Arc<dyn PushWorkHandler>,
48 instance_id: u64,
49 namespace: Arc<String>,
50 component_name: Arc<String>,
51 endpoint_name: Arc<String>,
52 system_health: Arc<Mutex<SystemHealth>>,
53 inflight: Arc<AtomicU64>,
54 notify: Arc<Notify>,
55}
56
57impl SharedHttpServer {
58 pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
59 Arc::new(Self {
60 handlers: Arc::new(DashMap::new()),
61 bind_addr,
62 actual_addr: RwLock::new(None),
63 cancellation_token,
64 })
65 }
66
67 pub fn actual_address(&self) -> Option<SocketAddr> {
69 self.actual_addr.try_read().ok().and_then(|g| *g)
70 }
71
72 #[allow(clippy::too_many_arguments)]
74 pub async fn register_endpoint(
75 &self,
76 subject: String,
77 service_handler: Arc<dyn PushWorkHandler>,
78 instance_id: u64,
79 namespace: String,
80 component_name: String,
81 endpoint_name: String,
82 system_health: Arc<Mutex<SystemHealth>>,
83 ) -> Result<()> {
84 let handler = Arc::new(EndpointHandler {
85 service_handler,
86 instance_id,
87 namespace: Arc::new(namespace),
88 component_name: Arc::new(component_name),
89 endpoint_name: Arc::new(endpoint_name.clone()),
90 system_health: system_health.clone(),
91 inflight: Arc::new(AtomicU64::new(0)),
92 notify: Arc::new(Notify::new()),
93 });
94
95 let subject_clone = subject.clone();
97 self.handlers.insert(subject, handler);
98
99 system_health
101 .lock()
102 .set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
103
104 tracing::debug!("Registered endpoint handler for subject: {subject_clone}");
105 Ok(())
106 }
107
108 pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
110 if let Some((_, handler)) = self.handlers.remove(subject) {
111 handler
112 .system_health
113 .lock()
114 .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
115 tracing::debug!(
116 endpoint_name = %endpoint_name,
117 subject = %subject,
118 "Unregistered HTTP endpoint handler"
119 );
120
121 let inflight_count = handler.inflight.load(Ordering::SeqCst);
122 if inflight_count > 0 {
123 tracing::info!(
124 endpoint_name = %endpoint_name,
125 inflight_count = inflight_count,
126 "Waiting for inflight HTTP requests to complete"
127 );
128 while handler.inflight.load(Ordering::SeqCst) > 0 {
129 handler.notify.notified().await;
130 }
131 tracing::info!(
132 endpoint_name = %endpoint_name,
133 "All inflight HTTP requests completed"
134 );
135 }
136 }
137 }
138
139 pub async fn bind_and_start(self: Arc<Self>) -> Result<SocketAddr> {
143 let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
144 .unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
145 let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
146
147 let app = Router::new()
148 .route(&route_pattern, post(handle_shared_request))
149 .layer(TraceLayer::new_for_http())
150 .with_state(self.clone());
151
152 let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
153 let actual_addr = listener.local_addr()?;
154
155 tracing::info!(
156 requested = %self.bind_addr,
157 actual = %actual_addr,
158 rpc_root = %rpc_root_path,
159 "HTTP/2 endpoint server bound"
160 );
161
162 *self.actual_addr.write().await = Some(actual_addr);
164
165 let cancellation_token = self.cancellation_token.clone();
166
167 tokio::spawn(async move {
169 loop {
170 tokio::select! {
171 accept_result = listener.accept() => {
172 match accept_result {
173 Ok((stream, _addr)) => {
174 let app_clone = app.clone();
175 let cancel_clone = cancellation_token.clone();
176
177 tokio::spawn(async move {
178 let http2_builder = Http2Builder::new(TokioExecutor::new());
179
180 let io = TokioIo::new(stream);
181 let tower_service = app_clone.into_service();
182 let hyper_service = TowerToHyperService::new(tower_service);
183
184 tokio::select! {
185 result = http2_builder.serve_connection(io, hyper_service) => {
186 if let Err(e) = result {
187 tracing::debug!("HTTP/2 connection error: {e}");
188 }
189 }
190 _ = cancel_clone.cancelled() => {
191 tracing::trace!("Connection cancelled");
192 }
193 }
194 });
195 }
196 Err(e) => {
197 tracing::error!("Failed to accept connection: {e}");
198 }
199 }
200 }
201 _ = cancellation_token.cancelled() => {
202 tracing::info!("SharedHttpServer received cancellation signal, shutting down");
203 return;
204 }
205 }
206 }
207 });
208
209 Ok(actual_addr)
210 }
211
212 pub async fn wait_for_inflight(&self) {
214 for handler in self.handlers.iter() {
215 while handler.value().inflight.load(Ordering::SeqCst) > 0 {
216 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
217 }
218 }
219 }
220}
221
222async fn handle_shared_request(
224 AxumState(server): AxumState<Arc<SharedHttpServer>>,
225 Path(endpoint_path): Path<String>,
226 headers: HeaderMap,
227 body: Bytes,
228) -> impl IntoResponse {
229 let handler = match server.handlers.get(&endpoint_path) {
231 Some(h) => h.clone(),
232 None => {
233 tracing::warn!("No handler found for endpoint: {endpoint_path}");
234 return (StatusCode::NOT_FOUND, "Endpoint not found");
235 }
236 };
237
238 handler.inflight.fetch_add(1, Ordering::SeqCst);
240
241 let traceparent = TraceParent::from_axum_headers(&headers);
243
244 let service_handler = handler.service_handler.clone();
246 let inflight = handler.inflight.clone();
247 let notify = handler.notify.clone();
248 let namespace = handler.namespace.clone();
249 let component_name = handler.component_name.clone();
250 let endpoint_name = handler.endpoint_name.clone();
251 let instance_id = handler.instance_id;
252
253 tokio::spawn(async move {
254 tracing::trace!(instance_id, "handling new HTTP request");
255 let result = service_handler
256 .handle_payload(body, traceparent.request_id.clone())
257 .instrument(tracing::info_span!(
258 "handle_payload",
259 component = component_name.as_ref(),
260 endpoint = endpoint_name.as_ref(),
261 namespace = namespace.as_ref(),
262 instance_id = instance_id,
263 trace_id = traceparent.trace_id,
264 parent_id = traceparent.parent_id,
265 x_request_id = traceparent.x_request_id,
266 request_id = traceparent.request_id,
267 tracestate = traceparent.tracestate
268 ))
269 .await;
270 match result {
271 Ok(_) => {
272 tracing::trace!(instance_id, "request handled successfully");
273 }
274 Err(e) => {
275 tracing::warn!("Failed to handle request: {}", e.to_string());
276 }
277 }
278
279 inflight.fetch_sub(1, Ordering::SeqCst);
281 notify.notify_one();
282 });
283
284 (StatusCode::ACCEPTED, "")
286}
287
288impl TraceParent {
290 pub fn from_axum_headers(headers: &HeaderMap) -> Self {
291 let mut traceparent = TraceParent::default();
292
293 if let Some(value) = headers.get("traceparent")
294 && let Ok(s) = value.to_str()
295 {
296 traceparent.trace_id = Some(s.to_string());
297 }
298
299 if let Some(value) = headers.get("tracestate")
300 && let Ok(s) = value.to_str()
301 {
302 traceparent.tracestate = Some(s.to_string());
303 }
304
305 if let Some(value) = headers.get("x-request-id")
306 && let Ok(s) = value.to_str()
307 {
308 traceparent.x_request_id = Some(s.to_string());
309 }
310
311 if let Some(s) = headers
313 .get("request-id")
314 .and_then(|v| v.to_str().ok())
315 .filter(|s| uuid::Uuid::parse_str(s).is_ok())
316 {
317 traceparent.request_id = Some(s.to_string());
318 } else if let Some(s) = headers
319 .get("x-dynamo-request-id")
320 .and_then(|v| v.to_str().ok())
321 .filter(|s| uuid::Uuid::parse_str(s).is_ok())
322 {
323 traceparent.request_id = Some(s.to_string());
324 }
325
326 traceparent
327 }
328}
329
330#[async_trait::async_trait]
332impl super::unified_server::RequestPlaneServer for SharedHttpServer {
333 async fn register_endpoint(
334 &self,
335 endpoint_name: String,
336 service_handler: Arc<dyn PushWorkHandler>,
337 instance_id: u64,
338 namespace: String,
339 component_name: String,
340 system_health: Arc<Mutex<SystemHealth>>,
341 ) -> Result<()> {
342 self.register_endpoint(
344 endpoint_name.clone(),
345 service_handler,
346 instance_id,
347 namespace,
348 component_name,
349 endpoint_name,
350 system_health,
351 )
352 .await
353 }
354
355 async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
356 self.unregister_endpoint(endpoint_name, endpoint_name).await;
357 Ok(())
358 }
359
360 fn address(&self) -> String {
361 let addr = self.actual_address().unwrap_or(self.bind_addr);
362 format!("http://{}:{}", addr.ip(), addr.port())
363 }
364
365 fn transport_name(&self) -> &'static str {
366 "http"
367 }
368
369 fn is_healthy(&self) -> bool {
370 true
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_traceparent_from_axum_headers() {
382 let mut headers = HeaderMap::new();
383 headers.insert("traceparent", "test-trace-id".parse().unwrap());
384 headers.insert("tracestate", "test-state".parse().unwrap());
385 headers.insert("x-request-id", "req-123".parse().unwrap());
386 headers.insert(
387 "x-dynamo-request-id",
388 "550e8400-e29b-41d4-a716-446655440000".parse().unwrap(),
389 );
390
391 let traceparent = TraceParent::from_axum_headers(&headers);
392 assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
393 assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
394 assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
395 assert_eq!(
396 traceparent.request_id,
397 Some("550e8400-e29b-41d4-a716-446655440000".to_string())
398 );
399 }
400
401 #[test]
402 fn test_shared_http_server_creation() {
403 use std::net::{IpAddr, Ipv4Addr};
404 let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
405 let token = CancellationToken::new();
406
407 let server = SharedHttpServer::new(bind_addr, token);
408 assert!(server.handlers.is_empty());
409 }
410
411 #[tokio::test]
412 async fn test_bind_and_start_assigns_os_port() {
413 use std::net::{IpAddr, Ipv4Addr};
414 let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
415 let token = CancellationToken::new();
416
417 let server = SharedHttpServer::new(bind_addr, token.clone());
418 let actual_addr = server.clone().bind_and_start().await.unwrap();
419
420 assert_ne!(actual_addr.port(), 0);
422
423 assert_eq!(server.actual_address(), Some(actual_addr));
425
426 let addr_str =
428 <SharedHttpServer as super::unified_server::RequestPlaneServer>::address(&*server);
429 assert!(addr_str.contains(&actual_addr.port().to_string()));
430
431 token.cancel();
432 }
433
434 #[tokio::test]
435 async fn test_two_servers_get_different_ports() {
436 use std::net::{IpAddr, Ipv4Addr};
437 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
438
439 let token1 = CancellationToken::new();
440 let token2 = CancellationToken::new();
441
442 let server1 = SharedHttpServer::new(addr, token1.clone());
443 let server2 = SharedHttpServer::new(addr, token2.clone());
444
445 let actual1 = server1.clone().bind_and_start().await.unwrap();
446 let actual2 = server2.clone().bind_and_start().await.unwrap();
447
448 assert_ne!(actual1.port(), actual2.port());
450
451 token1.cancel();
452 token2.cancel();
453 }
454
455 #[tokio::test]
456 async fn test_bind_and_start_with_explicit_port() {
457 use std::net::{IpAddr, Ipv4Addr};
458
459 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
461 let free_port = listener.local_addr().unwrap().port();
462 drop(listener); let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), free_port);
465 let token = CancellationToken::new();
466
467 let server = SharedHttpServer::new(bind_addr, token.clone());
468 let actual_addr = server.clone().bind_and_start().await.unwrap();
469
470 assert_eq!(actual_addr.port(), free_port);
472
473 token.cancel();
474 }
475}