use crate::embedding::{EmbeddingError, EmbeddingExtractor};
use crate::types::DiarizationConfig;
use crate::utils::l2_normalize;
use std::path::Path;
pub const ONNX_MIN_HEADER_BYTES: usize = 64;
#[derive(thiserror::Error, Debug)]
#[error("ONNX header validation failed for {path}: {detail}")]
pub struct OnnxValidationError {
pub path: std::path::PathBuf,
pub detail: String,
}
pub fn validate_onnx_header(path: &Path) -> Result<(), OnnxValidationError> {
let metadata = std::fs::metadata(path).map_err(|e| OnnxValidationError {
path: path.to_path_buf(),
detail: format!("cannot read metadata: {e}"),
})?;
if metadata.len() < ONNX_MIN_HEADER_BYTES as u64 {
return Err(OnnxValidationError {
path: path.to_path_buf(),
detail: format!(
"file too small ({} bytes, need at least {ONNX_MIN_HEADER_BYTES})",
metadata.len()
),
});
}
let mut file = std::fs::File::open(path).map_err(|e| OnnxValidationError {
path: path.to_path_buf(),
detail: format!("cannot open file: {e}"),
})?;
let mut header = [0u8; ONNX_MIN_HEADER_BYTES];
let n = std::io::Read::read(&mut file, &mut header).map_err(|e| OnnxValidationError {
path: path.to_path_buf(),
detail: format!("cannot read header: {e}"),
})?;
if n < ONNX_MIN_HEADER_BYTES {
return Err(OnnxValidationError {
path: path.to_path_buf(),
detail: format!("short read ({n} bytes, need at least {ONNX_MIN_HEADER_BYTES})"),
});
}
let has_onnx_magic = header[..16].windows(4).any(|w| w == b"ONNX");
let has_protobuf_header = header[0] == 0x08;
if !has_onnx_magic && !has_protobuf_header {
return Err(OnnxValidationError {
path: path.to_path_buf(),
detail: "ONNX magic bytes not found and file does not start with a valid ONNX protobuf header".to_string(),
});
}
Ok(())
}
#[cfg(feature = "onnx")]
pub struct OnnxEmbeddingExtractor {
pool: crossbeam_queue::ArrayQueue<ort::session::Session>,
embedding_dim: usize,
window_samples: usize,
}
#[cfg(feature = "onnx")]
impl OnnxEmbeddingExtractor {
pub fn new(
model_path: &Path,
embedding_dim: usize,
window_samples: usize,
pool_size: usize,
) -> anyhow::Result<Self> {
validate_onnx_header(model_path)
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
let pool = crossbeam_queue::ArrayQueue::new(pool_size);
for i in 0..pool_size {
let session = ort::session::Session::builder()
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?
.commit_from_file(model_path)
.map_err(|e| EmbeddingError::InferenceFailed(format!("session {i}: {e}")))?;
pool.push(session)
.map_err(|_| anyhow::anyhow!("failed to push session into pool"))?;
}
Ok(Self {
pool,
embedding_dim,
window_samples,
})
}
fn checkout(&self) -> Option<PooledSession<'_>> {
self.pool.pop().map(|s| PooledSession {
session: Some(s),
pool: &self.pool,
})
}
}
#[cfg(feature = "onnx")]
impl EmbeddingExtractor for OnnxEmbeddingExtractor {
fn extract(
&self,
samples: &[f32],
_config: &DiarizationConfig,
) -> Result<Vec<f32>, EmbeddingError> {
let mut guard = self.checkout().ok_or_else(|| {
EmbeddingError::InferenceFailed("ONNX session pool exhausted".to_string())
})?;
if samples.len() != self.window_samples {
return Err(EmbeddingError::InvalidInput {
expected: self.window_samples,
got: samples.len(),
});
}
let input_tensor =
ort::value::TensorRef::from_array_view(([1_usize, self.window_samples], samples))
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
let session = guard
.session
.as_mut()
.ok_or_else(|| EmbeddingError::InferenceFailed("session not available".to_string()))?;
let outputs = session
.run(ort::inputs![input_tensor])
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
if outputs.iter().next().is_none() {
return Err(EmbeddingError::InferenceFailed(
"ONNX model produced no outputs".to_string(),
));
}
let (_, data) = &outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| EmbeddingError::InferenceFailed(e.to_string()))?;
let data_len = data.len();
if data_len != self.embedding_dim {
return Err(EmbeddingError::InferenceFailed(format!(
"expected embedding dim {}, got {}",
self.embedding_dim, data_len
)));
}
let mut embedding = vec![0.0f32; self.embedding_dim];
embedding.copy_from_slice(data);
l2_normalize(&mut embedding);
Ok(embedding)
}
fn embedding_dim(&self) -> usize {
self.embedding_dim
}
}
#[cfg(feature = "onnx")]
struct PooledSession<'a> {
session: Option<ort::session::Session>,
pool: &'a crossbeam_queue::ArrayQueue<ort::session::Session>,
}
#[cfg(feature = "onnx")]
impl Drop for PooledSession<'_> {
fn drop(&mut self) {
if let Some(session) = self.session.take() {
let _ = self.pool.push(session);
}
}
}
#[cfg(not(feature = "onnx"))]
pub struct OnnxEmbeddingExtractor;
#[cfg(not(feature = "onnx"))]
impl OnnxEmbeddingExtractor {
pub fn new(
_model_path: &Path,
_embedding_dim: usize,
_window_samples: usize,
_pool_size: usize,
) -> anyhow::Result<Self> {
anyhow::bail!("the `onnx` feature is not enabled")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
#[cfg_attr(miri, ignore)]
fn valid_onnx_file_passes_validation() {
let path = std::path::Path::new("models/silero_vad.onnx");
if !path.exists() {
return;
}
assert!(validate_onnx_header(path).is_ok());
}
#[test]
#[cfg_attr(miri, ignore)]
fn random_64_bytes_fails_validation() {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
tmp.write_all(&[0xAB; 64]).unwrap();
let result = validate_onnx_header(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("ONNX magic") || msg.contains("protobuf header"),
"unexpected error message: {msg}"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn empty_file_fails_validation() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let result = validate_onnx_header(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("too small"),
"unexpected error: {err}"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn file_with_onnx_magic_passes() {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
let mut data = vec![0u8; 64];
data[4..8].copy_from_slice(b"ONNX");
tmp.write_all(&data).unwrap();
assert!(validate_onnx_header(tmp.path()).is_ok());
}
#[test]
#[cfg_attr(miri, ignore)]
fn file_with_protobuf_header_passes() {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
let mut data = vec![0u8; 64];
data[0] = 0x08; data[1] = 0x08; tmp.write_all(&data).unwrap();
assert!(validate_onnx_header(tmp.path()).is_ok());
}
}