1use crate::session;
2use eyre::{Context, ContextCompat, Result};
3use ndarray::Array2;
4use ort::{session::Session, value::Tensor};
5use std::path::Path;
6
7#[derive(Debug)]
8pub struct EmbeddingExtractor {
9 session: Session,
10}
11
12impl EmbeddingExtractor {
13 pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
14 let session = session::create_session(model_path.as_ref())?;
15 Ok(Self { session })
16 }
17
18 pub fn compute(&mut self, samples: &[i16]) -> Result<impl Iterator<Item = f32>> {
19 let mut samples_f32 = vec![0.0; samples.len()];
21 knf_rs::convert_integer_to_float_audio(samples, &mut samples_f32);
22 let samples = &samples_f32;
23
24 let features: Array2<f32> = knf_rs::compute_fbank(samples)?;
25 let features = features.insert_axis(ndarray::Axis(0)); let inputs = ort::inputs![
27 "feats" => Tensor::from_array(features)? ];
29
30 let ort_outs = self.session.run(inputs)?;
31 let ort_out = ort_outs
32 .get("embs")
33 .context("Output tensor not found")?
34 .try_extract_tensor::<f32>()
35 .context("Failed to extract tensor")?;
36
37 let embeddings: Vec<f32> = ort_out.1.iter().copied().collect();
39
40 Ok(embeddings.into_iter())
42 }
43}