camel_auth/
bearer_token_layer.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use tower::{Layer, Service};
7
8use crate::oauth2::TokenProvider;
9use camel_api::{CamelError, Exchange};
10
11pub struct BearerTokenLayer {
12 provider: Arc<dyn TokenProvider>,
13}
14
15impl BearerTokenLayer {
16 pub fn new(provider: Arc<dyn TokenProvider>) -> Self {
17 Self { provider }
18 }
19}
20
21impl<S> Layer<S> for BearerTokenLayer {
22 type Service = BearerTokenService<S>;
23
24 fn layer(&self, inner: S) -> Self::Service {
25 BearerTokenService {
26 inner,
27 provider: Arc::clone(&self.provider),
28 }
29 }
30}
31
32pub struct BearerTokenService<S> {
33 inner: S,
34 provider: Arc<dyn TokenProvider>,
35}
36
37impl<S: Clone> Clone for BearerTokenService<S> {
38 fn clone(&self) -> Self {
39 Self {
40 inner: self.inner.clone(),
41 provider: Arc::clone(&self.provider),
42 }
43 }
44}
45
46impl<S> Service<Exchange> for BearerTokenService<S>
47where
48 S: Service<Exchange, Response = Exchange, Error = CamelError> + Clone + Send + 'static,
49 S::Future: Send,
50{
51 type Response = Exchange;
52 type Error = CamelError;
53 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
54
55 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 self.inner.poll_ready(cx)
57 }
58
59 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
60 let provider = Arc::clone(&self.provider);
61 let clone = self.inner.clone();
62 let mut inner = std::mem::replace(&mut self.inner, clone);
63
64 Box::pin(async move {
65 let token = provider.get_token().await.map_err(CamelError::from)?;
66 exchange
67 .input
68 .set_header("Authorization", format!("Bearer {token}")); inner.call(exchange).await
70 })
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use crate::types::AuthError;
78 use async_trait::async_trait;
79 use camel_api::{BoxProcessor, BoxProcessorExt, Message};
80 use std::sync::atomic::{AtomicUsize, Ordering};
81 use tower::ServiceExt;
82
83 #[derive(Debug)]
84 struct StaticTokenProvider {
85 token: String,
86 }
87
88 #[async_trait]
89 impl TokenProvider for StaticTokenProvider {
90 async fn get_token(&self) -> Result<String, AuthError> {
91 Ok(self.token.clone())
92 }
93 }
94
95 #[derive(Debug)]
96 struct FailingTokenProvider;
97
98 #[async_trait]
99 impl TokenProvider for FailingTokenProvider {
100 async fn get_token(&self) -> Result<String, AuthError> {
101 Err(AuthError::ProviderUnavailable("token endpoint down".into()))
102 }
103 }
104
105 fn make_exchange() -> Exchange {
106 Exchange::new(Message::new("test"))
107 }
108
109 fn ok_processor() -> BoxProcessor {
110 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
111 }
112
113 #[tokio::test]
114 async fn test_injects_authorization_header() {
115 let provider = Arc::new(StaticTokenProvider {
116 token: "my-token".into(),
117 });
118 let layer = BearerTokenLayer::new(provider);
119 let mut svc = layer.layer(ok_processor());
120 let result = svc.ready().await.unwrap().call(make_exchange()).await;
121 assert!(result.is_ok());
122 let ex = result.unwrap();
123 let auth = ex.input.header("Authorization").and_then(|v| v.as_str());
124 assert_eq!(auth, Some("Bearer my-token"));
125 }
126
127 #[tokio::test]
128 async fn test_preserves_existing_headers() {
129 let provider = Arc::new(StaticTokenProvider { token: "t".into() });
130 let layer = BearerTokenLayer::new(provider);
131 let mut svc = layer.layer(ok_processor());
132 let mut ex = make_exchange();
133 ex.input.set_header("X-Custom", "value");
134 let result = svc.ready().await.unwrap().call(ex).await;
135 let ex = result.unwrap();
136 assert_eq!(
137 ex.input.header("X-Custom").and_then(|v| v.as_str()),
138 Some("value")
139 );
140 assert_eq!(
141 ex.input.header("Authorization").and_then(|v| v.as_str()),
142 Some("Bearer t")
143 );
144 }
145
146 #[tokio::test]
147 async fn test_overwrites_existing_auth_header() {
148 let provider = Arc::new(StaticTokenProvider {
149 token: "fresh".into(),
150 });
151 let layer = BearerTokenLayer::new(provider);
152 let mut svc = layer.layer(ok_processor());
153 let mut ex = make_exchange();
154 ex.input.set_header("Authorization", "Bearer stale");
155 let result = svc.ready().await.unwrap().call(ex).await;
156 let ex = result.unwrap();
157 assert_eq!(
158 ex.input.header("Authorization").and_then(|v| v.as_str()),
159 Some("Bearer fresh")
160 );
161 }
162
163 #[tokio::test]
164 async fn test_provider_error_propagates() {
165 let provider = Arc::new(FailingTokenProvider);
166 let layer = BearerTokenLayer::new(provider);
167 let mut svc = layer.layer(ok_processor());
168 let result = svc.ready().await.unwrap().call(make_exchange()).await;
169 assert!(result.is_err());
170 match result.unwrap_err() {
171 CamelError::ProcessorError(msg) => {
172 assert!(msg.contains("token endpoint down"));
173 }
174 other => panic!("expected ProcessorError, got: {other:?}"),
175 }
176 }
177
178 #[tokio::test]
179 async fn test_calls_provider_each_time() {
180 let count = Arc::new(AtomicUsize::new(0));
181 #[derive(Debug)]
182 struct CountingProvider {
183 count: Arc<AtomicUsize>,
184 }
185 #[async_trait]
186 impl TokenProvider for CountingProvider {
187 async fn get_token(&self) -> Result<String, AuthError> {
188 let n = self.count.fetch_add(1, Ordering::SeqCst);
189 Ok(format!("token-{n}")) }
191 }
192 let provider = Arc::new(CountingProvider {
193 count: Arc::clone(&count),
194 });
195 let layer = BearerTokenLayer::new(provider);
196 let mut svc = layer.layer(ok_processor());
197
198 for i in 0..3 {
199 let result = svc.ready().await.unwrap().call(make_exchange()).await;
200 let ex = result.unwrap();
201 let auth = ex
202 .input
203 .header("Authorization")
204 .and_then(|v| v.as_str())
205 .unwrap();
206 assert!(auth.contains(&format!("token-{i}"))); }
208 assert_eq!(count.load(Ordering::SeqCst), 3);
209 }
210
211 #[tokio::test]
212 async fn test_clone_produces_working_service() {
213 let provider = Arc::new(StaticTokenProvider { token: "t".into() });
214 let layer = BearerTokenLayer::new(provider);
215 let mut svc1 = layer.layer(ok_processor());
216 let svc2 = svc1.clone();
217
218 let r1 = svc1.ready().await.unwrap().call(make_exchange()).await;
219 let mut svc2 = svc2;
220 let r2 = svc2.ready().await.unwrap().call(make_exchange()).await;
221 assert!(r1.is_ok());
222 assert!(r2.is_ok());
223 }
224}