mas_http/layers/
catch_http_codes.rs1use std::ops::{Bound, RangeBounds};
16
17use futures_util::FutureExt;
18use http::{Request, Response, StatusCode};
19use thiserror::Error;
20use tower::{Layer, Service};
21
22#[derive(Debug, Error)]
23pub enum Error<S, E> {
24 #[error(transparent)]
25 Service { inner: S },
26
27 #[error("request failed with status {status_code}: {inner}")]
28 HttpError { status_code: StatusCode, inner: E },
29}
30
31impl<S, E> Error<S, E> {
32 fn service(inner: S) -> Self {
33 Self::Service { inner }
34 }
35
36 pub fn status_code(&self) -> Option<StatusCode> {
37 match self {
38 Self::Service { .. } => None,
39 Self::HttpError { status_code, .. } => Some(*status_code),
40 }
41 }
42}
43
44#[derive(Clone)]
47pub struct CatchHttpCodes<S, M> {
48 inner: S,
50 bounds: (Bound<StatusCode>, Bound<StatusCode>),
52 mapper: M,
55}
56
57impl<S, M> CatchHttpCodes<S, M> {
58 pub fn new<B>(inner: S, bounds: B, mapper: M) -> Self
59 where
60 B: RangeBounds<StatusCode>,
61 M: Clone,
62 {
63 let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
64 Self {
65 inner,
66 bounds,
67 mapper,
68 }
69 }
70}
71
72impl<S, M, E, ReqBody, ResBody> Service<Request<ReqBody>> for CatchHttpCodes<S, M>
73where
74 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
75 S::Future: Send + 'static,
76 M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
77{
78 type Error = Error<S::Error, E>;
79 type Response = Response<ResBody>;
80 type Future = futures_util::future::Map<
81 S::Future,
82 Box<
83 dyn Fn(Result<S::Response, S::Error>) -> Result<Self::Response, Self::Error>
84 + Send
85 + 'static,
86 >,
87 >;
88
89 fn poll_ready(
90 &mut self,
91 cx: &mut std::task::Context<'_>,
92 ) -> std::task::Poll<Result<(), Self::Error>> {
93 self.inner.poll_ready(cx).map_err(Error::service)
94 }
95
96 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
97 let fut = self.inner.call(request);
98 let bounds = self.bounds;
99 let mapper = self.mapper.clone();
100
101 fut.map(Box::new(move |res: Result<S::Response, S::Error>| {
102 let response = res.map_err(Error::service)?;
103 let status_code = response.status();
104
105 if bounds.contains(&status_code) {
106 let inner = mapper(response);
107 Err(Error::HttpError { status_code, inner })
108 } else {
109 Ok(response)
110 }
111 }))
112 }
113}
114
115#[derive(Clone)]
116pub struct CatchHttpCodesLayer<M> {
117 bounds: (Bound<StatusCode>, Bound<StatusCode>),
118 mapper: M,
119}
120
121impl<M> CatchHttpCodesLayer<M>
122where
123 M: Clone,
124{
125 pub fn new<B>(bounds: B, mapper: M) -> Self
126 where
127 B: RangeBounds<StatusCode>,
128 {
129 let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
130 Self { bounds, mapper }
131 }
132
133 pub fn exact(status_code: StatusCode, mapper: M) -> Self {
134 Self::new(status_code..=status_code, mapper)
135 }
136}
137
138impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
139where
140 M: Clone,
141{
142 type Service = CatchHttpCodes<S, M>;
143
144 fn layer(&self, inner: S) -> Self::Service {
145 CatchHttpCodes::new(inner, self.bounds, self.mapper.clone())
146 }
147}