use {
super::get_session_builder,
crate::OperationError,
ndarray::Array2,
ort::{
inputs,
session::{RunOptions, Session},
value::TensorRef,
},
std::{array::from_fn, path::Path},
};
pub struct SpeakerEmbeddingExtractor {
model: Session,
}
impl SpeakerEmbeddingExtractor {
pub fn new<P>(model_path: P) -> Result<Self, OperationError>
where
P: AsRef<Path>,
{
let model = get_session_builder()?.commit_from_file(model_path)?;
Ok(Self { model })
}
pub async fn extract(
&mut self,
audio: &[f32],
channels: usize,
) -> Result<Vec<[f32; 256]>, OperationError> {
let len = audio.len();
let x = Array2::from_shape_vec((len / channels, channels), audio.to_vec())?;
let options = RunOptions::new()?;
let outputs = self
.model
.run_async(
inputs![
"audio" => TensorRef::from_array_view(&x)?
],
&options,
)?
.await?;
let se = outputs["se"].try_extract_array::<f32>()?;
let mut out = Vec::with_capacity(channels);
for i in 0..channels {
let row: [f32; 256] = from_fn(|j| se[[i, j]]);
out.push(row);
}
Ok(out)
}
}