use crate::{FactKey, FactLoadResult};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
pub(super) struct FactStripeCore<K, W>
where
K: FactKey,
{
cache: HashMap<K, FactLoadResult<K::Value>>,
in_flight: HashMap<K, Vec<W>>,
}
#[derive(Debug)]
pub(super) enum Registration<V, R> {
Cached(FactLoadResult<V>),
Leading,
Joined(R),
}
impl<K, W> FactStripeCore<K, W>
where
K: FactKey,
{
pub(super) fn new() -> Self {
Self {
cache: HashMap::new(),
in_flight: HashMap::new(),
}
}
pub(super) fn peek_cache(&self, key: &K) -> Option<FactLoadResult<K::Value>> {
self.cache.get(key).cloned()
}
pub(super) fn try_register<F, R>(&mut self, key: &K, make_pair: F) -> Registration<K::Value, R>
where
F: FnOnce() -> (W, R),
{
if let Some(cached) = self.cache.get(key) {
return Registration::Cached(cached.clone());
}
match self.in_flight.entry(key.clone()) {
Entry::Occupied(mut existing) => {
let (waiter, receiver) = make_pair();
existing.get_mut().push(waiter);
Registration::Joined(receiver)
}
Entry::Vacant(slot) => {
slot.insert(Vec::new());
Registration::Leading
}
}
}
pub(super) fn finish(&mut self, key: K, result: FactLoadResult<K::Value>) -> Vec<W> {
let waiters = self.in_flight.remove(&key).unwrap_or_default();
self.cache.insert(key, result);
waiters
}
pub(super) fn is_idle(&self) -> bool {
self.in_flight.is_empty()
}
pub(super) fn clear_cache(&mut self) {
self.cache.clear();
}
}
impl<K, W> Default for FactStripeCore<K, W>
where
K: FactKey,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FactLoadError;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct K(u32);
impl FactKey for K {
type Value = u32;
const NAME: &'static str = "test";
}
type Wid = u32;
fn assert_found(result: &FactLoadResult<u32>, expected: u32) {
match result {
FactLoadResult::Found(value) => assert_eq!(*value, expected),
other => panic!("expected Found({expected}), got {other:?}"),
}
}
fn assert_cancelled(result: &FactLoadResult<u32>) {
match result {
FactLoadResult::Error(FactLoadError::LoaderCancelled { fact_name }) => {
assert_eq!(*fact_name, K::NAME);
}
other => panic!("expected LoaderCancelled, got {other:?}"),
}
}
fn expect_leading<R: std::fmt::Debug>(reg: Registration<u32, R>) {
match reg {
Registration::Leading => {}
other => panic!("expected Leading, got {other:?}"),
}
}
fn expect_joined<R>(reg: Registration<u32, R>) -> R {
match reg {
Registration::Joined(r) => r,
Registration::Leading => panic!("expected Joined, got Leading"),
Registration::Cached(_) => panic!("expected Joined, got Cached"),
}
}
fn lead<W>(core: &mut FactStripeCore<K, W>, key: K) {
expect_leading(
core.try_register::<_, ()>(&key, || panic!("make_pair should not run on Leading")),
);
}
#[test]
fn peek_cache_empty_is_none() {
let core = FactStripeCore::<K, Wid>::new();
assert!(core.peek_cache(&K(1)).is_none());
}
#[test]
fn finish_caches_value_and_extracts_waiters_in_order() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(1));
let r0 = expect_joined(core.try_register(&K(1), || (10u32, "r0")));
let r1 = expect_joined(core.try_register(&K(1), || (11u32, "r1")));
assert_eq!(r0, "r0");
assert_eq!(r1, "r1");
let waiters = core.finish(K(1), FactLoadResult::Found(42));
assert_eq!(
waiters,
vec![10u32, 11u32],
"waiters returned in join order"
);
assert_found(&core.peek_cache(&K(1)).expect("cached"), 42);
assert!(core.is_idle(), "in-flight entry removed");
}
#[test]
fn second_finish_for_same_key_returns_no_waiters() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(7));
let _ = core.finish(K(7), FactLoadResult::Found(7));
let waiters = core.finish(K(7), FactLoadResult::Found(8));
assert!(waiters.is_empty(), "no waiters on a second finish");
assert_found(&core.peek_cache(&K(7)).expect("cached"), 8);
}
#[test]
fn try_register_returns_cached_when_present() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(3));
let _ = core.finish(K(3), FactLoadResult::Found(30));
match core.try_register::<_, ()>(&K(3), || panic!("closure must not fire on Cached")) {
Registration::Cached(value) => assert_found(&value, 30),
other => panic!("expected Cached, got {other:?}"),
}
}
#[test]
fn cancellation_caches_loader_cancelled_and_wakes_waiters() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(9));
expect_joined(core.try_register(&K(9), || (1u32, ())));
expect_joined(core.try_register(&K(9), || (2u32, ())));
let waiters = core.finish(
K(9),
FactLoadResult::Error(FactLoadError::LoaderCancelled { fact_name: K::NAME }),
);
assert_eq!(waiters, vec![1u32, 2u32]);
assert_cancelled(&core.peek_cache(&K(9)).expect("cached"));
assert!(core.is_idle());
}
#[test]
fn unrelated_keys_are_independent() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(1));
lead(&mut core, K(2));
assert!(!core.is_idle(), "two in-flight entries");
let waiters_for_1 = core.finish(K(1), FactLoadResult::Found(100));
assert!(waiters_for_1.is_empty());
assert_found(&core.peek_cache(&K(1)).expect("cached"), 100);
assert!(core.peek_cache(&K(2)).is_none(), "K(2) still loading");
assert!(!core.is_idle(), "K(2) still in flight");
let _ = core.finish(K(2), FactLoadResult::Found(200));
assert!(core.is_idle());
}
#[test]
fn is_idle_reflects_in_flight_state() {
let mut core = FactStripeCore::<K, Wid>::new();
assert!(core.is_idle());
lead(&mut core, K(1));
assert!(!core.is_idle());
let _ = core.finish(K(1), FactLoadResult::Found(1));
assert!(core.is_idle());
}
#[test]
fn clear_cache_drops_cached_entries() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(1));
let _ = core.finish(K(1), FactLoadResult::Found(100));
assert_found(&core.peek_cache(&K(1)).expect("cached"), 100);
core.clear_cache();
assert!(core.peek_cache(&K(1)).is_none());
assert!(core.is_idle(), "clear_cache does not touch in-flight");
}
#[test]
fn clear_cache_with_active_in_flight_keeps_in_flight() {
let mut core = FactStripeCore::<K, Wid>::new();
lead(&mut core, K(1));
expect_joined(core.try_register(&K(1), || (5u32, ())));
core.clear_cache();
let waiters = core.finish(K(1), FactLoadResult::Found(11));
assert_eq!(waiters, vec![5u32]);
assert_found(&core.peek_cache(&K(1)).expect("cached"), 11);
}
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use crate::FactLoadError;
use loom::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use loom::sync::{Arc, Mutex};
use loom::thread;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct K(u32);
impl FactKey for K {
type Value = u32;
const NAME: &'static str = "loom";
}
#[derive(Clone)]
struct LoomWaiter {
count: Arc<AtomicUsize>,
}
impl LoomWaiter {
fn new() -> Self {
Self {
count: Arc::new(AtomicUsize::new(0)),
}
}
fn signal(&self) {
self.count.fetch_add(1, Ordering::Release);
}
fn signal_count(&self) -> usize {
self.count.load(Ordering::Acquire)
}
fn is_signaled(&self) -> bool {
self.signal_count() > 0
}
}
fn signal_all(waiters: Vec<LoomWaiter>) {
for w in waiters {
w.signal();
}
}
type Stripe = Arc<Mutex<FactStripeCore<K, LoomWaiter>>>;
fn new_stripe() -> Stripe {
Arc::new(Mutex::new(FactStripeCore::<K, LoomWaiter>::new()))
}
#[test]
fn loom_leader_election_is_unique() {
loom::model(|| {
let stripe = new_stripe();
let s1 = stripe.clone();
let t1 = thread::spawn(move || {
let w = LoomWaiter::new();
let stored = w.clone();
let mut g = s1.lock().unwrap();
g.try_register::<_, LoomWaiter>(&K(1), || (stored, w))
});
let s2 = stripe.clone();
let t2 = thread::spawn(move || {
let w = LoomWaiter::new();
let stored = w.clone();
let mut g = s2.lock().unwrap();
g.try_register::<_, LoomWaiter>(&K(1), || (stored, w))
});
let r1 = t1.join().unwrap();
let r2 = t2.join().unwrap();
let leading = matches!(r1, Registration::Leading) as u8
+ matches!(r2, Registration::Leading) as u8;
assert_eq!(
leading, 1,
"exactly one of the two registrations must become Leading"
);
});
}
#[test]
fn loom_waiters_woken_exactly_once_on_finish() {
loom::model(|| {
let stripe = new_stripe();
{
let mut g = stripe.lock().unwrap();
let outcome = g.try_register::<_, LoomWaiter>(&K(1), || {
unreachable!("first try_register must be Leading")
});
assert!(matches!(outcome, Registration::Leading));
}
let s1 = stripe.clone();
let waiter = LoomWaiter::new();
let waiter_for_thread = waiter.clone();
let joined = Arc::new(AtomicBool::new(false));
let joined_for_thread = joined.clone();
let t1 = thread::spawn(move || {
let stored = waiter_for_thread.clone();
let receiver = waiter_for_thread.clone();
let mut g = s1.lock().unwrap();
let outcome = g.try_register::<_, LoomWaiter>(&K(1), || (stored, receiver));
if matches!(outcome, Registration::Joined(_)) {
joined_for_thread.store(true, Ordering::Release);
}
});
let s2 = stripe.clone();
let t2 = thread::spawn(move || {
let waiters = {
let mut g = s2.lock().unwrap();
g.finish(K(1), FactLoadResult::Found(7))
};
signal_all(waiters);
});
t1.join().unwrap();
t2.join().unwrap();
if joined.load(Ordering::Acquire) {
assert_eq!(
waiter.signal_count(),
1,
"a waiter that registered before finish must be signaled exactly once \
(lost wake -> 0, double wake -> >1)"
);
} else {
assert_eq!(
waiter.signal_count(),
0,
"a joiner that saw Cached must never receive a signal"
);
let cached = stripe
.lock()
.unwrap()
.peek_cache(&K(1))
.expect("finish writes cache");
match cached {
FactLoadResult::Found(v) => assert_eq!(v, 7),
other => panic!("expected Found(7), got {other:?}"),
}
}
});
}
#[test]
fn loom_cancellation_is_fail_closed() {
loom::model(|| {
let stripe = new_stripe();
{
let mut g = stripe.lock().unwrap();
let outcome = g.try_register::<_, LoomWaiter>(&K(1), || unreachable!());
assert!(matches!(outcome, Registration::Leading));
}
let s1 = stripe.clone();
let waiter = LoomWaiter::new();
let waiter_for_thread = waiter.clone();
let joined = Arc::new(AtomicBool::new(false));
let joined_for_thread = joined.clone();
let t1 = thread::spawn(move || {
let stored = waiter_for_thread.clone();
let receiver = waiter_for_thread.clone();
let mut g = s1.lock().unwrap();
let outcome = g.try_register::<_, LoomWaiter>(&K(1), || (stored, receiver));
if matches!(outcome, Registration::Joined(_)) {
joined_for_thread.store(true, Ordering::Release);
}
});
let s2 = stripe.clone();
let t2 = thread::spawn(move || {
let waiters = {
let mut g = s2.lock().unwrap();
g.finish(
K(1),
FactLoadResult::Error(FactLoadError::LoaderCancelled {
fact_name: K::NAME,
}),
)
};
signal_all(waiters);
});
t1.join().unwrap();
t2.join().unwrap();
if joined.load(Ordering::Acquire) {
assert_eq!(
waiter.signal_count(),
1,
"a waiter that joined before cancellation must be signaled exactly once"
);
} else {
assert_eq!(
waiter.signal_count(),
0,
"a joiner that saw Cached after cancellation must never receive a signal"
);
}
let cached = stripe
.lock()
.unwrap()
.peek_cache(&K(1))
.expect("finish writes cache");
match cached {
FactLoadResult::Error(FactLoadError::LoaderCancelled { .. }) => {}
other => panic!("expected LoaderCancelled, got {other:?}"),
}
});
}
#[test]
fn loom_cache_write_is_visible_to_concurrent_peek() {
loom::model(|| {
let stripe = new_stripe();
let s1 = stripe.clone();
let t1 = thread::spawn(move || {
let waiters = {
let mut g = s1.lock().unwrap();
g.finish(K(1), FactLoadResult::Found(99))
};
signal_all(waiters);
});
let s2 = stripe.clone();
let t2 = thread::spawn(move || -> Option<FactLoadResult<u32>> {
let g = s2.lock().unwrap();
g.peek_cache(&K(1))
});
t1.join().unwrap();
let observed = t2.join().unwrap();
match observed {
None => {}
Some(FactLoadResult::Found(v)) => assert_eq!(v, 99),
other => panic!("unexpected cache observation: {other:?}"),
}
});
}
const LOOM_STRIPES: usize = 2;
struct LoomAdapter {
source: Mutex<Option<u32>>,
stripes: [Mutex<FactStripeCore<K, LoomWaiter>>; LOOM_STRIPES],
}
impl LoomAdapter {
fn new(initial_source: Option<u32>) -> Self {
Self {
source: Mutex::new(initial_source),
stripes: [
Mutex::new(FactStripeCore::new()),
Mutex::new(FactStripeCore::new()),
],
}
}
fn stripe_index(key: &K) -> usize {
(key.0 as usize) % LOOM_STRIPES
}
fn plan(&self, key: &K) -> (Option<u32>, Option<FactLoadResult<u32>>) {
let source_guard = self.source.lock().unwrap();
let source = *source_guard;
let stripe = self.stripes[Self::stripe_index(key)].lock().unwrap();
let cached = stripe.peek_cache(key);
(source, cached)
}
fn replace(&self, new_source: u32) -> bool {
let mut source_guard = self.source.lock().unwrap();
let mut s0 = self.stripes[0].lock().unwrap();
let mut s1 = self.stripes[1].lock().unwrap();
if !s0.is_idle() || !s1.is_idle() {
return false;
}
*source_guard = Some(new_source);
s0.clear_cache();
s1.clear_cache();
true
}
fn seed_cache(&self, key: K, value: u32) {
let mut stripe = self.stripes[Self::stripe_index(&key)].lock().unwrap();
let outcome = stripe.try_register::<_, LoomWaiter>(&key, || unreachable!());
assert!(matches!(outcome, Registration::Leading));
stripe.finish(key, FactLoadResult::Found(value));
}
fn lead(&self, key: &K) {
let mut stripe = self.stripes[Self::stripe_index(key)].lock().unwrap();
let outcome = stripe.try_register::<_, LoomWaiter>(key, || unreachable!());
assert!(matches!(outcome, Registration::Leading));
}
fn finish_leader(&self, key: K, value: u32) {
let waiters = {
let mut stripe = self.stripes[Self::stripe_index(&key)].lock().unwrap();
stripe.finish(key, FactLoadResult::Found(value))
};
signal_all(waiters);
}
}
#[test]
fn loom_replacement_is_atomic_with_respect_to_planning() {
loom::model(|| {
let adapter = Arc::new(LoomAdapter::new(Some(1)));
adapter.seed_cache(K(1), 100);
let a1 = adapter.clone();
let t1 = thread::spawn(move || a1.plan(&K(1)));
let a2 = adapter.clone();
let t2 = thread::spawn(move || a2.replace(2));
let (source_seen, cached_seen) = t1.join().unwrap();
let replaced = t2.join().unwrap();
assert!(
replaced,
"stripe is idle (seed_cache finished); replace must succeed under every schedule"
);
match (source_seen, cached_seen) {
(Some(1), Some(FactLoadResult::Found(100))) => {}
(Some(2), None) => {}
other => panic!(
"atomicity violation: planner observed {other:?}; \
new source must never be paired with stale cache"
),
}
});
}
#[test]
fn loom_unrelated_stripes_are_independent() {
loom::model(|| {
let adapter = Arc::new(LoomAdapter::new(None));
assert_ne!(
LoomAdapter::stripe_index(&K(0)),
LoomAdapter::stripe_index(&K(1)),
"test setup requires K(0) and K(1) on different stripes"
);
let a1 = adapter.clone();
let t1 = thread::spawn(move || {
a1.lead(&K(0));
a1.finish_leader(K(0), 10);
});
let a2 = adapter.clone();
let t2 = thread::spawn(move || {
a2.lead(&K(1));
a2.finish_leader(K(1), 20);
});
t1.join().unwrap();
t2.join().unwrap();
match adapter.plan(&K(0)).1 {
Some(FactLoadResult::Found(10)) => {}
other => panic!("K(0) cache: expected Found(10), got {other:?}"),
}
match adapter.plan(&K(1)).1 {
Some(FactLoadResult::Found(20)) => {}
other => panic!("K(1) cache: expected Found(20), got {other:?}"),
}
});
}
#[test]
fn loom_replacement_rejected_while_in_flight() {
loom::model(|| {
let adapter = Arc::new(LoomAdapter::new(Some(1)));
adapter.lead(&K(1));
let a1 = adapter.clone();
let t_finish = thread::spawn(move || a1.finish_leader(K(1), 77));
let a2 = adapter.clone();
let t_replace = thread::spawn(move || a2.replace(2));
t_finish.join().unwrap();
let replaced = t_replace.join().unwrap();
let (source, cached) = adapter.plan(&K(1));
match (replaced, source, cached) {
(false, Some(1), Some(FactLoadResult::Found(77))) => {}
(true, Some(2), None) => {}
other => panic!(
"in-flight replacement invariant violated: (replaced, source, cached) = {other:?}"
),
}
});
}
}