http_endpoint_server_harness/adapters/gateways/axum/
server.rs

1use async_trait::async_trait;
2use axum::{
3    body::Body,
4    extract::State,
5    http::{Request as AxumRequest, StatusCode},
6    response::IntoResponse,
7    routing::MethodRouter,
8    Router,
9};
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::{
13    atomic::{AtomicUsize, Ordering},
14    Arc,
15};
16use tokio::sync::{oneshot, Mutex};
17
18use crate::entities::{Endpoint, Handler, Method, Request};
19use crate::error::HarnessError;
20use crate::use_cases::ports::{Collector, Server};
21
22/// Axum-based HTTP server implementation
23#[derive(Clone)]
24pub struct Axum {
25    addr: SocketAddr,
26}
27
28impl Axum {
29    pub fn new(addr: SocketAddr) -> Self {
30        Self { addr }
31    }
32
33    pub fn bind(addr: impl Into<SocketAddr>) -> Self {
34        Self::new(addr.into())
35    }
36}
37
38impl Default for Axum {
39    fn default() -> Self {
40        Self::new(([127, 0, 0, 1], 0).into())
41    }
42}
43
44/// Shared state for tracking completion
45#[derive(Clone)]
46struct CompletionTracker {
47    /// Total number of handlers across all endpoints
48    total_handlers: usize,
49    /// Number of handlers that have been called at least once
50    handlers_called: Arc<AtomicUsize>,
51    /// Shutdown signal sender (wrapped in Mutex for Clone)
52    shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
53}
54
55impl CompletionTracker {
56    fn new(total_handlers: usize, shutdown_tx: oneshot::Sender<()>) -> Self {
57        Self {
58            total_handlers,
59            handlers_called: Arc::new(AtomicUsize::new(0)),
60            shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
61        }
62    }
63
64    /// Called when a handler is used for the first time
65    async fn handler_called(&self) {
66        let called = self.handlers_called.fetch_add(1, Ordering::SeqCst) + 1;
67        if called >= self.total_handlers {
68            // All handlers have been called, trigger shutdown
69            if let Some(tx) = self.shutdown_tx.lock().await.take() {
70                let _ = tx.send(());
71            }
72        }
73    }
74}
75
76/// Type-erased collector trait for internal use
77trait ErasedCollector: Send + Sync {
78    fn collect(&self, request: Request);
79}
80
81impl<C: Collector> ErasedCollector for std::sync::Mutex<Option<C>> {
82    fn collect(&self, request: Request) {
83        if let Ok(guard) = self.lock() {
84            if let Some(ref collector) = *guard {
85                collector.collect(request);
86            }
87        }
88    }
89}
90
91/// State shared with Axum handlers using type erasure
92#[derive(Clone)]
93struct EndpointState {
94    handlers: Arc<Vec<Handler>>,
95    call_count: Arc<AtomicUsize>,
96    collector: Arc<dyn ErasedCollector>,
97    completion_tracker: CompletionTracker,
98}
99
100async fn handle_request(
101    State(state): State<EndpointState>,
102    request: AxumRequest<Body>,
103) -> impl IntoResponse {
104    // Parse method
105    let method = match request.method().as_str() {
106        "GET" => Method::Get,
107        "POST" => Method::Post,
108        "PUT" => Method::Put,
109        "PATCH" => Method::Patch,
110        "DELETE" => Method::Delete,
111        "HEAD" => Method::Head,
112        "OPTIONS" => Method::Options,
113        _ => Method::Get,
114    };
115
116    let path = request.uri().path().to_string();
117    let headers: HashMap<String, String> = request
118        .headers()
119        .iter()
120        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
121        .collect();
122
123    let body = axum::body::to_bytes(request.into_body(), usize::MAX)
124        .await
125        .map(|b| b.to_vec())
126        .unwrap_or_default();
127
128    // Collect the request
129    let collected_request = Request {
130        method,
131        path,
132        headers,
133        body,
134    };
135    state.collector.collect(collected_request.clone());
136
137    // Get the response from the handler (sequential through handlers)
138    let call_index = state.call_count.fetch_add(1, Ordering::SeqCst);
139    let handler_count = state.handlers.len();
140    let handler_index = call_index.min(handler_count.saturating_sub(1));
141
142    // Check if this is a new handler being called for the first time
143    if call_index < handler_count {
144        state.completion_tracker.handler_called().await;
145    }
146
147    if let Some(handler) = state.handlers.get(handler_index) {
148        let response = handler.respond(&collected_request);
149        let status = StatusCode::from_u16(response.status).unwrap_or(StatusCode::OK);
150        let mut builder = axum::http::Response::builder().status(status);
151
152        for (key, value) in &response.headers {
153            builder = builder.header(key.as_str(), value.as_str());
154        }
155
156        builder
157            .body(Body::from(response.body.clone()))
158            .unwrap_or_else(|_| {
159                axum::http::Response::builder()
160                    .status(StatusCode::INTERNAL_SERVER_ERROR)
161                    .body(Body::empty())
162                    .unwrap()
163            })
164    } else {
165        axum::http::Response::builder()
166            .status(StatusCode::NOT_FOUND)
167            .body(Body::from("No handler configured"))
168            .unwrap()
169    }
170}
171
172fn create_method_router(method: Method) -> MethodRouter<EndpointState> {
173    match method {
174        Method::Get => axum::routing::get(handle_request),
175        Method::Post => axum::routing::post(handle_request),
176        Method::Put => axum::routing::put(handle_request),
177        Method::Patch => axum::routing::patch(handle_request),
178        Method::Delete => axum::routing::delete(handle_request),
179        Method::Head => axum::routing::head(handle_request),
180        Method::Options => axum::routing::options(handle_request),
181    }
182}
183
184#[async_trait]
185impl Server for Axum {
186    async fn run<C, F>(
187        &self,
188        endpoints: Vec<Endpoint>,
189        collector: C,
190        on_ready: Option<F>,
191    ) -> Result<C::Output, HarnessError>
192    where
193        C: Collector + 'static,
194        F: FnOnce(SocketAddr) + Send + 'static,
195    {
196        // Wrap collector in Mutex<Option<C>> so we can take it out at the end
197        let collector_holder: Arc<std::sync::Mutex<Option<C>>> =
198            Arc::new(std::sync::Mutex::new(Some(collector)));
199        let erased_collector: Arc<dyn ErasedCollector> = collector_holder.clone();
200
201        // Count total handlers
202        let total_handlers: usize = endpoints.iter().map(|e| e.handlers.len().max(1)).sum();
203
204        // Create shutdown channel for auto-shutdown
205        let (auto_shutdown_tx, auto_shutdown_rx) = oneshot::channel();
206        let completion_tracker = CompletionTracker::new(total_handlers, auto_shutdown_tx);
207
208        let mut router: Router<EndpointState> = Router::new();
209
210        for endpoint in endpoints {
211            let state = EndpointState {
212                handlers: Arc::new(endpoint.handlers),
213                call_count: Arc::new(AtomicUsize::new(0)),
214                collector: erased_collector.clone(),
215                completion_tracker: completion_tracker.clone(),
216            };
217
218            let method_router = create_method_router(endpoint.method);
219            router = router.route(&endpoint.path, method_router).with_state(state);
220        }
221
222        // Convert to Router<()> for serving
223        let router = router.with_state(EndpointState {
224            handlers: Arc::new(vec![]),
225            call_count: Arc::new(AtomicUsize::new(0)),
226            collector: erased_collector.clone(),
227            completion_tracker: completion_tracker.clone(),
228        });
229
230        let listener = tokio::net::TcpListener::bind(self.addr)
231            .await
232            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
233
234        let addr = listener
235            .local_addr()
236            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
237
238        // Call the on_ready callback if provided
239        if let Some(callback) = on_ready {
240            callback(addr);
241        }
242
243        // Serve and wait for auto-shutdown
244        axum::serve(listener, router)
245            .with_graceful_shutdown(async {
246                auto_shutdown_rx.await.ok();
247            })
248            .await
249            .map_err(|e| HarnessError::ServerError(e.to_string()))?;
250
251        // Extract the collector and return its output
252        let collector = collector_holder
253            .lock()
254            .map_err(|e| HarnessError::ServerError(e.to_string()))?
255            .take()
256            .ok_or_else(|| HarnessError::ServerError("Collector already taken".to_string()))?;
257
258        Ok(collector.into_output())
259    }
260}
261