1use async_trait::async_trait;
11use faucet_core::{AuthProvider, Credential, FaucetError};
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::Value;
15use tokio::sync::Mutex;
16use tokio::time::Instant;
17
18use crate::expiry_instant;
19
20#[derive(Deserialize)]
21struct TokenResponse {
22 access_token: String,
23 #[serde(default)]
24 expires_in: Option<u64>,
25 #[serde(default)]
26 refresh_token: Option<String>,
27 #[allow(dead_code)]
28 #[serde(default)]
29 token_type: Option<String>,
30}
31
32#[derive(Default)]
33struct CachedToken {
34 access_token: Option<String>,
35 expires_at: Option<Instant>,
36}
37
38impl CachedToken {
39 fn valid(&self) -> Option<&str> {
40 match (&self.access_token, self.expires_at) {
41 (Some(tok), Some(exp)) if Instant::now() < exp => Some(tok),
42 (Some(tok), None) => Some(tok),
43 _ => None,
44 }
45 }
46}
47
48pub struct OAuth2ClientCredentialsProvider {
50 http: Client,
51 token_url: String,
52 client_id: String,
53 client_secret: String,
54 scopes: Vec<String>,
55 expiry_ratio: f64,
56 state: Mutex<CachedToken>,
57}
58
59impl std::fmt::Debug for OAuth2ClientCredentialsProvider {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("OAuth2ClientCredentialsProvider")
65 .field("token_url", &self.token_url)
66 .field("client_id", &self.client_id)
67 .field("client_secret", &"***")
68 .field("scopes", &self.scopes)
69 .field("expiry_ratio", &self.expiry_ratio)
70 .finish_non_exhaustive()
71 }
72}
73
74impl OAuth2ClientCredentialsProvider {
75 pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
78 Ok(Self {
79 http: crate::auth_http_client(),
80 token_url: required_str(config, "token_url")?,
81 client_id: required_str(config, "client_id")?,
82 client_secret: required_str(config, "client_secret")?,
83 scopes: string_array(config, "scopes"),
84 expiry_ratio: crate::parse_expiry_ratio(config)?,
85 state: Mutex::new(CachedToken::default()),
86 })
87 }
88
89 async fn fetch(&self) -> Result<TokenResponse, FaucetError> {
90 let resp = self
91 .http
92 .post(&self.token_url)
93 .form(&[
94 ("grant_type", "client_credentials"),
95 ("client_id", &self.client_id),
96 ("client_secret", &self.client_secret),
97 ("scope", &self.scopes.join(" ")),
98 ])
99 .send()
100 .await?;
101 parse_token_response(resp).await
102 }
103}
104
105#[async_trait]
106impl AuthProvider for OAuth2ClientCredentialsProvider {
107 async fn credential(&self) -> Result<Credential, FaucetError> {
108 let mut state = self.state.lock().await;
109 if let Some(tok) = state.valid() {
110 return Ok(Credential::Bearer(tok.to_string()));
111 }
112 let body = self.fetch().await?;
113 state.access_token = Some(body.access_token.clone());
114 state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
115 Ok(Credential::Bearer(body.access_token))
116 }
117
118 async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
119 let mut state = self.state.lock().await;
120 if let (Some(cur), Credential::Bearer(stale_tok)) = (state.valid(), stale)
122 && cur != stale_tok
123 {
124 return Ok(Credential::Bearer(cur.to_string()));
125 }
126 let body = self.fetch().await?;
127 state.access_token = Some(body.access_token.clone());
128 state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
129 Ok(Credential::Bearer(body.access_token))
130 }
131
132 fn provider_name(&self) -> &'static str {
133 "oauth2"
134 }
135}
136
137#[derive(Default)]
138struct RefreshState {
139 access_token: Option<String>,
140 expires_at: Option<Instant>,
141 refresh_token: String,
142}
143
144pub struct OAuth2RefreshProvider {
146 http: Client,
147 token_url: String,
148 client_id: String,
149 client_secret: String,
150 expiry_ratio: f64,
151 state: Mutex<RefreshState>,
152}
153
154impl std::fmt::Debug for OAuth2RefreshProvider {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("OAuth2RefreshProvider")
159 .field("token_url", &self.token_url)
160 .field("client_id", &self.client_id)
161 .field("client_secret", &"***")
162 .field("expiry_ratio", &self.expiry_ratio)
163 .finish_non_exhaustive()
164 }
165}
166
167impl OAuth2RefreshProvider {
168 pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
171 let refresh_token = required_str(config, "refresh_token")?;
172 Ok(Self {
173 http: crate::auth_http_client(),
174 token_url: required_str(config, "token_url")?,
175 client_id: required_str(config, "client_id")?,
176 client_secret: required_str(config, "client_secret")?,
177 expiry_ratio: crate::parse_expiry_ratio(config)?,
178 state: Mutex::new(RefreshState {
179 refresh_token,
180 ..Default::default()
181 }),
182 })
183 }
184
185 async fn refresh(&self, state: &mut RefreshState) -> Result<String, FaucetError> {
187 let resp = self
188 .http
189 .post(&self.token_url)
190 .form(&[
191 ("grant_type", "refresh_token"),
192 ("refresh_token", &state.refresh_token),
193 ("client_id", &self.client_id),
194 ("client_secret", &self.client_secret),
195 ])
196 .send()
197 .await?;
198 let body = parse_token_response(resp).await?;
199 state.access_token = Some(body.access_token.clone());
200 state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
201 if let Some(rotated) = body.refresh_token {
202 state.refresh_token = rotated; }
204 Ok(body.access_token)
205 }
206}
207
208#[async_trait]
209impl AuthProvider for OAuth2RefreshProvider {
210 async fn credential(&self) -> Result<Credential, FaucetError> {
211 let mut state = self.state.lock().await;
212 if let (Some(tok), Some(exp)) = (&state.access_token, state.expires_at)
213 && Instant::now() < exp
214 {
215 return Ok(Credential::Bearer(tok.clone()));
216 }
217 let token = self.refresh(&mut state).await?;
218 Ok(Credential::Bearer(token))
219 }
220
221 async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
222 let mut state = self.state.lock().await;
223 if let (Some(cur), Credential::Bearer(stale_tok)) = (&state.access_token, stale)
226 && cur != stale_tok
227 {
228 return Ok(Credential::Bearer(cur.clone()));
229 }
230 let token = self.refresh(&mut state).await?;
231 Ok(Credential::Bearer(token))
232 }
233
234 fn provider_name(&self) -> &'static str {
235 "oauth2_refresh"
236 }
237}
238
239fn required_str(config: &Value, key: &str) -> Result<String, FaucetError> {
240 config
241 .get(key)
242 .and_then(Value::as_str)
243 .map(str::to_string)
244 .ok_or_else(|| FaucetError::Config(format!("oauth2 auth provider: missing `{key}`")))
245}
246
247fn string_array(config: &Value, key: &str) -> Vec<String> {
248 config
249 .get(key)
250 .and_then(Value::as_array)
251 .map(|a| {
252 a.iter()
253 .filter_map(|v| v.as_str().map(str::to_string))
254 .collect()
255 })
256 .unwrap_or_default()
257}
258
259async fn parse_token_response(resp: reqwest::Response) -> Result<TokenResponse, FaucetError> {
260 if !resp.status().is_success() {
261 let status = resp.status().as_u16();
262 let body = resp.text().await.unwrap_or_default();
263 return Err(FaucetError::Auth(format!(
264 "OAuth2 token request failed (HTTP {status}): {body}"
265 )));
266 }
267 resp.json::<TokenResponse>().await.map_err(Into::into)
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::sync::Arc;
274 use std::sync::atomic::{AtomicUsize, Ordering};
275 use wiremock::matchers::method;
276 use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
277
278 struct CountingToken {
279 hits: Arc<AtomicUsize>,
280 token_prefix: &'static str,
281 }
282 impl Respond for CountingToken {
283 fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
284 let n = self.hits.fetch_add(1, Ordering::SeqCst) + 1;
285 ResponseTemplate::new(200).set_body_json(serde_json::json!({
286 "access_token": format!("{}{n}", self.token_prefix),
287 "expires_in": 3600,
288 "refresh_token": format!("rt{n}"),
289 }))
290 }
291 }
292
293 #[tokio::test]
294 async fn refresh_provider_single_flight_one_fetch_for_concurrent_calls() {
295 let server = MockServer::start().await;
296 let hits = Arc::new(AtomicUsize::new(0));
297 Mock::given(method("POST"))
298 .respond_with(CountingToken {
299 hits: hits.clone(),
300 token_prefix: "A",
301 })
302 .mount(&server)
303 .await;
304
305 let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
306 "token_url": server.uri(),
307 "client_id": "id",
308 "client_secret": "secret",
309 "refresh_token": "rt0",
310 }))
311 .unwrap();
312
313 let results = futures::future::join_all((0..4).map(|_| provider.credential())).await;
314 for r in &results {
315 assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("A1".into()));
316 }
317 assert_eq!(
318 hits.load(Ordering::SeqCst),
319 1,
320 "expected exactly one token fetch"
321 );
322 }
323
324 #[tokio::test]
325 async fn refresh_provider_invalidate_cas_refetches_once() {
326 let server = MockServer::start().await;
327 let hits = Arc::new(AtomicUsize::new(0));
328 Mock::given(method("POST"))
329 .respond_with(CountingToken {
330 hits: hits.clone(),
331 token_prefix: "A",
332 })
333 .mount(&server)
334 .await;
335 let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
336 "token_url": server.uri(),
337 "client_id": "id",
338 "client_secret": "secret",
339 "refresh_token": "rt0",
340 }))
341 .unwrap();
342
343 let first = provider.credential().await.unwrap();
344 assert_eq!(first, Credential::Bearer("A1".into()));
345 let second = provider.invalidate(&first).await.unwrap();
347 assert_eq!(second, Credential::Bearer("A2".into()));
348 assert_eq!(hits.load(Ordering::SeqCst), 2);
349 let again = provider.invalidate(&first).await.unwrap();
351 assert_eq!(again, Credential::Bearer("A2".into()));
352 assert_eq!(hits.load(Ordering::SeqCst), 2, "stale CAS must not refetch");
353 }
354
355 #[test]
356 fn provider_debug_does_not_leak_secrets() {
357 let cc = OAuth2ClientCredentialsProvider::from_config(&serde_json::json!({
360 "token_url": "https://idp.example/token",
361 "client_id": "id",
362 "client_secret": "topsecretclient",
363 }))
364 .unwrap();
365 let s = format!("{cc:?}");
366 assert!(!s.contains("topsecretclient"), "client_secret leaked: {s}");
367 assert!(
368 s.contains("client_id"),
369 "non-secret fields should remain: {s}"
370 );
371
372 let rf = OAuth2RefreshProvider::from_config(&serde_json::json!({
373 "token_url": "https://idp.example/token",
374 "client_id": "id",
375 "client_secret": "topsecretclient",
376 "refresh_token": "topsecretrefresh",
377 }))
378 .unwrap();
379 let s = format!("{rf:?}");
380 assert!(!s.contains("topsecretclient"), "client_secret leaked: {s}");
381 assert!(!s.contains("topsecretrefresh"), "refresh_token leaked: {s}");
382 }
383
384 #[tokio::test]
385 async fn client_credentials_single_flight() {
386 let server = MockServer::start().await;
387 let hits = Arc::new(AtomicUsize::new(0));
388 Mock::given(method("POST"))
389 .respond_with(CountingToken {
390 hits: hits.clone(),
391 token_prefix: "C",
392 })
393 .mount(&server)
394 .await;
395 let provider = OAuth2ClientCredentialsProvider::from_config(&serde_json::json!({
396 "token_url": server.uri(),
397 "client_id": "id",
398 "client_secret": "secret",
399 "scopes": ["read"],
400 }))
401 .unwrap();
402 let results = futures::future::join_all((0..4).map(|_| provider.credential())).await;
403 for r in &results {
404 assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("C1".into()));
405 }
406 assert_eq!(hits.load(Ordering::SeqCst), 1);
407 }
408
409 #[tokio::test]
410 async fn token_endpoint_failure_surfaces_auth_error() {
411 let server = MockServer::start().await;
412 Mock::given(method("POST"))
413 .respond_with(ResponseTemplate::new(401).set_body_string("nope"))
414 .mount(&server)
415 .await;
416 let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
417 "token_url": server.uri(),
418 "client_id": "id",
419 "client_secret": "secret",
420 "refresh_token": "rt0",
421 }))
422 .unwrap();
423 assert!(matches!(
424 provider.credential().await,
425 Err(FaucetError::Auth(_))
426 ));
427 }
428}