pyannote_rs/
embedding.rs

1use crate::nn::{self, BurnBackend, BurnDevice};
2use anyhow::{Context, Result, anyhow, bail};
3use burn::tensor::{Tensor, TensorData};
4use kaldi_native_fbank::online::FeatureComputer;
5use kaldi_native_fbank::{FbankComputer, FbankOptions, OnlineFeature};
6use ndarray::{Array1, Array2, s};
7use std::path::Path;
8
9const TARGET_FRAME_COUNT: usize = 200;
10
11#[derive(Debug, Clone)]
12pub struct Embedding {
13    values: Vec<f32>,
14}
15
16impl Embedding {
17    pub fn new(values: Vec<f32>) -> Self {
18        Self { values }
19    }
20
21    pub fn as_slice(&self) -> &[f32] {
22        &self.values
23    }
24
25    pub fn into_inner(self) -> Vec<f32> {
26        self.values
27    }
28}
29
30impl From<Vec<f32>> for Embedding {
31    fn from(values: Vec<f32>) -> Self {
32        Self::new(values)
33    }
34}
35
36impl AsRef<[f32]> for Embedding {
37    fn as_ref(&self) -> &[f32] {
38        self.as_slice()
39    }
40}
41
42#[derive(Debug)]
43pub struct EmbeddingExtractor {
44    model: nn::speaker_identification::Model<BurnBackend>,
45    device: BurnDevice,
46}
47
48impl EmbeddingExtractor {
49    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
50        let device = BurnDevice::default();
51        let model_path = model_path
52            .as_ref()
53            .to_str()
54            .context("Model path must be valid UTF-8")?;
55        let model = nn::speaker_identification::Model::from_file(model_path, &device);
56
57        Ok(Self { model, device })
58    }
59
60    pub fn extract(&self, samples: &[i16], sample_rate: u32) -> Result<Embedding> {
61        let samples_f32 = normalize_i16_to_f32(samples);
62        self.extract_from_f32(&samples_f32, sample_rate)
63    }
64
65    pub fn extract_f32(&self, samples: &[f32], sample_rate: u32) -> Result<Embedding> {
66        self.extract_from_f32(samples, sample_rate)
67    }
68
69    fn extract_from_f32(&self, samples: &[f32], sample_rate: u32) -> Result<Embedding> {
70        if sample_rate == 0 {
71            bail!("sample_rate cannot be zero");
72        }
73        if samples.is_empty() {
74            bail!("samples cannot be empty");
75        }
76
77        let sample_rate = sample_rate as f32;
78        let mut fbank_opts = FbankOptions::default();
79        fbank_opts.mel_opts.num_bins = 80;
80        fbank_opts.use_energy = false;
81
82        {
83            let frame_opts = &mut fbank_opts.frame_opts;
84            frame_opts.dither = 0.0;
85            frame_opts.samp_freq = sample_rate;
86            frame_opts.snip_edges = true;
87        }
88
89        let fbank = FbankComputer::new(fbank_opts).map_err(|e| anyhow!(e))?;
90        let mut online_feature = OnlineFeature::new(FeatureComputer::Fbank(fbank));
91        online_feature.accept_waveform(sample_rate, samples);
92        online_feature.input_finished();
93
94        let frames = online_feature.features;
95        if frames.is_empty() {
96            bail!("No features computed");
97        }
98
99        let num_bins = frames[0].len();
100        let mut flattened = Vec::with_capacity(frames.len() * num_bins);
101        for frame in &frames {
102            if frame.len() != num_bins {
103                bail!("Inconsistent feature dimensions");
104            }
105            flattened.extend_from_slice(frame);
106        }
107
108        let features = Array2::from_shape_vec((frames.len(), num_bins), flattened)?;
109        let original_mean = features.mean_axis(ndarray::Axis(0)).context("mean")?;
110        let features = adjust_feature_length(features, TARGET_FRAME_COUNT, &original_mean);
111        let mean = features.mean_axis(ndarray::Axis(0)).context("mean")?;
112        let features: Array2<f32> = features - &mean;
113        let frame_count = features.nrows();
114
115        let (features, _) = features.into_raw_vec_and_offset();
116        let data = TensorData::new(features, [1, frame_count, num_bins]);
117        let input = Tensor::<BurnBackend, 3>::from_data(data, &self.device);
118        let output = self.model.forward(input);
119        let output_data = output.into_data();
120        let shape = output_data.shape.clone();
121
122        if shape.len() != 2 {
123            bail!("Unexpected embedding output shape: {:?}", shape);
124        }
125        if shape[0] != 1 {
126            bail!("Expected batch size 1, got {}", shape[0]);
127        }
128
129        let values = output_data
130            .into_vec::<f32>()
131            .map_err(|err| anyhow!("Failed to read embedding output: {err}"))?;
132
133        Ok(Embedding::new(values))
134    }
135}
136
137fn normalize_i16_to_f32(samples: &[i16]) -> Vec<f32> {
138    samples
139        .iter()
140        .map(|sample| *sample as f32 / 32768.0)
141        .collect()
142}
143
144fn adjust_feature_length(
145    features: Array2<f32>,
146    target_frames: usize,
147    pad_value: &Array1<f32>,
148) -> Array2<f32> {
149    let frame_count = features.nrows();
150    let num_bins = features.ncols();
151
152    if frame_count > target_frames {
153        let start = (frame_count - target_frames) / 2;
154        return features.slice(s![start..start + target_frames, ..]).to_owned();
155    }
156
157    if frame_count == target_frames {
158        return features;
159    }
160
161    let mut padded = Array2::zeros((target_frames, num_bins));
162    let offset = (target_frames - frame_count) / 2;
163    padded
164        .slice_mut(s![offset..offset + frame_count, ..])
165        .assign(&features);
166
167    fill_with_mean(padded.slice_mut(s![..offset, ..]), pad_value);
168
169    let end_padding = target_frames - offset - frame_count;
170    if end_padding > 0 {
171        fill_with_mean(
172            padded.slice_mut(s![target_frames - end_padding.., ..]),
173            pad_value,
174        );
175    }
176
177    padded
178}
179
180fn fill_with_mean(mut view: ndarray::ArrayViewMut2<'_, f32>, mean: &Array1<f32>) {
181    for mut row in view.rows_mut() {
182        row.assign(mean);
183    }
184}