use alloc::collections::BinaryHeap;
use alloc::collections::VecDeque;
use alloc::vec::Vec;
use core::cmp::Reverse;
use core::hash::Hash;
use hashbrown::HashMap;
use hashbrown::hash_map::Entry;
use crate::DrainBuilder;
use crate::channel::Channel;
use crate::graph::InvalidationGraph;
use crate::scratch::TraversalScratch;
use crate::trace::InvalidationTrace;
pub trait DenseKey: Copy {
fn index(self) -> usize;
}
impl DenseKey for u32 {
#[inline]
fn index(self) -> usize {
self as usize
}
}
impl DenseKey for usize {
#[inline]
fn index(self) -> usize {
self
}
}
const DENSE_SENTINEL: u32 = u32::MAX;
#[inline]
pub(crate) fn prepare_dense_growth<T>(vec: &mut Vec<T>, idx: usize, storage: &str) -> usize {
let target_len = idx.checked_add(1).unwrap_or_else(|| {
panic!("DenseKey index {idx} overflows addressable capacity for {storage}")
});
if target_len > vec.len() {
vec.try_reserve_exact(target_len - vec.len()).unwrap_or_else(|err| {
panic!(
"DenseKey index {idx} requires growing {storage} to length {target_len}: {err:?}; use a compact dense key space or intern::Interner"
)
});
}
target_len
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum DrainCompletion {
Complete,
Stalled {
remaining: usize,
},
}
#[derive(Debug)]
pub struct DrainSorted<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
graph: &'a InvalidationGraph<K>,
channel: Channel,
queue: VecDeque<K>,
in_degree: HashMap<K, usize>,
stalled: bool,
}
#[derive(Debug)]
pub struct DrainSortedDeterministic<'a, K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
graph: &'a InvalidationGraph<K>,
channel: Channel,
ready: BinaryHeap<Reverse<K>>,
in_degree: Vec<u32>,
remaining: usize,
stalled: bool,
}
impl<'a, K> DrainSorted<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
pub(crate) fn from_iter_with_capacity<I>(
invalidated_keys: I,
cap: usize,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> Self
where
I: Iterator<Item = K>,
{
let mut in_degree: HashMap<K, usize> = HashMap::with_capacity(cap);
let mut unique_keys = Vec::with_capacity(cap);
for key in invalidated_keys {
if let Entry::Vacant(e) = in_degree.entry(key) {
e.insert(0);
unique_keys.push(key);
}
}
for &key in &unique_keys {
for dep in graph.dependencies(key, channel) {
if in_degree.contains_key(&dep) {
*in_degree.get_mut(&key).expect("key is in in_degree") += 1;
}
}
}
let mut queue = VecDeque::with_capacity(in_degree.len());
queue.extend(
unique_keys
.into_iter()
.filter(|&k| in_degree.get(&k).is_some_and(|°| deg == 0)),
);
Self {
graph,
channel,
queue,
in_degree,
stalled: false,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
#[must_use]
pub fn remaining(&self) -> usize {
self.in_degree.len()
}
#[must_use]
pub fn is_stalled(&self) -> bool {
self.stalled
}
#[must_use]
pub fn completion(&self) -> DrainCompletion {
if self.stalled {
DrainCompletion::Stalled {
remaining: self.remaining(),
}
} else {
DrainCompletion::Complete
}
}
#[must_use]
pub fn collect_with_completion(mut self) -> (Vec<K>, DrainCompletion) {
let mut out = Vec::with_capacity(self.in_degree.len());
out.extend(&mut self);
let completion = self.completion();
(out, completion)
}
}
impl<'a, K> DrainSortedDeterministic<'a, K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
pub(crate) fn from_iter_with_capacity<I>(
invalidated_keys: I,
cap: usize,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> Self
where
I: Iterator<Item = K>,
{
let mut in_degree: Vec<u32> = Vec::new();
let mut unique_keys = Vec::with_capacity(cap);
for key in invalidated_keys {
let idx = key.index();
if idx >= in_degree.len() {
let target_len =
prepare_dense_growth(&mut in_degree, idx, "deterministic drain in-degree");
in_degree.resize(target_len, DENSE_SENTINEL);
}
if in_degree[idx] == DENSE_SENTINEL {
in_degree[idx] = 0;
unique_keys.push(key);
}
}
for &key in &unique_keys {
for dep in graph.dependencies(key, channel) {
let dep_idx = dep.index();
if dep_idx < in_degree.len() && in_degree[dep_idx] != DENSE_SENTINEL {
in_degree[key.index()] += 1;
}
}
}
let remaining = unique_keys.len();
let mut ready = BinaryHeap::with_capacity(remaining);
for key in unique_keys {
if in_degree[key.index()] == 0 {
ready.push(Reverse(key));
}
}
Self {
graph,
channel,
ready,
in_degree,
remaining,
stalled: false,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.ready.is_empty()
}
#[must_use]
pub fn remaining(&self) -> usize {
self.remaining
}
#[must_use]
pub fn is_stalled(&self) -> bool {
self.stalled
}
#[must_use]
pub fn completion(&self) -> DrainCompletion {
if self.stalled {
DrainCompletion::Stalled {
remaining: self.remaining(),
}
} else {
DrainCompletion::Complete
}
}
#[must_use]
pub fn collect_with_completion(mut self) -> (Vec<K>, DrainCompletion) {
let mut out = Vec::with_capacity(self.remaining);
out.extend(&mut self);
let completion = self.completion();
(out, completion)
}
}
impl<K> Iterator for DrainSorted<'_, K>
where
K: Copy + Eq + Hash + DenseKey,
{
type Item = K;
fn next(&mut self) -> Option<Self::Item> {
let Some(key) = self.queue.pop_front() else {
if !self.in_degree.is_empty() {
self.stalled = true;
}
return None;
};
self.in_degree.remove(&key);
for dependent in self.graph.dependents(key, self.channel) {
if let Some(deg) = self.in_degree.get_mut(&dependent) {
*deg -= 1;
if *deg == 0 {
self.queue.push_back(dependent);
}
}
}
Some(key)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.in_degree.len();
(remaining, Some(remaining))
}
}
impl<K> Iterator for DrainSortedDeterministic<'_, K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
type Item = K;
fn next(&mut self) -> Option<Self::Item> {
let Some(Reverse(key)) = self.ready.pop() else {
if self.remaining > 0 {
self.stalled = true;
}
return None;
};
self.in_degree[key.index()] = DENSE_SENTINEL;
self.remaining -= 1;
for dependent in self.graph.dependents(key, self.channel) {
let idx = dependent.index();
if idx < self.in_degree.len() && self.in_degree[idx] != DENSE_SENTINEL {
self.in_degree[idx] -= 1;
if self.in_degree[idx] == 0 {
self.ready.push(Reverse(dependent));
}
}
}
Some(key)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
pub fn drain_sorted<'a, K>(
invalidated: &mut crate::InvalidationSet<K>,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> DrainSorted<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
DrainBuilder::new(invalidated, graph, channel)
.invalidated_only()
.run()
}
pub fn drain_sorted_deterministic<'a, K>(
invalidated: &mut crate::InvalidationSet<K>,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> DrainSortedDeterministic<'a, K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
DrainBuilder::new(invalidated, graph, channel)
.invalidated_only()
.deterministic()
.run()
}
pub fn drain_affected_sorted<'a, K>(
invalidated: &mut crate::InvalidationSet<K>,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> DrainSorted<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
DrainBuilder::new(invalidated, graph, channel)
.affected()
.run()
}
pub fn drain_affected_sorted_with_trace<'a, K, T>(
invalidated: &mut crate::InvalidationSet<K>,
graph: &'a InvalidationGraph<K>,
channel: Channel,
scratch: &mut TraversalScratch<K>,
trace: &mut T,
) -> DrainSorted<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
T: InvalidationTrace<K>,
{
DrainBuilder::new(invalidated, graph, channel)
.affected()
.trace(scratch, trace)
.run()
}
pub fn drain_affected_sorted_deterministic<'a, K>(
invalidated: &mut crate::InvalidationSet<K>,
graph: &'a InvalidationGraph<K>,
channel: Channel,
) -> DrainSortedDeterministic<'a, K>
where
K: Copy + Eq + Hash + Ord + DenseKey,
{
DrainBuilder::new(invalidated, graph, channel)
.affected()
.deterministic()
.run()
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
use crate::TraversalScratch;
use crate::graph::CycleHandling;
use crate::set::InvalidationSet;
use crate::trace::OneParentRecorder;
const LAYOUT: Channel = Channel::new(0);
#[test]
fn topological_order_chain() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys = vec![4, 2, 1, 3]; let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert_eq!(sorted, vec![1, 2, 3, 4]);
}
#[test]
fn topological_order_diamond() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys = vec![4, 3, 2, 1];
let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert_eq!(sorted[0], 1);
assert_eq!(sorted[3], 4);
assert!(sorted[1] == 2 || sorted[1] == 3);
assert!(sorted[2] == 2 || sorted[2] == 3);
}
#[test]
fn partial_invalidated_set() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys = vec![3, 2];
let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert_eq!(sorted, vec![2, 3]);
}
#[test]
fn no_dependencies() {
let graph = InvalidationGraph::<u32>::new();
let invalidated_keys = vec![3, 1, 2];
let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert_eq!(sorted.len(), 3);
}
#[test]
fn drain_sorted_function() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, LAYOUT);
let sorted: Vec<_> = drain_sorted(&mut invalidated, &graph, LAYOUT).collect();
assert_eq!(sorted, vec![1, 2]);
assert!(!invalidated.has_invalidated(LAYOUT));
}
#[test]
fn empty_invalidated_set() {
let graph = InvalidationGraph::<u32>::new();
let invalidated_keys: Vec<u32> = vec![];
let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert!(sorted.is_empty());
}
#[test]
fn size_hint_accurate() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys = vec![1, 2];
let cap = invalidated_keys.len();
let mut drain =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT);
assert_eq!(drain.size_hint(), (2, Some(2)));
assert_eq!(drain.remaining(), 2);
let _ = drain.next();
assert_eq!(drain.size_hint(), (1, Some(1)));
let _ = drain.next();
assert_eq!(drain.size_hint(), (0, Some(0)));
assert!(drain.is_empty());
}
#[test]
fn duplicate_keys_deduplicated() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys = vec![1, 2, 1, 2, 1];
let cap = invalidated_keys.len();
let sorted: Vec<_> =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT)
.collect();
assert_eq!(sorted.len(), 2);
assert_eq!(sorted, vec![1, 2]);
}
#[test]
fn cycles_stall_drain() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Allow)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Allow)
.unwrap();
graph
.add_dependency(1, 3, LAYOUT, CycleHandling::Allow)
.unwrap();
let invalidated_keys = vec![1, 2, 3];
let cap = invalidated_keys.len();
let mut drain =
DrainSorted::from_iter_with_capacity(invalidated_keys.into_iter(), cap, &graph, LAYOUT);
let sorted: Vec<_> = drain.by_ref().collect();
assert!(
sorted.is_empty(),
"cycle should prevent any keys from being yielded"
);
assert!(drain.is_stalled());
assert_eq!(
drain.completion(),
DrainCompletion::Stalled { remaining: 3 }
);
}
#[test]
fn cycles_stall_drain_collect_with_completion() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Allow)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Allow)
.unwrap();
graph
.add_dependency(1, 3, LAYOUT, CycleHandling::Allow)
.unwrap();
let mut invalidated = InvalidationSet::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(2, LAYOUT);
invalidated.mark(3, LAYOUT);
let (sorted, completion) =
drain_sorted(&mut invalidated, &graph, LAYOUT).collect_with_completion();
assert!(sorted.is_empty());
assert_eq!(completion, DrainCompletion::Stalled { remaining: 3 });
}
#[test]
fn drain_affected_sorted_expands_dependents() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::new();
invalidated.mark(1, LAYOUT);
let sorted: Vec<_> = drain_affected_sorted(&mut invalidated, &graph, LAYOUT).collect();
assert_eq!(sorted, vec![1, 2, 3, 4]);
assert!(!invalidated.has_invalidated(LAYOUT));
}
#[test]
fn drain_affected_sorted_multiple_roots() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::new();
invalidated.mark(1, LAYOUT);
invalidated.mark(3, LAYOUT);
let sorted: Vec<_> = drain_affected_sorted(&mut invalidated, &graph, LAYOUT).collect();
assert_eq!(sorted.len(), 4);
}
#[test]
fn deterministic_topological_order_diamond_is_total() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let invalidated_keys: Vec<u32> = vec![4, 3, 2, 1];
let cap = invalidated_keys.len();
let sorted: Vec<_> = DrainSortedDeterministic::from_iter_with_capacity(
invalidated_keys.into_iter(),
cap,
&graph,
LAYOUT,
)
.collect();
assert_eq!(sorted, vec![1, 2, 3, 4]);
}
#[test]
#[should_panic(expected = "DenseKey index")]
fn deterministic_drain_rejects_sparse_key_space() {
let graph = InvalidationGraph::<usize>::new();
let mut invalidated = InvalidationSet::new();
invalidated.mark(usize::MAX, LAYOUT);
let _: Vec<_> = drain_sorted_deterministic(&mut invalidated, &graph, LAYOUT).collect();
}
#[test]
fn affected_sorted_with_trace_records_one_path() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let mut invalidated = InvalidationSet::new();
invalidated.mark(1, LAYOUT);
let mut scratch = TraversalScratch::new();
let mut rec = OneParentRecorder::new();
let sorted: Vec<_> = drain_affected_sorted_with_trace(
&mut invalidated,
&graph,
LAYOUT,
&mut scratch,
&mut rec,
)
.collect();
assert_eq!(sorted, vec![1, 2, 3]);
assert_eq!(rec.explain_path(3, LAYOUT).unwrap(), vec![1, 2, 3]);
}
}