mas_http/layers/
json_response.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{marker::PhantomData, task::Poll};
16
17use bytes::Buf;
18use futures_util::FutureExt;
19use http::{header::ACCEPT, HeaderValue, Request, Response};
20use serde::de::DeserializeOwned;
21use thiserror::Error;
22use tower::{Layer, Service};
23
24#[derive(Debug, Error)]
25pub enum Error<Service> {
26    /// An error from the inner service.
27    #[error(transparent)]
28    Service { inner: Service },
29
30    #[error("could not parse JSON payload")]
31    Deserialize {
32        #[source]
33        inner: serde_json::Error,
34    },
35}
36
37impl<S> Error<S> {
38    fn service(source: S) -> Self {
39        Self::Service { inner: source }
40    }
41
42    fn deserialize(source: serde_json::Error) -> Self {
43        Self::Deserialize { inner: source }
44    }
45}
46
47#[derive(Clone)]
48pub struct JsonResponse<S, T> {
49    inner: S,
50    _t: PhantomData<T>,
51}
52
53impl<S, T> JsonResponse<S, T> {
54    pub const fn new(inner: S) -> Self {
55        Self {
56            inner,
57            _t: PhantomData,
58        }
59    }
60}
61
62impl<S, T, B, C> Service<Request<B>> for JsonResponse<S, T>
63where
64    S: Service<Request<B>, Response = Response<C>>,
65    S::Future: Send + 'static,
66    C: Buf,
67    T: DeserializeOwned,
68{
69    type Error = Error<S::Error>;
70    type Response = Response<T>;
71    type Future = futures_util::future::Map<
72        S::Future,
73        fn(Result<Response<C>, S::Error>) -> Result<Self::Response, Self::Error>,
74    >;
75
76    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
77        self.inner.poll_ready(cx).map_err(Error::service)
78    }
79
80    fn call(&mut self, mut request: Request<B>) -> Self::Future {
81        fn mapper<C, T, E>(res: Result<Response<C>, E>) -> Result<Response<T>, Error<E>>
82        where
83            C: Buf,
84            T: DeserializeOwned,
85        {
86            let response = res.map_err(Error::service)?;
87            let (parts, body) = response.into_parts();
88
89            let body = serde_json::from_reader(body.reader()).map_err(Error::deserialize)?;
90
91            let res = Response::from_parts(parts, body);
92            Ok(res)
93        }
94
95        request
96            .headers_mut()
97            .insert(ACCEPT, HeaderValue::from_static("application/json"));
98
99        self.inner.call(request).map(mapper::<C, T, S::Error>)
100    }
101}
102
103#[derive(Clone, Copy)]
104pub struct JsonResponseLayer<T> {
105    _t: PhantomData<T>,
106}
107
108impl<T> Default for JsonResponseLayer<T> {
109    fn default() -> Self {
110        Self { _t: PhantomData }
111    }
112}
113
114impl<S, T> Layer<S> for JsonResponseLayer<T> {
115    type Service = JsonResponse<S, T>;
116
117    fn layer(&self, inner: S) -> Self::Service {
118        JsonResponse::new(inner)
119    }
120}