use core::hash::Hash;
use crate::channel::Channel;
use crate::drain::DenseKey;
use crate::graph::InvalidationGraph;
use crate::scratch::TraversalScratch;
use crate::set::InvalidationSet;
use crate::trace::InvalidationTrace;
pub trait PropagationPolicy<K>
where
K: Copy + Eq + Hash + DenseKey,
{
fn propagate(
&self,
key: K,
channel: Channel,
graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
);
}
#[derive(Copy, Clone, Debug, Default)]
pub struct EagerPolicy;
impl<K> PropagationPolicy<K> for EagerPolicy
where
K: Copy + Eq + Hash + DenseKey,
{
fn propagate(
&self,
key: K,
channel: Channel,
graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
) {
invalidated.mark(key, channel);
for dependent in graph.transitive_dependents(key, channel) {
invalidated.mark(dependent, channel);
}
}
}
impl EagerPolicy {
pub fn propagate_with_scratch<K>(
&self,
key: K,
channel: Channel,
graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
scratch: &mut TraversalScratch<K>,
) where
K: Copy + Eq + Hash + DenseKey,
{
invalidated.mark(key, channel);
graph.for_each_transitive_dependent(key, channel, scratch, |dependent| {
invalidated.mark(dependent, channel);
});
}
pub fn propagate_with_trace<K, T>(
&self,
key: K,
channel: Channel,
graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
scratch: &mut TraversalScratch<K>,
trace: &mut T,
) where
K: Copy + Eq + Hash + DenseKey,
T: InvalidationTrace<K>,
{
let newly_invalidated = invalidated.mark(key, channel);
trace.root(key, channel, newly_invalidated);
scratch.reset();
scratch.stack.push(key);
scratch.visited.insert(key);
while let Some(current) = scratch.stack.pop() {
for dependent in graph.dependents(current, channel) {
if !scratch.visited.insert(dependent) {
continue;
}
let newly_invalidated = invalidated.mark(dependent, channel);
trace.caused_by(dependent, current, channel, newly_invalidated);
scratch.stack.push(dependent);
}
}
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct LazyPolicy;
impl<K> PropagationPolicy<K> for LazyPolicy
where
K: Copy + Eq + Hash + DenseKey,
{
fn propagate(
&self,
key: K,
channel: Channel,
_graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
) {
invalidated.mark(key, channel);
}
}
impl<K, P> PropagationPolicy<K> for &P
where
K: Copy + Eq + Hash + DenseKey,
P: PropagationPolicy<K> + ?Sized,
{
fn propagate(
&self,
key: K,
channel: Channel,
graph: &InvalidationGraph<K>,
invalidated: &mut InvalidationSet<K>,
) {
(*self).propagate(key, channel, graph, invalidated);
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
use crate::graph::CycleHandling;
const LAYOUT: Channel = Channel::new(0);
fn setup_chain_graph() -> InvalidationGraph<u32> {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
graph
}
#[test]
fn eager_policy_marks_all_dependents() {
let graph = setup_chain_graph();
let mut invalidated = InvalidationSet::new();
let eager = EagerPolicy;
eager.propagate(1, LAYOUT, &graph, &mut invalidated);
assert!(invalidated.is_invalidated(1, LAYOUT));
assert!(invalidated.is_invalidated(2, LAYOUT));
assert!(invalidated.is_invalidated(3, LAYOUT));
assert!(invalidated.is_invalidated(4, LAYOUT));
}
#[test]
fn eager_policy_from_middle() {
let graph = setup_chain_graph();
let mut invalidated = InvalidationSet::new();
let eager = EagerPolicy;
eager.propagate(2, LAYOUT, &graph, &mut invalidated);
assert!(!invalidated.is_invalidated(1, LAYOUT));
assert!(invalidated.is_invalidated(2, LAYOUT));
assert!(invalidated.is_invalidated(3, LAYOUT));
assert!(invalidated.is_invalidated(4, LAYOUT));
}
#[test]
fn lazy_policy_only_marks_key() {
let graph = setup_chain_graph();
let mut invalidated = InvalidationSet::new();
let lazy = LazyPolicy;
lazy.propagate(1, LAYOUT, &graph, &mut invalidated);
assert!(invalidated.is_invalidated(1, LAYOUT));
assert!(!invalidated.is_invalidated(2, LAYOUT));
assert!(!invalidated.is_invalidated(3, LAYOUT));
assert!(!invalidated.is_invalidated(4, LAYOUT));
}
#[test]
fn policy_through_reference() {
let graph = setup_chain_graph();
let mut invalidated = InvalidationSet::new();
let eager = EagerPolicy;
let policy: &dyn PropagationPolicy<u32> = &eager;
policy.propagate(1, LAYOUT, &graph, &mut invalidated);
let invalidated_keys: Vec<_> = invalidated.iter(LAYOUT).collect();
assert_eq!(invalidated_keys.len(), 4);
}
#[test]
fn eager_handles_diamond() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::new();
EagerPolicy.propagate(1, LAYOUT, &graph, &mut invalidated);
assert!(invalidated.is_invalidated(1, LAYOUT));
assert!(invalidated.is_invalidated(2, LAYOUT));
assert!(invalidated.is_invalidated(3, LAYOUT));
assert!(invalidated.is_invalidated(4, LAYOUT));
assert_eq!(invalidated.len(LAYOUT), 4);
}
}