1use ringbuf::traits::Consumer;
17
18use crate::{
19 arena::NodeId,
20 command::Command,
21 graph::DspGraph,
22 node::DspNode,
23 param::ParamBlock,
24 BUFFER_SIZE, MAX_COMMANDS_PER_TICK, MAX_INPUTS,
25};
26
27struct NodeTask {
37 output_buf_ptr: *mut [f32; BUFFER_SIZE],
38 params_ptr: *mut ParamBlock,
39 processor_ptr: *mut dyn DspNode,
40 inputs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS],
41}
42
43unsafe impl Send for NodeTask {}
49unsafe impl Sync for NodeTask {}
50
51pub struct Scheduler {
55 pub graph: DspGraph,
56 pub sample_rate: f32,
57 pub muted: bool,
58}
59
60impl Scheduler {
61 pub fn new(sample_rate: f32) -> Self {
62 Self {
63 graph: DspGraph::new(),
64 sample_rate,
65 muted: false,
66 }
67 }
68
69 pub fn process_block<C>(&mut self, cmd_consumer: &mut C, output: &mut [f32])
71 where
72 C: Consumer<Item = Command>,
73 {
74 let mut processed = 0;
75 while processed < MAX_COMMANDS_PER_TICK {
76 match cmd_consumer.try_pop() {
77 Some(cmd) => { self.apply_command(cmd); processed += 1; }
78 None => break,
79 }
80 }
81 self.process_graph(output);
82 }
83
84 pub fn process_block_simple(&mut self, output: &mut [f32]) {
88 self.process_graph(output);
89 }
90
91 fn process_graph(&mut self, output: &mut [f32]) {
92 let sr = self.sample_rate;
93 let level_count = self.graph.levels.len();
94
95 for level_idx in 0..level_count {
96 let level_len = self.graph.levels[level_idx].len();
97
98 if level_len == 0 {
99 continue;
100 } else if level_len == 1 {
101 let node_id = self.graph.levels[level_idx][0];
103 self.process_node(node_id, sr);
104 } else {
105 let mut tasks: Vec<NodeTask> = Vec::with_capacity(level_len);
113
114 for i in 0..level_len {
115 let node_id = self.graph.levels[level_idx][i];
116 let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] =
117 [None; MAX_INPUTS];
118
119 if let Some(record) = self.graph.arena.get(node_id) {
120 for (slot, maybe_src) in record.inputs.iter().enumerate() {
121 if let Some(src_id) = maybe_src {
122 if let Some(src_record) = self.graph.arena.get(*src_id) {
123 input_ptrs[slot] = Some(
124 self.graph.buffers.get(src_record.output_buffer)
125 as *const [f32; BUFFER_SIZE],
126 );
127 }
128 }
129 }
130 let record_mut = self.graph.arena.get_mut(node_id).unwrap();
131 let output_buf_ptr = self.graph.buffers.get_mut(record_mut.output_buffer)
132 as *mut [f32; BUFFER_SIZE];
133 let params_ptr = &mut record_mut.params as *mut ParamBlock;
134 let processor_ptr = &mut *record_mut.processor as *mut dyn DspNode;
135
136 tasks.push(NodeTask {
137 output_buf_ptr,
138 params_ptr,
139 processor_ptr,
140 inputs: input_ptrs,
141 });
142 }
143 }
144
145 rayon::scope(|s| {
149 for task in tasks.iter_mut() {
150 let ptr = task as *mut NodeTask as usize;
153 s.spawn(move |_| {
154 let t: &mut NodeTask = unsafe { &mut *(ptr as *mut NodeTask) };
156 let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
157 t.inputs.map(|p| p.map(|raw| unsafe { &*raw }));
158 unsafe {
159 (*t.processor_ptr).process(
160 &inputs,
161 &mut *t.output_buf_ptr,
162 &mut *t.params_ptr,
163 sr,
164 );
165 }
166 });
167 }
168 });
169 }
170 }
171
172 if self.muted {
174 output.fill(0.0);
175 return;
176 }
177 if let Some(out_id) = self.graph.output_node {
178 if let Some(record) = self.graph.arena.get(out_id) {
179 let buf = self.graph.buffers.get(record.output_buffer);
180 let frames = output.len() / 2;
181 for i in 0..frames.min(BUFFER_SIZE) {
182 output[i * 2] = buf[i];
183 output[i * 2 + 1] = buf[i];
184 }
185 }
186 } else {
187 output.fill(0.0);
189 }
190 }
191
192 fn process_node(&mut self, node_id: NodeId, sample_rate: f32) {
194 let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] = [None; MAX_INPUTS];
195
196 if let Some(record) = self.graph.arena.get(node_id) {
197 for (slot, maybe_src) in record.inputs.iter().enumerate() {
198 if let Some(src_id) = maybe_src {
199 if let Some(src_record) = self.graph.arena.get(*src_id) {
200 input_ptrs[slot] = Some(
201 self.graph.buffers.get(src_record.output_buffer)
202 as *const [f32; BUFFER_SIZE],
203 );
204 }
205 }
206 }
207 } else {
208 return;
209 }
210
211 let (output_buf_id, params_ptr, processor_ptr) = {
212 let record = self.graph.arena.get_mut(node_id).unwrap();
213 (
214 record.output_buffer,
215 &mut record.params as *mut ParamBlock,
216 &mut *record.processor as *mut dyn crate::node::DspNode,
217 )
218 };
219
220 let output_buf = self.graph.buffers.get_mut(output_buf_id);
221 let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
222 input_ptrs.map(|p| p.map(|ptr| unsafe { &*ptr }));
223
224 unsafe {
225 (*processor_ptr).process(&inputs, output_buf, &mut *params_ptr, sample_rate);
226 }
227 }
228
229 fn apply_command(&mut self, cmd: Command) {
230 match cmd {
231 Command::AddNode { id } => { let _ = id; }
232 Command::RemoveNode { id } => { self.graph.remove_node(id); }
233 Command::Connect { src, dst, slot } => { self.graph.connect(src, dst, slot); }
234 Command::Disconnect { dst, slot } => { self.graph.disconnect(dst, slot); }
235 Command::UpdateParam { node, param_index, new_param } => {
236 if let Some(record) = self.graph.arena.get_mut(node) {
237 if param_index < record.params.count {
238 record.params.params[param_index] = new_param;
239 }
240 }
241 }
242 Command::SetOutputNode { id } => { self.graph.set_output_node(id); }
243 Command::SetMute { muted } => { self.muted = muted; }
244 Command::ClearGraph => {
245 let ids: Vec<_> = self.graph.execution_order.clone();
246 for id in ids { self.graph.remove_node(id); }
247 self.graph.output_node = None;
248 }
249 }
250 }
251
252 #[cfg(test)]
255 fn process_graph_sequential(&mut self, output: &mut [f32]) {
256 let sr = self.sample_rate;
257
258 let order: Vec<NodeId> = self.graph.execution_order.clone();
262 for &node_id in &order {
263 self.process_node(node_id, sr);
264 }
265
266 if self.muted {
268 output.fill(0.0);
269 return;
270 }
271 if let Some(out_id) = self.graph.output_node {
272 if let Some(record) = self.graph.arena.get(out_id) {
273 let buf = self.graph.buffers.get(record.output_buffer);
274 let frames = output.len() / 2;
275 for i in 0..frames.min(BUFFER_SIZE) {
276 output[i * 2] = buf[i];
277 output[i * 2 + 1] = buf[i];
278 }
279 }
280 } else {
281 output.fill(0.0);
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::node::DspNode;
290 use proptest::prelude::*;
291
292 struct TestNode {
295 gain: f32,
296 }
297
298 impl TestNode {
299 fn new(gain: f32) -> Self {
300 Self { gain }
301 }
302 }
303
304 impl DspNode for TestNode {
305 fn process(
306 &mut self,
307 inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
308 output: &mut [f32; BUFFER_SIZE],
309 _params: &mut ParamBlock,
310 _sample_rate: f32,
311 ) {
312 output.fill(0.0);
313 for input_opt in inputs.iter() {
314 if let Some(input) = input_opt {
315 for i in 0..BUFFER_SIZE {
316 output[i] += input[i] * self.gain;
317 }
318 }
319 }
320 }
321
322 fn type_name(&self) -> &'static str {
323 "TestNode"
324 }
325 }
326
327 proptest! {
329 #[test]
340 fn prop_parallel_equiv_sequential(
341 num_nodes in 1usize..=20,
342 edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50),
343 seed in any::<u64>(),
344 ) {
345 let mut scheduler_parallel = Scheduler::new(48000.0);
347 let mut scheduler_sequential = Scheduler::new(48000.0);
348
349 let mut node_ids = Vec::new();
350
351 for i in 0..num_nodes {
353 let gain = ((seed.wrapping_add(i as u64) % 100) as f32) / 100.0;
354
355 let id1 = scheduler_parallel.graph.add_node(Box::new(TestNode::new(gain)));
356 let id2 = scheduler_sequential.graph.add_node(Box::new(TestNode::new(gain)));
357
358 if let (Some(id1), Some(id2)) = (id1, id2) {
359 prop_assert_eq!(id1.index, id2.index);
361 prop_assert_eq!(id1.generation, id2.generation);
362 node_ids.push(id1);
363 }
364 }
365
366 for (src_idx, dst_idx, slot) in edges {
368 if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
369 let src = node_ids[src_idx];
370 let dst = node_ids[dst_idx];
371
372 scheduler_parallel.graph.connect(src, dst, slot);
373 scheduler_sequential.graph.connect(src, dst, slot);
374 }
375 }
376
377 if !node_ids.is_empty() {
379 let output_node = node_ids[num_nodes - 1];
380 scheduler_parallel.graph.set_output_node(output_node);
381 scheduler_sequential.graph.set_output_node(output_node);
382 }
383
384 let mut output_parallel = vec![0.0f32; BUFFER_SIZE * 2];
386 let mut output_sequential = vec![0.0f32; BUFFER_SIZE * 2];
387
388 scheduler_parallel.process_graph(&mut output_parallel);
390 scheduler_sequential.process_graph_sequential(&mut output_sequential);
391
392 for (i, (&p, &s)) in output_parallel.iter().zip(output_sequential.iter()).enumerate() {
394 prop_assert!(
395 p == s || (p.is_nan() && s.is_nan()),
396 "Output mismatch at sample {}: parallel={}, sequential={}",
397 i, p, s
398 );
399 }
400 }
401 }
402}