use super::*;
use crate::SystemHealth;
use crate::config::HealthStatus;
use crate::logging::TraceParent;
use anyhow::Result;
use axum::{
Router,
body::Bytes,
extract::{Path, State as AxumState},
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
};
use dashmap::DashMap;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as Http2Builder;
use hyper_util::service::TowerToHyperService;
use parking_lot::Mutex;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Notify, RwLock};
use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
use tracing::Instrument;
const DEFAULT_RPC_ROOT_PATH: &str = "/v1/rpc";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub struct SharedHttpServer {
handlers: Arc<DashMap<String, Arc<EndpointHandler>>>,
bind_addr: SocketAddr,
actual_addr: RwLock<Option<SocketAddr>>,
cancellation_token: CancellationToken,
}
struct EndpointHandler {
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: Arc<String>,
component_name: Arc<String>,
endpoint_name: Arc<String>,
system_health: Arc<Mutex<SystemHealth>>,
inflight: Arc<AtomicU64>,
notify: Arc<Notify>,
}
impl SharedHttpServer {
pub fn new(bind_addr: SocketAddr, cancellation_token: CancellationToken) -> Arc<Self> {
Arc::new(Self {
handlers: Arc::new(DashMap::new()),
bind_addr,
actual_addr: RwLock::new(None),
cancellation_token,
})
}
pub fn actual_address(&self) -> Option<SocketAddr> {
self.actual_addr.try_read().ok().and_then(|g| *g)
}
#[allow(clippy::too_many_arguments)]
pub async fn register_endpoint(
&self,
subject: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
endpoint_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
let handler = Arc::new(EndpointHandler {
service_handler,
instance_id,
namespace: Arc::new(namespace),
component_name: Arc::new(component_name),
endpoint_name: Arc::new(endpoint_name.clone()),
system_health: system_health.clone(),
inflight: Arc::new(AtomicU64::new(0)),
notify: Arc::new(Notify::new()),
});
let subject_clone = subject.clone();
self.handlers.insert(subject, handler);
system_health
.lock()
.set_endpoint_health_status(&endpoint_name, HealthStatus::Ready);
tracing::debug!("Registered endpoint handler for subject: {}", subject_clone);
Ok(())
}
pub async fn unregister_endpoint(&self, subject: &str, endpoint_name: &str) {
if let Some((_, handler)) = self.handlers.remove(subject) {
handler
.system_health
.lock()
.set_endpoint_health_status(endpoint_name, HealthStatus::NotReady);
tracing::debug!(
endpoint_name = %endpoint_name,
subject = %subject,
"Unregistered HTTP endpoint handler"
);
let inflight_count = handler.inflight.load(Ordering::SeqCst);
if inflight_count > 0 {
tracing::info!(
endpoint_name = %endpoint_name,
inflight_count = inflight_count,
"Waiting for inflight HTTP requests to complete"
);
while handler.inflight.load(Ordering::SeqCst) > 0 {
handler.notify.notified().await;
}
tracing::info!(
endpoint_name = %endpoint_name,
"All inflight HTTP requests completed"
);
}
}
}
pub async fn bind_and_start(self: Arc<Self>) -> Result<SocketAddr> {
let rpc_root_path = std::env::var("DYN_HTTP_RPC_ROOT_PATH")
.unwrap_or_else(|_| DEFAULT_RPC_ROOT_PATH.to_string());
let route_pattern = format!("{}/{{*endpoint}}", rpc_root_path);
let app = Router::new()
.route(&route_pattern, post(handle_shared_request))
.layer(TraceLayer::new_for_http())
.with_state(self.clone());
let listener = tokio::net::TcpListener::bind(&self.bind_addr).await?;
let actual_addr = listener.local_addr()?;
tracing::info!(
requested = %self.bind_addr,
actual = %actual_addr,
rpc_root = %rpc_root_path,
"HTTP/2 endpoint server bound"
);
*self.actual_addr.write().await = Some(actual_addr);
let cancellation_token = self.cancellation_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
let app_clone = app.clone();
let cancel_clone = cancellation_token.clone();
tokio::spawn(async move {
let http2_builder = Http2Builder::new(TokioExecutor::new());
let io = TokioIo::new(stream);
let tower_service = app_clone.into_service();
let hyper_service = TowerToHyperService::new(tower_service);
tokio::select! {
result = http2_builder.serve_connection(io, hyper_service) => {
if let Err(e) = result {
tracing::debug!("HTTP/2 connection error: {e}");
}
}
_ = cancel_clone.cancelled() => {
tracing::trace!("Connection cancelled");
}
}
});
}
Err(e) => {
tracing::error!("Failed to accept connection: {e}");
}
}
}
_ = cancellation_token.cancelled() => {
tracing::info!("SharedHttpServer received cancellation signal, shutting down");
return;
}
}
}
});
Ok(actual_addr)
}
pub async fn wait_for_inflight(&self) {
for handler in self.handlers.iter() {
while handler.value().inflight.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
}
}
async fn handle_shared_request(
AxumState(server): AxumState<Arc<SharedHttpServer>>,
Path(endpoint_path): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let handler = match server.handlers.get(&endpoint_path) {
Some(h) => h.clone(),
None => {
tracing::warn!("No handler found for endpoint: {}", endpoint_path);
return (StatusCode::NOT_FOUND, "Endpoint not found");
}
};
handler.inflight.fetch_add(1, Ordering::SeqCst);
let traceparent = TraceParent::from_axum_headers(&headers);
let service_handler = handler.service_handler.clone();
let inflight = handler.inflight.clone();
let notify = handler.notify.clone();
let namespace = handler.namespace.clone();
let component_name = handler.component_name.clone();
let endpoint_name = handler.endpoint_name.clone();
let instance_id = handler.instance_id;
tokio::spawn(async move {
tracing::trace!(instance_id, "handling new HTTP request");
let result = service_handler
.handle_payload(body)
.instrument(tracing::info_span!(
"handle_payload",
component = component_name.as_ref(),
endpoint = endpoint_name.as_ref(),
namespace = namespace.as_ref(),
instance_id = instance_id,
trace_id = traceparent.trace_id,
parent_id = traceparent.parent_id,
x_request_id = traceparent.x_request_id,
x_dynamo_request_id = traceparent.x_dynamo_request_id,
tracestate = traceparent.tracestate
))
.await;
match result {
Ok(_) => {
tracing::trace!(instance_id, "request handled successfully");
}
Err(e) => {
tracing::warn!("Failed to handle request: {}", e.to_string());
}
}
inflight.fetch_sub(1, Ordering::SeqCst);
notify.notify_one();
});
(StatusCode::ACCEPTED, "")
}
impl TraceParent {
pub fn from_axum_headers(headers: &HeaderMap) -> Self {
let mut traceparent = TraceParent::default();
if let Some(value) = headers.get("traceparent")
&& let Ok(s) = value.to_str()
{
traceparent.trace_id = Some(s.to_string());
}
if let Some(value) = headers.get("tracestate")
&& let Ok(s) = value.to_str()
{
traceparent.tracestate = Some(s.to_string());
}
if let Some(value) = headers.get("x-request-id")
&& let Ok(s) = value.to_str()
{
traceparent.x_request_id = Some(s.to_string());
}
if let Some(value) = headers.get("x-dynamo-request-id")
&& let Ok(s) = value.to_str()
{
traceparent.x_dynamo_request_id = Some(s.to_string());
}
traceparent
}
}
#[async_trait::async_trait]
impl super::unified_server::RequestPlaneServer for SharedHttpServer {
async fn register_endpoint(
&self,
endpoint_name: String,
service_handler: Arc<dyn PushWorkHandler>,
instance_id: u64,
namespace: String,
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
self.register_endpoint(
endpoint_name.clone(),
service_handler,
instance_id,
namespace,
component_name,
endpoint_name,
system_health,
)
.await
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
self.unregister_endpoint(endpoint_name, endpoint_name).await;
Ok(())
}
fn address(&self) -> String {
let addr = self.actual_address().unwrap_or(self.bind_addr);
format!("http://{}:{}", addr.ip(), addr.port())
}
fn transport_name(&self) -> &'static str {
"http"
}
fn is_healthy(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_traceparent_from_axum_headers() {
let mut headers = HeaderMap::new();
headers.insert("traceparent", "test-trace-id".parse().unwrap());
headers.insert("tracestate", "test-state".parse().unwrap());
headers.insert("x-request-id", "req-123".parse().unwrap());
headers.insert("x-dynamo-request-id", "dyn-456".parse().unwrap());
let traceparent = TraceParent::from_axum_headers(&headers);
assert_eq!(traceparent.trace_id, Some("test-trace-id".to_string()));
assert_eq!(traceparent.tracestate, Some("test-state".to_string()));
assert_eq!(traceparent.x_request_id, Some("req-123".to_string()));
assert_eq!(traceparent.x_dynamo_request_id, Some("dyn-456".to_string()));
}
#[test]
fn test_shared_http_server_creation() {
use std::net::{IpAddr, Ipv4Addr};
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let token = CancellationToken::new();
let server = SharedHttpServer::new(bind_addr, token);
assert!(server.handlers.is_empty());
}
#[tokio::test]
async fn test_bind_and_start_assigns_os_port() {
use std::net::{IpAddr, Ipv4Addr};
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let token = CancellationToken::new();
let server = SharedHttpServer::new(bind_addr, token.clone());
let actual_addr = server.clone().bind_and_start().await.unwrap();
assert_ne!(actual_addr.port(), 0);
assert_eq!(server.actual_address(), Some(actual_addr));
let addr_str =
<SharedHttpServer as super::unified_server::RequestPlaneServer>::address(&*server);
assert!(addr_str.contains(&actual_addr.port().to_string()));
token.cancel();
}
#[tokio::test]
async fn test_two_servers_get_different_ports() {
use std::net::{IpAddr, Ipv4Addr};
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let token1 = CancellationToken::new();
let token2 = CancellationToken::new();
let server1 = SharedHttpServer::new(addr, token1.clone());
let server2 = SharedHttpServer::new(addr, token2.clone());
let actual1 = server1.clone().bind_and_start().await.unwrap();
let actual2 = server2.clone().bind_and_start().await.unwrap();
assert_ne!(actual1.port(), actual2.port());
token1.cancel();
token2.cancel();
}
#[tokio::test]
async fn test_bind_and_start_with_explicit_port() {
use std::net::{IpAddr, Ipv4Addr};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let free_port = listener.local_addr().unwrap().port();
drop(listener);
let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), free_port);
let token = CancellationToken::new();
let server = SharedHttpServer::new(bind_addr, token.clone());
let actual_addr = server.clone().bind_and_start().await.unwrap();
assert_eq!(actual_addr.port(), free_port);
token.cancel();
}
}