dynamo_runtime/pipeline/network/ingress/
http_endpoint.rs1use super::*;
7use crate::SystemHealth;
8use crate::config::HealthStatus;
9use crate::logging::TraceParent;
10use anyhow::Result;
11use axum::{
12 Router,
13 body::Bytes,
14 extract::{Path, State as AxumState},
15 http::{HeaderMap, StatusCode},
16 response::IntoResponse,
17 routing::post,
18};
19use dashmap::DashMap;
20use hyper_util::rt::{TokioExecutor, TokioIo};
21use hyper_util::server::conn::auto::Builder as Http2Builder;
22use hyper_util::service::TowerToHyperService;
23use parking_lot::Mutex;
24use std::net::SocketAddr;
25use std::sync::atomic::{AtomicU64, Ordering};
26use tokio::sync::Notify;
27use tokio_util::sync::CancellationToken;
28use tower_http::trace::TraceLayer;
29use tracing::Instrument;
30
31const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
33
34pub const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37pub struct SharedHttpServer {
39 handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
40 bind_addr: SocketAddr,
41 cancellation_token: CancellationToken,
42}
43
44struct EndpointHandler {
46 service_handler: Arc<dyn PushWorkHandler>,
47 instance_id: u64,
48 namespace: Arc<String>,
49 component_name: Arc<String>,
50 endpoint_name: Arc<String>,
51 system_health: Arc<Mutex<SystemHealth>>,
52 inflight: Arc<AtomicU64>,
53 notify: Arc<Notify>,
54}
55
56impl SharedHttpServer {
57 pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
58 Arc::new(Self {
59 handlers: Arc::new(DashMap::new()),
60 bind_addr,
61 cancellation_token,
62 })
63 }
64
65 #[allow(clippy::too_many_arguments)]
67 pub async fn register_endpoint(
68 &self,
69 subject: String,
70 service_handler: Arc<dyn PushWorkHandler>,
71 instance_id: u64,
72 namespace: String,
73 component_name: String,
74 endpoint_name: String,
75 system_health: Arc<Mutex<SystemHealth>>,
76 ) -> Result<()> {
77 let handler = Arc::new(EndpointHandler {
78 service_handler,
79 instance_id,
80 namespace: Arc::new(namespace),
81 component_name: Arc::new(component_name),
82 endpoint_name: Arc::new(endpoint_name.clone()),
83 system_health: system_health.clone(),
84 inflight: Arc::new(AtomicU64::new(0)),
85 notify: Arc::new(Notify::new()),
86 });
87
88 let subject_clone = subject.clone();
90 self.handlers.insert(subject, handler);
91
92 system_health
94 .lock()
95 .set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
96
97 tracing::debug!("Registered endpoint handler for subject: {}", subject_clone);
98 Ok(())
99 }
100
101 pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
103 if let Some((_, handler)) = self.handlers.remove(subject) {
104 handler
105 .system_health
106 .lock()
107 .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
108 tracing::debug!(
109 endpoint_name = %endpoint_name,
110 subject = %subject,
111 "Unregistered HTTP endpoint handler"
112 );
113
114 let inflight_count = handler.inflight.load(Ordering::SeqCst);
115 if inflight_count > 0 {
116 tracing::info!(
117 endpoint_name = %endpoint_name,
118 inflight_count = inflight_count,
119 "Waiting for inflight HTTP requests to complete"
120 );
121 while handler.inflight.load(Ordering::SeqCst) > 0 {
122 handler.notify.notified().await;
123 }
124 tracing::info!(
125 endpoint_name = %endpoint_name,
126 "All inflight HTTP requests completed"
127 );
128 }
129 }
130 }
131
132 pub async fn start(self: Arc<Self>) -> Result<()> {
134 let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
135 .unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
136 let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
137
138 let app = Router::new()
139 .route(&route_pattern, post(handle_shared_request))
140 .layer(TraceLayer::new_for_http())
141 .with_state(self.clone());
142
143 tracing::info!(
144 "Starting shared HTTP/2 endpoint server on {} at path {}/:endpoint",
145 self.bind_addr,
146 rpc_root_path
147 );
148
149 let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
150 let cancellation_token = self.cancellation_token.clone();
151
152 loop {
153 tokio::select! {
154 accept_result = listener.accept() => {
155 match accept_result {
156 Ok((stream, _addr)) => {
157 let app_clone = app.clone();
158 let cancel_clone = cancellation_token.clone();
159
160 tokio::spawn(async move {
161 let http2_builder = Http2Builder::new(TokioExecutor::new());
163
164 let io = TokioIo::new(stream);
165 let tower_service = app_clone.into_service();
166
167 let hyper_service = TowerToHyperService::new(tower_service);
169
170 tokio::select! {
171 result = http2_builder.serve_connection(io, hyper_service) => {
172 if let Err(e) = result {
173 tracing::debug!("HTTP/2 connection error: {}", e);
174 }
175 }
176 _ = cancel_clone.cancelled() => {
177 tracing::trace!("Connection cancelled");
178 }
179 }
180 });
181 }
182 Err(e) => {
183 tracing::error!("Failed to accept connection: {}", e);
184 }
185 }
186 }
187 _ = cancellation_token.cancelled() => {
188 tracing::info!("SharedHttpServer received cancellation signal, shutting down");
189 return Ok(());
190 }
191 }
192 }
193 }
194
195 pub async fn wait_for_inflight(&self) {
197 for handler in self.handlers.iter() {
198 while handler.value().inflight.load(Ordering::SeqCst) > 0 {
199 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
200 }
201 }
202 }
203}
204
205async fn handle_shared_request(
207 AxumState(server): AxumState<Arc<SharedHttpServer>>,
208 Path(endpoint_path): Path<String>,
209 headers: HeaderMap,
210 body: Bytes,
211) -> impl IntoResponse {
212 let handler = match server.handlers.get(&endpoint_path) {
214 Some(h) => h.clone(),
215 None => {
216 tracing::warn!("No handler found for endpoint: {}", endpoint_path);
217 return (StatusCode::NOT_FOUND, "Endpoint not found");
218 }
219 };
220
221 handler.inflight.fetch_add(1, Ordering::SeqCst);
223
224 let traceparent = TraceParent::from_axum_headers(&headers);
226
227 let service_handler = handler.service_handler.clone();
229 let inflight = handler.inflight.clone();
230 let notify = handler.notify.clone();
231 let namespace = handler.namespace.clone();
232 let component_name = handler.component_name.clone();
233 let endpoint_name = handler.endpoint_name.clone();
234 let instance_id = handler.instance_id;
235
236 tokio::spawn(async move {
237 tracing::trace!(instance_id, "handling new HTTP request");
238 let result = service_handler
239 .handle_payload(body)
240 .instrument(tracing::info_span!(
241 "handle_payload",
242 component = component_name.as_ref(),
243 endpoint = endpoint_name.as_ref(),
244 namespace = namespace.as_ref(),
245 instance_id = instance_id,
246 trace_id = traceparent.trace_id,
247 parent_id = traceparent.parent_id,
248 x_request_id = traceparent.x_request_id,
249 x_dynamo_request_id = traceparent.x_dynamo_request_id,
250 tracestate = traceparent.tracestate
251 ))
252 .await;
253 match result {
254 Ok(_) => {
255 tracing::trace!(instance_id, "request handled successfully");
256 }
257 Err(e) => {
258 tracing::warn!("Failed to handle request: {}", e.to_string());
259 }
260 }
261
262 inflight.fetch_sub(1, Ordering::SeqCst);
264 notify.notify_one();
265 });
266
267 (StatusCode::ACCEPTED, "")
269}
270
271impl TraceParent {
273 pub fn from_axum_headers(headers: &HeaderMap) -> Self {
274 let mut traceparent = TraceParent::default();
275
276 if let Some(value) = headers.get("traceparent")
277 && let Ok(s) = value.to_str()
278 {
279 traceparent.trace_id = Some(s.to_string());
280 }
281
282 if let Some(value) = headers.get("tracestate")
283 && let Ok(s) = value.to_str()
284 {
285 traceparent.tracestate = Some(s.to_string());
286 }
287
288 if let Some(value) = headers.get("x-request-id")
289 && let Ok(s) = value.to_str()
290 {
291 traceparent.x_request_id = Some(s.to_string());
292 }
293
294 if let Some(value) = headers.get("x-dynamo-request-id")
295 && let Ok(s) = value.to_str()
296 {
297 traceparent.x_dynamo_request_id = Some(s.to_string());
298 }
299
300 traceparent
301 }
302}
303
304#[async_trait::async_trait]
306impl super::unified_server::RequestPlaneServer for SharedHttpServer {
307 async fn register_endpoint(
308 &self,
309 endpoint_name: String,
310 service_handler: Arc<dyn PushWorkHandler>,
311 instance_id: u64,
312 namespace: String,
313 component_name: String,
314 system_health: Arc<Mutex<SystemHealth>>,
315 ) -> Result<()> {
316 self.register_endpoint(
318 endpoint_name.clone(),
319 service_handler,
320 instance_id,
321 namespace,
322 component_name,
323 endpoint_name,
324 system_health,
325 )
326 .await
327 }
328
329 async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
330 self.unregister_endpoint(endpoint_name, endpoint_name).await;
331 Ok(())
332 }
333
334 fn address(&self) -> String {
335 format!("http://{}:{}", self.bind_addr.ip(), self.bind_addr.port())
336 }
337
338 fn transport_name(&self) -> &'static str {
339 "http"
340 }
341
342 fn is_healthy(&self) -> bool {
343 true
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_traceparent_from_axum_headers() {
355 let mut headers = HeaderMap::new();
356 headers.insert("traceparent", "test-trace-id".parse().unwrap());
357 headers.insert("tracestate", "test-state".parse().unwrap());
358 headers.insert("x-request-id", "req-123".parse().unwrap());
359 headers.insert("x-dynamo-request-id", "dyn-456".parse().unwrap());
360
361 let traceparent = TraceParent::from_axum_headers(&headers);
362 assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
363 assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
364 assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
365 assert_eq!(traceparent.x_dynamo_request_id, Some("dyn-456".to_string()));
366 }
367
368 #[test]
369 fn test_shared_http_server_creation() {
370 use std::net::{IpAddr, Ipv4Addr};
371 let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
372 let token = CancellationToken::new();
373
374 let server = SharedHttpServer::new(bind_addr, token);
375 assert!(server.handlers.is_empty());
376 }
377}