Skip to main content

camel_auth/
oauth2.rs

1use async_trait::async_trait;
2use serde::Deserialize;
3use std::time::{Duration, Instant};
4use tokio::sync::{Mutex, RwLock};
5
6use crate::types::AuthError;
7
8const DEFAULT_SKEW: Duration = Duration::from_secs(30);
9
10#[async_trait]
11pub trait TokenProvider: Send + Sync + std::fmt::Debug {
12    async fn get_token(&self) -> Result<String, AuthError>;
13}
14
15#[derive(Debug, Deserialize)]
16struct TokenResponse {
17    access_token: String,
18    #[allow(dead_code)]
19    token_type: String,
20    expires_in: u64,
21}
22
23struct CachedToken {
24    access_token: String,
25    #[allow(dead_code)]
26    expires_at: Instant,
27    refresh_at: Instant,
28}
29
30impl CachedToken {
31    fn new(access_token: String, expires_in: Duration, skew: Duration) -> Self {
32        let expires_at = Instant::now() + expires_in;
33        Self {
34            access_token,
35            refresh_at: expires_at.checked_sub(skew).unwrap_or(expires_at),
36            expires_at,
37        }
38    }
39
40    fn is_usable(&self) -> bool {
41        Instant::now() < self.refresh_at
42    }
43}
44
45pub struct ClientCredentialsProvider {
46    token_endpoint: String,
47    client_id: String,
48    client_secret: String,
49    scope: Option<String>,
50    audience: Option<Vec<String>>,
51    cache: RwLock<Option<CachedToken>>,
52    refresh_lock: Mutex<()>,
53    http: reqwest::Client,
54}
55
56impl std::fmt::Debug for ClientCredentialsProvider {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("ClientCredentialsProvider")
59            .field("token_endpoint", &self.token_endpoint)
60            .field("client_id", &self.client_id)
61            .field("scope", &self.scope)
62            .field("audience", &self.audience)
63            .finish_non_exhaustive()
64    }
65}
66
67impl ClientCredentialsProvider {
68    pub fn new(
69        token_endpoint: String,
70        client_id: String,
71        client_secret: String,
72        scope: Option<String>,
73        audience: Option<Vec<String>>,
74    ) -> Self {
75        Self {
76            token_endpoint,
77            client_id,
78            client_secret,
79            scope,
80            audience,
81            cache: RwLock::new(None),
82            refresh_lock: Mutex::new(()),
83            http: reqwest::Client::new(),
84        }
85    }
86
87    /// Test-only constructor that accepts a pre-built HTTP client.
88    ///
89    /// Skips SSRF validation and allows injecting a mock-capable client
90    /// (e.g. `wiremock`-configured). Do NOT use in production code.
91    pub fn new_unchecked_for_test(
92        token_endpoint: String,
93        client_id: String,
94        client_secret: String,
95        scope: Option<String>,
96        audience: Option<Vec<String>>,
97        http: reqwest::Client,
98    ) -> Self {
99        Self {
100            token_endpoint,
101            client_id,
102            client_secret,
103            scope,
104            audience,
105            cache: RwLock::new(None),
106            refresh_lock: Mutex::new(()),
107            http,
108        }
109    }
110
111    async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
112        let mut params = vec![
113            ("grant_type", "client_credentials".to_string()),
114            ("client_id", self.client_id.clone()),
115            ("client_secret", self.client_secret.clone()),
116        ];
117        if let Some(ref scope) = self.scope {
118            params.push(("scope", scope.clone()));
119        }
120        if let Some(ref audience) = self.audience {
121            for aud in audience {
122                params.push(("resource", aud.clone()));
123            }
124        }
125
126        let resp = self
127            .http
128            .post(&self.token_endpoint)
129            .form(&params)
130            .send()
131            .await
132            .map_err(|e| AuthError::ProviderUnavailable(format!("OAuth2 request failed: {e}")))?;
133
134        if !resp.status().is_success() {
135            let status = resp.status();
136            let body = resp.text().await.unwrap_or_default();
137            let sanitized = if body.len() > 128 {
138                format!("{}...(truncated)", &body[..128])
139            } else {
140                body
141            };
142            let message = format!("token endpoint returned {status}: {sanitized}"); // allow-secret
143            return Err(AuthError::ProviderUnavailable(message));
144        }
145
146        let token_resp: TokenResponse = resp
147            .json()
148            .await
149            .map_err(|e| AuthError::ProviderUnavailable(format!("invalid OAuth2 response: {e}")))?;
150
151        Ok(CachedToken::new(
152            token_resp.access_token,
153            Duration::from_secs(token_resp.expires_in),
154            DEFAULT_SKEW,
155        ))
156    }
157}
158
159#[async_trait]
160impl TokenProvider for ClientCredentialsProvider {
161    async fn get_token(&self) -> Result<String, AuthError> {
162        {
163            let cache = self.cache.read().await;
164            if let Some(ref cached) = *cache
165                && cached.is_usable()
166            {
167                return Ok(cached.access_token.clone());
168            }
169        }
170
171        let _guard = self.refresh_lock.lock().await;
172
173        {
174            let cache = self.cache.read().await;
175            if let Some(ref cached) = *cache
176                && cached.is_usable()
177            {
178                return Ok(cached.access_token.clone());
179            }
180        }
181
182        let cached = self.fetch_token().await?;
183        let token = cached.access_token.clone();
184        {
185            let mut cache = self.cache.write().await;
186            *cache = Some(cached);
187        }
188        Ok(token)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use std::sync::Arc;
195
196    use super::*;
197    use wiremock::matchers::{body_string_contains, method, path};
198    use wiremock::{Mock, MockServer, ResponseTemplate};
199
200    fn token_response(access_token: &str, expires_in: u64) -> serde_json::Value {
201        serde_json::json!({
202            "access_token": access_token,
203            "token_type": "Bearer",
204            "expires_in": expires_in,
205        })
206    }
207
208    #[tokio::test]
209    async fn test_get_token_fresh() {
210        let server = MockServer::start().await;
211        Mock::given(method("POST"))
212            .and(path("/protocol/openid-connect/token"))
213            .respond_with(ResponseTemplate::new(200).set_body_json(token_response("abc123", 300)))
214            .mount(&server)
215            .await;
216
217        let provider = ClientCredentialsProvider::new(
218            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
219            "test-client".into(),
220            "test-secret".into(),
221            None,
222            None,
223        );
224        let token = provider.get_token().await.unwrap();
225        assert_eq!(token, "abc123");
226    }
227
228    #[tokio::test]
229    async fn test_get_token_uses_cache() {
230        let server = MockServer::start().await;
231        Mock::given(method("POST"))
232            .respond_with(ResponseTemplate::new(200).set_body_json(token_response("cached", 300)))
233            .expect(1)
234            .mount(&server)
235            .await;
236
237        let provider = ClientCredentialsProvider::new(
238            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
239            "c".into(),
240            "s".into(),
241            None,
242            None,
243        );
244        let t1 = provider.get_token().await.unwrap();
245        let t2 = provider.get_token().await.unwrap();
246        assert_eq!(t1, "cached");
247        assert_eq!(t2, "cached");
248    }
249
250    #[tokio::test]
251    async fn test_get_token_refreshes_when_stale() {
252        let server = MockServer::start().await;
253        Mock::given(method("POST"))
254            .respond_with(ResponseTemplate::new(200).set_body_json(token_response("first", 1)))
255            .up_to_n_times(1)
256            .mount(&server)
257            .await;
258        Mock::given(method("POST"))
259            .respond_with(ResponseTemplate::new(200).set_body_json(token_response("second", 300)))
260            .mount(&server)
261            .await;
262
263        let provider = ClientCredentialsProvider::new(
264            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
265            "c".into(),
266            "s".into(),
267            None,
268            None,
269        );
270        let t1 = provider.get_token().await.unwrap();
271        assert_eq!(t1, "first");
272        tokio::time::sleep(Duration::from_millis(1100)).await;
273        let t2 = provider.get_token().await.unwrap();
274        assert_eq!(t2, "second");
275    }
276
277    #[tokio::test]
278    async fn test_get_token_server_error() {
279        let server = MockServer::start().await;
280        Mock::given(method("POST"))
281            .respond_with(ResponseTemplate::new(500))
282            .mount(&server)
283            .await;
284
285        let provider = ClientCredentialsProvider::new(
286            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
287            "c".into(),
288            "s".into(),
289            None,
290            None,
291        );
292        let err = provider.get_token().await.unwrap_err();
293        assert!(matches!(err, AuthError::ProviderUnavailable(_)));
294    }
295
296    #[tokio::test]
297    async fn test_get_token_invalid_response() {
298        let server = MockServer::start().await;
299        Mock::given(method("POST"))
300            .respond_with(
301                ResponseTemplate::new(200)
302                    .set_body_json(serde_json::json!({"error": "invalid_grant"})),
303            )
304            .mount(&server)
305            .await;
306
307        let provider = ClientCredentialsProvider::new(
308            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
309            "c".into(),
310            "s".into(),
311            None,
312            None,
313        );
314        let err = provider.get_token().await.unwrap_err();
315        assert!(matches!(err, AuthError::ProviderUnavailable(_)));
316    }
317
318    #[tokio::test]
319    async fn test_get_token_sends_audience_as_resource() {
320        let server = MockServer::start().await;
321        Mock::given(method("POST"))
322            .and(body_string_contains(
323                "resource=https%3A%2F%2Fapi.example.com",
324            ))
325            .respond_with(
326                ResponseTemplate::new(200).set_body_json(token_response("aud-token", 300)),
327            )
328            .mount(&server)
329            .await;
330
331        let provider = ClientCredentialsProvider::new(
332            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
333            "c".into(),
334            "s".into(),
335            None,
336            Some(vec!["https://api.example.com".into()]),
337        );
338        let token = provider.get_token().await.unwrap();
339        assert_eq!(token, "aud-token");
340    }
341
342    #[tokio::test]
343    async fn test_single_flight_concurrent_callers() {
344        let server = MockServer::start().await;
345        Mock::given(method("POST"))
346            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
347                "access_token": "single-flight-token",
348                "token_type": "Bearer",
349                "expires_in": 300,
350            })))
351            .expect(1)
352            .mount(&server)
353            .await;
354
355        let provider = Arc::new(ClientCredentialsProvider::new(
356            format!("{}/protocol/openid-connect/token", server.uri()), // allow-secret
357            "c".into(),
358            "s".into(),
359            None,
360            None,
361        ));
362
363        let mut handles = vec![];
364        for _ in 0..5 {
365            let p = Arc::clone(&provider);
366            handles.push(tokio::spawn(async move { p.get_token().await }));
367        }
368        for h in handles {
369            let token = h.await.unwrap().unwrap();
370            assert_eq!(token, "single-flight-token");
371        }
372    }
373}