Skip to main content

quilt_rs/
auth.rs

1//! OAuth 2.1 Authorization Code flow with PKCE for Quilt catalog authentication.
2//!
3//! Implements the following RFCs:
4//! - **RFC 6749** — OAuth 2.0 Authorization Framework (core flow)
5//! - **RFC 7636** — Proof Key for Code Exchange (PKCE)
6//! - **RFC 7591** — OAuth 2.0 Dynamic Client Registration (DCR)
7//!
8//! Terminology mapping (RFC → code):
9//! - *Authorization Endpoint* (RFC 6749 §3.1) → [`catalog_authorize_url`]
10//! - *Token Endpoint* (RFC 6749 §3.2) → [`connect_token_url`]
11//! - *Authorization Code* (RFC 6749 §1.3.1) → `OAuthParams::code`
12//! - *Code Verifier* (RFC 7636 §4.1) → `PkceChallenge::code_verifier`
13//! - *Code Challenge* (RFC 7636 §4.2) → `PkceChallenge::code_challenge`
14//! - *State* (RFC 6749 §10.12) — CSRF protection token, generated by [`random_state`]
15//! - *Client Registration Endpoint* (RFC 7591 §3) → [`connect_register_url`]
16//! - *Redirect URI* (RFC 6749 §3.1.2) → `OAuthParams::redirect_uri`
17
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21use std::sync::Mutex as StdMutex;
22use std::sync::Weak;
23
24use base64::Engine;
25use base64::engine::general_purpose::URL_SAFE_NO_PAD;
26use sha2::Digest;
27use sha2::Sha256;
28use tokio::sync::Mutex as AsyncMutex;
29
30use crate::Error;
31use crate::Res;
32use crate::error::AuthError;
33use crate::error::LoginError;
34use crate::io::remote::client::HttpClient;
35use crate::io::storage::LocalStorage;
36use crate::io::storage::Storage;
37use crate::io::storage::auth::AuthIo;
38use crate::io::storage::auth::Credentials;
39use crate::io::storage::auth::OAuthClient;
40use crate::io::storage::auth::Tokens;
41use crate::paths::DomainPaths;
42use chrono::serde::ts_seconds;
43use quilt_uri::Host;
44use serde::Deserialize;
45use serde::Deserializer;
46use serde::Serialize;
47use tracing::debug;
48use tracing::error;
49use tracing::info;
50use tracing::warn;
51
52/// Parameters for the Token Request (RFC 6749 §4.1.3) with PKCE extension.
53pub struct OAuthParams {
54    /// Authorization code received from the Authorization Endpoint (RFC 6749 §4.1.2)
55    pub code: String,
56    /// PKCE code verifier (RFC 7636 §4.1) — sent to the Token Endpoint for verification
57    pub code_verifier: String,
58    /// Redirect URI (RFC 6749 §3.1.2) — must match the value sent in the Authorization Request
59    pub redirect_uri: String,
60    /// Client identifier (RFC 6749 §2.2) obtained via DCR.
61    ///
62    /// The caller is responsible for ensuring this matches the `client_id`
63    /// stored in the [`OAuthClient`] for the target host (e.g. by calling
64    /// [`Auth::get_or_register_client`] and using its `client_id` directly).
65    pub client_id: String,
66}
67
68/// PKCE code verifier and challenge pair (RFC 7636).
69pub struct PkceChallenge {
70    /// Random verifier string — send to token endpoint
71    pub code_verifier: String,
72    /// S256 hash of verifier — send in the authorize URL
73    pub code_challenge: String,
74}
75
76/// Generate a PKCE code verifier and its S256 challenge.
77///
78/// The verifier is 64 random bytes, base64url-encoded (86 characters),
79/// well within RFC 7636 §4.1's 43–128 character range.
80pub fn pkce_challenge() -> PkceChallenge {
81    let mut random_bytes = [0u8; 64];
82    getrandom::fill(&mut random_bytes).expect("failed to generate random bytes");
83
84    let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
85    let code_challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
86
87    PkceChallenge {
88        code_verifier,
89        code_challenge,
90    }
91}
92
93/// Generate a random `state` parameter for CSRF protection (RFC 6749 §10.12).
94pub fn random_state() -> String {
95    let mut bytes = [0u8; 16];
96    getrandom::fill(&mut bytes).expect("failed to generate random bytes");
97    URL_SAFE_NO_PAD.encode(bytes)
98}
99
100// --- OAuth endpoint URLs ---
101//
102// OAuth uses two different hostnames derived from the catalog host:
103//
104// 1. **Catalog host** (`test.quilt.dev`) — the authorize endpoint lives here
105//    because the user's browser session (cookies) is on the catalog.
106//
107// 2. **Connect host** (`test-connect.quilt.dev`) — the token exchange and
108//    client registration (DCR) endpoints live on a separate subdomain.
109
110/// Authorization Endpoint (RFC 6749 §3.1) on the catalog host.
111///
112/// E.g., `test.quilt.dev` → `https://test.quilt.dev/connect/authorize`
113pub fn catalog_authorize_url(host: &Host) -> String {
114    format!("https://{host}/connect/authorize")
115}
116
117/// Derive the connect server hostname from the catalog host.
118///
119/// E.g., `test.quilt.dev` → `test-connect.quilt.dev`
120///
121/// # Assumptions
122///
123/// The catalog hostname is assumed to have exactly one label before the first
124/// dot (e.g. `test` in `test.quilt.dev`). Multi-label prefixes such as
125/// `a.b.quilt.dev` are not supported and will produce an incorrect result
126/// (`a-connect.b.quilt.dev` instead of a well-defined connect hostname).
127pub fn connect_host(host: &Host) -> String {
128    let s = host.to_string();
129    match s.split_once('.') {
130        Some((stack, domain)) => format!("{stack}-connect.{domain}"),
131        None => format!("{s}-connect"),
132    }
133}
134
135/// Token Endpoint (RFC 6749 §3.2) on the connect host.
136///
137/// E.g., `test.quilt.dev` → `https://test-connect.quilt.dev/auth/token`
138fn connect_token_url(host: &Host) -> String {
139    format!("https://{}/auth/token", connect_host(host))
140}
141
142/// Client Registration Endpoint (RFC 7591 §3) on the connect host.
143fn connect_register_url(host: &Host) -> String {
144    format!("https://{}/auth/register", connect_host(host))
145}
146
147/// DCR request body (RFC 7591).
148#[derive(Serialize)]
149struct DcrRequest {
150    client_name: String,
151    redirect_uris: Vec<String>,
152    token_endpoint_auth_method: String,
153}
154
155/// DCR response body (subset of fields we need).
156#[derive(Deserialize)]
157struct DcrResponse {
158    client_id: String,
159}
160
161/// Register a public OAuth client via Dynamic Client Registration (RFC 7591 §3.1).
162async fn register_client(
163    http_client: &impl HttpClient,
164    host: &Host,
165    redirect_uri: &str,
166) -> Res<OAuthClient> {
167    let register_url = connect_register_url(host);
168
169    let request = DcrRequest {
170        client_name: "QuiltSync".to_string(),
171        redirect_uris: vec![redirect_uri.to_string()],
172        token_endpoint_auth_method: "none".to_string(),
173    };
174
175    let response: DcrResponse = http_client.post_json(&register_url, &request).await?;
176
177    Ok(OAuthClient {
178        client_id: response.client_id,
179        redirect_uri: redirect_uri.to_string(),
180    })
181}
182
183#[derive(Deserialize, Serialize)]
184pub struct RemoteTokens {
185    pub access_token: String,
186    pub refresh_token: String,
187    #[serde(with = "ts_seconds")]
188    pub expires_at: chrono::DateTime<chrono::Utc>,
189}
190
191impl fmt::Debug for RemoteTokens {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        f.debug_struct("RemoteTokens")
194            .field("expires_at", &self.expires_at)
195            .field("access_token", &"[REDACTED]")
196            .field("refresh_token", &"[REDACTED]")
197            .finish_non_exhaustive()
198    }
199}
200
201impl From<RemoteTokens> for Tokens {
202    fn from(raw: RemoteTokens) -> Self {
203        Tokens {
204            access_token: raw.access_token,
205            refresh_token: raw.refresh_token,
206            expires_at: raw.expires_at,
207        }
208    }
209}
210
211/// Fallback TTL (seconds) when the token endpoint omits `expires_in`.
212///
213/// RFC 6749 §5.1 marks `expires_in` as RECOMMENDED, not required.
214/// We use 1 hour as a conservative default that avoids both excessive
215/// refresh loops (too short) and stale-token errors (too long).
216const DEFAULT_EXPIRES_IN: i64 = 3600;
217
218fn default_expires_in() -> i64 {
219    DEFAULT_EXPIRES_IN
220}
221
222/// Token response from the Connect OAuth token endpoint.
223///
224/// Uses `expires_in` (seconds until expiry) per RFC 6749,
225/// unlike `RemoteTokens` which uses `expires_at` (Unix timestamp).
226///
227/// `refresh_token` is `Option` because RFC 6749 §6 allows the server to omit
228/// it when rotating tokens; callers are responsible for falling back to the
229/// previous refresh token in that case.
230///
231/// `expires_in` is optional per RFC 6749 §5.1 (RECOMMENDED, not required);
232/// defaults to [`DEFAULT_EXPIRES_IN`] when absent.
233#[derive(Deserialize, Serialize)]
234struct OAuthTokenResponse {
235    access_token: String,
236    #[serde(default)]
237    refresh_token: Option<String>,
238    #[serde(default = "default_expires_in")]
239    expires_in: i64,
240}
241
242impl fmt::Debug for OAuthTokenResponse {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        f.debug_struct("OAuthTokenResponse")
245            .field("expires_in", &self.expires_in)
246            .field("access_token", &"[REDACTED]")
247            .field(
248                "refresh_token",
249                &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
250            )
251            .finish_non_exhaustive()
252    }
253}
254
255#[derive(Deserialize, Serialize)]
256#[serde(rename_all = "PascalCase")]
257struct RemoteCredentials {
258    access_key_id: String,
259    #[serde(deserialize_with = "date_from_rfc3339")]
260    expiration: chrono::DateTime<chrono::Utc>,
261    secret_access_key: String,
262    session_token: String,
263}
264
265impl fmt::Debug for RemoteCredentials {
266    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267        f.debug_struct("RemoteCredentials")
268            .field("expiration", &self.expiration)
269            .field("access_key_id", &"[REDACTED]")
270            .field("secret_access_key", &"[REDACTED]")
271            .field("session_token", &"[REDACTED]")
272            .finish_non_exhaustive()
273    }
274}
275
276impl From<RemoteCredentials> for Credentials {
277    fn from(raw: RemoteCredentials) -> Self {
278        Credentials {
279            access_key: raw.access_key_id,
280            secret_key: raw.secret_access_key,
281            token: raw.session_token,
282            expires_at: raw.expiration,
283        }
284    }
285}
286
287fn date_from_rfc3339<'de, D: Deserializer<'de>>(
288    deserializer: D,
289) -> Result<chrono::DateTime<chrono::Utc>, D::Error> {
290    use serde::de::Error;
291    String::deserialize(deserializer).and_then(|s| {
292        chrono::DateTime::parse_from_rfc3339(&s)
293            .map_err(|e| Error::custom(format!("Invalid RFC3339 date: {e}")))
294            .map(|dt| dt.with_timezone(&chrono::Utc))
295    })
296}
297
298#[derive(Deserialize, Serialize, Debug)]
299#[serde(rename_all = "camelCase")]
300struct QuiltStackConfig {
301    registry_url: url::Url,
302}
303
304async fn get_registry_url(http_client: &impl HttpClient, host: &Host) -> Res<url::Host> {
305    let QuiltStackConfig { registry_url } = http_client
306        .get(&format!("https://{host}/config.json"), None)
307        .await?;
308    Ok(url::Host::Domain(
309        registry_url
310            .domain()
311            .ok_or(LoginError::RequiredRegistryUrl(host.to_owned()))?
312            .to_string(),
313    ))
314}
315
316async fn get_auth_tokens(
317    http_client: &impl HttpClient,
318    host: &Host,
319    refresh_token: &str,
320) -> Res<Tokens> {
321    let registry = get_registry_url(http_client, host).await?;
322
323    let mut form_data: HashMap<String, String> = HashMap::new();
324    form_data.insert("refresh_token".to_string(), refresh_token.to_string());
325    let tokens_json: RemoteTokens = http_client
326        .post(&format!("https://{registry}/api/token"), &form_data)
327        .await?;
328    let tokens = Tokens::from(tokens_json);
329
330    Ok(tokens)
331}
332
333/// Token Request (RFC 6749 §4.1.3) with PKCE code verifier (RFC 7636 §4.5).
334async fn exchange_oauth_code(
335    http_client: &impl HttpClient,
336    host: &Host,
337    params: &OAuthParams,
338) -> Res<Tokens> {
339    let token_url = connect_token_url(host);
340
341    let mut form_data: HashMap<String, String> = HashMap::new();
342    form_data.insert("grant_type".to_string(), "authorization_code".to_string());
343    form_data.insert("code".to_string(), params.code.clone());
344    form_data.insert("code_verifier".to_string(), params.code_verifier.clone());
345    form_data.insert("redirect_uri".to_string(), params.redirect_uri.clone());
346    form_data.insert("client_id".to_string(), params.client_id.clone());
347
348    let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
349    let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
350    Ok(Tokens {
351        access_token: response.access_token,
352        refresh_token: response.refresh_token.ok_or_else(|| {
353            Error::Auth(
354                host.to_owned(),
355                AuthError::TokensExchange("server did not return a refresh token".to_string()),
356            )
357        })?,
358        expires_at,
359    })
360}
361
362/// Refresh Token Request (RFC 6749 §6) — exchange a refresh token for new tokens.
363async fn refresh_oauth_tokens(
364    http_client: &impl HttpClient,
365    host: &Host,
366    refresh_token: &str,
367    client_id: &str,
368) -> Res<Tokens> {
369    let token_url = connect_token_url(host);
370
371    let mut form_data: HashMap<String, String> = HashMap::new();
372    form_data.insert("grant_type".to_string(), "refresh_token".to_string());
373    form_data.insert("refresh_token".to_string(), refresh_token.to_string());
374    form_data.insert("client_id".to_string(), client_id.to_string());
375
376    let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
377    let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
378    Ok(Tokens {
379        access_token: response.access_token,
380        // RFC 6749 §6: server MAY omit the refresh token — retain the previous one if so.
381        refresh_token: response
382            .refresh_token
383            .unwrap_or_else(|| refresh_token.to_string()),
384        expires_at,
385    })
386}
387
388async fn refresh_credentials(
389    http_client: &impl HttpClient,
390    host: &Host,
391    access_token: &str,
392) -> Res<Credentials> {
393    let registry = get_registry_url(http_client, host).await?;
394
395    let creds_json: RemoteCredentials = http_client
396        .get(
397            &format!("https://{registry}/api/auth/get_credentials"),
398            Some(access_token),
399        )
400        .await?;
401
402    let credentials = Credentials::from(creds_json);
403
404    Ok(credentials)
405}
406
407/// Returns true when an error from the Connect **token endpoint** means the
408/// user must log in again.
409///
410/// Includes HTTP 400 because RFC 6749 §5.2 specifies that a revoked or
411/// expired refresh token produces `400 invalid_grant`, not 401.
412fn is_token_auth_error(e: &Error) -> bool {
413    matches!(
414        e,
415        Error::Reqwest(re) if re.status().is_some_and(|s| s == 400 || s == 401 || s == 403)
416    )
417}
418
419/// Returns true when an error from the registry **credentials endpoint** means
420/// the user must log in again.
421///
422/// Only 401/403 — a 400 from the registry means a malformed request (client
423/// bug), not an auth failure, so it should propagate rather than prompt login.
424fn is_credentials_auth_error(e: &Error) -> bool {
425    matches!(
426        e,
427        Error::Reqwest(re) if re.status().is_some_and(|s| s == 401 || s == 403)
428    )
429}
430
431/// Extracts the HTTP status code from an `Error::Reqwest`, if the wire-level
432/// error carried a response (network-level errors without a response return
433/// `None`). Used to include the status as a structured field in retry logs.
434fn http_status(e: &Error) -> Option<u16> {
435    match e {
436        Error::Reqwest(re) => re.status().map(|s| s.as_u16()),
437        _ => None,
438    }
439}
440
441/// Classifies the outcome of a retry attempt against an auth endpoint.
442///
443/// - `Ok(_)` → transient error recovered, log at `info!`.
444/// - `Err(e)` classified as auth → retry didn't help, upgrade to `LoginRequired`.
445/// - `Err(e)` otherwise → propagate as-is (includes nested `LoginRequired`
446///   from missing OAuth client state, IO errors, etc.).
447fn classify_retry_outcome<T>(
448    result: Res<T>,
449    is_auth_error: fn(&Error) -> bool,
450    endpoint: &str,
451    host: &Host,
452) -> Res<T> {
453    match result {
454        Ok(v) => {
455            info!(
456                "✔️ Recovered from transient auth error on {} for {}",
457                endpoint, host
458            );
459            Ok(v)
460        }
461        Err(e) if is_auth_error(&e) => {
462            warn!(
463                status = ?http_status(&e),
464                "❌ Auth error on {} for {} persisted after retry, login required: {}",
465                endpoint, host, e
466            );
467            Err(LoginError::Required(Some(host.to_owned())).into())
468        }
469        Err(e) => {
470            warn!(
471                status = ?http_status(&e),
472                "❌ Failed to refresh via {} for {} on retry: {}",
473                endpoint, host, e
474            );
475            Err(e)
476        }
477    }
478}
479
480/// Map of per-host refresh locks used to single-flight concurrent
481/// credential refreshes. The outer `StdMutex` is held only across the
482/// brief map lookup and is never held across an `.await`. The inner
483/// `AsyncMutex` is held across the HTTP refresh, serializing refreshes
484/// for a single host.
485///
486/// Entries are `Weak`, so the map size tracks *in-flight* refreshes
487/// rather than distinct hosts seen over the process lifetime. Racing
488/// callers upgrade the same `Weak` and share the mutex; once everyone
489/// drops their `Arc`, the entry becomes a dead `Weak` and is pruned
490/// on the next lookup. This matters for long-running server contexts
491/// that may authenticate against many distinct hosts.
492type RefreshLocks = Arc<StdMutex<HashMap<Host, Weak<AsyncMutex<()>>>>>;
493
494#[derive(Debug)]
495pub struct Auth<S: Storage = LocalStorage> {
496    pub paths: DomainPaths,
497    pub storage: Arc<S>,
498    refresh_locks: RefreshLocks,
499}
500
501impl<S: Storage> Clone for Auth<S> {
502    fn clone(&self) -> Self {
503        Self {
504            paths: self.paths.clone(),
505            storage: Arc::clone(&self.storage),
506            refresh_locks: Arc::clone(&self.refresh_locks),
507        }
508    }
509}
510
511impl<S: Storage + Send + Sync> Auth<S> {
512    pub fn new(paths: DomainPaths, storage: Arc<S>) -> Self {
513        Self {
514            paths,
515            storage,
516            refresh_locks: Arc::new(StdMutex::new(HashMap::new())),
517        }
518    }
519
520    /// Get the `Arc<Mutex>` for this host's refresh lock, creating it
521    /// on first use. The outer lock is only held for the brief map
522    /// lookup — never across `.await`. Dead `Weak` entries (mutex no
523    /// longer referenced by any in-flight refresh) are swept before
524    /// the lookup so the map stays bounded by active refreshes.
525    fn refresh_lock_for(&self, host: &Host) -> Arc<AsyncMutex<()>> {
526        let mut locks = self
527            .refresh_locks
528            .lock()
529            .unwrap_or_else(std::sync::PoisonError::into_inner);
530        locks.retain(|_, weak| weak.strong_count() > 0);
531        if let Some(arc) = locks.get(host).and_then(Weak::upgrade) {
532            return arc;
533        }
534        let arc = Arc::new(AsyncMutex::new(()));
535        locks.insert(host.clone(), Arc::downgrade(&arc));
536        arc
537    }
538
539    pub async fn login<T: HttpClient>(
540        &self,
541        http_client: &T,
542        host: &Host,
543        refresh_token: String,
544    ) -> Res {
545        info!("⏳ Logging in to host {} with refresh token", host);
546
547        let tokens = match self
548            .get_auth_tokens(http_client, host, &refresh_token)
549            .await
550        {
551            Ok(t) => t,
552            Err(e) => {
553                warn!("❌ Failed to get auth tokens for {}: {}", host, e);
554                return Err(e);
555            }
556        };
557
558        if let Err(e) = self.save_tokens(host, &tokens).await {
559            warn!("❌ Failed to save tokens for {}: {}", host, e);
560            return Err(e);
561        }
562
563        if let Err(e) = self
564            .refresh_credentials(http_client, host, &tokens.access_token)
565            .await
566        {
567            warn!("❌ Failed to refresh credentials for {}: {}", host, e);
568            return Err(e);
569        }
570
571        info!("✔️ Successfully logged in and authenticated to {}", host);
572        Ok(())
573    }
574
575    /// Get a stored OAuth client_id for the host, or register a new one via DCR.
576    pub async fn get_or_register_client<T: HttpClient>(
577        &self,
578        http_client: &T,
579        host: &Host,
580        redirect_uri: &str,
581    ) -> Res<OAuthClient> {
582        let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
583
584        if let Some(client) = auth_io.read_client().await? {
585            if client.redirect_uri == redirect_uri {
586                info!("✔️ Found existing OAuth client for {}", host);
587                return Ok(client);
588            }
589            info!(
590                "⚠️ Cached client has stale redirect_uri, re-registering for {}",
591                host
592            );
593        }
594
595        info!("⏳ Registering new OAuth client for {}", host);
596        let client = register_client(http_client, host, redirect_uri).await?;
597        auth_io.write_client(&client).await?;
598        info!(
599            "✔️ Registered OAuth client for {}: {}",
600            host, client.client_id
601        );
602
603        Ok(client)
604    }
605
606    /// Login using OAuth 2.1 Authorization Code flow with PKCE.
607    ///
608    /// Exchanges the authorization code for tokens, then fetches S3 credentials.
609    ///
610    /// # State / CSRF verification
611    ///
612    /// This method does not verify the `state` parameter returned by the
613    /// Authorization Endpoint. The caller is responsible for comparing the
614    /// `state` value in the callback against the value generated by
615    /// [`random_state`] before calling this method (RFC 6749 §10.12).
616    pub async fn login_oauth<T: HttpClient>(
617        &self,
618        http_client: &T,
619        host: &Host,
620        params: OAuthParams,
621    ) -> Res {
622        info!("⏳ OAuth login for host {}", host);
623
624        let tokens = exchange_oauth_code(http_client, host, &params)
625            .await
626            .map_err(|e| {
627                warn!("❌ Failed to exchange OAuth code for {}: {}", host, e);
628                e
629            })?;
630
631        self.save_tokens(host, &tokens).await.map_err(|e| {
632            warn!("❌ Failed to save tokens for {}: {}", host, e);
633            e
634        })?;
635
636        self.refresh_credentials(http_client, host, &tokens.access_token)
637            .await
638            .map_err(|e| {
639                warn!("❌ Failed to refresh credentials for {}: {}", host, e);
640                e
641            })?;
642
643        info!("✔️ OAuth login successful for {}", host);
644        Ok(())
645    }
646
647    async fn get_auth_tokens<T: HttpClient>(
648        &self,
649        http_client: &T,
650        host: &Host,
651        refresh_token: &str,
652    ) -> Res<Tokens> {
653        debug!("⏳ Getting auth tokens for host {:?}", host);
654        let tokens = get_auth_tokens(http_client, host, refresh_token).await?;
655        debug!("✔️ Successfully retrieved auth tokens");
656        Ok(tokens)
657    }
658
659    async fn save_tokens(&self, host: &Host, tokens: &Tokens) -> Res<()> {
660        debug!("⏳ Saving tokens for host {:?}", host);
661        let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
662        auth_io.write_tokens(tokens).await?;
663        debug!(
664            "✔️ Successfully saved tokens to the {:?}",
665            self.paths.auth_host(host)
666        );
667        Ok(())
668    }
669
670    /// Use the refresh token to obtain new access + refresh tokens from the
671    /// Connect token endpoint (RFC 6749 §6), then persist them.
672    async fn refresh_tokens<T: HttpClient>(
673        &self,
674        http_client: &T,
675        auth_io: &AuthIo<Arc<S>>,
676        host: &Host,
677        tokens: &Tokens,
678    ) -> Res<Tokens> {
679        let client = auth_io
680            .read_client()
681            .await?
682            .ok_or(LoginError::Required(Some(host.to_owned())))?;
683
684        let new_tokens =
685            refresh_oauth_tokens(http_client, host, &tokens.refresh_token, &client.client_id)
686                .await?;
687
688        auth_io.write_tokens(&new_tokens).await?;
689        info!("✔️ Successfully refreshed tokens for {}", host);
690
691        Ok(new_tokens)
692    }
693
694    /// `refresh_tokens` with a single transparent retry on auth-classified
695    /// errors (HTTP 400/401/403 from the token endpoint).
696    ///
697    /// A single 4xx is not necessarily a revoked refresh token — it can also
698    /// be a brief server-side token-validation hiccup (deploy, replica with
699    /// stale state, JWKS rotation). Only when two consecutive attempts return
700    /// a 4xx do we conclude the refresh token is actually bad and map to
701    /// `LoginError::Required`.
702    async fn refresh_tokens_with_retry<T: HttpClient>(
703        &self,
704        http_client: &T,
705        auth_io: &AuthIo<Arc<S>>,
706        host: &Host,
707        tokens: &Tokens,
708    ) -> Res<Tokens> {
709        let first_err = match self
710            .refresh_tokens(http_client, auth_io, host, tokens)
711            .await
712        {
713            Ok(t) => return Ok(t),
714            Err(e) => e,
715        };
716
717        if matches!(first_err, Error::Login(LoginError::Required(_))) {
718            warn!("❌ No OAuth client registered for {}, login required", host);
719            return Err(first_err);
720        }
721        if !is_token_auth_error(&first_err) {
722            warn!(
723                status = ?http_status(&first_err),
724                "❌ Failed to refresh tokens for {}: {}", host, first_err
725            );
726            return Err(first_err);
727        }
728
729        info!(
730            status = ?http_status(&first_err),
731            "⚠️ Auth error refreshing tokens for {}, retrying once: {}", host, first_err
732        );
733        classify_retry_outcome(
734            self.refresh_tokens(http_client, auth_io, host, tokens)
735                .await,
736            is_token_auth_error,
737            "token endpoint",
738            host,
739        )
740    }
741
742    /// `refresh_credentials` with a single transparent retry on auth-classified
743    /// errors (HTTP 401/403 from the credentials endpoint).
744    ///
745    /// A 4xx here usually means the server's view of the access token's
746    /// validity has shifted (clock skew, session-store replication lag, etc.).
747    /// Unlike the token-endpoint retry, this path **forces** a fresh access
748    /// token between the two attempts — retrying with the same stale token
749    /// would just reproduce the failure.
750    async fn refresh_credentials_with_retry<T: HttpClient>(
751        &self,
752        http_client: &T,
753        auth_io: &AuthIo<Arc<S>>,
754        host: &Host,
755        access_token: &str,
756    ) -> Res<Credentials> {
757        let first_err = match self
758            .refresh_credentials(http_client, host, access_token)
759            .await
760        {
761            Ok(c) => return Ok(c),
762            Err(e) => e,
763        };
764
765        if !is_credentials_auth_error(&first_err) {
766            warn!(
767                status = ?http_status(&first_err),
768                "❌ Failed to refresh credentials for {}: {}", host, first_err
769            );
770            return Err(first_err);
771        }
772
773        info!(
774            status = ?http_status(&first_err),
775            "⚠️ Auth error refreshing credentials for {}, \
776             force-refreshing token and retrying: {}",
777            host, first_err
778        );
779
780        // Force-refresh the access token, bypassing the 60s proactive check.
781        let tokens = auth_io
782            .read_tokens()
783            .await?
784            .ok_or_else(|| LoginError::Required(Some(host.to_owned())))?;
785        let new_tokens = self
786            .refresh_tokens_with_retry(http_client, auth_io, host, &tokens)
787            .await?;
788
789        classify_retry_outcome(
790            self.refresh_credentials(http_client, host, &new_tokens.access_token)
791                .await,
792            is_credentials_auth_error,
793            "credentials endpoint",
794            host,
795        )
796    }
797
798    async fn refresh_credentials<T: HttpClient>(
799        &self,
800        http_client: &T,
801        host: &Host,
802        access_token: &str,
803    ) -> Res<Credentials> {
804        debug!("⏳ Refreshing credentials for host {:?}", host);
805        let credentials = refresh_credentials(http_client, host, access_token).await?;
806
807        let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
808        auth_io.write_credentials(&credentials).await?;
809
810        debug!(
811            "✔️ Successfully refreshed credentials in {:?}",
812            self.paths.auth_host(host)
813        );
814        Ok(credentials)
815    }
816
817    pub async fn get_credentials_or_refresh<T: HttpClient>(
818        &self,
819        http_client: &T,
820        host: &Host,
821    ) -> Res<Credentials> {
822        info!("⏳ Getting or refreshing credentials for {}", host);
823        let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
824
825        match auth_io.read_credentials().await {
826            Ok(Some(creds)) => {
827                debug!("✔️ Found valid credentials for {}", host);
828                return Ok(creds);
829            }
830            Ok(None) => {
831                info!("❌ No existing credentials found for {}", host);
832            }
833            Err(e) => {
834                error!("❌ Failed to read credentials for {}: {}", host, e);
835                return Err(Error::Auth(
836                    host.to_owned(),
837                    AuthError::CredentialsRead(e.to_string()),
838                ));
839            }
840        }
841
842        // Serialize refreshes for this host so N concurrent callers
843        // fire one HTTP `/get_credentials` call instead of N. The
844        // loser of the race re-reads the credentials the winner
845        // wrote to disk and returns them without hitting the network.
846        let lock = self.refresh_lock_for(host);
847        let _guard = lock.lock().await;
848
849        match auth_io.read_credentials().await {
850            Ok(Some(creds)) => {
851                debug!("✔️ Another task refreshed credentials for {}", host);
852                return Ok(creds);
853            }
854            Ok(None) => {}
855            Err(e) => {
856                error!("❌ Failed to re-read credentials for {}: {}", host, e);
857                return Err(Error::Auth(
858                    host.to_owned(),
859                    AuthError::CredentialsRead(e.to_string()),
860                ));
861            }
862        }
863
864        let tokens = match auth_io.read_tokens().await {
865            Ok(Some(tokens)) => tokens,
866            Ok(None) => {
867                warn!("❌ No tokens found for {}, login required", host);
868                return Err(LoginError::Required(Some(host.to_owned())).into());
869            }
870            Err(e) => {
871                error!("❌ Failed to read tokens for {}: {}", host, e);
872                return Err(Error::Auth(
873                    host.to_owned(),
874                    AuthError::TokensRead(e.to_string()),
875                ));
876            }
877        };
878
879        // If the access token is expired, try to refresh it using the refresh token.
880        let access_token =
881            if tokens.expires_at <= chrono::Utc::now() + chrono::Duration::seconds(60) {
882                info!(
883                    "⏳ Access token expired for {}, refreshing via refresh token",
884                    host
885                );
886                self.refresh_tokens_with_retry(http_client, &auth_io, host, &tokens)
887                    .await?
888                    .access_token
889            } else {
890                tokens.access_token
891            };
892
893        info!("⏳ Refreshing credentials using access token for {}", host);
894        let creds = self
895            .refresh_credentials_with_retry(http_client, &auth_io, host, &access_token)
896            .await?;
897        info!("✔️ Successfully refreshed credentials for {}", host);
898        Ok(creds)
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    use async_trait::async_trait;
907    use reqwest::header::HeaderMap;
908    use test_log::test;
909
910    use crate::io::storage::mocks::MockStorage;
911    use crate::paths::DomainPaths;
912
913    const ACCESS_TOKEN: &str = "test-access-token";
914    const REFRESH_TOKEN: &str = "test-refresh-token";
915    const TIMESTAMP: i64 = 1_708_444_800;
916
917    fn get_host() -> Host {
918        "test.quilt.dev".parse().unwrap()
919    }
920
921    fn get_registry() -> String {
922        "registry-test.quilt.dev".to_string()
923    }
924
925    struct TestHttpClient;
926
927    #[async_trait]
928    impl HttpClient for TestHttpClient {
929        async fn get<T: serde::de::DeserializeOwned>(
930            &self,
931            url: &str,
932            auth_token: Option<&str>,
933        ) -> Res<T> {
934            let registry = get_registry();
935
936            match url {
937                u if u == format!("https://{}/config.json", get_host()) => {
938                    let config = QuiltStackConfig {
939                        registry_url: format!("https://{registry}").parse()?,
940                    };
941                    Ok(serde_json::from_value(serde_json::to_value(config)?)?)
942                }
943                u if u == format!("https://{registry}/api/auth/get_credentials") => {
944                    assert_eq!(auth_token, Some(ACCESS_TOKEN));
945                    let creds = RemoteCredentials {
946                        access_key_id: "test-access-key".to_string(),
947                        secret_access_key: "test-secret-key".to_string(),
948                        session_token: "test-session-token".to_string(),
949                        expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
950                    };
951                    Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
952                }
953                _ => panic!("Unexpected URL: {url}"),
954            }
955        }
956
957        async fn head(&self, _url: &str) -> Res<HeaderMap> {
958            unimplemented!("head is not used in this test")
959        }
960
961        async fn post<T: serde::de::DeserializeOwned>(
962            &self,
963            url: &str,
964            form_data: &HashMap<String, String>,
965        ) -> Res<T> {
966            assert_eq!(url, format!("https://{}/api/token", get_registry()));
967
968            // Verify form data contains the refresh token
969            assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
970
971            let tokens = RemoteTokens {
972                access_token: ACCESS_TOKEN.to_string(),
973                refresh_token: "new-refresh-token".to_string(),
974                expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
975            };
976            Ok(serde_json::from_value(serde_json::to_value(tokens)?)?)
977        }
978
979        async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
980            &self,
981            _url: &str,
982            _body: &B,
983        ) -> Res<T> {
984            unimplemented!("post_json is not used in this test")
985        }
986    }
987
988    #[test(tokio::test)]
989    async fn test_get_registry_url() {
990        let client = TestHttpClient;
991        let result = get_registry_url(&client, &get_host()).await.unwrap();
992        assert_eq!(
993            result,
994            url::Host::Domain("registry-test.quilt.dev".to_string())
995        );
996    }
997
998    #[test(tokio::test)]
999    async fn test_get_auth_tokens() {
1000        let client = TestHttpClient;
1001        let tokens = get_auth_tokens(&client, &get_host(), REFRESH_TOKEN)
1002            .await
1003            .unwrap();
1004        assert_eq!(tokens.access_token, ACCESS_TOKEN);
1005        assert_eq!(tokens.refresh_token, "new-refresh-token");
1006        assert_eq!(
1007            tokens.expires_at,
1008            chrono::DateTime::from_timestamp(1_708_444_800, 0).unwrap()
1009        );
1010    }
1011
1012    #[test(tokio::test)]
1013    async fn test_refresh_credentials() {
1014        let client = TestHttpClient;
1015        let credentials = refresh_credentials(&client, &get_host(), ACCESS_TOKEN)
1016            .await
1017            .unwrap();
1018        assert_eq!(credentials.access_key, "test-access-key");
1019        assert_eq!(credentials.secret_key, "test-secret-key");
1020        assert_eq!(credentials.token, "test-session-token");
1021        assert_eq!(
1022            credentials.expires_at,
1023            chrono::DateTime::from_timestamp(1_708_444_800, 0).unwrap()
1024        );
1025    }
1026
1027    #[test(tokio::test)]
1028    async fn test_auth_refresh_credentials() -> Res {
1029        let storage = Arc::new(MockStorage::default());
1030        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1031        let auth = Auth::new(paths.clone(), storage.clone());
1032        let host = get_host();
1033
1034        let credentials = auth
1035            .refresh_credentials(&TestHttpClient, &host, ACCESS_TOKEN)
1036            .await?;
1037
1038        // Verify returned credentials
1039        assert_eq!(credentials.access_key, "test-access-key");
1040        assert_eq!(credentials.secret_key, "test-secret-key");
1041        assert_eq!(credentials.token, "test-session-token");
1042        assert_eq!(
1043            credentials.expires_at,
1044            chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap()
1045        );
1046
1047        // Verify credentials were persisted. Note: read_credentials() filters
1048        // expired credentials, so we deserialize directly from the raw bytes.
1049        use crate::io::storage::StorageExt;
1050        let creds_path = paths.auth_host(&host).join(crate::paths::AUTH_CREDENTIALS);
1051        let bytes = storage.read_bytes(&creds_path).await?;
1052        let read_creds: Credentials = serde_json::from_slice(&bytes)?;
1053        assert_eq!(read_creds.access_key, credentials.access_key);
1054        assert_eq!(read_creds.secret_key, credentials.secret_key);
1055        assert_eq!(read_creds.token, credentials.token);
1056        assert_eq!(read_creds.expires_at, credentials.expires_at);
1057
1058        Ok(())
1059    }
1060
1061    #[test]
1062    fn test_remote_credentials_deserialization() {
1063        // Test valid RFC3339 date
1064        let valid_json = r#"{
1065            "AccessKeyId": "test-key",
1066            "Expiration": "2024-02-20T15:00:00Z",
1067            "SecretAccessKey": "test-secret",
1068            "SessionToken": "test-token"
1069        }"#;
1070
1071        let creds: RemoteCredentials = serde_json::from_str(valid_json).unwrap();
1072        assert_eq!(creds.access_key_id, "test-key");
1073        assert_eq!(creds.secret_access_key, "test-secret");
1074        assert_eq!(creds.session_token, "test-token");
1075        assert_eq!(
1076            creds.expiration,
1077            chrono::DateTime::parse_from_rfc3339("2024-02-20T15:00:00Z")
1078                .unwrap()
1079                .with_timezone(&chrono::Utc)
1080        );
1081
1082        // Test invalid RFC3339 date
1083        let invalid_json = r#"{
1084            "AccessKeyId": "test-key",
1085            "Expiration": "2024-02-20 15:00:00",
1086            "SecretAccessKey": "test-secret",
1087            "SessionToken": "test-token"
1088        }"#;
1089
1090        let error = serde_json::from_str::<RemoteCredentials>(invalid_json).unwrap_err();
1091        assert!(error.to_string().contains("Invalid RFC3339 date"));
1092    }
1093
1094    const AUTH_CODE: &str = "test-auth-code";
1095    const CODE_VERIFIER: &str = "test-code-verifier-that-is-at-least-43-characters-long";
1096    const CLIENT_ID: &str = "test-client-id";
1097    const REDIRECT_URI: &str = "quilt://auth/callback?host=test.quilt.dev";
1098
1099    struct OAuthTestHttpClient {
1100        /// The access token expected when hitting the credentials endpoint.
1101        expected_credentials_token: &'static str,
1102    }
1103
1104    impl Default for OAuthTestHttpClient {
1105        fn default() -> Self {
1106            Self {
1107                expected_credentials_token: ACCESS_TOKEN,
1108            }
1109        }
1110    }
1111
1112    #[async_trait]
1113    impl HttpClient for OAuthTestHttpClient {
1114        async fn get<T: serde::de::DeserializeOwned>(
1115            &self,
1116            url: &str,
1117            auth_token: Option<&str>,
1118        ) -> Res<T> {
1119            let registry = get_registry();
1120
1121            match url {
1122                u if u == format!("https://{}/config.json", get_host()) => {
1123                    let config = QuiltStackConfig {
1124                        registry_url: format!("https://{registry}").parse()?,
1125                    };
1126                    Ok(serde_json::from_value(serde_json::to_value(config)?)?)
1127                }
1128                u if u == format!("https://{registry}/api/auth/get_credentials") => {
1129                    assert_eq!(auth_token, Some(self.expected_credentials_token));
1130                    let creds = RemoteCredentials {
1131                        access_key_id: "oauth-access-key".to_string(),
1132                        secret_access_key: "oauth-secret-key".to_string(),
1133                        session_token: "oauth-session-token".to_string(),
1134                        expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1135                    };
1136                    Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
1137                }
1138                _ => panic!("Unexpected GET URL: {url}"),
1139            }
1140        }
1141
1142        async fn head(&self, _url: &str) -> Res<HeaderMap> {
1143            unimplemented!()
1144        }
1145
1146        async fn post<T: serde::de::DeserializeOwned>(
1147            &self,
1148            url: &str,
1149            form_data: &HashMap<String, String>,
1150        ) -> Res<T> {
1151            assert_eq!(url, connect_token_url(&get_host()));
1152
1153            let tokens = match form_data.get("grant_type").map(String::as_str) {
1154                Some("authorization_code") => {
1155                    assert_eq!(form_data.get("code").unwrap(), AUTH_CODE);
1156                    assert_eq!(form_data.get("code_verifier").unwrap(), CODE_VERIFIER);
1157                    assert_eq!(form_data.get("redirect_uri").unwrap(), REDIRECT_URI);
1158                    assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
1159                    OAuthTokenResponse {
1160                        access_token: ACCESS_TOKEN.to_string(),
1161                        refresh_token: Some("oauth-refresh-token".to_string()),
1162                        expires_in: 3600,
1163                    }
1164                }
1165                Some("refresh_token") => {
1166                    assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
1167                    assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
1168                    OAuthTokenResponse {
1169                        access_token: "refreshed-access-token".to_string(),
1170                        refresh_token: Some("new-refresh-token".to_string()),
1171                        expires_in: 3600,
1172                    }
1173                }
1174                other => panic!("Unexpected grant_type: {other:?}"),
1175            };
1176            Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
1177        }
1178
1179        async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1180            &self,
1181            url: &str,
1182            body: &B,
1183        ) -> Res<T> {
1184            assert_eq!(url, connect_register_url(&get_host()));
1185            let json = serde_json::to_value(body)?;
1186            assert_eq!(json["client_name"], "QuiltSync");
1187            assert_eq!(json["token_endpoint_auth_method"], "none");
1188            let redirect_uris = json["redirect_uris"].as_array().expect("redirect_uris");
1189            assert_eq!(redirect_uris.len(), 1);
1190            assert!(
1191                redirect_uris[0]
1192                    .as_str()
1193                    .unwrap()
1194                    .starts_with("quilt://auth/callback?host=")
1195            );
1196            Ok(serde_json::from_value(serde_json::json!({
1197                "client_id": "test-dcr-client-id"
1198            }))?)
1199        }
1200    }
1201
1202    #[test]
1203    fn test_connect_host() {
1204        let host: Host = "test.quilt.dev".parse().unwrap();
1205        assert_eq!(connect_host(&host), "test-connect.quilt.dev");
1206    }
1207
1208    #[test]
1209    fn test_connect_token_url() {
1210        let host: Host = "test.quilt.dev".parse().unwrap();
1211        assert_eq!(
1212            connect_token_url(&host),
1213            "https://test-connect.quilt.dev/auth/token"
1214        );
1215    }
1216
1217    #[test(tokio::test)]
1218    async fn test_exchange_oauth_code() {
1219        let client = OAuthTestHttpClient::default();
1220        let params = OAuthParams {
1221            code: AUTH_CODE.to_string(),
1222            code_verifier: CODE_VERIFIER.to_string(),
1223            redirect_uri: REDIRECT_URI.to_string(),
1224            client_id: CLIENT_ID.to_string(),
1225        };
1226        let tokens = exchange_oauth_code(&client, &get_host(), &params)
1227            .await
1228            .unwrap();
1229        assert_eq!(tokens.access_token, ACCESS_TOKEN);
1230        assert_eq!(tokens.refresh_token, "oauth-refresh-token");
1231    }
1232
1233    #[test]
1234    fn test_pkce_challenge() {
1235        let pkce = pkce_challenge();
1236
1237        // Verifier should be 86 characters (64 bytes base64url-encoded without padding)
1238        assert_eq!(pkce.code_verifier.len(), 86);
1239
1240        // Challenge should be 43 characters (SHA-256 is 32 bytes, base64url-encoded)
1241        assert_eq!(pkce.code_challenge.len(), 43);
1242
1243        // Verify the challenge is the S256 hash of the verifier
1244        let expected_challenge =
1245            URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.code_verifier.as_bytes()));
1246        assert_eq!(pkce.code_challenge, expected_challenge);
1247
1248        // Two calls should produce different verifiers
1249        let pkce2 = pkce_challenge();
1250        assert_ne!(pkce.code_verifier, pkce2.code_verifier);
1251    }
1252
1253    // RFC 7636 §4.1: code verifier must use only unreserved chars: ALPHA / DIGIT / "-" / "." / "_" / "~"
1254    #[test]
1255    fn test_pkce_verifier_charset_rfc7636() {
1256        let pkce = pkce_challenge();
1257        for ch in pkce.code_verifier.chars() {
1258            assert!(
1259                ch.is_ascii_alphanumeric() || matches!(ch, '-' | '.' | '_' | '~'),
1260                "code_verifier contains char '{ch}' not allowed by RFC 7636 §4.1"
1261            );
1262        }
1263    }
1264
1265    #[test(tokio::test)]
1266    async fn test_login_oauth() -> Res {
1267        let storage = Arc::new(MockStorage::default());
1268        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1269        let auth = Auth::new(paths, storage);
1270        let host = get_host();
1271
1272        let params = OAuthParams {
1273            code: AUTH_CODE.to_string(),
1274            code_verifier: CODE_VERIFIER.to_string(),
1275            redirect_uri: REDIRECT_URI.to_string(),
1276            client_id: CLIENT_ID.to_string(),
1277        };
1278
1279        auth.login_oauth(&OAuthTestHttpClient::default(), &host, params)
1280            .await?;
1281        Ok(())
1282    }
1283
1284    #[test(tokio::test)]
1285    async fn test_refresh_oauth_tokens() -> Res {
1286        let tokens = refresh_oauth_tokens(
1287            &OAuthTestHttpClient::default(),
1288            &get_host(),
1289            REFRESH_TOKEN,
1290            CLIENT_ID,
1291        )
1292        .await?;
1293        assert_eq!(tokens.access_token, "refreshed-access-token");
1294        assert_eq!(tokens.refresh_token, "new-refresh-token");
1295        Ok(())
1296    }
1297
1298    // RFC 6749 §6: if the server omits `refresh_token` in the refresh response,
1299    // the client MUST retain the previous refresh token.
1300    #[test(tokio::test)]
1301    async fn test_refresh_oauth_tokens_retains_old_when_omitted() -> Res {
1302        struct NoRefreshTokenClient;
1303
1304        #[async_trait]
1305        impl HttpClient for NoRefreshTokenClient {
1306            async fn get<T: serde::de::DeserializeOwned>(
1307                &self,
1308                _: &str,
1309                _: Option<&str>,
1310            ) -> Res<T> {
1311                unimplemented!()
1312            }
1313            async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
1314                unimplemented!()
1315            }
1316            async fn post<T: serde::de::DeserializeOwned>(
1317                &self,
1318                _: &str,
1319                _: &HashMap<String, String>,
1320            ) -> Res<T> {
1321                let resp = OAuthTokenResponse {
1322                    access_token: "new-access-token".to_string(),
1323                    refresh_token: None, // server omits refresh_token
1324                    expires_in: DEFAULT_EXPIRES_IN,
1325                };
1326                Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
1327            }
1328            async fn post_json<
1329                T: serde::de::DeserializeOwned,
1330                B: serde::Serialize + Send + Sync,
1331            >(
1332                &self,
1333                _: &str,
1334                _: &B,
1335            ) -> Res<T> {
1336                unimplemented!()
1337            }
1338        }
1339
1340        let tokens =
1341            refresh_oauth_tokens(&NoRefreshTokenClient, &get_host(), REFRESH_TOKEN, CLIENT_ID)
1342                .await?;
1343        assert_eq!(tokens.access_token, "new-access-token");
1344        // Old refresh token must be retained
1345        assert_eq!(tokens.refresh_token, REFRESH_TOKEN);
1346        Ok(())
1347    }
1348
1349    // RFC 6749 §4.1.4 + §5.1: initial code exchange MUST return a refresh_token;
1350    // if the server omits it the client should surface an error (not silently proceed).
1351    #[test(tokio::test)]
1352    async fn test_exchange_oauth_code_errors_when_refresh_token_missing() {
1353        struct NoRefreshTokenClient;
1354
1355        #[async_trait]
1356        impl HttpClient for NoRefreshTokenClient {
1357            async fn get<T: serde::de::DeserializeOwned>(
1358                &self,
1359                _: &str,
1360                _: Option<&str>,
1361            ) -> Res<T> {
1362                unimplemented!()
1363            }
1364            async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
1365                unimplemented!()
1366            }
1367            async fn post<T: serde::de::DeserializeOwned>(
1368                &self,
1369                _: &str,
1370                _: &HashMap<String, String>,
1371            ) -> Res<T> {
1372                let resp = OAuthTokenResponse {
1373                    access_token: ACCESS_TOKEN.to_string(),
1374                    refresh_token: None,
1375                    expires_in: DEFAULT_EXPIRES_IN,
1376                };
1377                Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
1378            }
1379            async fn post_json<
1380                T: serde::de::DeserializeOwned,
1381                B: serde::Serialize + Send + Sync,
1382            >(
1383                &self,
1384                _: &str,
1385                _: &B,
1386            ) -> Res<T> {
1387                unimplemented!()
1388            }
1389        }
1390
1391        let params = OAuthParams {
1392            code: AUTH_CODE.to_string(),
1393            code_verifier: CODE_VERIFIER.to_string(),
1394            redirect_uri: REDIRECT_URI.to_string(),
1395            client_id: CLIENT_ID.to_string(),
1396        };
1397        let result = exchange_oauth_code(&NoRefreshTokenClient, &get_host(), &params).await;
1398        assert!(
1399            matches!(result, Err(Error::Auth(_, AuthError::TokensExchange(_)))),
1400            "expected TokensExchange error, got: {result:?}"
1401        );
1402    }
1403
1404    // RFC 6749 §5.1: `expires_in` is RECOMMENDED, not required. If omitted,
1405    // the client should fall back to a safe default rather than failing.
1406    #[test]
1407    fn test_oauth_token_response_missing_expires_in() {
1408        let json = r#"{"access_token":"tok","refresh_token":"ref"}"#;
1409        let resp: OAuthTokenResponse = serde_json::from_str(json).unwrap();
1410        assert_eq!(resp.expires_in, DEFAULT_EXPIRES_IN);
1411    }
1412
1413    const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token";
1414
1415    #[test(tokio::test)]
1416    async fn test_get_credentials_or_refresh_with_expired_token() -> Res {
1417        let storage = Arc::new(MockStorage::default());
1418        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1419        let auth = Auth::new(paths.clone(), storage.clone());
1420        let host = get_host();
1421
1422        // Seed an expired access token and a stored OAuth client.
1423        let auth_io = AuthIo::new(storage, paths.auth_host(&host));
1424        auth_io
1425            .write_tokens(&Tokens {
1426                access_token: "expired-access-token".to_string(),
1427                refresh_token: REFRESH_TOKEN.to_string(),
1428                expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
1429            })
1430            .await?;
1431        auth_io
1432            .write_client(&OAuthClient {
1433                client_id: CLIENT_ID.to_string(),
1434                redirect_uri: REDIRECT_URI.to_string(),
1435            })
1436            .await?;
1437
1438        let client = OAuthTestHttpClient {
1439            expected_credentials_token: REFRESHED_ACCESS_TOKEN,
1440        };
1441        let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1442
1443        // Credentials should come from the refreshed access token.
1444        assert_eq!(creds.access_key, "oauth-access-key");
1445
1446        // Verify the new tokens were persisted by reading them back.
1447        let persisted = auth_io
1448            .read_tokens()
1449            .await?
1450            .expect("tokens should be persisted");
1451        assert_eq!(persisted.access_token, REFRESHED_ACCESS_TOKEN);
1452        assert_eq!(persisted.refresh_token, "new-refresh-token");
1453
1454        Ok(())
1455    }
1456
1457    #[test(tokio::test)]
1458    async fn test_get_or_register_client() -> Res {
1459        let storage = Arc::new(MockStorage::default());
1460        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1461        let auth = Auth::new(paths, storage);
1462        let host = get_host();
1463
1464        // First call registers via DCR
1465        let client = auth
1466            .get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
1467            .await?;
1468        assert_eq!(client.client_id, "test-dcr-client-id");
1469        assert_eq!(client.redirect_uri, REDIRECT_URI);
1470
1471        // Second call with same redirect_uri reads from storage (no DCR call)
1472        let client2 = auth
1473            .get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
1474            .await?;
1475        assert_eq!(client2.client_id, "test-dcr-client-id");
1476
1477        // Third call with different redirect_uri re-registers
1478        let new_redirect = "quilt://auth/callback?host=other.quilt.dev";
1479        let client3 = auth
1480            .get_or_register_client(&OAuthTestHttpClient::default(), &host, new_redirect)
1481            .await?;
1482        assert_eq!(client3.client_id, "test-dcr-client-id");
1483        assert_eq!(client3.redirect_uri, new_redirect);
1484
1485        Ok(())
1486    }
1487
1488    #[test]
1489    fn remote_tokens_debug_redacts_secrets() {
1490        let tokens = RemoteTokens {
1491            access_token: "secret-access".to_string(),
1492            refresh_token: "secret-refresh".to_string(),
1493            expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1494        };
1495        let output = format!("{tokens:?}");
1496        assert!(output.contains("[REDACTED]"));
1497        assert!(!output.contains("secret-access"));
1498        assert!(!output.contains("secret-refresh"));
1499    }
1500
1501    #[test]
1502    fn oauth_token_response_debug_redacts_secrets() {
1503        let response = OAuthTokenResponse {
1504            access_token: "secret-access".to_string(),
1505            refresh_token: Some("secret-refresh".to_string()),
1506            expires_in: 3600,
1507        };
1508        let output = format!("{response:?}");
1509        assert!(output.contains("[REDACTED]"));
1510        assert!(!output.contains("secret-access"));
1511        assert!(!output.contains("secret-refresh"));
1512    }
1513
1514    #[test]
1515    fn oauth_token_response_debug_none_refresh_token() {
1516        let response = OAuthTokenResponse {
1517            access_token: "secret-access".to_string(),
1518            refresh_token: None,
1519            expires_in: 3600,
1520        };
1521        let output = format!("{response:?}");
1522        assert!(output.contains("refresh_token: None"));
1523        assert!(!output.contains("secret-access"));
1524    }
1525
1526    #[test]
1527    fn remote_credentials_debug_redacts_secrets() {
1528        let creds = RemoteCredentials {
1529            access_key_id: "secret-key-id".to_string(),
1530            expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1531            secret_access_key: "secret-access-key".to_string(),
1532            session_token: "secret-session-token".to_string(),
1533        };
1534        let output = format!("{creds:?}");
1535        assert!(output.contains("[REDACTED]"));
1536        assert!(!output.contains("secret-key-id"));
1537        assert!(!output.contains("secret-access-key"));
1538        assert!(!output.contains("secret-session-token"));
1539    }
1540
1541    // ── Retry-on-transient-4xx tests ──────────────────────────────────────
1542
1543    use std::sync::atomic::AtomicUsize;
1544    use std::sync::atomic::Ordering;
1545    use tokio::io::AsyncReadExt;
1546    use tokio::io::AsyncWriteExt;
1547
1548    /// Spawns a one-connection TCP responder that replies with `response` bytes.
1549    /// Used to produce real `reqwest::Error` values with a chosen HTTP status.
1550    async fn spawn_one_shot(response: Vec<u8>) -> std::net::SocketAddr {
1551        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1552        let addr = listener.local_addr().unwrap();
1553        tokio::spawn(async move {
1554            if let Ok((mut stream, _)) = listener.accept().await {
1555                let mut buf = [0u8; 4096];
1556                let _ = stream.read(&mut buf).await;
1557                let _ = stream.write_all(&response).await;
1558                let _ = stream.shutdown().await;
1559            }
1560        });
1561        addr
1562    }
1563
1564    /// Produce an `Error::Reqwest` whose `.status()` is the given code. There
1565    /// is no public constructor for `reqwest::Error`, so we round-trip through
1566    /// a real HTTP request against a canned local responder.
1567    async fn reqwest_error_with_status(status: u16) -> Error {
1568        let body = format!("HTTP/1.1 {status} X\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
1569            .into_bytes();
1570        let addr = spawn_one_shot(body).await;
1571        reqwest::Client::new()
1572            .get(format!("http://{addr}/"))
1573            .send()
1574            .await
1575            .unwrap()
1576            .error_for_status()
1577            .unwrap_err()
1578            .into()
1579    }
1580
1581    /// Mock that fails the first N calls against each endpoint with a real
1582    /// `Error::Reqwest` carrying HTTP 401, then starts succeeding.
1583    struct RetryMockClient {
1584        cred_fail_first_n: usize,
1585        token_fail_first_n: usize,
1586        cred_calls: AtomicUsize,
1587        token_calls: AtomicUsize,
1588    }
1589
1590    impl RetryMockClient {
1591        fn new(cred_fail: usize, token_fail: usize) -> Self {
1592            Self {
1593                cred_fail_first_n: cred_fail,
1594                token_fail_first_n: token_fail,
1595                cred_calls: AtomicUsize::new(0),
1596                token_calls: AtomicUsize::new(0),
1597            }
1598        }
1599    }
1600
1601    #[async_trait]
1602    impl HttpClient for RetryMockClient {
1603        async fn get<T: serde::de::DeserializeOwned>(
1604            &self,
1605            url: &str,
1606            _auth_token: Option<&str>,
1607        ) -> Res<T> {
1608            let registry = get_registry();
1609            if url == format!("https://{}/config.json", get_host()) {
1610                let config = QuiltStackConfig {
1611                    registry_url: format!("https://{registry}").parse()?,
1612                };
1613                return Ok(serde_json::from_value(serde_json::to_value(config)?)?);
1614            }
1615            if url == format!("https://{registry}/api/auth/get_credentials") {
1616                let n = self.cred_calls.fetch_add(1, Ordering::SeqCst);
1617                if n < self.cred_fail_first_n {
1618                    return Err(reqwest_error_with_status(401).await);
1619                }
1620                let creds = RemoteCredentials {
1621                    access_key_id: "oauth-access-key".to_string(),
1622                    secret_access_key: "oauth-secret-key".to_string(),
1623                    session_token: "oauth-session-token".to_string(),
1624                    expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1625                };
1626                return Ok(serde_json::from_value(serde_json::to_value(creds)?)?);
1627            }
1628            panic!("Unexpected GET URL: {url}")
1629        }
1630
1631        async fn head(&self, _url: &str) -> Res<HeaderMap> {
1632            unimplemented!()
1633        }
1634
1635        async fn post<T: serde::de::DeserializeOwned>(
1636            &self,
1637            url: &str,
1638            form_data: &HashMap<String, String>,
1639        ) -> Res<T> {
1640            assert_eq!(url, connect_token_url(&get_host()));
1641            let n = self.token_calls.fetch_add(1, Ordering::SeqCst);
1642            if n < self.token_fail_first_n {
1643                return Err(reqwest_error_with_status(401).await);
1644            }
1645            assert_eq!(
1646                form_data.get("grant_type").map(String::as_str),
1647                Some("refresh_token")
1648            );
1649            let tokens = OAuthTokenResponse {
1650                access_token: REFRESHED_ACCESS_TOKEN.to_string(),
1651                refresh_token: Some("new-refresh-token".to_string()),
1652                expires_in: 3600,
1653            };
1654            Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
1655        }
1656
1657        async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1658            &self,
1659            _url: &str,
1660            _body: &B,
1661        ) -> Res<T> {
1662            unimplemented!()
1663        }
1664    }
1665
1666    async fn seed_fresh_tokens(storage: &Arc<MockStorage>, paths: &DomainPaths, host: &Host) {
1667        let auth_io = AuthIo::new(storage.clone(), paths.auth_host(host));
1668        auth_io
1669            .write_tokens(&Tokens {
1670                access_token: ACCESS_TOKEN.to_string(),
1671                refresh_token: REFRESH_TOKEN.to_string(),
1672                // Well inside the 60-second buffer → proactive refresh skipped.
1673                expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1674            })
1675            .await
1676            .unwrap();
1677        auth_io
1678            .write_client(&OAuthClient {
1679                client_id: CLIENT_ID.to_string(),
1680                redirect_uri: REDIRECT_URI.to_string(),
1681            })
1682            .await
1683            .unwrap();
1684    }
1685
1686    /// Credentials endpoint flaps 401 once, then succeeds. The retry path
1687    /// force-refreshes the access token and re-hits the credentials endpoint;
1688    /// user must not see `LoginRequired`.
1689    #[test(tokio::test)]
1690    async fn test_credentials_transient_401_recovers_via_force_token_refresh() -> Res {
1691        let storage = Arc::new(MockStorage::default());
1692        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1693        let auth = Auth::new(paths.clone(), storage.clone());
1694        let host = get_host();
1695        seed_fresh_tokens(&storage, &paths, &host).await;
1696
1697        let client = RetryMockClient::new(/*cred_fail=*/ 1, /*token_fail=*/ 0);
1698        let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1699
1700        assert_eq!(creds.access_key, "oauth-access-key");
1701        assert_eq!(
1702            client.cred_calls.load(Ordering::SeqCst),
1703            2,
1704            "credentials endpoint should be called twice: initial + retry"
1705        );
1706        assert_eq!(
1707            client.token_calls.load(Ordering::SeqCst),
1708            1,
1709            "token endpoint should be called once to force-refresh"
1710        );
1711        Ok(())
1712    }
1713
1714    /// Credentials endpoint fails 401 twice in a row. After the bounded retry
1715    /// the client must conclude login is really required.
1716    #[test(tokio::test)]
1717    async fn test_credentials_persistent_401_maps_to_login_required() -> Res {
1718        let storage = Arc::new(MockStorage::default());
1719        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1720        let auth = Auth::new(paths.clone(), storage.clone());
1721        let host = get_host();
1722        seed_fresh_tokens(&storage, &paths, &host).await;
1723
1724        let client = RetryMockClient::new(/*cred_fail=*/ usize::MAX, /*token_fail=*/ 0);
1725        let result = auth.get_credentials_or_refresh(&client, &host).await;
1726
1727        assert!(
1728            matches!(result, Err(Error::Login(LoginError::Required(_)))),
1729            "expected LoginRequired after persistent 4xx, got: {result:?}"
1730        );
1731        assert_eq!(
1732            client.cred_calls.load(Ordering::SeqCst),
1733            2,
1734            "retry must be bounded to one extra attempt"
1735        );
1736        Ok(())
1737    }
1738
1739    /// Token endpoint flaps 401 once during the proactive refresh path, then
1740    /// succeeds. The retry must kick in and `get_credentials_or_refresh` must
1741    /// return credentials without surfacing `LoginRequired`.
1742    #[test(tokio::test)]
1743    async fn test_token_refresh_transient_401_recovers() -> Res {
1744        let storage = Arc::new(MockStorage::default());
1745        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1746        let auth = Auth::new(paths.clone(), storage.clone());
1747        let host = get_host();
1748
1749        // Seed *expired* tokens so the proactive refresh path is taken.
1750        let auth_io = AuthIo::new(storage.clone(), paths.auth_host(&host));
1751        auth_io
1752            .write_tokens(&Tokens {
1753                access_token: "expired-access-token".to_string(),
1754                refresh_token: REFRESH_TOKEN.to_string(),
1755                expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
1756            })
1757            .await?;
1758        auth_io
1759            .write_client(&OAuthClient {
1760                client_id: CLIENT_ID.to_string(),
1761                redirect_uri: REDIRECT_URI.to_string(),
1762            })
1763            .await?;
1764
1765        let client = RetryMockClient::new(/*cred_fail=*/ 0, /*token_fail=*/ 1);
1766        let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1767
1768        assert_eq!(creds.access_key, "oauth-access-key");
1769        assert_eq!(
1770            client.token_calls.load(Ordering::SeqCst),
1771            2,
1772            "token endpoint should be called twice: initial + retry"
1773        );
1774        assert_eq!(
1775            client.cred_calls.load(Ordering::SeqCst),
1776            1,
1777            "credentials endpoint should only be called once after successful retry"
1778        );
1779        Ok(())
1780    }
1781
1782    /// Synchronization gate used by `CountingCredsClient` to park the
1783    /// `/api/auth/get_credentials` handler mid-call. `entered` signals
1784    /// the test that the handler has been reached; `release` holds the
1785    /// handler until the test lets it return.
1786    #[derive(Default)]
1787    struct Gate {
1788        entered: tokio::sync::Notify,
1789        release: tokio::sync::Notify,
1790    }
1791
1792    /// HTTP client that counts calls to `/api/auth/get_credentials`.
1793    /// Optionally sleeps inside the handler to widen the race window,
1794    /// or parks the handler on a `Gate` for deterministic coordination.
1795    /// Tokens are fresh so no OAuth leg fires.
1796    #[derive(Clone)]
1797    struct CountingCredsClient {
1798        cred_calls: Arc<std::sync::atomic::AtomicUsize>,
1799        sleep_ms: u64,
1800        gate: Option<Arc<Gate>>,
1801    }
1802
1803    #[async_trait]
1804    impl HttpClient for CountingCredsClient {
1805        async fn get<T: serde::de::DeserializeOwned>(
1806            &self,
1807            url: &str,
1808            _auth_token: Option<&str>,
1809        ) -> Res<T> {
1810            if url.ends_with("/config.json") {
1811                let body = serde_json::json!({
1812                    "registryUrl": format!("https://{}", get_registry()),
1813                });
1814                return Ok(serde_json::from_value(body)?);
1815            }
1816            if url.contains("/api/auth/get_credentials") {
1817                self.cred_calls
1818                    .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1819                if let Some(gate) = &self.gate {
1820                    gate.entered.notify_one();
1821                    gate.release.notified().await;
1822                } else if self.sleep_ms > 0 {
1823                    tokio::time::sleep(std::time::Duration::from_millis(self.sleep_ms)).await;
1824                }
1825                let body = serde_json::json!({
1826                    "AccessKeyId": "refreshed-key",
1827                    "SecretAccessKey": "refreshed-secret",
1828                    "SessionToken": "refreshed-session",
1829                    "Expiration": (chrono::Utc::now() + chrono::Duration::hours(1))
1830                        .to_rfc3339(),
1831                });
1832                return Ok(serde_json::from_value(body)?);
1833            }
1834            panic!("Unexpected GET: {url}");
1835        }
1836        async fn head(&self, _: &str) -> Res<HeaderMap> {
1837            unimplemented!()
1838        }
1839        async fn post<T: serde::de::DeserializeOwned>(
1840            &self,
1841            _: &str,
1842            _: &HashMap<String, String>,
1843        ) -> Res<T> {
1844            unimplemented!("fresh tokens → no OAuth leg fires")
1845        }
1846        async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1847            &self,
1848            _: &str,
1849            _: &B,
1850        ) -> Res<T> {
1851            unimplemented!()
1852        }
1853    }
1854
1855    async fn seed_expired_creds_fresh_tokens(auth_io: &AuthIo<Arc<MockStorage>>) -> Res {
1856        auth_io
1857            .write_credentials(&Credentials {
1858                access_key: "stale".to_string(),
1859                secret_key: "stale-secret".to_string(),
1860                token: "stale-session".to_string(),
1861                expires_at: chrono::Utc::now() - chrono::Duration::hours(1),
1862            })
1863            .await?;
1864        auth_io
1865            .write_tokens(&Tokens {
1866                access_token: ACCESS_TOKEN.to_string(),
1867                refresh_token: REFRESH_TOKEN.to_string(),
1868                expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1869            })
1870            .await?;
1871        Ok(())
1872    }
1873
1874    #[test(tokio::test)]
1875    async fn test_auth_refresh_is_single_flight_across_concurrent_callers() -> Res {
1876        let storage = Arc::new(MockStorage::default());
1877        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1878        let auth = Auth::new(paths.clone(), storage.clone());
1879        let host = get_host();
1880
1881        let auth_io = AuthIo::new(storage, paths.auth_host(&host));
1882        seed_expired_creds_fresh_tokens(&auth_io).await?;
1883
1884        let client = CountingCredsClient {
1885            cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1886            sleep_ms: 50,
1887            gate: None,
1888        };
1889
1890        let mut handles = Vec::new();
1891        for _ in 0..10 {
1892            let auth = auth.clone();
1893            let client = client.clone();
1894            let host = host.clone();
1895            handles.push(tokio::spawn(async move {
1896                auth.get_credentials_or_refresh(&client, &host).await
1897            }));
1898        }
1899
1900        let mut creds_seen = Vec::new();
1901        for h in handles {
1902            creds_seen.push(h.await.unwrap()?);
1903        }
1904
1905        assert_eq!(
1906            client.cred_calls.load(std::sync::atomic::Ordering::SeqCst),
1907            1,
1908            "single-flight: 10 concurrent callers must produce exactly one refresh",
1909        );
1910        let first = &creds_seen[0];
1911        for creds in &creds_seen {
1912            assert_eq!(creds.access_key, first.access_key);
1913            assert_eq!(creds.expires_at, first.expires_at);
1914        }
1915        assert_eq!(first.access_key, "refreshed-key");
1916        Ok(())
1917    }
1918
1919    #[test(tokio::test)]
1920    async fn test_auth_refresh_lock_is_per_host() -> Res {
1921        let storage = Arc::new(MockStorage::default());
1922        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1923        let auth = Auth::new(paths.clone(), storage.clone());
1924
1925        let host_a: Host = "a.quilt.dev".parse().unwrap();
1926        let host_b: Host = "b.quilt.dev".parse().unwrap();
1927
1928        // Seed each host separately; they live under distinct paths.
1929        seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_a)))
1930            .await?;
1931        seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_b)))
1932            .await?;
1933
1934        // Park host_a's refresh inside the HTTP handler using a gate so
1935        // it deterministically holds host_a's lock while we exercise
1936        // host_b. No wall-clock budget — robust under CI load.
1937        let gate = Arc::new(Gate::default());
1938        let gated_client = CountingCredsClient {
1939            cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1940            sleep_ms: 0,
1941            gate: Some(gate.clone()),
1942        };
1943        let fast_client = CountingCredsClient {
1944            cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1945            sleep_ms: 0,
1946            gate: None,
1947        };
1948
1949        let auth_clone = auth.clone();
1950        let client_a = gated_client.clone();
1951        let host_a_clone = host_a.clone();
1952        let a_task = tokio::spawn(async move {
1953            auth_clone
1954                .get_credentials_or_refresh(&client_a, &host_a_clone)
1955                .await
1956        });
1957
1958        // Wait until host_a is confirmed inside the handler, holding
1959        // host_a's refresh lock.
1960        gate.entered.notified().await;
1961
1962        // Run host_b. If per-host locking works, this completes;
1963        // otherwise it would block forever on host_a's lock. The
1964        // timeout is a safety net to fail fast instead of hanging CI.
1965        tokio::time::timeout(
1966            std::time::Duration::from_secs(5),
1967            auth.get_credentials_or_refresh(&fast_client, &host_b),
1968        )
1969        .await
1970        .expect("host_b refresh must not wait behind host_a's lock")?;
1971
1972        // Positive assertion: host_a is still parked in its handler,
1973        // proving host_b completed without host_a making progress.
1974        assert!(
1975            !a_task.is_finished(),
1976            "host_a must still be blocked in its handler while host_b completes",
1977        );
1978
1979        // Release host_a so the spawned task can finish cleanly.
1980        gate.release.notify_one();
1981        a_task.await.unwrap()?;
1982        Ok(())
1983    }
1984
1985    #[test(tokio::test)]
1986    async fn test_refresh_lock_map_sweeps_dead_entries() -> Res {
1987        let storage = Arc::new(MockStorage::default());
1988        let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1989        let auth = Auth::new(paths, storage);
1990
1991        let host: Host = "x.quilt.dev".parse().unwrap();
1992
1993        // First lookup inserts a live Weak.
1994        let arc1 = auth.refresh_lock_for(&host);
1995        assert_eq!(
1996            auth.refresh_locks
1997                .lock()
1998                .unwrap_or_else(std::sync::PoisonError::into_inner)
1999                .len(),
2000            1,
2001        );
2002
2003        // Dropping all strong refs leaves a dead Weak behind.
2004        drop(arc1);
2005        assert!(
2006            auth.refresh_locks
2007                .lock()
2008                .unwrap_or_else(std::sync::PoisonError::into_inner)
2009                .get(&host)
2010                .expect("entry still present before sweep")
2011                .upgrade()
2012                .is_none(),
2013        );
2014
2015        // Next lookup sweeps the dead entry and inserts a fresh one;
2016        // map size stays at 1 instead of accumulating per refresh.
2017        let _arc2 = auth.refresh_lock_for(&host);
2018        assert_eq!(
2019            auth.refresh_locks
2020                .lock()
2021                .unwrap_or_else(std::sync::PoisonError::into_inner)
2022                .len(),
2023            1,
2024        );
2025        Ok(())
2026    }
2027}