composable_tower_http/extract/chain/
chain_extractor.rs1use std::{ops::Deref, sync::Arc};
2
3use crate::extract::extractor::Extractor;
4
5use super::chainer::Chainer;
6
7#[derive(Debug)]
8pub struct ChainExtractorInner<Ex, C> {
9 extractor: Ex,
10 chainer: C,
11}
12
13impl<Ex, C> ChainExtractorInner<Ex, C> {
14 pub const fn new(extractor: Ex, chainer: C) -> Self {
15 Self { extractor, chainer }
16 }
17}
18
19#[derive(Debug)]
20pub struct ChainExtractor<Ex, C> {
21 inner: Arc<ChainExtractorInner<Ex, C>>,
22}
23
24impl<Ex, C> ChainExtractor<Ex, C> {
25 pub fn new(extractor: Ex, chainer: C) -> Self {
26 Self {
27 inner: Arc::new(ChainExtractorInner::new(extractor, chainer)),
28 }
29 }
30}
31
32impl<Ex, C> Clone for ChainExtractor<Ex, C> {
33 fn clone(&self) -> Self {
34 Self {
35 inner: self.inner.clone(),
36 }
37 }
38}
39
40impl<Ex, C> Deref for ChainExtractor<Ex, C> {
41 type Target = ChainExtractorInner<Ex, C>;
42
43 fn deref(&self) -> &Self::Target {
44 &self.inner
45 }
46}
47
48impl<Ex, C> Extractor for ChainExtractor<Ex, C>
49where
50 Ex: Extractor + Send + Sync,
51 C: Chainer<Ex::Extracted> + Send + Sync,
52{
53 type Extracted = C::Chained;
54
55 type Error = ChainError<Ex::Error, C::Error>;
56
57 async fn extract(&self, headers: &http::HeaderMap) -> Result<Self::Extracted, Self::Error> {
58 let extracted = self
59 .extractor
60 .extract(headers)
61 .await
62 .map_err(ChainError::Extract)?;
63
64 let chained = self
65 .chainer
66 .chain(extracted)
67 .await
68 .map_err(ChainError::Chain)?;
69
70 Ok(chained)
71 }
72}
73
74#[derive(Debug, thiserror::Error)]
75pub enum ChainError<Ex, E> {
76 #[error("Extraction error: {0}")]
77 Extract(#[source] Ex),
78 #[error("Chain error: {0}")]
79 Chain(#[source] E),
80}
81
82#[cfg(feature = "axum")]
83mod axum {
84 use axum::response::{IntoResponse, Response};
85
86 use super::ChainError;
87
88 impl<Ex, E> IntoResponse for ChainError<Ex, E>
89 where
90 Ex: IntoResponse,
91 E: IntoResponse,
92 {
93 fn into_response(self) -> Response {
94 match self {
95 ChainError::Extract(err) => err.into_response(),
96 ChainError::Chain(err) => err.into_response(),
97 }
98 }
99 }
100
101 impl<Ex, E> From<ChainError<Ex, E>> for Response
102 where
103 Ex: IntoResponse,
104 E: IntoResponse,
105 {
106 fn from(value: ChainError<Ex, E>) -> Self {
107 value.into_response()
108 }
109 }
110}