1use image::{ImageBuffer, Rgb};
2use ndarray::{concatenate, s, Array, Array2, Axis};
3use std::collections::HashSet;
4
5#[derive(Copy, Clone, Default)]
6pub struct DetectionSettings {
7 pub min_energy: f64,
8 pub min_y: usize,
9 pub min_x: usize,
10 pub min_mel: usize,
11}
12
13impl DetectionSettings {
41 pub fn new(min_energy: f64, min_y: usize, min_x: usize, min_mel: usize) -> Self {
42 Self {
43 min_energy,
44 min_y,
45 min_x,
46 min_mel,
47 }
48 }
49
50 pub fn min_energy(&self) -> f64 {
53 self.min_energy
54 }
55
56 pub fn min_y(&self) -> usize {
59 self.min_y
60 }
61
62 pub fn min_x(&self) -> usize {
66 self.min_x
67 }
68
69 pub fn min_mel(&self) -> usize {
72 self.min_mel
73 }
74}
75
76pub struct VoiceActivityDetector {
77 mel_buffer: Vec<Array2<f64>>,
78 settings: DetectionSettings,
79 idx: usize,
80}
81
82impl VoiceActivityDetector {
83 pub fn new(settings: &DetectionSettings) -> Self {
84 let mel_buffer: Vec<Array2<f64>> = Vec::new();
85
86 Self {
87 mel_buffer,
88 settings: settings.to_owned(),
89 idx: 0,
90 }
91 }
92
93 pub fn add(&mut self, frame: &Array2<f64>) -> Option<bool> {
95 let min_x = self.settings.min_x;
96 if self.idx == 128 {
97 self.mel_buffer = self.mel_buffer[(self.mel_buffer.len() - min_x)..].to_vec();
98 self.idx = min_x;
99 }
100 self.mel_buffer.push(frame.to_owned());
101 self.idx += 1;
102 if self.idx < min_x {
103 return None;
104 }
105
106 let window = &self.mel_buffer[self.idx - min_x..];
108 let edge_info = vad_boundaries(&window, &self.settings);
109 let ni = edge_info.intersected();
110 if ni.is_empty() {
111 Some(false)
112 } else {
113 Some(ni[0] == 0)
114 }
115 }
116}
117
118pub fn vad_on(edge_info: &EdgeInfo, n: usize) -> bool {
119 let intersected_columns = &edge_info.intersected_columns;
120
121 if intersected_columns.is_empty() {
122 return false;
123 }
124
125 let mut contiguous_count = 1;
126 let mut prev_index = intersected_columns[0];
127
128 for &index in &intersected_columns[1..] {
129 if index == prev_index + 1 {
130 contiguous_count += 1;
131 } else {
132 contiguous_count = 1;
133 }
134
135 if contiguous_count >= n {
136 return true;
137 }
138
139 prev_index = index;
140 }
141
142 false
143}
144
145pub fn vad_boundaries(frames: &[Array2<f64>], settings: &DetectionSettings) -> EdgeInfo {
146 let array_views: Vec<_> = frames.iter().map(|a| a.view()).collect();
147 let min_energy = settings.min_energy;
148 let min_y = settings.min_y;
149 let min_mel = settings.min_mel;
150
151 let merged_frames = concatenate(Axis(1), &array_views).unwrap();
153 let shape = merged_frames.raw_dim();
154 let width = shape[1];
155 let height = shape[0];
156
157 let sobel_x =
159 Array::from_shape_vec((3, 3), vec![-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0])
160 .unwrap();
161 let sobel_y =
162 Array::from_shape_vec((3, 3), vec![-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0])
163 .unwrap();
164
165 let gradient_mag = Array::from_shape_fn((height - 2, width - 2), |(y, x)| {
167 let view = merged_frames.slice(s![y..y + 3, x..x + 3]);
168 let mut gradient_x = 0.0;
169 let mut gradient_y = 0.0;
170 for j in 0..3 {
171 for i in 0..3 {
172 gradient_x += view[[j, i]] * sobel_x[[j, i]];
173 gradient_y += view[[j, i]] * sobel_y[[j, i]];
174 }
175 }
176 (gradient_x * gradient_x + gradient_y * gradient_y).sqrt()
177 });
178
179 let mut raw_classification = Vec::with_capacity(width - 2);
181 for x in 0..(width - 2) {
182 let mut count = 0;
183 for y in 0..(height - 2) {
184 let grad = gradient_mag[(y, x)];
185 if y >= min_mel && grad >= min_energy {
186 count += 1;
187 }
188 }
189 raw_classification.push(count >= min_y);
190 }
191
192 let smoothed_classification = smooth_mask(&raw_classification, 4);
195
196 let mut intersected_columns = Vec::new();
198 let mut non_intersected_columns = Vec::new();
199 for (x, &active) in smoothed_classification.iter().enumerate() {
200 if active {
201 intersected_columns.push(x);
202 } else {
203 non_intersected_columns.push(x);
204 }
205 }
206
207 let gradient_positions = HashSet::new();
209
210 EdgeInfo::new(
211 non_intersected_columns,
212 intersected_columns,
213 gradient_positions,
214 )
215}
216
217fn smooth_mask(mask: &[bool], window: usize) -> Vec<bool> {
221 let n = mask.len();
222 let mut smoothed = vec![false; n];
223 for i in 0..n {
224 let start = if i < window { 0 } else { i - window };
225 let end = if i + window + 1 > n {
226 n
227 } else {
228 i + window + 1
229 };
230 let count_true = mask[start..end].iter().filter(|&&val| val).count();
231 if count_true * 2 >= (end - start) {
232 smoothed[i] = true;
233 }
234 }
235 smoothed
236}
237
238#[derive(Debug)]
241pub struct EdgeInfo {
242 non_intersected_columns: Vec<usize>,
243 intersected_columns: Vec<usize>,
244 gradient_positions: HashSet<(usize, usize)>,
245}
246
247impl EdgeInfo {
248 pub fn new(
249 non_intersected_columns: Vec<usize>,
250 intersected_columns: Vec<usize>,
251 gradient_positions: HashSet<(usize, usize)>,
252 ) -> Self {
253 EdgeInfo {
254 non_intersected_columns,
255 intersected_columns,
256 gradient_positions,
257 }
258 }
259
260 pub fn non_intersected(&self) -> Vec<usize> {
262 self.non_intersected_columns.clone()
263 }
264
265 pub fn intersected(&self) -> Vec<usize> {
267 self.intersected_columns.clone()
268 }
269
270 pub fn gradient_positions(&self) -> HashSet<(usize, usize)> {
272 self.gradient_positions.clone()
273 }
274}
275
276pub fn as_image(
279 frames: &[Array2<f64>],
280 non_intersected_columns: &[usize],
281 gradient_positions: &HashSet<(usize, usize)>,
282) -> ImageBuffer<Rgb<u8>, Vec<u8>> {
283 let array_views: Vec<_> = frames.iter().map(|a| a.view()).collect();
284 let array_view = concatenate(Axis(1), &array_views).unwrap();
285 let shape = array_view.raw_dim();
286 let width = shape[1];
287 let height = shape[0];
288 let mut img_buffer = ImageBuffer::new(width as u32, height as u32);
289
290 let max_val = array_view.fold(0.0, |acc: f64, &val| acc.max(val));
291 let scaled_image: Array2<u8> = array_view.mapv(|val| (val * (255.0 / max_val)) as u8);
292
293 let tint_value = 200;
294
295 for (y, row) in scaled_image.outer_iter().rev().enumerate() {
296 for (x, &val) in row.into_iter().enumerate() {
297 let mut rgb_pixel = Rgb([val, val, val]);
298
299 if non_intersected_columns.contains(&x) {
300 if y < 10 {
301 let green_tint = Rgb([0, 255, 0]);
303 rgb_pixel = green_tint;
304 } else {
305 let green_tint_value = 60;
307 let green_tint = Rgb([val, val.saturating_add(green_tint_value), val]);
308 rgb_pixel = green_tint;
309 }
310 }
311
312 let inverted_y = height.checked_sub(y + 3).unwrap_or(0);
313 if gradient_positions.contains(&(x, inverted_y)) {
314 let tint = Rgb([tint_value, 0, 0]);
315 rgb_pixel = Rgb([
316 rgb_pixel[0].saturating_add(tint[0]),
317 rgb_pixel[1].saturating_add(tint[1]),
318 rgb_pixel[2].saturating_add(tint[2]),
319 ]);
320 }
321
322 img_buffer.put_pixel(x as u32, y as u32, rgb_pixel);
323 }
324 }
325
326 img_buffer
327}
328
329pub fn n_frames_for_duration(hop_size: usize, sampling_rate: f64, duration_ms: usize) -> usize {
331 let frame_duration = hop_size as f32 / sampling_rate as f32 * 1000.0;
332 let total_frames = (duration_ms as f32 / frame_duration).ceil() as u32;
333 total_frames as usize
334}
335
336pub fn duration_ms_for_n_frames(hop_size: usize, sampling_rate: f64, total_frames: usize) -> usize {
338 let frame_duration = hop_size as f64 / sampling_rate * 1000.0;
339 (total_frames as f64 * frame_duration) as usize
340}
341
342pub fn format_milliseconds(milliseconds: u64) -> String {
344 let total_seconds = milliseconds / 1000;
345 let ms = milliseconds % 1000;
346 let seconds = total_seconds % 60;
347 let total_minutes = total_seconds / 60;
348 let minutes = total_minutes % 60;
349 let hours = total_minutes / 60;
350
351 format!("{:02}:{:02}:{:02}.{:03}", hours, minutes, seconds, ms)
352}
353
354#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::quant::{load_tga_8bit, to_array2};
361
362 #[test]
363 fn test_speech_detection() {
364 let n_mels = 80;
365 let min_x = 10;
366 let settings = DetectionSettings {
367 min_energy: 1.0,
368 min_y: 10,
369 min_x,
370 min_mel: 0,
371 };
372
373 let ids = vec![21168, 23760, 41492, 41902, 63655, 7497, 39744];
374 for id in ids {
375 let file_path = format!("./testdata/blank/frame_{}.tga", id);
376 let dequantized_mel = load_tga_8bit(&file_path).unwrap();
377 let frames = to_array2(&dequantized_mel, n_mels);
378
379 let edge_info = vad_boundaries(&[frames.clone()], &settings);
380 let img = as_image(
381 &[frames.clone()],
382 &edge_info.non_intersected(),
383 &edge_info.gradient_positions(),
384 );
385
386 dbg!(file_path);
387 assert!(vad_on(&edge_info, min_x) == false);
388 let path = format!("./testdata/vad_off_{}.png", id);
389 img.save(path).unwrap();
390 }
391
392 let ids = vec![11648, 2889, 4694, 4901, 27125];
393 for id in ids {
394 let file_path = format!("./testdata/speech/frame_{}.tga", id);
395 let dequantized_mel = load_tga_8bit(&file_path).unwrap();
396 let frames = to_array2(&dequantized_mel, n_mels);
397
398 let edge_info = vad_boundaries(&[frames.clone()], &settings);
399 let img = as_image(
400 &[frames.clone()],
401 &edge_info.non_intersected(),
402 &edge_info.gradient_positions(),
403 );
404
405 assert!(vad_on(&edge_info, min_x) == true);
406 let path = format!("./testdata/vad_on_{}.png", id);
407 img.save(path).unwrap();
408
409 }
411 }
412
413 #[ignore]
414 #[test]
415 fn test_vad_debug() {
416 let n_mels = 80;
417 let settings = DetectionSettings {
418 min_energy: 1.0,
419 min_y: 6,
420 min_x: 1,
421 min_mel: 0,
422 };
423
424 let start = std::time::Instant::now();
425 let file_path = "./testdata/jfk_full_speech_chunk0_golden.tga";
426 let dequantized_mel = load_tga_8bit(file_path).unwrap();
427 let frames = to_array2(&dequantized_mel, n_mels);
428
429 let edge_info = vad_boundaries(&[frames.clone()], &settings);
430
431 let elapsed = start.elapsed().as_millis();
432 dbg!(elapsed);
433 let img = as_image(
434 &[frames.clone()],
435 &edge_info.non_intersected(),
436 &edge_info.gradient_positions(),
437 );
438
439 img.save("./doc/debug.png").unwrap();
440 }
441
442 #[test]
443 fn test_vad_boundaries() {
444 let n_mels = 80;
445 let settings = DetectionSettings {
446 min_energy: 1.0,
447 min_y: 3,
448 min_x: 6,
449 min_mel: 0,
450 };
451
452 let start = std::time::Instant::now();
453 let file_path = "./testdata/quantized_mel_golden.tga";
454 let dequantized_mel = load_tga_8bit(file_path).unwrap();
455 dbg!(&dequantized_mel);
456
457 let frames = to_array2(&dequantized_mel, n_mels);
458
459 let edge_info = vad_boundaries(&[frames.clone()], &settings);
460
461 let elapsed = start.elapsed().as_millis();
462 dbg!(elapsed);
463 let img = as_image(
464 &[frames.clone()],
465 &edge_info.non_intersected(),
466 &edge_info.gradient_positions(),
467 );
468
469 img.save("./doc/vad.png").unwrap();
470 }
471
472 #[ignore]
473 #[test]
474 fn test_stage() {
475 let n_mels = 80;
476 let settings = DetectionSettings {
477 min_energy: 1.0,
478 min_y: 3,
479 min_x: 3,
480 min_mel: 0,
481 };
482 let mut stage = VoiceActivityDetector::new(&settings);
483
484 let file_path = "./testdata/quantized_mel_golden.tga";
485 let dequantized_mel = load_tga_8bit(file_path).unwrap();
486 let frames = to_array2(&dequantized_mel, n_mels);
487 let chunk_size = 1;
488 let chunks: Vec<Array2<f64>> = frames
489 .axis_chunks_iter(Axis(1), chunk_size)
490 .map(|chunk| chunk.to_owned())
491 .collect();
492
493 let start = std::time::Instant::now();
494
495 for mel in &chunks {
496 if let Some(_) = stage.add(&mel) {}
497 }
498 let elapsed = start.elapsed().as_millis();
499 dbg!(elapsed);
500 }
501}