mel_spec/
vad.rs

1use image::{ImageBuffer, Rgb};
2use ndarray::{concatenate, s, Array, Array2, Axis};
3use std::collections::HashSet;
4
5#[derive(Copy, Clone, Default)]
6pub struct DetectionSettings {
7    pub min_energy: f64,
8    pub min_y: usize,
9    pub min_x: usize,
10    pub min_mel: usize,
11}
12
13/// The purpose of these settings is to detect the "edges" of features in the
14/// mel spectrogram, favouring gradients that are longer in the time axis and
15/// above a certain power threshold.
16///
17/// Speech is characteristic for occupying several mel frequency bins at once
18/// and continuing for several frames.
19///
20/// We naively sketch these waves as vectors, and find vertical columns where
21/// there are no intersections above certain threshold - ie, short gaps in
22/// speech.
23///
24/// `min_energy`: the relative power of the signal, set at around 1 to
25///  discard noise.
26/// `min_y`: this refers to the number of frames the gradient
27///  interescts: i.e., its duration along the x-axis.
28/// `min_x`: the number of distinct gradients that must
29///  cross a column on the x-asis for it to be considered generally
30///  intersected. The reasoning here is that speech will always occupy more
31///  than one mel frequency bin so a time frame with speech will have several
32///  or more intersections.
33///  `min_mel`: bins below this wont be counted. Useful if the signal
34///  is noisy in the very low (first couple of bins) frequencies.
35///  `min_frames`: the min frames to accumulate before looking for a
36///  boundary to split at. This should be at least 50-100. (100 frames is
37///  1 second using Whisper FFT settings).
38///
39/// See `doc/jfk_vad_boundaries.png` for a visualisation.
40impl DetectionSettings {
41    pub fn new(min_energy: f64, min_y: usize, min_x: usize, min_mel: usize) -> Self {
42        Self {
43            min_energy,
44            min_y,
45            min_x,
46            min_mel,
47        }
48    }
49
50    /// Signals below this threshold will not be counted in edge detection.
51    /// `1.0` is a good default.
52    pub fn min_energy(&self) -> f64 {
53        self.min_energy
54    }
55
56    /// The min length of a detectable gradient in x-asis frames.
57    /// `10` is a good default.
58    pub fn min_y(&self) -> usize {
59        self.min_y
60    }
61
62    /// The min number of gradients allowed to intersect an x-axis before it is
63    /// discounted from being a speech boundary.
64    /// `10` is a good default.
65    pub fn min_x(&self) -> usize {
66        self.min_x
67    }
68
69    /// Ignore mel bands below this setting.
70    /// `0` is a good default.
71    pub fn min_mel(&self) -> usize {
72        self.min_mel
73    }
74}
75
76pub struct VoiceActivityDetector {
77    mel_buffer: Vec<Array2<f64>>,
78    settings: DetectionSettings,
79    idx: usize,
80}
81
82impl VoiceActivityDetector {
83    pub fn new(settings: &DetectionSettings) -> Self {
84        let mel_buffer: Vec<Array2<f64>> = Vec::new();
85
86        Self {
87            mel_buffer,
88            settings: settings.to_owned(),
89            idx: 0,
90        }
91    }
92
93    /// Add Mel spectrogram - should be a single frame.
94    pub fn add(&mut self, frame: &Array2<f64>) -> Option<bool> {
95        let min_x = self.settings.min_x;
96        if self.idx == 128 {
97            self.mel_buffer = self.mel_buffer[(self.mel_buffer.len() - min_x)..].to_vec();
98            self.idx = min_x;
99        }
100        self.mel_buffer.push(frame.to_owned());
101        self.idx += 1;
102        if self.idx < min_x {
103            return None;
104        }
105
106        // check if we are at cutable frame position
107        let window = &self.mel_buffer[self.idx - min_x..];
108        let edge_info = vad_boundaries(&window, &self.settings);
109        let ni = edge_info.intersected();
110        if ni.is_empty() {
111            Some(false)
112        } else {
113            Some(ni[0] == 0)
114        }
115    }
116}
117
118pub fn vad_on(edge_info: &EdgeInfo, n: usize) -> bool {
119    let intersected_columns = &edge_info.intersected_columns;
120
121    if intersected_columns.is_empty() {
122        return false;
123    }
124
125    let mut contiguous_count = 1;
126    let mut prev_index = intersected_columns[0];
127
128    for &index in &intersected_columns[1..] {
129        if index == prev_index + 1 {
130            contiguous_count += 1;
131        } else {
132            contiguous_count = 1;
133        }
134
135        if contiguous_count >= n {
136            return true;
137        }
138
139        prev_index = index;
140    }
141
142    false
143}
144
145pub fn vad_boundaries(frames: &[Array2<f64>], settings: &DetectionSettings) -> EdgeInfo {
146    let array_views: Vec<_> = frames.iter().map(|a| a.view()).collect();
147    let min_energy = settings.min_energy;
148    let min_y = settings.min_y;
149    let min_mel = settings.min_mel;
150
151    // Concatenate the array views along the time (x) axis.
152    let merged_frames = concatenate(Axis(1), &array_views).unwrap();
153    let shape = merged_frames.raw_dim();
154    let width = shape[1];
155    let height = shape[0];
156
157    // Sobel kernels for edge detection.
158    let sobel_x =
159        Array::from_shape_vec((3, 3), vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0])
160            .unwrap();
161    let sobel_y =
162        Array::from_shape_vec((3, 3), vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0])
163            .unwrap();
164
165    // Compute gradient magnitude for each valid (3x3) patch.
166    let gradient_mag = Array::from_shape_fn((height - 2, width - 2), |(y, x)| {
167        let view = merged_frames.slice(s![y..y + 3, x..x + 3]);
168        let mut gradient_x = 0.0;
169        let mut gradient_y = 0.0;
170        for j in 0..3 {
171            for i in 0..3 {
172                gradient_x += view[[j, i]] * sobel_x[[j, i]];
173                gradient_y += view[[j, i]] * sobel_y[[j, i]];
174            }
175        }
176        (gradient_x * gradient_x + gradient_y * gradient_y).sqrt()
177    });
178
179    // Build a raw binary classification for each column based on count of above-threshold pixels.
180    let mut raw_classification = Vec::with_capacity(width - 2);
181    for x in 0..(width - 2) {
182        let mut count = 0;
183        for y in 0..(height - 2) {
184            let grad = gradient_mag[(y, x)];
185            if y >= min_mel && grad >= min_energy {
186                count += 1;
187            }
188        }
189        raw_classification.push(count >= min_y);
190    }
191
192    // Apply temporal smoothing via a moving-window majority vote.
193    // For each index, we consider a window of neighboring columns (window size can be adjusted).
194    let smoothed_classification = smooth_mask(&raw_classification, 4);
195
196    // Split the smoothed results into active (intersected) and inactive (non-intersected) columns.
197    let mut intersected_columns = Vec::new();
198    let mut non_intersected_columns = Vec::new();
199    for (x, &active) in smoothed_classification.iter().enumerate() {
200        if active {
201            intersected_columns.push(x);
202        } else {
203            non_intersected_columns.push(x);
204        }
205    }
206
207    // We leave gradient_positions empty in this version.
208    let gradient_positions = HashSet::new();
209
210    EdgeInfo::new(
211        non_intersected_columns,
212        intersected_columns,
213        gradient_positions,
214    )
215}
216
217/// Applies a simple temporal smoothing (moving-window majority vote) over a binary mask.
218/// For each index, we look at the window of values [i-window, i+window] and set the smoothed
219/// value to true if at least half of the values in that window are true.
220fn smooth_mask(mask: &[bool], window: usize) -> Vec<bool> {
221    let n = mask.len();
222    let mut smoothed = vec![false; n];
223    for i in 0..n {
224        let start = if i < window { 0 } else { i - window };
225        let end = if i + window + 1 > n {
226            n
227        } else {
228            i + window + 1
229        };
230        let count_true = mask[start..end].iter().filter(|&&val| val).count();
231        if count_true * 2 >= (end - start) {
232            smoothed[i] = true;
233        }
234    }
235    smoothed
236}
237
238/// EdgeInfo is the result of Voice Activity Detection.
239/// `non_intersected_columns` are good places to cut and send to speech-to-text
240#[derive(Debug)]
241pub struct EdgeInfo {
242    non_intersected_columns: Vec<usize>,
243    intersected_columns: Vec<usize>,
244    gradient_positions: HashSet<(usize, usize)>,
245}
246
247impl EdgeInfo {
248    pub fn new(
249        non_intersected_columns: Vec<usize>,
250        intersected_columns: Vec<usize>,
251        gradient_positions: HashSet<(usize, usize)>,
252    ) -> Self {
253        EdgeInfo {
254            non_intersected_columns,
255            intersected_columns,
256            gradient_positions,
257        }
258    }
259
260    /// The x-index of frames that don't intersect an edge.
261    pub fn non_intersected(&self) -> Vec<usize> {
262        self.non_intersected_columns.clone()
263    }
264
265    /// The x-index of frames that intersect an edge.
266    pub fn intersected(&self) -> Vec<usize> {
267        self.intersected_columns.clone()
268    }
269
270    ///  A bitmap, primarily used by [`as_image`].
271    pub fn gradient_positions(&self) -> HashSet<(usize, usize)> {
272        self.gradient_positions.clone()
273    }
274}
275
276/// An image of the mel spectrogram, useful for testing detection settings.
277/// Edge detection is overlayed in red and boundary detection in green.
278pub fn as_image(
279    frames: &[Array2<f64>],
280    non_intersected_columns: &[usize],
281    gradient_positions: &HashSet<(usize, usize)>,
282) -> ImageBuffer<Rgb<u8>, Vec<u8>> {
283    let array_views: Vec<_> = frames.iter().map(|a| a.view()).collect();
284    let array_view = concatenate(Axis(1), &array_views).unwrap();
285    let shape = array_view.raw_dim();
286    let width = shape[1];
287    let height = shape[0];
288    let mut img_buffer = ImageBuffer::new(width as u32, height as u32);
289
290    let max_val = array_view.fold(0.0, |acc: f64, &val| acc.max(val));
291    let scaled_image: Array2<u8> = array_view.mapv(|val| (val * (255.0 / max_val)) as u8);
292
293    let tint_value = 200;
294
295    for (y, row) in scaled_image.outer_iter().rev().enumerate() {
296        for (x, &val) in row.into_iter().enumerate() {
297            let mut rgb_pixel = Rgb([val, val, val]);
298
299            if non_intersected_columns.contains(&x) {
300                if y < 10 {
301                    // Set the pixel to be entirely green for the top 10 rows
302                    let green_tint = Rgb([0, 255, 0]);
303                    rgb_pixel = green_tint;
304                } else {
305                    // Apply a subtle green tint to the pixel for the rest of the rows
306                    let green_tint_value = 60;
307                    let green_tint = Rgb([val, val.saturating_add(green_tint_value), val]);
308                    rgb_pixel = green_tint;
309                }
310            }
311
312            let inverted_y = height.checked_sub(y + 3).unwrap_or(0);
313            if gradient_positions.contains(&(x, inverted_y)) {
314                let tint = Rgb([tint_value, 0, 0]);
315                rgb_pixel = Rgb([
316                    rgb_pixel[0].saturating_add(tint[0]),
317                    rgb_pixel[1].saturating_add(tint[1]),
318                    rgb_pixel[2].saturating_add(tint[2]),
319                ]);
320            }
321
322            img_buffer.put_pixel(x as u32, y as u32, rgb_pixel);
323        }
324    }
325
326    img_buffer
327}
328
329/// Returns number of FFT frames are needed for nth milliseconds
330pub fn n_frames_for_duration(hop_size: usize, sampling_rate: f64, duration_ms: usize) -> usize {
331    let frame_duration = hop_size as f32 / sampling_rate as f32 * 1000.0;
332    let total_frames = (duration_ms as f32 / frame_duration).ceil() as u32;
333    total_frames as usize
334}
335
336/// Returns milliseconds nth FFT frames represent
337pub fn duration_ms_for_n_frames(hop_size: usize, sampling_rate: f64, total_frames: usize) -> usize {
338    let frame_duration = hop_size as f64 / sampling_rate * 1000.0;
339    (total_frames as f64 * frame_duration) as usize
340}
341
342/// Formats milliseconds to HH:MM:SS.MS
343pub fn format_milliseconds(milliseconds: u64) -> String {
344    let total_seconds = milliseconds / 1000;
345    let ms = milliseconds % 1000;
346    let seconds = total_seconds % 60;
347    let total_minutes = total_seconds / 60;
348    let minutes = total_minutes % 60;
349    let hours = total_minutes / 60;
350
351    format!("{:02}:{:02}:{:02}.{:03}", hours, minutes, seconds, ms)
352}
353
354/// Smoke test - see the generated `./test/vad.png`.
355/// green lines are the cutsecs predicted to not intersect speech,
356/// red pixels are the detected gradients.
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::quant::{load_tga_8bit, to_array2};
361
362    #[test]
363    fn test_speech_detection() {
364        let n_mels = 80;
365        let min_x = 10;
366        let settings = DetectionSettings {
367            min_energy: 1.0,
368            min_y: 10,
369            min_x,
370            min_mel: 0,
371        };
372
373        let ids = vec![21168, 23760, 41492, 41902, 63655, 7497, 39744];
374        for id in ids {
375            let file_path = format!("./testdata/blank/frame_{}.tga", id);
376            let dequantized_mel = load_tga_8bit(&file_path).unwrap();
377            let frames = to_array2(&dequantized_mel, n_mels);
378
379            let edge_info = vad_boundaries(&[frames.clone()], &settings);
380            let img = as_image(
381                &[frames.clone()],
382                &edge_info.non_intersected(),
383                &edge_info.gradient_positions(),
384            );
385
386            dbg!(file_path);
387            assert!(vad_on(&edge_info, min_x) == false);
388            let path = format!("./testdata/vad_off_{}.png", id);
389            img.save(path).unwrap();
390        }
391
392        let ids = vec![11648, 2889, 4694, 4901, 27125];
393        for id in ids {
394            let file_path = format!("./testdata/speech/frame_{}.tga", id);
395            let dequantized_mel = load_tga_8bit(&file_path).unwrap();
396            let frames = to_array2(&dequantized_mel, n_mels);
397
398            let edge_info = vad_boundaries(&[frames.clone()], &settings);
399            let img = as_image(
400                &[frames.clone()],
401                &edge_info.non_intersected(),
402                &edge_info.gradient_positions(),
403            );
404
405            assert!(vad_on(&edge_info, min_x) == true);
406            let path = format!("./testdata/vad_on_{}.png", id);
407            img.save(path).unwrap();
408
409            //assert!(edge_info.gradient_count > 800);
410        }
411    }
412
413    #[ignore]
414    #[test]
415    fn test_vad_debug() {
416        let n_mels = 80;
417        let settings = DetectionSettings {
418            min_energy: 1.0,
419            min_y: 6,
420            min_x: 1,
421            min_mel: 0,
422        };
423
424        let start = std::time::Instant::now();
425        let file_path = "./testdata/jfk_full_speech_chunk0_golden.tga";
426        let dequantized_mel = load_tga_8bit(file_path).unwrap();
427        let frames = to_array2(&dequantized_mel, n_mels);
428
429        let edge_info = vad_boundaries(&[frames.clone()], &settings);
430
431        let elapsed = start.elapsed().as_millis();
432        dbg!(elapsed);
433        let img = as_image(
434            &[frames.clone()],
435            &edge_info.non_intersected(),
436            &edge_info.gradient_positions(),
437        );
438
439        img.save("./doc/debug.png").unwrap();
440    }
441
442    #[test]
443    fn test_vad_boundaries() {
444        let n_mels = 80;
445        let settings = DetectionSettings {
446            min_energy: 1.0,
447            min_y: 3,
448            min_x: 6,
449            min_mel: 0,
450        };
451
452        let start = std::time::Instant::now();
453        let file_path = "./testdata/quantized_mel_golden.tga";
454        let dequantized_mel = load_tga_8bit(file_path).unwrap();
455        dbg!(&dequantized_mel);
456
457        let frames = to_array2(&dequantized_mel, n_mels);
458
459        let edge_info = vad_boundaries(&[frames.clone()], &settings);
460
461        let elapsed = start.elapsed().as_millis();
462        dbg!(elapsed);
463        let img = as_image(
464            &[frames.clone()],
465            &edge_info.non_intersected(),
466            &edge_info.gradient_positions(),
467        );
468
469        img.save("./doc/vad.png").unwrap();
470    }
471
472    #[ignore]
473    #[test]
474    fn test_stage() {
475        let n_mels = 80;
476        let settings = DetectionSettings {
477            min_energy: 1.0,
478            min_y: 3,
479            min_x: 3,
480            min_mel: 0,
481        };
482        let mut stage = VoiceActivityDetector::new(&settings);
483
484        let file_path = "./testdata/quantized_mel_golden.tga";
485        let dequantized_mel = load_tga_8bit(file_path).unwrap();
486        let frames = to_array2(&dequantized_mel, n_mels);
487        let chunk_size = 1;
488        let chunks: Vec<Array2<f64>> = frames
489            .axis_chunks_iter(Axis(1), chunk_size)
490            .map(|chunk| chunk.to_owned())
491            .collect();
492
493        let start = std::time::Instant::now();
494
495        for mel in &chunks {
496            if let Some(_) = stage.add(&mel) {}
497        }
498        let elapsed = start.elapsed().as_millis();
499        dbg!(elapsed);
500    }
501}