Skip to main content

aether_core/
scheduler.rs

1//! Real-time audio scheduler.
2//!
3//! This is the hot path. It:
4//!   1. Drains bounded commands from the SPSC ring.
5//!   2. Executes the topologically sorted node list level by level.
6//!      Nodes within the same BFS level are independent and run in parallel
7//!      via Rayon's work-stealing thread pool.
8//!   3. Copies the output node's buffer to the DAC output.
9//!
10//! HARD RT RULES enforced here:
11//!   - No allocation (Vec<NodeTask> is pre-allocated per level, bounded by MAX_NODES)
12//!   - No locks
13//!   - No I/O
14//!   - No unbounded loops
15
16use ringbuf::traits::Consumer;
17
18use crate::{
19    arena::NodeId,
20    command::Command,
21    graph::DspGraph,
22    node::DspNode,
23    param::ParamBlock,
24    BUFFER_SIZE, MAX_COMMANDS_PER_TICK, MAX_INPUTS,
25};
26
27// ── Parallel dispatch helpers ─────────────────────────────────────────────────
28
29/// Per-node data bundle collected before parallel dispatch.
30///
31/// SAFETY INVARIANT: Within a single BFS level, every node writes to a distinct
32/// `BufferId` (guaranteed by the DAG structure — no two nodes in the same level
33/// share an output buffer). The `BufferPool` stores buffers in a flat `Vec`, so
34/// tasks writing to different `BufferId`s write to non-overlapping index ranges.
35/// This makes the concurrent writes safe despite using raw pointers.
36struct NodeTask {
37    output_buf_ptr: *mut [f32; BUFFER_SIZE],
38    params_ptr: *mut ParamBlock,
39    processor_ptr: *mut dyn DspNode,
40    inputs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS],
41}
42
43/// SAFETY: Within a BFS level each task accesses disjoint memory:
44/// - distinct output buffer (different BufferId → different Vec index range)
45/// - distinct processor and params (each belongs to exactly one NodeRecord)
46/// No two tasks in the same level share any pointed-to memory.
47unsafe impl Send for NodeTask {}
48unsafe impl Sync for NodeTask {}
49
50// ── Scheduler ─────────────────────────────────────────────────────────────────
51
52/// The RT scheduler. Owns the graph and processes audio callbacks.
53pub struct Scheduler {
54    pub graph: DspGraph,
55    pub sample_rate: f32,
56    pub muted: bool,
57}
58
59impl Scheduler {
60    pub fn new(sample_rate: f32) -> Self {
61        Self {
62            graph: DspGraph::new(),
63            sample_rate,
64            muted: false,
65        }
66    }
67
68    /// Called once per audio callback from the CPAL stream.
69    pub fn process_block<C>(&mut self, cmd_consumer: &mut C, output: &mut [f32])
70    where
71        C: Consumer<Item = Command>,
72    {
73        let mut processed = 0;
74        while processed < MAX_COMMANDS_PER_TICK {
75            match cmd_consumer.try_pop() {
76                Some(cmd) => { self.apply_command(cmd); processed += 1; }
77                None => break,
78            }
79        }
80        self.process_graph(output);
81    }
82
83    /// Simplified process block — no ring buffer.
84    /// Used when the scheduler is shared via Arc<Mutex<>> and the control thread
85    /// mutates it directly. The audio thread calls this after acquiring try_lock().
86    pub fn process_block_simple(&mut self, output: &mut [f32]) {
87        self.process_graph(output);
88    }
89
90    fn process_graph(&mut self, output: &mut [f32]) {
91        let sr = self.sample_rate;
92        let level_count = self.graph.levels.len();
93
94        for level_idx in 0..level_count {
95            let level_len = self.graph.levels[level_idx].len();
96
97            if level_len == 0 {
98                continue;
99            } else if level_len == 1 {
100                // Zero-overhead path: single node, no Rayon overhead.
101                let node_id = self.graph.levels[level_idx][0];
102                self.process_node(node_id, sr);
103            } else {
104                // Parallel path: collect raw pointers while holding &mut self,
105                // then dispatch DSP work in parallel via rayon::scope.
106                //
107                // SAFETY: Within a BFS level, every node writes to a distinct
108                // output buffer (disjoint BufferId). The BufferPool stores buffers
109                // in a flat Vec; tasks write to non-overlapping index ranges.
110                // Each processor and ParamBlock belongs to exactly one node.
111                let mut tasks: Vec<NodeTask> = Vec::with_capacity(level_len);
112
113                for i in 0..level_len {
114                    let node_id = self.graph.levels[level_idx][i];
115                    let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] =
116                        [None; MAX_INPUTS];
117
118                    if let Some(record) = self.graph.arena.get(node_id) {
119                        for (slot, maybe_src) in record.inputs.iter().enumerate() {
120                            if let Some(src_id) = maybe_src {
121                                if let Some(src_record) = self.graph.arena.get(*src_id) {
122                                    input_ptrs[slot] = Some(
123                                        self.graph.buffers.get(src_record.output_buffer)
124                                            as *const [f32; BUFFER_SIZE],
125                                    );
126                                }
127                            }
128                        }
129                        let record_mut = self.graph.arena.get_mut(node_id).unwrap();
130                        let output_buf_ptr = self.graph.buffers.get_mut(record_mut.output_buffer)
131                            as *mut [f32; BUFFER_SIZE];
132                        let params_ptr = &mut record_mut.params as *mut ParamBlock;
133                        let processor_ptr = &mut *record_mut.processor as *mut dyn DspNode;
134
135                        tasks.push(NodeTask {
136                            output_buf_ptr,
137                            params_ptr,
138                            processor_ptr,
139                            inputs: input_ptrs,
140                        });
141                    }
142                }
143
144                // SAFETY: each element of `tasks` points to disjoint memory.
145                // We pass a raw pointer per task so each closure captures a
146                // distinct non-aliasing pointer.
147                rayon::scope(|s| {
148                    for task in tasks.iter_mut() {
149                        // Capture the raw pointer value (usize) to avoid the
150                        // borrow checker complaining about &mut Vec element borrows.
151                        let ptr = task as *mut NodeTask as usize;
152                        s.spawn(move |_| {
153                            // SAFETY: ptr is a valid, exclusively-owned NodeTask.
154                            let t: &mut NodeTask = unsafe { &mut *(ptr as *mut NodeTask) };
155                            let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
156                                t.inputs.map(|p| p.map(|raw| unsafe { &*raw }));
157                            unsafe {
158                                (*t.processor_ptr).process(
159                                    &inputs,
160                                    &mut *t.output_buf_ptr,
161                                    &mut *t.params_ptr,
162                                    sr,
163                                );
164                            }
165                        });
166                    }
167                });
168            }
169        }
170
171        // Copy output node buffer to DAC
172        if self.muted {
173            output.fill(0.0);
174            return;
175        }
176        if let Some(out_id) = self.graph.output_node {
177            if let Some(record) = self.graph.arena.get(out_id) {
178                let buf = self.graph.buffers.get(record.output_buffer);
179                let frames = output.len() / 2;
180                for i in 0..frames.min(BUFFER_SIZE) {
181                    output[i * 2] = buf[i];
182                    output[i * 2 + 1] = buf[i];
183                }
184            }
185        } else {
186            // INVARIANT: empty graph → silence.
187            output.fill(0.0);
188        }
189    }
190
191    /// Process a single node on the calling thread.
192    fn process_node(&mut self, node_id: NodeId, sample_rate: f32) {
193        let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] = [None; MAX_INPUTS];
194
195        if let Some(record) = self.graph.arena.get(node_id) {
196            for (slot, maybe_src) in record.inputs.iter().enumerate() {
197                if let Some(src_id) = maybe_src {
198                    if let Some(src_record) = self.graph.arena.get(*src_id) {
199                        input_ptrs[slot] = Some(
200                            self.graph.buffers.get(src_record.output_buffer)
201                                as *const [f32; BUFFER_SIZE],
202                        );
203                    }
204                }
205            }
206        } else {
207            return;
208        }
209
210        let (output_buf_id, params_ptr, processor_ptr) = {
211            let record = self.graph.arena.get_mut(node_id).unwrap();
212            (
213                record.output_buffer,
214                &mut record.params as *mut ParamBlock,
215                &mut *record.processor as *mut dyn crate::node::DspNode,
216            )
217        };
218
219        let output_buf = self.graph.buffers.get_mut(output_buf_id);
220        let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
221            input_ptrs.map(|p| p.map(|ptr| unsafe { &*ptr }));
222
223        unsafe {
224            (*processor_ptr).process(&inputs, output_buf, &mut *params_ptr, sample_rate);
225        }
226    }
227
228    fn apply_command(&mut self, cmd: Command) {
229        match cmd {
230            Command::AddNode { id } => { let _ = id; }
231            Command::RemoveNode { id } => { self.graph.remove_node(id); }
232            Command::Connect { src, dst, slot } => { self.graph.connect(src, dst, slot); }
233            Command::Disconnect { dst, slot } => { self.graph.disconnect(dst, slot); }
234            Command::UpdateParam { node, param_index, new_param } => {
235                if let Some(record) = self.graph.arena.get_mut(node) {
236                    if param_index < record.params.count {
237                        record.params.params[param_index] = new_param;
238                    }
239                }
240            }
241            Command::SetOutputNode { id } => { self.graph.set_output_node(id); }
242            Command::SetMute { muted } => { self.muted = muted; }
243            Command::ClearGraph => {
244                let ids: Vec<_> = self.graph.execution_order.clone();
245                for id in ids { self.graph.remove_node(id); }
246                self.graph.output_node = None;
247            }
248        }
249    }
250
251    /// Reference sequential implementation for testing.
252    /// Processes nodes in flat execution_order without parallelism.
253    #[cfg(test)]
254    fn process_graph_sequential(&mut self, output: &mut [f32]) {
255        let sr = self.sample_rate;
256
257        // Process nodes in flat execution order (sequential)
258        for &node_id in &self.graph.execution_order {
259            self.process_node(node_id, sr);
260        }
261
262        // Copy output node buffer to DAC
263        if self.muted {
264            output.fill(0.0);
265            return;
266        }
267        if let Some(out_id) = self.graph.output_node {
268            if let Some(record) = self.graph.arena.get(out_id) {
269                let buf = self.graph.buffers.get(record.output_buffer);
270                let frames = output.len() / 2;
271                for i in 0..frames.min(BUFFER_SIZE) {
272                    output[i * 2] = buf[i];
273                    output[i * 2 + 1] = buf[i];
274                }
275            }
276        } else {
277            output.fill(0.0);
278        }
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::node::DspNode;
286    use proptest::prelude::*;
287
288    /// Minimal deterministic test node for property testing.
289    /// Sums all inputs and multiplies by a fixed gain.
290    struct TestNode {
291        gain: f32,
292    }
293
294    impl TestNode {
295        fn new(gain: f32) -> Self {
296            Self { gain }
297        }
298    }
299
300    impl DspNode for TestNode {
301        fn process(
302            &mut self,
303            inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
304            output: &mut [f32; BUFFER_SIZE],
305            _params: &mut ParamBlock,
306            _sample_rate: f32,
307        ) {
308            output.fill(0.0);
309            for input_opt in inputs.iter() {
310                if let Some(input) = input_opt {
311                    for i in 0..BUFFER_SIZE {
312                        output[i] += input[i] * self.gain;
313                    }
314                }
315            }
316        }
317
318        fn type_name(&self) -> &'static str {
319            "TestNode"
320        }
321    }
322
323    // Property 1
324    proptest! {
325        /// **Validates: Requirements 1.1, 1.4**
326        ///
327        /// Feature: aether-engine-upgrades, Property 1: parallel execution is output-equivalent
328        ///
329        /// Property 1: Parallel execution is output-equivalent to sequential execution.
330        ///
331        /// For any valid DSP patch (any combination of nodes and edges forming a valid DAG),
332        /// processing a block with the parallel Rayon scheduler SHALL produce a bit-identical
333        /// output buffer to processing the same block with the original sequential scheduler,
334        /// given the same initial node state and the same input.
335        #[test]
336        fn prop_parallel_equiv_sequential(
337            num_nodes in 1usize..=20,
338            edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50),
339            seed in any::<u64>(),
340        ) {
341            // Create two identical schedulers
342            let mut scheduler_parallel = Scheduler::new(48000.0);
343            let mut scheduler_sequential = Scheduler::new(48000.0);
344
345            let mut node_ids = Vec::new();
346
347            // Add nodes to both schedulers with deterministic gains based on seed
348            for i in 0..num_nodes {
349                let gain = ((seed.wrapping_add(i as u64) % 100) as f32) / 100.0;
350                
351                let id1 = scheduler_parallel.graph.add_node(Box::new(TestNode::new(gain)));
352                let id2 = scheduler_sequential.graph.add_node(Box::new(TestNode::new(gain)));
353                
354                if let (Some(id1), Some(id2)) = (id1, id2) {
355                    // Verify both schedulers assigned the same NodeId
356                    prop_assert_eq!(id1.index, id2.index);
357                    prop_assert_eq!(id1.generation, id2.generation);
358                    node_ids.push(id1);
359                }
360            }
361
362            // Add edges to both schedulers (filter to maintain DAG invariant: src < dst)
363            for (src_idx, dst_idx, slot) in edges {
364                if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
365                    let src = node_ids[src_idx];
366                    let dst = node_ids[dst_idx];
367                    
368                    scheduler_parallel.graph.connect(src, dst, slot);
369                    scheduler_sequential.graph.connect(src, dst, slot);
370                }
371            }
372
373            // Set output node to the last node if we have any nodes
374            if !node_ids.is_empty() {
375                let output_node = node_ids[num_nodes - 1];
376                scheduler_parallel.graph.set_output_node(output_node);
377                scheduler_sequential.graph.set_output_node(output_node);
378            }
379
380            // Prepare output buffers (stereo, 64 frames = 128 samples)
381            let mut output_parallel = vec![0.0f32; BUFFER_SIZE * 2];
382            let mut output_sequential = vec![0.0f32; BUFFER_SIZE * 2];
383
384            // Process one block with both schedulers
385            scheduler_parallel.process_graph(&mut output_parallel);
386            scheduler_sequential.process_graph_sequential(&mut output_sequential);
387
388            // Assert bit-identical output
389            for (i, (&p, &s)) in output_parallel.iter().zip(output_sequential.iter()).enumerate() {
390                prop_assert!(
391                    p == s || (p.is_nan() && s.is_nan()),
392                    "Output mismatch at sample {}: parallel={}, sequential={}",
393                    i, p, s
394                );
395            }
396        }
397    }
398}