1use crate::{
7 arena::{Arena, NodeId},
8 buffer_pool::BufferPool,
9 node::{DspNode, NodeRecord},
10 MAX_NODES,
11};
12use std::collections::HashMap;
13
14pub struct DspGraph {
16 pub arena: Arena<NodeRecord>,
17 pub buffers: BufferPool,
18 pub execution_order: Vec<NodeId>,
20 pub levels: Vec<Vec<NodeId>>,
23 pub output_node: Option<NodeId>,
25 forward_edges: HashMap<u32, Vec<(NodeId, usize)>>,
27 index_to_id: HashMap<u32, NodeId>,
29}
30
31impl DspGraph {
32 pub fn new() -> Self {
33 Self {
34 arena: Arena::with_capacity(MAX_NODES),
35 buffers: BufferPool::default(),
36 execution_order: Vec::with_capacity(MAX_NODES),
37 levels: Vec::with_capacity(MAX_NODES),
38 output_node: None,
39 forward_edges: HashMap::new(),
40 index_to_id: HashMap::new(),
41 }
42 }
43
44 pub fn add_node(&mut self, processor: Box<dyn DspNode>) -> Option<NodeId> {
46 let buf = self.buffers.acquire()?;
47 let record = NodeRecord::new(processor, buf);
48 let id = self.arena.insert(record)?;
49 self.forward_edges.insert(id.index, Vec::new());
50 self.index_to_id.insert(id.index, id);
51 self.rebuild_execution_order();
52 Some(id)
53 }
54
55 pub fn remove_node(&mut self, id: NodeId) -> bool {
57 if let Some(record) = self.arena.remove(id) {
58 self.buffers.release(record.output_buffer);
59 self.forward_edges.remove(&id.index);
60 self.index_to_id.remove(&id.index);
61 for edges in self.forward_edges.values_mut() {
62 edges.retain(|(dst, _)| dst.index != id.index);
63 }
64 self.rebuild_execution_order();
65 true
66 } else {
67 false
68 }
69 }
70
71 pub fn connect(&mut self, src: NodeId, dst: NodeId, slot: usize) -> bool {
73 if self.arena.get(src).is_none() || self.arena.get(dst).is_none() {
74 return false;
75 }
76 if let Some(edges) = self.forward_edges.get_mut(&src.index) {
78 edges.push((dst, slot));
79 }
80 if let Some(record) = self.arena.get_mut(dst) {
82 record.inputs[slot] = Some(src);
83 }
84 self.rebuild_execution_order();
85 true
86 }
87
88 pub fn disconnect(&mut self, dst: NodeId, slot: usize) -> bool {
90 let src_id = self.arena.get(dst).and_then(|r| r.inputs[slot]);
91 if let Some(src) = src_id {
92 if let Some(edges) = self.forward_edges.get_mut(&src.index) {
93 edges.retain(|(d, s)| !(d.index == dst.index && *s == slot));
94 }
95 }
96 if let Some(record) = self.arena.get_mut(dst) {
97 record.inputs[slot] = None;
98 self.rebuild_execution_order();
99 true
100 } else {
101 false
102 }
103 }
104
105 fn rebuild_execution_order(&mut self) {
107 self.execution_order.clear();
108 self.levels.clear();
109
110 let mut in_degree: HashMap<u32, usize> = self.index_to_id.keys().map(|&k| (k, 0)).collect();
112 for edges in self.forward_edges.values() {
113 for (dst, _) in edges {
114 *in_degree.entry(dst.index).or_insert(0) += 1;
115 }
116 }
117
118 let mut current_wave: Vec<u32> = in_degree
120 .iter()
121 .filter(|(_, °)| deg == 0)
122 .map(|(&idx, _)| idx)
123 .collect();
124
125 while !current_wave.is_empty() {
126 let mut level_ids: Vec<NodeId> = Vec::with_capacity(current_wave.len());
127 let mut next_wave: Vec<u32> = Vec::new();
128
129 for idx in ¤t_wave {
130 if let Some(&id) = self.index_to_id.get(idx) {
131 level_ids.push(id);
132 self.execution_order.push(id);
133 }
134 if let Some(edges) = self.forward_edges.get(idx) {
135 for (dst, _) in edges.clone() {
136 let deg = in_degree.entry(dst.index).or_insert(0);
137 if *deg > 0 {
138 *deg -= 1;
139 if *deg == 0 {
140 next_wave.push(dst.index);
141 }
142 }
143 }
144 }
145 }
146
147 self.levels.push(level_ids);
148 current_wave = next_wave;
149 }
150 }
151
152 pub fn set_output_node(&mut self, id: NodeId) {
153 self.output_node = Some(id);
154 }
155}
156
157impl Default for DspGraph {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::{node::DspNode, param::ParamBlock, state::StateBlob, BUFFER_SIZE, MAX_INPUTS};
167 use proptest::prelude::*;
168
169 struct TestNode;
171
172 impl DspNode for TestNode {
173 fn process(
174 &mut self,
175 _inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
176 output: &mut [f32; BUFFER_SIZE],
177 _params: &mut ParamBlock,
178 _sample_rate: f32,
179 ) {
180 output.fill(0.0);
181 }
182
183 fn type_name(&self) -> &'static str {
184 "TestNode"
185 }
186 }
187
188 proptest! {
190 #[test]
198 fn prop_topological_level_ordering_invariant(
199 num_nodes in 1usize..=20,
200 edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50)
201 ) {
202 let mut graph = DspGraph::new();
203 let mut node_ids = Vec::new();
204
205 for _ in 0..num_nodes {
207 if let Some(id) = graph.add_node(Box::new(TestNode)) {
208 node_ids.push(id);
209 }
210 }
211
212 for (src_idx, dst_idx, slot) in edges {
214 if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
215 let src = node_ids[src_idx];
216 let dst = node_ids[dst_idx];
217 graph.connect(src, dst, slot);
218 }
219 }
220
221 let mut node_to_level: HashMap<u32, usize> = HashMap::new();
223 for (level_idx, level_nodes) in graph.levels.iter().enumerate() {
224 for &node_id in level_nodes {
225 node_to_level.insert(node_id.index, level_idx);
226 }
227 }
228
229 for (src_idx, dst_idx, slot) in edges {
231 if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
232 let src = node_ids[src_idx];
233 let dst = node_ids[dst_idx];
234
235 if let Some(record) = graph.arena.get(dst) {
237 if record.inputs[slot] == Some(src) {
238 let src_level = node_to_level.get(&src.index).copied();
240 let dst_level = node_to_level.get(&dst.index).copied();
241
242 if let (Some(src_lvl), Some(dst_lvl)) = (src_level, dst_level) {
243 prop_assert!(
244 src_lvl < dst_lvl,
245 "Level ordering violated: node {} at level {} → node {} at level {}",
246 src.index, src_lvl, dst.index, dst_lvl
247 );
248 }
249 }
250 }
251 }
252 }
253 }
254 }
255}