Skip to main content

speech_prep/decoder/
mixer.rs

1use crate::error::{Error, Result};
2
3use super::MixedAudio;
4
5/// Channel mixer for converting multi-channel audio to mono.
6#[derive(Debug, Default, Clone, Copy)]
7pub struct ChannelMixer;
8
9impl ChannelMixer {
10    /// Mix multi-channel audio into a mono buffer via averaging.
11    pub fn mix_to_mono(samples: &[f32], channels: u8) -> Result<MixedAudio> {
12        if channels == 0 {
13            return Err(Error::InvalidInput("channel count cannot be zero".into()));
14        }
15        if ![1, 2, 4, 6].contains(&channels) {
16            return Err(Error::InvalidInput(format!(
17                "unsupported channel count: {channels} (supports 1, 2, 4, 6 only)"
18            )));
19        }
20        if !samples.len().is_multiple_of(usize::from(channels)) {
21            let sample_len = samples.len();
22            return Err(Error::InvalidInput(format!(
23                "sample count {sample_len} not divisible by channel count {channels}"
24            )));
25        }
26
27        let peak_before = Self::calculate_peak(samples);
28
29        if channels == 1 {
30            return Ok(MixedAudio {
31                samples: samples.to_vec(),
32                original_channels: 1,
33                peak_before_mix: peak_before,
34                peak_after_mix: peak_before,
35            });
36        }
37
38        let frame_count = samples.len() / usize::from(channels);
39        let mut mixed = Vec::with_capacity(frame_count);
40
41        for frame in samples.chunks_exact(usize::from(channels)) {
42            let sum: f32 = frame.iter().sum();
43            let avg = sum / f32::from(channels);
44            mixed.push(avg.clamp(-1.0, 1.0));
45        }
46
47        let peak_after = Self::calculate_peak(&mixed);
48
49        Ok(MixedAudio {
50            samples: mixed,
51            original_channels: channels,
52            peak_before_mix: peak_before,
53            peak_after_mix: peak_after,
54        })
55    }
56
57    #[inline]
58    fn calculate_peak(samples: &[f32]) -> f32 {
59        samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    type TestResult<T> = std::result::Result<T, String>;
68
69    #[test]
70    fn test_mix_identity_for_mono() -> TestResult<()> {
71        let mono = vec![0.1, -0.2, 0.3];
72        let mixed = ChannelMixer::mix_to_mono(&mono, 1).map_err(|e| e.to_string())?;
73        assert_eq!(mixed.samples, mono);
74        Ok(())
75    }
76
77    #[test]
78    fn test_mix_stereo_to_mono() -> TestResult<()> {
79        let stereo = vec![0.5, -0.5, 0.8, 0.2];
80        let mixed = ChannelMixer::mix_to_mono(&stereo, 2).map_err(|e| e.to_string())?;
81        assert_eq!(mixed.samples.len(), 2);
82        Ok(())
83    }
84
85    #[test]
86    fn test_mix_reject_invalid_channels() {
87        let samples = vec![0.0, 0.0, 0.0];
88        assert!(ChannelMixer::mix_to_mono(&samples, 3).is_err());
89    }
90
91    #[test]
92    fn test_mix_reject_misaligned_samples() {
93        let samples = vec![0.0, 0.0, 0.0];
94        assert!(ChannelMixer::mix_to_mono(&samples, 2).is_err());
95    }
96
97    #[test]
98    fn test_mix_empty_input() -> TestResult<()> {
99        let empty: Vec<f32> = Vec::new();
100        let mixed = ChannelMixer::mix_to_mono(&empty, 1).map_err(|e| e.to_string())?;
101        assert_eq!(mixed.samples.len(), 0);
102        assert_eq!(mixed.sample_count(), 0);
103        Ok(())
104    }
105
106    #[test]
107    fn test_mix_single_frame_stereo() -> TestResult<()> {
108        let stereo = vec![0.6, 0.4];
109        let mixed = ChannelMixer::mix_to_mono(&stereo, 2).map_err(|e| e.to_string())?;
110        assert_eq!(mixed.samples.len(), 1);
111        assert!((mixed.samples[0] - 0.5).abs() < f32::EPSILON);
112        Ok(())
113    }
114
115    #[test]
116    fn test_mix_is_clipped_detection() -> TestResult<()> {
117        let clipped = vec![1.0, 1.0];
118        let unclipped = vec![0.5, 0.5];
119
120        let mixed_clipped = ChannelMixer::mix_to_mono(&clipped, 2).map_err(|e| e.to_string())?;
121        assert!(mixed_clipped.is_clipped());
122
123        let mixed_unclipped =
124            ChannelMixer::mix_to_mono(&unclipped, 2).map_err(|e| e.to_string())?;
125        assert!(!mixed_unclipped.is_clipped());
126        Ok(())
127    }
128
129    #[test]
130    fn test_mix_peak_ratio_behavior() -> TestResult<()> {
131        let stereo = vec![0.8, -0.8, 0.4, -0.4];
132        let mixed = ChannelMixer::mix_to_mono(&stereo, 2).map_err(|e| e.to_string())?;
133        assert!(mixed.peak_before_mix > 0.0);
134        assert!(mixed.peak_ratio() <= 1.0);
135        Ok(())
136    }
137}