use std::future::Future;
use std::sync::Arc;
use tokio::sync::Mutex;
use zeroize::Zeroizing;
use crate::Token;
#[cfg(not(target_arch = "wasm32"))]
pub trait TokenStore: Send + Sync {
fn load(&self) -> impl Future<Output = Option<Token>> + Send;
fn save(&self, token: &Token) -> impl Future<Output = ()> + Send;
}
#[cfg(target_arch = "wasm32")]
pub trait TokenStore {
fn load(&self) -> impl Future<Output = Option<Token>>;
fn save(&self, token: &Token) -> impl Future<Output = ()>;
}
#[cfg(not(target_arch = "wasm32"))]
impl<T: TokenStore + ?Sized> TokenStore for Arc<T> {
fn load(&self) -> impl Future<Output = Option<Token>> + Send {
(**self).load()
}
fn save(&self, token: &Token) -> impl Future<Output = ()> + Send {
(**self).save(token)
}
}
#[cfg(target_arch = "wasm32")]
impl<T: TokenStore + ?Sized> TokenStore for Arc<T> {
fn load(&self) -> impl Future<Output = Option<Token>> {
(**self).load()
}
fn save(&self, token: &Token) -> impl Future<Output = ()> {
(**self).save(token)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoStore;
impl TokenStore for NoStore {
async fn load(&self) -> Option<Token> {
None
}
async fn save(&self, _token: &Token) {}
}
pub struct InMemoryTokenStore {
state: Mutex<Option<Zeroizing<String>>>,
}
impl InMemoryTokenStore {
pub fn new() -> Self {
Self {
state: Mutex::new(None),
}
}
}
impl Default for InMemoryTokenStore {
fn default() -> Self {
Self::new()
}
}
impl TokenStore for InMemoryTokenStore {
async fn load(&self) -> Option<Token> {
let guard = self.state.lock().await;
let json = guard.as_ref()?;
serde_json::from_str(json).ok()
}
async fn save(&self, token: &Token) {
let Ok(json) = serde_json::to_string(token) else {
tracing::warn!("InMemoryTokenStore: failed to serialise token");
return;
};
let mut guard = self.state.lock().await;
*guard = Some(Zeroizing::new(json));
}
}
pub struct TokenStoreFn<L, S> {
load: L,
save: S,
}
impl<L, S> TokenStoreFn<L, S> {
pub fn new(load: L, save: S) -> Self {
Self { load, save }
}
}
#[cfg(not(target_arch = "wasm32"))]
impl<L, LF, S, SF> TokenStore for TokenStoreFn<L, S>
where
L: Fn() -> LF + Send + Sync,
LF: Future<Output = Option<String>> + Send,
S: Fn(String) -> SF + Send + Sync,
SF: Future<Output = ()> + Send,
{
async fn load(&self) -> Option<Token> {
let json = Zeroizing::new((self.load)().await?);
serde_json::from_str(&json).ok()
}
async fn save(&self, token: &Token) {
let Ok(json) = serde_json::to_string(token) else {
tracing::warn!("TokenStoreFn: failed to serialise token");
return;
};
(self.save)(json).await;
}
}
#[cfg(target_arch = "wasm32")]
impl<L, LF, S, SF> TokenStore for TokenStoreFn<L, S>
where
L: Fn() -> LF,
LF: Future<Output = Option<String>>,
S: Fn(String) -> SF,
SF: Future<Output = ()>,
{
async fn load(&self) -> Option<Token> {
let json = Zeroizing::new((self.load)().await?);
serde_json::from_str(&json).ok()
}
async fn save(&self, token: &Token) {
let Ok(json) = serde_json::to_string(token) else {
tracing::warn!("TokenStoreFn: failed to serialise token");
return;
};
(self.save)(json).await;
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::SecretToken;
use super::*;
fn dummy_token(expires_at: u64) -> Token {
Token {
access_token: SecretToken::new("dummy-access".to_string()),
refresh_token: None,
token_type: "Bearer".to_string(),
expires_at,
region: None,
client_id: None,
device_instance_id: None,
}
}
#[tokio::test]
async fn in_memory_load_returns_none_when_empty() {
let store = InMemoryTokenStore::new();
assert!(
store.load().await.is_none(),
"freshly constructed store should hold no token"
);
}
#[tokio::test]
async fn in_memory_round_trip_preserves_expires_at() {
let store = InMemoryTokenStore::new();
store.save(&dummy_token(4_000_000_000)).await;
let loaded = store
.load()
.await
.expect("load should return the saved token");
assert_eq!(
loaded.expires_at(),
4_000_000_000,
"round-trip should preserve expires_at"
);
assert_eq!(
loaded.token_type(),
"Bearer",
"round-trip should preserve token_type"
);
}
#[tokio::test]
async fn in_memory_save_overwrites_previous() {
let store = InMemoryTokenStore::new();
store.save(&dummy_token(1_000_000_000)).await;
store.save(&dummy_token(2_000_000_000)).await;
let loaded = store.load().await.expect("store should hold a token");
assert_eq!(
loaded.expires_at(),
2_000_000_000,
"second save should replace the first"
);
}
#[tokio::test]
async fn callback_store_invokes_load_closure_each_call() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = Arc::clone(&calls);
let store = TokenStoreFn::new(
move || {
let calls = Arc::clone(&calls_clone);
async move {
let n = calls.fetch_add(1, Ordering::SeqCst);
if n == 0 {
None
} else {
Some(serde_json::to_string(&dummy_token(4_000_000_000)).unwrap())
}
}
},
|_json: String| async move {},
);
assert!(
store.load().await.is_none(),
"first load returns None because the closure does"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"first call should have invoked the load closure exactly once"
);
let loaded = store
.load()
.await
.expect("second load should yield a token");
assert_eq!(
loaded.expires_at(),
4_000_000_000,
"deserialised token should preserve the JSON payload's expires_at"
);
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"second call should have invoked the load closure a second time"
);
}
#[tokio::test]
async fn callback_store_forwards_serialised_token_to_save_closure() {
let captured = Arc::new(Mutex::new(None::<String>));
let captured_clone = Arc::clone(&captured);
let store = TokenStoreFn::new(
|| async { None },
move |json: String| {
let captured = Arc::clone(&captured_clone);
async move {
*captured.lock().await = Some(json);
}
},
);
store.save(&dummy_token(4_000_000_000)).await;
let json = captured
.lock()
.await
.clone()
.expect("save closure should have captured the JSON");
assert!(
json.contains("\"expires_at\":4000000000"),
"captured JSON should encode expires_at; got: {json}"
);
assert!(
json.contains("\"token_type\":\"Bearer\""),
"captured JSON should encode token_type; got: {json}"
);
}
#[tokio::test]
async fn callback_store_ignores_invalid_json_on_load() {
let store = TokenStoreFn::new(
|| async { Some("not valid json".to_string()) },
|_json: String| async move {},
);
assert!(
store.load().await.is_none(),
"invalid JSON from the load closure should be treated as cache miss"
);
}
}