Skip to main content

modkit_auth/oauth2/
token.rs

1use std::fmt;
2use std::sync::Arc;
3use std::time::Duration;
4
5use aliri_clock::DurationSecs;
6use aliri_tokens::backoff::ErrorBackoffConfig;
7use aliri_tokens::jitter::RandomEarlyJitter;
8use aliri_tokens::{TokenStatus, TokenWatcher};
9use arc_swap::ArcSwap;
10
11use super::config::OAuthClientConfig;
12use super::error::TokenError;
13use super::source::OAuthTokenSource;
14use modkit_utils::SecretString;
15
16/// Internal state holding the live watcher.
17///
18/// Wrapped in `Arc<ArcSwap<_>>` so that [`Token::invalidate`] can atomically
19/// swap in a replacement without blocking concurrent [`Token::get`] calls.
20struct TokenInner {
21    watcher: TokenWatcher,
22}
23
24/// Parameters needed to (re-)spawn a [`TokenWatcher`].
25struct WatcherConfig {
26    jitter_max: Duration,
27    min_refresh_period: Duration,
28}
29
30/// Handle for obtaining `OAuth2` bearer tokens.
31///
32/// Internally drives an `aliri_tokens::TokenWatcher` for background refresh and
33/// exposes lock-free reads via `ArcSwap` (same pattern as the JWKS key
34/// provider).
35///
36/// `Token` is [`Clone`] + [`Send`] + [`Sync`] — share freely across tasks.
37#[derive(Clone)]
38pub struct Token {
39    inner: Arc<ArcSwap<TokenInner>>,
40    source_factory: Arc<dyn Fn() -> Result<OAuthTokenSource, TokenError> + Send + Sync>,
41    watcher_config: Arc<WatcherConfig>,
42}
43
44impl fmt::Debug for Token {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("Token").finish_non_exhaustive()
47    }
48}
49
50impl Token {
51    /// Create a new token handle and start background refresh.
52    ///
53    /// This performs an initial token fetch — if the token endpoint is
54    /// unreachable or returns an error, `new` will fail immediately.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`TokenError::ConfigError`] if the config is invalid.
59    /// Returns [`TokenError::Http`] if the initial token fetch fails.
60    pub async fn new(mut config: OAuthClientConfig) -> Result<Self, TokenError> {
61        config.validate()?;
62
63        // Resolve issuer_url → token_endpoint via OIDC discovery (one-time).
64        if let Some(issuer_url) = config.issuer_url.take() {
65            let http_config = config
66                .http_config
67                .clone()
68                .unwrap_or_else(modkit_http::HttpClientConfig::token_endpoint);
69            let client = modkit_http::HttpClientBuilder::with_config(http_config)
70                .build()
71                .map_err(|e| {
72                    TokenError::Http(crate::http_error::format_http_error(&e, "OIDC discovery"))
73                })?;
74            let resolved = super::discovery::discover_token_endpoint(&client, &issuer_url).await?;
75            config.token_endpoint = Some(resolved);
76        }
77
78        let watcher_config = Arc::new(WatcherConfig {
79            jitter_max: config.jitter_max,
80            min_refresh_period: config.min_refresh_period,
81        });
82
83        let source = OAuthTokenSource::new(&config)?;
84        let watcher = spawn_watcher(source, &watcher_config).await?;
85
86        let source_factory: Arc<dyn Fn() -> Result<OAuthTokenSource, TokenError> + Send + Sync> =
87            Arc::new(move || OAuthTokenSource::new(&config));
88
89        Ok(Self {
90            inner: Arc::new(ArcSwap::from_pointee(TokenInner { watcher })),
91            source_factory,
92            watcher_config,
93        })
94    }
95
96    /// Get the current bearer token.
97    ///
98    /// This is a lock-free read from the `ArcSwap`-cached watcher — it never
99    /// blocks on a network call.  The underlying watcher refreshes the token in
100    /// the background before it expires.
101    ///
102    /// The returned [`SecretString`] wraps the raw access-token value so it is
103    /// not accidentally logged.
104    ///
105    /// # Errors
106    ///
107    /// Returns [`TokenError::Unavailable`] if the cached token has expired
108    /// (the background watcher has not yet refreshed it).
109    pub fn get(&self) -> Result<SecretString, TokenError> {
110        let guard = self.inner.load();
111        let borrowed = guard.watcher.token();
112        if matches!(borrowed.token_status(), TokenStatus::Expired) {
113            return Err(TokenError::Unavailable(
114                "token expired, refresh pending".into(),
115            ));
116        }
117        let raw = borrowed.access_token().as_str();
118        Ok(SecretString::new(raw))
119    }
120
121    /// Force-replace the internal watcher with a freshly-spawned one.
122    ///
123    /// Use this after receiving a 401 from a downstream service to immediately
124    /// discard a potentially revoked token.
125    ///
126    /// If recreating the source or the initial token fetch fails, a warning is
127    /// logged and the existing watcher is left in place.
128    pub async fn invalidate(&self) {
129        let source = match (self.source_factory)() {
130            Ok(s) => s,
131            Err(e) => {
132                tracing::warn!("OAuth2 token invalidation: failed to create source: {e}");
133                return;
134            }
135        };
136
137        let watcher = match spawn_watcher(source, &self.watcher_config).await {
138            Ok(w) => w,
139            Err(e) => {
140                tracing::warn!("OAuth2 token invalidation: initial fetch failed: {e}");
141                return;
142            }
143        };
144
145        self.inner.store(Arc::new(TokenInner { watcher }));
146    }
147}
148
149/// Spawn a [`TokenWatcher`] from the given source and config.
150async fn spawn_watcher(
151    source: OAuthTokenSource,
152    config: &WatcherConfig,
153) -> Result<TokenWatcher, TokenError> {
154    let jitter = RandomEarlyJitter::new(DurationSecs(config.jitter_max.as_secs()));
155    let backoff =
156        ErrorBackoffConfig::new(config.min_refresh_period, config.min_refresh_period * 30, 2);
157
158    TokenWatcher::spawn_from_token_source(source, jitter, backoff).await
159}
160
161#[cfg(test)]
162#[cfg_attr(coverage_nightly, coverage(off))]
163mod tests {
164    use super::*;
165    use httpmock::prelude::*;
166    use url::Url;
167
168    /// Build a test config pointing at the given mock server.
169    fn test_config(server: &MockServer) -> OAuthClientConfig {
170        OAuthClientConfig {
171            token_endpoint: Some(
172                Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
173            ),
174            client_id: "test-client".into(),
175            client_secret: SecretString::new("test-secret"),
176            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
177            // Use short durations for tests.
178            jitter_max: Duration::from_millis(0),
179            min_refresh_period: Duration::from_millis(100),
180            ..Default::default()
181        }
182    }
183
184    fn token_json(token: &str, expires_in: u64) -> String {
185        format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
186    }
187
188    // -- trait assertions -----------------------------------------------------
189
190    #[test]
191    fn token_is_send_sync_clone() {
192        fn assert_traits<T: Send + Sync + Clone>() {}
193        assert_traits::<Token>();
194    }
195
196    // -- new ------------------------------------------------------------------
197
198    #[tokio::test]
199    async fn new_with_valid_config() {
200        let server = MockServer::start();
201
202        let _mock = server.mock(|when, then| {
203            when.method(POST).path("/token");
204            then.status(200)
205                .header("content-type", "application/json")
206                .body(token_json("tok-new", 3600));
207        });
208
209        let token = Token::new(test_config(&server)).await;
210        assert!(
211            token.is_ok(),
212            "Token::new() should succeed: {:?}",
213            token.err()
214        );
215    }
216
217    #[tokio::test]
218    async fn new_validates_config() {
219        let cfg = OAuthClientConfig {
220            token_endpoint: Some(Url::parse("https://a.example.com/token").unwrap()),
221            issuer_url: Some(Url::parse("https://b.example.com").unwrap()),
222            client_id: "test-client".into(),
223            client_secret: SecretString::new("test-secret"),
224            ..Default::default()
225        };
226        let err = Token::new(cfg).await.unwrap_err();
227        assert!(
228            matches!(err, TokenError::ConfigError(ref msg) if msg.contains("mutually exclusive")),
229            "expected ConfigError, got: {err}"
230        );
231    }
232
233    // -- get ------------------------------------------------------------------
234
235    #[tokio::test]
236    async fn get_returns_secret_string() {
237        let server = MockServer::start();
238
239        let _mock = server.mock(|when, then| {
240            when.method(POST).path("/token");
241            then.status(200)
242                .header("content-type", "application/json")
243                .body(token_json("tok-get-test", 3600));
244        });
245
246        let token = Token::new(test_config(&server)).await.unwrap();
247        let secret = token.get().unwrap();
248
249        assert_eq!(secret.expose(), "tok-get-test");
250    }
251
252    // -- invalidate -----------------------------------------------------------
253
254    #[tokio::test]
255    async fn invalidate_creates_new_watcher() {
256        let server = MockServer::start();
257
258        let mock = server.mock(|when, then| {
259            when.method(POST).path("/token");
260            then.status(200)
261                .header("content-type", "application/json")
262                .body(token_json("tok-inv", 3600));
263        });
264
265        let token = Token::new(test_config(&server)).await.unwrap();
266        assert_eq!(mock.calls(), 1, "initial fetch");
267
268        token.invalidate().await;
269
270        // invalidate spawns a new watcher which fetches a fresh token
271        assert_eq!(mock.calls(), 2, "after invalidate");
272    }
273
274    // -- concurrency ----------------------------------------------------------
275
276    #[tokio::test]
277    async fn concurrent_get_no_deadlock() {
278        let server = MockServer::start();
279
280        let _mock = server.mock(|when, then| {
281            when.method(POST).path("/token");
282            then.status(200)
283                .header("content-type", "application/json")
284                .body(token_json("tok-conc", 3600));
285        });
286
287        let token = Token::new(test_config(&server)).await.unwrap();
288
289        let t1 = {
290            let token = token.clone();
291            tokio::spawn(async move { token.get() })
292        };
293        let t2 = {
294            let token = token.clone();
295            tokio::spawn(async move { token.get() })
296        };
297
298        let (r1, r2) = tokio::join!(t1, t2);
299        assert!(r1.unwrap().is_ok());
300        assert!(r2.unwrap().is_ok());
301    }
302
303    // -- OIDC discovery -------------------------------------------------------
304
305    #[tokio::test]
306    async fn new_with_issuer_url_discovery() {
307        let server = MockServer::start();
308
309        // Mock the OIDC discovery endpoint.
310        let token_ep = format!("http://localhost:{}/oauth/token", server.port());
311        let _discovery_mock = server.mock(|when, then| {
312            when.method(GET).path("/.well-known/openid-configuration");
313            then.status(200)
314                .header("content-type", "application/json")
315                .body(format!(r#"{{"token_endpoint":"{token_ep}"}}"#));
316        });
317
318        // Mock the resolved token endpoint.
319        let _token_mock = server.mock(|when, then| {
320            when.method(POST).path("/oauth/token");
321            then.status(200)
322                .header("content-type", "application/json")
323                .body(token_json("tok-discovered", 3600));
324        });
325
326        let cfg = OAuthClientConfig {
327            issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
328            client_id: "test-client".into(),
329            client_secret: SecretString::new("test-secret"),
330            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
331            jitter_max: Duration::from_millis(0),
332            min_refresh_period: Duration::from_millis(100),
333            ..Default::default()
334        };
335
336        let token = Token::new(cfg).await.unwrap();
337        let secret = token.get().unwrap();
338        assert_eq!(secret.expose(), "tok-discovered");
339    }
340
341    #[tokio::test]
342    async fn discovery_not_repeated_on_invalidate() {
343        let server = MockServer::start();
344
345        // Mock the OIDC discovery endpoint.
346        let token_ep = format!("http://localhost:{}/oauth/token", server.port());
347        let discovery_mock = server.mock(|when, then| {
348            when.method(GET).path("/.well-known/openid-configuration");
349            then.status(200)
350                .header("content-type", "application/json")
351                .body(format!(r#"{{"token_endpoint":"{token_ep}"}}"#));
352        });
353
354        // Mock the resolved token endpoint.
355        let token_mock = server.mock(|when, then| {
356            when.method(POST).path("/oauth/token");
357            then.status(200)
358                .header("content-type", "application/json")
359                .body(token_json("tok-disc-inv", 3600));
360        });
361
362        let cfg = OAuthClientConfig {
363            issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
364            client_id: "test-client".into(),
365            client_secret: SecretString::new("test-secret"),
366            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
367            jitter_max: Duration::from_millis(0),
368            min_refresh_period: Duration::from_millis(100),
369            ..Default::default()
370        };
371
372        let token = Token::new(cfg).await.unwrap();
373        assert_eq!(discovery_mock.calls(), 1, "discovery: initial");
374        assert_eq!(token_mock.calls(), 1, "token: initial");
375
376        // Invalidate should re-fetch the token but NOT re-run discovery.
377        token.invalidate().await;
378
379        assert_eq!(
380            discovery_mock.calls(),
381            1,
382            "discovery must NOT be repeated on invalidate"
383        );
384        assert_eq!(token_mock.calls(), 2, "token: after invalidate");
385    }
386
387    // -- debug safety ---------------------------------------------------------
388
389    #[tokio::test]
390    async fn debug_does_not_reveal_tokens() {
391        let server = MockServer::start();
392
393        let _mock = server.mock(|when, then| {
394            when.method(POST).path("/token");
395            then.status(200)
396                .header("content-type", "application/json")
397                .body(token_json("super-secret-tok", 3600));
398        });
399
400        let token = Token::new(test_config(&server)).await.unwrap();
401        let dbg = format!("{token:?}");
402        assert!(
403            !dbg.contains("super-secret-tok"),
404            "Debug must not reveal token value: {dbg}"
405        );
406    }
407}