use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::Weak;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use sha2::Digest;
use sha2::Sha256;
use tokio::sync::Mutex as AsyncMutex;
use crate::error::AuthError;
use crate::error::LoginError;
use crate::io::remote::client::HttpClient;
use crate::io::storage::auth::AuthIo;
use crate::io::storage::auth::Credentials;
use crate::io::storage::auth::OAuthClient;
use crate::io::storage::auth::Tokens;
use crate::io::storage::LocalStorage;
use crate::io::storage::Storage;
use crate::paths::DomainPaths;
use crate::uri::Host;
use crate::Error;
use crate::Res;
use chrono::serde::ts_seconds;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::warn;
pub struct OAuthParams {
pub code: String,
pub code_verifier: String,
pub redirect_uri: String,
pub client_id: String,
}
pub struct PkceChallenge {
pub code_verifier: String,
pub code_challenge: String,
}
pub fn pkce_challenge() -> PkceChallenge {
let mut random_bytes = [0u8; 64];
getrandom::fill(&mut random_bytes).expect("failed to generate random bytes");
let code_verifier = URL_SAFE_NO_PAD.encode(random_bytes);
let code_challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(code_verifier.as_bytes()));
PkceChallenge {
code_verifier,
code_challenge,
}
}
pub fn random_state() -> String {
let mut bytes = [0u8; 16];
getrandom::fill(&mut bytes).expect("failed to generate random bytes");
URL_SAFE_NO_PAD.encode(bytes)
}
pub fn catalog_authorize_url(host: &Host) -> String {
format!("https://{host}/connect/authorize")
}
pub fn connect_host(host: &Host) -> String {
let s = host.to_string();
match s.split_once('.') {
Some((stack, domain)) => format!("{stack}-connect.{domain}"),
None => format!("{s}-connect"),
}
}
fn connect_token_url(host: &Host) -> String {
format!("https://{}/auth/token", connect_host(host))
}
fn connect_register_url(host: &Host) -> String {
format!("https://{}/auth/register", connect_host(host))
}
#[derive(Serialize)]
struct DcrRequest {
client_name: String,
redirect_uris: Vec<String>,
token_endpoint_auth_method: String,
}
#[derive(Deserialize)]
struct DcrResponse {
client_id: String,
}
async fn register_client(
http_client: &impl HttpClient,
host: &Host,
redirect_uri: &str,
) -> Res<OAuthClient> {
let register_url = connect_register_url(host);
let request = DcrRequest {
client_name: "QuiltSync".to_string(),
redirect_uris: vec![redirect_uri.to_string()],
token_endpoint_auth_method: "none".to_string(),
};
let response: DcrResponse = http_client.post_json(®ister_url, &request).await?;
Ok(OAuthClient {
client_id: response.client_id,
redirect_uri: redirect_uri.to_string(),
})
}
#[derive(Deserialize, Serialize)]
pub struct RemoteTokens {
pub access_token: String,
pub refresh_token: String,
#[serde(with = "ts_seconds")]
pub expires_at: chrono::DateTime<chrono::Utc>,
}
impl fmt::Debug for RemoteTokens {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteTokens")
.field("expires_at", &self.expires_at)
.field("access_token", &"[REDACTED]")
.field("refresh_token", &"[REDACTED]")
.finish_non_exhaustive()
}
}
impl From<RemoteTokens> for Tokens {
fn from(raw: RemoteTokens) -> Self {
Tokens {
access_token: raw.access_token,
refresh_token: raw.refresh_token,
expires_at: raw.expires_at,
}
}
}
const DEFAULT_EXPIRES_IN: i64 = 3600;
fn default_expires_in() -> i64 {
DEFAULT_EXPIRES_IN
}
#[derive(Deserialize, Serialize)]
struct OAuthTokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default = "default_expires_in")]
expires_in: i64,
}
impl fmt::Debug for OAuthTokenResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OAuthTokenResponse")
.field("expires_in", &self.expires_in)
.field("access_token", &"[REDACTED]")
.field(
"refresh_token",
&self.refresh_token.as_ref().map(|_| "[REDACTED]"),
)
.finish_non_exhaustive()
}
}
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "PascalCase")]
struct RemoteCredentials {
access_key_id: String,
#[serde(deserialize_with = "date_from_rfc3339")]
expiration: chrono::DateTime<chrono::Utc>,
secret_access_key: String,
session_token: String,
}
impl fmt::Debug for RemoteCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoteCredentials")
.field("expiration", &self.expiration)
.field("access_key_id", &"[REDACTED]")
.field("secret_access_key", &"[REDACTED]")
.field("session_token", &"[REDACTED]")
.finish_non_exhaustive()
}
}
impl From<RemoteCredentials> for Credentials {
fn from(raw: RemoteCredentials) -> Self {
Credentials {
access_key: raw.access_key_id,
secret_key: raw.secret_access_key,
token: raw.session_token,
expires_at: raw.expiration,
}
}
}
fn date_from_rfc3339<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<chrono::DateTime<chrono::Utc>, D::Error> {
use serde::de::Error;
String::deserialize(deserializer).and_then(|s| {
chrono::DateTime::parse_from_rfc3339(&s)
.map_err(|e| Error::custom(format!("Invalid RFC3339 date: {e}")))
.map(|dt| dt.with_timezone(&chrono::Utc))
})
}
#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct QuiltStackConfig {
registry_url: url::Url,
}
async fn get_registry_url(http_client: &impl HttpClient, host: &Host) -> Res<url::Host> {
let QuiltStackConfig { registry_url } = http_client
.get(&format!("https://{host}/config.json"), None)
.await?;
Ok(url::Host::Domain(
registry_url
.domain()
.ok_or(LoginError::RequiredRegistryUrl(host.to_owned()))?
.to_string(),
))
}
async fn get_auth_tokens(
http_client: &impl HttpClient,
host: &Host,
refresh_token: &str,
) -> Res<Tokens> {
let registry = get_registry_url(http_client, host).await?;
let mut form_data: HashMap<String, String> = HashMap::new();
form_data.insert("refresh_token".to_string(), refresh_token.to_string());
let tokens_json: RemoteTokens = http_client
.post(&format!("https://{registry}/api/token"), &form_data)
.await?;
let tokens = Tokens::from(tokens_json);
Ok(tokens)
}
async fn exchange_oauth_code(
http_client: &impl HttpClient,
host: &Host,
params: &OAuthParams,
) -> Res<Tokens> {
let token_url = connect_token_url(host);
let mut form_data: HashMap<String, String> = HashMap::new();
form_data.insert("grant_type".to_string(), "authorization_code".to_string());
form_data.insert("code".to_string(), params.code.clone());
form_data.insert("code_verifier".to_string(), params.code_verifier.clone());
form_data.insert("redirect_uri".to_string(), params.redirect_uri.clone());
form_data.insert("client_id".to_string(), params.client_id.clone());
let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
Ok(Tokens {
access_token: response.access_token,
refresh_token: response.refresh_token.ok_or_else(|| {
Error::Auth(
host.to_owned(),
AuthError::TokensExchange("server did not return a refresh token".to_string()),
)
})?,
expires_at,
})
}
async fn refresh_oauth_tokens(
http_client: &impl HttpClient,
host: &Host,
refresh_token: &str,
client_id: &str,
) -> Res<Tokens> {
let token_url = connect_token_url(host);
let mut form_data: HashMap<String, String> = HashMap::new();
form_data.insert("grant_type".to_string(), "refresh_token".to_string());
form_data.insert("refresh_token".to_string(), refresh_token.to_string());
form_data.insert("client_id".to_string(), client_id.to_string());
let response: OAuthTokenResponse = http_client.post(&token_url, &form_data).await?;
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(response.expires_in);
Ok(Tokens {
access_token: response.access_token,
refresh_token: response
.refresh_token
.unwrap_or_else(|| refresh_token.to_string()),
expires_at,
})
}
async fn refresh_credentials(
http_client: &impl HttpClient,
host: &Host,
access_token: &str,
) -> Res<Credentials> {
let registry = get_registry_url(http_client, host).await?;
let creds_json: RemoteCredentials = http_client
.get(
&format!("https://{registry}/api/auth/get_credentials"),
Some(access_token),
)
.await?;
let credentials = Credentials::from(creds_json);
Ok(credentials)
}
fn is_token_auth_error(e: &Error) -> bool {
matches!(
e,
Error::Reqwest(re) if re.status().is_some_and(|s| s == 400 || s == 401 || s == 403)
)
}
fn is_credentials_auth_error(e: &Error) -> bool {
matches!(
e,
Error::Reqwest(re) if re.status().is_some_and(|s| s == 401 || s == 403)
)
}
fn http_status(e: &Error) -> Option<u16> {
match e {
Error::Reqwest(re) => re.status().map(|s| s.as_u16()),
_ => None,
}
}
fn classify_retry_outcome<T>(
result: Res<T>,
is_auth_error: fn(&Error) -> bool,
endpoint: &str,
host: &Host,
) -> Res<T> {
match result {
Ok(v) => {
info!(
"✔️ Recovered from transient auth error on {} for {}",
endpoint, host
);
Ok(v)
}
Err(e) if is_auth_error(&e) => {
warn!(
status = ?http_status(&e),
"❌ Auth error on {} for {} persisted after retry, login required: {}",
endpoint, host, e
);
Err(LoginError::Required(Some(host.to_owned())).into())
}
Err(e) => {
warn!(
status = ?http_status(&e),
"❌ Failed to refresh via {} for {} on retry: {}",
endpoint, host, e
);
Err(e)
}
}
}
type RefreshLocks = Arc<StdMutex<HashMap<Host, Weak<AsyncMutex<()>>>>>;
#[derive(Debug)]
pub struct Auth<S: Storage = LocalStorage> {
pub paths: DomainPaths,
pub storage: Arc<S>,
refresh_locks: RefreshLocks,
}
impl<S: Storage> Clone for Auth<S> {
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
storage: Arc::clone(&self.storage),
refresh_locks: Arc::clone(&self.refresh_locks),
}
}
}
impl<S: Storage + Send + Sync> Auth<S> {
pub fn new(paths: DomainPaths, storage: Arc<S>) -> Self {
Self {
paths,
storage,
refresh_locks: Arc::new(StdMutex::new(HashMap::new())),
}
}
fn refresh_lock_for(&self, host: &Host) -> Arc<AsyncMutex<()>> {
let mut locks = self.refresh_locks.lock().unwrap_or_else(|e| e.into_inner());
locks.retain(|_, weak| weak.strong_count() > 0);
if let Some(arc) = locks.get(host).and_then(Weak::upgrade) {
return arc;
}
let arc = Arc::new(AsyncMutex::new(()));
locks.insert(host.clone(), Arc::downgrade(&arc));
arc
}
pub async fn login<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
refresh_token: String,
) -> Res {
info!("⏳ Logging in to host {} with refresh token", host);
let tokens = match self
.get_auth_tokens(http_client, host, &refresh_token)
.await
{
Ok(t) => t,
Err(e) => {
warn!("❌ Failed to get auth tokens for {}: {}", host, e);
return Err(e);
}
};
if let Err(e) = self.save_tokens(host, &tokens).await {
warn!("❌ Failed to save tokens for {}: {}", host, e);
return Err(e);
}
if let Err(e) = self
.refresh_credentials(http_client, host, &tokens.access_token)
.await
{
warn!("❌ Failed to refresh credentials for {}: {}", host, e);
return Err(e);
}
info!("✔️ Successfully logged in and authenticated to {}", host);
Ok(())
}
pub async fn get_or_register_client<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
redirect_uri: &str,
) -> Res<OAuthClient> {
let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
if let Some(client) = auth_io.read_client().await? {
if client.redirect_uri == redirect_uri {
info!("✔️ Found existing OAuth client for {}", host);
return Ok(client);
}
info!(
"⚠️ Cached client has stale redirect_uri, re-registering for {}",
host
);
}
info!("⏳ Registering new OAuth client for {}", host);
let client = register_client(http_client, host, redirect_uri).await?;
auth_io.write_client(&client).await?;
info!(
"✔️ Registered OAuth client for {}: {}",
host, client.client_id
);
Ok(client)
}
pub async fn login_oauth<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
params: OAuthParams,
) -> Res {
info!("⏳ OAuth login for host {}", host);
let tokens = exchange_oauth_code(http_client, host, ¶ms)
.await
.map_err(|e| {
warn!("❌ Failed to exchange OAuth code for {}: {}", host, e);
e
})?;
self.save_tokens(host, &tokens).await.map_err(|e| {
warn!("❌ Failed to save tokens for {}: {}", host, e);
e
})?;
self.refresh_credentials(http_client, host, &tokens.access_token)
.await
.map_err(|e| {
warn!("❌ Failed to refresh credentials for {}: {}", host, e);
e
})?;
info!("✔️ OAuth login successful for {}", host);
Ok(())
}
async fn get_auth_tokens<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
refresh_token: &str,
) -> Res<Tokens> {
debug!("⏳ Getting auth tokens for host {:?}", host);
let tokens = get_auth_tokens(http_client, host, refresh_token).await?;
debug!("✔️ Successfully retrieved auth tokens");
Ok(tokens)
}
async fn save_tokens(&self, host: &Host, tokens: &Tokens) -> Res<()> {
debug!("⏳ Saving tokens for host {:?}", host);
let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
auth_io.write_tokens(tokens).await?;
debug!(
"✔️ Successfully saved tokens to the {:?}",
self.paths.auth_host(host)
);
Ok(())
}
async fn refresh_tokens<T: HttpClient>(
&self,
http_client: &T,
auth_io: &AuthIo<Arc<S>>,
host: &Host,
tokens: &Tokens,
) -> Res<Tokens> {
let client = auth_io
.read_client()
.await?
.ok_or(LoginError::Required(Some(host.to_owned())))?;
let new_tokens =
refresh_oauth_tokens(http_client, host, &tokens.refresh_token, &client.client_id)
.await?;
auth_io.write_tokens(&new_tokens).await?;
info!("✔️ Successfully refreshed tokens for {}", host);
Ok(new_tokens)
}
async fn refresh_tokens_with_retry<T: HttpClient>(
&self,
http_client: &T,
auth_io: &AuthIo<Arc<S>>,
host: &Host,
tokens: &Tokens,
) -> Res<Tokens> {
let first_err = match self
.refresh_tokens(http_client, auth_io, host, tokens)
.await
{
Ok(t) => return Ok(t),
Err(e) => e,
};
if matches!(first_err, Error::Login(LoginError::Required(_))) {
warn!("❌ No OAuth client registered for {}, login required", host);
return Err(first_err);
}
if !is_token_auth_error(&first_err) {
warn!(
status = ?http_status(&first_err),
"❌ Failed to refresh tokens for {}: {}", host, first_err
);
return Err(first_err);
}
info!(
status = ?http_status(&first_err),
"⚠️ Auth error refreshing tokens for {}, retrying once: {}", host, first_err
);
classify_retry_outcome(
self.refresh_tokens(http_client, auth_io, host, tokens)
.await,
is_token_auth_error,
"token endpoint",
host,
)
}
async fn refresh_credentials_with_retry<T: HttpClient>(
&self,
http_client: &T,
auth_io: &AuthIo<Arc<S>>,
host: &Host,
access_token: &str,
) -> Res<Credentials> {
let first_err = match self
.refresh_credentials(http_client, host, access_token)
.await
{
Ok(c) => return Ok(c),
Err(e) => e,
};
if !is_credentials_auth_error(&first_err) {
warn!(
status = ?http_status(&first_err),
"❌ Failed to refresh credentials for {}: {}", host, first_err
);
return Err(first_err);
}
info!(
status = ?http_status(&first_err),
"⚠️ Auth error refreshing credentials for {}, \
force-refreshing token and retrying: {}",
host, first_err
);
let tokens = auth_io
.read_tokens()
.await?
.ok_or_else(|| LoginError::Required(Some(host.to_owned())))?;
let new_tokens = self
.refresh_tokens_with_retry(http_client, auth_io, host, &tokens)
.await?;
classify_retry_outcome(
self.refresh_credentials(http_client, host, &new_tokens.access_token)
.await,
is_credentials_auth_error,
"credentials endpoint",
host,
)
}
async fn refresh_credentials<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
access_token: &str,
) -> Res<Credentials> {
debug!("⏳ Refreshing credentials for host {:?}", host);
let credentials = refresh_credentials(http_client, host, access_token).await?;
let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
auth_io.write_credentials(&credentials).await?;
debug!(
"✔️ Successfully refreshed credentials in {:?}",
self.paths.auth_host(host)
);
Ok(credentials)
}
pub async fn get_credentials_or_refresh<T: HttpClient>(
&self,
http_client: &T,
host: &Host,
) -> Res<Credentials> {
info!("⏳ Getting or refreshing credentials for {}", host);
let auth_io = AuthIo::new(self.storage.clone(), self.paths.auth_host(host));
match auth_io.read_credentials().await {
Ok(Some(creds)) => {
debug!("✔️ Found valid credentials for {}", host);
return Ok(creds);
}
Ok(None) => {
info!("❌ No existing credentials found for {}", host);
}
Err(e) => {
error!("❌ Failed to read credentials for {}: {}", host, e);
return Err(Error::Auth(
host.to_owned(),
AuthError::CredentialsRead(e.to_string()),
));
}
}
let lock = self.refresh_lock_for(host);
let _guard = lock.lock().await;
match auth_io.read_credentials().await {
Ok(Some(creds)) => {
debug!("✔️ Another task refreshed credentials for {}", host);
return Ok(creds);
}
Ok(None) => {}
Err(e) => {
error!("❌ Failed to re-read credentials for {}: {}", host, e);
return Err(Error::Auth(
host.to_owned(),
AuthError::CredentialsRead(e.to_string()),
));
}
}
let tokens = match auth_io.read_tokens().await {
Ok(Some(tokens)) => tokens,
Ok(None) => {
warn!("❌ No tokens found for {}, login required", host);
return Err(LoginError::Required(Some(host.to_owned())).into());
}
Err(e) => {
error!("❌ Failed to read tokens for {}: {}", host, e);
return Err(Error::Auth(
host.to_owned(),
AuthError::TokensRead(e.to_string()),
));
}
};
let access_token =
if tokens.expires_at <= chrono::Utc::now() + chrono::Duration::seconds(60) {
info!(
"⏳ Access token expired for {}, refreshing via refresh token",
host
);
self.refresh_tokens_with_retry(http_client, &auth_io, host, &tokens)
.await?
.access_token
} else {
tokens.access_token
};
info!("⏳ Refreshing credentials using access token for {}", host);
let creds = self
.refresh_credentials_with_retry(http_client, &auth_io, host, &access_token)
.await?;
info!("✔️ Successfully refreshed credentials for {}", host);
Ok(creds)
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use test_log::test;
use crate::io::storage::mocks::MockStorage;
use crate::paths::DomainPaths;
const ACCESS_TOKEN: &str = "test-access-token";
const REFRESH_TOKEN: &str = "test-refresh-token";
const TIMESTAMP: i64 = 1708444800;
fn get_host() -> Host {
"test.quilt.dev".parse().unwrap()
}
fn get_registry() -> String {
"registry-test.quilt.dev".to_string()
}
struct TestHttpClient;
#[async_trait]
impl HttpClient for TestHttpClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
url: &str,
auth_token: Option<&str>,
) -> Res<T> {
let registry = get_registry();
match url {
u if u == format!("https://{}/config.json", get_host()) => {
let config = QuiltStackConfig {
registry_url: format!("https://{registry}").parse()?,
};
Ok(serde_json::from_value(serde_json::to_value(config)?)?)
}
u if u == format!("https://{registry}/api/auth/get_credentials") => {
assert_eq!(auth_token, Some(ACCESS_TOKEN));
let creds = RemoteCredentials {
access_key_id: "test-access-key".to_string(),
secret_access_key: "test-secret-key".to_string(),
session_token: "test-session-token".to_string(),
expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
};
Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
}
_ => panic!("Unexpected URL: {url}"),
}
}
async fn head(&self, _url: &str) -> Res<HeaderMap> {
unimplemented!("head is not used in this test")
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
url: &str,
form_data: &HashMap<String, String>,
) -> Res<T> {
assert_eq!(url, format!("https://{}/api/token", get_registry()));
assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
let tokens = RemoteTokens {
access_token: ACCESS_TOKEN.to_string(),
refresh_token: "new-refresh-token".to_string(),
expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
};
Ok(serde_json::from_value(serde_json::to_value(tokens)?)?)
}
async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
_url: &str,
_body: &B,
) -> Res<T> {
unimplemented!("post_json is not used in this test")
}
}
#[test(tokio::test)]
async fn test_get_registry_url() {
let client = TestHttpClient;
let result = get_registry_url(&client, &get_host()).await.unwrap();
assert_eq!(
result,
url::Host::Domain("registry-test.quilt.dev".to_string())
);
}
#[test(tokio::test)]
async fn test_get_auth_tokens() {
let client = TestHttpClient;
let tokens = get_auth_tokens(&client, &get_host(), REFRESH_TOKEN)
.await
.unwrap();
assert_eq!(tokens.access_token, ACCESS_TOKEN);
assert_eq!(tokens.refresh_token, "new-refresh-token");
assert_eq!(
tokens.expires_at,
chrono::DateTime::from_timestamp(1708444800, 0).unwrap()
);
}
#[test(tokio::test)]
async fn test_refresh_credentials() {
let client = TestHttpClient;
let credentials = refresh_credentials(&client, &get_host(), ACCESS_TOKEN)
.await
.unwrap();
assert_eq!(credentials.access_key, "test-access-key");
assert_eq!(credentials.secret_key, "test-secret-key");
assert_eq!(credentials.token, "test-session-token");
assert_eq!(
credentials.expires_at,
chrono::DateTime::from_timestamp(1708444800, 0).unwrap()
);
}
#[test(tokio::test)]
async fn test_auth_refresh_credentials() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
let credentials = auth
.refresh_credentials(&TestHttpClient, &host, ACCESS_TOKEN)
.await?;
assert_eq!(credentials.access_key, "test-access-key");
assert_eq!(credentials.secret_key, "test-secret-key");
assert_eq!(credentials.token, "test-session-token");
assert_eq!(
credentials.expires_at,
chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap()
);
use crate::io::storage::StorageExt;
let creds_path = paths.auth_host(&host).join(crate::paths::AUTH_CREDENTIALS);
let bytes = storage.read_bytes(&creds_path).await?;
let read_creds: Credentials = serde_json::from_slice(&bytes)?;
assert_eq!(read_creds.access_key, credentials.access_key);
assert_eq!(read_creds.secret_key, credentials.secret_key);
assert_eq!(read_creds.token, credentials.token);
assert_eq!(read_creds.expires_at, credentials.expires_at);
Ok(())
}
#[test]
fn test_remote_credentials_deserialization() {
let valid_json = r#"{
"AccessKeyId": "test-key",
"Expiration": "2024-02-20T15:00:00Z",
"SecretAccessKey": "test-secret",
"SessionToken": "test-token"
}"#;
let creds: RemoteCredentials = serde_json::from_str(valid_json).unwrap();
assert_eq!(creds.access_key_id, "test-key");
assert_eq!(creds.secret_access_key, "test-secret");
assert_eq!(creds.session_token, "test-token");
assert_eq!(
creds.expiration,
chrono::DateTime::parse_from_rfc3339("2024-02-20T15:00:00Z")
.unwrap()
.with_timezone(&chrono::Utc)
);
let invalid_json = r#"{
"AccessKeyId": "test-key",
"Expiration": "2024-02-20 15:00:00",
"SecretAccessKey": "test-secret",
"SessionToken": "test-token"
}"#;
let error = serde_json::from_str::<RemoteCredentials>(invalid_json).unwrap_err();
assert!(error.to_string().contains("Invalid RFC3339 date"));
}
const AUTH_CODE: &str = "test-auth-code";
const CODE_VERIFIER: &str = "test-code-verifier-that-is-at-least-43-characters-long";
const CLIENT_ID: &str = "test-client-id";
const REDIRECT_URI: &str = "quilt://auth/callback?host=test.quilt.dev";
struct OAuthTestHttpClient {
expected_credentials_token: &'static str,
}
impl Default for OAuthTestHttpClient {
fn default() -> Self {
Self {
expected_credentials_token: ACCESS_TOKEN,
}
}
}
#[async_trait]
impl HttpClient for OAuthTestHttpClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
url: &str,
auth_token: Option<&str>,
) -> Res<T> {
let registry = get_registry();
match url {
u if u == format!("https://{}/config.json", get_host()) => {
let config = QuiltStackConfig {
registry_url: format!("https://{registry}").parse()?,
};
Ok(serde_json::from_value(serde_json::to_value(config)?)?)
}
u if u == format!("https://{registry}/api/auth/get_credentials") => {
assert_eq!(auth_token, Some(self.expected_credentials_token));
let creds = RemoteCredentials {
access_key_id: "oauth-access-key".to_string(),
secret_access_key: "oauth-secret-key".to_string(),
session_token: "oauth-session-token".to_string(),
expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
};
Ok(serde_json::from_value(serde_json::to_value(creds)?)?)
}
_ => panic!("Unexpected GET URL: {url}"),
}
}
async fn head(&self, _url: &str) -> Res<HeaderMap> {
unimplemented!()
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
url: &str,
form_data: &HashMap<String, String>,
) -> Res<T> {
assert_eq!(url, connect_token_url(&get_host()));
let tokens = match form_data.get("grant_type").map(String::as_str) {
Some("authorization_code") => {
assert_eq!(form_data.get("code").unwrap(), AUTH_CODE);
assert_eq!(form_data.get("code_verifier").unwrap(), CODE_VERIFIER);
assert_eq!(form_data.get("redirect_uri").unwrap(), REDIRECT_URI);
assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
OAuthTokenResponse {
access_token: ACCESS_TOKEN.to_string(),
refresh_token: Some("oauth-refresh-token".to_string()),
expires_in: 3600,
}
}
Some("refresh_token") => {
assert_eq!(form_data.get("refresh_token").unwrap(), REFRESH_TOKEN);
assert_eq!(form_data.get("client_id").unwrap(), CLIENT_ID);
OAuthTokenResponse {
access_token: "refreshed-access-token".to_string(),
refresh_token: Some("new-refresh-token".to_string()),
expires_in: 3600,
}
}
other => panic!("Unexpected grant_type: {other:?}"),
};
Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
}
async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
url: &str,
body: &B,
) -> Res<T> {
assert_eq!(url, connect_register_url(&get_host()));
let json = serde_json::to_value(body)?;
assert_eq!(json["client_name"], "QuiltSync");
assert_eq!(json["token_endpoint_auth_method"], "none");
let redirect_uris = json["redirect_uris"].as_array().expect("redirect_uris");
assert_eq!(redirect_uris.len(), 1);
assert!(redirect_uris[0]
.as_str()
.unwrap()
.starts_with("quilt://auth/callback?host="));
Ok(serde_json::from_value(serde_json::json!({
"client_id": "test-dcr-client-id"
}))?)
}
}
#[test]
fn test_connect_host() {
let host: Host = "test.quilt.dev".parse().unwrap();
assert_eq!(connect_host(&host), "test-connect.quilt.dev");
}
#[test]
fn test_connect_token_url() {
let host: Host = "test.quilt.dev".parse().unwrap();
assert_eq!(
connect_token_url(&host),
"https://test-connect.quilt.dev/auth/token"
);
}
#[test(tokio::test)]
async fn test_exchange_oauth_code() {
let client = OAuthTestHttpClient::default();
let params = OAuthParams {
code: AUTH_CODE.to_string(),
code_verifier: CODE_VERIFIER.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
client_id: CLIENT_ID.to_string(),
};
let tokens = exchange_oauth_code(&client, &get_host(), ¶ms)
.await
.unwrap();
assert_eq!(tokens.access_token, ACCESS_TOKEN);
assert_eq!(tokens.refresh_token, "oauth-refresh-token");
}
#[test]
fn test_pkce_challenge() {
let pkce = pkce_challenge();
assert_eq!(pkce.code_verifier.len(), 86);
assert_eq!(pkce.code_challenge.len(), 43);
let expected_challenge =
URL_SAFE_NO_PAD.encode(Sha256::digest(pkce.code_verifier.as_bytes()));
assert_eq!(pkce.code_challenge, expected_challenge);
let pkce2 = pkce_challenge();
assert_ne!(pkce.code_verifier, pkce2.code_verifier);
}
#[test]
fn test_pkce_verifier_charset_rfc7636() {
let pkce = pkce_challenge();
for ch in pkce.code_verifier.chars() {
assert!(
ch.is_ascii_alphanumeric() || matches!(ch, '-' | '.' | '_' | '~'),
"code_verifier contains char '{ch}' not allowed by RFC 7636 §4.1"
);
}
}
#[test(tokio::test)]
async fn test_login_oauth() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths, storage);
let host = get_host();
let params = OAuthParams {
code: AUTH_CODE.to_string(),
code_verifier: CODE_VERIFIER.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
client_id: CLIENT_ID.to_string(),
};
auth.login_oauth(&OAuthTestHttpClient::default(), &host, params)
.await?;
Ok(())
}
#[test(tokio::test)]
async fn test_refresh_oauth_tokens() -> Res {
let tokens = refresh_oauth_tokens(
&OAuthTestHttpClient::default(),
&get_host(),
REFRESH_TOKEN,
CLIENT_ID,
)
.await?;
assert_eq!(tokens.access_token, "refreshed-access-token");
assert_eq!(tokens.refresh_token, "new-refresh-token");
Ok(())
}
#[test(tokio::test)]
async fn test_refresh_oauth_tokens_retains_old_when_omitted() -> Res {
struct NoRefreshTokenClient;
#[async_trait]
impl HttpClient for NoRefreshTokenClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
_: &str,
_: Option<&str>,
) -> Res<T> {
unimplemented!()
}
async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
unimplemented!()
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
_: &str,
_: &HashMap<String, String>,
) -> Res<T> {
let resp = OAuthTokenResponse {
access_token: "new-access-token".to_string(),
refresh_token: None, expires_in: DEFAULT_EXPIRES_IN,
};
Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
}
async fn post_json<
T: serde::de::DeserializeOwned,
B: serde::Serialize + Send + Sync,
>(
&self,
_: &str,
_: &B,
) -> Res<T> {
unimplemented!()
}
}
let tokens =
refresh_oauth_tokens(&NoRefreshTokenClient, &get_host(), REFRESH_TOKEN, CLIENT_ID)
.await?;
assert_eq!(tokens.access_token, "new-access-token");
assert_eq!(tokens.refresh_token, REFRESH_TOKEN);
Ok(())
}
#[test(tokio::test)]
async fn test_exchange_oauth_code_errors_when_refresh_token_missing() {
struct NoRefreshTokenClient;
#[async_trait]
impl HttpClient for NoRefreshTokenClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
_: &str,
_: Option<&str>,
) -> Res<T> {
unimplemented!()
}
async fn head(&self, _: &str) -> Res<reqwest::header::HeaderMap> {
unimplemented!()
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
_: &str,
_: &HashMap<String, String>,
) -> Res<T> {
let resp = OAuthTokenResponse {
access_token: ACCESS_TOKEN.to_string(),
refresh_token: None,
expires_in: DEFAULT_EXPIRES_IN,
};
Ok(serde_json::from_value(serde_json::to_value(resp)?)?)
}
async fn post_json<
T: serde::de::DeserializeOwned,
B: serde::Serialize + Send + Sync,
>(
&self,
_: &str,
_: &B,
) -> Res<T> {
unimplemented!()
}
}
let params = OAuthParams {
code: AUTH_CODE.to_string(),
code_verifier: CODE_VERIFIER.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
client_id: CLIENT_ID.to_string(),
};
let result = exchange_oauth_code(&NoRefreshTokenClient, &get_host(), ¶ms).await;
assert!(
matches!(result, Err(Error::Auth(_, AuthError::TokensExchange(_)))),
"expected TokensExchange error, got: {result:?}"
);
}
#[test]
fn test_oauth_token_response_missing_expires_in() {
let json = r#"{"access_token":"tok","refresh_token":"ref"}"#;
let resp: OAuthTokenResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.expires_in, DEFAULT_EXPIRES_IN);
}
const REFRESHED_ACCESS_TOKEN: &str = "refreshed-access-token";
#[test(tokio::test)]
async fn test_get_credentials_or_refresh_with_expired_token() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
let auth_io = AuthIo::new(storage, paths.auth_host(&host));
auth_io
.write_tokens(&Tokens {
access_token: "expired-access-token".to_string(),
refresh_token: REFRESH_TOKEN.to_string(),
expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
})
.await?;
auth_io
.write_client(&OAuthClient {
client_id: CLIENT_ID.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
})
.await?;
let client = OAuthTestHttpClient {
expected_credentials_token: REFRESHED_ACCESS_TOKEN,
};
let creds = auth.get_credentials_or_refresh(&client, &host).await?;
assert_eq!(creds.access_key, "oauth-access-key");
let persisted = auth_io
.read_tokens()
.await?
.expect("tokens should be persisted");
assert_eq!(persisted.access_token, REFRESHED_ACCESS_TOKEN);
assert_eq!(persisted.refresh_token, "new-refresh-token");
Ok(())
}
#[test(tokio::test)]
async fn test_get_or_register_client() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths, storage);
let host = get_host();
let client = auth
.get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
.await?;
assert_eq!(client.client_id, "test-dcr-client-id");
assert_eq!(client.redirect_uri, REDIRECT_URI);
let client2 = auth
.get_or_register_client(&OAuthTestHttpClient::default(), &host, REDIRECT_URI)
.await?;
assert_eq!(client2.client_id, "test-dcr-client-id");
let new_redirect = "quilt://auth/callback?host=other.quilt.dev";
let client3 = auth
.get_or_register_client(&OAuthTestHttpClient::default(), &host, new_redirect)
.await?;
assert_eq!(client3.client_id, "test-dcr-client-id");
assert_eq!(client3.redirect_uri, new_redirect);
Ok(())
}
#[test]
fn remote_tokens_debug_redacts_secrets() {
let tokens = RemoteTokens {
access_token: "secret-access".to_string(),
refresh_token: "secret-refresh".to_string(),
expires_at: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
};
let output = format!("{:?}", tokens);
assert!(output.contains("[REDACTED]"));
assert!(!output.contains("secret-access"));
assert!(!output.contains("secret-refresh"));
}
#[test]
fn oauth_token_response_debug_redacts_secrets() {
let response = OAuthTokenResponse {
access_token: "secret-access".to_string(),
refresh_token: Some("secret-refresh".to_string()),
expires_in: 3600,
};
let output = format!("{:?}", response);
assert!(output.contains("[REDACTED]"));
assert!(!output.contains("secret-access"));
assert!(!output.contains("secret-refresh"));
}
#[test]
fn oauth_token_response_debug_none_refresh_token() {
let response = OAuthTokenResponse {
access_token: "secret-access".to_string(),
refresh_token: None,
expires_in: 3600,
};
let output = format!("{:?}", response);
assert!(output.contains("refresh_token: None"));
assert!(!output.contains("secret-access"));
}
#[test]
fn remote_credentials_debug_redacts_secrets() {
let creds = RemoteCredentials {
access_key_id: "secret-key-id".to_string(),
expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
secret_access_key: "secret-access-key".to_string(),
session_token: "secret-session-token".to_string(),
};
let output = format!("{:?}", creds);
assert!(output.contains("[REDACTED]"));
assert!(!output.contains("secret-key-id"));
assert!(!output.contains("secret-access-key"));
assert!(!output.contains("secret-session-token"));
}
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
async fn spawn_one_shot(response: Vec<u8>) -> std::net::SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
if let Ok((mut stream, _)) = listener.accept().await {
let mut buf = [0u8; 4096];
let _ = stream.read(&mut buf).await;
let _ = stream.write_all(&response).await;
let _ = stream.shutdown().await;
}
});
addr
}
async fn reqwest_error_with_status(status: u16) -> Error {
let body = format!(
"HTTP/1.1 {} X\r\nContent-Length: 0\r\nConnection: close\r\n\r\n",
status
)
.into_bytes();
let addr = spawn_one_shot(body).await;
reqwest::Client::new()
.get(format!("http://{}/", addr))
.send()
.await
.unwrap()
.error_for_status()
.unwrap_err()
.into()
}
struct RetryMockClient {
cred_fail_first_n: usize,
token_fail_first_n: usize,
cred_calls: AtomicUsize,
token_calls: AtomicUsize,
}
impl RetryMockClient {
fn new(cred_fail: usize, token_fail: usize) -> Self {
Self {
cred_fail_first_n: cred_fail,
token_fail_first_n: token_fail,
cred_calls: AtomicUsize::new(0),
token_calls: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl HttpClient for RetryMockClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
url: &str,
_auth_token: Option<&str>,
) -> Res<T> {
let registry = get_registry();
if url == format!("https://{}/config.json", get_host()) {
let config = QuiltStackConfig {
registry_url: format!("https://{registry}").parse()?,
};
return Ok(serde_json::from_value(serde_json::to_value(config)?)?);
}
if url == format!("https://{registry}/api/auth/get_credentials") {
let n = self.cred_calls.fetch_add(1, Ordering::SeqCst);
if n < self.cred_fail_first_n {
return Err(reqwest_error_with_status(401).await);
}
let creds = RemoteCredentials {
access_key_id: "oauth-access-key".to_string(),
secret_access_key: "oauth-secret-key".to_string(),
session_token: "oauth-session-token".to_string(),
expiration: chrono::DateTime::from_timestamp(TIMESTAMP, 0).unwrap(),
};
return Ok(serde_json::from_value(serde_json::to_value(creds)?)?);
}
panic!("Unexpected GET URL: {url}")
}
async fn head(&self, _url: &str) -> Res<HeaderMap> {
unimplemented!()
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
url: &str,
form_data: &HashMap<String, String>,
) -> Res<T> {
assert_eq!(url, connect_token_url(&get_host()));
let n = self.token_calls.fetch_add(1, Ordering::SeqCst);
if n < self.token_fail_first_n {
return Err(reqwest_error_with_status(401).await);
}
assert_eq!(
form_data.get("grant_type").map(String::as_str),
Some("refresh_token")
);
let tokens = OAuthTokenResponse {
access_token: REFRESHED_ACCESS_TOKEN.to_string(),
refresh_token: Some("new-refresh-token".to_string()),
expires_in: 3600,
};
Ok(serde_json::from_value(serde_json::to_value(&tokens)?)?)
}
async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
_url: &str,
_body: &B,
) -> Res<T> {
unimplemented!()
}
}
async fn seed_fresh_tokens(storage: &Arc<MockStorage>, paths: &DomainPaths, host: &Host) {
let auth_io = AuthIo::new(storage.clone(), paths.auth_host(host));
auth_io
.write_tokens(&Tokens {
access_token: ACCESS_TOKEN.to_string(),
refresh_token: REFRESH_TOKEN.to_string(),
expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
})
.await
.unwrap();
auth_io
.write_client(&OAuthClient {
client_id: CLIENT_ID.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
})
.await
.unwrap();
}
#[test(tokio::test)]
async fn test_credentials_transient_401_recovers_via_force_token_refresh() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
seed_fresh_tokens(&storage, &paths, &host).await;
let client = RetryMockClient::new( 1, 0);
let creds = auth.get_credentials_or_refresh(&client, &host).await?;
assert_eq!(creds.access_key, "oauth-access-key");
assert_eq!(
client.cred_calls.load(Ordering::SeqCst),
2,
"credentials endpoint should be called twice: initial + retry"
);
assert_eq!(
client.token_calls.load(Ordering::SeqCst),
1,
"token endpoint should be called once to force-refresh"
);
Ok(())
}
#[test(tokio::test)]
async fn test_credentials_persistent_401_maps_to_login_required() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
seed_fresh_tokens(&storage, &paths, &host).await;
let client = RetryMockClient::new( usize::MAX, 0);
let result = auth.get_credentials_or_refresh(&client, &host).await;
assert!(
matches!(result, Err(Error::Login(LoginError::Required(_)))),
"expected LoginRequired after persistent 4xx, got: {result:?}"
);
assert_eq!(
client.cred_calls.load(Ordering::SeqCst),
2,
"retry must be bounded to one extra attempt"
);
Ok(())
}
#[test(tokio::test)]
async fn test_token_refresh_transient_401_recovers() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
let auth_io = AuthIo::new(storage.clone(), paths.auth_host(&host));
auth_io
.write_tokens(&Tokens {
access_token: "expired-access-token".to_string(),
refresh_token: REFRESH_TOKEN.to_string(),
expires_at: chrono::Utc::now() - chrono::Duration::seconds(300),
})
.await?;
auth_io
.write_client(&OAuthClient {
client_id: CLIENT_ID.to_string(),
redirect_uri: REDIRECT_URI.to_string(),
})
.await?;
let client = RetryMockClient::new( 0, 1);
let creds = auth.get_credentials_or_refresh(&client, &host).await?;
assert_eq!(creds.access_key, "oauth-access-key");
assert_eq!(
client.token_calls.load(Ordering::SeqCst),
2,
"token endpoint should be called twice: initial + retry"
);
assert_eq!(
client.cred_calls.load(Ordering::SeqCst),
1,
"credentials endpoint should only be called once after successful retry"
);
Ok(())
}
#[derive(Default)]
struct Gate {
entered: tokio::sync::Notify,
release: tokio::sync::Notify,
}
#[derive(Clone)]
struct CountingCredsClient {
cred_calls: Arc<std::sync::atomic::AtomicUsize>,
sleep_ms: u64,
gate: Option<Arc<Gate>>,
}
#[async_trait]
impl HttpClient for CountingCredsClient {
async fn get<T: serde::de::DeserializeOwned>(
&self,
url: &str,
_auth_token: Option<&str>,
) -> Res<T> {
if url.ends_with("/config.json") {
let body = serde_json::json!({
"registryUrl": format!("https://{}", get_registry()),
});
return Ok(serde_json::from_value(body)?);
}
if url.contains("/api/auth/get_credentials") {
self.cred_calls
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if let Some(gate) = &self.gate {
gate.entered.notify_one();
gate.release.notified().await;
} else if self.sleep_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(self.sleep_ms)).await;
}
let body = serde_json::json!({
"AccessKeyId": "refreshed-key",
"SecretAccessKey": "refreshed-secret",
"SessionToken": "refreshed-session",
"Expiration": (chrono::Utc::now() + chrono::Duration::hours(1))
.to_rfc3339(),
});
return Ok(serde_json::from_value(body)?);
}
panic!("Unexpected GET: {url}");
}
async fn head(&self, _: &str) -> Res<HeaderMap> {
unimplemented!()
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
_: &str,
_: &HashMap<String, String>,
) -> Res<T> {
unimplemented!("fresh tokens → no OAuth leg fires")
}
async fn post_json<T: serde::de::DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
_: &str,
_: &B,
) -> Res<T> {
unimplemented!()
}
}
async fn seed_expired_creds_fresh_tokens(auth_io: &AuthIo<Arc<MockStorage>>) -> Res {
auth_io
.write_credentials(&Credentials {
access_key: "stale".to_string(),
secret_key: "stale-secret".to_string(),
token: "stale-session".to_string(),
expires_at: chrono::Utc::now() - chrono::Duration::hours(1),
})
.await?;
auth_io
.write_tokens(&Tokens {
access_token: ACCESS_TOKEN.to_string(),
refresh_token: REFRESH_TOKEN.to_string(),
expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
})
.await?;
Ok(())
}
#[test(tokio::test)]
async fn test_auth_refresh_is_single_flight_across_concurrent_callers() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host = get_host();
let auth_io = AuthIo::new(storage, paths.auth_host(&host));
seed_expired_creds_fresh_tokens(&auth_io).await?;
let client = CountingCredsClient {
cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
sleep_ms: 50,
gate: None,
};
let mut handles = Vec::new();
for _ in 0..10 {
let auth = auth.clone();
let client = client.clone();
let host = host.clone();
handles.push(tokio::spawn(async move {
auth.get_credentials_or_refresh(&client, &host).await
}));
}
let mut creds_seen = Vec::new();
for h in handles {
creds_seen.push(h.await.unwrap()?);
}
assert_eq!(
client.cred_calls.load(std::sync::atomic::Ordering::SeqCst),
1,
"single-flight: 10 concurrent callers must produce exactly one refresh",
);
let first = &creds_seen[0];
for creds in &creds_seen {
assert_eq!(creds.access_key, first.access_key);
assert_eq!(creds.expires_at, first.expires_at);
}
assert_eq!(first.access_key, "refreshed-key");
Ok(())
}
#[test(tokio::test)]
async fn test_auth_refresh_lock_is_per_host() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths.clone(), storage.clone());
let host_a: Host = "a.quilt.dev".parse().unwrap();
let host_b: Host = "b.quilt.dev".parse().unwrap();
seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_a)))
.await?;
seed_expired_creds_fresh_tokens(&AuthIo::new(storage.clone(), paths.auth_host(&host_b)))
.await?;
let gate = Arc::new(Gate::default());
let gated_client = CountingCredsClient {
cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
sleep_ms: 0,
gate: Some(gate.clone()),
};
let fast_client = CountingCredsClient {
cred_calls: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
sleep_ms: 0,
gate: None,
};
let auth_clone = auth.clone();
let client_a = gated_client.clone();
let host_a_clone = host_a.clone();
let a_task = tokio::spawn(async move {
auth_clone
.get_credentials_or_refresh(&client_a, &host_a_clone)
.await
});
gate.entered.notified().await;
tokio::time::timeout(
std::time::Duration::from_secs(5),
auth.get_credentials_or_refresh(&fast_client, &host_b),
)
.await
.expect("host_b refresh must not wait behind host_a's lock")?;
assert!(
!a_task.is_finished(),
"host_a must still be blocked in its handler while host_b completes",
);
gate.release.notify_one();
a_task.await.unwrap()?;
Ok(())
}
#[test(tokio::test)]
async fn test_refresh_lock_map_sweeps_dead_entries() -> Res {
let storage = Arc::new(MockStorage::default());
let paths = DomainPaths::new(storage.temp_dir.path().to_path_buf());
let auth = Auth::new(paths, storage);
let host: Host = "x.quilt.dev".parse().unwrap();
let arc1 = auth.refresh_lock_for(&host);
assert_eq!(
auth.refresh_locks
.lock()
.unwrap_or_else(|e| e.into_inner())
.len(),
1,
);
drop(arc1);
assert!(auth
.refresh_locks
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(&host)
.expect("entry still present before sweep")
.upgrade()
.is_none(),);
let _arc2 = auth.refresh_lock_for(&host);
assert_eq!(
auth.refresh_locks
.lock()
.unwrap_or_else(|e| e.into_inner())
.len(),
1,
);
Ok(())
}
}