Skip to main content

camel_auth/
bearer_token_layer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use tower::{Layer, Service};
7
8use crate::oauth2::TokenProvider;
9use camel_api::{CamelError, Exchange};
10
11pub struct BearerTokenLayer {
12    provider: Arc<dyn TokenProvider>,
13}
14
15impl BearerTokenLayer {
16    pub fn new(provider: Arc<dyn TokenProvider>) -> Self {
17        Self { provider }
18    }
19}
20
21impl<S> Layer<S> for BearerTokenLayer {
22    type Service = BearerTokenService<S>;
23
24    fn layer(&self, inner: S) -> Self::Service {
25        BearerTokenService {
26            inner,
27            provider: Arc::clone(&self.provider),
28        }
29    }
30}
31
32pub struct BearerTokenService<S> {
33    inner: S,
34    provider: Arc<dyn TokenProvider>,
35}
36
37impl<S: Clone> Clone for BearerTokenService<S> {
38    fn clone(&self) -> Self {
39        Self {
40            inner: self.inner.clone(),
41            provider: Arc::clone(&self.provider),
42        }
43    }
44}
45
46impl<S> Service<Exchange> for BearerTokenService<S>
47where
48    S: Service<Exchange, Response = Exchange, Error = CamelError> + Clone + Send + 'static,
49    S::Future: Send,
50{
51    type Response = Exchange;
52    type Error = CamelError;
53    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
54
55    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56        self.inner.poll_ready(cx)
57    }
58
59    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
60        let provider = Arc::clone(&self.provider);
61        let clone = self.inner.clone();
62        let mut inner = std::mem::replace(&mut self.inner, clone);
63
64        Box::pin(async move {
65            let token = provider.get_token().await.map_err(CamelError::from)?;
66            exchange
67                .input
68                .set_header("Authorization", format!("Bearer {token}")); // allow-secret
69            inner.call(exchange).await
70        })
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::types::AuthError;
78    use async_trait::async_trait;
79    use camel_api::{BoxProcessor, BoxProcessorExt, Message};
80    use std::sync::atomic::{AtomicUsize, Ordering};
81    use tower::ServiceExt;
82
83    #[derive(Debug)]
84    struct StaticTokenProvider {
85        token: String,
86    }
87
88    #[async_trait]
89    impl TokenProvider for StaticTokenProvider {
90        async fn get_token(&self) -> Result<String, AuthError> {
91            Ok(self.token.clone())
92        }
93    }
94
95    #[derive(Debug)]
96    struct FailingTokenProvider;
97
98    #[async_trait]
99    impl TokenProvider for FailingTokenProvider {
100        async fn get_token(&self) -> Result<String, AuthError> {
101            Err(AuthError::ProviderUnavailable("token endpoint down".into()))
102        }
103    }
104
105    fn make_exchange() -> Exchange {
106        Exchange::new(Message::new("test"))
107    }
108
109    fn ok_processor() -> BoxProcessor {
110        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
111    }
112
113    #[tokio::test]
114    async fn test_injects_authorization_header() {
115        let provider = Arc::new(StaticTokenProvider {
116            token: "my-token".into(),
117        });
118        let layer = BearerTokenLayer::new(provider);
119        let mut svc = layer.layer(ok_processor());
120        let result = svc.ready().await.unwrap().call(make_exchange()).await;
121        assert!(result.is_ok());
122        let ex = result.unwrap();
123        let auth = ex.input.header("Authorization").and_then(|v| v.as_str());
124        assert_eq!(auth, Some("Bearer my-token"));
125    }
126
127    #[tokio::test]
128    async fn test_preserves_existing_headers() {
129        let provider = Arc::new(StaticTokenProvider { token: "t".into() });
130        let layer = BearerTokenLayer::new(provider);
131        let mut svc = layer.layer(ok_processor());
132        let mut ex = make_exchange();
133        ex.input.set_header("X-Custom", "value");
134        let result = svc.ready().await.unwrap().call(ex).await;
135        let ex = result.unwrap();
136        assert_eq!(
137            ex.input.header("X-Custom").and_then(|v| v.as_str()),
138            Some("value")
139        );
140        assert_eq!(
141            ex.input.header("Authorization").and_then(|v| v.as_str()),
142            Some("Bearer t")
143        );
144    }
145
146    #[tokio::test]
147    async fn test_overwrites_existing_auth_header() {
148        let provider = Arc::new(StaticTokenProvider {
149            token: "fresh".into(),
150        });
151        let layer = BearerTokenLayer::new(provider);
152        let mut svc = layer.layer(ok_processor());
153        let mut ex = make_exchange();
154        ex.input.set_header("Authorization", "Bearer stale");
155        let result = svc.ready().await.unwrap().call(ex).await;
156        let ex = result.unwrap();
157        assert_eq!(
158            ex.input.header("Authorization").and_then(|v| v.as_str()),
159            Some("Bearer fresh")
160        );
161    }
162
163    #[tokio::test]
164    async fn test_provider_error_propagates() {
165        let provider = Arc::new(FailingTokenProvider);
166        let layer = BearerTokenLayer::new(provider);
167        let mut svc = layer.layer(ok_processor());
168        let result = svc.ready().await.unwrap().call(make_exchange()).await;
169        assert!(result.is_err());
170        match result.unwrap_err() {
171            CamelError::ProcessorError(msg) => {
172                assert!(msg.contains("token endpoint down"));
173            }
174            other => panic!("expected ProcessorError, got: {other:?}"),
175        }
176    }
177
178    #[tokio::test]
179    async fn test_calls_provider_each_time() {
180        let count = Arc::new(AtomicUsize::new(0));
181        #[derive(Debug)]
182        struct CountingProvider {
183            count: Arc<AtomicUsize>,
184        }
185        #[async_trait]
186        impl TokenProvider for CountingProvider {
187            async fn get_token(&self) -> Result<String, AuthError> {
188                let n = self.count.fetch_add(1, Ordering::SeqCst);
189                Ok(format!("token-{n}")) // allow-secret
190            }
191        }
192        let provider = Arc::new(CountingProvider {
193            count: Arc::clone(&count),
194        });
195        let layer = BearerTokenLayer::new(provider);
196        let mut svc = layer.layer(ok_processor());
197
198        for i in 0..3 {
199            let result = svc.ready().await.unwrap().call(make_exchange()).await;
200            let ex = result.unwrap();
201            let auth = ex
202                .input
203                .header("Authorization")
204                .and_then(|v| v.as_str())
205                .unwrap();
206            assert!(auth.contains(&format!("token-{i}"))); // allow-secret
207        }
208        assert_eq!(count.load(Ordering::SeqCst), 3);
209    }
210
211    #[tokio::test]
212    async fn test_clone_produces_working_service() {
213        let provider = Arc::new(StaticTokenProvider { token: "t".into() });
214        let layer = BearerTokenLayer::new(provider);
215        let mut svc1 = layer.layer(ok_processor());
216        let svc2 = svc1.clone();
217
218        let r1 = svc1.ready().await.unwrap().call(make_exchange()).await;
219        let mut svc2 = svc2;
220        let r2 = svc2.ready().await.unwrap().call(make_exchange()).await;
221        assert!(r1.is_ok());
222        assert!(r2.is_ok());
223    }
224}