mel_spec/
rb.rs

1use crate::{config::MelConfig, mel::MelSpectrogram, stft};
2use ndarray::Array2;
3
4#[cfg(feature = "rtrb")]
5use rtrb::RingBuffer as RtrbBuffer;
6#[cfg(feature = "rtrb")]
7use rtrb::{Consumer, Producer};
8
9#[cfg(not(feature = "rtrb"))]
10use std::collections::VecDeque;
11
12pub struct RingBuffer {
13    accumulated_samples: Vec<f32>,
14
15    #[cfg(feature = "rtrb")]
16    producer: Producer<f32>,
17    #[cfg(feature = "rtrb")]
18    consumer: Consumer<f32>,
19
20    #[cfg(not(feature = "rtrb"))]
21    buffer: VecDeque<f32>,
22
23    fft: stft::Spectrogram,
24    mel: MelSpectrogram,
25    config: MelConfig,
26}
27
28impl RingBuffer {
29    pub fn new(config: MelConfig, capacity: usize) -> Self {
30        let hop_size = config.hop_size();
31        let fft_size = config.fft_size();
32        let sample_rate = config.sampling_rate();
33
34        #[cfg(feature = "rtrb")]
35        let (producer, consumer) = RtrbBuffer::<f32>::new(capacity);
36
37        #[cfg(not(feature = "rtrb"))]
38        let buffer = VecDeque::with_capacity(capacity);
39
40        Self {
41            config: config.clone(),
42            accumulated_samples: Vec::with_capacity(hop_size),
43            #[cfg(feature = "rtrb")]
44            producer,
45            #[cfg(feature = "rtrb")]
46            consumer,
47            #[cfg(not(feature = "rtrb"))]
48            buffer,
49            fft: stft::Spectrogram::new(fft_size, hop_size),
50            mel: MelSpectrogram::new(fft_size, sample_rate, config.n_mels()),
51        }
52    }
53
54    pub fn add_frame(&mut self, samples: &[f32]) {
55        #[cfg(feature = "rtrb")]
56        {
57            // rtrb::Producer::push will overwrite old data if full
58            for &s in samples {
59                let _ = self.producer.push(s);
60            }
61        }
62        #[cfg(not(feature = "rtrb"))]
63        {
64            let available = self.buffer.capacity() - self.buffer.len();
65            if samples.len() > available {
66                self.buffer.drain(0..(samples.len() - available));
67            }
68            self.buffer.extend(samples);
69        }
70    }
71
72    pub fn add(&mut self, sample: f32) {
73        #[cfg(feature = "rtrb")]
74        {
75            let _ = self.producer.push(sample);
76        }
77        #[cfg(not(feature = "rtrb"))]
78        {
79            if self.buffer.len() == self.buffer.capacity() {
80                self.buffer.pop_front();
81            }
82            self.buffer.push_back(sample);
83        }
84    }
85
86    pub fn maybe_mel(&mut self) -> Option<Array2<f64>> {
87        let hop_size = self.config.hop_size();
88
89        // first, accumulate into `accumulated_samples`
90        #[cfg(feature = "rtrb")]
91        {
92            while self.accumulated_samples.len() < hop_size {
93                if let Ok(s) = self.consumer.pop() {
94                    self.accumulated_samples.push(s);
95                } else {
96                    break;
97                }
98            }
99        }
100        #[cfg(not(feature = "rtrb"))]
101        {
102            let to_add = hop_size - self.accumulated_samples.len();
103            let available = self.buffer.len().min(to_add);
104            self.accumulated_samples
105                .extend(self.buffer.drain(..available));
106        }
107
108        if self.accumulated_samples.len() < hop_size {
109            return None;
110        }
111
112        // we have enough to do one frame
113        let mut frame = Vec::new();
114        std::mem::swap(&mut frame, &mut self.accumulated_samples);
115
116        let fft_res = self.fft.add(&frame);
117        match fft_res {
118            Some(fft) => Some(self.mel.add(&fft)),
119            None => None,
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::mel::interleave_frames;
128    use ndarray::{Array2, Zip};
129    use ndarray_npy::read_npy;
130    use soundkit::{audio_bytes::deinterleave_vecs_f32, wav::WavStreamProcessor};
131    use std::fs::File;
132    use std::io::Read;
133
134    #[test]
135    fn test_ringbuffer() {
136        let fft_size = 512;
137        let hop_size = 160;
138        let n_mels = 80;
139        let sampling_rate = 16_000.0;
140        let config = MelConfig::new(fft_size, hop_size, n_mels, sampling_rate);
141        let mut rb = RingBuffer::new(config, 1024);
142
143        let mut file = File::open("./testdata/jfk_f32le.wav").unwrap();
144        let mut processor = WavStreamProcessor::new();
145        let mut buf = [0_u8; 128];
146        let mut frames: Vec<Array2<f64>> = Vec::new();
147
148        loop {
149            let n = file.read(&mut buf).unwrap();
150            if n == 0 {
151                break;
152            }
153            if let Ok(Some(audio)) = processor.add(&buf[..n]) {
154                let samples = deinterleave_vecs_f32(audio.data(), 1);
155                rb.add_frame(&samples[0]);
156                if let Some(mel_frame) = rb.maybe_mel() {
157                    frames.push(mel_frame);
158                }
159            }
160        }
161
162        // interleave and collect as f64
163        let flat_f32: Vec<f32> = interleave_frames(&frames, false, 0);
164        let flat: Vec<f64> = flat_f32.into_iter().map(f64::from).collect();
165
166        let t = frames.len();
167        let f = frames[0].dim().0;
168        let got: Array2<f64> = Array2::from_shape_vec((f, t), flat).unwrap();
169
170        // load golden as f32
171        let want_f32: Array2<f32> = read_npy("./testdata/rust_jfk_golden.npy").unwrap();
172
173        assert_eq!(got.shape(), want_f32.shape());
174
175        Zip::from(&got).and(&want_f32).for_each(|&a_f64, &b_f32| {
176            let a = a_f64 as f32;
177            assert!((a - b_f32).abs() <= 1e-6);
178        });
179    }
180}