Skip to main content

faucet_stream/auth/
oauth2.rs

1//! OAuth2 client credentials flow with token caching.
2
3use crate::error::FaucetError;
4use reqwest::Client;
5use serde::Deserialize;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8
9#[derive(Debug, Deserialize)]
10struct TokenResponse {
11    access_token: String,
12    #[serde(default)]
13    expires_in: Option<u64>,
14    #[allow(dead_code)]
15    #[serde(default)]
16    token_type: Option<String>,
17}
18
19/// Default fraction of `expires_in` after which the token is refreshed.
20pub const DEFAULT_EXPIRY_RATIO: f64 = 0.9;
21
22/// Cached OAuth2 token with expiration tracking.
23#[derive(Debug, Clone)]
24struct CachedToken {
25    access_token: String,
26    /// Instant at which the token should be considered expired. Computed as
27    /// `now + expires_in * expiry_ratio` at fetch time.  `None` means no
28    /// expiry info was provided by the server — the token is treated as valid
29    /// indefinitely (until a 401 forces a refresh).
30    expires_at: Option<tokio::time::Instant>,
31}
32
33impl CachedToken {
34    fn is_valid(&self) -> bool {
35        match self.expires_at {
36            Some(exp) => tokio::time::Instant::now() < exp,
37            None => true,
38        }
39    }
40}
41
42/// Thread-safe token cache shared across requests within a single `RestStream`.
43#[derive(Debug, Clone, Default)]
44pub struct TokenCache(Arc<Mutex<Option<CachedToken>>>);
45
46impl TokenCache {
47    pub fn new() -> Self {
48        Self(Arc::new(Mutex::new(None)))
49    }
50
51    /// Return a valid cached token or fetch a new one.
52    ///
53    /// `expiry_ratio` is the fraction of the server-reported `expires_in`
54    /// lifetime after which the token is proactively refreshed. For example,
55    /// `0.9` means a token with `expires_in = 3600` is refreshed after 3240 s.
56    pub async fn get_or_refresh(
57        &self,
58        client: &Client,
59        token_url: &str,
60        client_id: &str,
61        client_secret: &str,
62        scopes: &[String],
63        expiry_ratio: f64,
64    ) -> Result<String, FaucetError> {
65        let mut guard = self.0.lock().await;
66        if let Some(cached) = guard.as_ref() {
67            if cached.is_valid() {
68                return Ok(cached.access_token.clone());
69            }
70            tracing::debug!("OAuth2 token expired; refreshing");
71        }
72
73        let (token, expires_in) = fetch_oauth2_token_inner_with_client(
74            client,
75            token_url,
76            client_id,
77            client_secret,
78            scopes,
79        )
80        .await?;
81
82        let expires_at = expires_in.map(|secs| {
83            let effective = (secs as f64 * expiry_ratio) as u64;
84            tokio::time::Instant::now() + std::time::Duration::from_secs(effective)
85        });
86
87        *guard = Some(CachedToken {
88            access_token: token.clone(),
89            expires_at,
90        });
91
92        Ok(token)
93    }
94}
95
96/// Fetch an OAuth2 token using the client credentials grant.
97///
98/// Prefer using [`TokenCache::get_or_refresh`] to avoid fetching a new token
99/// on every request.
100pub async fn fetch_oauth2_token(
101    token_url: &str,
102    client_id: &str,
103    client_secret: &str,
104    scopes: &[String],
105) -> Result<String, FaucetError> {
106    let (token, _) = fetch_oauth2_token_inner(token_url, client_id, client_secret, scopes).await?;
107    Ok(token)
108}
109
110async fn fetch_oauth2_token_inner(
111    token_url: &str,
112    client_id: &str,
113    client_secret: &str,
114    scopes: &[String],
115) -> Result<(String, Option<u64>), FaucetError> {
116    let client = Client::new();
117    fetch_oauth2_token_inner_with_client(&client, token_url, client_id, client_secret, scopes).await
118}
119
120async fn fetch_oauth2_token_inner_with_client(
121    client: &Client,
122    token_url: &str,
123    client_id: &str,
124    client_secret: &str,
125    scopes: &[String],
126) -> Result<(String, Option<u64>), FaucetError> {
127    let resp = client
128        .post(token_url)
129        .form(&[
130            ("grant_type", "client_credentials"),
131            ("client_id", client_id),
132            ("client_secret", client_secret),
133            ("scope", &scopes.join(" ")),
134        ])
135        .send()
136        .await?;
137
138    if !resp.status().is_success() {
139        let status = resp.status().as_u16();
140        let body = resp.text().await.unwrap_or_default();
141        return Err(FaucetError::Auth(format!(
142            "OAuth2 token request failed (HTTP {status}): {body}"
143        )));
144    }
145
146    let token_resp: TokenResponse = resp.json().await?;
147    Ok((token_resp.access_token, token_resp.expires_in))
148}