use alloc::vec::Vec;
use core::hash::Hash;
use hashbrown::HashMap;
use crate::channel::Channel;
#[derive(Debug, Clone)]
pub struct CrossChannelEdges<K> {
forward: HashMap<(K, Channel), Vec<(K, Channel)>>,
reverse: HashMap<(K, Channel), Vec<(K, Channel)>>,
}
impl<K> Default for CrossChannelEdges<K>
where
K: Copy + Eq + Hash,
{
fn default() -> Self {
Self::new()
}
}
impl<K> CrossChannelEdges<K>
where
K: Copy + Eq + Hash,
{
#[must_use]
pub fn new() -> Self {
Self {
forward: HashMap::new(),
reverse: HashMap::new(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.forward.is_empty()
}
pub fn add_edge(&mut self, from_key: K, from_ch: Channel, to_key: K, to_ch: Channel) -> bool {
let from = (from_key, from_ch);
let to = (to_key, to_ch);
let fwd = self.forward.entry(from).or_default();
if fwd.contains(&to) {
return false;
}
fwd.push(to);
self.reverse.entry(to).or_default().push(from);
true
}
pub fn remove_edge(
&mut self,
from_key: K,
from_ch: Channel,
to_key: K,
to_ch: Channel,
) -> bool {
let from = (from_key, from_ch);
let to = (to_key, to_ch);
let removed = if let Some(fwd) = self.forward.get_mut(&from) {
if let Some(pos) = fwd.iter().position(|e| *e == to) {
fwd.swap_remove(pos);
if fwd.is_empty() {
self.forward.remove(&from);
}
true
} else {
false
}
} else {
false
};
if removed
&& let Some(rev) = self.reverse.get_mut(&to)
&& let Some(pos) = rev.iter().position(|e| *e == from)
{
rev.swap_remove(pos);
if rev.is_empty() {
self.reverse.remove(&to);
}
}
removed
}
pub fn dependents(&self, key: K, ch: Channel) -> impl Iterator<Item = (K, Channel)> + '_ {
self.forward
.get(&(key, ch))
.into_iter()
.flat_map(|v| v.iter())
.copied()
}
pub fn dependencies(&self, key: K, ch: Channel) -> impl Iterator<Item = (K, Channel)> + '_ {
self.reverse
.get(&(key, ch))
.into_iter()
.flat_map(|v| v.iter())
.copied()
}
pub fn remove_key(&mut self, key: K) {
let fwd_keys: Vec<(K, Channel)> = self
.forward
.keys()
.filter(|(k, _)| *k == key)
.copied()
.collect();
for fwd_key in &fwd_keys {
if let Some(targets) = self.forward.remove(fwd_key) {
for to in targets {
if let Some(rev) = self.reverse.get_mut(&to) {
if let Some(pos) = rev.iter().position(|e| *e == *fwd_key) {
rev.swap_remove(pos);
}
if rev.is_empty() {
self.reverse.remove(&to);
}
}
}
}
}
let rev_keys: Vec<(K, Channel)> = self
.reverse
.keys()
.filter(|(k, _)| *k == key)
.copied()
.collect();
for rev_key in &rev_keys {
if let Some(sources) = self.reverse.remove(rev_key) {
for from in sources {
if let Some(fwd) = self.forward.get_mut(&from) {
if let Some(pos) = fwd.iter().position(|e| *e == *rev_key) {
fwd.swap_remove(pos);
}
if fwd.is_empty() {
self.forward.remove(&from);
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
const COMPOSITE: Channel = Channel::new(2);
#[test]
fn new_is_empty() {
let edges = CrossChannelEdges::<u32>::new();
assert!(edges.is_empty());
}
#[test]
fn add_and_query_edge() {
let mut edges = CrossChannelEdges::<u32>::new();
assert!(edges.add_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.is_empty());
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps, vec![(2, PAINT)]);
let sources: Vec<_> = edges.dependencies(2, PAINT).collect();
assert_eq!(sources, vec![(1, LAYOUT)]);
}
#[test]
fn add_duplicate_returns_false() {
let mut edges = CrossChannelEdges::<u32>::new();
assert!(edges.add_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.add_edge(1, LAYOUT, 2, PAINT));
}
#[test]
fn remove_edge() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
assert!(edges.remove_edge(1, LAYOUT, 2, PAINT));
assert!(!edges.remove_edge(1, LAYOUT, 2, PAINT));
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps, vec![(3, COMPOSITE)]);
}
#[test]
fn remove_key_clears_all_channels() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, PAINT, 3, COMPOSITE);
edges.add_edge(5, COMPOSITE, 1, LAYOUT);
edges.remove_key(1);
assert!(edges.dependents(1, LAYOUT).next().is_none());
assert!(edges.dependents(1, PAINT).next().is_none());
assert!(edges.dependencies(1, LAYOUT).next().is_none());
assert!(edges.dependents(5, COMPOSITE).next().is_none());
}
#[test]
fn multiple_edges_from_same_source() {
let mut edges = CrossChannelEdges::<u32>::new();
edges.add_edge(1, LAYOUT, 2, PAINT);
edges.add_edge(1, LAYOUT, 3, COMPOSITE);
edges.add_edge(1, LAYOUT, 4, PAINT);
let deps: Vec<_> = edges.dependents(1, LAYOUT).collect();
assert_eq!(deps.len(), 3);
}
#[test]
fn no_dependents_returns_empty() {
let edges = CrossChannelEdges::<u32>::new();
assert_eq!(edges.dependents(1, LAYOUT).count(), 0);
assert_eq!(edges.dependencies(1, LAYOUT).count(), 0);
}
}