mas_http/layers/
body_to_bytes_response.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::Bytes;
16use futures_util::future::BoxFuture;
17use http::{Request, Response};
18use http_body::Body;
19use http_body_util::BodyExt;
20use thiserror::Error;
21use tower::{Layer, Service};
22
23#[derive(Debug, Error)]
24pub enum Error<ServiceError, BodyError> {
25    #[error(transparent)]
26    Service { inner: ServiceError },
27
28    #[error(transparent)]
29    Body { inner: BodyError },
30}
31
32impl<S, B> Error<S, B> {
33    fn service(inner: S) -> Self {
34        Self::Service { inner }
35    }
36
37    fn body(inner: B) -> Self {
38        Self::Body { inner }
39    }
40}
41
42impl<E> Error<E, E> {
43    pub fn unify(self) -> E {
44        match self {
45            Self::Service { inner } | Self::Body { inner } => inner,
46        }
47    }
48}
49
50#[derive(Clone)]
51pub struct BodyToBytesResponse<S> {
52    inner: S,
53}
54
55impl<S> BodyToBytesResponse<S> {
56    pub const fn new(inner: S) -> Self {
57        Self { inner }
58    }
59}
60
61impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for BodyToBytesResponse<S>
62where
63    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
64    S::Future: Send + 'static,
65    ResBody: Body + Send,
66    ResBody::Data: Send,
67{
68    type Error = Error<S::Error, ResBody::Error>;
69    type Response = Response<Bytes>;
70    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
71
72    fn poll_ready(
73        &mut self,
74        cx: &mut std::task::Context<'_>,
75    ) -> std::task::Poll<Result<(), Self::Error>> {
76        self.inner.poll_ready(cx).map_err(Error::service)
77    }
78
79    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
80        let inner = self.inner.call(request);
81
82        let fut = async {
83            let response = inner.await.map_err(Error::service)?;
84            let (parts, body) = response.into_parts();
85
86            let body = body.collect().await.map_err(Error::body)?.to_bytes();
87
88            let response = Response::from_parts(parts, body);
89            Ok(response)
90        };
91
92        Box::pin(fut)
93    }
94}
95
96#[derive(Default, Clone, Copy)]
97pub struct BodyToBytesResponseLayer;
98
99impl<S> Layer<S> for BodyToBytesResponseLayer {
100    type Service = BodyToBytesResponse<S>;
101
102    fn layer(&self, inner: S) -> Self::Service {
103        BodyToBytesResponse::new(inner)
104    }
105}