mas_http/layers/
catch_http_codes.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::{Bound, RangeBounds};
16
17use futures_util::FutureExt;
18use http::{Request, Response, StatusCode};
19use thiserror::Error;
20use tower::{Layer, Service};
21
22#[derive(Debug, Error)]
23pub enum Error<S, E> {
24    #[error(transparent)]
25    Service { inner: S },
26
27    #[error("request failed with status {status_code}: {inner}")]
28    HttpError { status_code: StatusCode, inner: E },
29}
30
31impl<S, E> Error<S, E> {
32    fn service(inner: S) -> Self {
33        Self::Service { inner }
34    }
35
36    pub fn status_code(&self) -> Option<StatusCode> {
37        match self {
38            Self::Service { .. } => None,
39            Self::HttpError { status_code, .. } => Some(*status_code),
40        }
41    }
42}
43
44/// A layer that catches responses with the HTTP status codes lying within
45/// `bounds` and then maps the requests into a custom error type using `mapper`.
46#[derive(Clone)]
47pub struct CatchHttpCodes<S, M> {
48    /// The inner service
49    inner: S,
50    /// Which HTTP status codes to catch
51    bounds: (Bound<StatusCode>, Bound<StatusCode>),
52    /// The function used to convert errors, which must be
53    /// `Fn(Response<ResBody>) -> E + Send + Clone + 'static`.
54    mapper: M,
55}
56
57impl<S, M> CatchHttpCodes<S, M> {
58    pub fn new<B>(inner: S, bounds: B, mapper: M) -> Self
59    where
60        B: RangeBounds<StatusCode>,
61        M: Clone,
62    {
63        let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
64        Self {
65            inner,
66            bounds,
67            mapper,
68        }
69    }
70}
71
72impl<S, M, E, ReqBody, ResBody> Service<Request<ReqBody>> for CatchHttpCodes<S, M>
73where
74    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
75    S::Future: Send + 'static,
76    M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
77{
78    type Error = Error<S::Error, E>;
79    type Response = Response<ResBody>;
80    type Future = futures_util::future::Map<
81        S::Future,
82        Box<
83            dyn Fn(Result<S::Response, S::Error>) -> Result<Self::Response, Self::Error>
84                + Send
85                + 'static,
86        >,
87    >;
88
89    fn poll_ready(
90        &mut self,
91        cx: &mut std::task::Context<'_>,
92    ) -> std::task::Poll<Result<(), Self::Error>> {
93        self.inner.poll_ready(cx).map_err(Error::service)
94    }
95
96    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
97        let fut = self.inner.call(request);
98        let bounds = self.bounds;
99        let mapper = self.mapper.clone();
100
101        fut.map(Box::new(move |res: Result<S::Response, S::Error>| {
102            let response = res.map_err(Error::service)?;
103            let status_code = response.status();
104
105            if bounds.contains(&status_code) {
106                let inner = mapper(response);
107                Err(Error::HttpError { status_code, inner })
108            } else {
109                Ok(response)
110            }
111        }))
112    }
113}
114
115#[derive(Clone)]
116pub struct CatchHttpCodesLayer<M> {
117    bounds: (Bound<StatusCode>, Bound<StatusCode>),
118    mapper: M,
119}
120
121impl<M> CatchHttpCodesLayer<M>
122where
123    M: Clone,
124{
125    pub fn new<B>(bounds: B, mapper: M) -> Self
126    where
127        B: RangeBounds<StatusCode>,
128    {
129        let bounds = (bounds.start_bound().cloned(), bounds.end_bound().cloned());
130        Self { bounds, mapper }
131    }
132
133    pub fn exact(status_code: StatusCode, mapper: M) -> Self {
134        Self::new(status_code..=status_code, mapper)
135    }
136}
137
138impl<S, M> Layer<S> for CatchHttpCodesLayer<M>
139where
140    M: Clone,
141{
142    type Service = CatchHttpCodes<S, M>;
143
144    fn layer(&self, inner: S) -> Self::Service {
145        CatchHttpCodes::new(inner, self.bounds, self.mapper.clone())
146    }
147}