Skip to main content

faucet_auth/
oauth2.rs

1//! OAuth2 providers: `client_credentials` and `refresh_token` (with rotation).
2//!
3//! Both hold a single [`Mutex`]-guarded cache and perform the token-endpoint
4//! call **with the lock held**, so concurrent callers during a refresh await the
5//! one in-flight fetch (single-flight). The refresh provider captures a rotated
6//! `refresh_token` from each response in place, so a single active access token
7//! plus a rotating refresh token can be shared across many connectors without
8//! racing.
9
10use async_trait::async_trait;
11use faucet_core::{AuthProvider, Credential, FaucetError};
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::Value;
15use tokio::sync::Mutex;
16use tokio::time::Instant;
17
18use crate::expiry_instant;
19
20#[derive(Deserialize)]
21struct TokenResponse {
22    access_token: String,
23    #[serde(default)]
24    expires_in: Option<u64>,
25    #[serde(default)]
26    refresh_token: Option<String>,
27    #[allow(dead_code)]
28    #[serde(default)]
29    token_type: Option<String>,
30}
31
32#[derive(Default)]
33struct CachedToken {
34    access_token: Option<String>,
35    expires_at: Option<Instant>,
36}
37
38impl CachedToken {
39    fn valid(&self) -> Option<&str> {
40        match (&self.access_token, self.expires_at) {
41            (Some(tok), Some(exp)) if Instant::now() < exp => Some(tok),
42            (Some(tok), None) => Some(tok),
43            _ => None,
44        }
45    }
46}
47
48/// OAuth2 `client_credentials` grant provider.
49pub struct OAuth2ClientCredentialsProvider {
50    http: Client,
51    token_url: String,
52    client_id: String,
53    client_secret: String,
54    scopes: Vec<String>,
55    expiry_ratio: f64,
56    state: Mutex<CachedToken>,
57}
58
59// Hand-written so `{:?}` (the trait requires `AuthProvider: Debug`, and providers
60// are shared as `Arc<dyn AuthProvider>`) never prints the `client_secret` or the
61// cached access token in `state`.
62impl std::fmt::Debug for OAuth2ClientCredentialsProvider {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("OAuth2ClientCredentialsProvider")
65            .field("token_url", &self.token_url)
66            .field("client_id", &self.client_id)
67            .field("client_secret", &"***")
68            .field("scopes", &self.scopes)
69            .field("expiry_ratio", &self.expiry_ratio)
70            .finish_non_exhaustive()
71    }
72}
73
74impl OAuth2ClientCredentialsProvider {
75    /// Build from a config object with `token_url`, `client_id`,
76    /// `client_secret`, optional `scopes` and `expiry_ratio`.
77    pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
78        Ok(Self {
79            http: crate::auth_http_client(),
80            token_url: required_str(config, "token_url")?,
81            client_id: required_str(config, "client_id")?,
82            client_secret: required_str(config, "client_secret")?,
83            scopes: string_array(config, "scopes"),
84            expiry_ratio: crate::parse_expiry_ratio(config)?,
85            state: Mutex::new(CachedToken::default()),
86        })
87    }
88
89    async fn fetch(&self) -> Result<TokenResponse, FaucetError> {
90        let resp = self
91            .http
92            .post(&self.token_url)
93            .form(&[
94                ("grant_type", "client_credentials"),
95                ("client_id", &self.client_id),
96                ("client_secret", &self.client_secret),
97                ("scope", &self.scopes.join(" ")),
98            ])
99            .send()
100            .await?;
101        parse_token_response(resp).await
102    }
103}
104
105#[async_trait]
106impl AuthProvider for OAuth2ClientCredentialsProvider {
107    async fn credential(&self) -> Result<Credential, FaucetError> {
108        let mut state = self.state.lock().await;
109        if let Some(tok) = state.valid() {
110            return Ok(Credential::Bearer(tok.to_string()));
111        }
112        let body = self.fetch().await?;
113        state.access_token = Some(body.access_token.clone());
114        state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
115        Ok(Credential::Bearer(body.access_token))
116    }
117
118    async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
119        let mut state = self.state.lock().await;
120        // CAS: only refresh if the cache still holds the stale token.
121        if let (Some(cur), Credential::Bearer(stale_tok)) = (state.valid(), stale)
122            && cur != stale_tok
123        {
124            return Ok(Credential::Bearer(cur.to_string()));
125        }
126        let body = self.fetch().await?;
127        state.access_token = Some(body.access_token.clone());
128        state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
129        Ok(Credential::Bearer(body.access_token))
130    }
131
132    fn provider_name(&self) -> &'static str {
133        "oauth2"
134    }
135}
136
137#[derive(Default)]
138struct RefreshState {
139    access_token: Option<String>,
140    expires_at: Option<Instant>,
141    refresh_token: String,
142}
143
144/// OAuth2 `refresh_token` grant provider with refresh-token rotation capture.
145pub struct OAuth2RefreshProvider {
146    http: Client,
147    token_url: String,
148    client_id: String,
149    client_secret: String,
150    expiry_ratio: f64,
151    state: Mutex<RefreshState>,
152}
153
154// Hand-written so `{:?}` never prints the `client_secret` or the `refresh_token`
155// / cached access token held in `state`.
156impl std::fmt::Debug for OAuth2RefreshProvider {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("OAuth2RefreshProvider")
159            .field("token_url", &self.token_url)
160            .field("client_id", &self.client_id)
161            .field("client_secret", &"***")
162            .field("expiry_ratio", &self.expiry_ratio)
163            .finish_non_exhaustive()
164    }
165}
166
167impl OAuth2RefreshProvider {
168    /// Build from a config object with `token_url`, `client_id`,
169    /// `client_secret`, `refresh_token`, and optional `expiry_ratio`.
170    pub fn from_config(config: &Value) -> Result<Self, FaucetError> {
171        let refresh_token = required_str(config, "refresh_token")?;
172        Ok(Self {
173            http: crate::auth_http_client(),
174            token_url: required_str(config, "token_url")?,
175            client_id: required_str(config, "client_id")?,
176            client_secret: required_str(config, "client_secret")?,
177            expiry_ratio: crate::parse_expiry_ratio(config)?,
178            state: Mutex::new(RefreshState {
179                refresh_token,
180                ..Default::default()
181            }),
182        })
183    }
184
185    /// Refresh using the *current* refresh token and capture rotation in place.
186    async fn refresh(&self, state: &mut RefreshState) -> Result<String, FaucetError> {
187        let resp = self
188            .http
189            .post(&self.token_url)
190            .form(&[
191                ("grant_type", "refresh_token"),
192                ("refresh_token", &state.refresh_token),
193                ("client_id", &self.client_id),
194                ("client_secret", &self.client_secret),
195            ])
196            .send()
197            .await?;
198        let body = parse_token_response(resp).await?;
199        state.access_token = Some(body.access_token.clone());
200        state.expires_at = expiry_instant(body.expires_in, self.expiry_ratio);
201        if let Some(rotated) = body.refresh_token {
202            state.refresh_token = rotated; // capture rotation centrally
203        }
204        Ok(body.access_token)
205    }
206}
207
208#[async_trait]
209impl AuthProvider for OAuth2RefreshProvider {
210    async fn credential(&self) -> Result<Credential, FaucetError> {
211        let mut state = self.state.lock().await;
212        if let (Some(tok), Some(exp)) = (&state.access_token, state.expires_at)
213            && Instant::now() < exp
214        {
215            return Ok(Credential::Bearer(tok.clone()));
216        }
217        let token = self.refresh(&mut state).await?;
218        Ok(Credential::Bearer(token))
219    }
220
221    async fn invalidate(&self, stale: &Credential) -> Result<Credential, FaucetError> {
222        let mut state = self.state.lock().await;
223        // CAS: another connector may have already refreshed; if the cached token
224        // no longer equals the stale one, hand back the fresh token.
225        if let (Some(cur), Credential::Bearer(stale_tok)) = (&state.access_token, stale)
226            && cur != stale_tok
227        {
228            return Ok(Credential::Bearer(cur.clone()));
229        }
230        let token = self.refresh(&mut state).await?;
231        Ok(Credential::Bearer(token))
232    }
233
234    fn provider_name(&self) -> &'static str {
235        "oauth2_refresh"
236    }
237}
238
239fn required_str(config: &Value, key: &str) -> Result<String, FaucetError> {
240    config
241        .get(key)
242        .and_then(Value::as_str)
243        .map(str::to_string)
244        .ok_or_else(|| FaucetError::Config(format!("oauth2 auth provider: missing `{key}`")))
245}
246
247fn string_array(config: &Value, key: &str) -> Vec<String> {
248    config
249        .get(key)
250        .and_then(Value::as_array)
251        .map(|a| {
252            a.iter()
253                .filter_map(|v| v.as_str().map(str::to_string))
254                .collect()
255        })
256        .unwrap_or_default()
257}
258
259async fn parse_token_response(resp: reqwest::Response) -> Result<TokenResponse, FaucetError> {
260    if !resp.status().is_success() {
261        let status = resp.status().as_u16();
262        let body = resp.text().await.unwrap_or_default();
263        return Err(FaucetError::Auth(format!(
264            "OAuth2 token request failed (HTTP {status}): {body}"
265        )));
266    }
267    resp.json::<TokenResponse>().await.map_err(Into::into)
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use std::sync::Arc;
274    use std::sync::atomic::{AtomicUsize, Ordering};
275    use wiremock::matchers::method;
276    use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
277
278    struct CountingToken {
279        hits: Arc<AtomicUsize>,
280        token_prefix: &'static str,
281    }
282    impl Respond for CountingToken {
283        fn respond(&self, _: &wiremock::Request) -> ResponseTemplate {
284            let n = self.hits.fetch_add(1, Ordering::SeqCst) + 1;
285            ResponseTemplate::new(200).set_body_json(serde_json::json!({
286                "access_token": format!("{}{n}", self.token_prefix),
287                "expires_in": 3600,
288                "refresh_token": format!("rt{n}"),
289            }))
290        }
291    }
292
293    #[tokio::test]
294    async fn refresh_provider_single_flight_one_fetch_for_concurrent_calls() {
295        let server = MockServer::start().await;
296        let hits = Arc::new(AtomicUsize::new(0));
297        Mock::given(method("POST"))
298            .respond_with(CountingToken {
299                hits: hits.clone(),
300                token_prefix: "A",
301            })
302            .mount(&server)
303            .await;
304
305        let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
306            "token_url": server.uri(),
307            "client_id": "id",
308            "client_secret": "secret",
309            "refresh_token": "rt0",
310        }))
311        .unwrap();
312
313        let results = futures::future::join_all((0..4).map(|_| provider.credential())).await;
314        for r in &results {
315            assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("A1".into()));
316        }
317        assert_eq!(
318            hits.load(Ordering::SeqCst),
319            1,
320            "expected exactly one token fetch"
321        );
322    }
323
324    #[tokio::test]
325    async fn refresh_provider_invalidate_cas_refetches_once() {
326        let server = MockServer::start().await;
327        let hits = Arc::new(AtomicUsize::new(0));
328        Mock::given(method("POST"))
329            .respond_with(CountingToken {
330                hits: hits.clone(),
331                token_prefix: "A",
332            })
333            .mount(&server)
334            .await;
335        let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
336            "token_url": server.uri(),
337            "client_id": "id",
338            "client_secret": "secret",
339            "refresh_token": "rt0",
340        }))
341        .unwrap();
342
343        let first = provider.credential().await.unwrap();
344        assert_eq!(first, Credential::Bearer("A1".into()));
345        // Invalidate the token we hold → one more fetch, rotated refresh token used.
346        let second = provider.invalidate(&first).await.unwrap();
347        assert_eq!(second, Credential::Bearer("A2".into()));
348        assert_eq!(hits.load(Ordering::SeqCst), 2);
349        // Invalidating a *stale* token that no longer matches → no fetch.
350        let again = provider.invalidate(&first).await.unwrap();
351        assert_eq!(again, Credential::Bearer("A2".into()));
352        assert_eq!(hits.load(Ordering::SeqCst), 2, "stale CAS must not refetch");
353    }
354
355    #[test]
356    fn provider_debug_does_not_leak_secrets() {
357        // `AuthProvider: Debug`, and providers are held as `Arc<dyn AuthProvider>`,
358        // so a stray `{:?}` must never print the client secret / refresh token.
359        let cc = OAuth2ClientCredentialsProvider::from_config(&serde_json::json!({
360            "token_url": "https://idp.example/token",
361            "client_id": "id",
362            "client_secret": "topsecretclient",
363        }))
364        .unwrap();
365        let s = format!("{cc:?}");
366        assert!(!s.contains("topsecretclient"), "client_secret leaked: {s}");
367        assert!(
368            s.contains("client_id"),
369            "non-secret fields should remain: {s}"
370        );
371
372        let rf = OAuth2RefreshProvider::from_config(&serde_json::json!({
373            "token_url": "https://idp.example/token",
374            "client_id": "id",
375            "client_secret": "topsecretclient",
376            "refresh_token": "topsecretrefresh",
377        }))
378        .unwrap();
379        let s = format!("{rf:?}");
380        assert!(!s.contains("topsecretclient"), "client_secret leaked: {s}");
381        assert!(!s.contains("topsecretrefresh"), "refresh_token leaked: {s}");
382    }
383
384    #[tokio::test]
385    async fn client_credentials_single_flight() {
386        let server = MockServer::start().await;
387        let hits = Arc::new(AtomicUsize::new(0));
388        Mock::given(method("POST"))
389            .respond_with(CountingToken {
390                hits: hits.clone(),
391                token_prefix: "C",
392            })
393            .mount(&server)
394            .await;
395        let provider = OAuth2ClientCredentialsProvider::from_config(&serde_json::json!({
396            "token_url": server.uri(),
397            "client_id": "id",
398            "client_secret": "secret",
399            "scopes": ["read"],
400        }))
401        .unwrap();
402        let results = futures::future::join_all((0..4).map(|_| provider.credential())).await;
403        for r in &results {
404            assert_eq!(r.as_ref().unwrap(), &Credential::Bearer("C1".into()));
405        }
406        assert_eq!(hits.load(Ordering::SeqCst), 1);
407    }
408
409    #[tokio::test]
410    async fn token_endpoint_failure_surfaces_auth_error() {
411        let server = MockServer::start().await;
412        Mock::given(method("POST"))
413            .respond_with(ResponseTemplate::new(401).set_body_string("nope"))
414            .mount(&server)
415            .await;
416        let provider = OAuth2RefreshProvider::from_config(&serde_json::json!({
417            "token_url": server.uri(),
418            "client_id": "id",
419            "client_secret": "secret",
420            "refresh_token": "rt0",
421        }))
422        .unwrap();
423        assert!(matches!(
424            provider.credential().await,
425            Err(FaucetError::Auth(_))
426        ));
427    }
428}