use url::Url;
#[cfg(not(target_arch = "wasm32"))]
use stack_profile::{FileLockGuard, ProfileData, ProfileStore};
use crate::refresher::Refresher;
use crate::{AuthError, SecretToken, Token};
pub(crate) struct DeviceSessionRefresher {
#[cfg(not(target_arch = "wasm32"))]
store: Option<ProfileStore>,
base_url: Url,
client_id: String,
region: String,
device_instance_id: Option<String>,
}
impl DeviceSessionRefresher {
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn new(
store: Option<ProfileStore>,
base_url: Url,
client_id: impl Into<String>,
region: impl Into<String>,
device_instance_id: Option<String>,
) -> Self {
Self {
store,
base_url,
client_id: client_id.into(),
region: region.into(),
device_instance_id,
}
}
#[cfg(target_arch = "wasm32")]
pub(crate) fn new(
_store: Option<()>,
base_url: Url,
client_id: impl Into<String>,
region: impl Into<String>,
device_instance_id: Option<String>,
) -> Self {
Self {
base_url,
client_id: client_id.into(),
region: region.into(),
device_instance_id,
}
}
}
impl Refresher for DeviceSessionRefresher {
type Credential = SecretToken;
fn save(&self, _token: &Token) {
}
fn try_credential(&self, token: Option<&mut Token>) -> Option<Self::Credential> {
token.and_then(|t| t.take_refresh_token())
}
fn restore(&self, token: &mut Token, credential: Self::Credential) {
token.refresh_token = Some(credential);
}
async fn refresh(&self, credential: &Self::Credential) -> Result<Token, AuthError> {
#[cfg(not(target_arch = "wasm32"))]
let _lock = self.acquire_refresh_lock().await?;
#[cfg(not(target_arch = "wasm32"))]
if let Some(disk_token) = self.load_freshly_refreshed_token(credential) {
tracing::debug!(
"refresh skipped: another process rotated the token while we waited on the lock"
);
return Ok(disk_token);
}
let mut token = Token::refresh(
credential,
&self.base_url,
&self.client_id,
self.device_instance_id.as_deref(),
)
.await?;
token.set_region(&self.region);
token.set_client_id(&self.client_id);
if let Some(ref id) = self.device_instance_id {
token.set_device_instance_id(id);
}
#[cfg(not(target_arch = "wasm32"))]
self.persist_refreshed(&token);
Ok(token)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl DeviceSessionRefresher {
async fn acquire_refresh_lock(&self) -> Result<Option<FileLockGuard>, AuthError> {
let Some(store) = self.store.clone() else {
return Ok(None);
};
let lock = tokio::task::spawn_blocking(move || store.lock_exclusive(Token::FILENAME))
.await
.map_err(|e| AuthError::Server(format!("refresh lock task join failed: {e}")))?
.map_err(|e| AuthError::Server(format!("failed to acquire refresh lock: {e}")))?;
Ok(Some(lock))
}
fn load_freshly_refreshed_token(&self, credential: &SecretToken) -> Option<Token> {
let store = self.store.as_ref()?;
let disk_token: Token = store.load_profile().ok()?;
let disk_refresh = disk_token.refresh_token()?;
if disk_refresh.as_str() != credential.as_str() {
Some(disk_token)
} else {
None
}
}
fn persist_refreshed(&self, token: &Token) {
let Some(store) = &self.store else { return };
match store.save_profile(token) {
Ok(()) => tracing::debug!("refreshed token saved to disk"),
Err(err) => tracing::error!(
%err,
"failed to persist refreshed token to disk — a subsequent process \
will replay the prior refresh token and Clerk will revoke the chain"
),
}
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use mocktail::prelude::*;
use std::time::{SystemTime, UNIX_EPOCH};
const WORKSPACE_ID: &str = "ZVATKW3VHMFG27DY";
fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn token_on_disk(access: &str, refresh: &str) -> Token {
Token {
access_token: SecretToken::new(access),
refresh_token: Some(SecretToken::new(refresh)),
token_type: "Bearer".to_string(),
expires_at: now() + 3600,
region: Some("ap-southeast-2.aws".to_string()),
client_id: Some("cli".to_string()),
device_instance_id: None,
}
}
async fn start_server(mocks: MockSet) -> MockServer {
let server = MockServer::new_http("oauth-refresher-lock-test").with_mocks(mocks);
server.start().await.unwrap();
server
}
fn refresher_with_disk_token(
dir: &tempfile::TempDir,
base_url: Url,
on_disk: Token,
) -> DeviceSessionRefresher {
let store = ProfileStore::new(dir.path());
store.init_workspace(WORKSPACE_ID).unwrap();
let ws_store = store.current_workspace_store().unwrap();
ws_store.save_profile(&on_disk).unwrap();
DeviceSessionRefresher::new(Some(ws_store), base_url, "cli", "ap-southeast-2.aws", None)
}
#[tokio::test]
async fn refresh_returns_disk_token_when_sibling_already_rotated() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/oauth/token");
then.bad_request().json(serde_json::json!({
"error": "invalid_grant",
"error_description": "must not be called"
}));
});
let server = start_server(mocks).await;
let dir = tempfile::tempdir().unwrap();
let disk = token_on_disk("rotated-access", "rotated-refresh");
let refresher = refresher_with_disk_token(&dir, server.url(""), disk);
let stale_credential = SecretToken::new("stale-refresh");
let result = refresher.refresh(&stale_credential).await.unwrap();
assert_eq!(
result.access_token().as_str(),
"rotated-access",
"refresh should return the disk-cached rotated token"
);
assert_eq!(
result.refresh_token().unwrap().as_str(),
"rotated-refresh",
"rotated refresh token from disk should flow through"
);
}
#[tokio::test]
async fn refresh_calls_upstream_and_persists_when_disk_matches() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/oauth/token");
then.json(serde_json::json!({
"access_token": "new-access",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh"
}));
});
let server = start_server(mocks).await;
let dir = tempfile::tempdir().unwrap();
let disk = token_on_disk("old-access", "matching-refresh");
let refresher = refresher_with_disk_token(&dir, server.url(""), disk);
let credential = SecretToken::new("matching-refresh");
let result = refresher.refresh(&credential).await.unwrap();
assert_eq!(result.access_token().as_str(), "new-access");
assert_eq!(result.refresh_token().unwrap().as_str(), "new-refresh");
let on_disk: Token = ProfileStore::new(dir.path())
.workspace_store(WORKSPACE_ID)
.unwrap()
.load_profile()
.unwrap();
assert_eq!(on_disk.access_token().as_str(), "new-access");
assert_eq!(on_disk.refresh_token().unwrap().as_str(), "new-refresh");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_refreshes_only_call_upstream_once() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let app = axum::Router::new().route(
"/oauth/token",
axum::routing::post(move || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
axum::Json(serde_json::json!({
"access_token": "rotated-access",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "rotated-refresh"
}))
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let base_url = Url::parse(&format!("http://{addr}")).unwrap();
let dir = tempfile::tempdir().unwrap();
let store = ProfileStore::new(dir.path());
store.init_workspace(WORKSPACE_ID).unwrap();
let ws_store = store.current_workspace_store().unwrap();
ws_store
.save_profile(&token_on_disk("old-access", "shared-refresh"))
.unwrap();
let r1 = Arc::new(DeviceSessionRefresher::new(
Some(ws_store.clone()),
base_url.clone(),
"cli",
"ap-southeast-2.aws",
None,
));
let r2 = Arc::new(DeviceSessionRefresher::new(
Some(ws_store),
base_url,
"cli",
"ap-southeast-2.aws",
None,
));
let cred1 = SecretToken::new("shared-refresh");
let cred2 = SecretToken::new("shared-refresh");
let r1c = Arc::clone(&r1);
let h1 = tokio::spawn(async move { r1c.refresh(&cred1).await });
let r2c = Arc::clone(&r2);
let h2 = tokio::spawn(async move { r2c.refresh(&cred2).await });
let (a, b) = tokio::join!(h1, h2);
let a = a.unwrap().unwrap();
let b = b.unwrap().unwrap();
assert_eq!(a.access_token().as_str(), "rotated-access");
assert_eq!(b.access_token().as_str(), "rotated-access");
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"exactly one upstream refresh — the second caller must take the lock-and-reload fast path"
);
}
}