use core::fmt;
use crate::channel::{Channel, ChannelSet};
const MAX_CHANNELS: usize = 64;
#[derive(Clone, PartialEq, Eq)]
pub struct CascadeCycleError {
pub from: Channel,
pub to: Channel,
}
impl fmt::Debug for CascadeCycleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CascadeCycleError {{ from: {:?}, to: {:?} }}",
self.from, self.to
)
}
}
impl fmt::Display for CascadeCycleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"adding cascade {:?} -> {:?} would create a cycle",
self.from, self.to
)
}
}
impl core::error::Error for CascadeCycleError {}
#[derive(Clone)]
pub struct ChannelCascade {
direct: [ChannelSet; MAX_CHANNELS],
transitive: [ChannelSet; MAX_CHANNELS],
}
impl fmt::Debug for ChannelCascade {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChannelCascade")
.field("direct", &self.direct)
.field("transitive", &self.transitive)
.finish()
}
}
impl Default for ChannelCascade {
fn default() -> Self {
Self::new()
}
}
impl ChannelCascade {
#[must_use]
pub fn new() -> Self {
Self {
direct: [ChannelSet::EMPTY; MAX_CHANNELS],
transitive: [ChannelSet::EMPTY; MAX_CHANNELS],
}
}
pub fn from_edges(
edges: impl IntoIterator<Item = (Channel, Channel)>,
) -> Result<Self, CascadeCycleError> {
let mut cascade = Self::new();
for (from, to) in edges {
cascade.add_cascade(from, to)?;
}
Ok(cascade)
}
pub fn add_cascade(&mut self, from: Channel, to: Channel) -> Result<bool, CascadeCycleError> {
if from == to {
return Err(CascadeCycleError { from, to });
}
if self.direct[from.index() as usize].contains(to) {
return Ok(false);
}
if self.is_reachable(to, from) {
return Err(CascadeCycleError { from, to });
}
self.direct[from.index() as usize].insert(to);
self.recompute_transitive();
Ok(true)
}
pub fn remove_cascade(&mut self, from: Channel, to: Channel) -> bool {
let idx = from.index() as usize;
if !self.direct[idx].contains(to) {
return false;
}
self.direct[idx].remove(to);
self.recompute_transitive();
true
}
#[inline]
#[must_use]
pub fn cascades_from(&self, channel: Channel) -> ChannelSet {
self.transitive[channel.index() as usize]
}
#[inline]
#[must_use]
pub fn direct_cascades_from(&self, channel: Channel) -> ChannelSet {
self.direct[channel.index() as usize]
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.direct.iter().all(|cs| cs.is_empty())
}
fn is_reachable(&self, start: Channel, target: Channel) -> bool {
let mut visited = ChannelSet::EMPTY;
let mut queue = ChannelSet::EMPTY;
queue.insert(start);
while !queue.is_empty() {
let ch = queue.iter().next().unwrap();
queue.remove(ch);
if ch == target {
return true;
}
if visited.contains(ch) {
continue;
}
visited.insert(ch);
let targets = self.direct[ch.index() as usize];
let new_targets = targets & !visited;
queue |= new_targets;
}
false
}
fn recompute_transitive(&mut self) {
for i in 0..MAX_CHANNELS {
let mut reachable = ChannelSet::EMPTY;
let mut frontier = self.direct[i];
while !frontier.is_empty() {
reachable |= frontier;
let mut next_frontier = ChannelSet::EMPTY;
for ch in frontier {
let targets = self.direct[ch.index() as usize] & !reachable;
next_frontier |= targets;
}
frontier = next_frontier;
}
self.transitive[i] = reachable;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const LAYOUT: Channel = Channel::new(0);
const PAINT: Channel = Channel::new(1);
const COMPOSITE: Channel = Channel::new(2);
const A11Y: Channel = Channel::new(3);
#[test]
fn new_cascade_is_empty() {
let cascade = ChannelCascade::new();
assert!(cascade.is_empty());
assert!(cascade.cascades_from(LAYOUT).is_empty());
assert!(cascade.direct_cascades_from(LAYOUT).is_empty());
}
#[test]
fn add_single_cascade() {
let mut cascade = ChannelCascade::new();
let added = cascade.add_cascade(LAYOUT, PAINT).unwrap();
assert!(added);
assert!(!cascade.is_empty());
assert!(cascade.direct_cascades_from(LAYOUT).contains(PAINT));
assert!(cascade.cascades_from(LAYOUT).contains(PAINT));
assert!(!cascade.cascades_from(PAINT).contains(LAYOUT));
}
#[test]
fn add_duplicate_returns_false() {
let mut cascade = ChannelCascade::new();
assert!(cascade.add_cascade(LAYOUT, PAINT).unwrap());
assert!(!cascade.add_cascade(LAYOUT, PAINT).unwrap());
}
#[test]
fn from_edges_builds_static_cascade() {
let cascade =
ChannelCascade::from_edges([(LAYOUT, PAINT), (PAINT, COMPOSITE), (LAYOUT, PAINT)])
.unwrap();
let targets = cascade.cascades_from(LAYOUT);
assert!(targets.contains(PAINT));
assert!(targets.contains(COMPOSITE));
}
#[test]
fn from_edges_rejects_cycles() {
let err = ChannelCascade::from_edges([(LAYOUT, PAINT), (PAINT, LAYOUT)]).unwrap_err();
assert_eq!(err.from, PAINT);
assert_eq!(err.to, LAYOUT);
}
#[test]
fn transitive_closure() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.add_cascade(PAINT, COMPOSITE).unwrap();
let direct = cascade.direct_cascades_from(LAYOUT);
assert!(direct.contains(PAINT));
assert!(!direct.contains(COMPOSITE));
let transitive = cascade.cascades_from(LAYOUT);
assert!(transitive.contains(PAINT));
assert!(transitive.contains(COMPOSITE));
let paint_trans = cascade.cascades_from(PAINT);
assert!(paint_trans.contains(COMPOSITE));
assert!(!paint_trans.contains(LAYOUT));
}
#[test]
fn self_cascade_is_cycle() {
let mut cascade = ChannelCascade::new();
let err = cascade.add_cascade(LAYOUT, LAYOUT).unwrap_err();
assert_eq!(err.from, LAYOUT);
assert_eq!(err.to, LAYOUT);
}
#[test]
fn direct_cycle_detected() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
let err = cascade.add_cascade(PAINT, LAYOUT).unwrap_err();
assert_eq!(err.from, PAINT);
assert_eq!(err.to, LAYOUT);
}
#[test]
fn transitive_cycle_detected() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.add_cascade(PAINT, COMPOSITE).unwrap();
let err = cascade.add_cascade(COMPOSITE, LAYOUT).unwrap_err();
assert_eq!(err.from, COMPOSITE);
assert_eq!(err.to, LAYOUT);
}
#[test]
fn remove_cascade() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.add_cascade(PAINT, COMPOSITE).unwrap();
assert!(cascade.remove_cascade(LAYOUT, PAINT));
assert!(cascade.cascades_from(LAYOUT).is_empty());
assert!(cascade.direct_cascades_from(LAYOUT).is_empty());
assert!(cascade.cascades_from(PAINT).contains(COMPOSITE));
}
#[test]
fn remove_nonexistent_returns_false() {
let mut cascade = ChannelCascade::new();
assert!(!cascade.remove_cascade(LAYOUT, PAINT));
}
#[test]
fn remove_allows_previously_cyclic_edge() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.add_cascade(PAINT, COMPOSITE).unwrap();
assert!(cascade.add_cascade(COMPOSITE, LAYOUT).is_err());
cascade.remove_cascade(LAYOUT, PAINT);
assert!(cascade.add_cascade(COMPOSITE, LAYOUT).is_ok());
}
#[test]
fn diamond_cascade() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.add_cascade(LAYOUT, A11Y).unwrap();
cascade.add_cascade(PAINT, COMPOSITE).unwrap();
cascade.add_cascade(A11Y, COMPOSITE).unwrap();
let targets = cascade.cascades_from(LAYOUT);
assert!(targets.contains(PAINT));
assert!(targets.contains(A11Y));
assert!(targets.contains(COMPOSITE));
assert_eq!(targets.len(), 3);
}
#[test]
fn is_empty_after_remove_all() {
let mut cascade = ChannelCascade::new();
cascade.add_cascade(LAYOUT, PAINT).unwrap();
cascade.remove_cascade(LAYOUT, PAINT);
assert!(cascade.is_empty());
}
#[test]
fn cascade_cycle_error_display() {
let err = CascadeCycleError {
from: LAYOUT,
to: PAINT,
};
let msg = alloc::format!("{err}");
assert!(msg.contains("cascade"));
assert!(msg.contains("cycle"));
}
}