1use crate::headers::HeaderMap as ArmatureHeaderMap;
31use crate::http::{HttpRequest, HttpResponse};
32use bytes::Bytes;
33use http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode};
34use http_body_util::Full;
35use std::collections::HashMap;
36use std::convert::Infallible;
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::sync::atomic::{AtomicU64, Ordering};
41use std::task::{Context, Poll};
42use tower_service::Service;
43
44pub trait IntoHttpRequest {
50 fn into_http_request(self) -> Request<Bytes>;
52}
53
54impl IntoHttpRequest for HttpRequest {
55 fn into_http_request(self) -> Request<Bytes> {
56 let mut builder = Request::builder()
57 .method(self.method.as_str())
58 .uri(&self.path);
59
60 if let Some(headers) = builder.headers_mut() {
62 for (name, value) in &self.headers {
63 if let (Ok(name), Ok(value)) = (
64 HeaderName::try_from(name.as_str()),
65 HeaderValue::try_from(value.as_str()),
66 ) {
67 headers.insert(name, value);
68 }
69 }
70 }
71
72 builder
73 .body(self.body_bytes())
74 .unwrap_or_else(|_| Request::new(Bytes::new()))
75 }
76}
77
78pub trait FromHttpRequest {
80 fn from_http_request(req: Request<Bytes>) -> Self;
82}
83
84impl FromHttpRequest for HttpRequest {
85 fn from_http_request(req: Request<Bytes>) -> Self {
86 let method = req.method().as_str().to_string();
87 let path = req.uri().path().to_string();
88
89 let query_params: HashMap<String, String> = req
91 .uri()
92 .query()
93 .map(|q| serde_urlencoded::from_str(q).unwrap_or_default())
94 .unwrap_or_default();
95
96 let mut headers = HashMap::new();
98 for (name, value) in req.headers() {
99 if let Ok(v) = value.to_str() {
100 headers.insert(name.as_str().to_string(), v.to_string());
101 }
102 }
103
104 let body = req.into_body();
105
106 HttpRequest::with_bytes_body(method, path, body)
107 .with_headers_map(headers)
108 .with_query_params(query_params)
109 }
110}
111
112trait HttpRequestBuilderExt {
114 fn with_headers_map(self, headers: HashMap<String, String>) -> Self;
115 fn with_query_params(self, params: HashMap<String, String>) -> Self;
116}
117
118impl HttpRequestBuilderExt for HttpRequest {
119 fn with_headers_map(mut self, headers: HashMap<String, String>) -> Self {
120 self.headers = headers;
121 self
122 }
123
124 fn with_query_params(mut self, params: HashMap<String, String>) -> Self {
125 self.query_params = params;
126 self
127 }
128}
129
130pub trait IntoHttpResponse {
132 fn into_http_response(self) -> Response<Full<Bytes>>;
134}
135
136impl IntoHttpResponse for HttpResponse {
137 fn into_http_response(self) -> Response<Full<Bytes>> {
138 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
139
140 let mut builder = Response::builder().status(status);
141
142 if let Some(headers) = builder.headers_mut() {
144 for (name, value) in &self.headers {
145 if let (Ok(name), Ok(value)) = (
146 HeaderName::try_from(name.as_str()),
147 HeaderValue::try_from(value.as_str()),
148 ) {
149 headers.insert(name, value);
150 }
151 }
152 }
153
154 builder
155 .body(Full::new(self.into_body_bytes()))
156 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
157 }
158}
159
160pub trait HttpResponseFromHttp {
162 fn from_http_response(resp: Response<Bytes>) -> Self;
164}
165
166impl HttpResponseFromHttp for HttpResponse {
167 fn from_http_response(resp: Response<Bytes>) -> Self {
168 let status = resp.status().as_u16();
169
170 let mut headers = HashMap::new();
171 for (name, value) in resp.headers() {
172 if let Ok(v) = value.to_str() {
173 headers.insert(name.as_str().to_string(), v.to_string());
174 }
175 }
176
177 let body_bytes = resp.into_body();
178
179 HttpResponse::new(status)
180 .with_headers(headers)
181 .with_bytes_body(body_bytes)
182 }
183}
184
185pub trait HeaderMapExt {
187 fn to_armature_headers(&self) -> ArmatureHeaderMap;
189}
190
191impl HeaderMapExt for HeaderMap {
192 fn to_armature_headers(&self) -> ArmatureHeaderMap {
193 let mut result = ArmatureHeaderMap::new();
194 for (name, value) in self {
195 if let Ok(v) = value.to_str() {
196 result.insert(name.as_str(), v);
197 }
198 }
199 result
200 }
201}
202
203pub trait ArmatureHeaderMapExt {
205 fn to_http_headers(&self) -> HeaderMap;
207}
208
209impl ArmatureHeaderMapExt for ArmatureHeaderMap {
210 fn to_http_headers(&self) -> HeaderMap {
211 let mut result = HeaderMap::new();
212 for (name, value) in self.iter() {
213 if let (Ok(name), Ok(value)) = (
214 HeaderName::try_from(name.as_str()),
215 HeaderValue::try_from(value.as_str()),
216 ) {
217 result.insert(name, value);
218 }
219 }
220 result
221 }
222}
223
224pub type BoxedHandler =
230 Box<dyn Fn(HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> + Send + Sync>;
231
232pub struct ArmatureService<H> {
234 handler: Arc<H>,
235}
236
237impl<H> ArmatureService<H> {
238 pub fn new(handler: H) -> Self {
240 Self {
241 handler: Arc::new(handler),
242 }
243 }
244}
245
246impl<H> Clone for ArmatureService<H> {
247 fn clone(&self) -> Self {
248 Self {
249 handler: Arc::clone(&self.handler),
250 }
251 }
252}
253
254impl<H, Fut> Service<Request<Bytes>> for ArmatureService<H>
256where
257 H: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
258 Fut: Future<Output = HttpResponse> + Send + 'static,
259{
260 type Response = Response<Full<Bytes>>;
261 type Error = Infallible;
262 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
263
264 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
265 TOWER_STATS.record_poll_ready();
266 Poll::Ready(Ok(()))
267 }
268
269 fn call(&mut self, req: Request<Bytes>) -> Self::Future {
270 let handler = Arc::clone(&self.handler);
271 TOWER_STATS.record_call();
272
273 Box::pin(async move {
274 let armature_req = <HttpRequest as FromHttpRequest>::from_http_request(req);
275 let armature_resp = handler(armature_req).await;
276 Ok(armature_resp.into_http_response())
277 })
278 }
279}
280
281pub struct HyperServiceAdapter<H> {
287 handler: Arc<H>,
288}
289
290impl<H> HyperServiceAdapter<H> {
291 pub fn new(handler: H) -> Self {
293 Self {
294 handler: Arc::new(handler),
295 }
296 }
297}
298
299impl<H> Clone for HyperServiceAdapter<H> {
300 fn clone(&self) -> Self {
301 Self {
302 handler: Arc::clone(&self.handler),
303 }
304 }
305}
306
307impl<H, Fut> Service<hyper::Request<hyper::body::Incoming>> for HyperServiceAdapter<H>
308where
309 H: Fn(HttpRequest) -> Fut + Send + Sync + 'static,
310 Fut: Future<Output = HttpResponse> + Send + 'static,
311{
312 type Response = hyper::Response<Full<Bytes>>;
313 type Error = Infallible;
314 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
315
316 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
317 Poll::Ready(Ok(()))
318 }
319
320 fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
321 let handler = Arc::clone(&self.handler);
322 TOWER_STATS.record_hyper_call();
323
324 Box::pin(async move {
325 let (parts, body) = req.into_parts();
327 let body_bytes = match http_body_util::BodyExt::collect(body).await {
328 Ok(collected) => collected.to_bytes(),
329 Err(_) => Bytes::new(),
330 };
331
332 let http_req = Request::from_parts(parts, body_bytes);
333 let armature_req = <HttpRequest as FromHttpRequest>::from_http_request(http_req);
334 let armature_resp = handler(armature_req).await;
335
336 Ok(armature_resp.into_http_response())
337 })
338 }
339}
340
341pub struct ServiceFactory<H> {
347 handler: Arc<H>,
348}
349
350impl<H> ServiceFactory<H> {
351 pub fn new(handler: H) -> Self {
353 Self {
354 handler: Arc::new(handler),
355 }
356 }
357}
358
359impl<H: Clone> Clone for ServiceFactory<H> {
360 fn clone(&self) -> Self {
361 Self {
362 handler: Arc::clone(&self.handler),
363 }
364 }
365}
366
367impl<H, Fut> Service<()> for ServiceFactory<H>
368where
369 H: Fn(HttpRequest) -> Fut + Send + Sync + Clone + 'static,
370 Fut: Future<Output = HttpResponse> + Send + 'static,
371{
372 type Response = HyperServiceAdapter<H>;
373 type Error = Infallible;
374 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
375
376 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
377 Poll::Ready(Ok(()))
378 }
379
380 fn call(&mut self, _: ()) -> Self::Future {
381 let handler = (*self.handler).clone();
382 std::future::ready(Ok(HyperServiceAdapter::new(handler)))
383 }
384}
385
386pub struct ArmatureLayer;
392
393impl ArmatureLayer {
394 pub fn new() -> Self {
396 Self
397 }
398}
399
400impl Default for ArmatureLayer {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406impl<S> tower::Layer<S> for ArmatureLayer {
407 type Service = ArmatureLayerService<S>;
408
409 fn layer(&self, inner: S) -> Self::Service {
410 ArmatureLayerService { inner }
411 }
412}
413
414pub struct ArmatureLayerService<S> {
416 inner: S,
417}
418
419impl<S: Clone> Clone for ArmatureLayerService<S> {
420 fn clone(&self) -> Self {
421 Self {
422 inner: self.inner.clone(),
423 }
424 }
425}
426
427impl<S, ReqBody> Service<Request<ReqBody>> for ArmatureLayerService<S>
428where
429 S: Service<Request<ReqBody>>,
430{
431 type Response = S::Response;
432 type Error = S::Error;
433 type Future = S::Future;
434
435 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 self.inner.poll_ready(cx)
437 }
438
439 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
440 self.inner.call(req)
441 }
442}
443
444#[derive(Debug, Default)]
450pub struct TowerStats {
451 poll_ready_calls: AtomicU64,
453 service_calls: AtomicU64,
455 hyper_calls: AtomicU64,
457 from_http_request: AtomicU64,
459 to_http_response: AtomicU64,
461}
462
463impl TowerStats {
464 fn record_poll_ready(&self) {
465 self.poll_ready_calls.fetch_add(1, Ordering::Relaxed);
466 }
467
468 fn record_call(&self) {
469 self.service_calls.fetch_add(1, Ordering::Relaxed);
470 }
471
472 fn record_hyper_call(&self) {
473 self.hyper_calls.fetch_add(1, Ordering::Relaxed);
474 }
475
476 #[allow(dead_code)]
477 fn record_from_http_request(&self) {
478 self.from_http_request.fetch_add(1, Ordering::Relaxed);
479 }
480
481 #[allow(dead_code)]
482 fn record_to_http_response(&self) {
483 self.to_http_response.fetch_add(1, Ordering::Relaxed);
484 }
485
486 pub fn poll_ready_calls(&self) -> u64 {
488 self.poll_ready_calls.load(Ordering::Relaxed)
489 }
490
491 pub fn service_calls(&self) -> u64 {
493 self.service_calls.load(Ordering::Relaxed)
494 }
495
496 pub fn hyper_calls(&self) -> u64 {
498 self.hyper_calls.load(Ordering::Relaxed)
499 }
500}
501
502static TOWER_STATS: TowerStats = TowerStats {
504 poll_ready_calls: AtomicU64::new(0),
505 service_calls: AtomicU64::new(0),
506 hyper_calls: AtomicU64::new(0),
507 from_http_request: AtomicU64::new(0),
508 to_http_response: AtomicU64::new(0),
509};
510
511pub fn tower_stats() -> &'static TowerStats {
513 &TOWER_STATS
514}
515
516#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_http_request_conversion() {
526 let http_req = Request::builder()
527 .method("POST")
528 .uri("/api/users?page=1")
529 .header("Content-Type", "application/json")
530 .body(Bytes::from(r#"{"name":"test"}"#))
531 .unwrap();
532
533 let armature_req = <HttpRequest as FromHttpRequest>::from_http_request(http_req);
534 assert_eq!(armature_req.method, "POST");
535 assert_eq!(armature_req.path, "/api/users");
536 }
537
538 #[test]
539 fn test_http_response_conversion() {
540 let armature_resp = HttpResponse::ok()
541 .with_header("Content-Type".to_string(), "text/plain".to_string())
542 .with_body("Hello, World!".as_bytes().to_vec());
543
544 let http_resp = armature_resp.into_http_response();
545 assert_eq!(http_resp.status(), StatusCode::OK);
546 }
547
548 #[test]
549 fn test_round_trip_request() {
550 let original = HttpRequest::new("GET".to_string(), "/test".to_string());
551
552 let http_req = original.clone().into_http_request();
553 let back = <HttpRequest as FromHttpRequest>::from_http_request(http_req);
554
555 assert_eq!(original.method, back.method);
556 assert_eq!(original.path, back.path);
557 }
558
559 #[test]
560 fn test_header_map_conversion() {
561 let mut http_headers = HeaderMap::new();
562 http_headers.insert(
563 HeaderName::from_static("content-type"),
564 HeaderValue::from_static("application/json"),
565 );
566
567 let armature_headers = http_headers.to_armature_headers();
568 assert!(armature_headers.get("content-type").is_some());
569
570 let back = armature_headers.to_http_headers();
571 assert!(back.get("content-type").is_some());
572 }
573
574 #[tokio::test]
575 async fn test_armature_service() {
576 async fn handler(_req: HttpRequest) -> HttpResponse {
577 HttpResponse::ok().with_body(b"OK".to_vec())
578 }
579
580 let mut service = ArmatureService::new(handler);
581
582 let req = Request::builder()
583 .method("GET")
584 .uri("/")
585 .body(Bytes::new())
586 .unwrap();
587
588 let resp = service.call(req).await.unwrap();
589 assert_eq!(resp.status(), StatusCode::OK);
590 }
591
592 #[test]
593 fn test_armature_layer() {
594 let _layer = ArmatureLayer::new();
595 }
596
597 #[test]
598 fn test_tower_stats() {
599 let stats = tower_stats();
600 let _ = stats.poll_ready_calls();
601 let _ = stats.service_calls();
602 let _ = stats.hyper_calls();
603 }
604}