use std::path::Path;
pub use polyvoice::{
DiarizationConfig, EmbeddingError, EmbeddingExtractor, OnlineDiarizer, OnnxEmbeddingExtractor,
SampleRate, SpeakerId,
};
pub const EMBEDDING_DIM: usize = 256;
pub const SEGMENT_SAMPLES: usize = 24000;
pub fn load_extractor(
model_dir: &Path,
pool_size: usize,
) -> anyhow::Result<OnnxEmbeddingExtractor> {
let path = model_dir.join("wespeaker_resnet34.onnx");
if !path.exists() {
anyhow::bail!(
"wespeaker_resnet34.onnx not found in {}",
model_dir.display()
);
}
OnnxEmbeddingExtractor::new(&path, EMBEDDING_DIM, SEGMENT_SAMPLES, pool_size)
.map_err(|e| anyhow::anyhow!("{e:#}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_dim_constant() {
assert_eq!(EMBEDDING_DIM, 256);
}
#[test]
fn test_segment_samples_constant() {
assert_eq!(SEGMENT_SAMPLES, 24000);
}
#[test]
fn test_load_extractor_missing_file() {
let result = load_extractor(Path::new("/nonexistent/path"), 1);
assert!(result.is_err());
let err = result.err().unwrap();
let msg = format!("{err}");
assert!(msg.contains("wespeaker_resnet34.onnx"));
}
}