1use std::fmt;
2use std::sync::Arc;
3use std::time::Duration;
4
5use aliri_clock::DurationSecs;
6use aliri_tokens::backoff::ErrorBackoffConfig;
7use aliri_tokens::jitter::RandomEarlyJitter;
8use aliri_tokens::{TokenStatus, TokenWatcher};
9use arc_swap::ArcSwap;
10
11use super::config::OAuthClientConfig;
12use super::error::TokenError;
13use super::source::OAuthTokenSource;
14use modkit_utils::SecretString;
15
16struct TokenInner {
21 watcher: TokenWatcher,
22}
23
24struct WatcherConfig {
26 jitter_max: Duration,
27 min_refresh_period: Duration,
28}
29
30#[derive(Clone)]
38pub struct Token {
39 inner: Arc<ArcSwap<TokenInner>>,
40 source_factory: Arc<dyn Fn() -> Result<OAuthTokenSource, TokenError> + Send + Sync>,
41 watcher_config: Arc<WatcherConfig>,
42}
43
44impl fmt::Debug for Token {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("Token").finish_non_exhaustive()
47 }
48}
49
50impl Token {
51 pub async fn new(mut config: OAuthClientConfig) -> Result<Self, TokenError> {
61 config.validate()?;
62
63 if let Some(issuer_url) = config.issuer_url.take() {
65 let http_config = config
66 .http_config
67 .clone()
68 .unwrap_or_else(modkit_http::HttpClientConfig::token_endpoint);
69 let client = modkit_http::HttpClientBuilder::with_config(http_config)
70 .build()
71 .map_err(|e| {
72 TokenError::Http(crate::http_error::format_http_error(&e, "OIDC discovery"))
73 })?;
74 let resolved = super::discovery::discover_token_endpoint(&client, &issuer_url).await?;
75 config.token_endpoint = Some(resolved);
76 }
77
78 let watcher_config = Arc::new(WatcherConfig {
79 jitter_max: config.jitter_max,
80 min_refresh_period: config.min_refresh_period,
81 });
82
83 let source = OAuthTokenSource::new(&config)?;
84 let watcher = spawn_watcher(source, &watcher_config).await?;
85
86 let source_factory: Arc<dyn Fn() -> Result<OAuthTokenSource, TokenError> + Send + Sync> =
87 Arc::new(move || OAuthTokenSource::new(&config));
88
89 Ok(Self {
90 inner: Arc::new(ArcSwap::from_pointee(TokenInner { watcher })),
91 source_factory,
92 watcher_config,
93 })
94 }
95
96 pub fn get(&self) -> Result<SecretString, TokenError> {
110 let guard = self.inner.load();
111 let borrowed = guard.watcher.token();
112 if matches!(borrowed.token_status(), TokenStatus::Expired) {
113 return Err(TokenError::Unavailable(
114 "token expired, refresh pending".into(),
115 ));
116 }
117 let raw = borrowed.access_token().as_str();
118 Ok(SecretString::new(raw))
119 }
120
121 pub async fn invalidate(&self) {
129 let source = match (self.source_factory)() {
130 Ok(s) => s,
131 Err(e) => {
132 tracing::warn!("OAuth2 token invalidation: failed to create source: {e}");
133 return;
134 }
135 };
136
137 let watcher = match spawn_watcher(source, &self.watcher_config).await {
138 Ok(w) => w,
139 Err(e) => {
140 tracing::warn!("OAuth2 token invalidation: initial fetch failed: {e}");
141 return;
142 }
143 };
144
145 self.inner.store(Arc::new(TokenInner { watcher }));
146 }
147}
148
149async fn spawn_watcher(
151 source: OAuthTokenSource,
152 config: &WatcherConfig,
153) -> Result<TokenWatcher, TokenError> {
154 let jitter = RandomEarlyJitter::new(DurationSecs(config.jitter_max.as_secs()));
155 let backoff =
156 ErrorBackoffConfig::new(config.min_refresh_period, config.min_refresh_period * 30, 2);
157
158 TokenWatcher::spawn_from_token_source(source, jitter, backoff).await
159}
160
161#[cfg(test)]
162#[cfg_attr(coverage_nightly, coverage(off))]
163mod tests {
164 use super::*;
165 use httpmock::prelude::*;
166 use url::Url;
167
168 fn test_config(server: &MockServer) -> OAuthClientConfig {
170 OAuthClientConfig {
171 token_endpoint: Some(
172 Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
173 ),
174 client_id: "test-client".into(),
175 client_secret: SecretString::new("test-secret"),
176 http_config: Some(modkit_http::HttpClientConfig::for_testing()),
177 jitter_max: Duration::from_millis(0),
179 min_refresh_period: Duration::from_millis(100),
180 ..Default::default()
181 }
182 }
183
184 fn token_json(token: &str, expires_in: u64) -> String {
185 format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
186 }
187
188 #[test]
191 fn token_is_send_sync_clone() {
192 fn assert_traits<T: Send + Sync + Clone>() {}
193 assert_traits::<Token>();
194 }
195
196 #[tokio::test]
199 async fn new_with_valid_config() {
200 let server = MockServer::start();
201
202 let _mock = server.mock(|when, then| {
203 when.method(POST).path("/token");
204 then.status(200)
205 .header("content-type", "application/json")
206 .body(token_json("tok-new", 3600));
207 });
208
209 let token = Token::new(test_config(&server)).await;
210 assert!(
211 token.is_ok(),
212 "Token::new() should succeed: {:?}",
213 token.err()
214 );
215 }
216
217 #[tokio::test]
218 async fn new_validates_config() {
219 let cfg = OAuthClientConfig {
220 token_endpoint: Some(Url::parse("https://a.example.com/token").unwrap()),
221 issuer_url: Some(Url::parse("https://b.example.com").unwrap()),
222 client_id: "test-client".into(),
223 client_secret: SecretString::new("test-secret"),
224 ..Default::default()
225 };
226 let err = Token::new(cfg).await.unwrap_err();
227 assert!(
228 matches!(err, TokenError::ConfigError(ref msg) if msg.contains("mutually exclusive")),
229 "expected ConfigError, got: {err}"
230 );
231 }
232
233 #[tokio::test]
236 async fn get_returns_secret_string() {
237 let server = MockServer::start();
238
239 let _mock = server.mock(|when, then| {
240 when.method(POST).path("/token");
241 then.status(200)
242 .header("content-type", "application/json")
243 .body(token_json("tok-get-test", 3600));
244 });
245
246 let token = Token::new(test_config(&server)).await.unwrap();
247 let secret = token.get().unwrap();
248
249 assert_eq!(secret.expose(), "tok-get-test");
250 }
251
252 #[tokio::test]
255 async fn invalidate_creates_new_watcher() {
256 let server = MockServer::start();
257
258 let mock = server.mock(|when, then| {
259 when.method(POST).path("/token");
260 then.status(200)
261 .header("content-type", "application/json")
262 .body(token_json("tok-inv", 3600));
263 });
264
265 let token = Token::new(test_config(&server)).await.unwrap();
266 assert_eq!(mock.calls(), 1, "initial fetch");
267
268 token.invalidate().await;
269
270 assert_eq!(mock.calls(), 2, "after invalidate");
272 }
273
274 #[tokio::test]
277 async fn concurrent_get_no_deadlock() {
278 let server = MockServer::start();
279
280 let _mock = server.mock(|when, then| {
281 when.method(POST).path("/token");
282 then.status(200)
283 .header("content-type", "application/json")
284 .body(token_json("tok-conc", 3600));
285 });
286
287 let token = Token::new(test_config(&server)).await.unwrap();
288
289 let t1 = {
290 let token = token.clone();
291 tokio::spawn(async move { token.get() })
292 };
293 let t2 = {
294 let token = token.clone();
295 tokio::spawn(async move { token.get() })
296 };
297
298 let (r1, r2) = tokio::join!(t1, t2);
299 assert!(r1.unwrap().is_ok());
300 assert!(r2.unwrap().is_ok());
301 }
302
303 #[tokio::test]
306 async fn new_with_issuer_url_discovery() {
307 let server = MockServer::start();
308
309 let token_ep = format!("http://localhost:{}/oauth/token", server.port());
311 let _discovery_mock = server.mock(|when, then| {
312 when.method(GET).path("/.well-known/openid-configuration");
313 then.status(200)
314 .header("content-type", "application/json")
315 .body(format!(r#"{{"token_endpoint":"{token_ep}"}}"#));
316 });
317
318 let _token_mock = server.mock(|when, then| {
320 when.method(POST).path("/oauth/token");
321 then.status(200)
322 .header("content-type", "application/json")
323 .body(token_json("tok-discovered", 3600));
324 });
325
326 let cfg = OAuthClientConfig {
327 issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
328 client_id: "test-client".into(),
329 client_secret: SecretString::new("test-secret"),
330 http_config: Some(modkit_http::HttpClientConfig::for_testing()),
331 jitter_max: Duration::from_millis(0),
332 min_refresh_period: Duration::from_millis(100),
333 ..Default::default()
334 };
335
336 let token = Token::new(cfg).await.unwrap();
337 let secret = token.get().unwrap();
338 assert_eq!(secret.expose(), "tok-discovered");
339 }
340
341 #[tokio::test]
342 async fn discovery_not_repeated_on_invalidate() {
343 let server = MockServer::start();
344
345 let token_ep = format!("http://localhost:{}/oauth/token", server.port());
347 let discovery_mock = server.mock(|when, then| {
348 when.method(GET).path("/.well-known/openid-configuration");
349 then.status(200)
350 .header("content-type", "application/json")
351 .body(format!(r#"{{"token_endpoint":"{token_ep}"}}"#));
352 });
353
354 let token_mock = server.mock(|when, then| {
356 when.method(POST).path("/oauth/token");
357 then.status(200)
358 .header("content-type", "application/json")
359 .body(token_json("tok-disc-inv", 3600));
360 });
361
362 let cfg = OAuthClientConfig {
363 issuer_url: Some(Url::parse(&format!("http://localhost:{}", server.port())).unwrap()),
364 client_id: "test-client".into(),
365 client_secret: SecretString::new("test-secret"),
366 http_config: Some(modkit_http::HttpClientConfig::for_testing()),
367 jitter_max: Duration::from_millis(0),
368 min_refresh_period: Duration::from_millis(100),
369 ..Default::default()
370 };
371
372 let token = Token::new(cfg).await.unwrap();
373 assert_eq!(discovery_mock.calls(), 1, "discovery: initial");
374 assert_eq!(token_mock.calls(), 1, "token: initial");
375
376 token.invalidate().await;
378
379 assert_eq!(
380 discovery_mock.calls(),
381 1,
382 "discovery must NOT be repeated on invalidate"
383 );
384 assert_eq!(token_mock.calls(), 2, "token: after invalidate");
385 }
386
387 #[tokio::test]
390 async fn debug_does_not_reveal_tokens() {
391 let server = MockServer::start();
392
393 let _mock = server.mock(|when, then| {
394 when.method(POST).path("/token");
395 then.status(200)
396 .header("content-type", "application/json")
397 .body(token_json("super-secret-tok", 3600));
398 });
399
400 let token = Token::new(test_config(&server)).await.unwrap();
401 let dbg = format!("{token:?}");
402 assert!(
403 !dbg.contains("super-secret-tok"),
404 "Debug must not reveal token value: {dbg}"
405 );
406 }
407}