1use crate::{config::MelConfig, mel::MelSpectrogram, stft};
2use ndarray::Array2;
3
4#[cfg(feature = "rtrb")]
5use rtrb::RingBuffer as RtrbBuffer;
6#[cfg(feature = "rtrb")]
7use rtrb::{Consumer, Producer};
8
9#[cfg(not(feature = "rtrb"))]
10use std::collections::VecDeque;
11
12pub struct RingBuffer {
13 accumulated_samples: Vec<f32>,
14
15 #[cfg(feature = "rtrb")]
16 producer: Producer<f32>,
17 #[cfg(feature = "rtrb")]
18 consumer: Consumer<f32>,
19
20 #[cfg(not(feature = "rtrb"))]
21 buffer: VecDeque<f32>,
22
23 fft: stft::Spectrogram,
24 mel: MelSpectrogram,
25 config: MelConfig,
26}
27
28impl RingBuffer {
29 pub fn new(config: MelConfig, capacity: usize) -> Self {
30 let hop_size = config.hop_size();
31 let fft_size = config.fft_size();
32 let sample_rate = config.sampling_rate();
33
34 #[cfg(feature = "rtrb")]
35 let (producer, consumer) = RtrbBuffer::<f32>::new(capacity);
36
37 #[cfg(not(feature = "rtrb"))]
38 let buffer = VecDeque::with_capacity(capacity);
39
40 Self {
41 config: config.clone(),
42 accumulated_samples: Vec::with_capacity(hop_size),
43 #[cfg(feature = "rtrb")]
44 producer,
45 #[cfg(feature = "rtrb")]
46 consumer,
47 #[cfg(not(feature = "rtrb"))]
48 buffer,
49 fft: stft::Spectrogram::new(fft_size, hop_size),
50 mel: MelSpectrogram::new(fft_size, sample_rate, config.n_mels()),
51 }
52 }
53
54 pub fn add_frame(&mut self, samples: &[f32]) {
55 #[cfg(feature = "rtrb")]
56 {
57 for &s in samples {
59 let _ = self.producer.push(s);
60 }
61 }
62 #[cfg(not(feature = "rtrb"))]
63 {
64 let available = self.buffer.capacity() - self.buffer.len();
65 if samples.len() > available {
66 self.buffer.drain(0..(samples.len() - available));
67 }
68 self.buffer.extend(samples);
69 }
70 }
71
72 pub fn add(&mut self, sample: f32) {
73 #[cfg(feature = "rtrb")]
74 {
75 let _ = self.producer.push(sample);
76 }
77 #[cfg(not(feature = "rtrb"))]
78 {
79 if self.buffer.len() == self.buffer.capacity() {
80 self.buffer.pop_front();
81 }
82 self.buffer.push_back(sample);
83 }
84 }
85
86 pub fn maybe_mel(&mut self) -> Option<Array2<f64>> {
87 let hop_size = self.config.hop_size();
88
89 #[cfg(feature = "rtrb")]
91 {
92 while self.accumulated_samples.len() < hop_size {
93 if let Ok(s) = self.consumer.pop() {
94 self.accumulated_samples.push(s);
95 } else {
96 break;
97 }
98 }
99 }
100 #[cfg(not(feature = "rtrb"))]
101 {
102 let to_add = hop_size - self.accumulated_samples.len();
103 let available = self.buffer.len().min(to_add);
104 self.accumulated_samples
105 .extend(self.buffer.drain(..available));
106 }
107
108 if self.accumulated_samples.len() < hop_size {
109 return None;
110 }
111
112 let mut frame = Vec::new();
114 std::mem::swap(&mut frame, &mut self.accumulated_samples);
115
116 let fft_res = self.fft.add(&frame);
117 match fft_res {
118 Some(fft) => Some(self.mel.add(&fft)),
119 None => None,
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::mel::interleave_frames;
128 use ndarray::{Array2, Zip};
129 use ndarray_npy::read_npy;
130 use soundkit::{audio_bytes::deinterleave_vecs_f32, wav::WavStreamProcessor};
131 use std::fs::File;
132 use std::io::Read;
133
134 #[test]
135 fn test_ringbuffer() {
136 let fft_size = 512;
137 let hop_size = 160;
138 let n_mels = 80;
139 let sampling_rate = 16_000.0;
140 let config = MelConfig::new(fft_size, hop_size, n_mels, sampling_rate);
141 let mut rb = RingBuffer::new(config, 1024);
142
143 let mut file = File::open("./testdata/jfk_f32le.wav").unwrap();
144 let mut processor = WavStreamProcessor::new();
145 let mut buf = [0_u8; 128];
146 let mut frames: Vec<Array2<f64>> = Vec::new();
147
148 loop {
149 let n = file.read(&mut buf).unwrap();
150 if n == 0 {
151 break;
152 }
153 if let Ok(Some(audio)) = processor.add(&buf[..n]) {
154 let samples = deinterleave_vecs_f32(audio.data(), 1);
155 rb.add_frame(&samples[0]);
156 if let Some(mel_frame) = rb.maybe_mel() {
157 frames.push(mel_frame);
158 }
159 }
160 }
161
162 let flat_f32: Vec<f32> = interleave_frames(&frames, false, 0);
164 let flat: Vec<f64> = flat_f32.into_iter().map(f64::from).collect();
165
166 let t = frames.len();
167 let f = frames[0].dim().0;
168 let got: Array2<f64> = Array2::from_shape_vec((f, t), flat).unwrap();
169
170 let want_f32: Array2<f32> = read_npy("./testdata/rust_jfk_golden.npy").unwrap();
172
173 assert_eq!(got.shape(), want_f32.shape());
174
175 Zip::from(&got).and(&want_f32).for_each(|&a_f64, &b_f32| {
176 let a = a_f64 as f32;
177 assert!((a - b_f32).abs() <= 1e-6);
178 });
179 }
180}