active_call/offline/sensevoice/
frontend.rs

1use anyhow::Result;
2use ndarray::{s, Array2};
3use std::ffi::c_float;
4
5#[derive(Clone, Copy, Debug)]
6pub struct FrontendConfig {
7    pub sample_rate: usize,
8    pub n_mels: usize,
9    pub frame_length_ms: f32,
10    pub frame_shift_ms: f32,
11    pub lfr_m: usize,
12    pub lfr_n: usize,
13}
14
15impl Default for FrontendConfig {
16    fn default() -> Self {
17        Self {
18            sample_rate: 16000,
19            n_mels: 80,
20            frame_length_ms: 25.0,
21            frame_shift_ms: 10.0,
22            lfr_m: 7,
23            lfr_n: 6,
24        }
25    }
26}
27
28pub struct FeaturePipeline {
29    cfg: FrontendConfig,
30    scaled_buf: Vec<f32>,
31}
32
33impl FeaturePipeline {
34    pub fn new(cfg: FrontendConfig) -> Self {
35        Self {
36            cfg,
37            scaled_buf: Vec::new(),
38        }
39    }
40
41    pub fn compute_features(&mut self, pcm: &[f32], sample_rate: u32) -> Result<Array2<f32>> {
42        anyhow::ensure!(
43            sample_rate as usize == self.cfg.sample_rate,
44            "expect sample rate {} but got {}",
45            self.cfg.sample_rate,
46            sample_rate
47        );
48        if pcm.is_empty() {
49            anyhow::bail!("audio length too short for feature extraction");
50        }
51
52        // Keep config knobs alive for potential future customization of the backend extractor.
53        let _ = (self.cfg.frame_length_ms, self.cfg.frame_shift_ms);
54        self.scaled_buf.resize(pcm.len(), 0.0);
55        let scale = (1 << 15) as f32;
56        for (dst, src) in self.scaled_buf.iter_mut().zip(pcm.iter()) {
57            *dst = *src * scale;
58        }
59
60        let mut result = unsafe {
61            knf_rs_sys::ComputeFbank(
62                self.scaled_buf.as_ptr() as *const c_float,
63                self.scaled_buf.len() as i32,
64            )
65        };
66
67        anyhow::ensure!(
68            result.num_bins > 0 && result.num_frames > 0,
69            "fbank extraction failed"
70        );
71        let frame_count = result.num_frames as usize;
72        let mel_bins = result.num_bins as usize;
73        anyhow::ensure!(
74            mel_bins == self.cfg.n_mels,
75            "expected {} mel bins but got {}",
76            self.cfg.n_mels,
77            mel_bins
78        );
79
80        let fbank_vec =
81            unsafe { std::slice::from_raw_parts(result.frames, frame_count * mel_bins).to_vec() };
82        unsafe {
83            knf_rs_sys::DestroyFbankResult(&mut result as *mut _);
84        }
85        let fbank = Array2::<f32>::from_shape_vec((frame_count, mel_bins), fbank_vec)?;
86
87        let lfr = apply_lfr(&fbank, self.cfg.lfr_m, self.cfg.lfr_n);
88        Ok(lfr)
89    }
90}
91
92pub(crate) fn apply_lfr(fbank: &Array2<f32>, lfr_m: usize, lfr_n: usize) -> Array2<f32> {
93    if lfr_m == 1 && lfr_n == 1 {
94        return fbank.to_owned();
95    }
96    let t = fbank.len_of(ndarray::Axis(0));
97    let d = fbank.len_of(ndarray::Axis(1));
98    if t == 0 {
99        return Array2::<f32>::zeros((0, d * lfr_m));
100    }
101    let pad = (lfr_m - 1) / 2;
102    let t_lfr = ((t as f32) / (lfr_n as f32)).ceil() as usize;
103    let mut out = Array2::<f32>::zeros((t_lfr, d * lfr_m));
104
105    for i in 0..t_lfr {
106        let start = i * lfr_n;
107        for m in 0..lfr_m {
108            let effective_idx = start + m;
109            let row_idx = if effective_idx < pad {
110                0
111            } else {
112                let shifted = effective_idx - pad;
113                if shifted < t {
114                    shifted
115                } else {
116                    t - 1
117                }
118            };
119            let src_row = fbank.row(row_idx);
120            let src_slice = src_row
121                .as_slice()
122                .expect("mel features row should be contiguous");
123            let mut row_out = out.slice_mut(s![i, m * d..(m + 1) * d]);
124            let dst_slice = row_out
125                .as_slice_mut()
126                .expect("output row should be contiguous");
127            dst_slice.copy_from_slice(src_slice);
128        }
129    }
130
131    out
132}