composable_tower_http/extract/chain/
chain_extractor.rs

1use std::{ops::Deref, sync::Arc};
2
3use crate::extract::extractor::Extractor;
4
5use super::chainer::Chainer;
6
7#[derive(Debug)]
8pub struct ChainExtractorInner<Ex, C> {
9    extractor: Ex,
10    chainer: C,
11}
12
13impl<Ex, C> ChainExtractorInner<Ex, C> {
14    pub const fn new(extractor: Ex, chainer: C) -> Self {
15        Self { extractor, chainer }
16    }
17}
18
19#[derive(Debug)]
20pub struct ChainExtractor<Ex, C> {
21    inner: Arc<ChainExtractorInner<Ex, C>>,
22}
23
24impl<Ex, C> ChainExtractor<Ex, C> {
25    pub fn new(extractor: Ex, chainer: C) -> Self {
26        Self {
27            inner: Arc::new(ChainExtractorInner::new(extractor, chainer)),
28        }
29    }
30}
31
32impl<Ex, C> Clone for ChainExtractor<Ex, C> {
33    fn clone(&self) -> Self {
34        Self {
35            inner: self.inner.clone(),
36        }
37    }
38}
39
40impl<Ex, C> Deref for ChainExtractor<Ex, C> {
41    type Target = ChainExtractorInner<Ex, C>;
42
43    fn deref(&self) -> &Self::Target {
44        &self.inner
45    }
46}
47
48impl<Ex, C> Extractor for ChainExtractor<Ex, C>
49where
50    Ex: Extractor + Send + Sync,
51    C: Chainer<Ex::Extracted> + Send + Sync,
52{
53    type Extracted = C::Chained;
54
55    type Error = ChainError<Ex::Error, C::Error>;
56
57    async fn extract(&self, headers: &http::HeaderMap) -> Result<Self::Extracted, Self::Error> {
58        let extracted = self
59            .extractor
60            .extract(headers)
61            .await
62            .map_err(ChainError::Extract)?;
63
64        let chained = self
65            .chainer
66            .chain(extracted)
67            .await
68            .map_err(ChainError::Chain)?;
69
70        Ok(chained)
71    }
72}
73
74#[derive(Debug, thiserror::Error)]
75pub enum ChainError<Ex, E> {
76    #[error("Extraction error: {0}")]
77    Extract(#[source] Ex),
78    #[error("Chain error: {0}")]
79    Chain(#[source] E),
80}
81
82#[cfg(feature = "axum")]
83mod axum {
84    use axum::response::{IntoResponse, Response};
85
86    use super::ChainError;
87
88    impl<Ex, E> IntoResponse for ChainError<Ex, E>
89    where
90        Ex: IntoResponse,
91        E: IntoResponse,
92    {
93        fn into_response(self) -> Response {
94            match self {
95                ChainError::Extract(err) => err.into_response(),
96                ChainError::Chain(err) => err.into_response(),
97            }
98        }
99    }
100
101    impl<Ex, E> From<ChainError<Ex, E>> for Response
102    where
103        Ex: IntoResponse,
104        E: IntoResponse,
105    {
106        fn from(value: ChainError<Ex, E>) -> Self {
107            value.into_response()
108        }
109    }
110}