use alloc::vec::Vec;
use core::fmt;
use core::hash::Hash;
use hashbrown::HashSet;
use crate::channel::{Channel, ChannelSet};
use crate::drain::{DenseKey, prepare_dense_growth};
use crate::scratch::TraversalScratch;
const MAX_CHANNELS: usize = 64;
#[derive(Clone, PartialEq, Eq)]
pub struct CycleError<K> {
pub from: K,
pub to: K,
pub channel: Channel,
}
impl<K: fmt::Debug> fmt::Debug for CycleError<K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CycleError {{ from: {:?}, to: {:?}, channel: {:?} }}",
self.from, self.to, self.channel
)
}
}
impl<K: fmt::Debug> fmt::Display for CycleError<K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"adding dependency {:?} -> {:?} in {:?} would create a cycle",
self.from, self.to, self.channel
)
}
}
impl<K: fmt::Debug> core::error::Error for CycleError<K> {}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
pub enum CycleHandling {
#[default]
DebugAssert,
Error,
Ignore,
Allow,
}
#[derive(Debug, Clone)]
pub struct InvalidationGraph<K>
where
K: Copy + Eq + Hash + DenseKey,
{
forward: [Vec<Vec<K>>; MAX_CHANNELS],
reverse: [Vec<Vec<K>>; MAX_CHANNELS],
forward_channels: Vec<ChannelSet>,
reverse_channels: Vec<ChannelSet>,
}
impl<K> Default for InvalidationGraph<K>
where
K: Copy + Eq + Hash + DenseKey,
{
fn default() -> Self {
Self::new()
}
}
#[inline]
fn grow<T: Default>(vec: &mut Vec<T>, idx: usize) {
if idx >= vec.len() {
let target_len = prepare_dense_growth(vec, idx, "dependency graph adjacency");
vec.resize_with(target_len, T::default);
}
}
#[inline]
fn grow_channels(vec: &mut Vec<ChannelSet>, idx: usize) {
if idx >= vec.len() {
let target_len = prepare_dense_growth(vec, idx, "dependency graph channel cache");
vec.resize(target_len, ChannelSet::EMPTY);
}
}
impl<K> InvalidationGraph<K>
where
K: Copy + Eq + Hash + DenseKey,
{
#[must_use]
pub fn new() -> Self {
Self {
forward: core::array::from_fn(|_| Vec::new()),
reverse: core::array::from_fn(|_| Vec::new()),
forward_channels: Vec::new(),
reverse_channels: Vec::new(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.forward_channels.iter().all(|cs| cs.is_empty())
}
pub fn add_dependency(
&mut self,
from: K,
to: K,
channel: Channel,
handling: CycleHandling,
) -> Result<bool, CycleError<K>> {
if from == to {
return self.handle_cycle(from, to, channel, handling);
}
if handling != CycleHandling::Allow && self.would_create_cycle(from, to, channel) {
return self.handle_cycle(from, to, channel, handling);
}
let ch = channel.index() as usize;
let from_idx = from.index();
let to_idx = to.index();
let fwd = &mut self.forward[ch];
grow(fwd, from_idx);
if fwd[from_idx].contains(&to) {
return Ok(false);
}
fwd[from_idx].push(to);
let rev = &mut self.reverse[ch];
grow(rev, to_idx);
rev[to_idx].push(from);
grow_channels(&mut self.forward_channels, from_idx);
self.forward_channels[from_idx].insert(channel);
grow_channels(&mut self.reverse_channels, to_idx);
self.reverse_channels[to_idx].insert(channel);
Ok(true)
}
fn handle_cycle(
&self,
from: K,
to: K,
channel: Channel,
handling: CycleHandling,
) -> Result<bool, CycleError<K>> {
match handling {
CycleHandling::DebugAssert => {
debug_assert!(false, "adding dependency would create a cycle");
Ok(false)
}
CycleHandling::Error => Err(CycleError { from, to, channel }),
CycleHandling::Ignore | CycleHandling::Allow => Ok(false),
}
}
fn would_create_cycle(&self, from: K, to: K, channel: Channel) -> bool {
let mut visited = HashSet::new();
let mut stack = Vec::new();
stack.push(to);
while let Some(current) = stack.pop() {
if current == from {
return true;
}
if !visited.insert(current) {
continue;
}
let ch = channel.index() as usize;
let fwd = &self.forward[ch];
let idx = current.index();
if idx < fwd.len() {
stack.extend(fwd[idx].iter().copied());
}
}
false
}
pub fn remove_dependency(&mut self, from: K, to: K, channel: Channel) -> bool {
let ch = channel.index() as usize;
let from_idx = from.index();
let to_idx = to.index();
let fwd = &mut self.forward[ch];
let removed = if from_idx < fwd.len() {
if let Some(pos) = fwd[from_idx].iter().position(|&k| k == to) {
fwd[from_idx].swap_remove(pos);
true
} else {
false
}
} else {
false
};
if !removed {
return false;
}
let rev = &mut self.reverse[ch];
if to_idx < rev.len()
&& let Some(pos) = rev[to_idx].iter().position(|&k| k == from)
{
rev[to_idx].swap_remove(pos);
}
if fwd[from_idx].is_empty() && from_idx < self.forward_channels.len() {
self.forward_channels[from_idx].remove(channel);
}
if to_idx < rev.len() && rev[to_idx].is_empty() && to_idx < self.reverse_channels.len() {
self.reverse_channels[to_idx].remove(channel);
}
true
}
pub fn replace_dependencies(
&mut self,
from: K,
channel: Channel,
to: impl IntoIterator<Item = K>,
handling: CycleHandling,
) -> Result<bool, CycleError<K>> {
let mut new_set: Vec<K> = Vec::new();
for k in to {
if !new_set.contains(&k) {
new_set.push(k);
}
}
let ch = channel.index() as usize;
let from_idx = from.index();
let fwd = &self.forward[ch];
let old = if from_idx < fwd.len() {
fwd[from_idx].as_slice()
} else {
&[]
};
let unchanged = old.len() == new_set.len() && old.iter().all(|dep| new_set.contains(dep));
if unchanged {
return Ok(false);
}
let mut to_remove: Vec<K> = Vec::new();
for &dep in old {
if !new_set.contains(&dep) {
to_remove.push(dep);
}
}
let mut to_add: Vec<K> = Vec::new();
for &dep in &new_set {
if !old.contains(&dep) {
to_add.push(dep);
}
}
let mut removed: Vec<K> = Vec::new();
for dep in to_remove.iter().copied() {
if self.remove_dependency(from, dep, channel) {
removed.push(dep);
}
}
let mut added: Vec<K> = Vec::new();
for dep in to_add.iter().copied() {
match self.add_dependency(from, dep, channel, handling) {
Ok(true) => added.push(dep),
Ok(false) => {}
Err(e) => {
for d in added {
let _ = self.remove_dependency(from, d, channel);
}
for d in removed {
let _ = self.add_dependency(from, d, channel, CycleHandling::Allow);
}
return Err(e);
}
}
}
Ok(true)
}
pub fn remove_key(&mut self, key: K) {
let key_idx = key.index();
let fwd_channels = if key_idx < self.forward_channels.len() {
self.forward_channels[key_idx]
} else {
ChannelSet::EMPTY
};
for channel in fwd_channels {
let ch = channel.index() as usize;
let deps: Vec<K> = if key_idx < self.forward[ch].len() {
core::mem::take(&mut self.forward[ch][key_idx])
} else {
Vec::new()
};
for dep in deps {
let dep_idx = dep.index();
let rev = &mut self.reverse[ch];
if dep_idx < rev.len() {
if let Some(pos) = rev[dep_idx].iter().position(|&k| k == key) {
rev[dep_idx].swap_remove(pos);
}
if rev[dep_idx].is_empty() && dep_idx < self.reverse_channels.len() {
self.reverse_channels[dep_idx].remove(channel);
}
}
}
}
if key_idx < self.forward_channels.len() {
self.forward_channels[key_idx] = ChannelSet::EMPTY;
}
let rev_channels = if key_idx < self.reverse_channels.len() {
self.reverse_channels[key_idx]
} else {
ChannelSet::EMPTY
};
for channel in rev_channels {
let ch = channel.index() as usize;
let dependents: Vec<K> = if key_idx < self.reverse[ch].len() {
core::mem::take(&mut self.reverse[ch][key_idx])
} else {
Vec::new()
};
for dependent in dependents {
let dep_idx = dependent.index();
let fwd = &mut self.forward[ch];
if dep_idx < fwd.len() {
if let Some(pos) = fwd[dep_idx].iter().position(|&k| k == key) {
fwd[dep_idx].swap_remove(pos);
}
if fwd[dep_idx].is_empty() && dep_idx < self.forward_channels.len() {
self.forward_channels[dep_idx].remove(channel);
}
}
}
}
if key_idx < self.reverse_channels.len() {
self.reverse_channels[key_idx] = ChannelSet::EMPTY;
}
}
#[inline]
pub fn dependencies(&self, key: K, channel: Channel) -> impl Iterator<Item = K> + '_ {
let ch = channel.index() as usize;
let fwd = &self.forward[ch];
let idx = key.index();
let slice = if idx < fwd.len() {
fwd[idx].as_slice()
} else {
&[]
};
slice.iter().copied()
}
#[inline]
pub fn dependents(&self, key: K, channel: Channel) -> impl Iterator<Item = K> + '_ {
let ch = channel.index() as usize;
let rev = &self.reverse[ch];
let idx = key.index();
let slice = if idx < rev.len() {
rev[idx].as_slice()
} else {
&[]
};
slice.iter().copied()
}
pub fn transitive_dependents(&self, key: K, channel: Channel) -> impl Iterator<Item = K> + '_ {
TransitiveDependentsIter::new(self, key, channel)
}
pub fn for_each_transitive_dependent(
&self,
key: K,
channel: Channel,
scratch: &mut TraversalScratch<K>,
mut f: impl FnMut(K),
) {
scratch.reset();
scratch.stack.extend(self.dependents(key, channel));
while let Some(next) = scratch.stack.pop() {
if scratch.visited.insert(next) {
f(next);
scratch.stack.extend(self.dependents(next, channel));
}
}
}
#[must_use]
pub fn dependency_channels(&self, key: K) -> ChannelSet {
let idx = key.index();
if idx < self.forward_channels.len() {
self.forward_channels[idx]
} else {
ChannelSet::EMPTY
}
}
#[must_use]
pub fn dependent_channels(&self, key: K) -> ChannelSet {
let idx = key.index();
if idx < self.reverse_channels.len() {
self.reverse_channels[idx]
} else {
ChannelSet::EMPTY
}
}
#[inline]
#[must_use]
pub fn has_dependencies(&self, key: K, channel: Channel) -> bool {
let ch = channel.index() as usize;
let fwd = &self.forward[ch];
let idx = key.index();
idx < fwd.len() && !fwd[idx].is_empty()
}
#[must_use]
pub fn has_dependents(&self, key: K, channel: Channel) -> bool {
let ch = channel.index() as usize;
let rev = &self.reverse[ch];
let idx = key.index();
idx < rev.len() && !rev[idx].is_empty()
}
#[must_use]
pub fn in_degree(&self, key: K, channel: Channel) -> usize {
let ch = channel.index() as usize;
let fwd = &self.forward[ch];
let idx = key.index();
if idx < fwd.len() { fwd[idx].len() } else { 0 }
}
#[must_use]
pub fn out_degree(&self, key: K, channel: Channel) -> usize {
let ch = channel.index() as usize;
let rev = &self.reverse[ch];
let idx = key.index();
if idx < rev.len() { rev[idx].len() } else { 0 }
}
pub fn keys(&self) -> impl Iterator<Item = K> + '_ {
let mut seen = HashSet::new();
let max_len = self.forward_channels.len().max(self.reverse_channels.len());
let mut all_keys: Vec<K> = Vec::new();
for ch in 0..MAX_CHANNELS {
for inner in &self.forward[ch] {
for &k in inner {
if seen.insert(k) {
all_keys.push(k);
}
}
}
for inner in &self.reverse[ch] {
for &k in inner {
if seen.insert(k) {
all_keys.push(k);
}
}
}
}
let _ = max_len;
all_keys.into_iter()
}
#[must_use]
pub fn keys_vec(&self) -> Vec<K> {
self.keys().collect()
}
}
struct TransitiveDependentsIter<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
graph: &'a InvalidationGraph<K>,
channel: Channel,
visited: HashSet<K>,
stack: Vec<K>,
}
impl<'a, K> TransitiveDependentsIter<'a, K>
where
K: Copy + Eq + Hash + DenseKey,
{
fn new(graph: &'a InvalidationGraph<K>, start: K, channel: Channel) -> Self {
let mut iter = Self {
graph,
channel,
visited: HashSet::new(),
stack: Vec::new(),
};
iter.stack.extend(graph.dependents(start, channel));
iter
}
}
impl<K> Iterator for TransitiveDependentsIter<'_, K>
where
K: Copy + Eq + Hash + DenseKey,
{
type Item = K;
fn next(&mut self) -> Option<Self::Item> {
while let Some(key) = self.stack.pop() {
if self.visited.insert(key) {
self.stack.extend(self.graph.dependents(key, self.channel));
return Some(key);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
const A11Y: Channel = Channel::new(2);
#[test]
fn add_and_query_dependencies() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
assert!(graph.dependencies(2, LAYOUT).any(|k| k == 1));
assert!(graph.dependents(1, LAYOUT).any(|k| k == 2));
assert!(graph.dependents(2, LAYOUT).any(|k| k == 3));
}
#[test]
fn replace_dependencies_updates_in_place() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(10, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(10, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let changed = graph
.replace_dependencies(10, LAYOUT, [3, 4], CycleHandling::Error)
.unwrap();
assert!(changed);
let deps: Vec<_> = graph.dependencies(10, LAYOUT).collect();
assert_eq!(deps.len(), 2);
assert!(deps.contains(&3));
assert!(deps.contains(&4));
assert!(!deps.contains(&1));
assert!(!deps.contains(&2));
}
#[test]
fn replace_dependencies_rolls_back_on_cycle_error() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(1, 3, LAYOUT, CycleHandling::Error)
.unwrap();
let err = graph
.replace_dependencies(1, LAYOUT, [2], CycleHandling::Error)
.unwrap_err();
assert_eq!(err.from, 1);
assert_eq!(err.to, 2);
let deps: Vec<_> = graph.dependencies(1, LAYOUT).collect();
assert_eq!(deps, vec![3]);
assert!(!graph.dependencies(1, LAYOUT).any(|k| k == 2));
assert!(graph.dependencies(2, LAYOUT).any(|k| k == 1));
}
#[test]
fn replace_dependencies_noop_when_set_unchanged_returns_false() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(10, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(10, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let changed = graph
.replace_dependencies(10, LAYOUT, [2, 1, 2], CycleHandling::Error)
.unwrap();
assert!(!changed);
let deps: Vec<_> = graph.dependencies(10, LAYOUT).collect();
assert_eq!(deps.len(), 2);
assert!(deps.contains(&1));
assert!(deps.contains(&2));
}
#[test]
fn replace_dependencies_rolls_back_mixed_delta_on_cycle_error() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(1, 3, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(1, 4, LAYOUT, CycleHandling::Error)
.unwrap();
let err = graph
.replace_dependencies(1, LAYOUT, [4, 2], CycleHandling::Error)
.unwrap_err();
assert_eq!(err.from, 1);
assert_eq!(err.to, 2);
let deps: Vec<_> = graph.dependencies(1, LAYOUT).collect();
assert_eq!(deps.len(), 2);
assert!(deps.contains(&3));
assert!(deps.contains(&4));
assert!(!deps.contains(&2));
}
#[test]
fn replace_dependencies_is_channel_scoped() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(7, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(7, 9, PAINT, CycleHandling::Error)
.unwrap();
graph
.replace_dependencies(7, LAYOUT, [2, 3], CycleHandling::Error)
.unwrap();
let layout: Vec<_> = graph.dependencies(7, LAYOUT).collect();
assert_eq!(layout.len(), 2);
assert!(layout.contains(&2));
assert!(layout.contains(&3));
assert!(!layout.contains(&1));
let paint: Vec<_> = graph.dependencies(7, PAINT).collect();
assert_eq!(paint, vec![9]);
}
#[test]
fn cycle_detection_error() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let result = graph.add_dependency(1, 3, LAYOUT, CycleHandling::Error);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.from, 1);
assert_eq!(err.to, 3);
}
#[test]
fn self_dependency_is_cycle() {
let mut graph = InvalidationGraph::<u32>::new();
let result = graph.add_dependency(1, 1, LAYOUT, CycleHandling::Error);
assert!(result.is_err());
}
#[test]
fn cycle_ignore() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Ignore)
.unwrap();
let result = graph.add_dependency(1, 1, LAYOUT, CycleHandling::Ignore);
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn cycle_allow() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Allow)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Allow)
.unwrap();
let result = graph.add_dependency(1, 3, LAYOUT, CycleHandling::Allow);
assert!(result.is_ok());
assert!(result.unwrap()); }
#[test]
fn remove_dependency() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
assert!(graph.dependencies(2, LAYOUT).any(|k| k == 1));
let removed = graph.remove_dependency(2, 1, LAYOUT);
assert!(removed);
assert!(!graph.dependencies(2, LAYOUT).any(|k| k == 1));
assert!(!graph.remove_dependency(2, 1, LAYOUT));
}
#[test]
fn remove_key() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(2, 1, PAINT, CycleHandling::Error)
.unwrap();
graph.remove_key(2);
assert!(!graph.dependencies(2, LAYOUT).any(|_| true));
assert!(!graph.dependents(1, LAYOUT).any(|_| true));
assert!(!graph.dependencies(3, LAYOUT).any(|_| true));
}
#[test]
fn transitive_dependents() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(4, 2, LAYOUT, CycleHandling::Error)
.unwrap();
let transitive: Vec<_> = graph.transitive_dependents(1, LAYOUT).collect();
assert_eq!(transitive.len(), 3);
assert!(transitive.contains(&2));
assert!(transitive.contains(&3));
assert!(transitive.contains(&4));
}
#[test]
fn channel_independence() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
assert!(graph.has_dependencies(2, LAYOUT));
assert!(!graph.has_dependencies(2, PAINT));
assert!(graph.has_dependents(1, LAYOUT));
assert!(!graph.has_dependents(1, PAINT));
}
#[test]
fn in_out_degree() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(3, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(3, 2, LAYOUT, CycleHandling::Error)
.unwrap();
assert_eq!(graph.in_degree(3, LAYOUT), 2);
assert_eq!(graph.out_degree(1, LAYOUT), 1);
}
#[test]
fn dependency_channels() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
graph
.add_dependency(2, 1, PAINT, CycleHandling::Error)
.unwrap();
let channels = graph.dependency_channels(2);
assert!(channels.contains(LAYOUT));
assert!(channels.contains(PAINT));
assert!(!channels.contains(A11Y));
}
#[test]
fn keys_and_keys_vec_are_unique() {
let mut graph = InvalidationGraph::<u32>::new();
graph
.add_dependency(2, 1, LAYOUT, CycleHandling::Error)
.unwrap();
let keys: Vec<_> = graph.keys().collect();
assert_eq!(keys.len(), 2);
assert!(keys.contains(&1));
assert!(keys.contains(&2));
let keys_vec = graph.keys_vec();
assert_eq!(keys_vec.len(), 2);
assert!(keys_vec.contains(&1));
assert!(keys_vec.contains(&2));
}
#[test]
#[should_panic(expected = "DenseKey index")]
fn add_dependency_rejects_sparse_key_space() {
let mut graph = InvalidationGraph::<usize>::new();
let _ = graph.add_dependency(usize::MAX, 7, LAYOUT, CycleHandling::Error);
}
}