use crate::embedding::{EmbeddingError, EmbeddingExtractor};
use crate::types::DiarizationConfig;
use crate::utils::l2_normalize;
use std::path::Path;
#[cfg(feature = "onnx")]
pub struct OnnxEmbeddingExtractor {
pool: crossbeam_queue::ArrayQueue<ort::session::Session>,
embedding_dim: usize,
window_samples: usize,
}
#[cfg(feature = "onnx")]
impl OnnxEmbeddingExtractor {
pub fn new(
model_path: &Path,
embedding_dim: usize,
window_samples: usize,
pool_size: usize,
) -> anyhow::Result<Self> {
let pool = crossbeam_queue::ArrayQueue::new(pool_size);
for i in 0..pool_size {
let session = ort::session::Session::builder()
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?
.commit_from_file(model_path)
.map_err(|e| EmbeddingError::InferenceFailed(format!("session {i}: {e}")))?;
pool.push(session)
.map_err(|_| anyhow::anyhow!("failed to push session into pool"))?;
}
Ok(Self {
pool,
embedding_dim,
window_samples,
})
}
fn checkout(&self) -> Option<PooledSession<'_>> {
self.pool.pop().map(|s| PooledSession {
session: Some(s),
pool: &self.pool,
})
}
}
#[cfg(feature = "onnx")]
impl EmbeddingExtractor for OnnxEmbeddingExtractor {
fn extract(&self, samples: &[f32], _config: &DiarizationConfig) -> Result<Vec<f32>, EmbeddingError> {
let mut guard = self.checkout().ok_or_else(|| {
EmbeddingError::InferenceFailed("ONNX session pool exhausted".to_string())
})?;
if samples.len() != self.window_samples {
return Err(EmbeddingError::InvalidInput {
expected: self.window_samples,
got: samples.len(),
});
}
let input_tensor = ort::value::TensorRef::from_array_view(
([1_usize, self.window_samples], samples),
)
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
let session = guard
.session
.as_mut()
.ok_or_else(|| EmbeddingError::InferenceFailed("session not available".to_string()))?;
let outputs = session
.run(ort::inputs![input_tensor])
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
if outputs.iter().next().is_none() {
return Err(EmbeddingError::InferenceFailed(
"ONNX model produced no outputs".to_string(),
));
}
let (_, data) = &outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
let data_len = data.len();
if data_len != self.embedding_dim {
return Err(EmbeddingError::InferenceFailed(format!(
"expected embedding dim {}, got {}",
self.embedding_dim, data_len
)));
}
let mut embedding = vec![0.0f32; self.embedding_dim];
embedding.copy_from_slice(data);
l2_normalize(&mut embedding);
Ok(embedding)
}
fn embedding_dim(&self) -> usize {
self.embedding_dim
}
}
#[cfg(feature = "onnx")]
struct PooledSession<'a> {
session: Option<ort::session::Session>,
pool: &'a crossbeam_queue::ArrayQueue<ort::session::Session>,
}
#[cfg(feature = "onnx")]
impl Drop for PooledSession<'_> {
fn drop(&mut self) {
if let Some(session) = self.session.take() {
let _ = self.pool.push(session);
}
}
}
#[cfg(not(feature = "onnx"))]
pub struct OnnxEmbeddingExtractor;
#[cfg(not(feature = "onnx"))]
impl OnnxEmbeddingExtractor {
pub fn new(
_model_path: &Path,
_embedding_dim: usize,
_window_samples: usize,
_pool_size: usize,
) -> anyhow::Result<Self> {
anyhow::bail!("the `onnx` feature is not enabled")
}
}