1use async_trait::async_trait;
2use serde::Deserialize;
3use std::time::{Duration, Instant};
4use tokio::sync::{Mutex, RwLock};
5
6use crate::types::AuthError;
7
8const DEFAULT_SKEW: Duration = Duration::from_secs(30);
9
10#[async_trait]
11pub trait TokenProvider: Send + Sync + std::fmt::Debug {
12 async fn get_token(&self) -> Result<String, AuthError>;
13}
14
15#[derive(Debug, Deserialize)]
16struct TokenResponse {
17 access_token: String,
18 #[allow(dead_code)]
19 token_type: String,
20 expires_in: u64,
21}
22
23struct CachedToken {
24 access_token: String,
25 #[allow(dead_code)]
26 expires_at: Instant,
27 refresh_at: Instant,
28}
29
30impl CachedToken {
31 fn new(access_token: String, expires_in: Duration, skew: Duration) -> Self {
32 let expires_at = Instant::now() + expires_in;
33 Self {
34 access_token,
35 refresh_at: expires_at.checked_sub(skew).unwrap_or(expires_at),
36 expires_at,
37 }
38 }
39
40 fn is_usable(&self) -> bool {
41 Instant::now() < self.refresh_at
42 }
43}
44
45pub struct ClientCredentialsProvider {
46 token_endpoint: String,
47 client_id: String,
48 client_secret: String,
49 scope: Option<String>,
50 audience: Option<Vec<String>>,
51 cache: RwLock<Option<CachedToken>>,
52 refresh_lock: Mutex<()>,
53 http: reqwest::Client,
54}
55
56impl std::fmt::Debug for ClientCredentialsProvider {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("ClientCredentialsProvider")
59 .field("token_endpoint", &self.token_endpoint)
60 .field("client_id", &self.client_id)
61 .field("scope", &self.scope)
62 .field("audience", &self.audience)
63 .finish_non_exhaustive()
64 }
65}
66
67impl ClientCredentialsProvider {
68 pub fn new(
69 token_endpoint: String,
70 client_id: String,
71 client_secret: String,
72 scope: Option<String>,
73 audience: Option<Vec<String>>,
74 ) -> Self {
75 Self {
76 token_endpoint,
77 client_id,
78 client_secret,
79 scope,
80 audience,
81 cache: RwLock::new(None),
82 refresh_lock: Mutex::new(()),
83 http: reqwest::Client::new(),
84 }
85 }
86
87 pub fn new_unchecked_for_test(
92 token_endpoint: String,
93 client_id: String,
94 client_secret: String,
95 scope: Option<String>,
96 audience: Option<Vec<String>>,
97 http: reqwest::Client,
98 ) -> Self {
99 Self {
100 token_endpoint,
101 client_id,
102 client_secret,
103 scope,
104 audience,
105 cache: RwLock::new(None),
106 refresh_lock: Mutex::new(()),
107 http,
108 }
109 }
110
111 async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
112 let mut params = vec![
113 ("grant_type", "client_credentials".to_string()),
114 ("client_id", self.client_id.clone()),
115 ("client_secret", self.client_secret.clone()),
116 ];
117 if let Some(ref scope) = self.scope {
118 params.push(("scope", scope.clone()));
119 }
120 if let Some(ref audience) = self.audience {
121 for aud in audience {
122 params.push(("resource", aud.clone()));
123 }
124 }
125
126 let resp = self
127 .http
128 .post(&self.token_endpoint)
129 .form(¶ms)
130 .send()
131 .await
132 .map_err(|e| AuthError::ProviderUnavailable(format!("OAuth2 request failed: {e}")))?;
133
134 if !resp.status().is_success() {
135 let status = resp.status();
136 let body = resp.text().await.unwrap_or_default();
137 let sanitized = if body.len() > 128 {
138 format!("{}...(truncated)", &body[..128])
139 } else {
140 body
141 };
142 let message = format!("token endpoint returned {status}: {sanitized}"); return Err(AuthError::ProviderUnavailable(message));
144 }
145
146 let token_resp: TokenResponse = resp
147 .json()
148 .await
149 .map_err(|e| AuthError::ProviderUnavailable(format!("invalid OAuth2 response: {e}")))?;
150
151 Ok(CachedToken::new(
152 token_resp.access_token,
153 Duration::from_secs(token_resp.expires_in),
154 DEFAULT_SKEW,
155 ))
156 }
157}
158
159#[async_trait]
160impl TokenProvider for ClientCredentialsProvider {
161 async fn get_token(&self) -> Result<String, AuthError> {
162 {
163 let cache = self.cache.read().await;
164 if let Some(ref cached) = *cache
165 && cached.is_usable()
166 {
167 return Ok(cached.access_token.clone());
168 }
169 }
170
171 let _guard = self.refresh_lock.lock().await;
172
173 {
174 let cache = self.cache.read().await;
175 if let Some(ref cached) = *cache
176 && cached.is_usable()
177 {
178 return Ok(cached.access_token.clone());
179 }
180 }
181
182 let cached = self.fetch_token().await?;
183 let token = cached.access_token.clone();
184 {
185 let mut cache = self.cache.write().await;
186 *cache = Some(cached);
187 }
188 Ok(token)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::Arc;
195
196 use super::*;
197 use wiremock::matchers::{body_string_contains, method, path};
198 use wiremock::{Mock, MockServer, ResponseTemplate};
199
200 fn token_response(access_token: &str, expires_in: u64) -> serde_json::Value {
201 serde_json::json!({
202 "access_token": access_token,
203 "token_type": "Bearer",
204 "expires_in": expires_in,
205 })
206 }
207
208 #[tokio::test]
209 async fn test_get_token_fresh() {
210 let server = MockServer::start().await;
211 Mock::given(method("POST"))
212 .and(path("/protocol/openid-connect/token"))
213 .respond_with(ResponseTemplate::new(200).set_body_json(token_response("abc123", 300)))
214 .mount(&server)
215 .await;
216
217 let provider = ClientCredentialsProvider::new(
218 format!("{}/protocol/openid-connect/token", server.uri()), "test-client".into(),
220 "test-secret".into(),
221 None,
222 None,
223 );
224 let token = provider.get_token().await.unwrap();
225 assert_eq!(token, "abc123");
226 }
227
228 #[tokio::test]
229 async fn test_get_token_uses_cache() {
230 let server = MockServer::start().await;
231 Mock::given(method("POST"))
232 .respond_with(ResponseTemplate::new(200).set_body_json(token_response("cached", 300)))
233 .expect(1)
234 .mount(&server)
235 .await;
236
237 let provider = ClientCredentialsProvider::new(
238 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
240 "s".into(),
241 None,
242 None,
243 );
244 let t1 = provider.get_token().await.unwrap();
245 let t2 = provider.get_token().await.unwrap();
246 assert_eq!(t1, "cached");
247 assert_eq!(t2, "cached");
248 }
249
250 #[tokio::test]
251 async fn test_get_token_refreshes_when_stale() {
252 let server = MockServer::start().await;
253 Mock::given(method("POST"))
254 .respond_with(ResponseTemplate::new(200).set_body_json(token_response("first", 1)))
255 .up_to_n_times(1)
256 .mount(&server)
257 .await;
258 Mock::given(method("POST"))
259 .respond_with(ResponseTemplate::new(200).set_body_json(token_response("second", 300)))
260 .mount(&server)
261 .await;
262
263 let provider = ClientCredentialsProvider::new(
264 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
266 "s".into(),
267 None,
268 None,
269 );
270 let t1 = provider.get_token().await.unwrap();
271 assert_eq!(t1, "first");
272 tokio::time::sleep(Duration::from_millis(1100)).await;
273 let t2 = provider.get_token().await.unwrap();
274 assert_eq!(t2, "second");
275 }
276
277 #[tokio::test]
278 async fn test_get_token_server_error() {
279 let server = MockServer::start().await;
280 Mock::given(method("POST"))
281 .respond_with(ResponseTemplate::new(500))
282 .mount(&server)
283 .await;
284
285 let provider = ClientCredentialsProvider::new(
286 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
288 "s".into(),
289 None,
290 None,
291 );
292 let err = provider.get_token().await.unwrap_err();
293 assert!(matches!(err, AuthError::ProviderUnavailable(_)));
294 }
295
296 #[tokio::test]
297 async fn test_get_token_invalid_response() {
298 let server = MockServer::start().await;
299 Mock::given(method("POST"))
300 .respond_with(
301 ResponseTemplate::new(200)
302 .set_body_json(serde_json::json!({"error": "invalid_grant"})),
303 )
304 .mount(&server)
305 .await;
306
307 let provider = ClientCredentialsProvider::new(
308 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
310 "s".into(),
311 None,
312 None,
313 );
314 let err = provider.get_token().await.unwrap_err();
315 assert!(matches!(err, AuthError::ProviderUnavailable(_)));
316 }
317
318 #[tokio::test]
319 async fn test_get_token_sends_audience_as_resource() {
320 let server = MockServer::start().await;
321 Mock::given(method("POST"))
322 .and(body_string_contains(
323 "resource=https%3A%2F%2Fapi.example.com",
324 ))
325 .respond_with(
326 ResponseTemplate::new(200).set_body_json(token_response("aud-token", 300)),
327 )
328 .mount(&server)
329 .await;
330
331 let provider = ClientCredentialsProvider::new(
332 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
334 "s".into(),
335 None,
336 Some(vec!["https://api.example.com".into()]),
337 );
338 let token = provider.get_token().await.unwrap();
339 assert_eq!(token, "aud-token");
340 }
341
342 #[tokio::test]
343 async fn test_single_flight_concurrent_callers() {
344 let server = MockServer::start().await;
345 Mock::given(method("POST"))
346 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
347 "access_token": "single-flight-token",
348 "token_type": "Bearer",
349 "expires_in": 300,
350 })))
351 .expect(1)
352 .mount(&server)
353 .await;
354
355 let provider = Arc::new(ClientCredentialsProvider::new(
356 format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
358 "s".into(),
359 None,
360 None,
361 ));
362
363 let mut handles = vec![];
364 for _ in 0..5 {
365 let p = Arc::clone(&provider);
366 handles.push(tokio::spawn(async move { p.get_token().await }));
367 }
368 for h in handles {
369 let token = h.await.unwrap().unwrap();
370 assert_eq!(token, "single-flight-token");
371 }
372 }
373}