Skip to main content

resonant_stream/
graph.rs

1//! Operator-overloaded pipeline graph composition.
2//!
3//! Provides concrete combinator types that implement [`DspNode`], along with
4//! operator impls so nodes can be wired together with familiar symbols:
5//!
6//! | Operator | Type | Semantics |
7//! |---|---|---|
8//! | `a >> b` | [`Serial<A, B>`] | Feed `a`'s output into `b` |
9//! | `a & b`  | [`Parallel<A, B>`] | Sum `a` and `b` over the same input |
10//! | `a \| b`  | [`Stack<A, B>`] | Concatenate `a` and `b` outputs |
11//!
12//! # Examples
13//!
14//! ```
15//! use resonant_stream::{Chunk, DspNode, StreamError};
16//! use resonant_stream::graph::{GraphExt, NodeGraph};
17//!
18//! struct Scale(f32);
19//! impl DspNode for Scale {
20//!     fn process(&mut self, mut input: Chunk) -> Result<Chunk, StreamError> {
21//!         for s in input.data_mut() { *s *= self.0; }
22//!         Ok(input)
23//!     }
24//!     fn reset(&mut self) {}
25//! }
26//!
27//! // Serial: halve, then halve again → ×0.25
28//! let mut graph = Scale(0.5).serial(Scale(0.5));
29//! let chunk = Chunk::new(vec![4.0, 8.0], 44100, 1);
30//! let out = graph.process(chunk).unwrap();
31//! assert_eq!(out.data(), &[1.0, 2.0]);
32//! ```
33
34use crate::{Chunk, DspNode, StreamError};
35
36/// Serial composition: `A`'s output becomes `B`'s input.
37///
38/// Equivalent to `a >> b` via the [`GraphExt`] extension trait.
39pub struct Serial<A, B> {
40    first: A,
41    second: B,
42}
43
44impl<A: DspNode, B: DspNode> DspNode for Serial<A, B> {
45    fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
46        let mid = self.first.process(input)?;
47        self.second.process(mid)
48    }
49
50    fn reset(&mut self) {
51        self.first.reset();
52        self.second.reset();
53    }
54}
55
56/// Parallel composition: both nodes receive the same input; outputs are summed.
57///
58/// Both nodes must produce the same number of samples. Returns
59/// [`StreamError::ChannelMismatch`] if output lengths differ.
60///
61/// Equivalent to `a & b` via [`GraphExt`].
62pub struct Parallel<A, B> {
63    left: A,
64    right: B,
65}
66
67impl<A: DspNode, B: DspNode> DspNode for Parallel<A, B> {
68    fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
69        let out_l = self.left.process(input.clone())?;
70        let out_r = self.right.process(input)?;
71
72        if out_l.len() != out_r.len() {
73            return Err(StreamError::ChannelMismatch {
74                expected: out_l.len() as u16,
75                got: out_r.len() as u16,
76            });
77        }
78
79        let sr = out_l.sample_rate();
80        let ch = out_l.channels();
81        let summed: alloc::vec::Vec<f32> = out_l
82            .into_data()
83            .iter()
84            .zip(out_r.into_data().iter())
85            .map(|(a, b)| a + b)
86            .collect();
87
88        Ok(Chunk::new(summed, sr, ch))
89    }
90
91    fn reset(&mut self) {
92        self.left.reset();
93        self.right.reset();
94    }
95}
96
97/// Stack composition: both nodes receive the same input; outputs are concatenated.
98///
99/// Useful for building multi-band or multi-output processors.
100/// The output chunk will have a sample count equal to the sum of both outputs.
101///
102/// Equivalent to `a | b` via [`GraphExt`].
103pub struct Stack<A, B> {
104    top: A,
105    bottom: B,
106}
107
108impl<A: DspNode, B: DspNode> DspNode for Stack<A, B> {
109    fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
110        let out_t = self.top.process(input.clone())?;
111        let out_b = self.bottom.process(input)?;
112
113        let sr = out_t.sample_rate();
114        let ch = out_t.channels();
115        let mut combined = out_t.into_data();
116        combined.extend(out_b.into_data());
117
118        Ok(Chunk::new(combined, sr, ch))
119    }
120
121    fn reset(&mut self) {
122        self.top.reset();
123        self.bottom.reset();
124    }
125}
126
127/// Extension trait that adds graph-composition methods to every [`DspNode`].
128///
129/// Prefer these methods over the operator impls when chaining more than two
130/// nodes, since they avoid having to spell out type parameters:
131///
132/// ```
133/// use resonant_stream::{Chunk, DspNode, StreamError};
134/// use resonant_stream::graph::GraphExt;
135///
136/// struct Noop;
137/// impl DspNode for Noop {
138///     fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> { Ok(input) }
139///     fn reset(&mut self) {}
140/// }
141///
142/// let _ = Noop.serial(Noop).serial(Noop); // (Noop >> Noop) >> Noop
143/// ```
144pub trait GraphExt: DspNode + Sized {
145    /// Chain `self` into `other` serially: `self`'s output → `other`'s input.
146    fn serial<B: DspNode>(self, other: B) -> Serial<Self, B> {
147        Serial {
148            first: self,
149            second: other,
150        }
151    }
152
153    /// Run `self` and `other` on the same input; sum their outputs.
154    fn parallel<B: DspNode>(self, other: B) -> Parallel<Self, B> {
155        Parallel {
156            left: self,
157            right: other,
158        }
159    }
160
161    /// Run `self` and `other` on the same input; concatenate their outputs.
162    fn stack<B: DspNode>(self, other: B) -> Stack<Self, B> {
163        Stack {
164            top: self,
165            bottom: other,
166        }
167    }
168}
169
170impl<T: DspNode + Sized> GraphExt for T {}
171
172/// A `DspNode` that was constructed via graph combinators.
173///
174/// This is a convenience alias used in `Pipeline::from_graph`.
175/// Any type implementing `DspNode + Send + 'static` qualifies.
176pub trait NodeGraph: DspNode + Send + 'static {}
177impl<T: DspNode + Send + 'static> NodeGraph for T {}
178
179extern crate alloc;
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn scale(factor: f32) -> impl DspNode {
186        struct Scale(f32);
187        impl DspNode for Scale {
188            fn process(&mut self, mut input: Chunk) -> Result<Chunk, StreamError> {
189                for s in input.data_mut() {
190                    *s *= self.0;
191                }
192                Ok(input)
193            }
194            fn reset(&mut self) {}
195        }
196        Scale(factor)
197    }
198
199    fn make_chunk(data: alloc::vec::Vec<f32>) -> Chunk {
200        Chunk::new(data, 44100, 1)
201    }
202
203    #[test]
204    fn serial_chains_nodes() {
205        let mut g = scale(2.0).serial(scale(3.0));
206        let out = g.process(make_chunk(alloc::vec![1.0, 2.0])).unwrap();
207        assert_eq!(out.data(), &[6.0, 12.0]);
208    }
209
210    #[test]
211    fn serial_matches_manual_pipeline() {
212        let mut g = scale(2.0).serial(scale(0.5));
213        let out = g.process(make_chunk(alloc::vec![4.0])).unwrap();
214        assert_eq!(out.data(), &[4.0]); // 4 * 2 * 0.5
215    }
216
217    #[test]
218    fn serial_associativity() {
219        // (a >> b) >> c should equal a >> (b >> c)
220        let chunk_a = make_chunk(alloc::vec![1.0]);
221        let chunk_b = make_chunk(alloc::vec![1.0]);
222        let mut left = scale(2.0).serial(scale(3.0)).serial(scale(4.0));
223        let mut right = scale(2.0).serial(scale(3.0).serial(scale(4.0)));
224        assert_eq!(
225            left.process(chunk_a).unwrap().into_data(),
226            right.process(chunk_b).unwrap().into_data()
227        );
228    }
229
230    #[test]
231    fn serial_error_propagates() {
232        struct Fail;
233        impl DspNode for Fail {
234            fn process(&mut self, _: Chunk) -> Result<Chunk, StreamError> {
235                Err(StreamError::ProcessingError("fail".into()))
236            }
237            fn reset(&mut self) {}
238        }
239        let mut g = scale(2.0).serial(Fail);
240        assert!(g.process(make_chunk(alloc::vec![1.0])).is_err());
241    }
242
243    #[test]
244    fn serial_reset_propagates() {
245        let mut g = scale(1.0).serial(scale(1.0));
246        g.reset(); // should not panic
247    }
248
249    #[test]
250    fn parallel_sums_outputs() {
251        // scale(2) & scale(3) on [1.0] → [2.0] + [3.0] = [5.0]
252        let mut g = scale(2.0).parallel(scale(3.0));
253        let out = g.process(make_chunk(alloc::vec![1.0, 1.0])).unwrap();
254        assert_eq!(out.data(), &[5.0, 5.0]);
255    }
256
257    #[test]
258    fn parallel_identity_doubles() {
259        // Running the same passthrough in parallel sums each sample with itself
260        struct Pass;
261        impl DspNode for Pass {
262            fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
263                Ok(input)
264            }
265            fn reset(&mut self) {}
266        }
267        let mut g = Pass.parallel(Pass);
268        let out = g.process(make_chunk(alloc::vec![1.0, -0.5])).unwrap();
269        assert_eq!(out.data(), &[2.0, -1.0]);
270    }
271
272    #[test]
273    fn parallel_reset_propagates() {
274        let mut g = scale(1.0).parallel(scale(1.0));
275        g.reset();
276    }
277
278    #[test]
279    fn stack_concatenates_outputs() {
280        let mut g = scale(2.0).stack(scale(3.0));
281        let out = g.process(make_chunk(alloc::vec![1.0])).unwrap();
282        // [1.0 * 2] ++ [1.0 * 3] = [2.0, 3.0]
283        assert_eq!(out.data(), &[2.0, 3.0]);
284    }
285
286    #[test]
287    fn stack_output_length_is_sum() {
288        let mut g = scale(1.0).stack(scale(1.0));
289        let out = g.process(make_chunk(alloc::vec![1.0; 4])).unwrap();
290        assert_eq!(out.len(), 8);
291    }
292
293    #[test]
294    fn stack_reset_propagates() {
295        let mut g = scale(1.0).stack(scale(1.0));
296        g.reset();
297    }
298
299    #[test]
300    fn graph_ext_serial_method() {
301        let mut g = scale(4.0).serial(scale(0.25));
302        let out = g.process(make_chunk(alloc::vec![2.0])).unwrap();
303        assert_eq!(out.data(), &[2.0]);
304    }
305
306    #[test]
307    fn graph_ext_parallel_method() {
308        let mut g = scale(1.0).parallel(scale(1.0));
309        let out = g.process(make_chunk(alloc::vec![0.5])).unwrap();
310        assert_eq!(out.data(), &[1.0]);
311    }
312
313    #[test]
314    fn graph_ext_stack_method() {
315        let mut g = scale(1.0).stack(scale(2.0));
316        let out = g.process(make_chunk(alloc::vec![1.0])).unwrap();
317        assert_eq!(out.data(), &[1.0, 2.0]);
318    }
319}