use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use super::ClockFn;
use super::client::ScittClient;
use super::error::ScittError;
use super::root_keys::ScittKeyStore;
use super::system_clock;
const DEFAULT_KEY_REFRESH_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
const DEFAULT_ON_DEMAND_COOLDOWN: Duration = Duration::from_secs(5 * 60);
struct KeyStoreState {
snapshot: Arc<ScittKeyStore>,
last_refreshed: Option<i64>,
}
struct Inner {
client: Option<Arc<dyn ScittClient>>,
on_demand_cooldown_secs: i64,
clock: ClockFn,
state: RwLock<KeyStoreState>,
refresh_gate: Mutex<()>,
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RefreshableKeyStoreInner")
.field("has_client", &self.client.is_some())
.field("cooldown_secs", &self.on_demand_cooldown_secs)
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug)]
pub struct RefreshableKeyStore {
inner: Arc<Inner>,
}
impl RefreshableKeyStore {
pub fn new(initial: ScittKeyStore, client: Arc<dyn ScittClient>) -> Self {
Self {
inner: Arc::new(Inner {
client: Some(client),
on_demand_cooldown_secs: DEFAULT_ON_DEMAND_COOLDOWN
.as_secs()
.try_into()
.unwrap_or(i64::MAX),
clock: system_clock(),
state: RwLock::new(KeyStoreState {
snapshot: Arc::new(initial),
last_refreshed: None,
}),
refresh_gate: Mutex::new(()),
}),
}
}
pub fn with_cooldown(
initial: ScittKeyStore,
client: Arc<dyn ScittClient>,
cooldown: Duration,
) -> Self {
Self {
inner: Arc::new(Inner {
client: Some(client),
on_demand_cooldown_secs: cooldown.as_secs().try_into().unwrap_or(i64::MAX),
clock: system_clock(),
state: RwLock::new(KeyStoreState {
snapshot: Arc::new(initial),
last_refreshed: None,
}),
refresh_gate: Mutex::new(()),
}),
}
}
pub fn from_static(initial: ScittKeyStore) -> Self {
Self {
inner: Arc::new(Inner {
client: None,
on_demand_cooldown_secs: DEFAULT_ON_DEMAND_COOLDOWN
.as_secs()
.try_into()
.unwrap_or(i64::MAX),
clock: system_clock(),
state: RwLock::new(KeyStoreState {
snapshot: Arc::new(initial),
last_refreshed: None,
}),
refresh_gate: Mutex::new(()),
}),
}
}
#[allow(clippy::expect_used)] pub fn with_clock(mut self, clock: ClockFn) -> Self {
Arc::get_mut(&mut self.inner)
.expect("with_clock must be called before cloning")
.clock = clock;
self
}
pub async fn current_snapshot(&self) -> Arc<ScittKeyStore> {
self.inner.state.read().await.snapshot.clone()
}
pub async fn refresh_if_cooldown_elapsed(&self) -> Result<bool, ScittError> {
if self.inner.client.is_none() {
return Ok(false);
}
let _guard = self.inner.refresh_gate.lock().await;
let should_refresh = {
let state = self.inner.state.read().await;
match state.last_refreshed {
None => true,
Some(ts) => {
let now = (self.inner.clock)();
(now - ts) >= self.inner.on_demand_cooldown_secs
}
}
};
if !should_refresh {
return Ok(false);
}
self.do_refresh().await?;
Ok(true)
}
pub async fn do_refresh(&self) -> Result<(), ScittError> {
let Some(client) = &self.inner.client else {
tracing::debug!("Static RefreshableKeyStore — refresh is a no-op");
return Ok(());
};
let key_strings = client.fetch_root_keys().await?;
let now = (self.inner.clock)();
let mut state = self.inner.state.write().await;
let merged = state.snapshot.merge_from(&key_strings);
state.snapshot = Arc::new(merged);
state.last_refreshed = Some(now);
Ok(())
}
pub fn start_background_refresh(&self, interval: Duration) -> KeyRefreshHandle {
let cancel = CancellationToken::new();
let store = self.clone();
let cancel_clone = cancel.clone();
let task = tokio::spawn(async move {
tracing::info!(
interval_secs = interval.as_secs(),
"SCITT root key background refresh started"
);
let mut consecutive_failures: u32 = 0;
loop {
tokio::select! {
() = tokio::time::sleep(interval) => {
match store.do_refresh().await {
Ok(()) => {
consecutive_failures = 0;
let count = store.current_snapshot().await.len();
tracing::debug!(key_count = count, "SCITT root keys refreshed");
}
Err(e) => {
consecutive_failures = consecutive_failures.saturating_add(1);
tracing::warn!(
error = %e,
consecutive_failures,
"Background SCITT key refresh failed"
);
}
}
}
() = cancel_clone.cancelled() => {
tracing::debug!("SCITT key background refresh cancelled");
break;
}
}
}
});
KeyRefreshHandle { cancel, task }
}
pub fn start_background_refresh_default(&self) -> KeyRefreshHandle {
self.start_background_refresh(DEFAULT_KEY_REFRESH_INTERVAL)
}
pub async fn len(&self) -> usize {
self.inner.state.read().await.snapshot.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.state.read().await.snapshot.is_empty()
}
pub async fn last_refreshed(&self) -> Option<i64> {
self.inner.state.read().await.last_refreshed
}
pub async fn last_refreshed_age_secs(&self) -> Option<i64> {
let last = self.inner.state.read().await.last_refreshed?;
let now = (self.inner.clock)();
Some(now - last)
}
}
pub struct KeyRefreshHandle {
cancel: CancellationToken,
task: tokio::task::JoinHandle<()>,
}
impl Drop for KeyRefreshHandle {
fn drop(&mut self) {
self.cancel.cancel();
self.task.abort();
}
}
impl std::fmt::Debug for KeyRefreshHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KeyRefreshHandle")
.field("cancelled", &self.cancel.is_cancelled())
.field("task_finished", &self.task.is_finished())
.finish()
}
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use p256::ecdsa::SigningKey;
use p256::pkcs8::EncodePublicKey as _;
use sha2::{Digest, Sha256};
use super::*;
use crate::scitt::client::MockScittClient;
fn make_c2sp_key_string(seed: u8, name: &str) -> String {
let signing_key = SigningKey::from_slice(&[seed; 32]).unwrap();
let verifying_key = signing_key.verifying_key();
let spki_doc = verifying_key.to_public_key_der().unwrap();
let spki_der = spki_doc.as_bytes();
let digest = Sha256::digest(spki_der);
let kid: [u8; 4] = [digest[0], digest[1], digest[2], digest[3]];
let key_hash_hex = hex::encode(kid);
let spki_b64 = BASE64_STANDARD.encode(spki_der);
format!("{name}+{key_hash_hex}+{spki_b64}")
}
fn make_store_with_key(seed: u8) -> (ScittKeyStore, [u8; 4]) {
let key_string = make_c2sp_key_string(seed, "tl.example.com");
let store = ScittKeyStore::from_c2sp_keys(&[key_string]).unwrap();
let signing_key = SigningKey::from_slice(&[seed; 32]).unwrap();
let spki_doc = signing_key.verifying_key().to_public_key_der().unwrap();
let digest = Sha256::digest(spki_doc.as_bytes());
let kid: [u8; 4] = [digest[0], digest[1], digest[2], digest[3]];
(store, kid)
}
#[tokio::test]
async fn new_contains_initial_keys() {
let (store, kid) = make_store_with_key(1);
let client: Arc<dyn ScittClient> = Arc::new(MockScittClient::new());
let refreshable = RefreshableKeyStore::new(store, client);
let snapshot = refreshable.current_snapshot().await;
assert_eq!(snapshot.len(), 1);
assert!(snapshot.get(kid).is_ok());
}
#[tokio::test]
async fn from_static_refresh_is_noop() {
let (store, kid) = make_store_with_key(1);
let refreshable = RefreshableKeyStore::from_static(store);
refreshable.do_refresh().await.unwrap();
assert!(refreshable.last_refreshed().await.is_none());
assert!(!refreshable.refresh_if_cooldown_elapsed().await.unwrap());
let snapshot = refreshable.current_snapshot().await;
assert!(snapshot.get(kid).is_ok());
}
#[tokio::test]
async fn do_refresh_merges_new_keys_preserves_existing() {
let (store, kid1) = make_store_with_key(1);
let key2_string = make_c2sp_key_string(2, "tl2.example.com");
let (_, kid2) = make_store_with_key(2);
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key2_string]));
let refreshable = RefreshableKeyStore::new(store, client);
refreshable.do_refresh().await.unwrap();
let snapshot = refreshable.current_snapshot().await;
assert_eq!(snapshot.len(), 2);
assert!(snapshot.get(kid1).is_ok());
assert!(snapshot.get(kid2).is_ok());
}
#[tokio::test]
async fn do_refresh_does_not_overwrite_existing_kid() {
let (store, kid) = make_store_with_key(1);
let key1_string = make_c2sp_key_string(1, "tl.example.com");
let original_name = store.get(kid).unwrap().name.clone();
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key1_string]));
let refreshable = RefreshableKeyStore::new(store, client);
refreshable.do_refresh().await.unwrap();
let snapshot = refreshable.current_snapshot().await;
assert_eq!(snapshot.len(), 1);
assert_eq!(snapshot.get(kid).unwrap().name, original_name);
}
#[tokio::test]
async fn do_refresh_fails_gracefully_preserves_snapshot() {
let (store, kid) = make_store_with_key(1);
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_error("root_keys", || ScittError::NotACoseSign1));
let refreshable = RefreshableKeyStore::new(store, client);
let result = refreshable.do_refresh().await;
assert!(result.is_err());
let snapshot = refreshable.current_snapshot().await;
assert_eq!(snapshot.len(), 1);
assert!(snapshot.get(kid).is_ok());
assert!(refreshable.last_refreshed().await.is_none());
}
#[tokio::test]
async fn refresh_if_cooldown_elapsed_returns_true_when_never_refreshed() {
let (store, _) = make_store_with_key(1);
let key2_string = make_c2sp_key_string(2, "tl2.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key2_string]));
let refreshable = RefreshableKeyStore::new(store, client);
assert!(refreshable.last_refreshed().await.is_none());
let result = refreshable.refresh_if_cooldown_elapsed().await.unwrap();
assert!(result);
assert!(refreshable.last_refreshed().await.is_some());
}
#[tokio::test]
async fn refresh_if_cooldown_elapsed_returns_false_within_cooldown() {
let key_string = make_c2sp_key_string(1, "tl.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key_string.clone()]));
let refreshable = RefreshableKeyStore::with_cooldown(
ScittKeyStore::from_c2sp_keys(&[key_string]).unwrap(),
client,
Duration::from_secs(3600),
);
let first = refreshable.refresh_if_cooldown_elapsed().await.unwrap();
assert!(first);
let second = refreshable.refresh_if_cooldown_elapsed().await.unwrap();
assert!(!second);
}
#[tokio::test]
async fn refresh_if_cooldown_elapsed_returns_true_after_cooldown() {
let (store, _) = make_store_with_key(1);
let key_string = make_c2sp_key_string(1, "tl.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key_string]));
let refreshable = RefreshableKeyStore::with_cooldown(store, client, Duration::ZERO);
let first = refreshable.refresh_if_cooldown_elapsed().await.unwrap();
assert!(first);
let second = refreshable.refresh_if_cooldown_elapsed().await.unwrap();
assert!(second);
}
#[tokio::test]
async fn custom_cooldown_is_respected() {
let (store, _) = make_store_with_key(1);
let key_string = make_c2sp_key_string(1, "tl.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key_string]));
let refreshable =
RefreshableKeyStore::with_cooldown(store, client, Duration::from_secs(9999));
assert!(refreshable.refresh_if_cooldown_elapsed().await.unwrap());
assert!(!refreshable.refresh_if_cooldown_elapsed().await.unwrap());
}
#[tokio::test]
async fn last_refreshed_updates_on_success() {
let (store, _) = make_store_with_key(1);
let key_string = make_c2sp_key_string(1, "tl.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key_string]));
let refreshable = RefreshableKeyStore::new(store, client);
assert!(refreshable.last_refreshed().await.is_none());
refreshable.do_refresh().await.unwrap();
let ts = refreshable.last_refreshed().await.unwrap();
let now = chrono::Utc::now().timestamp();
assert!((now - ts).abs() < 5);
}
#[tokio::test]
async fn last_refreshed_unchanged_on_failure() {
let (store, _) = make_store_with_key(1);
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_error("root_keys", || ScittError::NotACoseSign1));
let refreshable = RefreshableKeyStore::new(store, client);
let _ = refreshable.do_refresh().await;
assert!(refreshable.last_refreshed().await.is_none());
}
#[tokio::test]
async fn background_refresh_handle_cancels_on_drop() {
let (store, _) = make_store_with_key(1);
let client: Arc<dyn ScittClient> = Arc::new(MockScittClient::new());
let refreshable = RefreshableKeyStore::new(store, client);
let cancel = {
let handle = refreshable.start_background_refresh(Duration::from_secs(3600));
let cancel = handle.cancel.clone();
assert!(!cancel.is_cancelled());
drop(handle);
cancel
};
assert!(cancel.is_cancelled());
}
#[tokio::test]
async fn concurrent_readers_during_refresh() {
let (store, kid) = make_store_with_key(1);
let key2_string = make_c2sp_key_string(2, "tl2.example.com");
let client: Arc<dyn ScittClient> =
Arc::new(MockScittClient::new().with_root_keys(vec![key2_string]));
let refreshable = RefreshableKeyStore::new(store, client);
let mut handles = Vec::new();
for _ in 0..10 {
let store_clone = refreshable.clone();
handles.push(tokio::spawn(async move {
let snapshot = store_clone.current_snapshot().await;
assert!(snapshot.get(kid).is_ok());
}));
}
refreshable.do_refresh().await.unwrap();
for handle in handles {
handle.await.unwrap();
}
let snapshot = refreshable.current_snapshot().await;
assert_eq!(snapshot.len(), 2);
}
const fn _assert_send_sync<T: Send + Sync>() {}
const _: () = _assert_send_sync::<RefreshableKeyStore>();
const _: () = _assert_send_sync::<KeyRefreshHandle>();
}