Skip to main content

aether_core/
graph.rs

1//! Directed Acyclic Graph (DAG) for DSP routing.
2//!
3//! The graph owns the node arena and buffer pool.
4//! Topological sort produces a flat execution order — no recursion in the RT path.
5
6use crate::{
7    arena::{Arena, NodeId},
8    buffer_pool::BufferPool,
9    node::{DspNode, NodeRecord},
10    MAX_NODES,
11};
12use std::collections::HashMap;
13
14/// The DSP graph. Lives on the RT thread after initial construction.
15pub struct DspGraph {
16    pub arena: Arena<NodeRecord>,
17    pub buffers: BufferPool,
18    /// Topologically sorted execution order. Rebuilt on structural mutations.
19    pub execution_order: Vec<NodeId>,
20    /// BFS wave levels: each inner Vec contains nodes that can execute in parallel.
21    /// Level[i] nodes all depend only on nodes in levels 0..i.
22    pub levels: Vec<Vec<NodeId>>,
23    /// The node whose output buffer is sent to the DAC.
24    pub output_node: Option<NodeId>,
25    /// Adjacency list: node index → list of (dst_node, slot) it feeds into.
26    forward_edges: HashMap<u32, Vec<(NodeId, usize)>>,
27    /// Maps slot index → full NodeId (for topo sort without generation scanning).
28    index_to_id: HashMap<u32, NodeId>,
29}
30
31impl DspGraph {
32    pub fn new() -> Self {
33        Self {
34            arena: Arena::with_capacity(MAX_NODES),
35            buffers: BufferPool::default(),
36            execution_order: Vec::with_capacity(MAX_NODES),
37            levels: Vec::with_capacity(MAX_NODES),
38            output_node: None,
39            forward_edges: HashMap::new(),
40            index_to_id: HashMap::new(),
41        }
42    }
43
44    /// Add a node to the graph. Returns its NodeId.
45    pub fn add_node(&mut self, processor: Box<dyn DspNode>) -> Option<NodeId> {
46        let buf = self.buffers.acquire()?;
47        let record = NodeRecord::new(processor, buf);
48        let id = self.arena.insert(record)?;
49        self.forward_edges.insert(id.index, Vec::new());
50        self.index_to_id.insert(id.index, id);
51        self.rebuild_execution_order();
52        Some(id)
53    }
54
55    /// Remove a node, releasing its buffer.
56    pub fn remove_node(&mut self, id: NodeId) -> bool {
57        if let Some(record) = self.arena.remove(id) {
58            self.buffers.release(record.output_buffer);
59            self.forward_edges.remove(&id.index);
60            self.index_to_id.remove(&id.index);
61            for edges in self.forward_edges.values_mut() {
62                edges.retain(|(dst, _)| dst.index != id.index);
63            }
64            self.rebuild_execution_order();
65            true
66        } else {
67            false
68        }
69    }
70
71    /// Connect src output → dst input[slot].
72    pub fn connect(&mut self, src: NodeId, dst: NodeId, slot: usize) -> bool {
73        if self.arena.get(src).is_none() || self.arena.get(dst).is_none() {
74            return false;
75        }
76        // Record forward edge for topo sort.
77        if let Some(edges) = self.forward_edges.get_mut(&src.index) {
78            edges.push((dst, slot));
79        }
80        // Record backward reference in dst node.
81        if let Some(record) = self.arena.get_mut(dst) {
82            record.inputs[slot] = Some(src);
83        }
84        self.rebuild_execution_order();
85        true
86    }
87
88    /// Disconnect dst input[slot].
89    pub fn disconnect(&mut self, dst: NodeId, slot: usize) -> bool {
90        let src_id = self.arena.get(dst).and_then(|r| r.inputs[slot]);
91        if let Some(src) = src_id {
92            if let Some(edges) = self.forward_edges.get_mut(&src.index) {
93                edges.retain(|(d, s)| !(d.index == dst.index && *s == slot));
94            }
95        }
96        if let Some(record) = self.arena.get_mut(dst) {
97            record.inputs[slot] = None;
98            self.rebuild_execution_order();
99            true
100        } else {
101            false
102        }
103    }
104
105    /// Kahn's algorithm topological sort. O(V+E), bounded by MAX_NODES.
106    fn rebuild_execution_order(&mut self) {
107        self.execution_order.clear();
108        self.levels.clear();
109
110        // Compute in-degrees from forward edges.
111        let mut in_degree: HashMap<u32, usize> = self.index_to_id.keys().map(|&k| (k, 0)).collect();
112        for edges in self.forward_edges.values() {
113            for (dst, _) in edges {
114                *in_degree.entry(dst.index).or_insert(0) += 1;
115            }
116        }
117
118        // Seed the first wave: all nodes with in-degree 0.
119        let mut current_wave: Vec<u32> = in_degree
120            .iter()
121            .filter(|(_, &deg)| deg == 0)
122            .map(|(&idx, _)| idx)
123            .collect();
124
125        while !current_wave.is_empty() {
126            let mut level_ids: Vec<NodeId> = Vec::with_capacity(current_wave.len());
127            let mut next_wave: Vec<u32> = Vec::new();
128
129            for idx in &current_wave {
130                if let Some(&id) = self.index_to_id.get(idx) {
131                    level_ids.push(id);
132                    self.execution_order.push(id);
133                }
134                if let Some(edges) = self.forward_edges.get(idx) {
135                    for (dst, _) in edges.clone() {
136                        let deg = in_degree.entry(dst.index).or_insert(0);
137                        if *deg > 0 {
138                            *deg -= 1;
139                            if *deg == 0 {
140                                next_wave.push(dst.index);
141                            }
142                        }
143                    }
144                }
145            }
146
147            self.levels.push(level_ids);
148            current_wave = next_wave;
149        }
150    }
151
152    pub fn set_output_node(&mut self, id: NodeId) {
153        self.output_node = Some(id);
154    }
155}
156
157impl Default for DspGraph {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::{node::DspNode, param::ParamBlock, BUFFER_SIZE, MAX_INPUTS};
167    use proptest::prelude::*;
168
169    /// Minimal test node for graph topology testing.
170    struct TestNode;
171
172    impl DspNode for TestNode {
173        fn process(
174            &mut self,
175            _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
176            output: &mut [f32; BUFFER_SIZE],
177            _params: &mut ParamBlock,
178            _sample_rate: f32,
179        ) {
180            output.fill(0.0);
181        }
182
183        fn type_name(&self) -> &'static str {
184            "TestNode"
185        }
186    }
187
188    // Property 2
189    proptest! {
190        /// **Validates: Requirements 1.2, 1.9**
191        ///
192        /// Property 2: Topological level assignments satisfy the dependency ordering invariant.
193        ///
194        /// For any DAG after `rebuild_execution_order`, every node at level L SHALL have all
195        /// its input-connected nodes at levels strictly less than L. Equivalently, no node at
196        /// level L depends on any other node at level L.
197        #[test]
198        fn prop_topological_level_ordering_invariant(
199            num_nodes in 1usize..=20,
200            edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50)
201        ) {
202            let mut graph = DspGraph::new();
203            let mut node_ids = Vec::new();
204
205            // Add nodes
206            for _ in 0..num_nodes {
207                if let Some(id) = graph.add_node(Box::new(TestNode)) {
208                    node_ids.push(id);
209                }
210            }
211
212            // Add edges, filtering to maintain DAG invariant (src < dst to prevent cycles)
213            for &(src_idx, dst_idx, slot) in &edges {
214                if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
215                    let src = node_ids[src_idx];
216                    let dst = node_ids[dst_idx];
217                    graph.connect(src, dst, slot);
218                }
219            }
220
221            // Build a map from NodeId to level index
222            let mut node_to_level: HashMap<u32, usize> = HashMap::new();
223            for (level_idx, level_nodes) in graph.levels.iter().enumerate() {
224                for &node_id in level_nodes {
225                    node_to_level.insert(node_id.index, level_idx);
226                }
227            }
228
229            // Verify the invariant: for every edge (src → dst), level[src] < level[dst]
230            for &(src_idx, dst_idx, slot) in &edges {
231                if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
232                    let src = node_ids[src_idx];
233                    let dst = node_ids[dst_idx];
234
235                    // Check if the edge was actually added (connect may fail if slot already occupied)
236                    if let Some(record) = graph.arena.get(dst) {
237                        if record.inputs[slot] == Some(src) {
238                            // Edge exists, verify level ordering
239                            let src_level = node_to_level.get(&src.index).copied();
240                            let dst_level = node_to_level.get(&dst.index).copied();
241
242                            if let (Some(src_lvl), Some(dst_lvl)) = (src_level, dst_level) {
243                                prop_assert!(
244                                    src_lvl < dst_lvl,
245                                    "Level ordering violated: node {} at level {} → node {} at level {}",
246                                    src.index, src_lvl, dst.index, dst_lvl
247                                );
248                            }
249                        }
250                    }
251                }
252            }
253        }
254    }
255}