active_call/media/vad/
tiny_ten.rs

1use super::{VADOption, VadEngine};
2use crate::media::{AudioFrame, Samples};
3use anyhow::Result;
4use realfft::{RealFftPlanner, RealToComplex};
5use std::sync::Arc;
6
7// Constants
8const SAMPLE_RATE: u32 = 16000;
9const HOP_SIZE: usize = 256; // 16ms per frame
10const FFT_SIZE: usize = 1024;
11const WINDOW_SIZE: usize = 768;
12const MEL_FILTER_BANK_NUM: usize = 40;
13const FEATURE_LEN: usize = 41; // 40 mel features + 1 pitch feature
14const CONTEXT_WINDOW_LEN: usize = 3;
15const HIDDEN_SIZE: usize = 64;
16const EPS: f32 = 1e-20;
17const PRE_EMPHASIS_COEFF: f32 = 0.97;
18
19// Feature normalization parameters
20const FEATURE_MEANS: [f32; FEATURE_LEN] = [
21    -8.198_236,
22    -6.265_716_6,
23    -5.483_818_5,
24    -4.758_691_3,
25    -4.417_089,
26    -4.142_893,
27    -3.912_850_4,
28    -3.845_928,
29    -3.657_090_4,
30    -3.723_418_7,
31    -3.876_134_2,
32    -3.843_891,
33    -3.690_405_1,
34    -3.756_065_8,
35    -3.698_696_1,
36    -3.650_463,
37    -3.700_468_8,
38    -3.567_321_3,
39    -3.498_900_2,
40    -3.477_807,
41    -3.458_816,
42    -3.444_923_9,
43    -3.401_328_6,
44    -3.306_261_3,
45    -3.278_556_8,
46    -3.233_250_9,
47    -3.198_616,
48    -3.204_526_4,
49    -3.208_798_6,
50    -3.257_838,
51    -3.381_376_7,
52    -3.534_021_4,
53    -3.640_868,
54    -3.726_858_9,
55    -3.773_731,
56    -3.804_667_2,
57    -3.832_901,
58    -3.871_120_5,
59    -3.990_593,
60    -4.480_289_5,
61    9.235_69e1,
62];
63
64const FEATURE_STDS: [f32; FEATURE_LEN] = [
65    5.166_064,
66    4.977_21,
67    4.698_896,
68    4.630_621_4,
69    4.634_348,
70    4.641_156,
71    4.640_676_5,
72    4.666_367,
73    4.650_534_6,
74    4.640_021,
75    4.637_4,
76    4.620_099,
77    4.596_316_3,
78    4.562_655,
79    4.554_36,
80    4.566_910_7,
81    4.562_49,
82    4.562_413,
83    4.585_299_5,
84    4.600_179_7,
85    4.592_846,
86    4.585_923,
87    4.583_496_6,
88    4.626_093,
89    4.626_958,
90    4.626_289_4,
91    4.637_006,
92    4.683_016,
93    4.726_814,
94    4.734_29,
95    4.753_227,
96    4.849_723,
97    4.869_435,
98    4.884_483,
99    4.921_327,
100    4.959_212_3,
101    4.996_619,
102    5.044_823_6,
103    5.072_217,
104    5.096_439_4,
105    1.152_136_9e2,
106];
107
108pub struct TenFeatureExtractor {
109    mel_filters: ndarray::Array2<f32>,
110    mel_filter_ranges: Vec<(usize, usize)>,
111    window: Vec<f32>,
112    // FFT related fields
113    rfft: Arc<dyn RealToComplex<f32>>,
114    fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
115    fft_output: Vec<realfft::num_complex::Complex<f32>>,
116    fft_input: Vec<f32>,
117    power_spectrum: Vec<f32>,
118    inv_stds: Vec<f32>,
119}
120
121impl TenFeatureExtractor {
122    pub fn new() -> Self {
123        // Generate mel filter bank
124        let (mel_filters, mel_filter_ranges) = Self::generate_mel_filters();
125
126        // Generate Hann window
127        let window = super::utils::generate_hann_window(WINDOW_SIZE, false);
128
129        // Initialize FFT
130        let mut planner = RealFftPlanner::<f32>::new();
131        let rfft = planner.plan_fft_forward(FFT_SIZE);
132        let fft_scratch = rfft.make_scratch_vec();
133        let fft_output = rfft.make_output_vec();
134        let fft_input = rfft.make_input_vec();
135        let power_spectrum = vec![0.0; FFT_SIZE / 2 + 1];
136
137        // Pre-calculate inverse STDs
138        let inv_stds: Vec<f32> = FEATURE_STDS.iter().map(|&std| 1.0 / (std + EPS)).collect();
139
140        Self {
141            mel_filters,
142            mel_filter_ranges,
143            window,
144            rfft,
145            fft_scratch,
146            fft_output,
147            fft_input,
148            power_spectrum,
149            inv_stds,
150        }
151    }
152
153    fn generate_mel_filters() -> (ndarray::Array2<f32>, Vec<(usize, usize)>) {
154        let n_bins = FFT_SIZE / 2 + 1;
155
156        // Generate mel frequency points
157        let low_mel = 2595.0_f32 * (1.0_f32 + 0.0_f32 / 700.0_f32).log10();
158        let high_mel = 2595.0_f32 * (1.0_f32 + 8000.0_f32 / 700.0_f32).log10();
159
160        let mut mel_points = Vec::new();
161        for i in 0..=MEL_FILTER_BANK_NUM + 1 {
162            let mel = low_mel + (high_mel - low_mel) * i as f32 / (MEL_FILTER_BANK_NUM + 1) as f32;
163            mel_points.push(mel);
164        }
165
166        // Convert to Hz
167        let mut hz_points = Vec::new();
168        for mel in mel_points {
169            let hz = 700.0_f32 * (10.0_f32.powf(mel / 2595.0_f32) - 1.0_f32);
170            hz_points.push(hz);
171        }
172
173        // Convert to FFT bin indices
174        let mut bin_points = Vec::new();
175        for hz in hz_points {
176            let bin = ((FFT_SIZE + 1) as f32 * hz / SAMPLE_RATE as f32).floor() as usize;
177            bin_points.push(bin);
178        }
179
180        // Build mel filter bank
181        let mut mel_filters = ndarray::Array2::<f32>::zeros((MEL_FILTER_BANK_NUM, n_bins));
182        let mut ranges = Vec::with_capacity(MEL_FILTER_BANK_NUM);
183
184        for i in 0..MEL_FILTER_BANK_NUM {
185            let start = bin_points[i];
186            let end = bin_points[i + 2];
187            ranges.push((start, end));
188
189            // Left slope
190            for j in bin_points[i]..bin_points[i + 1] {
191                if j < n_bins {
192                    mel_filters[[i, j]] =
193                        (j - bin_points[i]) as f32 / (bin_points[i + 1] - bin_points[i]) as f32;
194                }
195            }
196
197            // Right slope
198            for j in bin_points[i + 1]..bin_points[i + 2] {
199                if j < n_bins {
200                    mel_filters[[i, j]] = (bin_points[i + 2] - j) as f32
201                        / (bin_points[i + 2] - bin_points[i + 1]) as f32;
202                }
203            }
204        }
205
206        (mel_filters, ranges)
207    }
208
209    pub fn extract_features(&mut self, audio_frame: &[f32]) -> ndarray::Array1<f32> {
210        // Prepare FFT input buffer
211        // 1. Clear buffer
212        self.fft_input.fill(0.0);
213
214        // 2. Use provided pre-emphasized audio
215        let copy_len = audio_frame.len().min(WINDOW_SIZE);
216        self.fft_input[..copy_len].copy_from_slice(audio_frame);
217
218        // 3. Windowing
219        for (i, sample) in self.fft_input.iter_mut().enumerate().take(copy_len) {
220            *sample *= self.window[i];
221        }
222
223        // 4. FFT
224        self.rfft
225            .process_with_scratch(
226                &mut self.fft_input,
227                &mut self.fft_output,
228                &mut self.fft_scratch,
229            )
230            .unwrap();
231
232        // 5. Power spectrum
233        let n_bins = FFT_SIZE / 2 + 1;
234        let scale = 1.0 / FFT_SIZE as f32;
235
236        // Compute power spectrum once
237        // Use iterators to avoid bounds checks
238        for (pow, complex) in self.power_spectrum.iter_mut().zip(self.fft_output.iter()) {
239            *pow = (complex.re * complex.re + complex.im * complex.im) * scale;
240        }
241
242        // Mel filter bank features
243        let mut mel_features = ndarray::Array1::<f32>::zeros(MEL_FILTER_BANK_NUM);
244
245        for i in 0..MEL_FILTER_BANK_NUM {
246            let (start, end) = self.mel_filter_ranges[i];
247            let valid_end = end.min(n_bins);
248
249            let mut sum = 0.0;
250            if start < valid_end {
251                // Use slices for dot product to enable vectorization
252                let filter_row = self.mel_filters.row(i);
253                // Safety: we know the row is contiguous because we created it that way
254                // and we haven't modified layout.
255                if let Some(filter_slice) = filter_row.as_slice() {
256                    let filter_sub = &filter_slice[start..valid_end];
257                    let power_sub = &self.power_spectrum[start..valid_end];
258
259                    // This dot product should be auto-vectorized
260                    sum = super::simd::dot_product(filter_sub, power_sub);
261                } else {
262                    // Fallback if not contiguous (should not happen)
263                    for j in start..valid_end {
264                        sum += self.mel_filters[[i, j]] * self.power_spectrum[j];
265                    }
266                }
267            }
268            mel_features[i] = (sum + EPS).ln();
269        }
270
271        // Simple pitch estimation (using 0 as in Python code)
272        let pitch_freq = 0.0;
273
274        // Combine features
275        let mut features = ndarray::Array1::<f32>::zeros(FEATURE_LEN);
276        features
277            .slice_mut(ndarray::s![..MEL_FILTER_BANK_NUM])
278            .assign(&mel_features);
279        features[MEL_FILTER_BANK_NUM] = pitch_freq;
280
281        // Feature normalization
282        // Use pre-calculated inverse STDs and iterators
283        for (feat, (&mean, &inv_std)) in features
284            .iter_mut()
285            .zip(FEATURE_MEANS.iter().zip(self.inv_stds.iter()))
286        {
287            *feat = (*feat - mean) * inv_std;
288        }
289
290        features
291    }
292}
293
294// 3D Tensor (H, W, C)
295#[derive(Clone, Debug)]
296struct Tensor3D {
297    data: Vec<f32>,
298    h: usize,
299    w: usize,
300    c: usize,
301}
302
303impl Tensor3D {
304    fn new(h: usize, w: usize, c: usize) -> Self {
305        Self {
306            data: vec![0.0; h * w * c],
307            h,
308            w,
309            c,
310        }
311    }
312
313    fn zeros(&mut self) {
314        self.data.fill(0.0);
315    }
316
317    #[inline(always)]
318    fn get(&self, y: usize, x: usize, ch: usize) -> f32 {
319        // Safety: We assume caller checks bounds or we rely on Vec bounds check
320        self.data[y * self.w * self.c + x * self.c + ch]
321    }
322
323    #[inline(always)]
324    fn set(&mut self, y: usize, x: usize, ch: usize, val: f32) {
325        self.data[y * self.w * self.c + x * self.c + ch] = val;
326    }
327}
328
329// Conv2D Layer
330struct Conv2dLayer {
331    weights: Vec<f32>,      // [out_c, in_c/groups, kh, kw]
332    bias: Option<Vec<f32>>, // [out_c]
333    in_channels: usize,
334    out_channels: usize,
335    kernel_h: usize,
336    kernel_w: usize,
337    stride_h: usize,
338    stride_w: usize,
339    padding: [usize; 4], // [top, left, bottom, right]
340    groups: usize,
341}
342
343impl Conv2dLayer {
344    fn new(
345        in_channels: usize,
346        out_channels: usize,
347        kernel_h: usize,
348        kernel_w: usize,
349        stride_h: usize,
350        stride_w: usize,
351        padding: [usize; 4],
352        groups: usize,
353    ) -> Self {
354        Self {
355            weights: vec![0.0; out_channels * (in_channels / groups) * kernel_h * kernel_w],
356            bias: None,
357            in_channels,
358            out_channels,
359            kernel_h,
360            kernel_w,
361            stride_h,
362            stride_w,
363            padding,
364            groups,
365        }
366    }
367
368    // Optimized forward pass with pre-allocated output buffer
369    fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
370        let out_h = output.h;
371        let out_w = output.w;
372
373        // Optimization for Conv1_DW (3x3, s=1, p=0, in=1, out=1)
374        // Input: [3, 41, 1], Output: [1, 39, 1]
375        if self.in_channels == 1
376            && self.out_channels == 1
377            && self.kernel_h == 3
378            && self.kernel_w == 3
379            && self.stride_h == 1
380            && self.stride_w == 1
381            && self.padding == [0, 0, 0, 0]
382        {
383            let bias = self.bias.as_ref().map(|b| b[0]).unwrap_or(0.0);
384            let w = &self.weights; // 9 elements
385
386            // Hardcoded 3x3 convolution
387            // y is always 0 because out_h=1 (input_h=3, k=3, s=1 -> (3-3)/1 + 1 = 1)
388            for x in 0..out_w {
389                // input x range: x to x+3
390                // input y range: 0 to 3
391                let mut sum = bias;
392
393                // Unroll 3x3 kernel
394                // Row 0
395                sum += input.get(0, x, 0) * w[0];
396                sum += input.get(0, x + 1, 0) * w[1];
397                sum += input.get(0, x + 2, 0) * w[2];
398
399                // Row 1
400                sum += input.get(1, x, 0) * w[3];
401                sum += input.get(1, x + 1, 0) * w[4];
402                sum += input.get(1, x + 2, 0) * w[5];
403
404                // Row 2
405                sum += input.get(2, x, 0) * w[6];
406                sum += input.get(2, x + 1, 0) * w[7];
407                sum += input.get(2, x + 2, 0) * w[8];
408
409                output.set(0, x, 0, sum);
410            }
411            return;
412        }
413
414        // Optimization for Conv1_PW (1x1, s=1, p=0, in=1, out=16)
415        // Input: [1, 39, 1], Output: [1, 39, 16]
416        if self.in_channels == 1
417            && self.out_channels == 16
418            && self.kernel_h == 1
419            && self.kernel_w == 1
420            && self.stride_h == 1
421            && self.stride_w == 1
422        {
423            let w = &self.weights; // 16 elements
424            let b = self.bias.as_ref(); // 16 elements
425
426            for x in 0..out_w {
427                let val = input.get(0, x, 0);
428
429                // Unroll 16 channels
430                for oc in 0..16 {
431                    let bias = if let Some(bias_vec) = b {
432                        bias_vec[oc]
433                    } else {
434                        0.0
435                    };
436                    let res = val * w[oc] + bias;
437                    output.set(0, x, oc, res);
438                }
439            }
440            return;
441        }
442
443        // Optimization for Conv2_DW (1x3, s=2, p=[0,1,0,1], in=16, out=16, groups=16)
444        // Input: [1, 19, 16], Output: [1, 10, 16]
445        if self.groups == 16
446            && self.in_channels == 16
447            && self.out_channels == 16
448            && self.kernel_h == 1
449            && self.kernel_w == 3
450            && self.stride_w == 2
451            && self.padding == [0, 1, 0, 1]
452        {
453            let w = &self.weights; // 16 * 1 * 1 * 3 = 48 elements
454            let b = self.bias.as_ref();
455
456            for c in 0..16 {
457                let w_offset = c * 3;
458                let w0 = w[w_offset];
459                let w1 = w[w_offset + 1];
460                let w2 = w[w_offset + 2];
461                let bias = if let Some(bias_vec) = b {
462                    bias_vec[c]
463                } else {
464                    0.0
465                };
466
467                // x=0: in_x = -1, 0, 1. Valid: 0, 1. (w1, w2)
468                let val0 = input.get(0, 0, c);
469                let val1 = input.get(0, 1, c);
470                let sum0 = val0 * w1 + val1 * w2 + bias;
471                output.set(0, 0, c, sum0);
472
473                // x=1..9: in_x = 1, 3, 5, ... 17.
474                // x=1: in_x_origin = 1. kx=0->1, kx=1->2, kx=2->3.
475                // ...
476                // x=9: in_x_origin = 17. kx=0->17, kx=1->18, kx=2->19(skip).
477
478                // Middle loop x=1..8
479                for x in 1..9 {
480                    let in_x_origin = x * 2 - 1;
481                    let v0 = input.get(0, in_x_origin, c);
482                    let v1 = input.get(0, in_x_origin + 1, c);
483                    let v2 = input.get(0, in_x_origin + 2, c);
484                    let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
485                    output.set(0, x, c, sum);
486                }
487
488                // x=9: in_x_origin = 17. Valid: 17, 18. (w0, w1)
489                let v0 = input.get(0, 17, c);
490                let v1 = input.get(0, 18, c);
491                let sum9 = v0 * w0 + v1 * w1 + bias;
492                output.set(0, 9, c, sum9);
493            }
494            return;
495        }
496
497        // Optimization for Conv2_PW (1x1, s=1, p=0, in=16, out=16)
498        // Input: [1, 10, 16], Output: [1, 10, 16]
499        if self.in_channels == 16
500            && self.out_channels == 16
501            && self.kernel_h == 1
502            && self.kernel_w == 1
503            && self.stride_h == 1
504            && self.stride_w == 1
505            && self.groups == 1
506        {
507            let w = &self.weights; // 16 * 16 = 256 elements
508            let b = self.bias.as_ref();
509
510            for x in 0..out_w {
511                // Pre-load input channel values for this pixel to registers (hopefully)
512                let mut in_vals = [0.0; 16];
513                for ic in 0..16 {
514                    in_vals[ic] = input.get(0, x, ic);
515                }
516
517                for oc in 0..16 {
518                    let mut sum = if let Some(bias_vec) = b {
519                        bias_vec[oc]
520                    } else {
521                        0.0
522                    };
523                    let w_offset = oc * 16;
524
525                    // Unroll dot product
526                    for ic in 0..16 {
527                        sum += in_vals[ic] * w[w_offset + ic];
528                    }
529                    output.set(0, x, oc, sum);
530                }
531            }
532            return;
533        }
534
535        // Optimization for Conv3_DW (1x3, s=2, p=[0,1,0,1], in=16, out=16, groups=16)
536        // Input: [1, 10, 16], Output: [1, 5, 16]
537        if self.groups == 16
538            && self.in_channels == 16
539            && self.out_channels == 16
540            && self.kernel_h == 1
541            && self.kernel_w == 3
542            && self.stride_w == 2
543            && self.padding == [0, 1, 0, 1]
544            && out_w == 5
545        {
546            let w = &self.weights;
547            let b = self.bias.as_ref();
548
549            for c in 0..16 {
550                let w_offset = c * 3;
551                let w0 = w[w_offset];
552                let w1 = w[w_offset + 1];
553                let w2 = w[w_offset + 2];
554                let bias = if let Some(bias_vec) = b {
555                    bias_vec[c]
556                } else {
557                    0.0
558                };
559
560                // x=0: in_x = -1. Valid: 0, 1. (w1, w2)
561                let val0 = input.get(0, 0, c);
562                let val1 = input.get(0, 1, c);
563                let sum0 = val0 * w1 + val1 * w2 + bias;
564                output.set(0, 0, c, sum0);
565
566                // x=1..5: in_x = 1, 3, 5, 7.
567                // Max index accessed: 7 + 2 = 9. Input width is 10 (0..9). Safe.
568                for x in 1..5 {
569                    let in_x_origin = x * 2 - 1;
570                    let v0 = input.get(0, in_x_origin, c);
571                    let v1 = input.get(0, in_x_origin + 1, c);
572                    let v2 = input.get(0, in_x_origin + 2, c);
573                    let sum = v0 * w0 + v1 * w1 + v2 * w2 + bias;
574                    output.set(0, x, c, sum);
575                }
576            }
577            return;
578        }
579
580        // Optimization for Conv3_PW (1x1, s=1, p=0, in=16, out=32)
581        // Input: [1, 5, 16], Output: [1, 5, 32]
582        if self.in_channels == 16
583            && self.out_channels == 32
584            && self.kernel_h == 1
585            && self.kernel_w == 1
586            && self.stride_h == 1
587            && self.stride_w == 1
588            && self.groups == 1
589        {
590            let w = &self.weights; // 32 * 16 = 512 elements
591            let b = self.bias.as_ref();
592
593            for x in 0..out_w {
594                let mut in_vals = [0.0; 16];
595                for ic in 0..16 {
596                    in_vals[ic] = input.get(0, x, ic);
597                }
598
599                for oc in 0..32 {
600                    let mut sum = if let Some(bias_vec) = b {
601                        bias_vec[oc]
602                    } else {
603                        0.0
604                    };
605                    let w_offset = oc * 16;
606
607                    for ic in 0..16 {
608                        sum += in_vals[ic] * w[w_offset + ic];
609                    }
610                    output.set(0, x, oc, sum);
611                }
612            }
613            return;
614        }
615
616        // Reset output buffer
617        output.zeros();
618
619        let in_c_per_group = self.in_channels / self.groups;
620        let out_c_per_group = self.out_channels / self.groups;
621
622        // Optimization: Check if we can use fast path (no padding, stride 1, etc)
623        // But here we have padding and strides.
624
625        // Optimization: Lift bias addition out of inner loop
626        if let Some(b) = &self.bias {
627            for g in 0..self.groups {
628                for oc in 0..out_c_per_group {
629                    let out_ch_idx = g * out_c_per_group + oc;
630                    let bias_val = b[out_ch_idx];
631                    // Initialize output with bias
632                    for y in 0..out_h {
633                        for x in 0..out_w {
634                            output.set(y, x, out_ch_idx, bias_val);
635                        }
636                    }
637                }
638            }
639        }
640
641        for g in 0..self.groups {
642            for oc in 0..out_c_per_group {
643                let out_ch_idx = g * out_c_per_group + oc;
644
645                // Pre-calculate weight offset for this output channel
646                let w_base = out_ch_idx * (in_c_per_group * self.kernel_h * self.kernel_w);
647
648                for y in 0..out_h {
649                    let in_y_origin = (y * self.stride_h) as isize - self.padding[0] as isize;
650
651                    for x in 0..out_w {
652                        let in_x_origin = (x * self.stride_w) as isize - self.padding[1] as isize;
653
654                        let mut sum = 0.0;
655
656                        for ic in 0..in_c_per_group {
657                            let in_ch_idx = g * in_c_per_group + ic;
658                            let w_ic_base = w_base + ic * (self.kernel_h * self.kernel_w);
659
660                            for ky in 0..self.kernel_h {
661                                let in_y = in_y_origin + ky as isize;
662                                if in_y >= 0 && in_y < input.h as isize {
663                                    let w_ky_base = w_ic_base + ky * self.kernel_w;
664
665                                    for kx in 0..self.kernel_w {
666                                        let in_x = in_x_origin + kx as isize;
667
668                                        if in_x >= 0 && in_x < input.w as isize {
669                                            // Hot path
670                                            let val =
671                                                input.get(in_y as usize, in_x as usize, in_ch_idx);
672                                            let w_idx = w_ky_base + kx;
673                                            // Safety: w_idx is within bounds by construction
674                                            let w = unsafe { *self.weights.get_unchecked(w_idx) };
675                                            sum += val * w;
676                                        }
677                                    }
678                                }
679                            }
680                        }
681
682                        // Accumulate to output (which already has bias)
683                        let current = output.get(y, x, out_ch_idx);
684                        output.set(y, x, out_ch_idx, current + sum);
685                    }
686                }
687            }
688        }
689    }
690}
691
692// MaxPool2D Layer
693struct MaxPool2dLayer {
694    kernel_h: usize,
695    kernel_w: usize,
696    stride_h: usize,
697    stride_w: usize,
698}
699
700impl MaxPool2dLayer {
701    fn forward_into(&self, input: &Tensor3D, output: &mut Tensor3D) {
702        let out_h = output.h;
703        let out_w = output.w;
704
705        // Optimization for MaxPool (1x3, s=1x2)
706        if self.kernel_h == 1 && self.kernel_w == 3 && self.stride_h == 1 && self.stride_w == 2 {
707            for c in 0..input.c {
708                // y is always 0
709                for x in 0..out_w {
710                    let in_x = x * 2;
711                    // We assume valid padding so in_x+2 is within bounds
712                    let v0 = input.get(0, in_x, c);
713                    let v1 = input.get(0, in_x + 1, c);
714                    let v2 = input.get(0, in_x + 2, c);
715
716                    let max_v = v0.max(v1).max(v2);
717                    output.set(0, x, c, max_v);
718                }
719            }
720            return;
721        }
722
723        for c in 0..input.c {
724            for y in 0..out_h {
725                for x in 0..out_w {
726                    let mut max_val = f32::NEG_INFINITY;
727
728                    for ky in 0..self.kernel_h {
729                        for kx in 0..self.kernel_w {
730                            let in_y = y * self.stride_h + ky;
731                            let in_x = x * self.stride_w + kx;
732                            // MaxPool usually doesn't have padding in this model (valid padding)
733                            // So we can skip bounds check if we trust output size calculation
734                            let val = input.get(in_y, in_x, c);
735                            if val > max_val {
736                                max_val = val;
737                            }
738                        }
739                    }
740                    output.set(y, x, c, max_val);
741                }
742            }
743        }
744    }
745}
746
747// Simple Linear Layer
748struct LinearLayer {
749    weights: Vec<f32>, // Flattened [out_features, in_features]
750    bias: Vec<f32>,    // [out_features]
751    in_features: usize,
752    out_features: usize,
753}
754
755impl LinearLayer {
756    fn new(in_features: usize, out_features: usize) -> Self {
757        // Initialize with dummy weights (or load from file)
758        // For now, we initialize with zeros/randoms if we were training,
759        // but here we just create the structure.
760        Self {
761            weights: vec![0.0; out_features * in_features],
762            bias: vec![0.0; out_features],
763            in_features,
764            out_features,
765        }
766    }
767
768    fn forward(&self, input: &[f32], output: &mut [f32]) {
769        assert_eq!(input.len(), self.in_features);
770        assert_eq!(output.len(), self.out_features);
771
772        // Matrix-Vector Multiplication: y = Wx + b
773        // Optimized with iterators for auto-vectorization
774        for (i, out_val) in output.iter_mut().enumerate() {
775            let weight_row_start = i * self.in_features;
776            let weight_row = &self.weights[weight_row_start..weight_row_start + self.in_features];
777
778            let dot_product: f32 = weight_row
779                .iter()
780                .zip(input.iter())
781                .map(|(&w, &x)| w * x)
782                .sum();
783
784            *out_val = dot_product + self.bias[i];
785        }
786    }
787}
788
789// LSTM Layer
790struct LstmLayer {
791    input_size: usize,
792    hidden_size: usize,
793    // Weights: 4 * hidden_size rows (i, f, g, o)
794    weight_ih: Vec<f32>, // [4 * hidden_size, input_size]
795    weight_hh: Vec<f32>, // [4 * hidden_size, hidden_size]
796    bias_ih: Vec<f32>,   // [4 * hidden_size]
797    bias_hh: Vec<f32>,   // [4 * hidden_size]
798
799    // Scratch buffers
800    gates_buffer: Vec<f32>, // [4 * hidden_size]
801}
802
803impl LstmLayer {
804    fn new(input_size: usize, hidden_size: usize) -> Self {
805        Self {
806            input_size,
807            hidden_size,
808            weight_ih: vec![0.0; 4 * hidden_size * input_size],
809            weight_hh: vec![0.0; 4 * hidden_size * hidden_size],
810            bias_ih: vec![0.0; 4 * hidden_size],
811            bias_hh: vec![0.0; 4 * hidden_size],
812            gates_buffer: vec![0.0; 4 * hidden_size],
813        }
814    }
815
816    fn forward_optimized(&mut self, input: &[f32], hidden: &mut [f32], cell: &mut [f32]) {
817        let h_size = self.hidden_size;
818
819        // 1. Compute W_ih * x + b_ih for all gates (i, f, g, o)
820        for i in 0..4 * h_size {
821            let w_start = i * self.input_size;
822            let w_row = &self.weight_ih[w_start..w_start + self.input_size];
823            let dot: f32 = w_row.iter().zip(input).map(|(&w, &x)| w * x).sum();
824            self.gates_buffer[i] = dot + self.bias_ih[i];
825        }
826
827        // 2. Compute W_hh * h + b_hh for all gates
828        // We can add directly to gates_buffer
829        for i in 0..4 * h_size {
830            let w_start = i * h_size;
831            let w_row = &self.weight_hh[w_start..w_start + h_size];
832            let dot: f32 = w_row.iter().zip(hidden.iter()).map(|(&w, &h)| w * h).sum();
833            self.gates_buffer[i] += dot + self.bias_hh[i];
834        }
835
836        // 3. Apply activations and update states
837        // ONNX Gates order: i, o, f, g (c)
838        for i in 0..h_size {
839            let i_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i]);
840            let o_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + h_size]);
841            let f_gate = crate::media::vad::utils::sigmoid(self.gates_buffer[i + 2 * h_size]);
842            let g_gate = crate::media::vad::utils::tanh(self.gates_buffer[i + 3 * h_size]);
843
844            // c_t = f_t * c_{t-1} + i_t * g_t
845            cell[i] = f_gate * cell[i] + i_gate * g_gate;
846
847            // h_t = o_t * tanh(c_t)
848            hidden[i] = o_gate * crate::media::vad::utils::tanh(cell[i]);
849        }
850    }
851}
852
853pub struct TinyTen {
854    config: VADOption,
855    buffer: Vec<f32>, // Store pre-emphasized f32 samples
856    pre_emphasis_prev: f32,
857    current_timestamp: u64,
858    processed_samples: u64,
859    initialized_timestamp: bool,
860
861    feature_extractor: TenFeatureExtractor,
862    feature_buffer: ndarray::Array2<f32>,
863
864    // Model Layers
865    // Block 1
866    conv1_dw: Conv2dLayer,
867    conv1_pw: Conv2dLayer,
868    maxpool: MaxPool2dLayer,
869
870    // Block 2
871    conv2_dw: Conv2dLayer,
872    conv2_pw: Conv2dLayer,
873
874    // Block 3
875    conv3_dw: Conv2dLayer,
876    conv3_pw: Conv2dLayer,
877
878    lstm1: LstmLayer,
879    lstm2: LstmLayer,
880    dense1: LinearLayer,
881    dense2: LinearLayer,
882
883    // Model States
884    h1: Vec<f32>,
885    c1: Vec<f32>,
886    h2: Vec<f32>,
887    c2: Vec<f32>,
888
889    // Scratch Buffers (Pre-allocated)
890    t_input: Tensor3D,
891    t_conv1_dw: Tensor3D,
892    t_conv1_pw: Tensor3D,
893    t_maxpool: Tensor3D,
894    t_conv2_dw: Tensor3D,
895    t_conv2_pw: Tensor3D,
896    t_conv3_dw: Tensor3D,
897    t_conv3_pw: Tensor3D,
898
899    dense_input_buffer: Vec<f32>,
900    dense1_out_buffer: Vec<f32>,
901
902    last_score: Option<f32>,
903}
904
905const WEIGHTS_BYTES: &[u8] = include_bytes!("tiny_tenvad.bin");
906
907impl TinyTen {
908    pub fn new(config: VADOption) -> Result<Self> {
909        if config.samplerate != 16000 {
910            return Err(anyhow::anyhow!("TinyVad only supports 16kHz audio"));
911        }
912
913        let feature_extractor = TenFeatureExtractor::new();
914        let feature_buffer = ndarray::Array2::<f32>::zeros((CONTEXT_WINDOW_LEN, FEATURE_LEN));
915
916        // Initialize layers
917        // Conv1: Input [1, 3, 41, 1]
918        // DW: 3x3, stride 1, pad 0. Out: [1, 1, 39, 1]
919        let conv1_dw = Conv2dLayer::new(1, 1, 3, 3, 1, 1, [0, 0, 0, 0], 1);
920        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 39, 16]
921        let conv1_pw = Conv2dLayer::new(1, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
922
923        // MaxPool: 1x3, stride 1x2. Out: [1, 1, 19, 16]
924        let maxpool = MaxPool2dLayer {
925            kernel_h: 1,
926            kernel_w: 3,
927            stride_h: 1,
928            stride_w: 2,
929        };
930
931        // Conv2: Input [1, 1, 19, 16]
932        // DW: 1x3, stride 2x2, pad [0, 1, 0, 1]. Out: [1, 1, 10, 16]
933        let conv2_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 1, 0, 1], 16);
934        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 10, 16]
935        let conv2_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
936
937        // Conv3: Input [1, 1, 10, 16]
938        // DW: 1x3, stride 2x2, pad [0, 0, 0, 1]. Out: [1, 1, 5, 16]
939        let conv3_dw = Conv2dLayer::new(16, 16, 1, 3, 2, 2, [0, 0, 0, 1], 16);
940        // PW: 1x1, stride 1, pad 0. Out: [1, 1, 5, 16]
941        let conv3_pw = Conv2dLayer::new(16, 16, 1, 1, 1, 1, [0, 0, 0, 0], 1);
942
943        // LSTM Input size: 5 * 16 = 80.
944        let lstm1 = LstmLayer::new(80, HIDDEN_SIZE);
945        let lstm2 = LstmLayer::new(HIDDEN_SIZE, HIDDEN_SIZE);
946
947        let dense1 = LinearLayer::new(HIDDEN_SIZE * 2, 32);
948        let dense2 = LinearLayer::new(32, 1);
949
950        // Pre-allocate scratch buffers
951        let t_input = Tensor3D::new(CONTEXT_WINDOW_LEN, FEATURE_LEN, 1);
952        let t_conv1_dw = Tensor3D::new(1, 39, 1);
953        let t_conv1_pw = Tensor3D::new(1, 39, 16);
954        let t_maxpool = Tensor3D::new(1, 19, 16);
955        let t_conv2_dw = Tensor3D::new(1, 10, 16);
956        let t_conv2_pw = Tensor3D::new(1, 10, 16);
957        let t_conv3_dw = Tensor3D::new(1, 5, 16);
958        let t_conv3_pw = Tensor3D::new(1, 5, 16);
959
960        let dense_input_buffer = vec![0.0; HIDDEN_SIZE * 2];
961        let dense1_out_buffer = vec![0.0; 32];
962
963        let mut vad = Self {
964            config,
965            buffer: Vec::with_capacity(768),
966            pre_emphasis_prev: 0.0,
967            current_timestamp: 0,
968            processed_samples: 0,
969            initialized_timestamp: false,
970            feature_extractor,
971            feature_buffer,
972            conv1_dw,
973            conv1_pw,
974            maxpool,
975            conv2_dw,
976            conv2_pw,
977            conv3_dw,
978            conv3_pw,
979            lstm1,
980            lstm2,
981            dense1,
982            dense2,
983            h1: vec![0.0; HIDDEN_SIZE],
984            c1: vec![0.0; HIDDEN_SIZE],
985            h2: vec![0.0; HIDDEN_SIZE],
986            c2: vec![0.0; HIDDEN_SIZE],
987            t_input,
988            t_conv1_dw,
989            t_conv1_pw,
990            t_maxpool,
991            t_conv2_dw,
992            t_conv2_pw,
993            t_conv3_dw,
994            t_conv3_pw,
995            dense_input_buffer,
996            dense1_out_buffer,
997            last_score: None,
998        };
999
1000        vad.load_weights_from_bytes(WEIGHTS_BYTES)?;
1001        Ok(vad)
1002    }
1003
1004    pub fn predict(&mut self, audio_frame: &[f32]) -> f32 {
1005        // 1. Extract features
1006        let features = self.feature_extractor.extract_features(audio_frame);
1007
1008        // 2. Update context window
1009        for i in 0..CONTEXT_WINDOW_LEN - 1 {
1010            for j in 0..FEATURE_LEN {
1011                self.feature_buffer[[i, j]] = self.feature_buffer[[i + 1, j]];
1012            }
1013        }
1014        for j in 0..FEATURE_LEN {
1015            self.feature_buffer[[CONTEXT_WINDOW_LEN - 1, j]] = features[j];
1016        }
1017
1018        // 3. Prepare Input Tensor [1, 3, 41, 1]
1019        // H=3 (Time), W=41 (Freq), C=1
1020        // Reuse t_input
1021        for i in 0..CONTEXT_WINDOW_LEN {
1022            for j in 0..FEATURE_LEN {
1023                self.t_input.set(i, j, 0, self.feature_buffer[[i, j]]);
1024            }
1025        }
1026
1027        // 4. Forward Pass
1028        // Block 1
1029        self.conv1_dw
1030            .forward_into(&self.t_input, &mut self.t_conv1_dw);
1031        self.conv1_pw
1032            .forward_into(&self.t_conv1_dw, &mut self.t_conv1_pw);
1033
1034        // Apply Relu
1035        for val in self.t_conv1_pw.data.iter_mut() {
1036            *val = val.max(0.0);
1037        }
1038
1039        self.maxpool
1040            .forward_into(&self.t_conv1_pw, &mut self.t_maxpool);
1041
1042        // Block 2
1043        self.conv2_dw
1044            .forward_into(&self.t_maxpool, &mut self.t_conv2_dw);
1045        self.conv2_pw
1046            .forward_into(&self.t_conv2_dw, &mut self.t_conv2_pw);
1047
1048        for val in self.t_conv2_pw.data.iter_mut() {
1049            *val = val.max(0.0);
1050        }
1051
1052        // Block 3
1053        self.conv3_dw
1054            .forward_into(&self.t_conv2_pw, &mut self.t_conv3_dw);
1055        self.conv3_pw
1056            .forward_into(&self.t_conv3_dw, &mut self.t_conv3_pw);
1057
1058        for val in self.t_conv3_pw.data.iter_mut() {
1059            *val = val.max(0.0);
1060        }
1061
1062        // Flatten for LSTM
1063        // x shape should be [1, 5, 16] -> 80 elements
1064        let lstm_input = &self.t_conv3_pw.data;
1065
1066        // LSTM 1
1067        self.lstm1
1068            .forward_optimized(lstm_input, &mut self.h1, &mut self.c1);
1069
1070        // LSTM 2
1071        self.lstm2
1072            .forward_optimized(&self.h1, &mut self.h2, &mut self.c2);
1073
1074        // Concat h2, h1 (Graph says concat_1 inputs: lstm2, lstm1)
1075        // dense_input_buffer is [h2, h1]
1076        let h_size = HIDDEN_SIZE;
1077        self.dense_input_buffer[0..h_size].copy_from_slice(&self.h2);
1078        self.dense_input_buffer[h_size..2 * h_size].copy_from_slice(&self.h1);
1079
1080        // Dense 1
1081        self.dense1
1082            .forward(&self.dense_input_buffer, &mut self.dense1_out_buffer);
1083        // Relu
1084        for val in self.dense1_out_buffer.iter_mut() {
1085            *val = val.max(0.0);
1086        }
1087
1088        // Dense 2
1089        let mut output = [0.0; 1];
1090        self.dense2.forward(&self.dense1_out_buffer, &mut output);
1091
1092        let score = 1.0 / (1.0 + (-output[0]).exp()); // Sigmoid
1093        self.last_score = Some(score);
1094
1095        score
1096    }
1097
1098    fn load_weights_from_bytes(&mut self, bytes: &[u8]) -> Result<()> {
1099        let mut offset = 0;
1100
1101        // Helper to read u32
1102        let read_u32 = |offset: &mut usize, buf: &[u8]| -> u32 {
1103            let val = u32::from_le_bytes(buf[*offset..*offset + 4].try_into().unwrap());
1104            *offset += 4;
1105            val
1106        };
1107
1108        let num_tensors = read_u32(&mut offset, bytes);
1109
1110        let mut weights = std::collections::HashMap::new();
1111
1112        for _ in 0..num_tensors {
1113            let name_len = read_u32(&mut offset, bytes) as usize;
1114            let name_bytes = &bytes[offset..offset + name_len];
1115            let name = std::str::from_utf8(name_bytes)?.to_string();
1116            offset += name_len;
1117
1118            let shape_len = read_u32(&mut offset, bytes) as usize;
1119            let mut shape = Vec::new();
1120            for _ in 0..shape_len {
1121                shape.push(read_u32(&mut offset, bytes));
1122            }
1123
1124            let data_len = read_u32(&mut offset, bytes) as usize;
1125            let data_bytes = &bytes[offset..offset + data_len];
1126            let floats: Vec<f32> = data_bytes
1127                .chunks_exact(4)
1128                .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
1129                .collect();
1130            offset += data_len;
1131
1132            weights.insert(name, (shape, floats));
1133        }
1134
1135        // Assign weights
1136        if let Some(w) = weights.get("conv1_dw_weight") {
1137            self.conv1_dw.weights = w.1.clone();
1138        }
1139        if let Some(w) = weights.get("conv1_pw_weight") {
1140            self.conv1_pw.weights = w.1.clone();
1141        }
1142        if let Some(w) = weights.get("conv1_bias") {
1143            self.conv1_pw.bias = Some(w.1.clone());
1144        }
1145
1146        if let Some(w) = weights.get("conv2_dw_weight") {
1147            self.conv2_dw.weights = w.1.clone();
1148        }
1149        if let Some(w) = weights.get("conv2_pw_weight") {
1150            self.conv2_pw.weights = w.1.clone();
1151        }
1152        if let Some(w) = weights.get("conv2_bias") {
1153            self.conv2_pw.bias = Some(w.1.clone());
1154        }
1155
1156        if let Some(w) = weights.get("conv3_dw_weight") {
1157            self.conv3_dw.weights = w.1.clone();
1158        }
1159        if let Some(w) = weights.get("conv3_pw_weight") {
1160            self.conv3_pw.weights = w.1.clone();
1161        }
1162        if let Some(w) = weights.get("conv3_bias") {
1163            self.conv3_pw.bias = Some(w.1.clone());
1164        }
1165
1166        if let Some(w) = weights.get("lstm1_w_ih") {
1167            self.lstm1.weight_ih = w.1.clone();
1168        }
1169        if let Some(w) = weights.get("lstm1_w_hh") {
1170            self.lstm1.weight_hh = w.1.clone();
1171        }
1172        if let Some(w) = weights.get("lstm1_bias") {
1173            // Split bias into ih and hh if needed, or just use as is.
1174            // ONNX LSTM bias is [8*H]. Our LstmLayer expects bias_ih [4*H] and bias_hh [4*H].
1175            // Usually first half is W_b, second half is R_b.
1176            let b = &w.1;
1177            if b.len() == 8 * HIDDEN_SIZE {
1178                self.lstm1.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1179                self.lstm1.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1180            }
1181        }
1182
1183        if let Some(w) = weights.get("lstm2_w_ih") {
1184            self.lstm2.weight_ih = w.1.clone();
1185        }
1186        if let Some(w) = weights.get("lstm2_w_hh") {
1187            self.lstm2.weight_hh = w.1.clone();
1188        }
1189        if let Some(w) = weights.get("lstm2_bias") {
1190            let b = &w.1;
1191            if b.len() == 8 * HIDDEN_SIZE {
1192                self.lstm2.bias_ih = b[0..4 * HIDDEN_SIZE].to_vec();
1193                self.lstm2.bias_hh = b[4 * HIDDEN_SIZE..].to_vec();
1194            }
1195        }
1196
1197        if let Some(w) = weights.get("dense1_weight") {
1198            self.dense1.weights = w.1.clone();
1199        }
1200        if let Some(w) = weights.get("dense1_bias") {
1201            self.dense1.bias = w.1.clone();
1202        }
1203
1204        if let Some(w) = weights.get("dense2_weight") {
1205            self.dense2.weights = w.1.clone();
1206        }
1207        if let Some(w) = weights.get("dense2_bias") {
1208            self.dense2.bias = w.1.clone();
1209        }
1210        Ok(())
1211    }
1212}
1213
1214impl VadEngine for TinyTen {
1215    fn process(&mut self, frame: &mut AudioFrame) -> Vec<(bool, u64)> {
1216        let samples = match &frame.samples {
1217            Samples::PCM { samples } => samples,
1218            _ => return vec![(false, frame.timestamp)],
1219        };
1220
1221        if !self.initialized_timestamp {
1222            self.current_timestamp = frame.timestamp;
1223            self.initialized_timestamp = true;
1224            // Pre-pad with zeros for the initial context
1225            // librosa center=True pads n_fft/2 (512).
1226            // 512 context makes the first chunk start at the end of the window.
1227            self.buffer.resize(512, 0.0);
1228        }
1229
1230        // Apply pre-emphasis once to NEW samples and push to buffer
1231        let inv_scale = 1.0 / 32768.0;
1232        for &sample in samples {
1233            let s_f32 = sample as f32;
1234            let pre_emphasized = (s_f32 - PRE_EMPHASIS_COEFF * self.pre_emphasis_prev) * inv_scale;
1235            self.buffer.push(pre_emphasized);
1236            self.pre_emphasis_prev = s_f32;
1237        }
1238
1239        let mut results = Vec::new();
1240
1241        while self.buffer.len() >= WINDOW_SIZE {
1242            // Predict uses the full window but we only drain the hop size
1243            let window = self.buffer[..WINDOW_SIZE].to_vec();
1244            let score = self.predict(&window);
1245
1246            let is_voice = score > self.config.voice_threshold;
1247
1248            let chunk_timestamp =
1249                self.current_timestamp + (self.processed_samples * 1000) / (SAMPLE_RATE as u64);
1250            self.processed_samples += HOP_SIZE as u64;
1251
1252            results.push((is_voice, chunk_timestamp));
1253            self.buffer.drain(..HOP_SIZE);
1254        }
1255
1256        results
1257    }
1258}