1use ringbuf::traits::Consumer;
17
18use crate::{
19 arena::NodeId,
20 command::Command,
21 graph::DspGraph,
22 node::DspNode,
23 param::ParamBlock,
24 BUFFER_SIZE, MAX_COMMANDS_PER_TICK, MAX_INPUTS,
25};
26
27struct NodeTask {
37 output_buf_ptr: *mut [f32; BUFFER_SIZE],
38 params_ptr: *mut ParamBlock,
39 processor_ptr: *mut dyn DspNode,
40 inputs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS],
41}
42
43unsafe impl Send for NodeTask {}
48unsafe impl Sync for NodeTask {}
49
50pub struct Scheduler {
54 pub graph: DspGraph,
55 pub sample_rate: f32,
56 pub muted: bool,
57}
58
59impl Scheduler {
60 pub fn new(sample_rate: f32) -> Self {
61 Self {
62 graph: DspGraph::new(),
63 sample_rate,
64 muted: false,
65 }
66 }
67
68 pub fn process_block<C>(&mut self, cmd_consumer: &mut C, output: &mut [f32])
70 where
71 C: Consumer<Item = Command>,
72 {
73 let mut processed = 0;
74 while processed < MAX_COMMANDS_PER_TICK {
75 match cmd_consumer.try_pop() {
76 Some(cmd) => { self.apply_command(cmd); processed += 1; }
77 None => break,
78 }
79 }
80 self.process_graph(output);
81 }
82
83 pub fn process_block_simple(&mut self, output: &mut [f32]) {
87 self.process_graph(output);
88 }
89
90 fn process_graph(&mut self, output: &mut [f32]) {
91 let sr = self.sample_rate;
92 let level_count = self.graph.levels.len();
93
94 for level_idx in 0..level_count {
95 let level_len = self.graph.levels[level_idx].len();
96
97 if level_len == 0 {
98 continue;
99 } else if level_len == 1 {
100 let node_id = self.graph.levels[level_idx][0];
102 self.process_node(node_id, sr);
103 } else {
104 let mut tasks: Vec<NodeTask> = Vec::with_capacity(level_len);
112
113 for i in 0..level_len {
114 let node_id = self.graph.levels[level_idx][i];
115 let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] =
116 [None; MAX_INPUTS];
117
118 if let Some(record) = self.graph.arena.get(node_id) {
119 for (slot, maybe_src) in record.inputs.iter().enumerate() {
120 if let Some(src_id) = maybe_src {
121 if let Some(src_record) = self.graph.arena.get(*src_id) {
122 input_ptrs[slot] = Some(
123 self.graph.buffers.get(src_record.output_buffer)
124 as *const [f32; BUFFER_SIZE],
125 );
126 }
127 }
128 }
129 let record_mut = self.graph.arena.get_mut(node_id).unwrap();
130 let output_buf_ptr = self.graph.buffers.get_mut(record_mut.output_buffer)
131 as *mut [f32; BUFFER_SIZE];
132 let params_ptr = &mut record_mut.params as *mut ParamBlock;
133 let processor_ptr = &mut *record_mut.processor as *mut dyn DspNode;
134
135 tasks.push(NodeTask {
136 output_buf_ptr,
137 params_ptr,
138 processor_ptr,
139 inputs: input_ptrs,
140 });
141 }
142 }
143
144 rayon::scope(|s| {
148 for task in tasks.iter_mut() {
149 let ptr = task as *mut NodeTask as usize;
152 s.spawn(move |_| {
153 let t: &mut NodeTask = unsafe { &mut *(ptr as *mut NodeTask) };
155 let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
156 t.inputs.map(|p| p.map(|raw| unsafe { &*raw }));
157 unsafe {
158 (*t.processor_ptr).process(
159 &inputs,
160 &mut *t.output_buf_ptr,
161 &mut *t.params_ptr,
162 sr,
163 );
164 }
165 });
166 }
167 });
168 }
169 }
170
171 if self.muted {
173 output.fill(0.0);
174 return;
175 }
176 if let Some(out_id) = self.graph.output_node {
177 if let Some(record) = self.graph.arena.get(out_id) {
178 let buf = self.graph.buffers.get(record.output_buffer);
179 let frames = output.len() / 2;
180 for i in 0..frames.min(BUFFER_SIZE) {
181 output[i * 2] = buf[i];
182 output[i * 2 + 1] = buf[i];
183 }
184 }
185 } else {
186 output.fill(0.0);
188 }
189 }
190
191 fn process_node(&mut self, node_id: NodeId, sample_rate: f32) {
193 let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] = [None; MAX_INPUTS];
194
195 if let Some(record) = self.graph.arena.get(node_id) {
196 for (slot, maybe_src) in record.inputs.iter().enumerate() {
197 if let Some(src_id) = maybe_src {
198 if let Some(src_record) = self.graph.arena.get(*src_id) {
199 input_ptrs[slot] = Some(
200 self.graph.buffers.get(src_record.output_buffer)
201 as *const [f32; BUFFER_SIZE],
202 );
203 }
204 }
205 }
206 } else {
207 return;
208 }
209
210 let (output_buf_id, params_ptr, processor_ptr) = {
211 let record = self.graph.arena.get_mut(node_id).unwrap();
212 (
213 record.output_buffer,
214 &mut record.params as *mut ParamBlock,
215 &mut *record.processor as *mut dyn crate::node::DspNode,
216 )
217 };
218
219 let output_buf = self.graph.buffers.get_mut(output_buf_id);
220 let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
221 input_ptrs.map(|p| p.map(|ptr| unsafe { &*ptr }));
222
223 unsafe {
224 (*processor_ptr).process(&inputs, output_buf, &mut *params_ptr, sample_rate);
225 }
226 }
227
228 fn apply_command(&mut self, cmd: Command) {
229 match cmd {
230 Command::AddNode { id } => { let _ = id; }
231 Command::RemoveNode { id } => { self.graph.remove_node(id); }
232 Command::Connect { src, dst, slot } => { self.graph.connect(src, dst, slot); }
233 Command::Disconnect { dst, slot } => { self.graph.disconnect(dst, slot); }
234 Command::UpdateParam { node, param_index, new_param } => {
235 if let Some(record) = self.graph.arena.get_mut(node) {
236 if param_index < record.params.count {
237 record.params.params[param_index] = new_param;
238 }
239 }
240 }
241 Command::SetOutputNode { id } => { self.graph.set_output_node(id); }
242 Command::SetMute { muted } => { self.muted = muted; }
243 Command::ClearGraph => {
244 let ids: Vec<_> = self.graph.execution_order.clone();
245 for id in ids { self.graph.remove_node(id); }
246 self.graph.output_node = None;
247 }
248 }
249 }
250
251 #[cfg(test)]
254 fn process_graph_sequential(&mut self, output: &mut [f32]) {
255 let sr = self.sample_rate;
256
257 for &node_id in &self.graph.execution_order {
259 self.process_node(node_id, sr);
260 }
261
262 if self.muted {
264 output.fill(0.0);
265 return;
266 }
267 if let Some(out_id) = self.graph.output_node {
268 if let Some(record) = self.graph.arena.get(out_id) {
269 let buf = self.graph.buffers.get(record.output_buffer);
270 let frames = output.len() / 2;
271 for i in 0..frames.min(BUFFER_SIZE) {
272 output[i * 2] = buf[i];
273 output[i * 2 + 1] = buf[i];
274 }
275 }
276 } else {
277 output.fill(0.0);
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::node::DspNode;
286 use proptest::prelude::*;
287
288 struct TestNode {
291 gain: f32,
292 }
293
294 impl TestNode {
295 fn new(gain: f32) -> Self {
296 Self { gain }
297 }
298 }
299
300 impl DspNode for TestNode {
301 fn process(
302 &mut self,
303 inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
304 output: &mut [f32; BUFFER_SIZE],
305 _params: &mut ParamBlock,
306 _sample_rate: f32,
307 ) {
308 output.fill(0.0);
309 for input_opt in inputs.iter() {
310 if let Some(input) = input_opt {
311 for i in 0..BUFFER_SIZE {
312 output[i] += input[i] * self.gain;
313 }
314 }
315 }
316 }
317
318 fn type_name(&self) -> &'static str {
319 "TestNode"
320 }
321 }
322
323 proptest! {
325 #[test]
336 fn prop_parallel_equiv_sequential(
337 num_nodes in 1usize..=20,
338 edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50),
339 seed in any::<u64>(),
340 ) {
341 let mut scheduler_parallel = Scheduler::new(48000.0);
343 let mut scheduler_sequential = Scheduler::new(48000.0);
344
345 let mut node_ids = Vec::new();
346
347 for i in 0..num_nodes {
349 let gain = ((seed.wrapping_add(i as u64) % 100) as f32) / 100.0;
350
351 let id1 = scheduler_parallel.graph.add_node(Box::new(TestNode::new(gain)));
352 let id2 = scheduler_sequential.graph.add_node(Box::new(TestNode::new(gain)));
353
354 if let (Some(id1), Some(id2)) = (id1, id2) {
355 prop_assert_eq!(id1.index, id2.index);
357 prop_assert_eq!(id1.generation, id2.generation);
358 node_ids.push(id1);
359 }
360 }
361
362 for (src_idx, dst_idx, slot) in edges {
364 if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
365 let src = node_ids[src_idx];
366 let dst = node_ids[dst_idx];
367
368 scheduler_parallel.graph.connect(src, dst, slot);
369 scheduler_sequential.graph.connect(src, dst, slot);
370 }
371 }
372
373 if !node_ids.is_empty() {
375 let output_node = node_ids[num_nodes - 1];
376 scheduler_parallel.graph.set_output_node(output_node);
377 scheduler_sequential.graph.set_output_node(output_node);
378 }
379
380 let mut output_parallel = vec![0.0f32; BUFFER_SIZE * 2];
382 let mut output_sequential = vec![0.0f32; BUFFER_SIZE * 2];
383
384 scheduler_parallel.process_graph(&mut output_parallel);
386 scheduler_sequential.process_graph_sequential(&mut output_sequential);
387
388 for (i, (&p, &s)) in output_parallel.iter().zip(output_sequential.iter()).enumerate() {
390 prop_assert!(
391 p == s || (p.is_nan() && s.is_nan()),
392 "Output mismatch at sample {}: parallel={}, sequential={}",
393 i, p, s
394 );
395 }
396 }
397 }
398}