use crate::{Chunk, DspNode, StreamError};
pub struct Serial<A, B> {
first: A,
second: B,
}
impl<A: DspNode, B: DspNode> DspNode for Serial<A, B> {
fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
let mid = self.first.process(input)?;
self.second.process(mid)
}
fn reset(&mut self) {
self.first.reset();
self.second.reset();
}
}
pub struct Parallel<A, B> {
left: A,
right: B,
}
impl<A: DspNode, B: DspNode> DspNode for Parallel<A, B> {
fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
let out_l = self.left.process(input.clone())?;
let out_r = self.right.process(input)?;
if out_l.len() != out_r.len() {
return Err(StreamError::ChannelMismatch {
expected: out_l.len() as u16,
got: out_r.len() as u16,
});
}
let sr = out_l.sample_rate();
let ch = out_l.channels();
let summed: alloc::vec::Vec<f32> = out_l
.into_data()
.iter()
.zip(out_r.into_data().iter())
.map(|(a, b)| a + b)
.collect();
Ok(Chunk::new(summed, sr, ch))
}
fn reset(&mut self) {
self.left.reset();
self.right.reset();
}
}
pub struct Stack<A, B> {
top: A,
bottom: B,
}
impl<A: DspNode, B: DspNode> DspNode for Stack<A, B> {
fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
let out_t = self.top.process(input.clone())?;
let out_b = self.bottom.process(input)?;
let sr = out_t.sample_rate();
let ch = out_t.channels();
let mut combined = out_t.into_data();
combined.extend(out_b.into_data());
Ok(Chunk::new(combined, sr, ch))
}
fn reset(&mut self) {
self.top.reset();
self.bottom.reset();
}
}
pub trait GraphExt: DspNode + Sized {
fn serial<B: DspNode>(self, other: B) -> Serial<Self, B> {
Serial {
first: self,
second: other,
}
}
fn parallel<B: DspNode>(self, other: B) -> Parallel<Self, B> {
Parallel {
left: self,
right: other,
}
}
fn stack<B: DspNode>(self, other: B) -> Stack<Self, B> {
Stack {
top: self,
bottom: other,
}
}
}
impl<T: DspNode + Sized> GraphExt for T {}
pub trait NodeGraph: DspNode + Send + 'static {}
impl<T: DspNode + Send + 'static> NodeGraph for T {}
extern crate alloc;
#[cfg(test)]
mod tests {
use super::*;
fn scale(factor: f32) -> impl DspNode {
struct Scale(f32);
impl DspNode for Scale {
fn process(&mut self, mut input: Chunk) -> Result<Chunk, StreamError> {
for s in input.data_mut() {
*s *= self.0;
}
Ok(input)
}
fn reset(&mut self) {}
}
Scale(factor)
}
fn make_chunk(data: alloc::vec::Vec<f32>) -> Chunk {
Chunk::new(data, 44100, 1)
}
#[test]
fn serial_chains_nodes() {
let mut g = scale(2.0).serial(scale(3.0));
let out = g.process(make_chunk(alloc::vec![1.0, 2.0])).unwrap();
assert_eq!(out.data(), &[6.0, 12.0]);
}
#[test]
fn serial_matches_manual_pipeline() {
let mut g = scale(2.0).serial(scale(0.5));
let out = g.process(make_chunk(alloc::vec![4.0])).unwrap();
assert_eq!(out.data(), &[4.0]); }
#[test]
fn serial_associativity() {
let chunk_a = make_chunk(alloc::vec![1.0]);
let chunk_b = make_chunk(alloc::vec![1.0]);
let mut left = scale(2.0).serial(scale(3.0)).serial(scale(4.0));
let mut right = scale(2.0).serial(scale(3.0).serial(scale(4.0)));
assert_eq!(
left.process(chunk_a).unwrap().into_data(),
right.process(chunk_b).unwrap().into_data()
);
}
#[test]
fn serial_error_propagates() {
struct Fail;
impl DspNode for Fail {
fn process(&mut self, _: Chunk) -> Result<Chunk, StreamError> {
Err(StreamError::ProcessingError("fail".into()))
}
fn reset(&mut self) {}
}
let mut g = scale(2.0).serial(Fail);
assert!(g.process(make_chunk(alloc::vec![1.0])).is_err());
}
#[test]
fn serial_reset_propagates() {
let mut g = scale(1.0).serial(scale(1.0));
g.reset(); }
#[test]
fn parallel_sums_outputs() {
let mut g = scale(2.0).parallel(scale(3.0));
let out = g.process(make_chunk(alloc::vec![1.0, 1.0])).unwrap();
assert_eq!(out.data(), &[5.0, 5.0]);
}
#[test]
fn parallel_identity_doubles() {
struct Pass;
impl DspNode for Pass {
fn process(&mut self, input: Chunk) -> Result<Chunk, StreamError> {
Ok(input)
}
fn reset(&mut self) {}
}
let mut g = Pass.parallel(Pass);
let out = g.process(make_chunk(alloc::vec![1.0, -0.5])).unwrap();
assert_eq!(out.data(), &[2.0, -1.0]);
}
#[test]
fn parallel_reset_propagates() {
let mut g = scale(1.0).parallel(scale(1.0));
g.reset();
}
#[test]
fn stack_concatenates_outputs() {
let mut g = scale(2.0).stack(scale(3.0));
let out = g.process(make_chunk(alloc::vec![1.0])).unwrap();
assert_eq!(out.data(), &[2.0, 3.0]);
}
#[test]
fn stack_output_length_is_sum() {
let mut g = scale(1.0).stack(scale(1.0));
let out = g.process(make_chunk(alloc::vec![1.0; 4])).unwrap();
assert_eq!(out.len(), 8);
}
#[test]
fn stack_reset_propagates() {
let mut g = scale(1.0).stack(scale(1.0));
g.reset();
}
#[test]
fn graph_ext_serial_method() {
let mut g = scale(4.0).serial(scale(0.25));
let out = g.process(make_chunk(alloc::vec![2.0])).unwrap();
assert_eq!(out.data(), &[2.0]);
}
#[test]
fn graph_ext_parallel_method() {
let mut g = scale(1.0).parallel(scale(1.0));
let out = g.process(make_chunk(alloc::vec![0.5])).unwrap();
assert_eq!(out.data(), &[1.0]);
}
#[test]
fn graph_ext_stack_method() {
let mut g = scale(1.0).stack(scale(2.0));
let out = g.process(make_chunk(alloc::vec![1.0])).unwrap();
assert_eq!(out.data(), &[1.0, 2.0]);
}
}