use std::sync::Arc;
use async_trait::async_trait;
use ppoppo_token::access_token::{EpochRevocation, EpochRevocationError};
use ppoppo_token::{SV_CACHE_TTL, sv_cache_key};
use super::{Cache, Fetcher};
pub struct CompositeEpochRevocation {
cache: Arc<dyn Cache>,
fetcher: Arc<dyn Fetcher>,
}
impl CompositeEpochRevocation {
#[must_use]
pub fn new(cache: Arc<dyn Cache>, fetcher: Arc<dyn Fetcher>) -> Self {
Self { cache, fetcher }
}
}
impl std::fmt::Debug for CompositeEpochRevocation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeEpochRevocation")
.finish_non_exhaustive()
}
}
#[async_trait]
impl EpochRevocation for CompositeEpochRevocation {
async fn current(&self, sub: &str) -> Result<i64, EpochRevocationError> {
let key = sv_cache_key(sub);
if let Some(v) = self.cache.get(&key).await {
return Ok(v);
}
let fresh = self
.fetcher
.fetch(sub)
.await
.map_err(|e| EpochRevocationError::Transient(e.to_string()))?;
self.cache.set(&key, fresh, SV_CACHE_TTL).await;
Ok(fresh)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::collections::HashMap;
use std::sync::Mutex;
use std::time::Duration;
use super::super::FetchError;
use super::*;
#[derive(Default)]
struct TestCache {
store: Mutex<HashMap<String, i64>>,
set_count: Mutex<u32>,
}
#[async_trait]
impl Cache for TestCache {
async fn get(&self, key: &str) -> Option<i64> {
self.store.lock().unwrap().get(key).copied()
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
*self.set_count.lock().unwrap() += 1;
self.store.lock().unwrap().insert(key.to_string(), sv);
}
}
struct TestFetcher {
value: Result<i64, &'static str>,
fetch_count: Mutex<u32>,
}
impl TestFetcher {
fn ok(v: i64) -> Self {
Self {
value: Ok(v),
fetch_count: Mutex::new(0),
}
}
fn failing() -> Self {
Self {
value: Err("substrate down"),
fetch_count: Mutex::new(0),
}
}
}
#[async_trait]
impl Fetcher for TestFetcher {
async fn fetch(&self, _sub: &str) -> Result<i64, FetchError> {
*self.fetch_count.lock().unwrap() += 1;
self.value.map_err(|e| FetchError::Other(e.to_string()))
}
}
const SUB: &str = "01HSAB00000000000000000000";
#[tokio::test]
async fn cache_hit_short_circuits_fetcher() {
let cache = Arc::new(TestCache::default());
cache.store.lock().unwrap().insert(sv_cache_key(SUB), 7);
let fetcher = Arc::new(TestFetcher::ok(99));
let composer = CompositeEpochRevocation::new(cache, fetcher.clone());
assert_eq!(composer.current(SUB).await.unwrap(), 7);
assert_eq!(*fetcher.fetch_count.lock().unwrap(), 0);
}
#[tokio::test]
async fn cache_miss_fetches_then_writes_back() {
let cache = Arc::new(TestCache::default());
let fetcher = Arc::new(TestFetcher::ok(42));
let composer = CompositeEpochRevocation::new(cache.clone(), fetcher.clone());
assert_eq!(composer.current(SUB).await.unwrap(), 42);
assert_eq!(*fetcher.fetch_count.lock().unwrap(), 1);
assert_eq!(*cache.set_count.lock().unwrap(), 1);
assert_eq!(
cache.store.lock().unwrap().get(&sv_cache_key(SUB)).copied(),
Some(42),
);
}
#[tokio::test]
async fn second_call_after_miss_hits_cache() {
let cache = Arc::new(TestCache::default());
let fetcher = Arc::new(TestFetcher::ok(11));
let composer = CompositeEpochRevocation::new(cache, fetcher.clone());
let _ = composer.current(SUB).await.unwrap();
let _ = composer.current(SUB).await.unwrap();
assert_eq!(*fetcher.fetch_count.lock().unwrap(), 1);
}
#[tokio::test]
async fn fetcher_transient_maps_to_epoch_revocation_transient() {
let cache = Arc::new(TestCache::default());
let fetcher = Arc::new(TestFetcher::failing());
let composer = CompositeEpochRevocation::new(cache, fetcher);
let err = composer.current(SUB).await.unwrap_err();
match err {
EpochRevocationError::Transient(detail) => {
assert!(detail.contains("substrate down"), "{detail}");
}
}
}
}