use alloc::vec::Vec;
use core::hash::Hash;
use hashbrown::HashMap;
use crate::channel::Channel;
#[derive(Debug, Clone)]
pub struct CrossChannelEdges<K> {
forward: HashMap<(K, Channel), Vec<(K, Channel)>>,
reverse: HashMap<(K, Channel), Vec<(K, Channel)>>,
}
impl<K> Default for CrossChannelEdges<K>
where
K: Copy + Eq + Hash,
{
fn default() -> Self {
Self::new()
}
}
impl<K> CrossChannelEdges<K>
where
K: Copy + Eq + Hash,
{
#[must_use]
pub fn new() -> Self {
Self {
forward: HashMap::new(),
reverse: HashMap::new(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.forward.is_empty()
}
pub fn add_edge(&mut self, from_key: K, from_ch: Channel, to_key: K, to_ch: Channel) -> bool {
let from = (from_key, from_ch);
let to = (to_key, to_ch);
let fwd = self.forward.entry(from).or_default();
if fwd.contains(&to) {
return false;
}
fwd.push(to);
self.reverse.entry(to).or_default().push(from);
true
}
pub fn replace_dependents(
&mut self,
from_key: K,
from_ch: Channel,
targets: impl IntoIterator<Item = (K, Channel)>,
) -> bool {
let mut new_set: Vec<(K, Channel)> = Vec::new();
for target in targets {
if !new_set.contains(&target) {
new_set.push(target);
}
}
let from = (from_key, from_ch);
let old: &[(K, Channel)] = self.forward.get(&from).map_or(&[], Vec::as_slice);
let unchanged =
old.len() == new_set.len() && old.iter().all(|target| new_set.contains(target));
if unchanged {
return false;
}
self.clear_dependents(from_key, from_ch);
for (to_key, to_ch) in new_set {
self.add_edge(from_key, from_ch, to_key, to_ch);
}
true
}
pub fn remove_edge(
&mut self,
from_key: K,
from_ch: Channel,
to_key: K,
to_ch: Channel,
) -> bool {
let from = (from_key, from_ch);
let to = (to_key, to_ch);
let removed = if let Some(fwd) = self.forward.get_mut(&from) {
if let Some(pos) = fwd.iter().position(|e| *e == to) {
fwd.swap_remove(pos);
if fwd.is_empty() {
self.forward.remove(&from);
}
true
} else {
false
}
} else {
false
};
if removed
&& let Some(rev) = self.reverse.get_mut(&to)
&& let Some(pos) = rev.iter().position(|e| *e == from)
{
rev.swap_remove(pos);
if rev.is_empty() {
self.reverse.remove(&to);
}
}
removed
}
pub fn clear_dependents(&mut self, from_key: K, from_ch: Channel) -> bool {
let from = (from_key, from_ch);
let Some(targets) = self.forward.remove(&from) else {
return false;
};
for to in targets {
if let Some(rev) = self.reverse.get_mut(&to) {
if let Some(pos) = rev.iter().position(|e| *e == from) {
rev.swap_remove(pos);
}
if rev.is_empty() {
self.reverse.remove(&to);
}
}
}
true
}
pub fn clear_dependencies(&mut self, to_key: K, to_ch: Channel) -> bool {
let to = (to_key, to_ch);
let Some(sources) = self.reverse.remove(&to) else {
return false;
};
for from in sources {
if let Some(fwd) = self.forward.get_mut(&from) {
if let Some(pos) = fwd.iter().position(|e| *e == to) {
fwd.swap_remove(pos);
}
if fwd.is_empty() {
self.forward.remove(&from);
}
}
}
true
}
pub fn dependents(&self, key: K, ch: Channel) -> impl Iterator<Item = (K, Channel)> + '_ {
self.forward
.get(&(key, ch))
.into_iter()
.flat_map(|v| v.iter())
.copied()
}
pub fn dependencies(&self, key: K, ch: Channel) -> impl Iterator<Item = (K, Channel)> + '_ {
self.reverse
.get(&(key, ch))
.into_iter()
.flat_map(|v| v.iter())
.copied()
}
pub fn remove_key(&mut self, key: K) {
let fwd_keys: Vec<(K, Channel)> = self
.forward
.keys()
.filter(|(k, _)| *k == key)
.copied()
.collect();
for fwd_key in &fwd_keys {
if let Some(targets) = self.forward.remove(fwd_key) {
for to in targets {
if let Some(rev) = self.reverse.get_mut(&to) {
if let Some(pos) = rev.iter().position(|e| *e == *fwd_key) {
rev.swap_remove(pos);
}
if rev.is_empty() {
self.reverse.remove(&to);
}
}
}
}
}
let rev_keys: Vec<(K, Channel)> = self
.reverse
.keys()
.filter(|(k, _)| *k == key)
.copied()
.collect();
for rev_key in &rev_keys {
if let Some(sources) = self.reverse.remove(rev_key) {
for from in sources {
if let Some(fwd) = self.forward.get_mut(&from) {
if let Some(pos) = fwd.iter().position(|e| *e == *rev_key) {
fwd.swap_remove(pos);
}
if fwd.is_empty() {
self.forward.remove(&from);
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
const COMPOSITE: Channel = Channel::new(2);
#[test]
fn new_is_empty() {
let edges = CrossChannelEdges::<u32>::new();
assert!(edges.is_empty());
}
#[test]
fn add_and_query_edge() {
let mut edges = CrossChannelEdges::<u32>::new();
assert!(edges.add_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.is_empty());
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps, vec![(2, PAINT)]);
let sources: Vec<_> = edges.dependencies(2, PAINT).collect();
assert_eq!(sources, vec![(1, LAYOUT)]);
}
#[test]
fn add_duplicate_returns_false() {
let mut edges = CrossChannelEdges::<u32>::new();
assert!(edges.add_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.add_edge(1, LAYOUT, 2, PAINT));
}
#[test]
fn remove_edge() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
assert!(edges.remove_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.remove_edge(1, LAYOUT, 2, PAINT));
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps, vec![(3, COMPOSITE)]);
}
#[test]
fn replace_dependents_updates_outgoing_edges() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
assert!(edges.replace_dependents(1, LAYOUT, [(3, COMPOSITE), (4, PAINT), (4, PAINT)]));
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps.len(), 2);
assert!(deps.contains(&(3, COMPOSITE)));
assert!(deps.contains(&(4, PAINT)));
assert!(edges.dependencies(2, PAINT).next().is_none());
assert_eq!(
edges.dependencies(4, PAINT).collect::<Vec<_>>(),
vec![(1, LAYOUT)]
);
assert!(!edges.replace_dependents(1, LAYOUT, [(4, PAINT), (3, COMPOSITE)]));
}
#[test]
fn clear_dependents_removes_outgoing_edges() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
assert!(edges.clear_dependents(1, LAYOUT));
assert!(!edges.clear_dependents(1, LAYOUT));
assert!(edges.dependents(1, LAYOUT).next().is_none());
assert!(edges.dependencies(2, PAINT).next().is_none());
assert!(edges.dependencies(3, COMPOSITE).next().is_none());
}
#[test]
fn clear_dependencies_removes_incoming_edges() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 3, PAINT);
edges.add_edge(2, COMPOSITE, 3, PAINT);
assert!(edges.clear_dependencies(3, PAINT));
assert!(!edges.clear_dependencies(3, PAINT));
assert!(edges.dependencies(3, PAINT).next().is_none());
assert!(edges.dependents(1, LAYOUT).next().is_none());
assert!(edges.dependents(2, COMPOSITE).next().is_none());
}
#[test]
fn remove_key_clears_all_channels() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, PAINT, 3, COMPOSITE);
edges.add_edge(5, COMPOSITE, 1, LAYOUT);
edges.remove_key(1);
assert!(edges.dependents(1, LAYOUT).next().is_none());
assert!(edges.dependents(1, PAINT).next().is_none());
assert!(edges.dependencies(1, LAYOUT).next().is_none());
assert!(edges.dependents(5, COMPOSITE).next().is_none());
}
#[test]
fn multiple_edges_from_same_source() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
edges.add_edge(1, LAYOUT, 4, PAINT);
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps.len(), 3);
}
#[test]
fn no_dependents_returns_empty() {
let edges = CrossChannelEdges::<u32>::new();
assert_eq!(edges.dependents(1, LAYOUT).count(), 0);
assert_eq!(edges.dependencies(1, LAYOUT).count(), 0);
}
}