use alloc::vec::Vec;
use core::hash::Hash;
use crate::cascade::{CascadeCycleError, ChannelCascade};
use crate::channel::Channel;
use crate::cross_channel::CrossChannelEdges;
use crate::drain::{DenseKey, DrainSorted, DrainSortedDeterministic};
use crate::drain_builder::{AnyOrder, DrainBuilder};
use crate::graph::{CycleError, CycleHandling, InvalidationGraph};
use crate::policy::PropagationPolicy;
use crate::scratch::TraversalScratch;
use crate::set::InvalidationSet;
use crate::trace::InvalidationTrace;
#[derive(Debug, Clone)]
pub struct InvalidationTracker<K>
where
K: Copy + Eq + Hash + DenseKey,
{
graph: InvalidationGraph<K>,
invalidated: InvalidationSet<K>,
cycle_handling: CycleHandling,
cascade: ChannelCascade,
cross_channel: CrossChannelEdges<K>,
}
impl<K> Default for InvalidationTracker<K>
where
K: Copy + Eq + Hash + DenseKey,
{
fn default() -> Self {
Self::new()
}
}
impl<K> InvalidationTracker<K>
where
K: Copy + Eq + Hash + DenseKey,
{
pub fn drain(&mut self, channel: Channel) -> DrainBuilder<'_, '_, '_, K, AnyOrder> {
DrainBuilder::new(&mut self.invalidated, &self.graph, channel)
}
#[must_use]
pub fn new() -> Self {
Self::with_cycle_handling(CycleHandling::default())
}
#[must_use]
pub fn with_cycle_handling(cycle_handling: CycleHandling) -> Self {
Self {
graph: InvalidationGraph::new(),
invalidated: InvalidationSet::new(),
cycle_handling,
cascade: ChannelCascade::new(),
cross_channel: CrossChannelEdges::new(),
}
}
pub fn with_cascades(
cascades: impl IntoIterator<Item = (Channel, Channel)>,
) -> Result<Self, CascadeCycleError> {
Ok(Self {
graph: InvalidationGraph::new(),
invalidated: InvalidationSet::new(),
cycle_handling: CycleHandling::default(),
cascade: ChannelCascade::from_edges(cascades)?,
cross_channel: CrossChannelEdges::new(),
})
}
#[must_use]
pub fn from_graph(graph: InvalidationGraph<K>) -> Self {
Self::from_graph_with_cycle_handling(graph, CycleHandling::default())
}
#[must_use]
pub fn from_graph_with_cycle_handling(
graph: InvalidationGraph<K>,
cycle_handling: CycleHandling,
) -> Self {
Self {
graph,
invalidated: InvalidationSet::new(),
cycle_handling,
cascade: ChannelCascade::new(),
cross_channel: CrossChannelEdges::new(),
}
}
#[inline]
#[must_use]
pub fn graph(&self) -> &InvalidationGraph<K> {
&self.graph
}
#[inline]
#[must_use]
pub fn invalidated(&self) -> &InvalidationSet<K> {
&self.invalidated
}
#[inline]
#[must_use]
pub fn generation(&self) -> u64 {
self.invalidated.generation()
}
#[inline]
#[must_use]
pub fn cycle_handling(&self) -> CycleHandling {
self.cycle_handling
}
pub fn add_dependency(
&mut self,
from: K,
to: K,
channel: Channel,
) -> Result<bool, CycleError<K>> {
self.graph
.add_dependency(from, to, channel, self.cycle_handling)
}
pub fn add_dependency_with(
&mut self,
from: K,
to: K,
channel: Channel,
handling: CycleHandling,
) -> Result<bool, CycleError<K>> {
self.graph.add_dependency(from, to, channel, handling)
}
pub fn remove_dependency(&mut self, from: K, to: K, channel: Channel) -> bool {
self.graph.remove_dependency(from, to, channel)
}
pub fn remove_key(&mut self, key: K) {
self.graph.remove_key(key);
self.invalidated.remove_key(key);
self.cross_channel.remove_key(key);
}
pub fn replace_dependencies(
&mut self,
from: K,
channel: Channel,
to: impl IntoIterator<Item = K>,
) -> Result<bool, CycleError<K>> {
self.graph
.replace_dependencies(from, channel, to, self.cycle_handling)
}
pub fn replace_dependencies_with(
&mut self,
from: K,
channel: Channel,
to: impl IntoIterator<Item = K>,
handling: CycleHandling,
) -> Result<bool, CycleError<K>> {
self.graph.replace_dependencies(from, channel, to, handling)
}
pub fn add_cascade(&mut self, from: Channel, to: Channel) -> Result<bool, CascadeCycleError> {
self.cascade.add_cascade(from, to)
}
pub fn remove_cascade(&mut self, from: Channel, to: Channel) -> bool {
self.cascade.remove_cascade(from, to)
}
#[inline]
#[must_use]
pub fn cascade(&self) -> &ChannelCascade {
&self.cascade
}
pub fn add_cross_dependency(
&mut self,
from_key: K,
from_ch: Channel,
to_key: K,
to_ch: Channel,
) -> bool {
self.cross_channel
.add_edge(from_key, from_ch, to_key, to_ch)
}
pub fn remove_cross_dependency(
&mut self,
from_key: K,
from_ch: Channel,
to_key: K,
to_ch: Channel,
) -> bool {
self.cross_channel
.remove_edge(from_key, from_ch, to_key, to_ch)
}
pub fn replace_cross_dependents(
&mut self,
from_key: K,
from_ch: Channel,
targets: impl IntoIterator<Item = (K, Channel)>,
) -> bool {
self.cross_channel
.replace_dependents(from_key, from_ch, targets)
}
pub fn clear_cross_dependents(&mut self, from_key: K, from_ch: Channel) -> bool {
self.cross_channel.clear_dependents(from_key, from_ch)
}
pub fn clear_cross_dependencies(&mut self, to_key: K, to_ch: Channel) -> bool {
self.cross_channel.clear_dependencies(to_key, to_ch)
}
#[inline]
#[must_use]
pub fn cross_channel(&self) -> &CrossChannelEdges<K> {
&self.cross_channel
}
#[inline]
pub fn mark(&mut self, key: K, channel: Channel) -> bool {
let result = self.invalidated.mark(key, channel);
let cascaded = self.cascade.cascades_from(channel);
if !cascaded.is_empty() {
for ch in cascaded {
self.invalidated.mark(key, ch);
}
}
result
}
pub fn mark_with<P>(&mut self, key: K, channel: Channel, policy: &P)
where
P: PropagationPolicy<K>,
{
if self.cascade.cascades_from(channel).is_empty() && self.cross_channel.is_empty() {
policy.propagate(key, channel, &self.graph, &mut self.invalidated);
return;
}
let mut worklist: Vec<(K, Channel)> = Vec::new();
worklist.push((key, channel));
let mut processed = hashbrown::HashSet::<(K, Channel)>::new();
while let Some((k, ch)) = worklist.pop() {
if !processed.insert((k, ch)) {
continue;
}
policy.propagate(k, ch, &self.graph, &mut self.invalidated);
self.enqueue_cross_successors(k, ch, &processed, &mut worklist);
for dep in self.graph.transitive_dependents(k, ch) {
if self.invalidated.is_invalidated(dep, ch) {
self.enqueue_cross_successors(dep, ch, &processed, &mut worklist);
}
}
}
}
#[inline]
#[must_use]
pub fn is_invalidated(&self, key: K, channel: Channel) -> bool {
self.invalidated.is_invalidated(key, channel)
}
#[inline]
#[must_use]
pub fn has_invalidated(&self, channel: Channel) -> bool {
self.invalidated.has_invalidated(channel)
}
#[inline]
#[must_use]
pub fn is_clean(&self) -> bool {
self.invalidated.is_empty()
}
pub fn drain_sorted(&mut self, channel: Channel) -> DrainSorted<'_, K> {
self.drain(channel).invalidated_only().run()
}
pub fn drain_affected_sorted(&mut self, channel: Channel) -> DrainSorted<'_, K> {
self.drain(channel).affected().run()
}
pub fn drain_affected_sorted_with_trace<T>(
&mut self,
channel: Channel,
scratch: &mut TraversalScratch<K>,
trace: &mut T,
) -> DrainSorted<'_, K>
where
T: InvalidationTrace<K>,
{
self.drain(channel).affected().trace(scratch, trace).run()
}
#[must_use]
pub fn peek_sorted(&self, channel: Channel) -> DrainSorted<'_, K> {
let cap = self.invalidated.len(channel);
DrainSorted::from_iter_with_capacity(
self.invalidated.iter(channel),
cap,
&self.graph,
channel,
)
}
pub fn clear(&mut self, channel: Channel) {
self.invalidated.clear(channel);
}
pub fn clear_all(&mut self) {
self.invalidated.clear_all();
}
pub fn drain_channels_sorted(&mut self, order: &[Channel]) -> Vec<(Channel, K)> {
let mut results = Vec::new();
for &ch in order {
for key in self.drain_sorted(ch) {
results.push((ch, key));
}
}
results
}
pub fn transitive_dependents_cross(&self, key: K, channel: Channel) -> Vec<(K, Channel)> {
use hashbrown::HashSet;
let mut visited = HashSet::new();
let mut queue = Vec::new();
let mut result = Vec::new();
queue.push((key, channel));
visited.insert((key, channel));
while let Some((k, ch)) = queue.pop() {
for dep in self.graph.dependents(k, ch) {
if visited.insert((dep, ch)) {
result.push((dep, ch));
queue.push((dep, ch));
}
}
self.for_each_cross_successor(k, ch, |next_key, next_ch| {
if visited.insert((next_key, next_ch)) {
result.push((next_key, next_ch));
queue.push((next_key, next_ch));
}
});
}
result
}
fn for_each_cross_successor(&self, key: K, channel: Channel, mut f: impl FnMut(K, Channel)) {
for cascade_ch in self.cascade.cascades_from(channel) {
f(key, cascade_ch);
}
for (to_key, to_ch) in self.cross_channel.dependents(key, channel) {
f(to_key, to_ch);
}
}
fn enqueue_cross_successors(
&self,
key: K,
channel: Channel,
processed: &hashbrown::HashSet<(K, Channel)>,
worklist: &mut Vec<(K, Channel)>,
) {
self.for_each_cross_successor(key, channel, |to_key, to_ch| {
if !processed.contains(&(to_key, to_ch)) {
worklist.push((to_key, to_ch));
}
});
}
}
impl<K> InvalidationTracker<K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
pub fn drain_sorted_deterministic(
&mut self,
channel: Channel,
) -> DrainSortedDeterministic<'_, K> {
self.drain(channel).invalidated_only().deterministic().run()
}
pub fn drain_affected_sorted_deterministic(
&mut self,
channel: Channel,
) -> DrainSortedDeterministic<'_, K> {
self.drain(channel).affected().deterministic().run()
}
#[must_use]
pub fn peek_sorted_deterministic(&self, channel: Channel) -> DrainSortedDeterministic<'_, K> {
let cap = self.invalidated.len(channel);
DrainSortedDeterministic::from_iter_with_capacity(
self.invalidated.iter(channel),
cap,
&self.graph,
channel,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
use crate::policy::{EagerPolicy, LazyPolicy, PropagationPolicy};
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
struct OffGraphMarkPolicy {
extra_key: u32,
}
impl PropagationPolicy<u32> for OffGraphMarkPolicy {
fn propagate(
&self,
key: u32,
channel: Channel,
_graph: &InvalidationGraph<u32>,
invalidated: &mut InvalidationSet<u32>,
) {
invalidated.mark(key, channel);
invalidated.mark(self.extra_key, channel);
}
}
#[test]
fn basic_workflow() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(3, 2, LAYOUT).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, LAYOUT));
assert!(tracker.is_invalidated(3, LAYOUT));
let order: Vec<_> = tracker.drain_sorted(LAYOUT).collect();
assert_eq!(order, vec![1, 2, 3]);
assert!(!tracker.has_invalidated(LAYOUT));
}
#[test]
fn can_seed_tracker_from_graph() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
let mut tracker =
InvalidationTracker::from_graph_with_cycle_handling(graph, CycleHandling::Error);
assert_eq!(tracker.cycle_handling(), CycleHandling::Error);
assert!(tracker.graph().dependents(1, LAYOUT).any(|key| key == 2));
tracker.mark_with(1, LAYOUT, &EagerPolicy);
let order: Vec<_> = tracker.drain_sorted(LAYOUT).collect();
assert_eq!(order, vec![1, 2]);
}
#[test]
fn can_seed_tracker_with_static_cascades() {
let mut tracker =
InvalidationTracker::<u32>::with_cascades([(LAYOUT, PAINT), (PAINT, COMPOSITE)])
.unwrap();
tracker.mark(1, LAYOUT);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
assert!(tracker.is_invalidated(1, COMPOSITE));
}
#[test]
fn manual_mark_no_propagation() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.mark(1, LAYOUT);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(!tracker.is_invalidated(2, LAYOUT));
}
#[test]
fn replace_dependencies_uses_configured_cycle_handling() {
let mut tracker = InvalidationTracker::<u32>::with_cycle_handling(CycleHandling::Error);
let err = tracker.replace_dependencies(1, LAYOUT, [1]).unwrap_err();
assert_eq!(err.from, 1);
assert_eq!(err.to, 1);
}
#[test]
fn lazy_policy() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(3, 2, LAYOUT).unwrap();
tracker.mark_with(1, LAYOUT, &LazyPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(!tracker.is_invalidated(2, LAYOUT));
assert!(!tracker.is_invalidated(3, LAYOUT));
}
#[test]
fn remove_key() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.mark(1, LAYOUT);
tracker.mark(2, LAYOUT);
tracker.remove_key(2);
assert!(!tracker.graph().dependents(1, LAYOUT).any(|_| true));
assert!(!tracker.is_invalidated(2, LAYOUT));
assert!(tracker.is_invalidated(1, LAYOUT));
}
#[test]
fn peek_sorted_preserves_state() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.mark(1, LAYOUT);
tracker.mark(2, LAYOUT);
let order: Vec<_> = tracker.peek_sorted(LAYOUT).collect();
assert_eq!(order, vec![1, 2]);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, LAYOUT));
}
#[test]
fn generation_tracking() {
let mut tracker = InvalidationTracker::<u32>::new();
let initial = tracker.generation();
tracker.mark(1, LAYOUT);
assert_eq!(tracker.generation(), initial + 1);
tracker.mark(2, LAYOUT);
assert_eq!(tracker.generation(), initial + 2);
}
#[test]
fn explicit_cycle_handling_can_override_tracker_default() {
let mut tracker = InvalidationTracker::<u32>::with_cycle_handling(CycleHandling::Error);
tracker.add_dependency(2, 1, LAYOUT).unwrap();
let result = tracker.add_dependency(1, 1, LAYOUT);
assert!(result.is_err());
let result = tracker.add_dependency_with(1, 1, LAYOUT, CycleHandling::Ignore);
assert!(result.is_ok());
}
#[test]
fn multiple_channels() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(2, 1, PAINT).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, LAYOUT));
assert!(!tracker.is_invalidated(1, PAINT));
assert!(!tracker.is_invalidated(2, PAINT));
tracker.mark_with(1, PAINT, &EagerPolicy);
assert!(tracker.is_invalidated(1, PAINT));
assert!(tracker.is_invalidated(2, PAINT));
}
#[test]
fn clear_specific_channel() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, LAYOUT);
tracker.mark(1, PAINT);
tracker.clear(LAYOUT);
assert!(!tracker.has_invalidated(LAYOUT));
assert!(tracker.has_invalidated(PAINT));
}
#[test]
fn clear_all() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, LAYOUT);
tracker.mark(1, PAINT);
tracker.clear_all();
assert!(tracker.is_clean());
}
const COMPOSITE: Channel = Channel::new(2);
#[test]
fn cascade_mark_propagates_to_cascaded_channels() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.mark(1, LAYOUT);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
}
#[test]
fn cascade_mark_with_eager_propagates_across_channels() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(2, 1, PAINT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
assert!(tracker.is_invalidated(2, PAINT));
}
#[test]
fn cascade_mark_with_lazy() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.mark_with(1, LAYOUT, &LazyPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
assert!(!tracker.is_invalidated(2, LAYOUT));
assert!(!tracker.is_invalidated(2, PAINT));
}
#[test]
fn cascade_transitive_chain() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cascade(PAINT, COMPOSITE).unwrap();
tracker.mark(1, LAYOUT);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
assert!(tracker.is_invalidated(1, COMPOSITE));
}
#[test]
fn no_cascade_zero_overhead() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, LAYOUT);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(!tracker.is_invalidated(1, PAINT));
}
#[test]
fn cascade_add_remove() {
let mut tracker = InvalidationTracker::<u32>::new();
assert!(tracker.add_cascade(LAYOUT, PAINT).unwrap());
assert!(!tracker.add_cascade(LAYOUT, PAINT).unwrap());
assert!(tracker.remove_cascade(LAYOUT, PAINT));
assert!(!tracker.remove_cascade(LAYOUT, PAINT));
tracker.mark(1, LAYOUT);
assert!(!tracker.is_invalidated(1, PAINT));
}
#[test]
fn cascade_accessor() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
assert!(tracker.cascade().cascades_from(LAYOUT).contains(PAINT));
}
#[test]
fn cross_channel_mark_with_follows_edges() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, PAINT));
}
#[test]
fn cross_channel_with_same_channel_deps() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(3, 2, PAINT).unwrap();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(2, PAINT));
assert!(tracker.is_invalidated(3, PAINT));
}
#[test]
fn cross_channel_with_cascade() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cross_dependency(1, PAINT, 2, COMPOSITE);
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, PAINT));
assert!(tracker.is_invalidated(2, COMPOSITE));
}
#[test]
fn cross_channel_remove_edge() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
assert!(tracker.remove_cross_dependency(1, LAYOUT, 2, PAINT));
assert!(!tracker.remove_cross_dependency(1, LAYOUT, 2, PAINT));
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(!tracker.is_invalidated(2, PAINT));
}
#[test]
fn cross_channel_replace_dependents() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.add_cross_dependency(1, LAYOUT, 3, COMPOSITE);
assert!(tracker.replace_cross_dependents(1, LAYOUT, [(3, COMPOSITE), (4, PAINT)]));
assert!(!tracker.replace_cross_dependents(1, LAYOUT, [(4, PAINT), (3, COMPOSITE)]));
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(!tracker.is_invalidated(2, PAINT));
assert!(tracker.is_invalidated(3, COMPOSITE));
assert!(tracker.is_invalidated(4, PAINT));
}
#[test]
fn cross_channel_clear_dependents_and_dependencies() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 3, PAINT);
tracker.add_cross_dependency(2, COMPOSITE, 3, PAINT);
assert!(tracker.clear_cross_dependents(1, LAYOUT));
assert!(!tracker.clear_cross_dependents(1, LAYOUT));
assert!(
tracker
.cross_channel()
.dependencies(3, PAINT)
.eq([(2, COMPOSITE)])
);
assert!(tracker.clear_cross_dependencies(3, PAINT));
assert!(!tracker.clear_cross_dependencies(3, PAINT));
assert!(
tracker
.cross_channel()
.dependents(2, COMPOSITE)
.next()
.is_none()
);
}
#[test]
fn cross_channel_remove_key_cleans_edges() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.mark(1, LAYOUT);
tracker.remove_key(1);
assert!(
tracker
.cross_channel()
.dependents(1, LAYOUT)
.next()
.is_none()
);
}
#[test]
fn cross_channel_accessor() {
let mut tracker = InvalidationTracker::<u32>::new();
assert!(tracker.cross_channel().is_empty());
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
assert!(!tracker.cross_channel().is_empty());
}
#[test]
fn drain_channels_sorted_basic() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
tracker.mark(5, PAINT);
let results = tracker.drain_channels_sorted(&[LAYOUT, PAINT]);
assert_eq!(results, vec![(LAYOUT, 1), (LAYOUT, 2), (PAINT, 5)]);
}
#[test]
fn drain_channels_sorted_empty_channels_skipped() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, PAINT);
let results = tracker.drain_channels_sorted(&[LAYOUT, PAINT]);
assert_eq!(results, vec![(PAINT, 1)]);
}
#[test]
fn drain_channels_sorted_respects_order() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, LAYOUT);
tracker.mark(2, PAINT);
let results = tracker.drain_channels_sorted(&[PAINT, LAYOUT]);
assert_eq!(results, vec![(PAINT, 2), (LAYOUT, 1)]);
}
#[test]
fn drain_channels_sorted_clears_channels() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.mark(1, LAYOUT);
tracker.mark(2, PAINT);
let _ = tracker.drain_channels_sorted(&[LAYOUT, PAINT]);
assert!(!tracker.has_invalidated(LAYOUT));
assert!(!tracker.has_invalidated(PAINT));
}
#[test]
fn transitive_dependents_cross_same_channel_only() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(3, 2, LAYOUT).unwrap();
let deps = tracker.transitive_dependents_cross(1, LAYOUT);
assert_eq!(deps.len(), 2);
assert!(deps.contains(&(2, LAYOUT)));
assert!(deps.contains(&(3, LAYOUT)));
}
#[test]
fn transitive_dependents_cross_with_cascade() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_dependency(2, 1, PAINT).unwrap();
let deps = tracker.transitive_dependents_cross(1, LAYOUT);
assert!(deps.contains(&(1, PAINT)));
assert!(deps.contains(&(2, PAINT)));
}
#[test]
fn transitive_dependents_cross_with_cross_edges() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.add_dependency(3, 2, PAINT).unwrap();
let deps = tracker.transitive_dependents_cross(1, LAYOUT);
assert!(deps.contains(&(2, PAINT)));
assert!(deps.contains(&(3, PAINT)));
}
#[test]
fn transitive_dependents_cross_combined() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cross_dependency(2, LAYOUT, 3, COMPOSITE);
let deps = tracker.transitive_dependents_cross(1, LAYOUT);
assert!(deps.contains(&(2, LAYOUT)));
assert!(deps.contains(&(1, PAINT)));
assert!(deps.contains(&(2, PAINT)));
assert!(deps.contains(&(3, COMPOSITE)));
}
#[test]
fn transitive_dependents_cross_no_duplicates() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cross_dependency(1, LAYOUT, 1, PAINT);
let deps = tracker.transitive_dependents_cross(1, LAYOUT);
let paint_count = deps
.iter()
.filter(|&&(k, ch)| k == 1 && ch == PAINT)
.count();
assert_eq!(paint_count, 1);
}
#[test]
fn cross_channel_from_propagated_dependent_not_just_root() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(1, 0, LAYOUT).unwrap();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.mark_with(0, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(0, LAYOUT));
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(
tracker.is_invalidated(2, PAINT),
"cross-channel edge from propagated dependent must fire"
);
}
#[test]
fn cascade_applies_to_all_propagated_dependents_not_just_root() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(2, LAYOUT));
assert!(
tracker.is_invalidated(1, PAINT),
"cascade must apply to root"
);
assert!(
tracker.is_invalidated(2, PAINT),
"cascade must apply to propagated dependent"
);
}
#[test]
fn cascade_fires_from_cross_channel_target() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.add_cascade(PAINT, COMPOSITE).unwrap();
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(2, PAINT));
assert!(
tracker.is_invalidated(2, COMPOSITE),
"cascade on cross-channel target channel must fire"
);
}
#[test]
fn chained_cross_channel_edges() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(1, LAYOUT, 2, PAINT);
tracker.add_cross_dependency(2, PAINT, 3, COMPOSITE);
tracker.mark_with(1, LAYOUT, &EagerPolicy);
assert!(tracker.is_invalidated(2, PAINT));
assert!(
tracker.is_invalidated(3, COMPOSITE),
"chained cross-channel edges must be followed transitively"
);
}
#[test]
fn mark_with_parity_with_transitive_dependents_cross() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_dependency(3, 2, LAYOUT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cross_dependency(2, LAYOUT, 4, COMPOSITE);
tracker.add_cross_dependency(4, COMPOSITE, 5, PAINT);
tracker.add_dependency(6, 5, PAINT).unwrap();
let expected = tracker.transitive_dependents_cross(1, LAYOUT);
tracker.mark_with(1, LAYOUT, &EagerPolicy);
for (k, ch) in &expected {
assert!(
tracker.is_invalidated(*k, *ch),
"mark_with must realize ({k}, {ch:?}) from transitive_dependents_cross"
);
}
}
#[test]
fn lazy_policy_does_not_cascade_or_cross_channel_dependents() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_dependency(2, 1, LAYOUT).unwrap();
tracker.add_cascade(LAYOUT, PAINT).unwrap();
tracker.add_cross_dependency(2, LAYOUT, 3, COMPOSITE);
tracker.mark_with(1, LAYOUT, &LazyPolicy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(1, PAINT));
assert!(!tracker.is_invalidated(2, LAYOUT));
assert!(!tracker.is_invalidated(3, COMPOSITE));
}
#[test]
fn custom_policy_off_graph_marks_do_not_trigger_cross_channel_follow_up() {
let mut tracker = InvalidationTracker::<u32>::new();
tracker.add_cross_dependency(9, LAYOUT, 10, PAINT);
let policy = OffGraphMarkPolicy { extra_key: 9 };
tracker.mark_with(1, LAYOUT, &policy);
assert!(tracker.is_invalidated(1, LAYOUT));
assert!(tracker.is_invalidated(9, LAYOUT));
assert!(
!tracker.is_invalidated(10, PAINT),
"cross-channel traversal is not defined for off-graph keys marked by a custom policy"
);
}
}