use alloc::vec::Vec;
use core::hash::Hash;
use hashbrown::{HashMap, HashSet};
use crate::Channel;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum InvalidationCause<K> {
Root,
Because {
because: K,
},
}
pub trait InvalidationTrace<K> {
fn root(&mut self, key: K, channel: Channel, newly_invalidated: bool);
fn caused_by(&mut self, key: K, because: K, channel: Channel, newly_invalidated: bool);
}
#[derive(Debug, Default, Clone)]
pub struct OneParentRecorder<K>
where
K: Copy + Eq + Hash,
{
causes: HashMap<(K, Channel), InvalidationCause<K>>,
}
impl<K> OneParentRecorder<K>
where
K: Copy + Eq + Hash,
{
#[must_use]
pub fn new() -> Self {
Self {
causes: HashMap::new(),
}
}
pub fn clear(&mut self) {
self.causes.clear();
}
#[must_use]
pub fn cause(&self, key: K, channel: Channel) -> Option<InvalidationCause<K>> {
self.causes.get(&(key, channel)).copied()
}
#[must_use]
pub fn explain_path(&self, key: K, channel: Channel) -> Option<Vec<K>> {
let mut out = Vec::new();
let mut seen: HashSet<K> = HashSet::new();
let mut current = key;
loop {
if !seen.insert(current) {
return None;
}
out.push(current);
match self.cause(current, channel)? {
InvalidationCause::Root => break,
InvalidationCause::Because { because } => current = because,
}
}
out.reverse();
Some(out)
}
}
impl<K> InvalidationTrace<K> for OneParentRecorder<K>
where
K: Copy + Eq + Hash,
{
fn root(&mut self, key: K, channel: Channel, _newly_invalidated: bool) {
self.causes
.entry((key, channel))
.or_insert(InvalidationCause::Root);
}
fn caused_by(&mut self, key: K, because: K, channel: Channel, _newly_invalidated: bool) {
self.causes
.entry((key, channel))
.or_insert(InvalidationCause::Because { because });
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
use crate::{CycleHandling, EagerPolicy, InvalidationGraph, InvalidationSet, TraversalScratch};
use alloc::vec;
const LAYOUT: Channel = Channel::new(0);
#[test]
fn records_one_parent_path() {
let mut g = InvalidationGraph::<u32>::new();
g.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
g.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::<u32>::new();
let mut scratch = TraversalScratch::new();
let mut rec = OneParentRecorder::new();
EagerPolicy.propagate_with_trace(1, LAYOUT, &g, &mut invalidated, &mut scratch, &mut rec);
assert_eq!(rec.explain_path(3, LAYOUT).unwrap(), vec![1, 2, 3]);
}
#[test]
fn can_fill_in_missing_causes_for_already_invalidated_keys() {
let mut g = InvalidationGraph::<u32>::new();
g.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
g.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::<u32>::new();
invalidated.mark(2, LAYOUT);
let mut scratch = TraversalScratch::new();
let mut rec = OneParentRecorder::new();
EagerPolicy.propagate_with_trace(1, LAYOUT, &g, &mut invalidated, &mut scratch, &mut rec);
assert_eq!(rec.explain_path(2, LAYOUT).unwrap(), vec![1, 2]);
}
}