Skip to main content

modkit_auth/oauth2/
fetch.rs

1//! One-shot `OAuth2` client credentials token fetch.
2//!
3//! Use [`fetch_token`] when you need a single token exchange without spawning a
4//! background refresh watcher.  This is the right choice for callers that manage
5//! their own cache (e.g. an auth plugin with a TTL-based token cache).
6//!
7//! For long-lived service singletons that benefit from automatic background
8//! refresh, use [`Token`](super::Token) instead.
9
10use std::fmt;
11use std::time::Duration;
12
13use aliri_tokens::sources::AsyncTokenSource;
14
15use super::config::OAuthClientConfig;
16use super::error::TokenError;
17use super::source::OAuthTokenSource;
18use modkit_utils::SecretString;
19
20/// Result of a one-shot `OAuth2` client credentials token exchange.
21///
22/// Contains the bearer token and the server-reported lifetime so callers can
23/// set per-entry cache TTLs.
24///
25/// `Debug` is manually implemented to redact [`bearer`](Self::bearer).
26pub struct FetchedToken {
27    /// The access token, wrapped in [`SecretString`] for safe handling.
28    pub bearer: SecretString,
29
30    /// Token lifetime as reported by the authorization server (`expires_in`),
31    /// or the configured [`default_ttl`](OAuthClientConfig::default_ttl) when
32    /// the server omits it.
33    pub expires_in: Duration,
34}
35
36/// `Debug` redacts the bearer value to prevent accidental exposure in logs.
37impl fmt::Debug for FetchedToken {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.debug_struct("FetchedToken")
40            .field("bearer", &"[REDACTED]")
41            .field("expires_in", &self.expires_in)
42            .finish()
43    }
44}
45
46/// Perform a single `OAuth2` client credentials token exchange.
47///
48/// This function validates the configuration, optionally resolves the token
49/// endpoint via OIDC discovery, fetches a token, and returns the bearer value
50/// alongside `expires_in` — all without spawning background tasks.
51///
52/// # Errors
53///
54/// Returns [`TokenError::ConfigError`] if the configuration is invalid.
55/// Returns [`TokenError::Http`] if the token (or discovery) request fails.
56/// Returns [`TokenError::UnsupportedTokenType`] if the server returns a
57/// non-Bearer token type.
58pub async fn fetch_token(mut config: OAuthClientConfig) -> Result<FetchedToken, TokenError> {
59    config.validate()?;
60
61    // Resolve issuer_url → token_endpoint via OIDC discovery (one-time).
62    if let Some(issuer_url) = config.issuer_url.take() {
63        let http_config = config
64            .http_config
65            .clone()
66            .unwrap_or_else(modkit_http::HttpClientConfig::token_endpoint);
67        let client = modkit_http::HttpClientBuilder::with_config(http_config)
68            .build()
69            .map_err(|e| {
70                TokenError::Http(crate::http_error::format_http_error(&e, "OIDC discovery"))
71            })?;
72        let resolved = super::discovery::discover_token_endpoint(&client, &issuer_url).await?;
73        config.token_endpoint = Some(resolved);
74    }
75
76    let mut source = OAuthTokenSource::new(&config)?;
77    let token = source.request_token().await?;
78
79    Ok(FetchedToken {
80        bearer: SecretString::new(token.access_token().as_str()),
81        expires_in: Duration::from_secs(token.lifetime().0),
82    })
83}
84
85#[cfg(test)]
86#[cfg_attr(coverage_nightly, coverage(off))]
87mod tests {
88    use super::*;
89    use httpmock::prelude::*;
90    use url::Url;
91
92    use super::super::types::ClientAuthMethod;
93
94    // -----------------------------------------------------------------------
95    // Helpers
96    // -----------------------------------------------------------------------
97
98    fn test_config(server: &MockServer) -> OAuthClientConfig {
99        OAuthClientConfig {
100            token_endpoint: Some(
101                Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
102            ),
103            client_id: "test-client".into(),
104            client_secret: SecretString::new("test-secret"),
105            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
106            jitter_max: Duration::from_millis(0),
107            min_refresh_period: Duration::from_millis(100),
108            ..Default::default()
109        }
110    }
111
112    fn token_json(token: &str, expires_in: u64) -> String {
113        format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
114    }
115
116    // -----------------------------------------------------------------------
117    // Config validation
118    // -----------------------------------------------------------------------
119
120    #[tokio::test]
121    async fn config_validated_before_fetch() {
122        let cfg = OAuthClientConfig {
123            token_endpoint: Some(Url::parse("https://a.example.com/token").unwrap()),
124            issuer_url: Some(Url::parse("https://b.example.com").unwrap()),
125            client_id: "test-client".into(),
126            client_secret: SecretString::new("test-secret"),
127            ..Default::default()
128        };
129
130        let err = fetch_token(cfg).await.unwrap_err();
131        assert!(
132            matches!(err, TokenError::ConfigError(ref msg) if msg.contains("mutually exclusive")),
133            "expected ConfigError, got: {err}"
134        );
135    }
136
137    // -----------------------------------------------------------------------
138    // OIDC discovery
139    // -----------------------------------------------------------------------
140
141    #[tokio::test]
142    async fn fetch_with_issuer_url_discovery() {
143        let server = MockServer::start();
144
145        let token_ep = format!("http://localhost:{}/oauth/token", server.port());
146        let _discovery_mock = server.mock(|when, then| {
147            when.method(GET).path("/.well-known/openid-configuration");
148            then.status(200)
149                .header("content-type", "application/json")
150                .body(format!(r#"{{"token_endpoint":"{token_ep}"}}"#));
151        });
152
153        let _token_mock = server.mock(|when, then| {
154            when.method(POST).path("/oauth/token");
155            then.status(200)
156                .header("content-type", "application/json")
157                .body(token_json("tok-discovered", 1800));
158        });
159
160        let cfg = OAuthClientConfig {
161            issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
162            client_id: "test-client".into(),
163            client_secret: SecretString::new("test-secret"),
164            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
165            jitter_max: Duration::from_millis(0),
166            min_refresh_period: Duration::from_millis(100),
167            ..Default::default()
168        };
169
170        let fetched = fetch_token(cfg).await.unwrap();
171        assert_eq!(fetched.bearer.expose(), "tok-discovered");
172        assert_eq!(fetched.expires_in, Duration::from_secs(1800));
173    }
174
175    #[tokio::test]
176    async fn discovery_failure_returns_error() {
177        let server = MockServer::start();
178
179        let _mock = server.mock(|when, then| {
180            when.method(GET).path("/.well-known/openid-configuration");
181            then.status(500).body("internal server error");
182        });
183
184        let cfg = OAuthClientConfig {
185            issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
186            client_id: "test-client".into(),
187            client_secret: SecretString::new("test-secret"),
188            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
189            ..Default::default()
190        };
191
192        let err = fetch_token(cfg).await.unwrap_err();
193        assert!(
194            matches!(
195                err,
196                TokenError::Http(ref msg) if msg.contains("OIDC discovery") && msg.contains("500")
197            ),
198            "expected Http error with OIDC discovery prefix, got: {err}"
199        );
200    }
201
202    // -----------------------------------------------------------------------
203    // Token fetch
204    // -----------------------------------------------------------------------
205
206    #[tokio::test]
207    async fn fetch_returns_bearer_and_expires_in() {
208        let server = MockServer::start();
209
210        let _mock = server.mock(|when, then| {
211            when.method(POST).path("/token");
212            then.status(200)
213                .header("content-type", "application/json")
214                .body(token_json("tok-happy", 3600));
215        });
216
217        let fetched = fetch_token(test_config(&server)).await.unwrap();
218        assert_eq!(fetched.bearer.expose(), "tok-happy");
219        assert_eq!(fetched.expires_in, Duration::from_secs(3600));
220    }
221
222    #[tokio::test]
223    async fn missing_expires_in_uses_default_ttl() {
224        let server = MockServer::start();
225
226        let _mock = server.mock(|when, then| {
227            when.method(POST).path("/token");
228            then.status(200)
229                .header("content-type", "application/json")
230                .body(r#"{"access_token":"tok-default"}"#);
231        });
232
233        let fetched = fetch_token(test_config(&server)).await.unwrap();
234        assert_eq!(fetched.bearer.expose(), "tok-default");
235        // default_ttl from OAuthClientConfig::default() is 5 min = 300s
236        assert_eq!(fetched.expires_in, Duration::from_secs(300));
237    }
238
239    #[tokio::test]
240    async fn expires_in_zero_returns_zero_duration() {
241        let server = MockServer::start();
242
243        let _mock = server.mock(|when, then| {
244            when.method(POST).path("/token");
245            then.status(200)
246                .header("content-type", "application/json")
247                .body(r#"{"access_token":"tok-zero","expires_in":0}"#);
248        });
249
250        let fetched = fetch_token(test_config(&server)).await.unwrap();
251        assert_eq!(fetched.expires_in, Duration::ZERO);
252    }
253
254    #[tokio::test]
255    async fn http_error_returns_token_error() {
256        let server = MockServer::start();
257
258        let _mock = server.mock(|when, then| {
259            when.method(POST).path("/token");
260            then.status(500).body("internal server error");
261        });
262
263        let err = fetch_token(test_config(&server)).await.unwrap_err();
264        assert!(
265            matches!(
266                err,
267                TokenError::Http(ref msg) if msg.contains("OAuth2 token") && msg.contains("500")
268            ),
269            "expected Http error, got: {err}"
270        );
271    }
272
273    #[tokio::test]
274    async fn unsupported_token_type_returns_error() {
275        let server = MockServer::start();
276
277        let _mock = server.mock(|when, then| {
278            when.method(POST).path("/token");
279            then.status(200)
280                .header("content-type", "application/json")
281                .body(r#"{"access_token":"tok","token_type":"mac"}"#);
282        });
283
284        let err = fetch_token(test_config(&server)).await.unwrap_err();
285        assert!(
286            matches!(err, TokenError::UnsupportedTokenType(ref t) if t == "mac"),
287            "expected UnsupportedTokenType(\"mac\"), got: {err}"
288        );
289    }
290
291    // -----------------------------------------------------------------------
292    // Security
293    // -----------------------------------------------------------------------
294
295    #[tokio::test]
296    async fn debug_does_not_reveal_bearer() {
297        let server = MockServer::start();
298
299        let _mock = server.mock(|when, then| {
300            when.method(POST).path("/token");
301            then.status(200)
302                .header("content-type", "application/json")
303                .body(token_json("super-secret-bearer", 3600));
304        });
305
306        let fetched = fetch_token(test_config(&server)).await.unwrap();
307        let dbg = format!("{fetched:?}");
308        assert!(
309            !dbg.contains("super-secret-bearer"),
310            "Debug must not reveal bearer value: {dbg}"
311        );
312        assert!(dbg.contains("[REDACTED]"), "Debug must contain [REDACTED]");
313    }
314
315    // -----------------------------------------------------------------------
316    // Auth methods
317    // -----------------------------------------------------------------------
318
319    #[tokio::test]
320    async fn form_auth_sends_credentials_in_body() {
321        let server = MockServer::start();
322
323        let mock = server.mock(|when, then| {
324            when.method(POST)
325                .path("/token")
326                .body_includes("client_id=test-client")
327                .body_includes("client_secret=test-secret");
328            then.status(200)
329                .header("content-type", "application/json")
330                .body(token_json("tok-form", 3600));
331        });
332
333        let mut cfg = test_config(&server);
334        cfg.auth_method = ClientAuthMethod::Form;
335        fetch_token(cfg).await.unwrap();
336        mock.assert();
337    }
338
339    #[tokio::test]
340    async fn basic_auth_sends_credentials_in_header() {
341        let server = MockServer::start();
342
343        // base64("test-client:test-secret") = "dGVzdC1jbGllbnQ6dGVzdC1zZWNyZXQ="
344        let mock = server.mock(|when, then| {
345            when.method(POST)
346                .path("/token")
347                .header("authorization", "Basic dGVzdC1jbGllbnQ6dGVzdC1zZWNyZXQ=");
348            then.status(200)
349                .header("content-type", "application/json")
350                .body(token_json("tok-basic", 3600));
351        });
352
353        let cfg = test_config(&server);
354        // Default auth_method is Basic.
355        fetch_token(cfg).await.unwrap();
356        mock.assert();
357    }
358}