use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
const SCHEMA_VERSION: u32 = 1;
pub type ChainId = String;
#[derive(Debug, thiserror::Error)]
pub enum RefreshError {
#[error("unknown refresh token")]
Unknown,
#[error("refresh token expired")]
Expired,
#[error("refresh token already consumed (replay attack signal); chain {0} revoked")]
Replayed(ChainId),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
struct AccessRecord {
token_hash: String,
expires_at: u64,
chain_id: ChainId,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
struct RefreshRecord {
token_hash: String,
expires_at: u64,
chain_id: ChainId,
consumed_at: Option<u64>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct Snapshot {
version: u32,
client_id_hash: String,
access: Vec<AccessRecord>,
refresh: Vec<RefreshRecord>,
revoked_chains: Vec<ChainId>,
}
#[derive(Default)]
struct Index {
access_by_hash: HashMap<String, AccessRecord>,
refresh_by_hash: HashMap<String, RefreshRecord>,
revoked: std::collections::HashSet<ChainId>,
}
pub struct TokenStore {
inner: Arc<Inner>,
}
struct Inner {
state: RwLock<Index>,
path: PathBuf,
access_ttl: Duration,
refresh_ttl: Duration,
client_id_hash: String,
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn sha256_hex(input: &str) -> String {
let digest = Sha256::digest(input.as_bytes());
let mut hex = String::with_capacity(64);
for byte in digest {
use std::fmt::Write;
let _ = write!(&mut hex, "{:02x}", byte);
}
hex
}
#[derive(Debug, Clone)]
pub struct MintedPair {
pub access_token: String,
pub refresh_token: String,
pub access_ttl: Duration,
pub refresh_ttl: Duration,
pub chain_id: ChainId,
}
impl TokenStore {
pub fn load(
path: PathBuf,
client_id: &str,
access_ttl: Duration,
refresh_ttl: Duration,
) -> anyhow::Result<Self> {
let client_id_hash = sha256_hex(client_id);
let mut snapshot = match std::fs::read(&path) {
Ok(bytes) => match serde_json::from_slice::<Snapshot>(&bytes) {
Ok(snap) if snap.version == SCHEMA_VERSION => snap,
Ok(_) | Err(_) => {
let backup = path.with_extension(format!("json.broken-{}", unix_now()));
if let Err(e) = std::fs::rename(&path, &backup) {
tracing::warn!(
path = %path.display(),
backup = %backup.display(),
error = %e,
"tokens.json corrupt or wrong schema version; could not rename aside (continuing with empty store)"
);
} else {
tracing::warn!(
path = %path.display(),
backup = %backup.display(),
"tokens.json corrupt or wrong schema version; renamed aside, starting fresh"
);
}
Snapshot::default()
}
},
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::info!(path = %path.display(), "no tokens.json found; starting fresh");
Snapshot::default()
}
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"could not read tokens.json (transient I/O error?); starting fresh"
);
Snapshot::default()
}
};
if !snapshot.client_id_hash.is_empty() && snapshot.client_id_hash != client_id_hash {
tracing::warn!(
"tokens.json client_id_hash mismatch; wiping (oauth.toml was likely regenerated)"
);
snapshot = Snapshot::default();
}
snapshot.client_id_hash = client_id_hash.clone();
snapshot.version = SCHEMA_VERSION;
let now = unix_now();
snapshot.access.retain(|r| r.expires_at > now);
snapshot.refresh.retain(|r| r.expires_at > now);
let mut index = Index::default();
for r in &snapshot.access {
index.access_by_hash.insert(r.token_hash.clone(), r.clone());
}
for r in &snapshot.refresh {
index
.refresh_by_hash
.insert(r.token_hash.clone(), r.clone());
}
for c in &snapshot.revoked_chains {
index.revoked.insert(c.clone());
}
Ok(Self {
inner: Arc::new(Inner {
state: RwLock::new(index),
path,
access_ttl,
refresh_ttl,
client_id_hash,
}),
})
}
pub async fn mint_pair(&self, chain_id: Option<ChainId>) -> anyhow::Result<MintedPair> {
let chain_id = chain_id.unwrap_or_else(opaque_id);
let access_token = opaque_id();
let refresh_token = opaque_id();
let now = unix_now();
let access_record = AccessRecord {
token_hash: sha256_hex(&access_token),
expires_at: now + self.inner.access_ttl.as_secs(),
chain_id: chain_id.clone(),
};
let refresh_record = RefreshRecord {
token_hash: sha256_hex(&refresh_token),
expires_at: now + self.inner.refresh_ttl.as_secs(),
chain_id: chain_id.clone(),
consumed_at: None,
};
{
let mut idx = self.inner.state.write().await;
idx.access_by_hash
.insert(access_record.token_hash.clone(), access_record);
idx.refresh_by_hash
.insert(refresh_record.token_hash.clone(), refresh_record);
self.persist_locked(&idx);
}
Ok(MintedPair {
access_token,
refresh_token,
access_ttl: self.inner.access_ttl,
refresh_ttl: self.inner.refresh_ttl,
chain_id,
})
}
pub async fn validate_access(&self, raw: &str) -> bool {
let hash = sha256_hex(raw);
let idx = self.inner.state.read().await;
let Some(record) = idx.access_by_hash.get(&hash) else {
return false;
};
if record.expires_at <= unix_now() {
return false;
}
if idx.revoked.contains(&record.chain_id) {
return false;
}
true
}
pub async fn consume_refresh(&self, raw: &str) -> Result<ChainId, RefreshError> {
let hash = sha256_hex(raw);
let mut idx = self.inner.state.write().await;
let now = unix_now();
let (chain_id, expires_at, consumed_at) = match idx.refresh_by_hash.get(&hash) {
Some(r) => (r.chain_id.clone(), r.expires_at, r.consumed_at),
None => return Err(RefreshError::Unknown),
};
if idx.revoked.contains(&chain_id) {
return Err(RefreshError::Unknown);
}
if expires_at <= now {
return Err(RefreshError::Expired);
}
if consumed_at.is_some() {
return Err(RefreshError::Replayed(chain_id));
}
idx.refresh_by_hash.get_mut(&hash).unwrap().consumed_at = Some(now);
self.persist_locked(&idx);
Ok(chain_id)
}
pub async fn revoke_chain(&self, chain_id: ChainId) {
let mut idx = self.inner.state.write().await;
idx.revoked.insert(chain_id);
self.persist_locked(&idx);
}
fn persist_locked(&self, idx: &Index) {
let mut access: Vec<_> = idx.access_by_hash.values().cloned().collect();
let mut refresh: Vec<_> = idx.refresh_by_hash.values().cloned().collect();
access.sort_by(|a, b| a.token_hash.cmp(&b.token_hash));
refresh.sort_by(|a, b| a.token_hash.cmp(&b.token_hash));
let mut revoked: Vec<_> = idx.revoked.iter().cloned().collect();
revoked.sort();
let snap = Snapshot {
version: SCHEMA_VERSION,
client_id_hash: self.inner.client_id_hash.clone(),
access,
refresh,
revoked_chains: revoked,
};
let bytes = match serde_json::to_vec_pretty(&snap) {
Ok(b) => b,
Err(e) => {
tracing::error!(error = %e, "could not serialize token snapshot; in-memory state preserved");
return;
}
};
if let Err(e) = atomic_write_0600(&self.inner.path, &bytes) {
tracing::error!(
path = %self.inner.path.display(),
error = %e,
"could not persist tokens.json; in-memory state preserved"
);
}
}
}
fn opaque_id() -> String {
format!("{:032x}", rand::random::<u128>())
}
fn atomic_write_0600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension(format!("json.tmp.{:08x}", rand::random::<u32>()));
std::fs::write(&tmp, bytes)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600))?;
}
std::fs::rename(&tmp, path)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn fresh(dir: &TempDir) -> TokenStore {
TokenStore::load(
dir.path().join("tokens.json"),
"client-id-1",
Duration::from_secs(60),
Duration::from_secs(600),
)
.unwrap()
}
#[test]
fn load_treats_missing_file_as_empty() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let _ = store;
}
#[test]
fn load_renames_corrupt_file_aside_and_starts_fresh() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
std::fs::write(&path, b"this is not valid json").unwrap();
let _ = TokenStore::load(
path.clone(),
"client-id-1",
Duration::from_secs(60),
Duration::from_secs(600),
)
.unwrap();
assert!(
!path.exists(),
"original corrupt file should have been moved aside"
);
let entries: Vec<_> = std::fs::read_dir(dir.path())
.unwrap()
.map(|e| e.unwrap().file_name().to_string_lossy().to_string())
.collect();
assert!(
entries
.iter()
.any(|name| name.starts_with("tokens.json.broken-")),
"expected backup file, got {entries:?}"
);
}
#[test]
fn load_wipes_store_on_client_id_hash_mismatch() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
let snap = Snapshot {
version: SCHEMA_VERSION,
client_id_hash: sha256_hex("OLD-CLIENT-ID"),
access: vec![AccessRecord {
token_hash: "deadbeef".into(),
expires_at: unix_now() + 9999,
chain_id: "chain-x".into(),
}],
refresh: vec![],
revoked_chains: vec![],
};
std::fs::write(&path, serde_json::to_vec(&snap).unwrap()).unwrap();
let store = TokenStore::load(
path,
"NEW-CLIENT-ID",
Duration::from_secs(60),
Duration::from_secs(600),
)
.unwrap();
let idx = store.inner.state.try_read().unwrap();
assert!(
idx.access_by_hash.is_empty(),
"tokens issued under old client_id must be wiped"
);
}
#[test]
fn load_drops_expired_access_and_refresh_records() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
let now = unix_now();
let snap = Snapshot {
version: SCHEMA_VERSION,
client_id_hash: sha256_hex("client-id-1"),
access: vec![
AccessRecord {
token_hash: "fresh-access".into(),
expires_at: now + 600,
chain_id: "c1".into(),
},
AccessRecord {
token_hash: "stale-access".into(),
expires_at: now - 1,
chain_id: "c1".into(),
},
],
refresh: vec![
RefreshRecord {
token_hash: "fresh-refresh".into(),
expires_at: now + 600,
chain_id: "c1".into(),
consumed_at: None,
},
RefreshRecord {
token_hash: "stale-refresh".into(),
expires_at: now - 1,
chain_id: "c1".into(),
consumed_at: None,
},
],
revoked_chains: vec![],
};
std::fs::write(&path, serde_json::to_vec(&snap).unwrap()).unwrap();
let store = fresh_with_path(&path);
let idx = store.inner.state.try_read().unwrap();
assert!(idx.access_by_hash.contains_key("fresh-access"));
assert!(!idx.access_by_hash.contains_key("stale-access"));
assert!(idx.refresh_by_hash.contains_key("fresh-refresh"));
assert!(!idx.refresh_by_hash.contains_key("stale-refresh"));
}
fn fresh_with_path(path: &Path) -> TokenStore {
TokenStore::load(
path.to_path_buf(),
"client-id-1",
Duration::from_secs(60),
Duration::from_secs(600),
)
.unwrap()
}
#[tokio::test]
async fn mint_pair_returns_two_distinct_tokens() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
assert_ne!(pair.access_token, pair.refresh_token);
assert!(pair.access_token.len() >= 32);
assert!(pair.refresh_token.len() >= 32);
assert!(!pair.chain_id.is_empty());
}
#[tokio::test]
async fn mint_pair_persists_to_disk_with_mode_0600() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
let store = fresh_with_path(&path);
let _ = store.mint_pair(None).await.unwrap();
assert!(path.exists(), "mint_pair must persist to disk");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600);
}
}
#[tokio::test]
async fn tokens_at_rest_are_hashed_not_plaintext() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
let store = fresh_with_path(&path);
let pair = store.mint_pair(None).await.unwrap();
let bytes = std::fs::read(&path).unwrap();
let body = String::from_utf8(bytes).unwrap();
assert!(
!body.contains(&pair.access_token),
"raw access token must not appear on disk"
);
assert!(
!body.contains(&pair.refresh_token),
"raw refresh token must not appear on disk"
);
assert!(
body.contains(&sha256_hex(&pair.access_token)),
"expected access-token hash in file"
);
}
#[tokio::test]
async fn mint_pair_with_existing_chain_id_keeps_chain() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let first = store.mint_pair(None).await.unwrap();
let second = store.mint_pair(Some(first.chain_id.clone())).await.unwrap();
assert_eq!(first.chain_id, second.chain_id);
assert_ne!(first.access_token, second.access_token);
}
#[tokio::test]
async fn validate_access_returns_true_for_freshly_minted_token() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
assert!(store.validate_access(&pair.access_token).await);
}
#[tokio::test]
async fn validate_access_returns_false_for_unknown_token() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
assert!(!store.validate_access("not-a-real-token").await);
}
#[tokio::test]
async fn validate_access_returns_false_after_expiry() {
let dir = TempDir::new().unwrap();
let store = TokenStore::load(
dir.path().join("tokens.json"),
"client-id-1",
Duration::from_secs(0), Duration::from_secs(600),
)
.unwrap();
let pair = store.mint_pair(None).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
assert!(!store.validate_access(&pair.access_token).await);
}
#[tokio::test]
async fn consume_refresh_returns_chain_id_on_first_use() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
let chain = store.consume_refresh(&pair.refresh_token).await.unwrap();
assert_eq!(chain, pair.chain_id);
}
#[tokio::test]
async fn consume_refresh_replay_returns_replayed_with_chain_id() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
let _first = store.consume_refresh(&pair.refresh_token).await.unwrap();
let err = store
.consume_refresh(&pair.refresh_token)
.await
.unwrap_err();
match err {
RefreshError::Replayed(chain) => assert_eq!(chain, pair.chain_id),
other => panic!("expected Replayed, got {other:?}"),
}
}
#[tokio::test]
async fn consume_refresh_unknown_returns_unknown() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let err = store.consume_refresh("never-issued").await.unwrap_err();
assert!(matches!(err, RefreshError::Unknown), "got {err:?}");
}
#[tokio::test]
async fn consume_refresh_expired_returns_expired() {
let dir = TempDir::new().unwrap();
let store = TokenStore::load(
dir.path().join("tokens.json"),
"client-id-1",
Duration::from_secs(60),
Duration::from_secs(0), )
.unwrap();
let pair = store.mint_pair(None).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
let err = store
.consume_refresh(&pair.refresh_token)
.await
.unwrap_err();
assert!(matches!(err, RefreshError::Expired), "got {err:?}");
}
#[tokio::test]
async fn revoke_chain_invalidates_all_access_tokens_in_chain() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
assert!(store.validate_access(&pair.access_token).await);
store.revoke_chain(pair.chain_id.clone()).await;
assert!(!store.validate_access(&pair.access_token).await);
}
#[tokio::test]
async fn revoke_chain_invalidates_subsequent_refresh_consumption() {
let dir = TempDir::new().unwrap();
let store = fresh(&dir);
let pair = store.mint_pair(None).await.unwrap();
store.revoke_chain(pair.chain_id.clone()).await;
let err = store
.consume_refresh(&pair.refresh_token)
.await
.unwrap_err();
assert!(
matches!(err, RefreshError::Unknown),
"revoked-chain refresh should look Unknown to callers; got {err:?}"
);
}
#[tokio::test]
async fn revoke_chain_persists_to_disk() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("tokens.json");
let store = fresh_with_path(&path);
let pair = store.mint_pair(None).await.unwrap();
store.revoke_chain(pair.chain_id.clone()).await;
drop(store);
let store2 = fresh_with_path(&path);
assert!(!store2.validate_access(&pair.access_token).await);
}
#[tokio::test]
async fn tokens_survive_oauth_state_recreation() {
use crate::oauth::{OAuthConfig, OAuthState};
let dir = tempfile::TempDir::new().unwrap();
let tokens_path = dir.path().join("tokens.json");
let config = OAuthConfig {
client_id: "test-id".into(),
client_secret: "test-secret".into(),
issuer: "https://example.test".into(),
access_token_ttl_secs: None,
refresh_token_ttl_secs: None,
};
let access_token = {
let state_a =
OAuthState::with_tokens_path(config.clone(), tokens_path.clone()).unwrap();
let pair = state_a.token_store().mint_pair(None).await.unwrap();
assert!(state_a.validate_token(&pair.access_token).await);
pair.access_token
};
let state_b = OAuthState::with_tokens_path(config, tokens_path).unwrap();
assert!(
state_b.validate_token(&access_token).await,
"access token issued before restart must still validate after restart"
);
}
}