Skip to main content

aether_core/
scheduler.rs

1//! Real-time audio scheduler.
2//!
3//! This is the hot path. It:
4//!   1. Drains bounded commands from the SPSC ring.
5//!   2. Executes the topologically sorted node list level by level.
6//!      Nodes within the same BFS level are independent and run in parallel
7//!      via Rayon's work-stealing thread pool.
8//!   3. Copies the output node's buffer to the DAC output.
9//!
10//! HARD RT RULES enforced here:
11//!   - No allocation (Vec<NodeTask> is pre-allocated per level, bounded by MAX_NODES)
12//!   - No locks
13//!   - No I/O
14//!   - No unbounded loops
15
16use ringbuf::traits::Consumer;
17
18use crate::{
19    arena::NodeId,
20    command::Command,
21    graph::DspGraph,
22    node::DspNode,
23    param::ParamBlock,
24    BUFFER_SIZE, MAX_COMMANDS_PER_TICK, MAX_INPUTS,
25};
26
27// ── Parallel dispatch helpers ─────────────────────────────────────────────────
28
29/// Per-node data bundle collected before parallel dispatch.
30///
31/// SAFETY INVARIANT: Within a single BFS level, every node writes to a distinct
32/// `BufferId` (guaranteed by the DAG structure — no two nodes in the same level
33/// share an output buffer). The `BufferPool` stores buffers in a flat `Vec`, so
34/// tasks writing to different `BufferId`s write to non-overlapping index ranges.
35/// This makes the concurrent writes safe despite using raw pointers.
36struct NodeTask {
37    output_buf_ptr: *mut [f32; BUFFER_SIZE],
38    params_ptr: *mut ParamBlock,
39    processor_ptr: *mut dyn DspNode,
40    inputs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS],
41}
42
43/// SAFETY: Within a BFS level each task accesses disjoint memory:
44/// - distinct output buffer (different BufferId → different Vec index range)
45/// - distinct processor and params (each belongs to exactly one NodeRecord)
46///
47/// No two tasks in the same level share any pointed-to memory.
48unsafe impl Send for NodeTask {}
49unsafe impl Sync for NodeTask {}
50
51// ── Scheduler ─────────────────────────────────────────────────────────────────
52
53/// The RT scheduler. Owns the graph and processes audio callbacks.
54pub struct Scheduler {
55    pub graph: DspGraph,
56    pub sample_rate: f32,
57    pub muted: bool,
58}
59
60impl Scheduler {
61    pub fn new(sample_rate: f32) -> Self {
62        Self {
63            graph: DspGraph::new(),
64            sample_rate,
65            muted: false,
66        }
67    }
68
69    /// Called once per audio callback from the CPAL stream.
70    pub fn process_block<C>(&mut self, cmd_consumer: &mut C, output: &mut [f32])
71    where
72        C: Consumer<Item = Command>,
73    {
74        let mut processed = 0;
75        while processed < MAX_COMMANDS_PER_TICK {
76            match cmd_consumer.try_pop() {
77                Some(cmd) => { self.apply_command(cmd); processed += 1; }
78                None => break,
79            }
80        }
81        self.process_graph(output);
82    }
83
84    /// Simplified process block — no ring buffer.
85    /// Used when the scheduler is shared via Arc<Mutex<>> and the control thread
86    /// mutates it directly. The audio thread calls this after acquiring try_lock().
87    pub fn process_block_simple(&mut self, output: &mut [f32]) {
88        self.process_graph(output);
89    }
90
91    fn process_graph(&mut self, output: &mut [f32]) {
92        let sr = self.sample_rate;
93        let level_count = self.graph.levels.len();
94
95        for level_idx in 0..level_count {
96            let level_len = self.graph.levels[level_idx].len();
97
98            if level_len == 0 {
99                continue;
100            } else if level_len == 1 {
101                // Zero-overhead path: single node, no Rayon overhead.
102                let node_id = self.graph.levels[level_idx][0];
103                self.process_node(node_id, sr);
104            } else {
105                // Parallel path: collect raw pointers while holding &mut self,
106                // then dispatch DSP work in parallel via rayon::scope.
107                //
108                // SAFETY: Within a BFS level, every node writes to a distinct
109                // output buffer (disjoint BufferId). The BufferPool stores buffers
110                // in a flat Vec; tasks write to non-overlapping index ranges.
111                // Each processor and ParamBlock belongs to exactly one node.
112                let mut tasks: Vec<NodeTask> = Vec::with_capacity(level_len);
113
114                for i in 0..level_len {
115                    let node_id = self.graph.levels[level_idx][i];
116                    let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] =
117                        [None; MAX_INPUTS];
118
119                    if let Some(record) = self.graph.arena.get(node_id) {
120                        for (slot, maybe_src) in record.inputs.iter().enumerate() {
121                            if let Some(src_id) = maybe_src {
122                                if let Some(src_record) = self.graph.arena.get(*src_id) {
123                                    input_ptrs[slot] = Some(
124                                        self.graph.buffers.get(src_record.output_buffer)
125                                            as *const [f32; BUFFER_SIZE],
126                                    );
127                                }
128                            }
129                        }
130                        let record_mut = self.graph.arena.get_mut(node_id).unwrap();
131                        let output_buf_ptr = self.graph.buffers.get_mut(record_mut.output_buffer)
132                            as *mut [f32; BUFFER_SIZE];
133                        let params_ptr = &mut record_mut.params as *mut ParamBlock;
134                        let processor_ptr = &mut *record_mut.processor as *mut dyn DspNode;
135
136                        tasks.push(NodeTask {
137                            output_buf_ptr,
138                            params_ptr,
139                            processor_ptr,
140                            inputs: input_ptrs,
141                        });
142                    }
143                }
144
145                // SAFETY: each element of `tasks` points to disjoint memory.
146                // We pass a raw pointer per task so each closure captures a
147                // distinct non-aliasing pointer.
148                rayon::scope(|s| {
149                    for task in tasks.iter_mut() {
150                        // Capture the raw pointer value (usize) to avoid the
151                        // borrow checker complaining about &mut Vec element borrows.
152                        let ptr = task as *mut NodeTask as usize;
153                        s.spawn(move |_| {
154                            // SAFETY: ptr is a valid, exclusively-owned NodeTask.
155                            let t: &mut NodeTask = unsafe { &mut *(ptr as *mut NodeTask) };
156                            let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
157                                t.inputs.map(|p| p.map(|raw| unsafe { &*raw }));
158                            unsafe {
159                                (*t.processor_ptr).process(
160                                    &inputs,
161                                    &mut *t.output_buf_ptr,
162                                    &mut *t.params_ptr,
163                                    sr,
164                                );
165                            }
166                        });
167                    }
168                });
169            }
170        }
171
172        // Copy output node buffer to DAC
173        if self.muted {
174            output.fill(0.0);
175            return;
176        }
177        if let Some(out_id) = self.graph.output_node {
178            if let Some(record) = self.graph.arena.get(out_id) {
179                let buf = self.graph.buffers.get(record.output_buffer);
180                let frames = output.len() / 2;
181                for i in 0..frames.min(BUFFER_SIZE) {
182                    output[i * 2] = buf[i];
183                    output[i * 2 + 1] = buf[i];
184                }
185            }
186        } else {
187            // INVARIANT: empty graph → silence.
188            output.fill(0.0);
189        }
190    }
191
192    /// Process a single node on the calling thread.
193    fn process_node(&mut self, node_id: NodeId, sample_rate: f32) {
194        let mut input_ptrs: [Option<*const [f32; BUFFER_SIZE]>; MAX_INPUTS] = [None; MAX_INPUTS];
195
196        if let Some(record) = self.graph.arena.get(node_id) {
197            for (slot, maybe_src) in record.inputs.iter().enumerate() {
198                if let Some(src_id) = maybe_src {
199                    if let Some(src_record) = self.graph.arena.get(*src_id) {
200                        input_ptrs[slot] = Some(
201                            self.graph.buffers.get(src_record.output_buffer)
202                                as *const [f32; BUFFER_SIZE],
203                        );
204                    }
205                }
206            }
207        } else {
208            return;
209        }
210
211        let (output_buf_id, params_ptr, processor_ptr) = {
212            let record = self.graph.arena.get_mut(node_id).unwrap();
213            (
214                record.output_buffer,
215                &mut record.params as *mut ParamBlock,
216                &mut *record.processor as *mut dyn crate::node::DspNode,
217            )
218        };
219
220        let output_buf = self.graph.buffers.get_mut(output_buf_id);
221        let inputs: [Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS] =
222            input_ptrs.map(|p| p.map(|ptr| unsafe { &*ptr }));
223
224        unsafe {
225            (*processor_ptr).process(&inputs, output_buf, &mut *params_ptr, sample_rate);
226        }
227    }
228
229    fn apply_command(&mut self, cmd: Command) {
230        match cmd {
231            Command::AddNode { id } => { let _ = id; }
232            Command::RemoveNode { id } => { self.graph.remove_node(id); }
233            Command::Connect { src, dst, slot } => { self.graph.connect(src, dst, slot); }
234            Command::Disconnect { dst, slot } => { self.graph.disconnect(dst, slot); }
235            Command::UpdateParam { node, param_index, new_param } => {
236                if let Some(record) = self.graph.arena.get_mut(node) {
237                    if param_index < record.params.count {
238                        record.params.params[param_index] = new_param;
239                    }
240                }
241            }
242            Command::SetOutputNode { id } => { self.graph.set_output_node(id); }
243            Command::SetMute { muted } => { self.muted = muted; }
244            Command::ClearGraph => {
245                let ids: Vec<_> = self.graph.execution_order.clone();
246                for id in ids { self.graph.remove_node(id); }
247                self.graph.output_node = None;
248            }
249        }
250    }
251
252    /// Reference sequential implementation for testing.
253    /// Processes nodes in flat execution_order without parallelism.
254    #[cfg(test)]
255    fn process_graph_sequential(&mut self, output: &mut [f32]) {
256        let sr = self.sample_rate;
257
258        // Collect execution order into a local Vec to avoid borrow conflict
259        // between the immutable borrow of execution_order and the mutable
260        // borrow inside process_node.
261        let order: Vec<NodeId> = self.graph.execution_order.clone();
262        for &node_id in &order {
263            self.process_node(node_id, sr);
264        }
265
266        // Copy output node buffer to DAC
267        if self.muted {
268            output.fill(0.0);
269            return;
270        }
271        if let Some(out_id) = self.graph.output_node {
272            if let Some(record) = self.graph.arena.get(out_id) {
273                let buf = self.graph.buffers.get(record.output_buffer);
274                let frames = output.len() / 2;
275                for i in 0..frames.min(BUFFER_SIZE) {
276                    output[i * 2] = buf[i];
277                    output[i * 2 + 1] = buf[i];
278                }
279            }
280        } else {
281            output.fill(0.0);
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::node::DspNode;
290    use proptest::prelude::*;
291
292    /// Minimal deterministic test node for property testing.
293    /// Sums all inputs and multiplies by a fixed gain.
294    struct TestNode {
295        gain: f32,
296    }
297
298    impl TestNode {
299        fn new(gain: f32) -> Self {
300            Self { gain }
301        }
302    }
303
304    impl DspNode for TestNode {
305        fn process(
306            &mut self,
307            inputs: &[Option<&[f32; BUFFER_SIZE]>; MAX_INPUTS],
308            output: &mut [f32; BUFFER_SIZE],
309            _params: &mut ParamBlock,
310            _sample_rate: f32,
311        ) {
312            output.fill(0.0);
313            for input_opt in inputs.iter() {
314                if let Some(input) = input_opt {
315                    for i in 0..BUFFER_SIZE {
316                        output[i] += input[i] * self.gain;
317                    }
318                }
319            }
320        }
321
322        fn type_name(&self) -> &'static str {
323            "TestNode"
324        }
325    }
326
327    // Property 1
328    proptest! {
329        /// **Validates: Requirements 1.1, 1.4**
330        ///
331        /// Feature: aether-engine-upgrades, Property 1: parallel execution is output-equivalent
332        ///
333        /// Property 1: Parallel execution is output-equivalent to sequential execution.
334        ///
335        /// For any valid DSP patch (any combination of nodes and edges forming a valid DAG),
336        /// processing a block with the parallel Rayon scheduler SHALL produce a bit-identical
337        /// output buffer to processing the same block with the original sequential scheduler,
338        /// given the same initial node state and the same input.
339        #[test]
340        fn prop_parallel_equiv_sequential(
341            num_nodes in 1usize..=20,
342            edges in prop::collection::vec((0usize..20, 0usize..20, 0usize..MAX_INPUTS), 0..50),
343            seed in any::<u64>(),
344        ) {
345            // Create two identical schedulers
346            let mut scheduler_parallel = Scheduler::new(48000.0);
347            let mut scheduler_sequential = Scheduler::new(48000.0);
348
349            let mut node_ids = Vec::new();
350
351            // Add nodes to both schedulers with deterministic gains based on seed
352            for i in 0..num_nodes {
353                let gain = ((seed.wrapping_add(i as u64) % 100) as f32) / 100.0;
354                
355                let id1 = scheduler_parallel.graph.add_node(Box::new(TestNode::new(gain)));
356                let id2 = scheduler_sequential.graph.add_node(Box::new(TestNode::new(gain)));
357                
358                if let (Some(id1), Some(id2)) = (id1, id2) {
359                    // Verify both schedulers assigned the same NodeId
360                    prop_assert_eq!(id1.index, id2.index);
361                    prop_assert_eq!(id1.generation, id2.generation);
362                    node_ids.push(id1);
363                }
364            }
365
366            // Add edges to both schedulers (filter to maintain DAG invariant: src < dst)
367            for (src_idx, dst_idx, slot) in edges {
368                if src_idx < num_nodes && dst_idx < num_nodes && src_idx < dst_idx {
369                    let src = node_ids[src_idx];
370                    let dst = node_ids[dst_idx];
371                    
372                    scheduler_parallel.graph.connect(src, dst, slot);
373                    scheduler_sequential.graph.connect(src, dst, slot);
374                }
375            }
376
377            // Set output node to the last node if we have any nodes
378            if !node_ids.is_empty() {
379                let output_node = node_ids[num_nodes - 1];
380                scheduler_parallel.graph.set_output_node(output_node);
381                scheduler_sequential.graph.set_output_node(output_node);
382            }
383
384            // Prepare output buffers (stereo, 64 frames = 128 samples)
385            let mut output_parallel = vec![0.0f32; BUFFER_SIZE * 2];
386            let mut output_sequential = vec![0.0f32; BUFFER_SIZE * 2];
387
388            // Process one block with both schedulers
389            scheduler_parallel.process_graph(&mut output_parallel);
390            scheduler_sequential.process_graph_sequential(&mut output_sequential);
391
392            // Assert bit-identical output
393            for (i, (&p, &s)) in output_parallel.iter().zip(output_sequential.iter()).enumerate() {
394                prop_assert!(
395                    p == s || (p.is_nan() && s.is_nan()),
396                    "Output mismatch at sample {}: parallel={}, sequential={}",
397                    i, p, s
398                );
399            }
400        }
401    }
402}