1use tower::{Service, ServiceExt};
4use vidi_core::{Body, BoxError, Bytes, Error, Handler, HttpBody, Request, Response, Result};
5
6mod service;
7pub use service::HandlerService;
8
9mod middleware;
10pub use middleware::Middleware;
11
12mod layer;
13pub use layer::Layered;
14
15#[derive(Clone, Debug)]
17pub struct ServiceHandler<S>(S);
18
19impl<S> ServiceHandler<S> {
20 pub const fn new(s: S) -> Self {
22 Self(s)
23 }
24}
25
26#[vidi_core::async_trait]
27impl<O, S> Handler<Request> for ServiceHandler<S>
28where
29 O: HttpBody + Send + 'static,
30 O::Data: Into<Bytes>,
31 O::Error: Into<BoxError>,
32 S: Service<Request, Response = Response<O>> + Send + Sync + Clone + 'static,
33 S::Future: Send,
34 S::Error: Into<BoxError>,
35{
36 type Output = Result<Response>;
37
38 async fn call(&self, req: Request) -> Self::Output {
39 self.0
40 .clone()
41 .oneshot(req)
42 .await
43 .map_err(Error::boxed)
44 .map(|resp| resp.map(Body::wrap))
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51 use std::{
52 sync::{
53 Arc,
54 atomic::{AtomicU64, Ordering},
55 },
56 time::Duration,
57 };
58 use tower::util::{MapErrLayer, MapRequestLayer, MapResponseLayer};
59 use tower::{ServiceBuilder, service_fn};
60 use tower_http::{
61 limit::RequestBodyLimitLayer,
62 request_id::{MakeRequestId, RequestId, SetRequestIdLayer},
63 timeout::TimeoutLayer,
64 };
65 use vidi_core::{
66 Body, BoxHandler, Handler, HandlerExt, IntoResponse, Request, RequestExt, Response,
67 StatusCode,
68 };
69
70 #[derive(Clone, Debug, Default)]
71 struct MyMakeRequestId {
72 counter: Arc<AtomicU64>,
73 }
74
75 impl MakeRequestId for MyMakeRequestId {
76 fn make_request_id<B>(&mut self, _: &Request<B>) -> Option<RequestId> {
77 let request_id = self
78 .counter
79 .fetch_add(1, Ordering::SeqCst)
80 .to_string()
81 .parse()
82 .unwrap();
83
84 Some(RequestId::new(request_id))
85 }
86 }
87
88 async fn hello(mut req: Request) -> Result<Response> {
89 let bytes = req.bytes().await?;
90 Ok(bytes.into_response())
91 }
92
93 #[tokio::test]
94 async fn tower_service_into_handler() {
95 let hello_svc = service_fn(hello);
96
97 let svc = ServiceBuilder::new()
98 .layer(RequestBodyLimitLayer::new(1))
99 .layer(MapErrLayer::new(Error::from))
100 .layer(SetRequestIdLayer::x_request_id(MyMakeRequestId::default()))
101 .layer(MapResponseLayer::new(IntoResponse::into_response))
102 .layer(MapRequestLayer::new(|req: Request<_>| req.map(Body::wrap)))
103 .layer(TimeoutLayer::with_status_code(
104 StatusCode::REQUEST_TIMEOUT,
105 Duration::from_secs(10),
106 ))
107 .service(hello_svc);
108
109 let r0 = Request::new(Body::Full("12".into()));
110 let h0 = ServiceHandler::new(svc);
111 assert!(h0.call(r0).await.is_err());
112
113 let r1 = Request::new(Body::Full("1".into()));
114 let b0: BoxHandler = h0.boxed();
115 assert!(b0.call(r1).await.is_ok());
116 }
117}