1use crate::nn::{self, BurnBackend, BurnDevice};
2use anyhow::{Context, Result, anyhow, bail};
3use burn::tensor::{Tensor, TensorData};
4use kaldi_native_fbank::online::FeatureComputer;
5use kaldi_native_fbank::{FbankComputer, FbankOptions, OnlineFeature};
6use ndarray::{Array1, Array2, s};
7use std::path::Path;
8
9const TARGET_FRAME_COUNT: usize = 200;
10
11#[derive(Debug, Clone)]
12pub struct Embedding {
13 values: Vec<f32>,
14}
15
16impl Embedding {
17 pub fn new(values: Vec<f32>) -> Self {
18 Self { values }
19 }
20
21 pub fn as_slice(&self) -> &[f32] {
22 &self.values
23 }
24
25 pub fn into_inner(self) -> Vec<f32> {
26 self.values
27 }
28}
29
30impl From<Vec<f32>> for Embedding {
31 fn from(values: Vec<f32>) -> Self {
32 Self::new(values)
33 }
34}
35
36impl AsRef<[f32]> for Embedding {
37 fn as_ref(&self) -> &[f32] {
38 self.as_slice()
39 }
40}
41
42#[derive(Debug)]
43pub struct EmbeddingExtractor {
44 model: nn::speaker_identification::Model<BurnBackend>,
45 device: BurnDevice,
46}
47
48impl EmbeddingExtractor {
49 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
50 let device = BurnDevice::default();
51 let model_path = model_path
52 .as_ref()
53 .to_str()
54 .context("Model path must be valid UTF-8")?;
55 let model = nn::speaker_identification::Model::from_file(model_path, &device);
56
57 Ok(Self { model, device })
58 }
59
60 pub fn extract(&self, samples: &[i16], sample_rate: u32) -> Result<Embedding> {
61 let samples_f32 = normalize_i16_to_f32(samples);
62 self.extract_from_f32(&samples_f32, sample_rate)
63 }
64
65 pub fn extract_f32(&self, samples: &[f32], sample_rate: u32) -> Result<Embedding> {
66 self.extract_from_f32(samples, sample_rate)
67 }
68
69 fn extract_from_f32(&self, samples: &[f32], sample_rate: u32) -> Result<Embedding> {
70 if sample_rate == 0 {
71 bail!("sample_rate cannot be zero");
72 }
73 if samples.is_empty() {
74 bail!("samples cannot be empty");
75 }
76
77 let sample_rate = sample_rate as f32;
78 let mut fbank_opts = FbankOptions::default();
79 fbank_opts.mel_opts.num_bins = 80;
80 fbank_opts.use_energy = false;
81
82 {
83 let frame_opts = &mut fbank_opts.frame_opts;
84 frame_opts.dither = 0.0;
85 frame_opts.samp_freq = sample_rate;
86 frame_opts.snip_edges = true;
87 }
88
89 let fbank = FbankComputer::new(fbank_opts).map_err(|e| anyhow!(e))?;
90 let mut online_feature = OnlineFeature::new(FeatureComputer::Fbank(fbank));
91 online_feature.accept_waveform(sample_rate, samples);
92 online_feature.input_finished();
93
94 let frames = online_feature.features;
95 if frames.is_empty() {
96 bail!("No features computed");
97 }
98
99 let num_bins = frames[0].len();
100 let mut flattened = Vec::with_capacity(frames.len() * num_bins);
101 for frame in &frames {
102 if frame.len() != num_bins {
103 bail!("Inconsistent feature dimensions");
104 }
105 flattened.extend_from_slice(frame);
106 }
107
108 let features = Array2::from_shape_vec((frames.len(), num_bins), flattened)?;
109 let original_mean = features.mean_axis(ndarray::Axis(0)).context("mean")?;
110 let features = adjust_feature_length(features, TARGET_FRAME_COUNT, &original_mean);
111 let mean = features.mean_axis(ndarray::Axis(0)).context("mean")?;
112 let features: Array2<f32> = features - &mean;
113 let frame_count = features.nrows();
114
115 let (features, _) = features.into_raw_vec_and_offset();
116 let data = TensorData::new(features, [1, frame_count, num_bins]);
117 let input = Tensor::<BurnBackend, 3>::from_data(data, &self.device);
118 let output = self.model.forward(input);
119 let output_data = output.into_data();
120 let shape = output_data.shape.clone();
121
122 if shape.len() != 2 {
123 bail!("Unexpected embedding output shape: {:?}", shape);
124 }
125 if shape[0] != 1 {
126 bail!("Expected batch size 1, got {}", shape[0]);
127 }
128
129 let values = output_data
130 .into_vec::<f32>()
131 .map_err(|err| anyhow!("Failed to read embedding output: {err}"))?;
132
133 Ok(Embedding::new(values))
134 }
135}
136
137fn normalize_i16_to_f32(samples: &[i16]) -> Vec<f32> {
138 samples
139 .iter()
140 .map(|sample| *sample as f32 / 32768.0)
141 .collect()
142}
143
144fn adjust_feature_length(
145 features: Array2<f32>,
146 target_frames: usize,
147 pad_value: &Array1<f32>,
148) -> Array2<f32> {
149 let frame_count = features.nrows();
150 let num_bins = features.ncols();
151
152 if frame_count > target_frames {
153 let start = (frame_count - target_frames) / 2;
154 return features.slice(s![start..start + target_frames, ..]).to_owned();
155 }
156
157 if frame_count == target_frames {
158 return features;
159 }
160
161 let mut padded = Array2::zeros((target_frames, num_bins));
162 let offset = (target_frames - frame_count) / 2;
163 padded
164 .slice_mut(s![offset..offset + frame_count, ..])
165 .assign(&features);
166
167 fill_with_mean(padded.slice_mut(s![..offset, ..]), pad_value);
168
169 let end_padding = target_frames - offset - frame_count;
170 if end_padding > 0 {
171 fill_with_mean(
172 padded.slice_mut(s![target_frames - end_padding.., ..]),
173 pad_value,
174 );
175 }
176
177 padded
178}
179
180fn fill_with_mean(mut view: ndarray::ArrayViewMut2<'_, f32>, mean: &Array1<f32>) {
181 for mut row in view.rows_mut() {
182 row.assign(mean);
183 }
184}