1use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21use std::sync::Mutex as StdMutex;
22use std::sync::Weak;
23
24use base64::Engine;
25use base64::engine::general_purpose::URL_SAFE_NO_PAD;
26use sha2::Digest;
27use sha2::Sha256;
28use tokio::sync::Mutex as AsyncMutex;
29
30use crate::Error;
31use crate::Res;
32use crate::error::AuthError;
33use crate::error::LoginError;
34use crate::io::remote::client::HttpClient;
35use crate::io::storage::LocalStorage;
36use crate::io::storage::Storage;
37use crate::io::storage::auth::AuthIo;
38use crate::io::storage::auth::Credentials;
39use crate::io::storage::auth::OAuthClient;
40use crate::io::storage::auth::Tokens;
41use crate::paths::DomainPaths;
42use chrono::serde::ts_seconds;
43use quilt_uri::Host;
44use serde::Deserialize;
45use serde::Deserializer;
46use serde::Serialize;
47use tracing::debug;
48use tracing::error;
49use tracing::info;
50use tracing::warn;
51
52pub struct OAuthParams {
54 pub code: String,
56 pub code_verifier: String,
58 pub redirect_uri: String,
60 pub client_id: String,
66}
67
68pub struct PkceChallenge {
70 pub code_verifier: String,
72 pub code_challenge: String,
74}
75
76pub fn pkce_challenge() -> PkceChallenge {
81 let mut random_bytes = [0u8; 64];
82 getrandom::fill(&mut random_bytes).expect("failed to generate random bytes");
83
84 let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
85 let code_challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
86
87 PkceChallenge {
88 code_verifier,
89 code_challenge,
90 }
91}
92
93pub fn random_state() -> String {
95 let mut bytes = [0u8; 16];
96 getrandom::fill(&mut bytes).expect("failed to generate random bytes");
97 URL_SAFE_NO_PAD.encode(bytes)
98}
99
100pub fn catalog_authorize_url(host: &Host) -> String {
114 format!("https://{host}/connect/authorize")
115}
116
117pub fn connect_host(host: &Host) -> String {
128 let s = host.to_string();
129 match s.split_once('.') {
130 Some((stack, domain)) => format!("{stack}-connect.{domain}"),
131 None => format!("{s}-connect"),
132 }
133}
134
135fn connect_token_url(host: &Host) -> String {
139 format!("https://{}/auth/token", connect_host(host))
140}
141
142fn connect_register_url(host: &Host) -> String {
144 format!("https://{}/auth/register", connect_host(host))
145}
146
147#[derive(Serialize)]
149struct DcrRequest {
150 client_name: String,
151 redirect_uris: Vec<String>,
152 token_endpoint_auth_method: String,
153}
154
155#[derive(Deserialize)]
157struct DcrResponse {
158 client_id: String,
159}
160
161async fn register_client(
163 http_client: &impl HttpClient,
164 host: &Host,
165 redirect_uri: &str,
166) -> Res<OAuthClient> {
167 let register_url = connect_register_url(host);
168
169 let request = DcrRequest {
170 client_name: "QuiltSync".to_string(),
171 redirect_uris: vec![redirect_uri.to_string()],
172 token_endpoint_auth_method: "none".to_string(),
173 };
174
175 let response: DcrResponse = http_client.post_json(®ister_url, &request).await?;
176
177 Ok(OAuthClient {
178 client_id: response.client_id,
179 redirect_uri: redirect_uri.to_string(),
180 })
181}
182
183#[derive(Deserialize, Serialize)]
184pub struct RemoteTokens {
185 pub access_token: String,
186 pub refresh_token: String,
187 #[serde(with = "ts_seconds")]
188 pub expires_at: chrono::DateTime<chrono::Utc>,
189}
190
191impl fmt::Debug for RemoteTokens {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 f.debug_struct("RemoteTokens")
194 .field("expires_at", &self.expires_at)
195 .field("access_token", &"[REDACTED]")
196 .field("refresh_token", &"[REDACTED]")
197 .finish_non_exhaustive()
198 }
199}
200
201impl From<RemoteTokens> for Tokens {
202 fn from(raw: RemoteTokens) -> Self {
203 Tokens {
204 access_token: raw.access_token,
205 refresh_token: raw.refresh_token,
206 expires_at: raw.expires_at,
207 }
208 }
209}
210
211const DEFAULT_EXPIRES_IN: i64 = 3600;
217
218fn default_expires_in() -> i64 {
219 DEFAULT_EXPIRES_IN
220}
221
222#[derive(Deserialize, Serialize)]
234struct OAuthTokenResponse {
235 access_token: String,
236 #[serde(default)]
237 refresh_token: Option<String>,
238 #[serde(default = "default_expires_in")]
239 expires_in: i64,
240}
241
242impl fmt::Debug for OAuthTokenResponse {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 f.debug_struct("OAuthTokenResponse")
245 .field("expires_in", &self.expires_in)
246 .field("access_token", &"[REDACTED]")
247 .field(
248 "refresh_token",
249 &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
250 )
251 .finish_non_exhaustive()
252 }
253}
254
255#[derive(Deserialize, Serialize)]
256#[serde(rename_all = "PascalCase")]
257struct RemoteCredentials {
258 access_key_id: String,
259 #[serde(deserialize_with = "date_from_rfc3339")]
260 expiration: chrono::DateTime<chrono::Utc>,
261 secret_access_key: String,
262 session_token: String,
263}
264
265impl fmt::Debug for RemoteCredentials {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 f.debug_struct("RemoteCredentials")
268 .field("expiration", &self.expiration)
269 .field("access_key_id", &"[REDACTED]")
270 .field("secret_access_key", &"[REDACTED]")
271 .field("session_token", &"[REDACTED]")
272 .finish_non_exhaustive()
273 }
274}
275
276impl From<RemoteCredentials> for Credentials {
277 fn from(raw: RemoteCredentials) -> Self {
278 Credentials {
279 access_key: raw.access_key_id,
280 secret_key: raw.secret_access_key,
281 token: raw.session_token,
282 expires_at: raw.expiration,
283 }
284 }
285}
286
287fn date_from_rfc3339<'de, D: Deserializer<'de>>(
288 deserializer: D,
289) -> Result<chrono::DateTime<chrono::Utc>, D::Error> {
290 use serde::de::Error;
291 String::deserialize(deserializer).and_then(|s| {
292 chrono::DateTime::parse_from_rfc3339(&s)
293 .map_err(|e| Error::custom(format!("Invalid RFC3339 date: {e}")))
294 .map(|dt| dt.with_timezone(&chrono::Utc))
295 })
296}
297
298#[derive(Deserialize, Serialize, Debug)]
299#[serde(rename_all = "camelCase")]
300struct QuiltStackConfig {
301 registry_url: url::Url,
302}
303
304async fn get_registry_url(http_client: &impl HttpClient, host: &Host) -> Res<url::Host> {
305 let QuiltStackConfig { registry_url } = http_client
306 .get(&format!("https://{host}/config.json"), None)
307 .await?;
308 Ok(url::Host::Domain(
309 registry_url
310 .domain()
311 .ok_or(LoginError::RequiredRegistryUrl(host.to_owned()))?
312 .to_string(),
313 ))
314}
315
316async fn get_auth_tokens(
317 http_client: &impl HttpClient,
318 host: &Host,
319 refresh_token: &str,
320) -> Res<Tokens> {
321 let registry = get_registry_url(http_client, host).await?;
322
323 let mut form_data: HashMap<String, String> = HashMap::new();
324 form_data.insert("refresh_token".to_string(), refresh_token.to_string());
325 let tokens_json: RemoteTokens = http_client
326 .post(&format!("https://{registry}/api/token"), &form_data)
327 .await?;
328 let tokens = Tokens::from(tokens_json);
329
330 Ok(tokens)
331}
332
333async fn exchange_oauth_code(
335 http_client: &impl HttpClient,
336 host: &Host,
337 params: &OAuthParams,
338) -> Res<Tokens> {
339 let token_url = connect_token_url(host);
340
341 let mut form_data: HashMap<String, String> = HashMap::new();
342 form_data.insert("grant_type".to_string(), "authorization_code".to_string());
343 form_data.insert("code".to_string(), params.code.clone());
344 form_data.insert("code_verifier".to_string(), params.code_verifier.clone());
345 form_data.insert("redirect_uri".to_string(), params.redirect_uri.clone());
346 form_data.insert("client_id".to_string(), params.client_id.clone());
347
348 let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
349 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
350 Ok(Tokens {
351 access_token: response.access_token,
352 refresh_token: response.refresh_token.ok_or_else(|| {
353 Error::Auth(
354 host.to_owned(),
355 AuthError::TokensExchange("server did not return a refresh token".to_string()),
356 )
357 })?,
358 expires_at,
359 })
360}
361
362async fn refresh_oauth_tokens(
364 http_client: &impl HttpClient,
365 host: &Host,
366 refresh_token: &str,
367 client_id: &str,
368) -> Res<Tokens> {
369 let token_url = connect_token_url(host);
370
371 let mut form_data: HashMap<String, String> = HashMap::new();
372 form_data.insert("grant_type".to_string(), "refresh_token".to_string());
373 form_data.insert("refresh_token".to_string(), refresh_token.to_string());
374 form_data.insert("client_id".to_string(), client_id.to_string());
375
376 let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
377 let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
378 Ok(Tokens {
379 access_token: response.access_token,
380 refresh_token: response
382 .refresh_token
383 .unwrap_or_else(|| refresh_token.to_string()),
384 expires_at,
385 })
386}
387
388async fn refresh_credentials(
389 http_client: &impl HttpClient,
390 host: &Host,
391 access_token: &str,
392) -> Res<Credentials> {
393 let registry = get_registry_url(http_client, host).await?;
394
395 let creds_json: RemoteCredentials = http_client
396 .get(
397 &format!("https://{registry}/api/auth/get_credentials"),
398 Some(access_token),
399 )
400 .await?;
401
402 let credentials = Credentials::from(creds_json);
403
404 Ok(credentials)
405}
406
407fn is_token_auth_error(e: &Error) -> bool {
413 matches!(
414 e,
415 Error::Reqwest(re) if re.status().is_some_and(|s| s == 400 || s == 401 || s == 403)
416 )
417}
418
419fn is_credentials_auth_error(e: &Error) -> bool {
425 matches!(
426 e,
427 Error::Reqwest(re) if re.status().is_some_and(|s| s == 401 || s == 403)
428 )
429}
430
431fn http_status(e: &Error) -> Option<u16> {
435 match e {
436 Error::Reqwest(re) => re.status().map(|s| s.as_u16()),
437 _ => None,
438 }
439}
440
441fn classify_retry_outcome<T>(
448 result: Res<T>,
449 is_auth_error: fn(&Error) -> bool,
450 endpoint: &str,
451 host: &Host,
452) -> Res<T> {
453 match result {
454 Ok(v) => {
455 info!(
456 "✔️ Recovered from transient auth error on {} for {}",
457 endpoint, host
458 );
459 Ok(v)
460 }
461 Err(e) if is_auth_error(&e) => {
462 warn!(
463 status = ?http_status(&e),
464 "❌ Auth error on {} for {} persisted after retry, login required: {}",
465 endpoint, host, e
466 );
467 Err(LoginError::Required(Some(host.to_owned())).into())
468 }
469 Err(e) => {
470 warn!(
471 status = ?http_status(&e),
472 "❌ Failed to refresh via {} for {} on retry: {}",
473 endpoint, host, e
474 );
475 Err(e)
476 }
477 }
478}
479
480type RefreshLocks = Arc<StdMutex<HashMap<Host, Weak<AsyncMutex<()>>>>>;
493
494#[derive(Debug)]
495pub struct Auth<S: Storage = LocalStorage> {
496 pub paths: DomainPaths,
497 pub storage: Arc<S>,
498 refresh_locks: RefreshLocks,
499}
500
501impl<S: Storage> Clone for Auth<S> {
502 fn clone(&self) -> Self {
503 Self {
504 paths: self.paths.clone(),
505 storage: Arc::clone(&self.storage),
506 refresh_locks: Arc::clone(&self.refresh_locks),
507 }
508 }
509}
510
511impl<S: Storage + Send + Sync> Auth<S> {
512 pub fn new(paths: DomainPaths, storage: Arc<S>) -> Self {
513 Self {
514 paths,
515 storage,
516 refresh_locks: Arc::new(StdMutex::new(HashMap::new())),
517 }
518 }
519
520 fn refresh_lock_for(&self, host: &Host) -> Arc<AsyncMutex<()>> {
526 let mut locks = self
527 .refresh_locks
528 .lock()
529 .unwrap_or_else(std::sync::PoisonError::into_inner);
530 locks.retain(|_, weak| weak.strong_count() > 0);
531 if let Some(arc) = locks.get(host).and_then(Weak::upgrade) {
532 return arc;
533 }
534 let arc = Arc::new(AsyncMutex::new(()));
535 locks.insert(host.clone(), Arc::downgrade(&arc));
536 arc
537 }
538
539 pub async fn login<T: HttpClient>(
540 &self,
541 http_client: &T,
542 host: &Host,
543 refresh_token: String,
544 ) -> Res {
545 info!("⏳ Logging in to host {} with refresh token", host);
546
547 let tokens = match self
548 .get_auth_tokens(http_client, host, &refresh_token)
549 .await
550 {
551 Ok(t) => t,
552 Err(e) => {
553 warn!("❌ Failed to get auth tokens for {}: {}", host, e);
554 return Err(e);
555 }
556 };
557
558 if let Err(e) = self.save_tokens(host, &tokens).await {
559 warn!("❌ Failed to save tokens for {}: {}", host, e);
560 return Err(e);
561 }
562
563 if let Err(e) = self
564 .refresh_credentials(http_client, host, &tokens.access_token)
565 .await
566 {
567 warn!("❌ Failed to refresh credentials for {}: {}", host, e);
568 return Err(e);
569 }
570
571 info!("✔️ Successfully logged in and authenticated to {}", host);
572 Ok(())
573 }
574
575 pub async fn get_or_register_client<T: HttpClient>(
577 &self,
578 http_client: &T,
579 host: &Host,
580 redirect_uri: &str,
581 ) -> Res<OAuthClient> {
582 let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
583
584 if let Some(client) = auth_io.read_client().await? {
585 if client.redirect_uri == redirect_uri {
586 info!("✔️ Found existing OAuth client for {}", host);
587 return Ok(client);
588 }
589 info!(
590 "⚠️ Cached client has stale redirect_uri, re-registering for {}",
591 host
592 );
593 }
594
595 info!("⏳ Registering new OAuth client for {}", host);
596 let client = register_client(http_client, host, redirect_uri).await?;
597 auth_io.write_client(&client).await?;
598 info!(
599 "✔️ Registered OAuth client for {}: {}",
600 host, client.client_id
601 );
602
603 Ok(client)
604 }
605
606 pub async fn login_oauth<T: HttpClient>(
617 &self,
618 http_client: &T,
619 host: &Host,
620 params: OAuthParams,
621 ) -> Res {
622 info!("⏳ OAuth login for host {}", host);
623
624 let tokens = exchange_oauth_code(http_client, host, ¶ms)
625 .await
626 .map_err(|e| {
627 warn!("❌ Failed to exchange OAuth code for {}: {}", host, e);
628 e
629 })?;
630
631 self.save_tokens(host, &tokens).await.map_err(|e| {
632 warn!("❌ Failed to save tokens for {}: {}", host, e);
633 e
634 })?;
635
636 self.refresh_credentials(http_client, host, &tokens.access_token)
637 .await
638 .map_err(|e| {
639 warn!("❌ Failed to refresh credentials for {}: {}", host, e);
640 e
641 })?;
642
643 info!("✔️ OAuth login successful for {}", host);
644 Ok(())
645 }
646
647 async fn get_auth_tokens<T: HttpClient>(
648 &self,
649 http_client: &T,
650 host: &Host,
651 refresh_token: &str,
652 ) -> Res<Tokens> {
653 debug!("⏳ Getting auth tokens for host {:?}", host);
654 let tokens = get_auth_tokens(http_client, host, refresh_token).await?;
655 debug!("✔️ Successfully retrieved auth tokens");
656 Ok(tokens)
657 }
658
659 async fn save_tokens(&self, host: &Host, tokens: &Tokens) -> Res<()> {
660 debug!("⏳ Saving tokens for host {:?}", host);
661 let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
662 auth_io.write_tokens(tokens).await?;
663 debug!(
664 "✔️ Successfully saved tokens to the {:?}",
665 self.paths.auth_host(host)
666 );
667 Ok(())
668 }
669
670 async fn refresh_tokens<T: HttpClient>(
673 &self,
674 http_client: &T,
675 auth_io: &AuthIo<Arc<S>>,
676 host: &Host,
677 tokens: &Tokens,
678 ) -> Res<Tokens> {
679 let client = auth_io
680 .read_client()
681 .await?
682 .ok_or(LoginError::Required(Some(host.to_owned())))?;
683
684 let new_tokens =
685 refresh_oauth_tokens(http_client, host, &tokens.refresh_token, &client.client_id)
686 .await?;
687
688 auth_io.write_tokens(&new_tokens).await?;
689 info!("✔️ Successfully refreshed tokens for {}", host);
690
691 Ok(new_tokens)
692 }
693
694 async fn refresh_tokens_with_retry<T: HttpClient>(
703 &self,
704 http_client: &T,
705 auth_io: &AuthIo<Arc<S>>,
706 host: &Host,
707 tokens: &Tokens,
708 ) -> Res<Tokens> {
709 let first_err = match self
710 .refresh_tokens(http_client, auth_io, host, tokens)
711 .await
712 {
713 Ok(t) => return Ok(t),
714 Err(e) => e,
715 };
716
717 if matches!(first_err, Error::Login(LoginError::Required(_))) {
718 warn!("❌ No OAuth client registered for {}, login required", host);
719 return Err(first_err);
720 }
721 if !is_token_auth_error(&first_err) {
722 warn!(
723 status = ?http_status(&first_err),
724 "❌ Failed to refresh tokens for {}: {}", host, first_err
725 );
726 return Err(first_err);
727 }
728
729 info!(
730 status = ?http_status(&first_err),
731 "⚠️ Auth error refreshing tokens for {}, retrying once: {}", host, first_err
732 );
733 classify_retry_outcome(
734 self.refresh_tokens(http_client, auth_io, host, tokens)
735 .await,
736 is_token_auth_error,
737 "token endpoint",
738 host,
739 )
740 }
741
742 async fn refresh_credentials_with_retry<T: HttpClient>(
751 &self,
752 http_client: &T,
753 auth_io: &AuthIo<Arc<S>>,
754 host: &Host,
755 access_token: &str,
756 ) -> Res<Credentials> {
757 let first_err = match self
758 .refresh_credentials(http_client, host, access_token)
759 .await
760 {
761 Ok(c) => return Ok(c),
762 Err(e) => e,
763 };
764
765 if !is_credentials_auth_error(&first_err) {
766 warn!(
767 status = ?http_status(&first_err),
768 "❌ Failed to refresh credentials for {}: {}", host, first_err
769 );
770 return Err(first_err);
771 }
772
773 info!(
774 status = ?http_status(&first_err),
775 "⚠️ Auth error refreshing credentials for {}, \
776 force-refreshing token and retrying: {}",
777 host, first_err
778 );
779
780 let tokens = auth_io
782 .read_tokens()
783 .await?
784 .ok_or_else(|| LoginError::Required(Some(host.to_owned())))?;
785 let new_tokens = self
786 .refresh_tokens_with_retry(http_client, auth_io, host, &tokens)
787 .await?;
788
789 classify_retry_outcome(
790 self.refresh_credentials(http_client, host, &new_tokens.access_token)
791 .await,
792 is_credentials_auth_error,
793 "credentials endpoint",
794 host,
795 )
796 }
797
798 async fn refresh_credentials<T: HttpClient>(
799 &self,
800 http_client: &T,
801 host: &Host,
802 access_token: &str,
803 ) -> Res<Credentials> {
804 debug!("⏳ Refreshing credentials for host {:?}", host);
805 let credentials = refresh_credentials(http_client, host, access_token).await?;
806
807 let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
808 auth_io.write_credentials(&credentials).await?;
809
810 debug!(
811 "✔️ Successfully refreshed credentials in {:?}",
812 self.paths.auth_host(host)
813 );
814 Ok(credentials)
815 }
816
817 pub async fn get_credentials_or_refresh<T: HttpClient>(
818 &self,
819 http_client: &T,
820 host: &Host,
821 ) -> Res<Credentials> {
822 info!("⏳ Getting or refreshing credentials for {}", host);
823 let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
824
825 match auth_io.read_credentials().await {
826 Ok(Some(creds)) => {
827 debug!("✔️ Found valid credentials for {}", host);
828 return Ok(creds);
829 }
830 Ok(None) => {
831 info!("❌ No existing credentials found for {}", host);
832 }
833 Err(e) => {
834 error!("❌ Failed to read credentials for {}: {}", host, e);
835 return Err(Error::Auth(
836 host.to_owned(),
837 AuthError::CredentialsRead(e.to_string()),
838 ));
839 }
840 }
841
842 let lock = self.refresh_lock_for(host);
847 let _guard = lock.lock().await;
848
849 match auth_io.read_credentials().await {
850 Ok(Some(creds)) => {
851 debug!("✔️ Another task refreshed credentials for {}", host);
852 return Ok(creds);
853 }
854 Ok(None) => {}
855 Err(e) => {
856 error!("❌ Failed to re-read credentials for {}: {}", host, e);
857 return Err(Error::Auth(
858 host.to_owned(),
859 AuthError::CredentialsRead(e.to_string()),
860 ));
861 }
862 }
863
864 let tokens = match auth_io.read_tokens().await {
865 Ok(Some(tokens)) => tokens,
866 Ok(None) => {
867 warn!("❌ No tokens found for {}, login required", host);
868 return Err(LoginError::Required(Some(host.to_owned())).into());
869 }
870 Err(e) => {
871 error!("❌ Failed to read tokens for {}: {}", host, e);
872 return Err(Error::Auth(
873 host.to_owned(),
874 AuthError::TokensRead(e.to_string()),
875 ));
876 }
877 };
878
879 let access_token =
881 if tokens.expires_at <= chrono::Utc::now() + chrono::Duration::seconds(60) {
882 info!(
883 "⏳ Access token expired for {}, refreshing via refresh token",
884 host
885 );
886 self.refresh_tokens_with_retry(http_client, &auth_io, host, &tokens)
887 .await?
888 .access_token
889 } else {
890 tokens.access_token
891 };
892
893 info!("⏳ Refreshing credentials using access token for {}", host);
894 let creds = self
895 .refresh_credentials_with_retry(http_client, &auth_io, host, &access_token)
896 .await?;
897 info!("✔️ Successfully refreshed credentials for {}", host);
898 Ok(creds)
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905
906 use async_trait::async_trait;
907 use reqwest::header::HeaderMap;
908 use test_log::test;
909
910 use crate::io::storage::mocks::MockStorage;
911 use crate::paths::DomainPaths;
912
913 const ACCESS_TOKEN: &str = "test-access-token";
914 const REFRESH_TOKEN: &str = "test-refresh-token";
915 const TIMESTAMP: i64 = 1_708_444_800;
916
917 fn get_host() -> Host {
918 "test.quilt.dev".parse().unwrap()
919 }
920
921 fn get_registry() -> String {
922 "registry-test.quilt.dev".to_string()
923 }
924
925 struct TestHttpClient;
926
927 #[async_trait]
928 impl HttpClient for TestHttpClient {
929 async fn get<T: serde::de::DeserializeOwned>(
930 &self,
931 url: &str,
932 auth_token: Option<&str>,
933 ) -> Res<T> {
934 let registry = get_registry();
935
936 match url {
937 u if u == format!("https://{}/config.json", get_host()) => {
938 let config = QuiltStackConfig {
939 registry_url: format!("https://{registry}").parse()?,
940 };
941 Ok(serde_json::from_value(serde_json::to_value(config)?)?)
942 }
943 u if u == format!("https://{registry}/api/auth/get_credentials") => {
944 assert_eq!(auth_token, Some(ACCESS_TOKEN));
945 let creds = RemoteCredentials {
946 access_key_id: "test-access-key".to_string(),
947 secret_access_key: "test-secret-key".to_string(),
948 session_token: "test-session-token".to_string(),
949 expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
950 };
951 Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
952 }
953 _ => panic!("Unexpected URL: {url}"),
954 }
955 }
956
957 async fn head(&self, _url: &str) -> Res<HeaderMap> {
958 unimplemented!("head is not used in this test")
959 }
960
961 async fn post<T: serde::de::DeserializeOwned>(
962 &self,
963 url: &str,
964 form_data: &HashMap<String, String>,
965 ) -> Res<T> {
966 assert_eq!(url, format!("https://{}/api/token", get_registry()));
967
968 assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
970
971 let tokens = RemoteTokens {
972 access_token: ACCESS_TOKEN.to_string(),
973 refresh_token: "new-refresh-token".to_string(),
974 expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
975 };
976 Ok(serde_json::from_value(serde_json::to_value(tokens)?)?)
977 }
978
979 async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
980 &self,
981 _url: &str,
982 _body: &B,
983 ) -> Res<T> {
984 unimplemented!("post_json is not used in this test")
985 }
986 }
987
988 #[test(tokio::test)]
989 async fn test_get_registry_url() {
990 let client = TestHttpClient;
991 let result = get_registry_url(&client, &get_host()).await.unwrap();
992 assert_eq!(
993 result,
994 url::Host::Domain("registry-test.quilt.dev".to_string())
995 );
996 }
997
998 #[test(tokio::test)]
999 async fn test_get_auth_tokens() {
1000 let client = TestHttpClient;
1001 let tokens = get_auth_tokens(&client, &get_host(), REFRESH_TOKEN)
1002 .await
1003 .unwrap();
1004 assert_eq!(tokens.access_token, ACCESS_TOKEN);
1005 assert_eq!(tokens.refresh_token, "new-refresh-token");
1006 assert_eq!(
1007 tokens.expires_at,
1008 chrono::DateTime::from_timestamp(1_708_444_800, 0).unwrap()
1009 );
1010 }
1011
1012 #[test(tokio::test)]
1013 async fn test_refresh_credentials() {
1014 let client = TestHttpClient;
1015 let credentials = refresh_credentials(&client, &get_host(), ACCESS_TOKEN)
1016 .await
1017 .unwrap();
1018 assert_eq!(credentials.access_key, "test-access-key");
1019 assert_eq!(credentials.secret_key, "test-secret-key");
1020 assert_eq!(credentials.token, "test-session-token");
1021 assert_eq!(
1022 credentials.expires_at,
1023 chrono::DateTime::from_timestamp(1_708_444_800, 0).unwrap()
1024 );
1025 }
1026
1027 #[test(tokio::test)]
1028 async fn test_auth_refresh_credentials() -> Res {
1029 let storage = Arc::new(MockStorage::default());
1030 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1031 let auth = Auth::new(paths.clone(), storage.clone());
1032 let host = get_host();
1033
1034 let credentials = auth
1035 .refresh_credentials(&TestHttpClient, &host, ACCESS_TOKEN)
1036 .await?;
1037
1038 assert_eq!(credentials.access_key, "test-access-key");
1040 assert_eq!(credentials.secret_key, "test-secret-key");
1041 assert_eq!(credentials.token, "test-session-token");
1042 assert_eq!(
1043 credentials.expires_at,
1044 chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap()
1045 );
1046
1047 use crate::io::storage::StorageExt;
1050 let creds_path = paths.auth_host(&host).join(crate::paths::AUTH_CREDENTIALS);
1051 let bytes = storage.read_bytes(&creds_path).await?;
1052 let read_creds: Credentials = serde_json::from_slice(&bytes)?;
1053 assert_eq!(read_creds.access_key, credentials.access_key);
1054 assert_eq!(read_creds.secret_key, credentials.secret_key);
1055 assert_eq!(read_creds.token, credentials.token);
1056 assert_eq!(read_creds.expires_at, credentials.expires_at);
1057
1058 Ok(())
1059 }
1060
1061 #[test]
1062 fn test_remote_credentials_deserialization() {
1063 let valid_json = r#"{
1065 "AccessKeyId": "test-key",
1066 "Expiration": "2024-02-20T15:00:00Z",
1067 "SecretAccessKey": "test-secret",
1068 "SessionToken": "test-token"
1069 }"#;
1070
1071 let creds: RemoteCredentials = serde_json::from_str(valid_json).unwrap();
1072 assert_eq!(creds.access_key_id, "test-key");
1073 assert_eq!(creds.secret_access_key, "test-secret");
1074 assert_eq!(creds.session_token, "test-token");
1075 assert_eq!(
1076 creds.expiration,
1077 chrono::DateTime::parse_from_rfc3339("2024-02-20T15:00:00Z")
1078 .unwrap()
1079 .with_timezone(&chrono::Utc)
1080 );
1081
1082 let invalid_json = r#"{
1084 "AccessKeyId": "test-key",
1085 "Expiration": "2024-02-20 15:00:00",
1086 "SecretAccessKey": "test-secret",
1087 "SessionToken": "test-token"
1088 }"#;
1089
1090 let error = serde_json::from_str::<RemoteCredentials>(invalid_json).unwrap_err();
1091 assert!(error.to_string().contains("Invalid RFC3339 date"));
1092 }
1093
1094 const AUTH_CODE: &str = "test-auth-code";
1095 const CODE_VERIFIER: &str = "test-code-verifier-that-is-at-least-43-characters-long";
1096 const CLIENT_ID: &str = "test-client-id";
1097 const REDIRECT_URI: &str = "quilt://auth/callback?host=test.quilt.dev";
1098
1099 struct OAuthTestHttpClient {
1100 expected_credentials_token: &'static str,
1102 }
1103
1104 impl Default for OAuthTestHttpClient {
1105 fn default() -> Self {
1106 Self {
1107 expected_credentials_token: ACCESS_TOKEN,
1108 }
1109 }
1110 }
1111
1112 #[async_trait]
1113 impl HttpClient for OAuthTestHttpClient {
1114 async fn get<T: serde::de::DeserializeOwned>(
1115 &self,
1116 url: &str,
1117 auth_token: Option<&str>,
1118 ) -> Res<T> {
1119 let registry = get_registry();
1120
1121 match url {
1122 u if u == format!("https://{}/config.json", get_host()) => {
1123 let config = QuiltStackConfig {
1124 registry_url: format!("https://{registry}").parse()?,
1125 };
1126 Ok(serde_json::from_value(serde_json::to_value(config)?)?)
1127 }
1128 u if u == format!("https://{registry}/api/auth/get_credentials") => {
1129 assert_eq!(auth_token, Some(self.expected_credentials_token));
1130 let creds = RemoteCredentials {
1131 access_key_id: "oauth-access-key".to_string(),
1132 secret_access_key: "oauth-secret-key".to_string(),
1133 session_token: "oauth-session-token".to_string(),
1134 expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1135 };
1136 Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
1137 }
1138 _ => panic!("Unexpected GET URL: {url}"),
1139 }
1140 }
1141
1142 async fn head(&self, _url: &str) -> Res<HeaderMap> {
1143 unimplemented!()
1144 }
1145
1146 async fn post<T: serde::de::DeserializeOwned>(
1147 &self,
1148 url: &str,
1149 form_data: &HashMap<String, String>,
1150 ) -> Res<T> {
1151 assert_eq!(url, connect_token_url(&get_host()));
1152
1153 let tokens = match form_data.get("grant_type").map(String::as_str) {
1154 Some("authorization_code") => {
1155 assert_eq!(form_data.get("code").unwrap(), AUTH_CODE);
1156 assert_eq!(form_data.get("code_verifier").unwrap(), CODE_VERIFIER);
1157 assert_eq!(form_data.get("redirect_uri").unwrap(), REDIRECT_URI);
1158 assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
1159 OAuthTokenResponse {
1160 access_token: ACCESS_TOKEN.to_string(),
1161 refresh_token: Some("oauth-refresh-token".to_string()),
1162 expires_in: 3600,
1163 }
1164 }
1165 Some("refresh_token") => {
1166 assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
1167 assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
1168 OAuthTokenResponse {
1169 access_token: "refreshed-access-token".to_string(),
1170 refresh_token: Some("new-refresh-token".to_string()),
1171 expires_in: 3600,
1172 }
1173 }
1174 other => panic!("Unexpected grant_type: {other:?}"),
1175 };
1176 Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
1177 }
1178
1179 async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1180 &self,
1181 url: &str,
1182 body: &B,
1183 ) -> Res<T> {
1184 assert_eq!(url, connect_register_url(&get_host()));
1185 let json = serde_json::to_value(body)?;
1186 assert_eq!(json["client_name"], "QuiltSync");
1187 assert_eq!(json["token_endpoint_auth_method"], "none");
1188 let redirect_uris = json["redirect_uris"].as_array().expect("redirect_uris");
1189 assert_eq!(redirect_uris.len(), 1);
1190 assert!(
1191 redirect_uris[0]
1192 .as_str()
1193 .unwrap()
1194 .starts_with("quilt://auth/callback?host=")
1195 );
1196 Ok(serde_json::from_value(serde_json::json!({
1197 "client_id": "test-dcr-client-id"
1198 }))?)
1199 }
1200 }
1201
1202 #[test]
1203 fn test_connect_host() {
1204 let host: Host = "test.quilt.dev".parse().unwrap();
1205 assert_eq!(connect_host(&host), "test-connect.quilt.dev");
1206 }
1207
1208 #[test]
1209 fn test_connect_token_url() {
1210 let host: Host = "test.quilt.dev".parse().unwrap();
1211 assert_eq!(
1212 connect_token_url(&host),
1213 "https://test-connect.quilt.dev/auth/token"
1214 );
1215 }
1216
1217 #[test(tokio::test)]
1218 async fn test_exchange_oauth_code() {
1219 let client = OAuthTestHttpClient::default();
1220 let params = OAuthParams {
1221 code: AUTH_CODE.to_string(),
1222 code_verifier: CODE_VERIFIER.to_string(),
1223 redirect_uri: REDIRECT_URI.to_string(),
1224 client_id: CLIENT_ID.to_string(),
1225 };
1226 let tokens = exchange_oauth_code(&client, &get_host(), ¶ms)
1227 .await
1228 .unwrap();
1229 assert_eq!(tokens.access_token, ACCESS_TOKEN);
1230 assert_eq!(tokens.refresh_token, "oauth-refresh-token");
1231 }
1232
1233 #[test]
1234 fn test_pkce_challenge() {
1235 let pkce = pkce_challenge();
1236
1237 assert_eq!(pkce.code_verifier.len(), 86);
1239
1240 assert_eq!(pkce.code_challenge.len(), 43);
1242
1243 let expected_challenge =
1245 URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.code_verifier.as_bytes()));
1246 assert_eq!(pkce.code_challenge, expected_challenge);
1247
1248 let pkce2 = pkce_challenge();
1250 assert_ne!(pkce.code_verifier, pkce2.code_verifier);
1251 }
1252
1253 #[test]
1255 fn test_pkce_verifier_charset_rfc7636() {
1256 let pkce = pkce_challenge();
1257 for ch in pkce.code_verifier.chars() {
1258 assert!(
1259 ch.is_ascii_alphanumeric() || matches!(ch, '-' | '.' | '_' | '~'),
1260 "code_verifier contains char '{ch}' not allowed by RFC 7636 §4.1"
1261 );
1262 }
1263 }
1264
1265 #[test(tokio::test)]
1266 async fn test_login_oauth() -> Res {
1267 let storage = Arc::new(MockStorage::default());
1268 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1269 let auth = Auth::new(paths, storage);
1270 let host = get_host();
1271
1272 let params = OAuthParams {
1273 code: AUTH_CODE.to_string(),
1274 code_verifier: CODE_VERIFIER.to_string(),
1275 redirect_uri: REDIRECT_URI.to_string(),
1276 client_id: CLIENT_ID.to_string(),
1277 };
1278
1279 auth.login_oauth(&OAuthTestHttpClient::default(), &host, params)
1280 .await?;
1281 Ok(())
1282 }
1283
1284 #[test(tokio::test)]
1285 async fn test_refresh_oauth_tokens() -> Res {
1286 let tokens = refresh_oauth_tokens(
1287 &OAuthTestHttpClient::default(),
1288 &get_host(),
1289 REFRESH_TOKEN,
1290 CLIENT_ID,
1291 )
1292 .await?;
1293 assert_eq!(tokens.access_token, "refreshed-access-token");
1294 assert_eq!(tokens.refresh_token, "new-refresh-token");
1295 Ok(())
1296 }
1297
1298 #[test(tokio::test)]
1301 async fn test_refresh_oauth_tokens_retains_old_when_omitted() -> Res {
1302 struct NoRefreshTokenClient;
1303
1304 #[async_trait]
1305 impl HttpClient for NoRefreshTokenClient {
1306 async fn get<T: serde::de::DeserializeOwned>(
1307 &self,
1308 _: &str,
1309 _: Option<&str>,
1310 ) -> Res<T> {
1311 unimplemented!()
1312 }
1313 async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
1314 unimplemented!()
1315 }
1316 async fn post<T: serde::de::DeserializeOwned>(
1317 &self,
1318 _: &str,
1319 _: &HashMap<String, String>,
1320 ) -> Res<T> {
1321 let resp = OAuthTokenResponse {
1322 access_token: "new-access-token".to_string(),
1323 refresh_token: None, expires_in: DEFAULT_EXPIRES_IN,
1325 };
1326 Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
1327 }
1328 async fn post_json<
1329 T: serde::de::DeserializeOwned,
1330 B: serde::Serialize + Send + Sync,
1331 >(
1332 &self,
1333 _: &str,
1334 _: &B,
1335 ) -> Res<T> {
1336 unimplemented!()
1337 }
1338 }
1339
1340 let tokens =
1341 refresh_oauth_tokens(&NoRefreshTokenClient, &get_host(), REFRESH_TOKEN, CLIENT_ID)
1342 .await?;
1343 assert_eq!(tokens.access_token, "new-access-token");
1344 assert_eq!(tokens.refresh_token, REFRESH_TOKEN);
1346 Ok(())
1347 }
1348
1349 #[test(tokio::test)]
1352 async fn test_exchange_oauth_code_errors_when_refresh_token_missing() {
1353 struct NoRefreshTokenClient;
1354
1355 #[async_trait]
1356 impl HttpClient for NoRefreshTokenClient {
1357 async fn get<T: serde::de::DeserializeOwned>(
1358 &self,
1359 _: &str,
1360 _: Option<&str>,
1361 ) -> Res<T> {
1362 unimplemented!()
1363 }
1364 async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
1365 unimplemented!()
1366 }
1367 async fn post<T: serde::de::DeserializeOwned>(
1368 &self,
1369 _: &str,
1370 _: &HashMap<String, String>,
1371 ) -> Res<T> {
1372 let resp = OAuthTokenResponse {
1373 access_token: ACCESS_TOKEN.to_string(),
1374 refresh_token: None,
1375 expires_in: DEFAULT_EXPIRES_IN,
1376 };
1377 Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
1378 }
1379 async fn post_json<
1380 T: serde::de::DeserializeOwned,
1381 B: serde::Serialize + Send + Sync,
1382 >(
1383 &self,
1384 _: &str,
1385 _: &B,
1386 ) -> Res<T> {
1387 unimplemented!()
1388 }
1389 }
1390
1391 let params = OAuthParams {
1392 code: AUTH_CODE.to_string(),
1393 code_verifier: CODE_VERIFIER.to_string(),
1394 redirect_uri: REDIRECT_URI.to_string(),
1395 client_id: CLIENT_ID.to_string(),
1396 };
1397 let result = exchange_oauth_code(&NoRefreshTokenClient, &get_host(), ¶ms).await;
1398 assert!(
1399 matches!(result, Err(Error::Auth(_, AuthError::TokensExchange(_)))),
1400 "expected TokensExchange error, got: {result:?}"
1401 );
1402 }
1403
1404 #[test]
1407 fn test_oauth_token_response_missing_expires_in() {
1408 let json = r#"{"access_token":"tok","refresh_token":"ref"}"#;
1409 let resp: OAuthTokenResponse = serde_json::from_str(json).unwrap();
1410 assert_eq!(resp.expires_in, DEFAULT_EXPIRES_IN);
1411 }
1412
1413 const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token";
1414
1415 #[test(tokio::test)]
1416 async fn test_get_credentials_or_refresh_with_expired_token() -> Res {
1417 let storage = Arc::new(MockStorage::default());
1418 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1419 let auth = Auth::new(paths.clone(), storage.clone());
1420 let host = get_host();
1421
1422 let auth_io = AuthIo::new(storage, paths.auth_host(&host));
1424 auth_io
1425 .write_tokens(&Tokens {
1426 access_token: "expired-access-token".to_string(),
1427 refresh_token: REFRESH_TOKEN.to_string(),
1428 expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
1429 })
1430 .await?;
1431 auth_io
1432 .write_client(&OAuthClient {
1433 client_id: CLIENT_ID.to_string(),
1434 redirect_uri: REDIRECT_URI.to_string(),
1435 })
1436 .await?;
1437
1438 let client = OAuthTestHttpClient {
1439 expected_credentials_token: REFRESHED_ACCESS_TOKEN,
1440 };
1441 let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1442
1443 assert_eq!(creds.access_key, "oauth-access-key");
1445
1446 let persisted = auth_io
1448 .read_tokens()
1449 .await?
1450 .expect("tokens should be persisted");
1451 assert_eq!(persisted.access_token, REFRESHED_ACCESS_TOKEN);
1452 assert_eq!(persisted.refresh_token, "new-refresh-token");
1453
1454 Ok(())
1455 }
1456
1457 #[test(tokio::test)]
1458 async fn test_get_or_register_client() -> Res {
1459 let storage = Arc::new(MockStorage::default());
1460 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1461 let auth = Auth::new(paths, storage);
1462 let host = get_host();
1463
1464 let client = auth
1466 .get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
1467 .await?;
1468 assert_eq!(client.client_id, "test-dcr-client-id");
1469 assert_eq!(client.redirect_uri, REDIRECT_URI);
1470
1471 let client2 = auth
1473 .get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
1474 .await?;
1475 assert_eq!(client2.client_id, "test-dcr-client-id");
1476
1477 let new_redirect = "quilt://auth/callback?host=other.quilt.dev";
1479 let client3 = auth
1480 .get_or_register_client(&OAuthTestHttpClient::default(), &host, new_redirect)
1481 .await?;
1482 assert_eq!(client3.client_id, "test-dcr-client-id");
1483 assert_eq!(client3.redirect_uri, new_redirect);
1484
1485 Ok(())
1486 }
1487
1488 #[test]
1489 fn remote_tokens_debug_redacts_secrets() {
1490 let tokens = RemoteTokens {
1491 access_token: "secret-access".to_string(),
1492 refresh_token: "secret-refresh".to_string(),
1493 expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1494 };
1495 let output = format!("{tokens:?}");
1496 assert!(output.contains("[REDACTED]"));
1497 assert!(!output.contains("secret-access"));
1498 assert!(!output.contains("secret-refresh"));
1499 }
1500
1501 #[test]
1502 fn oauth_token_response_debug_redacts_secrets() {
1503 let response = OAuthTokenResponse {
1504 access_token: "secret-access".to_string(),
1505 refresh_token: Some("secret-refresh".to_string()),
1506 expires_in: 3600,
1507 };
1508 let output = format!("{response:?}");
1509 assert!(output.contains("[REDACTED]"));
1510 assert!(!output.contains("secret-access"));
1511 assert!(!output.contains("secret-refresh"));
1512 }
1513
1514 #[test]
1515 fn oauth_token_response_debug_none_refresh_token() {
1516 let response = OAuthTokenResponse {
1517 access_token: "secret-access".to_string(),
1518 refresh_token: None,
1519 expires_in: 3600,
1520 };
1521 let output = format!("{response:?}");
1522 assert!(output.contains("refresh_token: None"));
1523 assert!(!output.contains("secret-access"));
1524 }
1525
1526 #[test]
1527 fn remote_credentials_debug_redacts_secrets() {
1528 let creds = RemoteCredentials {
1529 access_key_id: "secret-key-id".to_string(),
1530 expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1531 secret_access_key: "secret-access-key".to_string(),
1532 session_token: "secret-session-token".to_string(),
1533 };
1534 let output = format!("{creds:?}");
1535 assert!(output.contains("[REDACTED]"));
1536 assert!(!output.contains("secret-key-id"));
1537 assert!(!output.contains("secret-access-key"));
1538 assert!(!output.contains("secret-session-token"));
1539 }
1540
1541 use std::sync::atomic::AtomicUsize;
1544 use std::sync::atomic::Ordering;
1545 use tokio::io::AsyncReadExt;
1546 use tokio::io::AsyncWriteExt;
1547
1548 async fn spawn_one_shot(response: Vec<u8>) -> std::net::SocketAddr {
1551 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1552 let addr = listener.local_addr().unwrap();
1553 tokio::spawn(async move {
1554 if let Ok((mut stream, _)) = listener.accept().await {
1555 let mut buf = [0u8; 4096];
1556 let _ = stream.read(&mut buf).await;
1557 let _ = stream.write_all(&response).await;
1558 let _ = stream.shutdown().await;
1559 }
1560 });
1561 addr
1562 }
1563
1564 async fn reqwest_error_with_status(status: u16) -> Error {
1568 let body = format!("HTTP/1.1 {status} X\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
1569 .into_bytes();
1570 let addr = spawn_one_shot(body).await;
1571 reqwest::Client::new()
1572 .get(format!("http://{addr}/"))
1573 .send()
1574 .await
1575 .unwrap()
1576 .error_for_status()
1577 .unwrap_err()
1578 .into()
1579 }
1580
1581 struct RetryMockClient {
1584 cred_fail_first_n: usize,
1585 token_fail_first_n: usize,
1586 cred_calls: AtomicUsize,
1587 token_calls: AtomicUsize,
1588 }
1589
1590 impl RetryMockClient {
1591 fn new(cred_fail: usize, token_fail: usize) -> Self {
1592 Self {
1593 cred_fail_first_n: cred_fail,
1594 token_fail_first_n: token_fail,
1595 cred_calls: AtomicUsize::new(0),
1596 token_calls: AtomicUsize::new(0),
1597 }
1598 }
1599 }
1600
1601 #[async_trait]
1602 impl HttpClient for RetryMockClient {
1603 async fn get<T: serde::de::DeserializeOwned>(
1604 &self,
1605 url: &str,
1606 _auth_token: Option<&str>,
1607 ) -> Res<T> {
1608 let registry = get_registry();
1609 if url == format!("https://{}/config.json", get_host()) {
1610 let config = QuiltStackConfig {
1611 registry_url: format!("https://{registry}").parse()?,
1612 };
1613 return Ok(serde_json::from_value(serde_json::to_value(config)?)?);
1614 }
1615 if url == format!("https://{registry}/api/auth/get_credentials") {
1616 let n = self.cred_calls.fetch_add(1, Ordering::SeqCst);
1617 if n < self.cred_fail_first_n {
1618 return Err(reqwest_error_with_status(401).await);
1619 }
1620 let creds = RemoteCredentials {
1621 access_key_id: "oauth-access-key".to_string(),
1622 secret_access_key: "oauth-secret-key".to_string(),
1623 session_token: "oauth-session-token".to_string(),
1624 expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
1625 };
1626 return Ok(serde_json::from_value(serde_json::to_value(creds)?)?);
1627 }
1628 panic!("Unexpected GET URL: {url}")
1629 }
1630
1631 async fn head(&self, _url: &str) -> Res<HeaderMap> {
1632 unimplemented!()
1633 }
1634
1635 async fn post<T: serde::de::DeserializeOwned>(
1636 &self,
1637 url: &str,
1638 form_data: &HashMap<String, String>,
1639 ) -> Res<T> {
1640 assert_eq!(url, connect_token_url(&get_host()));
1641 let n = self.token_calls.fetch_add(1, Ordering::SeqCst);
1642 if n < self.token_fail_first_n {
1643 return Err(reqwest_error_with_status(401).await);
1644 }
1645 assert_eq!(
1646 form_data.get("grant_type").map(String::as_str),
1647 Some("refresh_token")
1648 );
1649 let tokens = OAuthTokenResponse {
1650 access_token: REFRESHED_ACCESS_TOKEN.to_string(),
1651 refresh_token: Some("new-refresh-token".to_string()),
1652 expires_in: 3600,
1653 };
1654 Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
1655 }
1656
1657 async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1658 &self,
1659 _url: &str,
1660 _body: &B,
1661 ) -> Res<T> {
1662 unimplemented!()
1663 }
1664 }
1665
1666 async fn seed_fresh_tokens(storage: &Arc<MockStorage>, paths: &DomainPaths, host: &Host) {
1667 let auth_io = AuthIo::new(storage.clone(), paths.auth_host(host));
1668 auth_io
1669 .write_tokens(&Tokens {
1670 access_token: ACCESS_TOKEN.to_string(),
1671 refresh_token: REFRESH_TOKEN.to_string(),
1672 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1674 })
1675 .await
1676 .unwrap();
1677 auth_io
1678 .write_client(&OAuthClient {
1679 client_id: CLIENT_ID.to_string(),
1680 redirect_uri: REDIRECT_URI.to_string(),
1681 })
1682 .await
1683 .unwrap();
1684 }
1685
1686 #[test(tokio::test)]
1690 async fn test_credentials_transient_401_recovers_via_force_token_refresh() -> Res {
1691 let storage = Arc::new(MockStorage::default());
1692 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1693 let auth = Auth::new(paths.clone(), storage.clone());
1694 let host = get_host();
1695 seed_fresh_tokens(&storage, &paths, &host).await;
1696
1697 let client = RetryMockClient::new(1, 0);
1698 let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1699
1700 assert_eq!(creds.access_key, "oauth-access-key");
1701 assert_eq!(
1702 client.cred_calls.load(Ordering::SeqCst),
1703 2,
1704 "credentials endpoint should be called twice: initial + retry"
1705 );
1706 assert_eq!(
1707 client.token_calls.load(Ordering::SeqCst),
1708 1,
1709 "token endpoint should be called once to force-refresh"
1710 );
1711 Ok(())
1712 }
1713
1714 #[test(tokio::test)]
1717 async fn test_credentials_persistent_401_maps_to_login_required() -> Res {
1718 let storage = Arc::new(MockStorage::default());
1719 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1720 let auth = Auth::new(paths.clone(), storage.clone());
1721 let host = get_host();
1722 seed_fresh_tokens(&storage, &paths, &host).await;
1723
1724 let client = RetryMockClient::new(usize::MAX, 0);
1725 let result = auth.get_credentials_or_refresh(&client, &host).await;
1726
1727 assert!(
1728 matches!(result, Err(Error::Login(LoginError::Required(_)))),
1729 "expected LoginRequired after persistent 4xx, got: {result:?}"
1730 );
1731 assert_eq!(
1732 client.cred_calls.load(Ordering::SeqCst),
1733 2,
1734 "retry must be bounded to one extra attempt"
1735 );
1736 Ok(())
1737 }
1738
1739 #[test(tokio::test)]
1743 async fn test_token_refresh_transient_401_recovers() -> Res {
1744 let storage = Arc::new(MockStorage::default());
1745 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1746 let auth = Auth::new(paths.clone(), storage.clone());
1747 let host = get_host();
1748
1749 let auth_io = AuthIo::new(storage.clone(), paths.auth_host(&host));
1751 auth_io
1752 .write_tokens(&Tokens {
1753 access_token: "expired-access-token".to_string(),
1754 refresh_token: REFRESH_TOKEN.to_string(),
1755 expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
1756 })
1757 .await?;
1758 auth_io
1759 .write_client(&OAuthClient {
1760 client_id: CLIENT_ID.to_string(),
1761 redirect_uri: REDIRECT_URI.to_string(),
1762 })
1763 .await?;
1764
1765 let client = RetryMockClient::new(0, 1);
1766 let creds = auth.get_credentials_or_refresh(&client, &host).await?;
1767
1768 assert_eq!(creds.access_key, "oauth-access-key");
1769 assert_eq!(
1770 client.token_calls.load(Ordering::SeqCst),
1771 2,
1772 "token endpoint should be called twice: initial + retry"
1773 );
1774 assert_eq!(
1775 client.cred_calls.load(Ordering::SeqCst),
1776 1,
1777 "credentials endpoint should only be called once after successful retry"
1778 );
1779 Ok(())
1780 }
1781
1782 #[derive(Default)]
1787 struct Gate {
1788 entered: tokio::sync::Notify,
1789 release: tokio::sync::Notify,
1790 }
1791
1792 #[derive(Clone)]
1797 struct CountingCredsClient {
1798 cred_calls: Arc<std::sync::atomic::AtomicUsize>,
1799 sleep_ms: u64,
1800 gate: Option<Arc<Gate>>,
1801 }
1802
1803 #[async_trait]
1804 impl HttpClient for CountingCredsClient {
1805 async fn get<T: serde::de::DeserializeOwned>(
1806 &self,
1807 url: &str,
1808 _auth_token: Option<&str>,
1809 ) -> Res<T> {
1810 if url.ends_with("/config.json") {
1811 let body = serde_json::json!({
1812 "registryUrl": format!("https://{}", get_registry()),
1813 });
1814 return Ok(serde_json::from_value(body)?);
1815 }
1816 if url.contains("/api/auth/get_credentials") {
1817 self.cred_calls
1818 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1819 if let Some(gate) = &self.gate {
1820 gate.entered.notify_one();
1821 gate.release.notified().await;
1822 } else if self.sleep_ms > 0 {
1823 tokio::time::sleep(std::time::Duration::from_millis(self.sleep_ms)).await;
1824 }
1825 let body = serde_json::json!({
1826 "AccessKeyId": "refreshed-key",
1827 "SecretAccessKey": "refreshed-secret",
1828 "SessionToken": "refreshed-session",
1829 "Expiration": (chrono::Utc::now() + chrono::Duration::hours(1))
1830 .to_rfc3339(),
1831 });
1832 return Ok(serde_json::from_value(body)?);
1833 }
1834 panic!("Unexpected GET: {url}");
1835 }
1836 async fn head(&self, _: &str) -> Res<HeaderMap> {
1837 unimplemented!()
1838 }
1839 async fn post<T: serde::de::DeserializeOwned>(
1840 &self,
1841 _: &str,
1842 _: &HashMap<String, String>,
1843 ) -> Res<T> {
1844 unimplemented!("fresh tokens → no OAuth leg fires")
1845 }
1846 async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
1847 &self,
1848 _: &str,
1849 _: &B,
1850 ) -> Res<T> {
1851 unimplemented!()
1852 }
1853 }
1854
1855 async fn seed_expired_creds_fresh_tokens(auth_io: &AuthIo<Arc<MockStorage>>) -> Res {
1856 auth_io
1857 .write_credentials(&Credentials {
1858 access_key: "stale".to_string(),
1859 secret_key: "stale-secret".to_string(),
1860 token: "stale-session".to_string(),
1861 expires_at: chrono::Utc::now() - chrono::Duration::hours(1),
1862 })
1863 .await?;
1864 auth_io
1865 .write_tokens(&Tokens {
1866 access_token: ACCESS_TOKEN.to_string(),
1867 refresh_token: REFRESH_TOKEN.to_string(),
1868 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
1869 })
1870 .await?;
1871 Ok(())
1872 }
1873
1874 #[test(tokio::test)]
1875 async fn test_auth_refresh_is_single_flight_across_concurrent_callers() -> Res {
1876 let storage = Arc::new(MockStorage::default());
1877 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1878 let auth = Auth::new(paths.clone(), storage.clone());
1879 let host = get_host();
1880
1881 let auth_io = AuthIo::new(storage, paths.auth_host(&host));
1882 seed_expired_creds_fresh_tokens(&auth_io).await?;
1883
1884 let client = CountingCredsClient {
1885 cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1886 sleep_ms: 50,
1887 gate: None,
1888 };
1889
1890 let mut handles = Vec::new();
1891 for _ in 0..10 {
1892 let auth = auth.clone();
1893 let client = client.clone();
1894 let host = host.clone();
1895 handles.push(tokio::spawn(async move {
1896 auth.get_credentials_or_refresh(&client, &host).await
1897 }));
1898 }
1899
1900 let mut creds_seen = Vec::new();
1901 for h in handles {
1902 creds_seen.push(h.await.unwrap()?);
1903 }
1904
1905 assert_eq!(
1906 client.cred_calls.load(std::sync::atomic::Ordering::SeqCst),
1907 1,
1908 "single-flight: 10 concurrent callers must produce exactly one refresh",
1909 );
1910 let first = &creds_seen[0];
1911 for creds in &creds_seen {
1912 assert_eq!(creds.access_key, first.access_key);
1913 assert_eq!(creds.expires_at, first.expires_at);
1914 }
1915 assert_eq!(first.access_key, "refreshed-key");
1916 Ok(())
1917 }
1918
1919 #[test(tokio::test)]
1920 async fn test_auth_refresh_lock_is_per_host() -> Res {
1921 let storage = Arc::new(MockStorage::default());
1922 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1923 let auth = Auth::new(paths.clone(), storage.clone());
1924
1925 let host_a: Host = "a.quilt.dev".parse().unwrap();
1926 let host_b: Host = "b.quilt.dev".parse().unwrap();
1927
1928 seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_a)))
1930 .await?;
1931 seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_b)))
1932 .await?;
1933
1934 let gate = Arc::new(Gate::default());
1938 let gated_client = CountingCredsClient {
1939 cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1940 sleep_ms: 0,
1941 gate: Some(gate.clone()),
1942 };
1943 let fast_client = CountingCredsClient {
1944 cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1945 sleep_ms: 0,
1946 gate: None,
1947 };
1948
1949 let auth_clone = auth.clone();
1950 let client_a = gated_client.clone();
1951 let host_a_clone = host_a.clone();
1952 let a_task = tokio::spawn(async move {
1953 auth_clone
1954 .get_credentials_or_refresh(&client_a, &host_a_clone)
1955 .await
1956 });
1957
1958 gate.entered.notified().await;
1961
1962 tokio::time::timeout(
1966 std::time::Duration::from_secs(5),
1967 auth.get_credentials_or_refresh(&fast_client, &host_b),
1968 )
1969 .await
1970 .expect("host_b refresh must not wait behind host_a's lock")?;
1971
1972 assert!(
1975 !a_task.is_finished(),
1976 "host_a must still be blocked in its handler while host_b completes",
1977 );
1978
1979 gate.release.notify_one();
1981 a_task.await.unwrap()?;
1982 Ok(())
1983 }
1984
1985 #[test(tokio::test)]
1986 async fn test_refresh_lock_map_sweeps_dead_entries() -> Res {
1987 let storage = Arc::new(MockStorage::default());
1988 let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
1989 let auth = Auth::new(paths, storage);
1990
1991 let host: Host = "x.quilt.dev".parse().unwrap();
1992
1993 let arc1 = auth.refresh_lock_for(&host);
1995 assert_eq!(
1996 auth.refresh_locks
1997 .lock()
1998 .unwrap_or_else(std::sync::PoisonError::into_inner)
1999 .len(),
2000 1,
2001 );
2002
2003 drop(arc1);
2005 assert!(
2006 auth.refresh_locks
2007 .lock()
2008 .unwrap_or_else(std::sync::PoisonError::into_inner)
2009 .get(&host)
2010 .expect("entry still present before sweep")
2011 .upgrade()
2012 .is_none(),
2013 );
2014
2015 let _arc2 = auth.refresh_lock_for(&host);
2018 assert_eq!(
2019 auth.refresh_locks
2020 .lock()
2021 .unwrap_or_else(std::sync::PoisonError::into_inner)
2022 .len(),
2023 1,
2024 );
2025 Ok(())
2026 }
2027}