use std::{
collections::{HashMap, VecDeque, hash_map::Entry},
ops::Range,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(u32);
impl NodeId {
#[inline]
pub fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IndexMapping {
Pointwise,
LocalWindow { half: usize },
Scatter,
Reshape,
}
impl IndexMapping {
pub fn map_range(&self, input: Range<usize>) -> Range<usize> {
match self {
IndexMapping::Pointwise => input,
IndexMapping::LocalWindow { half } => {
let start = input.start.saturating_sub(*half);
let end = input.end.saturating_add(*half);
start..end
}
IndexMapping::Scatter | IndexMapping::Reshape => 0..usize::MAX,
}
}
#[inline]
pub fn is_blocking(&self) -> bool {
matches!(self, IndexMapping::Scatter | IndexMapping::Reshape)
}
}
#[derive(Debug, Clone)]
pub struct NodeDescriptor {
pub id: NodeId,
pub name: &'static str,
pub inputs: Vec<NodeId>,
pub mapping: IndexMapping,
}
#[derive(Debug, Default)]
pub struct Graph {
nodes: Vec<NodeDescriptor>,
consumers: Vec<Vec<NodeId>>,
}
impl Graph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(
&mut self,
name: &'static str,
inputs: &[NodeId],
mapping: IndexMapping,
) -> NodeId {
let id = NodeId(self.nodes.len() as u32);
for &inp in inputs {
assert!(
inp.index() < self.nodes.len(),
"Graph::add_node: input {:?} does not exist (graph has {} nodes)",
inp,
self.nodes.len(),
);
self.consumers[inp.index()].push(id);
}
self.consumers.push(Vec::new());
self.nodes.push(NodeDescriptor { id, name, inputs: inputs.to_vec(), mapping });
id
}
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn node(&self, id: NodeId) -> &NodeDescriptor {
&self.nodes[id.index()]
}
pub fn consumers_of(&self, id: NodeId) -> &[NodeId] {
&self.consumers[id.index()]
}
pub fn topological_order(&self) -> Vec<NodeId> {
let n = self.nodes.len();
let mut in_degree: Vec<usize> = self.nodes.iter().map(|d| d.inputs.len()).collect();
let mut queue: VecDeque<NodeId> = in_degree
.iter()
.enumerate()
.filter(|&(_, &d)| d == 0)
.map(|(i, _)| NodeId(i as u32))
.collect();
let mut order = Vec::with_capacity(n);
while let Some(id) = queue.pop_front() {
order.push(id);
for &consumer_id in &self.consumers[id.index()] {
in_degree[consumer_id.index()] -= 1;
if in_degree[consumer_id.index()] == 0 {
queue.push_back(consumer_id);
}
}
}
assert_eq!(
order.len(),
n,
"Graph::topological_order: cycle detected ({} / {} nodes processed)",
order.len(),
n,
);
order
}
pub fn propagate_dirty(
&self,
source: NodeId,
row_range: Range<usize>,
) -> HashMap<NodeId, Range<usize>> {
let mut dirty: HashMap<NodeId, Range<usize>> = HashMap::new();
let mut queue: VecDeque<NodeId> = VecDeque::new();
dirty.insert(source, row_range);
queue.push_back(source);
while let Some(node_id) = queue.pop_front() {
let range = dirty[&node_id].clone();
for &consumer_id in &self.consumers[node_id.index()] {
let consumer = &self.nodes[consumer_id.index()];
let mapped = consumer.mapping.map_range(range.clone());
match dirty.entry(consumer_id) {
Entry::Vacant(v) => {
v.insert(mapped);
queue.push_back(consumer_id);
}
Entry::Occupied(mut o) => {
let existing = o.get_mut();
let new_start = existing.start.min(mapped.start);
let new_end = existing.end.max(mapped.end);
if new_start < existing.start || new_end > existing.end {
*existing = new_start..new_end;
queue.push_back(consumer_id);
}
}
}
}
}
dirty
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_valid_topo_order(order: &[NodeId], graph: &Graph) -> bool {
let pos: HashMap<NodeId, usize> =
order.iter().enumerate().map(|(i, &id)| (id, i)).collect();
graph
.nodes
.iter()
.all(|desc| desc.inputs.iter().all(|&inp| pos[&inp] < pos[&desc.id]))
}
#[test]
fn mapping_pointwise_identity() {
assert_eq!(IndexMapping::Pointwise.map_range(100..200), 100..200);
}
#[test]
fn mapping_local_window_expands_range() {
let m = IndexMapping::LocalWindow { half: 10 };
assert_eq!(m.map_range(100..200), 90..210);
}
#[test]
fn mapping_local_window_saturates_at_zero() {
let m = IndexMapping::LocalWindow { half: 50 };
assert_eq!(m.map_range(10..20), 0..70); }
#[test]
fn mapping_scatter_is_full_range() {
assert_eq!(IndexMapping::Scatter.map_range(100..200), 0..usize::MAX);
}
#[test]
fn mapping_reshape_is_full_range() {
assert_eq!(IndexMapping::Reshape.map_range(100..200), 0..usize::MAX);
}
#[test]
fn mapping_is_blocking_flags() {
assert!(!IndexMapping::Pointwise.is_blocking());
assert!(!IndexMapping::LocalWindow { half: 1 }.is_blocking());
assert!(IndexMapping::Scatter.is_blocking());
assert!(IndexMapping::Reshape.is_blocking());
}
#[test]
fn add_and_retrieve_nodes() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
assert_eq!(g.node(a).name, "a");
assert_eq!(g.node(b).name, "b");
assert_eq!(g.node(b).inputs, vec![a]);
assert_eq!(g.consumers_of(a), &[b]);
}
#[test]
#[should_panic(expected = "does not exist")]
fn add_node_with_missing_input_panics() {
let mut g = Graph::new();
let phantom = NodeId(99);
g.add_node("x", &[phantom], IndexMapping::Pointwise);
}
#[test]
fn topo_order_single_node() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let order = g.topological_order();
assert_eq!(order, vec![a]);
}
#[test]
fn topo_order_linear_chain() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let c = g.add_node("c", &[b], IndexMapping::Pointwise);
let d = g.add_node("d", &[c], IndexMapping::Pointwise);
let order = g.topological_order();
assert!(is_valid_topo_order(&order, &g));
assert_eq!(order.len(), 4);
assert_eq!(order, vec![a, b, c, d]);
}
#[test]
fn topo_order_diamond() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let c = g.add_node("c", &[a], IndexMapping::Pointwise);
let d = g.add_node("d", &[b, c], IndexMapping::Pointwise);
let order = g.topological_order();
assert!(is_valid_topo_order(&order, &g));
assert_eq!(order.len(), 4);
assert_eq!(order[0], a);
assert_eq!(order[3], d);
}
#[test]
fn propagate_dirty_source_only() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let dirty = g.propagate_dirty(a, 100..200);
assert_eq!(dirty.len(), 1);
assert_eq!(dirty[&a], 100..200);
}
#[test]
fn propagate_dirty_simple_pointwise_chain() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let c = g.add_node("c", &[b], IndexMapping::Pointwise);
let dirty = g.propagate_dirty(a, 50..150);
assert_eq!(dirty[&a], 50..150);
assert_eq!(dirty[&b], 50..150);
assert_eq!(dirty[&c], 50..150);
}
#[test]
fn propagate_dirty_local_window_expands() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::LocalWindow { half: 10 });
let dirty = g.propagate_dirty(a, 100..200);
assert_eq!(dirty[&b], 90..210);
}
#[test]
fn propagate_dirty_blocking_node_full_output() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Scatter);
let dirty = g.propagate_dirty(a, 100..200);
assert_eq!(dirty[&b], 0..usize::MAX);
}
#[test]
fn propagate_dirty_five_node_mixed_graph() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let c = g.add_node("c", &[b], IndexMapping::LocalWindow { half: 10 });
let d = g.add_node("d", &[b], IndexMapping::Scatter);
let e = g.add_node("e", &[c], IndexMapping::Pointwise);
let dirty = g.propagate_dirty(a, 100..200);
assert_eq!(dirty[&a], 100..200);
assert_eq!(dirty[&b], 100..200);
assert_eq!(dirty[&c], 90..210);
assert_eq!(dirty[&d], 0..usize::MAX);
assert_eq!(dirty[&e], 90..210);
assert_eq!(dirty.len(), 5);
}
#[test]
fn propagate_dirty_diamond_merges_ranges() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let c = g.add_node("c", &[a], IndexMapping::LocalWindow { half: 10 });
let d = g.add_node("d", &[b, c], IndexMapping::Pointwise);
let dirty = g.propagate_dirty(a, 100..200);
assert_eq!(dirty[&b], 100..200);
assert_eq!(dirty[&c], 90..210);
assert_eq!(dirty[&d], 90..210);
}
#[test]
fn propagate_dirty_does_not_include_unrelated_nodes() {
let mut g = Graph::new();
let a = g.add_node("a", &[], IndexMapping::Pointwise);
let b = g.add_node("b", &[a], IndexMapping::Pointwise);
let _c = g.add_node("c", &[b], IndexMapping::Pointwise);
let d = g.add_node("d", &[], IndexMapping::Pointwise);
let dirty = g.propagate_dirty(d, 0..10);
assert_eq!(dirty.len(), 1);
assert!(dirty.contains_key(&d));
}
}