use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use tracing::debug;
use crate::error::{Result, SchemaRegError};
struct InFlightEntry<V> {
token: u64,
waiters: Vec<oneshot::Sender<Result<Arc<V>>>>,
}
pub(crate) struct InMemoryCache<K, V> {
entries: RwLock<HashMap<K, Arc<V>>>,
insertion_order: RwLock<VecDeque<K>>,
max_entries: Option<usize>,
in_flight_token: AtomicU64,
invalidation_generation: AtomicU64,
in_flight: Mutex<HashMap<K, InFlightEntry<V>>>,
make_cancelled_error: fn(K) -> SchemaRegError,
}
impl<K, V> InMemoryCache<K, V>
where
K: Hash + Eq + Copy + fmt::Debug + Send + Sync + 'static,
V: Send + Sync + 'static,
{
pub(crate) fn new(
max_entries: Option<usize>,
make_cancelled_error: fn(K) -> SchemaRegError,
) -> Self {
let capacity = max_entries.unwrap_or(0);
Self {
entries: RwLock::new(HashMap::with_capacity(capacity)),
insertion_order: RwLock::new(VecDeque::with_capacity(capacity)),
max_entries,
in_flight_token: AtomicU64::new(0),
invalidation_generation: AtomicU64::new(0),
in_flight: Mutex::new(HashMap::new()),
make_cancelled_error,
}
}
pub(crate) fn len(&self) -> usize {
self.entries.read().len()
}
pub(crate) fn is_empty(&self) -> bool {
self.entries.read().is_empty()
}
pub(crate) fn generation(&self) -> u64 {
self.invalidation_generation.load(Ordering::SeqCst)
}
pub(crate) fn invalidate(&self, key: K) {
self.invalidation_generation.fetch_add(1, Ordering::SeqCst);
let waiters = self
.in_flight
.lock()
.remove(&key)
.map(|e| e.waiters)
.unwrap_or_default();
self.entries.write().remove(&key);
self.insertion_order.write().retain(|cached| *cached != key);
let err = (self.make_cancelled_error)(key);
for waiter in waiters {
let _ = waiter.send(Err(err.clone()));
}
}
pub(crate) fn clear(&self) {
self.invalidation_generation.fetch_add(1, Ordering::SeqCst);
let cancelled: Vec<(K, InFlightEntry<V>)> = self.in_flight.lock().drain().collect();
self.entries.write().clear();
self.insertion_order.write().clear();
for (key, entry) in cancelled {
let err = (self.make_cancelled_error)(key);
for waiter in entry.waiters {
let _ = waiter.send(Err(err.clone()));
}
}
}
pub(crate) fn keys_matching<P>(&self, predicate: P) -> Vec<K>
where
P: Fn(&V) -> bool,
{
self.entries
.read()
.iter()
.filter(|(_, v)| predicate(v.as_ref()))
.map(|(k, _)| *k)
.collect()
}
pub(crate) fn insert_if_current(&self, key: K, value: Arc<V>, observed_generation: u64) {
let mut entries = self.entries.write();
if self.invalidation_generation.load(Ordering::SeqCst) != observed_generation {
debug!(
?key,
"fetch completed after invalidation; skipping cache insert"
);
return;
}
if let Some(existing) = entries.get_mut(&key) {
*existing = value;
return;
}
if let Some(max_entries) = self.max_entries {
let mut insertion_order = self.insertion_order.write();
if entries.len() >= max_entries
&& let Some(evicted) = insertion_order.pop_front()
{
entries.remove(&evicted);
}
insertion_order.push_back(key);
}
entries.insert(key, value);
}
pub(crate) async fn get_or_fetch<F, Fut>(&self, key: K, fetch: F) -> Result<Arc<V>>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Arc<V>>>,
{
if let Some(v) = self.entries.read().get(&key) {
return Ok(Arc::clone(v));
}
let (waiter_rx, leader_token) = {
let mut in_flight = self.in_flight.lock();
if let Some(v) = self.entries.read().get(&key) {
return Ok(Arc::clone(v));
}
if let Some(entry) = in_flight.get_mut(&key) {
let (tx, rx) = oneshot::channel();
entry.waiters.push(tx);
(Some(rx), None)
} else {
let token = self.in_flight_token.fetch_add(1, Ordering::SeqCst) + 1;
in_flight.insert(
key,
InFlightEntry {
token,
waiters: Vec::new(),
},
);
(None, Some(token))
}
};
if let Some(rx) = waiter_rx {
return rx.await.map_err(|_| (self.make_cancelled_error)(key))?;
}
let Some(leader_token) = leader_token else {
return Err((self.make_cancelled_error)(key));
};
struct FetchGuard<'a, K, V>
where
K: Hash + Eq + Copy + fmt::Debug + Send + Sync + 'static,
V: Send + Sync + 'static,
{
cache: &'a InMemoryCache<K, V>,
key: K,
token: u64,
completed: bool,
}
impl<K, V> Drop for FetchGuard<'_, K, V>
where
K: Hash + Eq + Copy + fmt::Debug + Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn drop(&mut self) {
if self.completed {
return;
}
let waiters = {
let mut in_flight = self.cache.in_flight.lock();
if matches!(in_flight.get(&self.key), Some(e) if e.token == self.token) {
in_flight
.remove(&self.key)
.map(|e| e.waiters)
.unwrap_or_default()
} else {
Vec::new()
}
};
let err = (self.cache.make_cancelled_error)(self.key);
for waiter in waiters {
let _ = waiter.send(Err(err.clone()));
}
}
}
let mut guard = FetchGuard {
cache: self,
key,
token: leader_token,
completed: false,
};
let gen_before = self.invalidation_generation.load(Ordering::SeqCst);
let result = fetch().await;
let arc_result: Result<Arc<V>> = match result {
Ok(ref value) => {
let should_insert = {
let in_flight = self.in_flight.lock();
matches!(in_flight.get(&key), Some(e) if e.token == leader_token)
};
if should_insert {
let mut entries = self.entries.write();
debug!(?key, "cache miss — fetched from backend");
if self.invalidation_generation.load(Ordering::SeqCst) != gen_before {
debug!(
?key,
"fetch completed after invalidation; skipping cache insert"
);
Ok(Arc::clone(value))
} else if let Some(existing) = entries.get(&key) {
Ok(Arc::clone(existing))
} else {
if let Some(max_entries) = self.max_entries {
let mut insertion_order = self.insertion_order.write();
if entries.len() >= max_entries
&& let Some(evicted) = insertion_order.pop_front()
{
entries.remove(&evicted);
}
insertion_order.push_back(key);
}
let arc = Arc::clone(value);
entries.insert(key, Arc::clone(&arc));
Ok(arc)
}
} else {
debug!(
?key,
"fetch completed after invalidation; skipping cache insert"
);
Ok(Arc::clone(value))
}
}
Err(e) => Err(e),
};
let waiters = {
let mut in_flight = self.in_flight.lock();
if matches!(in_flight.get(&key), Some(e) if e.token == leader_token) {
in_flight
.remove(&key)
.map(|e| e.waiters)
.unwrap_or_default()
} else {
Vec::new()
}
};
for waiter in waiters {
let _ = waiter.send(arc_result.as_ref().map(Arc::clone).map_err(Clone::clone));
}
guard.completed = true;
arc_result
}
}
impl<K, V> fmt::Debug for InMemoryCache<K, V>
where
K: fmt::Debug + Hash + Eq + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InMemoryCache")
.field("len", &self.entries.read().len())
.field("max_entries", &self.max_entries)
.finish()
}
}