1use std::future::Future;
2use std::sync::Arc;
3
4use cts_common::WorkspaceId;
5use url::Url;
6use web_time::{SystemTime, UNIX_EPOCH};
7
8use crate::authorize_dto::AuthoriseResponse;
9use crate::refresher::Refresher;
10use crate::{http_client, AuthError, SecretToken, Token};
11
12#[cfg(not(target_arch = "wasm32"))]
25pub trait OidcProvider: Send + Sync {
26 fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send;
28}
29
30#[cfg(target_arch = "wasm32")]
32pub trait OidcProvider {
33 fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>>;
35}
36
37pub struct OidcProviderFn<F> {
62 fetch: F,
63}
64
65impl<F> OidcProviderFn<F> {
66 pub fn new(fetch: F) -> Self {
68 Self { fetch }
69 }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73impl<F, Fut> OidcProvider for OidcProviderFn<F>
74where
75 F: Fn() -> Fut + Send + Sync,
76 Fut: Future<Output = Result<SecretToken, AuthError>> + Send,
77{
78 fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send {
79 (self.fetch)()
80 }
81}
82
83#[cfg(target_arch = "wasm32")]
84impl<F, Fut> OidcProvider for OidcProviderFn<F>
85where
86 F: Fn() -> Fut,
87 Fut: Future<Output = Result<SecretToken, AuthError>>,
88{
89 fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> {
90 (self.fetch)()
91 }
92}
93
94pub(crate) struct OidcRefresher<P> {
112 oidc_provider: P,
113 workspace_id: WorkspaceId,
114 base_url: Url,
115 http_client: Arc<reqwest::Client>,
116}
117
118impl<P> OidcRefresher<P> {
119 pub(crate) fn new(oidc_provider: P, workspace_id: WorkspaceId, base_url: Url) -> Self {
120 Self {
121 oidc_provider,
122 workspace_id,
123 base_url,
124 http_client: Arc::new(http_client()),
125 }
126 }
127}
128
129impl<P: OidcProvider> Refresher for OidcRefresher<P> {
130 type Credential = ();
131
132 fn save(&self, _token: &Token) {
133 }
135
136 fn try_credential(&self, _token: Option<&mut Token>) -> Option<Self::Credential> {
137 Some(())
140 }
141
142 fn restore(&self, _token: &mut Token, _credential: Self::Credential) {
143 }
145
146 async fn refresh(&self, _credential: &Self::Credential) -> Result<Token, AuthError> {
147 let oidc_token = self.oidc_provider.fetch().await?;
148
149 let url = self.base_url.join("api/authorise")?;
150 tracing::debug!(url = %url, "federating OIDC token");
151
152 let resp = self
153 .http_client
154 .post(url)
155 .json(&OidcAuthoriseRequest {
156 oidc_token: oidc_token.as_str(),
157 workspace_id: self.workspace_id.as_str(),
158 })
159 .send()
160 .await?;
161
162 if !resp.status().is_success() {
163 let status = resp.status();
164 let body = resp.text().await.unwrap_or_default();
165 tracing::debug!(%status, %body, "OIDC federation failed");
166 return Err(AuthError::Server(format!("{status}: {body}")));
167 }
168
169 let auth_resp: AuthoriseResponse = resp.json().await?;
170 let now = SystemTime::now()
171 .duration_since(UNIX_EPOCH)
172 .unwrap_or_default()
173 .as_secs();
174
175 Ok(Token {
176 access_token: auth_resp.access_token,
177 token_type: "Bearer".to_string(),
178 expires_at: now + auth_resp.expiry,
179 refresh_token: None,
180 region: None,
181 client_id: None,
182 device_instance_id: None,
183 })
184 }
185}
186
187#[derive(serde::Serialize)]
188#[serde(rename_all = "camelCase")]
189struct OidcAuthoriseRequest<'a> {
190 oidc_token: &'a str,
191 workspace_id: &'a str,
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used)]
196mod tests {
197 use std::sync::atomic::{AtomicUsize, Ordering};
198 use std::sync::Arc;
199 use std::time::{SystemTime, UNIX_EPOCH};
200
201 use mocktail::prelude::*;
202
203 use super::*;
204 use crate::auto_refresh::{AutoRefresh, AutoRefreshError};
205 use crate::TokenStore;
206
207 const WORKSPACE_ID: &str = "ZVATKW3VHMFG27DY";
208
209 fn workspace_id() -> WorkspaceId {
210 WORKSPACE_ID.parse().unwrap()
211 }
212
213 fn auth_response_json(access: &str, expiry: u64) -> serde_json::Value {
214 serde_json::json!({ "accessToken": access, "expiry": expiry })
215 }
216
217 async fn start_server(mocks: MockSet) -> MockServer {
218 let server = MockServer::new_http("oidc-refresher-test").with_mocks(mocks);
219 server.start().await.unwrap();
220 server
221 }
222
223 fn counting_provider() -> (Arc<AtomicUsize>, impl OidcProvider) {
226 let calls = Arc::new(AtomicUsize::new(0));
227 let calls_clone = Arc::clone(&calls);
228 let provider = OidcProviderFn::new(move || {
229 let calls = Arc::clone(&calls_clone);
230 async move {
231 let n = calls.fetch_add(1, Ordering::SeqCst);
232 Ok(SecretToken::new(format!("jwt-{n}")))
233 }
234 });
235 (calls, provider)
236 }
237
238 fn make_strategy<P: OidcProvider>(
239 server: &MockServer,
240 provider: P,
241 ) -> AutoRefresh<OidcRefresher<P>> {
242 let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
243 AutoRefresh::with_store(refresher, crate::NoStore)
244 }
245
246 fn make_token(access: &str, expires_in_secs: u64) -> Token {
247 let now = SystemTime::now()
248 .duration_since(UNIX_EPOCH)
249 .unwrap()
250 .as_secs();
251 Token {
252 access_token: SecretToken::new(access),
253 token_type: "Bearer".to_string(),
254 expires_at: now + expires_in_secs,
255 refresh_token: None,
256 region: None,
257 client_id: None,
258 device_instance_id: None,
259 }
260 }
261
262 #[tokio::test]
263 async fn test_initial_federation() {
264 let mut mocks = MockSet::new();
265 mocks.mock(|when, then| {
266 when.post().path("/api/authorise");
267 then.json(auth_response_json("cts-token", 3600));
268 });
269 let server = start_server(mocks).await;
270 let (calls, provider) = counting_provider();
271 let strategy = make_strategy(&server, provider);
272
273 let token = strategy.get_token().await.unwrap();
274
275 assert_eq!(token.as_str(), "cts-token");
276 assert_eq!(
277 calls.load(Ordering::SeqCst),
278 1,
279 "initial federation should invoke the OIDC provider once"
280 );
281 }
282
283 #[test]
284 fn test_request_serialization() {
285 let body = serde_json::to_value(OidcAuthoriseRequest {
286 oidc_token: "the-jwt",
287 workspace_id: WORKSPACE_ID,
288 })
289 .unwrap();
290 assert_eq!(
291 body,
292 serde_json::json!({ "oidcToken": "the-jwt", "workspaceId": WORKSPACE_ID }),
293 "request body should carry exactly the OIDC token and workspace ID"
294 );
295 }
296
297 #[tokio::test]
298 async fn test_caches_token_after_initial_federation() {
299 let mut mocks = MockSet::new();
300 mocks.mock(|when, then| {
301 when.post().path("/api/authorise");
302 then.json(auth_response_json("cts-token", 3600));
303 });
304 let server = start_server(mocks).await;
305 let (calls, provider) = counting_provider();
306 let strategy = make_strategy(&server, provider);
307
308 assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
309
310 server.mocks().clear();
312 server.mocks().mock(|when, then| {
313 when.post().path("/api/authorise");
314 then.internal_server_error()
315 .json(serde_json::json!({"error": "should not be called"}));
316 });
317
318 assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
319 assert_eq!(
320 calls.load(Ordering::SeqCst),
321 1,
322 "cached token should be returned without re-federating"
323 );
324 }
325
326 #[tokio::test]
327 async fn test_re_federates_on_expiry() {
328 let mut mocks = MockSet::new();
329 mocks.mock(|when, then| {
330 when.post().path("/api/authorise");
331 then.json(auth_response_json("re-federated-token", 3600));
332 });
333 let server = start_server(mocks).await;
334
335 let (calls, provider) = counting_provider();
336 let store = Arc::new(crate::InMemoryTokenStore::new());
338 store.save(&make_token("stale-cts-token", 0)).await;
339
340 let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
341 let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
342
343 let token = strategy.get_token().await.unwrap();
344 assert_eq!(
345 token.as_str(),
346 "re-federated-token",
347 "expired cached token should trigger re-federation"
348 );
349 assert_eq!(
350 calls.load(Ordering::SeqCst),
351 1,
352 "re-federation should invoke the OIDC provider for a current JWT"
353 );
354 }
355
356 #[tokio::test]
357 async fn test_oidc_provider_failure_propagates() {
358 let mut mocks = MockSet::new();
359 mocks.mock(|when, then| {
360 when.post().path("/api/authorise");
361 then.json(auth_response_json("unreachable", 3600));
362 });
363 let server = start_server(mocks).await;
364
365 let provider = OidcProviderFn::new(|| async {
366 Err::<SecretToken, _>(AuthError::Server("provider exploded".to_string()))
367 });
368 let strategy = make_strategy(&server, provider);
369
370 let err = strategy.get_token().await.unwrap_err();
371 assert!(
372 matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
373 "OIDC provider failure should surface as an auth error, got: {err:?}"
374 );
375 }
376
377 #[tokio::test]
378 async fn test_server_rejection_propagates() {
379 let mut mocks = MockSet::new();
380 mocks.mock(|when, then| {
381 when.post().path("/api/authorise");
382 then.internal_server_error()
383 .json(serde_json::json!({"error": "workspace mismatch"}));
384 });
385 let server = start_server(mocks).await;
386 let (_calls, provider) = counting_provider();
387 let strategy = make_strategy(&server, provider);
388
389 let err = strategy.get_token().await.unwrap_err();
390 assert!(
391 matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
392 "a 500 from /api/authorise should surface as a server error, got: {err:?}"
393 );
394 }
395
396 #[tokio::test]
397 async fn test_loads_token_from_store_on_cold_start_no_http() {
398 let mut mocks = MockSet::new();
399 mocks.mock(|when, then| {
400 when.post().path("/api/authorise");
401 then.internal_server_error()
402 .json(serde_json::json!({"error": "should not be called"}));
403 });
404 let server = start_server(mocks).await;
405
406 let store = Arc::new(crate::InMemoryTokenStore::new());
407 store.save(&make_token("from-store", 3600)).await;
408
409 let (calls, provider) = counting_provider();
410 let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
411 let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
412
413 let token = strategy.get_token().await.unwrap();
414 assert_eq!(token.as_str(), "from-store");
415 assert_eq!(
416 calls.load(Ordering::SeqCst),
417 0,
418 "a fresh cached token should be used without invoking the OIDC provider"
419 );
420 }
421
422 #[tokio::test]
423 async fn test_persists_token_to_store_after_federation() {
424 let mut mocks = MockSet::new();
425 mocks.mock(|when, then| {
426 when.post().path("/api/authorise");
427 then.json(auth_response_json("freshly-federated", 3600));
428 });
429 let server = start_server(mocks).await;
430
431 let store = Arc::new(crate::InMemoryTokenStore::new());
432 let (_calls, provider) = counting_provider();
433 let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
434 let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
435
436 let token = strategy.get_token().await.unwrap();
437 assert_eq!(token.as_str(), "freshly-federated");
438
439 let saved = store
440 .load()
441 .await
442 .expect("store should hold a token after federation");
443 assert_eq!(saved.access_token().as_str(), "freshly-federated");
444 }
445}