use crate::buffer::AudioBuffer;
use serde::{Deserialize, Serialize};
#[must_use]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingMatrix {
gains: Vec<Vec<f32>>,
inputs: usize,
outputs: usize,
}
impl RoutingMatrix {
pub fn new(inputs: usize, outputs: usize) -> Self {
Self {
gains: vec![vec![0.0; inputs]; outputs],
inputs,
outputs,
}
}
pub fn identity(channels: usize) -> Self {
let mut m = Self::new(channels, channels);
for i in 0..channels {
m.gains[i][i] = 1.0;
}
m
}
pub fn mono_to_stereo() -> Self {
let mut m = Self::new(1, 2);
m.gains[0][0] = 1.0; m.gains[1][0] = 1.0; m
}
pub fn stereo_to_mono() -> Self {
let mut m = Self::new(2, 1);
m.gains[0][0] = 0.5; m.gains[0][1] = 0.5; m
}
pub fn mid_side_encode() -> Self {
let mut m = Self::new(2, 2);
m.gains[0][0] = 0.5; m.gains[0][1] = 0.5; m.gains[1][0] = 0.5; m.gains[1][1] = -0.5; m
}
pub fn mid_side_decode() -> Self {
let mut m = Self::new(2, 2);
m.gains[0][0] = 1.0; m.gains[0][1] = 1.0; m.gains[1][0] = 1.0; m.gains[1][1] = -1.0; m
}
pub fn set_gain(&mut self, output_ch: usize, input_ch: usize, gain: f32) -> bool {
if output_ch < self.outputs && input_ch < self.inputs {
self.gains[output_ch][input_ch] = gain;
true
} else {
false
}
}
#[must_use]
pub fn gain(&self, output_ch: usize, input_ch: usize) -> f32 {
if output_ch < self.outputs && input_ch < self.inputs {
self.gains[output_ch][input_ch]
} else {
0.0
}
}
#[must_use]
pub fn inputs(&self) -> usize {
self.inputs
}
#[must_use]
pub fn outputs(&self) -> usize {
self.outputs
}
pub fn apply(&self, buf: &AudioBuffer) -> crate::Result<AudioBuffer> {
let in_ch = buf.channels as usize;
if in_ch != self.inputs {
return Err(crate::NadaError::Conversion(format!(
"routing matrix expects {} input channels, got {}",
self.inputs, in_ch
)));
}
let frames = buf.frames;
let out_ch = self.outputs;
let mut output = vec![0.0f32; frames * out_ch];
for frame in 0..frames {
for o in 0..out_ch {
let mut sum = 0.0f32;
for i in 0..self.inputs {
let gain = self.gains[o][i];
if gain != 0.0 {
sum += buf.samples[frame * in_ch + i] * gain;
}
}
output[frame * out_ch + o] = sum;
}
}
AudioBuffer::from_interleaved(output, out_ch as u32, buf.sample_rate)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_passthrough() {
let m = RoutingMatrix::identity(2);
let buf = AudioBuffer::from_interleaved(vec![0.5, -0.5, 0.3, -0.3], 2, 44100).unwrap();
let out = m.apply(&buf).unwrap();
assert_eq!(out.channels(), 2);
assert_eq!(out.samples(), buf.samples());
}
#[test]
fn mono_to_stereo() {
let m = RoutingMatrix::mono_to_stereo();
let buf = AudioBuffer::from_interleaved(vec![0.5, 0.3], 1, 44100).unwrap();
let out = m.apply(&buf).unwrap();
assert_eq!(out.channels(), 2);
assert_eq!(out.samples()[0], 0.5); assert_eq!(out.samples()[1], 0.5); assert_eq!(out.samples()[2], 0.3); assert_eq!(out.samples()[3], 0.3); }
#[test]
fn stereo_to_mono() {
let m = RoutingMatrix::stereo_to_mono();
let buf = AudioBuffer::from_interleaved(vec![1.0, 0.5], 2, 44100).unwrap();
let out = m.apply(&buf).unwrap();
assert_eq!(out.channels(), 1);
assert!((out.samples()[0] - 0.75).abs() < 1e-6); }
#[test]
fn mid_side_roundtrip() {
let buf = AudioBuffer::from_interleaved(vec![0.8, 0.2, 0.6, 0.4], 2, 44100).unwrap();
let encoded = RoutingMatrix::mid_side_encode().apply(&buf).unwrap();
let decoded = RoutingMatrix::mid_side_decode().apply(&encoded).unwrap();
for (orig, dec) in buf.samples().iter().zip(decoded.samples().iter()) {
assert!((orig - dec).abs() < 1e-5, "M/S roundtrip: {orig} != {dec}");
}
}
#[test]
fn custom_crosspoint_gain() {
let mut m = RoutingMatrix::new(2, 2);
assert!(m.set_gain(0, 0, 0.7));
assert!(m.set_gain(1, 1, 0.3));
assert_eq!(m.gain(0, 0), 0.7);
assert_eq!(m.gain(1, 1), 0.3);
assert_eq!(m.gain(0, 1), 0.0); }
#[test]
fn channel_mismatch_error() {
let m = RoutingMatrix::new(2, 2);
let buf = AudioBuffer::from_interleaved(vec![0.5], 1, 44100).unwrap();
assert!(m.apply(&buf).is_err());
}
#[test]
fn zero_gain_passthrough() {
let m = RoutingMatrix::new(2, 2); let buf = AudioBuffer::from_interleaved(vec![1.0, 1.0], 2, 44100).unwrap();
let out = m.apply(&buf).unwrap();
assert!(out.samples().iter().all(|&s| s == 0.0));
}
#[test]
fn out_of_bounds_set_gain() {
let mut m = RoutingMatrix::new(2, 2);
assert!(!m.set_gain(5, 0, 1.0));
assert_eq!(m.gain(5, 0), 0.0);
}
}