use alloc::vec::Vec;
use core::hash::Hash;
use core::marker::PhantomData;
use hashbrown::HashSet;
use crate::Channel;
use crate::DenseKey;
use crate::DrainSorted;
use crate::DrainSortedDeterministic;
use crate::InvalidationGraph;
use crate::InvalidationSet;
use crate::TraversalScratch;
use crate::trace::InvalidationTrace;
#[derive(Copy, Clone, Debug, Default)]
pub struct AnyOrder;
#[derive(Copy, Clone, Debug, Default)]
pub struct DeterministicOrder;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum DrainMode {
InvalidatedOnly,
Affected,
}
#[derive(Copy, Clone, Debug)]
enum Within<'w, K> {
All,
Keys(&'w [K]),
DependenciesOf(K),
}
pub struct DrainBuilder<'d, 'g, 's, K, O = AnyOrder>
where
K: Copy + Eq + Hash + DenseKey,
{
invalidated: &'d mut InvalidationSet<K>,
graph: &'g InvalidationGraph<K>,
channel: Channel,
mode: DrainMode,
within: Within<'d, K>,
scratch: Option<&'s mut TraversalScratch<K>>,
trace: Option<&'s mut dyn InvalidationTrace<K>>,
_order: PhantomData<O>,
}
impl<K, O> core::fmt::Debug for DrainBuilder<'_, '_, '_, K, O>
where
K: Copy + Eq + Hash + DenseKey,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DrainBuilder")
.field("channel", &self.channel)
.field("mode", &self.mode)
.finish_non_exhaustive()
}
}
impl<'d, 'g, K> DrainBuilder<'d, 'g, 'd, K, AnyOrder>
where
K: Copy + Eq + Hash + DenseKey,
{
pub(crate) fn new(
invalidated: &'d mut InvalidationSet<K>,
graph: &'g InvalidationGraph<K>,
channel: Channel,
) -> Self {
Self {
invalidated,
graph,
channel,
mode: DrainMode::InvalidatedOnly,
within: Within::All,
scratch: None,
trace: None,
_order: PhantomData,
}
}
}
impl<'d, 'g, 's, K, O> DrainBuilder<'d, 'g, 's, K, O>
where
K: Copy + Eq + Hash + DenseKey,
{
#[must_use]
pub fn invalidated_only(mut self) -> Self {
self.mode = DrainMode::InvalidatedOnly;
self
}
#[must_use]
pub fn affected(mut self) -> Self {
self.mode = DrainMode::Affected;
self
}
#[must_use]
pub fn within_keys(mut self, keys: &'d [K]) -> Self {
self.within = Within::Keys(keys);
self
}
#[must_use]
pub fn within_dependencies_of(mut self, key: K) -> Self {
self.within = Within::DependenciesOf(key);
self
}
#[must_use]
pub fn scratch<'s2>(
self,
scratch: &'s2 mut TraversalScratch<K>,
) -> DrainBuilder<'d, 'g, 's2, K, O> {
let DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
trace,
..
} = self;
debug_assert!(
trace.is_none(),
"calling `DrainBuilder::scratch` after configuring trace is not supported; call `DrainBuilder::trace` instead",
);
DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
scratch: Some(scratch),
trace: None,
_order: PhantomData,
}
}
#[must_use]
pub fn trace<'s2, T>(
self,
scratch: &'s2 mut TraversalScratch<K>,
trace: &'s2 mut T,
) -> DrainBuilder<'d, 'g, 's2, K, O>
where
T: InvalidationTrace<K>,
{
let DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
..
} = self;
DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
scratch: Some(scratch),
trace: Some(trace),
_order: PhantomData,
}
}
}
impl<'d, 'g, 's, K> DrainBuilder<'d, 'g, 's, K, AnyOrder>
where
K: Copy + Eq + Hash + DenseKey,
{
#[must_use]
pub fn deterministic(self) -> DrainBuilder<'d, 'g, 's, K, DeterministicOrder>
where
K: Ord + DenseKey,
{
let DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
scratch,
trace,
..
} = self;
DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
scratch,
trace,
_order: PhantomData,
}
}
}
impl<'d, 'g, 's, K, O> DrainBuilder<'d, 'g, 's, K, O>
where
K: Copy + Eq + Hash + DenseKey,
{
fn is_allowed(within: &Within<'d, K>, key: K, allowed: Option<&HashSet<K>>) -> bool {
match *within {
Within::All => true,
Within::Keys(keys) => keys.contains(&key),
Within::DependenciesOf(_) => allowed.is_some_and(|set| set.contains(&key)),
}
}
fn compute_allowed_dependencies(
graph: &InvalidationGraph<K>,
channel: Channel,
key: K,
scratch: Option<&mut TraversalScratch<K>>,
) -> HashSet<K> {
let mut allowed: HashSet<K> = HashSet::new();
allowed.insert(key);
match scratch {
Some(s) => {
s.reset();
s.stack.push(key);
while let Some(next) = s.stack.pop() {
for dep in graph.dependencies(next, channel) {
if allowed.insert(dep) {
s.stack.push(dep);
}
}
}
}
None => {
let mut stack = Vec::new();
stack.push(key);
while let Some(next) = stack.pop() {
for dep in graph.dependencies(next, channel) {
if allowed.insert(dep) {
stack.push(dep);
}
}
}
}
}
allowed
}
fn take_roots(
invalidated: &mut InvalidationSet<K>,
channel: Channel,
within: &Within<'d, K>,
allowed: Option<&HashSet<K>>,
) -> Vec<K> {
match within {
Within::All => invalidated.drain(channel).collect(),
Within::Keys(_) | Within::DependenciesOf(_) => {
let roots: Vec<K> = invalidated
.iter(channel)
.filter(|&k| Self::is_allowed(within, k, allowed))
.collect();
for &k in &roots {
let _ = invalidated.take(k, channel);
}
roots
}
}
}
fn collect_affected<'t>(
graph: &InvalidationGraph<K>,
channel: Channel,
roots: Vec<K>,
within: &Within<'d, K>,
allowed: Option<&HashSet<K>>,
scratch: Option<&'t mut TraversalScratch<K>>,
mut trace: Option<&'t mut dyn InvalidationTrace<K>>,
) -> Vec<K> {
let mut out = Vec::new();
match scratch {
Some(s) => {
s.reset();
for root in roots {
if !Self::is_allowed(within, root, allowed) {
continue;
}
let newly = s.visited.insert(root);
if newly {
out.push(root);
s.stack.push(root);
}
if let Some(t) = trace.as_deref_mut() {
t.root(root, channel, newly);
}
}
while let Some(because) = s.stack.pop() {
for dependent in graph.dependents(because, channel) {
if !Self::is_allowed(within, dependent, allowed) {
continue;
}
let newly = s.visited.insert(dependent);
if let Some(t) = trace.as_deref_mut() {
t.caused_by(dependent, because, channel, newly);
}
if newly {
out.push(dependent);
s.stack.push(dependent);
}
}
}
}
None => {
let mut visited: HashSet<K> = HashSet::new();
let mut stack: Vec<K> = Vec::new();
for root in roots {
if !Self::is_allowed(within, root, allowed) {
continue;
}
let newly = visited.insert(root);
if newly {
out.push(root);
stack.push(root);
}
if let Some(t) = trace.as_deref_mut() {
t.root(root, channel, newly);
}
}
while let Some(because) = stack.pop() {
for dependent in graph.dependents(because, channel) {
if !Self::is_allowed(within, dependent, allowed) {
continue;
}
let newly = visited.insert(dependent);
if let Some(t) = trace.as_deref_mut() {
t.caused_by(dependent, because, channel, newly);
}
if newly {
out.push(dependent);
stack.push(dependent);
}
}
}
}
}
out
}
}
impl<'d, 'g, 's, K> DrainBuilder<'d, 'g, 's, K, AnyOrder>
where
K: Copy + Eq + Hash + DenseKey,
{
pub fn run(self) -> DrainSorted<'g, K> {
let DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
mut scratch,
trace,
..
} = self;
let allowed_set_storage;
let allowed = match within {
Within::DependenciesOf(key) => {
allowed_set_storage =
Self::compute_allowed_dependencies(graph, channel, key, scratch.as_deref_mut());
Some(&allowed_set_storage)
}
Within::All | Within::Keys(_) => None,
};
let roots = Self::take_roots(invalidated, channel, &within, allowed);
let keys = match mode {
DrainMode::InvalidatedOnly => roots,
DrainMode::Affected => {
Self::collect_affected(graph, channel, roots, &within, allowed, scratch, trace)
}
};
let cap = keys.len();
DrainSorted::from_iter_with_capacity(keys.into_iter(), cap, graph, channel)
}
}
impl<'d, 'g, 's, K> DrainBuilder<'d, 'g, 's, K, DeterministicOrder>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
pub fn run(self) -> DrainSortedDeterministic<'g, K> {
let DrainBuilder {
invalidated,
graph,
channel,
mode,
within,
mut scratch,
trace,
..
} = self;
let allowed_set_storage;
let allowed = match within {
Within::DependenciesOf(key) => {
allowed_set_storage =
Self::compute_allowed_dependencies(graph, channel, key, scratch.as_deref_mut());
Some(&allowed_set_storage)
}
Within::All | Within::Keys(_) => None,
};
let roots = Self::take_roots(invalidated, channel, &within, allowed);
let keys = match mode {
DrainMode::InvalidatedOnly => roots,
DrainMode::Affected => {
Self::collect_affected(graph, channel, roots, &within, allowed, scratch, trace)
}
};
let cap = keys.len();
DrainSortedDeterministic::from_iter_with_capacity(keys.into_iter(), cap, graph, channel)
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
use alloc::vec;
use crate::CycleHandling;
use crate::InvalidationTracker;
use crate::trace::OneParentRecorder;
const LAYOUT: Channel = Channel::new(0);
#[test]
fn within_keys_does_not_clear_outside_roots() {
let mut t = InvalidationTracker::<u32>::new();
t.mark(1, LAYOUT);
t.mark(2, LAYOUT);
let subset = [1];
let order: Vec<_> = t
.drain(LAYOUT)
.invalidated_only()
.within_keys(&subset)
.run()
.collect();
assert_eq!(order, vec![1]);
assert!(t.is_invalidated(2, LAYOUT));
}
#[test]
fn within_dependencies_of_filters_invalidated_only() {
let mut t = InvalidationTracker::<u32>::with_cycle_handling(CycleHandling::Error);
t.add_dependency(2, 1, LAYOUT).unwrap();
t.add_dependency(3, 2, LAYOUT).unwrap();
t.mark(1, LAYOUT);
t.mark(2, LAYOUT);
t.mark(3, LAYOUT);
t.mark(9, LAYOUT);
let order: Vec<_> = t
.drain(LAYOUT)
.invalidated_only()
.within_dependencies_of(3)
.deterministic()
.run()
.collect();
assert_eq!(order, vec![1, 2, 3]);
assert!(t.is_invalidated(9, LAYOUT));
}
#[test]
fn affected_with_trace_records_one_plausible_path() {
let mut t = InvalidationTracker::<u32>::with_cycle_handling(CycleHandling::Error);
t.add_dependency(2, 1, LAYOUT).unwrap();
t.add_dependency(3, 2, LAYOUT).unwrap();
t.mark(1, LAYOUT);
let mut scratch = TraversalScratch::new();
let mut rec = OneParentRecorder::new();
let order: Vec<_> = t
.drain(LAYOUT)
.affected()
.trace(&mut scratch, &mut rec)
.run()
.collect();
assert_eq!(order, vec![1, 2, 3]);
assert_eq!(rec.explain_path(3, LAYOUT).unwrap(), vec![1, 2, 3]);
}
#[test]
fn deterministic_diamond_is_total() {
let mut t = InvalidationTracker::<u32>::with_cycle_handling(CycleHandling::Error);
t.add_dependency(2, 1, LAYOUT).unwrap();
t.add_dependency(3, 1, LAYOUT).unwrap();
t.add_dependency(4, 2, LAYOUT).unwrap();
t.add_dependency(4, 3, LAYOUT).unwrap();
t.mark(1, LAYOUT);
t.mark(2, LAYOUT);
t.mark(3, LAYOUT);
t.mark(4, LAYOUT);
let order: Vec<_> = t
.drain(LAYOUT)
.invalidated_only()
.deterministic()
.run()
.collect();
assert_eq!(order, vec![1, 2, 3, 4]);
}
}