1use super::{VADOption, VadEngine};
2use crate::media::{AudioFrame, Samples};
3use anyhow::Result;
4use realfft::{RealFftPlanner, RealToComplex};
5use std::sync::Arc;
6
7const SAMPLE_RATE: u32 = 16000;
9const HOP_SIZE: usize = 256; const FFT_SIZE: usize = 1024;
11const WINDOW_SIZE: usize = 768;
12const MEL_FILTER_BANK_NUM: usize = 40;
13const FEATURE_LEN: usize = 41; const CONTEXT_WINDOW_LEN: usize = 3;
15const HIDDEN_SIZE: usize = 64;
16const EPS: f32 = 1e-20;
17const PRE_EMPHASIS_COEFF: f32 = 0.97;
18
19const FEATURE_MEANS: [f32; FEATURE_LEN] = [
21 -8.198_236,
22 -6.265_716_6,
23 -5.483_818_5,
24 -4.758_691_3,
25 -4.417_089,
26 -4.142_893,
27 -3.912_850_4,
28 -3.845_928,
29 -3.657_090_4,
30 -3.723_418_7,
31 -3.876_134_2,
32 -3.843_891,
33 -3.690_405_1,
34 -3.756_065_8,
35 -3.698_696_1,
36 -3.650_463,
37 -3.700_468_8,
38 -3.567_321_3,
39 -3.498_900_2,
40 -3.477_807,
41 -3.458_816,
42 -3.444_923_9,
43 -3.401_328_6,
44 -3.306_261_3,
45 -3.278_556_8,
46 -3.233_250_9,
47 -3.198_616,
48 -3.204_526_4,
49 -3.208_798_6,
50 -3.257_838,
51 -3.381_376_7,
52 -3.534_021_4,
53 -3.640_868,
54 -3.726_858_9,
55 -3.773_731,
56 -3.804_667_2,
57 -3.832_901,
58 -3.871_120_5,
59 -3.990_593,
60 -4.480_289_5,
61 9.235_69e1,
62];
63
64const FEATURE_STDS: [f32; FEATURE_LEN] = [
65 5.166_064,
66 4.977_21,
67 4.698_896,
68 4.630_621_4,
69 4.634_348,
70 4.641_156,
71 4.640_676_5,
72 4.666_367,
73 4.650_534_6,
74 4.640_021,
75 4.637_4,
76 4.620_099,
77 4.596_316_3,
78 4.562_655,
79 4.554_36,
80 4.566_910_7,
81 4.562_49,
82 4.562_413,
83 4.585_299_5,
84 4.600_179_7,
85 4.592_846,
86 4.585_923,
87 4.583_496_6,
88 4.626_093,
89 4.626_958,
90 4.626_289_4,
91 4.637_006,
92 4.683_016,
93 4.726_814,
94 4.734_29,
95 4.753_227,
96 4.849_723,
97 4.869_435,
98 4.884_483,
99 4.921_327,
100 4.959_212_3,
101 4.996_619,
102 5.044_823_6,
103 5.072_217,
104 5.096_439_4,
105 1.152_136_9e2,
106];
107
108pub struct TenFeatureExtractor {
109 mel_filters: ndarray::Array2<f32>,
110 mel_filter_ranges: Vec<(usize, usize)>,
111 window: Vec<f32>,
112 rfft: Arc<dyn RealToComplex<f32>>,
114 fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
115 fft_output: Vec<realfft::num_complex::Complex<f32>>,
116 fft_input: Vec<f32>,
117 power_spectrum: Vec<f32>,
118 inv_stds: Vec<f32>,
119}
120
121impl TenFeatureExtractor {
122 pub fn new() -> Self {
123 let (mel_filters, mel_filter_ranges) = Self::generate_mel_filters();
125
126 let window = super::utils::generate_hann_window(WINDOW_SIZE, false);
128
129 let mut planner = RealFftPlanner::<f32>::new();
131 let rfft = planner.plan_fft_forward(FFT_SIZE);
132 let fft_scratch = rfft.make_scratch_vec();
133 let fft_output = rfft.make_output_vec();
134 let fft_input = rfft.make_input_vec();
135 let power_spectrum = vec![0.0; FFT_SIZE / 2 + 1];
136
137 let inv_stds: Vec<f32> = FEATURE_STDS.iter().map(|&std| 1.0 / (std + EPS)).collect();
139
140 Self {
141 mel_filters,
142 mel_filter_ranges,
143 window,
144 rfft,
145 fft_scratch,
146 fft_output,
147 fft_input,
148 power_spectrum,
149 inv_stds,
150 }
151 }
152
153 fn generate_mel_filters() -> (ndarray::Array2<f32>, Vec<(usize, usize)>) {
154 let n_bins = FFT_SIZE / 2 + 1;
155
156 let low_mel = 2595.0_f32 * (1.0_f32 + 0.0_f32 / 700.0_f32).log10();
158 let high_mel = 2595.0_f32 * (1.0_f32 + 8000.0_f32 / 700.0_f32).log10();
159
160 let mut mel_points = Vec::new();
161 for i in 0..=MEL_FILTER_BANK_NUM + 1 {
162 let mel = low_mel + (high_mel - low_mel) * i as f32 / (MEL_FILTER_BANK_NUM + 1) as f32;
163 mel_points.push(mel);
164 }
165
166 let mut hz_points = Vec::new();
168 for mel in mel_points {
169 let hz = 700.0_f32 * (10.0_f32.powf(mel / 2595.0_f32) - 1.0_f32);
170 hz_points.push(hz);
171 }
172
173 let mut bin_points = Vec::new();
175 for hz in hz_points {
176 let bin = ((FFT_SIZE + 1) as f32 * hz / SAMPLE_RATE as f32).floor() as usize;
177 bin_points.push(bin);
178 }
179
180 let mut mel_filters = ndarray::Array2::<f32>::zeros((MEL_FILTER_BANK_NUM, n_bins));
182 let mut ranges = Vec::with_capacity(MEL_FILTER_BANK_NUM);
183
184 for i in 0..MEL_FILTER_BANK_NUM {
185 let start = bin_points[i];
186 let end = bin_points[i + 2];
187 ranges.push((start, end));
188
189 for j in bin_points[i]..bin_points[i + 1] {
191 if j < n_bins {
192 mel_filters[[i, j]] =
193 (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32;
194 }
195 }
196
197 for j in bin_points[i + 1]..bin_points[i + 2] {
199 if j < n_bins {
200 mel_filters[[i, j]] = (bin_points[i + 2] - j) as f32
201 / (bin_points[i + 2] - bin_points[i + 1]) as f32;
202 }
203 }
204 }
205
206 (mel_filters, ranges)
207 }
208
209 pub fn extract_features(&mut self, audio_frame: &[f32]) -> ndarray::Array1<f32> {
210 self.fft_input.fill(0.0);
213
214 let copy_len = audio_frame.len().min(WINDOW_SIZE);
216 self.fft_input[..copy_len].copy_from_slice(audio_frame);
217
218 for (i, sample) in self.fft_input.iter_mut().enumerate().take(copy_len) {
220 *sample *= self.window[i];
221 }
222
223 self.rfft
225 .process_with_scratch(
226 &mut self.fft_input,
227 &mut self.fft_output,
228 &mut self.fft_scratch,
229 )
230 .unwrap();
231
232 let n_bins = FFT_SIZE / 2 + 1;
234 let scale = 1.0 / FFT_SIZE as f32;
235
236 for (pow, complex) in self.power_spectrum.iter_mut().zip(self.fft_output.iter()) {
239 *pow = (complex.re * complex.re + complex.im * complex.im) * scale;
240 }
241
242 let mut mel_features = ndarray::Array1::<f32>::zeros(MEL_FILTER_BANK_NUM);
244
245 for i in 0..MEL_FILTER_BANK_NUM {
246 let (start, end) = self.mel_filter_ranges[i];
247 let valid_end = end.min(n_bins);
248
249 let mut sum = 0.0;
250 if start < valid_end {
251 let filter_row = self.mel_filters.row(i);
253 if let Some(filter_slice) = filter_row.as_slice() {
256 let filter_sub = &filter_slice[start..valid_end];
257 let power_sub = &self.power_spectrum[start..valid_end];
258
259 sum = super::simd::dot_product(filter_sub, power_sub);
261 } else {
262 for j in start..valid_end {
264 sum += self.mel_filters[[i, j]] * self.power_spectrum[j];
265 }
266 }
267 }
268 mel_features[i] = (sum + EPS).ln();
269 }
270
271 let pitch_freq = 0.0;
273
274 let mut features = ndarray::Array1::<f32>::zeros(FEATURE_LEN);
276 features
277 .slice_mut(ndarray::s![..MEL_FILTER_BANK_NUM])
278 .assign(&mel_features);
279 features[MEL_FILTER_BANK_NUM] = pitch_freq;
280
281 for (feat, (&mean, &inv_std)) in features
284 .iter_mut()
285 .zip(FEATURE_MEANS.iter().zip(self.inv_stds.iter()))
286 {
287 *feat = (*feat - mean) * inv_std;
288 }
289
290 features
291 }
292}
293
294#[derive(Clone, Debug)]
296struct Tensor3D {
297 data: Vec<f32>,
298 h: usize,
299 w: usize,
300 c: usize,
301}
302
303impl Tensor3D {
304 fn new(h: usize, w: usize, c: usize) -> Self {
305 Self {
306 data: vec![0.0; h * w * c],
307 h,
308 w,
309 c,
310 }
311 }
312
313 fn zeros(&mut self) {
314 self.data.fill(0.0);
315 }
316
317 #[inline(always)]
318 fn get(&self, y: usize, x: usize, ch: usize) -> f32 {
319 self.data[y * self.w * self.c + x * self.c + ch]
321 }
322
323 #[inline(always)]
324 fn set(&mut self, y: usize, x: usize, ch: usize, val: f32) {
325 self.data[y * self.w * self.c + x * self.c + ch] = val;
326 }
327}
328
329struct Conv2dLayer {
331 weights: Vec<f32>, bias: Option<Vec<f32>>, in_channels: usize,
334 out_channels: usize,
335 kernel_h: usize,
336 kernel_w: usize,
337 stride_h: usize,
338 stride_w: usize,
339 padding: [usize; 4], groups: usize,
341}
342
343impl Conv2dLayer {
344 fn new(
345 in_channels: usize,
346 out_channels: usize,
347 kernel_h: usize,
348 kernel_w: usize,
349 stride_h: usize,
350 stride_w: usize,
351 padding: [usize; 4],
352 groups: usize,
353 ) -> Self {
354 Self {
355 weights: vec![0.0; out_channels * (in_channels / groups) * kernel_h * kernel_w],
356 bias: None,
357 in_channels,
358 out_channels,
359 kernel_h,
360 kernel_w,
361 stride_h,
362 stride_w,
363 padding,
364 groups,
365 }
366 }
367
368 fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
370 let out_h = output.h;
371 let out_w = output.w;
372
373 if self.in_channels == 1
376 && self.out_channels == 1
377 && self.kernel_h == 3
378 && self.kernel_w == 3
379 && self.stride_h == 1
380 && self.stride_w == 1
381 && self.padding == [0, 0, 0, 0]
382 {
383 let bias = self.bias.as_ref().map(|b| b[0]).unwrap_or(0.0);
384 let w = &self.weights; for x in 0..out_w {
389 let mut sum = bias;
392
393 sum += input.get(0, x, 0) * w[0];
396 sum += input.get(0, x + 1, 0) * w[1];
397 sum += input.get(0, x + 2, 0) * w[2];
398
399 sum += input.get(1, x, 0) * w[3];
401 sum += input.get(1, x + 1, 0) * w[4];
402 sum += input.get(1, x + 2, 0) * w[5];
403
404 sum += input.get(2, x, 0) * w[6];
406 sum += input.get(2, x + 1, 0) * w[7];
407 sum += input.get(2, x + 2, 0) * w[8];
408
409 output.set(0, x, 0, sum);
410 }
411 return;
412 }
413
414 if self.in_channels == 1
417 && self.out_channels == 16
418 && self.kernel_h == 1
419 && self.kernel_w == 1
420 && self.stride_h == 1
421 && self.stride_w == 1
422 {
423 let w = &self.weights; let b = self.bias.as_ref(); for x in 0..out_w {
427 let val = input.get(0, x, 0);
428
429 for oc in 0..16 {
431 let bias = if let Some(bias_vec) = b {
432 bias_vec[oc]
433 } else {
434 0.0
435 };
436 let res = val * w[oc] + bias;
437 output.set(0, x, oc, res);
438 }
439 }
440 return;
441 }
442
443 if self.groups == 16
446 && self.in_channels == 16
447 && self.out_channels == 16
448 && self.kernel_h == 1
449 && self.kernel_w == 3
450 && self.stride_w == 2
451 && self.padding == [0, 1, 0, 1]
452 {
453 let w = &self.weights; let b = self.bias.as_ref();
455
456 for c in 0..16 {
457 let w_offset = c * 3;
458 let w0 = w[w_offset];
459 let w1 = w[w_offset + 1];
460 let w2 = w[w_offset + 2];
461 let bias = if let Some(bias_vec) = b {
462 bias_vec[c]
463 } else {
464 0.0
465 };
466
467 let val0 = input.get(0, 0, c);
469 let val1 = input.get(0, 1, c);
470 let sum0 = val0 * w1 + val1 * w2 + bias;
471 output.set(0, 0, c, sum0);
472
473 for x in 1..9 {
480 let in_x_origin = x * 2 - 1;
481 let v0 = input.get(0, in_x_origin, c);
482 let v1 = input.get(0, in_x_origin + 1, c);
483 let v2 = input.get(0, in_x_origin + 2, c);
484 let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
485 output.set(0, x, c, sum);
486 }
487
488 let v0 = input.get(0, 17, c);
490 let v1 = input.get(0, 18, c);
491 let sum9 = v0 * w0 + v1 * w1 + bias;
492 output.set(0, 9, c, sum9);
493 }
494 return;
495 }
496
497 if self.in_channels == 16
500 && self.out_channels == 16
501 && self.kernel_h == 1
502 && self.kernel_w == 1
503 && self.stride_h == 1
504 && self.stride_w == 1
505 && self.groups == 1
506 {
507 let w = &self.weights; let b = self.bias.as_ref();
509
510 for x in 0..out_w {
511 let mut in_vals = [0.0; 16];
513 for ic in 0..16 {
514 in_vals[ic] = input.get(0, x, ic);
515 }
516
517 for oc in 0..16 {
518 let mut sum = if let Some(bias_vec) = b {
519 bias_vec[oc]
520 } else {
521 0.0
522 };
523 let w_offset = oc * 16;
524
525 for ic in 0..16 {
527 sum += in_vals[ic] * w[w_offset + ic];
528 }
529 output.set(0, x, oc, sum);
530 }
531 }
532 return;
533 }
534
535 if self.groups == 16
538 && self.in_channels == 16
539 && self.out_channels == 16
540 && self.kernel_h == 1
541 && self.kernel_w == 3
542 && self.stride_w == 2
543 && self.padding == [0, 1, 0, 1]
544 && out_w == 5
545 {
546 let w = &self.weights;
547 let b = self.bias.as_ref();
548
549 for c in 0..16 {
550 let w_offset = c * 3;
551 let w0 = w[w_offset];
552 let w1 = w[w_offset + 1];
553 let w2 = w[w_offset + 2];
554 let bias = if let Some(bias_vec) = b {
555 bias_vec[c]
556 } else {
557 0.0
558 };
559
560 let val0 = input.get(0, 0, c);
562 let val1 = input.get(0, 1, c);
563 let sum0 = val0 * w1 + val1 * w2 + bias;
564 output.set(0, 0, c, sum0);
565
566 for x in 1..5 {
569 let in_x_origin = x * 2 - 1;
570 let v0 = input.get(0, in_x_origin, c);
571 let v1 = input.get(0, in_x_origin + 1, c);
572 let v2 = input.get(0, in_x_origin + 2, c);
573 let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
574 output.set(0, x, c, sum);
575 }
576 }
577 return;
578 }
579
580 if self.in_channels == 16
583 && self.out_channels == 32
584 && self.kernel_h == 1
585 && self.kernel_w == 1
586 && self.stride_h == 1
587 && self.stride_w == 1
588 && self.groups == 1
589 {
590 let w = &self.weights; let b = self.bias.as_ref();
592
593 for x in 0..out_w {
594 let mut in_vals = [0.0; 16];
595 for ic in 0..16 {
596 in_vals[ic] = input.get(0, x, ic);
597 }
598
599 for oc in 0..32 {
600 let mut sum = if let Some(bias_vec) = b {
601 bias_vec[oc]
602 } else {
603 0.0
604 };
605 let w_offset = oc * 16;
606
607 for ic in 0..16 {
608 sum += in_vals[ic] * w[w_offset + ic];
609 }
610 output.set(0, x, oc, sum);
611 }
612 }
613 return;
614 }
615
616 output.zeros();
618
619 let in_c_per_group = self.in_channels / self.groups;
620 let out_c_per_group = self.out_channels / self.groups;
621
622 if let Some(b) = &self.bias {
627 for g in 0..self.groups {
628 for oc in 0..out_c_per_group {
629 let out_ch_idx = g * out_c_per_group + oc;
630 let bias_val = b[out_ch_idx];
631 for y in 0..out_h {
633 for x in 0..out_w {
634 output.set(y, x, out_ch_idx, bias_val);
635 }
636 }
637 }
638 }
639 }
640
641 for g in 0..self.groups {
642 for oc in 0..out_c_per_group {
643 let out_ch_idx = g * out_c_per_group + oc;
644
645 let w_base = out_ch_idx * (in_c_per_group * self.kernel_h * self.kernel_w);
647
648 for y in 0..out_h {
649 let in_y_origin = (y * self.stride_h) as isize - self.padding[0] as isize;
650
651 for x in 0..out_w {
652 let in_x_origin = (x * self.stride_w) as isize - self.padding[1] as isize;
653
654 let mut sum = 0.0;
655
656 for ic in 0..in_c_per_group {
657 let in_ch_idx = g * in_c_per_group + ic;
658 let w_ic_base = w_base + ic * (self.kernel_h * self.kernel_w);
659
660 for ky in 0..self.kernel_h {
661 let in_y = in_y_origin + ky as isize;
662 if in_y >= 0 && in_y < input.h as isize {
663 let w_ky_base = w_ic_base + ky * self.kernel_w;
664
665 for kx in 0..self.kernel_w {
666 let in_x = in_x_origin + kx as isize;
667
668 if in_x >= 0 && in_x < input.w as isize {
669 let val =
671 input.get(in_y as usize, in_x as usize, in_ch_idx);
672 let w_idx = w_ky_base + kx;
673 let w = unsafe { *self.weights.get_unchecked(w_idx) };
675 sum += val * w;
676 }
677 }
678 }
679 }
680 }
681
682 let current = output.get(y, x, out_ch_idx);
684 output.set(y, x, out_ch_idx, current + sum);
685 }
686 }
687 }
688 }
689 }
690}
691
692struct MaxPool2dLayer {
694 kernel_h: usize,
695 kernel_w: usize,
696 stride_h: usize,
697 stride_w: usize,
698}
699
700impl MaxPool2dLayer {
701 fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
702 let out_h = output.h;
703 let out_w = output.w;
704
705 if self.kernel_h == 1 && self.kernel_w == 3 && self.stride_h == 1 && self.stride_w == 2 {
707 for c in 0..input.c {
708 for x in 0..out_w {
710 let in_x = x * 2;
711 let v0 = input.get(0, in_x, c);
713 let v1 = input.get(0, in_x + 1, c);
714 let v2 = input.get(0, in_x + 2, c);
715
716 let max_v = v0.max(v1).max(v2);
717 output.set(0, x, c, max_v);
718 }
719 }
720 return;
721 }
722
723 for c in 0..input.c {
724 for y in 0..out_h {
725 for x in 0..out_w {
726 let mut max_val = f32::NEG_INFINITY;
727
728 for ky in 0..self.kernel_h {
729 for kx in 0..self.kernel_w {
730 let in_y = y * self.stride_h + ky;
731 let in_x = x * self.stride_w + kx;
732 let val = input.get(in_y, in_x, c);
735 if val > max_val {
736 max_val = val;
737 }
738 }
739 }
740 output.set(y, x, c, max_val);
741 }
742 }
743 }
744 }
745}
746
747struct LinearLayer {
749 weights: Vec<f32>, bias: Vec<f32>, in_features: usize,
752 out_features: usize,
753}
754
755impl LinearLayer {
756 fn new(in_features: usize, out_features: usize) -> Self {
757 Self {
761 weights: vec![0.0; out_features * in_features],
762 bias: vec![0.0; out_features],
763 in_features,
764 out_features,
765 }
766 }
767
768 fn forward(&self, input: &[f32], output: &mut [f32]) {
769 assert_eq!(input.len(), self.in_features);
770 assert_eq!(output.len(), self.out_features);
771
772 for (i, out_val) in output.iter_mut().enumerate() {
775 let weight_row_start = i * self.in_features;
776 let weight_row = &self.weights[weight_row_start..weight_row_start + self.in_features];
777
778 let dot_product: f32 = weight_row
779 .iter()
780 .zip(input.iter())
781 .map(|(&w, &x)| w * x)
782 .sum();
783
784 *out_val = dot_product + self.bias[i];
785 }
786 }
787}
788
789struct LstmLayer {
791 input_size: usize,
792 hidden_size: usize,
793 weight_ih: Vec<f32>, weight_hh: Vec<f32>, bias_ih: Vec<f32>, bias_hh: Vec<f32>, gates_buffer: Vec<f32>, }
802
803impl LstmLayer {
804 fn new(input_size: usize, hidden_size: usize) -> Self {
805 Self {
806 input_size,
807 hidden_size,
808 weight_ih: vec![0.0; 4 * hidden_size * input_size],
809 weight_hh: vec![0.0; 4 * hidden_size * hidden_size],
810 bias_ih: vec![0.0; 4 * hidden_size],
811 bias_hh: vec![0.0; 4 * hidden_size],
812 gates_buffer: vec![0.0; 4 * hidden_size],
813 }
814 }
815
816 fn forward_optimized(&mut self, input: &[f32], hidden: &mut [f32], cell: &mut [f32]) {
817 let h_size = self.hidden_size;
818
819 for i in 0..4 * h_size {
821 let w_start = i * self.input_size;
822 let w_row = &self.weight_ih[w_start..w_start + self.input_size];
823 let dot: f32 = w_row.iter().zip(input).map(|(&w, &x)| w * x).sum();
824 self.gates_buffer[i] = dot + self.bias_ih[i];
825 }
826
827 for i in 0..4 * h_size {
830 let w_start = i * h_size;
831 let w_row = &self.weight_hh[w_start..w_start + h_size];
832 let dot: f32 = w_row.iter().zip(hidden.iter()).map(|(&w, &h)| w * h).sum();
833 self.gates_buffer[i] += dot + self.bias_hh[i];
834 }
835
836 for i in 0..h_size {
839 let i_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i]);
840 let o_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + h_size]);
841 let f_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + 2 * h_size]);
842 let g_gate = crate::media::vad::utils::tanh(self.gates_buffer[i + 3 * h_size]);
843
844 cell[i] = f_gate * cell[i] + i_gate * g_gate;
846
847 hidden[i] = o_gate * crate::media::vad::utils::tanh(cell[i]);
849 }
850 }
851}
852
853pub struct TinyTen {
854 config: VADOption,
855 buffer: Vec<f32>, pre_emphasis_prev: f32,
857 current_timestamp: u64,
858 processed_samples: u64,
859 initialized_timestamp: bool,
860
861 feature_extractor: TenFeatureExtractor,
862 feature_buffer: ndarray::Array2<f32>,
863
864 conv1_dw: Conv2dLayer,
867 conv1_pw: Conv2dLayer,
868 maxpool: MaxPool2dLayer,
869
870 conv2_dw: Conv2dLayer,
872 conv2_pw: Conv2dLayer,
873
874 conv3_dw: Conv2dLayer,
876 conv3_pw: Conv2dLayer,
877
878 lstm1: LstmLayer,
879 lstm2: LstmLayer,
880 dense1: LinearLayer,
881 dense2: LinearLayer,
882
883 h1: Vec<f32>,
885 c1: Vec<f32>,
886 h2: Vec<f32>,
887 c2: Vec<f32>,
888
889 t_input: Tensor3D,
891 t_conv1_dw: Tensor3D,
892 t_conv1_pw: Tensor3D,
893 t_maxpool: Tensor3D,
894 t_conv2_dw: Tensor3D,
895 t_conv2_pw: Tensor3D,
896 t_conv3_dw: Tensor3D,
897 t_conv3_pw: Tensor3D,
898
899 dense_input_buffer: Vec<f32>,
900 dense1_out_buffer: Vec<f32>,
901
902 last_score: Option<f32>,
903}
904
905const WEIGHTS_BYTES: &[u8] = include_bytes!("tiny_tenvad.bin");
906
907impl TinyTen {
908 pub fn new(config: VADOption) -> Result<Self> {
909 if config.samplerate != 16000 {
910 return Err(anyhow::anyhow!("TinyVad only supports 16kHz audio"));
911 }
912
913 let feature_extractor = TenFeatureExtractor::new();
914 let feature_buffer = ndarray::Array2::<f32>::zeros((CONTEXT_WINDOW_LEN, FEATURE_LEN));
915
916 let conv1_dw = Conv2dLayer::new(1, 1, 3, 3, 1, 1, [0, 0, 0, 0], 1);
920 let conv1_pw = Conv2dLayer::new(1, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
922
923 let maxpool = MaxPool2dLayer {
925 kernel_h: 1,
926 kernel_w: 3,
927 stride_h: 1,
928 stride_w: 2,
929 };
930
931 let conv2_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 1, 0, 1], 16);
934 let conv2_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
936
937 let conv3_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 0, 0, 1], 16);
940 let conv3_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
942
943 let lstm1 = LstmLayer::new(80, HIDDEN_SIZE);
945 let lstm2 = LstmLayer::new(HIDDEN_SIZE, HIDDEN_SIZE);
946
947 let dense1 = LinearLayer::new(HIDDEN_SIZE * 2, 32);
948 let dense2 = LinearLayer::new(32, 1);
949
950 let t_input = Tensor3D::new(CONTEXT_WINDOW_LEN, FEATURE_LEN, 1);
952 let t_conv1_dw = Tensor3D::new(1, 39, 1);
953 let t_conv1_pw = Tensor3D::new(1, 39, 16);
954 let t_maxpool = Tensor3D::new(1, 19, 16);
955 let t_conv2_dw = Tensor3D::new(1, 10, 16);
956 let t_conv2_pw = Tensor3D::new(1, 10, 16);
957 let t_conv3_dw = Tensor3D::new(1, 5, 16);
958 let t_conv3_pw = Tensor3D::new(1, 5, 16);
959
960 let dense_input_buffer = vec![0.0; HIDDEN_SIZE * 2];
961 let dense1_out_buffer = vec![0.0; 32];
962
963 let mut vad = Self {
964 config,
965 buffer: Vec::with_capacity(768),
966 pre_emphasis_prev: 0.0,
967 current_timestamp: 0,
968 processed_samples: 0,
969 initialized_timestamp: false,
970 feature_extractor,
971 feature_buffer,
972 conv1_dw,
973 conv1_pw,
974 maxpool,
975 conv2_dw,
976 conv2_pw,
977 conv3_dw,
978 conv3_pw,
979 lstm1,
980 lstm2,
981 dense1,
982 dense2,
983 h1: vec![0.0; HIDDEN_SIZE],
984 c1: vec![0.0; HIDDEN_SIZE],
985 h2: vec![0.0; HIDDEN_SIZE],
986 c2: vec![0.0; HIDDEN_SIZE],
987 t_input,
988 t_conv1_dw,
989 t_conv1_pw,
990 t_maxpool,
991 t_conv2_dw,
992 t_conv2_pw,
993 t_conv3_dw,
994 t_conv3_pw,
995 dense_input_buffer,
996 dense1_out_buffer,
997 last_score: None,
998 };
999
1000 vad.load_weights_from_bytes(WEIGHTS_BYTES)?;
1001 Ok(vad)
1002 }
1003
1004 pub fn predict(&mut self, audio_frame: &[f32]) -> f32 {
1005 let features = self.feature_extractor.extract_features(audio_frame);
1007
1008 for i in 0..CONTEXT_WINDOW_LEN - 1 {
1010 for j in 0..FEATURE_LEN {
1011 self.feature_buffer[[i, j]] = self.feature_buffer[[i + 1, j]];
1012 }
1013 }
1014 for j in 0..FEATURE_LEN {
1015 self.feature_buffer[[CONTEXT_WINDOW_LEN - 1, j]] = features[j];
1016 }
1017
1018 for i in 0..CONTEXT_WINDOW_LEN {
1022 for j in 0..FEATURE_LEN {
1023 self.t_input.set(i, j, 0, self.feature_buffer[[i, j]]);
1024 }
1025 }
1026
1027 self.conv1_dw
1030 .forward_into(&self.t_input, &mut self.t_conv1_dw);
1031 self.conv1_pw
1032 .forward_into(&self.t_conv1_dw, &mut self.t_conv1_pw);
1033
1034 for val in self.t_conv1_pw.data.iter_mut() {
1036 *val = val.max(0.0);
1037 }
1038
1039 self.maxpool
1040 .forward_into(&self.t_conv1_pw, &mut self.t_maxpool);
1041
1042 self.conv2_dw
1044 .forward_into(&self.t_maxpool, &mut self.t_conv2_dw);
1045 self.conv2_pw
1046 .forward_into(&self.t_conv2_dw, &mut self.t_conv2_pw);
1047
1048 for val in self.t_conv2_pw.data.iter_mut() {
1049 *val = val.max(0.0);
1050 }
1051
1052 self.conv3_dw
1054 .forward_into(&self.t_conv2_pw, &mut self.t_conv3_dw);
1055 self.conv3_pw
1056 .forward_into(&self.t_conv3_dw, &mut self.t_conv3_pw);
1057
1058 for val in self.t_conv3_pw.data.iter_mut() {
1059 *val = val.max(0.0);
1060 }
1061
1062 let lstm_input = &self.t_conv3_pw.data;
1065
1066 self.lstm1
1068 .forward_optimized(lstm_input, &mut self.h1, &mut self.c1);
1069
1070 self.lstm2
1072 .forward_optimized(&self.h1, &mut self.h2, &mut self.c2);
1073
1074 let h_size = HIDDEN_SIZE;
1077 self.dense_input_buffer[0..h_size].copy_from_slice(&self.h2);
1078 self.dense_input_buffer[h_size..2 * h_size].copy_from_slice(&self.h1);
1079
1080 self.dense1
1082 .forward(&self.dense_input_buffer, &mut self.dense1_out_buffer);
1083 for val in self.dense1_out_buffer.iter_mut() {
1085 *val = val.max(0.0);
1086 }
1087
1088 let mut output = [0.0; 1];
1090 self.dense2.forward(&self.dense1_out_buffer, &mut output);
1091
1092 let score = 1.0 / (1.0 + (-output[0]).exp()); self.last_score = Some(score);
1094
1095 score
1096 }
1097
1098 fn load_weights_from_bytes(&mut self, bytes: &[u8]) -> Result<()> {
1099 let mut offset = 0;
1100
1101 let read_u32 = |offset: &mut usize, buf: &[u8]| -> u32 {
1103 let val = u32::from_le_bytes(buf[*offset..*offset + 4].try_into().unwrap());
1104 *offset += 4;
1105 val
1106 };
1107
1108 let num_tensors = read_u32(&mut offset, bytes);
1109
1110 let mut weights = std::collections::HashMap::new();
1111
1112 for _ in 0..num_tensors {
1113 let name_len = read_u32(&mut offset, bytes) as usize;
1114 let name_bytes = &bytes[offset..offset + name_len];
1115 let name = std::str::from_utf8(name_bytes)?.to_string();
1116 offset += name_len;
1117
1118 let shape_len = read_u32(&mut offset, bytes) as usize;
1119 let mut shape = Vec::new();
1120 for _ in 0..shape_len {
1121 shape.push(read_u32(&mut offset, bytes));
1122 }
1123
1124 let data_len = read_u32(&mut offset, bytes) as usize;
1125 let data_bytes = &bytes[offset..offset + data_len];
1126 let floats: Vec<f32> = data_bytes
1127 .chunks_exact(4)
1128 .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
1129 .collect();
1130 offset += data_len;
1131
1132 weights.insert(name, (shape, floats));
1133 }
1134
1135 if let Some(w) = weights.get("conv1_dw_weight") {
1137 self.conv1_dw.weights = w.1.clone();
1138 }
1139 if let Some(w) = weights.get("conv1_pw_weight") {
1140 self.conv1_pw.weights = w.1.clone();
1141 }
1142 if let Some(w) = weights.get("conv1_bias") {
1143 self.conv1_pw.bias = Some(w.1.clone());
1144 }
1145
1146 if let Some(w) = weights.get("conv2_dw_weight") {
1147 self.conv2_dw.weights = w.1.clone();
1148 }
1149 if let Some(w) = weights.get("conv2_pw_weight") {
1150 self.conv2_pw.weights = w.1.clone();
1151 }
1152 if let Some(w) = weights.get("conv2_bias") {
1153 self.conv2_pw.bias = Some(w.1.clone());
1154 }
1155
1156 if let Some(w) = weights.get("conv3_dw_weight") {
1157 self.conv3_dw.weights = w.1.clone();
1158 }
1159 if let Some(w) = weights.get("conv3_pw_weight") {
1160 self.conv3_pw.weights = w.1.clone();
1161 }
1162 if let Some(w) = weights.get("conv3_bias") {
1163 self.conv3_pw.bias = Some(w.1.clone());
1164 }
1165
1166 if let Some(w) = weights.get("lstm1_w_ih") {
1167 self.lstm1.weight_ih = w.1.clone();
1168 }
1169 if let Some(w) = weights.get("lstm1_w_hh") {
1170 self.lstm1.weight_hh = w.1.clone();
1171 }
1172 if let Some(w) = weights.get("lstm1_bias") {
1173 let b = &w.1;
1177 if b.len() == 8 * HIDDEN_SIZE {
1178 self.lstm1.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1179 self.lstm1.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1180 }
1181 }
1182
1183 if let Some(w) = weights.get("lstm2_w_ih") {
1184 self.lstm2.weight_ih = w.1.clone();
1185 }
1186 if let Some(w) = weights.get("lstm2_w_hh") {
1187 self.lstm2.weight_hh = w.1.clone();
1188 }
1189 if let Some(w) = weights.get("lstm2_bias") {
1190 let b = &w.1;
1191 if b.len() == 8 * HIDDEN_SIZE {
1192 self.lstm2.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1193 self.lstm2.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1194 }
1195 }
1196
1197 if let Some(w) = weights.get("dense1_weight") {
1198 self.dense1.weights = w.1.clone();
1199 }
1200 if let Some(w) = weights.get("dense1_bias") {
1201 self.dense1.bias = w.1.clone();
1202 }
1203
1204 if let Some(w) = weights.get("dense2_weight") {
1205 self.dense2.weights = w.1.clone();
1206 }
1207 if let Some(w) = weights.get("dense2_bias") {
1208 self.dense2.bias = w.1.clone();
1209 }
1210 Ok(())
1211 }
1212}
1213
1214impl VadEngine for TinyTen {
1215 fn process(&mut self, frame: &mut AudioFrame) -> Vec<(bool, u64)> {
1216 let samples = match &frame.samples {
1217 Samples::PCM { samples } => samples,
1218 _ => return vec![(false, frame.timestamp)],
1219 };
1220
1221 if !self.initialized_timestamp {
1222 self.current_timestamp = frame.timestamp;
1223 self.initialized_timestamp = true;
1224 self.buffer.resize(512, 0.0);
1228 }
1229
1230 let inv_scale = 1.0 / 32768.0;
1232 for &sample in samples {
1233 let s_f32 = sample as f32;
1234 let pre_emphasized = (s_f32 - PRE_EMPHASIS_COEFF * self.pre_emphasis_prev) * inv_scale;
1235 self.buffer.push(pre_emphasized);
1236 self.pre_emphasis_prev = s_f32;
1237 }
1238
1239 let mut results = Vec::new();
1240
1241 while self.buffer.len() >= WINDOW_SIZE {
1242 let window = self.buffer[..WINDOW_SIZE].to_vec();
1244 let score = self.predict(&window);
1245
1246 let is_voice = score > self.config.voice_threshold;
1247
1248 let chunk_timestamp =
1249 self.current_timestamp + (self.processed_samples * 1000) / (SAMPLE_RATE as u64);
1250 self.processed_samples += HOP_SIZE as u64;
1251
1252 results.push((is_voice, chunk_timestamp));
1253 self.buffer.drain(..HOP_SIZE);
1254 }
1255
1256 results
1257 }
1258}