mas_http/
ext.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{ops::RangeBounds, sync::OnceLock};
16
17use http::{header::HeaderName, Request, Response, StatusCode};
18use tower::Service;
19use tower_http::cors::CorsLayer;
20
21use crate::layers::{
22    body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
23    catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
24    json_request::JsonRequest, json_response::JsonResponse,
25};
26
27static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
28
29/// Notify the CORS layer what opentelemetry propagators are being used. This
30/// helps whitelisting headers in CORS requests.
31///
32/// # Panics
33///
34/// When called twice
35pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
36    let headers = propagator
37        .fields()
38        .map(|h| HeaderName::try_from(h).unwrap())
39        .collect();
40
41    tracing::debug!(
42        ?headers,
43        "Headers allowed in CORS requests for trace propagators set"
44    );
45    PROPAGATOR_HEADERS
46        .set(headers)
47        .expect(concat!(module_path!(), "::set_propagator was called twice"));
48}
49
50pub trait CorsLayerExt {
51    #[must_use]
52    fn allow_otel_headers<H>(self, headers: H) -> Self
53    where
54        H: IntoIterator<Item = HeaderName>;
55}
56
57impl CorsLayerExt for CorsLayer {
58    fn allow_otel_headers<H>(self, headers: H) -> Self
59    where
60        H: IntoIterator<Item = HeaderName>,
61    {
62        let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
63        let headers: Vec<_> = headers.into_iter().chain(base).collect();
64        self.allow_headers(headers)
65    }
66}
67
68pub trait ServiceExt<Body>: Sized {
69    fn request_bytes_to_body(self) -> BytesToBodyRequest<Self> {
70        BytesToBodyRequest::new(self)
71    }
72
73    /// Adds a layer which collects all the response body into a contiguous
74    /// byte buffer.
75    /// This makes the response type `Response<Bytes>`.
76    fn response_body_to_bytes(self) -> BodyToBytesResponse<Self> {
77        BodyToBytesResponse::new(self)
78    }
79
80    fn json_response<T>(self) -> JsonResponse<Self, T> {
81        JsonResponse::new(self)
82    }
83
84    fn json_request<T>(self) -> JsonRequest<Self, T> {
85        JsonRequest::new(self)
86    }
87
88    fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
89        FormUrlencodedRequest::new(self)
90    }
91
92    /// Catches responses with the given status code and then maps those
93    /// responses to an error type using the provided `mapper` function.
94    fn catch_http_code<M, ResBody, E>(
95        self,
96        status_code: StatusCode,
97        mapper: M,
98    ) -> CatchHttpCodes<Self, M>
99    where
100        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
101    {
102        self.catch_http_codes(status_code..=status_code, mapper)
103    }
104
105    /// Catches responses with the given status codes and then maps those
106    /// responses to an error type using the provided `mapper` function.
107    fn catch_http_codes<B, M, ResBody, E>(self, bounds: B, mapper: M) -> CatchHttpCodes<Self, M>
108    where
109        B: RangeBounds<StatusCode>,
110        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
111    {
112        CatchHttpCodes::new(self, bounds, mapper)
113    }
114
115    /// Shorthand for [`Self::catch_http_codes`] which catches all client errors
116    /// (4xx) and server errors (5xx).
117    fn catch_http_errors<M, ResBody, E>(self, mapper: M) -> CatchHttpCodes<Self, M>
118    where
119        M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
120    {
121        self.catch_http_codes(
122            StatusCode::from_u16(400).unwrap()..StatusCode::from_u16(600).unwrap(),
123            mapper,
124        )
125    }
126}
127
128impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}