http_endpoint_server_harness/adapters/gateways/axum/
server.rs1use async_trait::async_trait;
2use axum::{
3 body::Body,
4 extract::State,
5 http::{Request as AxumRequest, StatusCode},
6 response::IntoResponse,
7 routing::MethodRouter,
8 Router,
9};
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::{
13 atomic::{AtomicUsize, Ordering},
14 Arc,
15};
16use tokio::sync::{oneshot, Mutex};
17
18use crate::entities::{Endpoint, Handler, Method, Request};
19use crate::error::HarnessError;
20use crate::use_cases::ports::{Collector, Server};
21
22#[derive(Clone)]
24pub struct Axum {
25 addr: SocketAddr,
26}
27
28impl Axum {
29 pub fn new(addr: SocketAddr) -> Self {
30 Self { addr }
31 }
32
33 pub fn bind(addr: impl Into<SocketAddr>) -> Self {
34 Self::new(addr.into())
35 }
36}
37
38impl Default for Axum {
39 fn default() -> Self {
40 Self::new(([127, 0, 0, 1], 0).into())
41 }
42}
43
44#[derive(Clone)]
46struct CompletionTracker {
47 total_handlers: usize,
49 handlers_called: Arc<AtomicUsize>,
51 shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
53}
54
55impl CompletionTracker {
56 fn new(total_handlers: usize, shutdown_tx: oneshot::Sender<()>) -> Self {
57 Self {
58 total_handlers,
59 handlers_called: Arc::new(AtomicUsize::new(0)),
60 shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
61 }
62 }
63
64 async fn handler_called(&self) {
66 let called = self.handlers_called.fetch_add(1, Ordering::SeqCst) + 1;
67 if called >= self.total_handlers {
68 if let Some(tx) = self.shutdown_tx.lock().await.take() {
70 let _ = tx.send(());
71 }
72 }
73 }
74}
75
76trait ErasedCollector: Send + Sync {
78 fn collect(&self, request: Request);
79}
80
81impl<C: Collector> ErasedCollector for std::sync::Mutex<Option<C>> {
82 fn collect(&self, request: Request) {
83 if let Ok(guard) = self.lock() {
84 if let Some(ref collector) = *guard {
85 collector.collect(request);
86 }
87 }
88 }
89}
90
91#[derive(Clone)]
93struct EndpointState {
94 handlers: Arc<Vec<Handler>>,
95 call_count: Arc<AtomicUsize>,
96 collector: Arc<dyn ErasedCollector>,
97 completion_tracker: CompletionTracker,
98}
99
100async fn handle_request(
101 State(state): State<EndpointState>,
102 request: AxumRequest<Body>,
103) -> impl IntoResponse {
104 let method = match request.method().as_str() {
106 "GET" => Method::Get,
107 "POST" => Method::Post,
108 "PUT" => Method::Put,
109 "PATCH" => Method::Patch,
110 "DELETE" => Method::Delete,
111 "HEAD" => Method::Head,
112 "OPTIONS" => Method::Options,
113 _ => Method::Get,
114 };
115
116 let path = request.uri().path().to_string();
117 let headers: HashMap<String, String> = request
118 .headers()
119 .iter()
120 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
121 .collect();
122
123 let body = axum::body::to_bytes(request.into_body(), usize::MAX)
124 .await
125 .map(|b| b.to_vec())
126 .unwrap_or_default();
127
128 let collected_request = Request {
130 method,
131 path,
132 headers,
133 body,
134 };
135 state.collector.collect(collected_request.clone());
136
137 let call_index = state.call_count.fetch_add(1, Ordering::SeqCst);
139 let handler_count = state.handlers.len();
140 let handler_index = call_index.min(handler_count.saturating_sub(1));
141
142 if call_index < handler_count {
144 state.completion_tracker.handler_called().await;
145 }
146
147 if let Some(handler) = state.handlers.get(handler_index) {
148 let response = handler.respond(&collected_request);
149 let status = StatusCode::from_u16(response.status).unwrap_or(StatusCode::OK);
150 let mut builder = axum::http::Response::builder().status(status);
151
152 for (key, value) in &response.headers {
153 builder = builder.header(key.as_str(), value.as_str());
154 }
155
156 builder
157 .body(Body::from(response.body.clone()))
158 .unwrap_or_else(|_| {
159 axum::http::Response::builder()
160 .status(StatusCode::INTERNAL_SERVER_ERROR)
161 .body(Body::empty())
162 .unwrap()
163 })
164 } else {
165 axum::http::Response::builder()
166 .status(StatusCode::NOT_FOUND)
167 .body(Body::from("No handler configured"))
168 .unwrap()
169 }
170}
171
172fn create_method_router(method: Method) -> MethodRouter<EndpointState> {
173 match method {
174 Method::Get => axum::routing::get(handle_request),
175 Method::Post => axum::routing::post(handle_request),
176 Method::Put => axum::routing::put(handle_request),
177 Method::Patch => axum::routing::patch(handle_request),
178 Method::Delete => axum::routing::delete(handle_request),
179 Method::Head => axum::routing::head(handle_request),
180 Method::Options => axum::routing::options(handle_request),
181 }
182}
183
184#[async_trait]
185impl Server for Axum {
186 async fn run<C, F>(
187 &self,
188 endpoints: Vec<Endpoint>,
189 collector: C,
190 on_ready: Option<F>,
191 ) -> Result<C::Output, HarnessError>
192 where
193 C: Collector + 'static,
194 F: FnOnce(SocketAddr) + Send + 'static,
195 {
196 let collector_holder: Arc<std::sync::Mutex<Option<C>>> =
198 Arc::new(std::sync::Mutex::new(Some(collector)));
199 let erased_collector: Arc<dyn ErasedCollector> = collector_holder.clone();
200
201 let total_handlers: usize = endpoints.iter().map(|e| e.handlers.len().max(1)).sum();
203
204 let (auto_shutdown_tx, auto_shutdown_rx) = oneshot::channel();
206 let completion_tracker = CompletionTracker::new(total_handlers, auto_shutdown_tx);
207
208 let mut router: Router<EndpointState> = Router::new();
209
210 for endpoint in endpoints {
211 let state = EndpointState {
212 handlers: Arc::new(endpoint.handlers),
213 call_count: Arc::new(AtomicUsize::new(0)),
214 collector: erased_collector.clone(),
215 completion_tracker: completion_tracker.clone(),
216 };
217
218 let method_router = create_method_router(endpoint.method);
219 router = router.route(&endpoint.path, method_router).with_state(state);
220 }
221
222 let router = router.with_state(EndpointState {
224 handlers: Arc::new(vec![]),
225 call_count: Arc::new(AtomicUsize::new(0)),
226 collector: erased_collector.clone(),
227 completion_tracker: completion_tracker.clone(),
228 });
229
230 let listener = tokio::net::TcpListener::bind(self.addr)
231 .await
232 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
233
234 let addr = listener
235 .local_addr()
236 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
237
238 if let Some(callback) = on_ready {
240 callback(addr);
241 }
242
243 axum::serve(listener, router)
245 .with_graceful_shutdown(async {
246 auto_shutdown_rx.await.ok();
247 })
248 .await
249 .map_err(|e| HarnessError::ServerError(e.to_string()))?;
250
251 let collector = collector_holder
253 .lock()
254 .map_err(|e| HarnessError::ServerError(e.to_string()))?
255 .take()
256 .ok_or_else(|| HarnessError::ServerError("Collector already taken".to_string()))?;
257
258 Ok(collector.into_output())
259 }
260}
261