composable_tower_http/extension/modify/
service.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use http::Request;
9use tower::Service;
10
11use crate::{extract::SealedExtracted, modify::Modifier};
12
13#[derive(Debug, Clone)]
14pub struct ModificationService<S, M, T> {
15    service: S,
16    modifier: M,
17    _phantom: PhantomData<T>,
18}
19
20impl<S, M, T> ModificationService<S, M, T> {
21    pub const fn new(service: S, modifier: M) -> Self {
22        Self {
23            service,
24            modifier,
25            _phantom: PhantomData,
26        }
27    }
28}
29
30impl<S, M, B, T> Service<Request<B>> for ModificationService<S, M, T>
31where
32    M: Modifier<T> + Clone + Send + 'static,
33    T: Send + Sync + 'static,
34    S: Service<Request<B>> + Clone + Send + 'static,
35    S::Future: Send,
36    S::Response: From<ModificationError<M::Error>>,
37    B: Send + 'static,
38{
39    type Response = S::Response;
40    type Error = S::Error;
41    type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
42
43    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
44        self.service.poll_ready(cx)
45    }
46
47    fn call(&mut self, mut request: Request<B>) -> Self::Future {
48        let mut service = self.service.clone();
49        let modifier = self.modifier.clone();
50
51        Box::pin(async move {
52            match request.extensions_mut().remove::<SealedExtracted<T>>() {
53                Some(SealedExtracted(extracted)) => {
54                    match modifier.modify(extracted).await {
55                        Ok(modified) => request.extensions_mut().insert(SealedExtracted(modified)),
56                        Err(err) => return Ok(From::from(ModificationError::Modification(err))),
57                    };
58                }
59                None => return Ok(From::from(ModificationError::Extract)),
60            }
61
62            service.call(request).await
63        })
64    }
65}
66
67#[derive(Debug, thiserror::Error)]
68pub enum ModificationError<E> {
69    #[error("Extraction error")]
70    Extract,
71    #[error("Modification error: {0}")]
72    Modification(#[source] E),
73}
74
75#[cfg(feature = "axum")]
76mod axum {
77    use axum::response::{IntoResponse, Response};
78    use http::StatusCode;
79
80    use super::ModificationError;
81
82    impl<E> IntoResponse for ModificationError<E>
83    where
84        E: IntoResponse,
85    {
86        fn into_response(self) -> Response {
87            match self {
88                ModificationError::Extract => (StatusCode::INTERNAL_SERVER_ERROR).into_response(),
89                ModificationError::Modification(err) => err.into_response(),
90            }
91        }
92    }
93
94    impl<E> From<ModificationError<E>> for Response
95    where
96        E: IntoResponse,
97    {
98        fn from(value: ModificationError<E>) -> Self {
99            value.into_response()
100        }
101    }
102}