pyannote_rs/
embedding.rs

1use crate::session;
2use eyre::{Context, ContextCompat, Result};
3use ndarray::Array2;
4use ort::{session::Session, value::Tensor};
5use std::path::Path;
6
7#[derive(Debug)]
8pub struct EmbeddingExtractor {
9    session: Session,
10}
11
12impl EmbeddingExtractor {
13    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
14        let session = session::create_session(model_path.as_ref())?;
15        Ok(Self { session })
16    }
17
18    pub fn compute(&mut self, samples: &[i16]) -> Result<impl Iterator<Item = f32>> {
19        // Convert to f32 precisely
20        let mut samples_f32 = vec![0.0; samples.len()];
21        knf_rs::convert_integer_to_float_audio(samples, &mut samples_f32);
22        let samples = &samples_f32;
23
24        let features: Array2<f32> = knf_rs::compute_fbank(samples)?;
25        let features = features.insert_axis(ndarray::Axis(0)); // Add batch dimension
26        let inputs = ort::inputs![
27        "feats" => Tensor::from_array(features)? // takes ownership of `features`
28        ];
29
30        let ort_outs = self.session.run(inputs)?;
31        let ort_out = ort_outs
32            .get("embs")
33            .context("Output tensor not found")?
34            .try_extract_tensor::<f32>()
35            .context("Failed to extract tensor")?;
36
37        // Collect the tensor data into a Vec to own it
38        let embeddings: Vec<f32> = ort_out.1.iter().copied().collect();
39
40        // Return an iterator over the Vec
41        Ok(embeddings.into_iter())
42    }
43}