use crate::vectordb::embedding::{EmbeddingModel, EmbeddingModelType};
use crate::vectordb::error::{Result, VectorDBError};
use std::path::PathBuf;
#[derive(Clone, Debug)]
pub struct EmbeddingHandler {
embedding_model_type: EmbeddingModelType,
onnx_model_path: Option<PathBuf>,
onnx_tokenizer_path: Option<PathBuf>,
}
impl EmbeddingHandler {
pub fn new(
embedding_model_type: EmbeddingModelType,
onnx_model_path: Option<PathBuf>,
onnx_tokenizer_path: Option<PathBuf>,
) -> Result<Self> {
if embedding_model_type == EmbeddingModelType::Onnx {
match (&onnx_model_path, &onnx_tokenizer_path) {
(Some(model_p), Some(tok_p)) => {
if !model_p.exists() {
return Err(VectorDBError::FileNotFound(format!(
"ONNX model file not found: {}",
model_p.display()
)));
}
if !tok_p.exists() {
return Err(VectorDBError::FileNotFound(format!(
"ONNX tokenizer file not found: {}",
tok_p.display()
)));
}
}
_ => {
return Err(VectorDBError::ConfigurationError(
"ONNX model type requires both model and tokenizer paths.".to_string()
));
}
}
}
Ok(Self {
embedding_model_type,
onnx_model_path,
onnx_tokenizer_path,
})
}
pub fn create_embedding_model(&self) -> Result<EmbeddingModel> {
match self.embedding_model_type {
EmbeddingModelType::Onnx => {
if let (Some(model_path), Some(tokenizer_path)) =
(&self.onnx_model_path, &self.onnx_tokenizer_path)
{
EmbeddingModel::new_onnx(model_path, tokenizer_path)
.map_err(|e| VectorDBError::EmbeddingError(e.to_string()))
} else {
Err(VectorDBError::EmbeddingError(
"ONNX model paths not set in handler.".to_string(),
))
}
}
}
}
pub fn set_onnx_paths(
&mut self,
model_path: Option<PathBuf>,
tokenizer_path: Option<PathBuf>,
) -> Result<()> {
if let Some(model_p) = &model_path {
if !model_p.exists() {
return Err(VectorDBError::EmbeddingError(format!(
"ONNX model file not found: {}",
model_p.display()
)));
}
}
if let Some(tokenizer_p) = &tokenizer_path {
if !tokenizer_p.exists() {
return Err(VectorDBError::EmbeddingError(format!(
"ONNX tokenizer file not found: {}",
tokenizer_p.display()
)));
}
}
if model_path.is_some() || tokenizer_path.is_some() {
self.embedding_model_type = EmbeddingModelType::Onnx;
}
self.onnx_model_path = model_path;
self.onnx_tokenizer_path = tokenizer_path;
Ok(())
}
pub fn embedding_model_type(&self) -> EmbeddingModelType {
self.embedding_model_type
}
pub fn onnx_model_path(&self) -> Option<&PathBuf> {
self.onnx_model_path.as_ref()
}
pub fn onnx_tokenizer_path(&self) -> Option<&PathBuf> {
self.onnx_tokenizer_path.as_ref()
}
pub fn dimension(&self) -> Result<usize> {
let model = self.create_embedding_model()?;
Ok(model.dim())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectordb::embedding::EmbeddingModelType;
use std::fs::File;
use tempfile::tempdir;
fn create_dummy_file(path: &PathBuf) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
File::create(path)?;
Ok(())
}
#[test]
fn test_embedding_handler_new_onnx_valid_paths() -> Result<()> {
let dir = tempdir()?;
let model_path = dir.path().join("model.onnx");
let tokenizer_path = dir.path().join("tokenizer.json");
File::create(&model_path)?;
File::create(&tokenizer_path)?;
let handler = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path.clone()),
Some(tokenizer_path.clone()),
)?;
assert_eq!(handler.embedding_model_type(), EmbeddingModelType::Onnx);
assert_eq!(handler.onnx_model_path(), Some(&model_path));
assert_eq!(handler.onnx_tokenizer_path(), Some(&tokenizer_path));
Ok(())
}
#[test]
fn test_embedding_handler_new_onnx_missing_paths() {
let result = EmbeddingHandler::new(EmbeddingModelType::Onnx, None, None);
assert!(matches!(
result,
Err(VectorDBError::ConfigurationError(_))
));
if let Err(VectorDBError::ConfigurationError(msg)) = result {
assert!(msg.contains("requires both model and tokenizer paths"));
}
}
#[test]
fn test_embedding_handler_new_onnx_missing_model_path() {
let dir = tempdir().unwrap();
let tokenizer_path = dir.path().join("tokenizer.json");
File::create(&tokenizer_path).unwrap();
let result = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
None, Some(tokenizer_path),
);
assert!(matches!(
result,
Err(VectorDBError::ConfigurationError(_))
));
if let Err(VectorDBError::ConfigurationError(msg)) = result {
assert!(msg.contains("requires both model and tokenizer paths"));
}
}
#[test]
fn test_embedding_handler_new_onnx_missing_tokenizer_path() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("model.onnx");
File::create(&model_path).unwrap();
let result = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path), None,
);
assert!(matches!(
result,
Err(VectorDBError::ConfigurationError(_))
));
if let Err(VectorDBError::ConfigurationError(msg)) = result {
assert!(msg.contains("requires both model and tokenizer paths"));
}
}
#[test]
fn test_embedding_handler_new_onnx_invalid_model_path() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("non_existent_model.onnx");
let tokenizer_path = dir.path().join("tokenizer.json");
File::create(&tokenizer_path).unwrap();
let result = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path.clone()),
Some(tokenizer_path),
);
assert!(matches!(result, Err(VectorDBError::FileNotFound(_))));
if let Err(VectorDBError::FileNotFound(msg)) = result {
assert!(msg.contains("ONNX model file not found"));
assert!(msg.contains("non_existent_model.onnx"));
}
}
#[test]
fn test_embedding_handler_new_onnx_invalid_tokenizer_path() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("model.onnx");
let tokenizer_path = dir.path().join("non_existent_tokenizer.json");
File::create(&model_path).unwrap();
let result = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path),
Some(tokenizer_path.clone()),
);
assert!(matches!(result, Err(VectorDBError::FileNotFound(_))));
if let Err(VectorDBError::FileNotFound(msg)) = result {
assert!(msg.contains("ONNX tokenizer file not found"));
assert!(msg.contains("non_existent_tokenizer.json"));
}
}
#[test]
fn test_set_onnx_paths_valid() -> Result<()> {
let dir = tempdir()?;
let model_path = dir.path().join("model_v1.onnx");
let tokenizer_path = dir.path().join("tokenizer_v1.json");
File::create(&model_path)?;
File::create(&tokenizer_path)?;
#[allow(clippy::unnecessary_lazy_evaluations)]
let mut handler = EmbeddingHandler::new(EmbeddingModelType::Onnx, None, None).unwrap_or_else(|_|
EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: None
}
);
assert_eq!(handler.onnx_model_path(), None);
assert_eq!(handler.onnx_tokenizer_path(), None);
handler.set_onnx_paths(Some(model_path.clone()), Some(tokenizer_path.clone()))?;
assert_eq!(handler.embedding_model_type(), EmbeddingModelType::Onnx);
assert_eq!(handler.onnx_model_path(), Some(&model_path));
assert_eq!(handler.onnx_tokenizer_path(), Some(&tokenizer_path));
Ok(())
}
#[test]
fn test_set_onnx_paths_clear() -> Result<()> {
let dir = tempdir()?;
let model_path = dir.path().join("model.onnx");
let tokenizer_path = dir.path().join("tokenizer.json");
File::create(&model_path)?;
File::create(&tokenizer_path)?;
let mut handler = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path.clone()),
Some(tokenizer_path.clone()),
)?;
handler.set_onnx_paths(None, None)?;
assert_eq!(handler.embedding_model_type(), EmbeddingModelType::Onnx);
assert_eq!(handler.onnx_model_path(), None);
assert_eq!(handler.onnx_tokenizer_path(), None);
Ok(())
}
#[test]
fn test_set_onnx_paths_invalid_model() {
let dir = tempdir().unwrap();
let invalid_model_path = dir.path().join("bad_model.onnx");
let tokenizer_path = dir.path().join("good_tokenizer.json");
File::create(&tokenizer_path).unwrap();
let mut handler = EmbeddingHandler { embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: None,
};
let result = handler.set_onnx_paths(Some(invalid_model_path.clone()), Some(tokenizer_path));
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX model file not found"));
assert!(msg.contains("bad_model.onnx"));
}
assert_eq!(handler.onnx_model_path(), None);
assert_eq!(handler.onnx_tokenizer_path(), None);
}
#[test]
fn test_set_onnx_paths_invalid_tokenizer() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("good_model.onnx");
let invalid_tokenizer_path = dir.path().join("bad_tokenizer.json");
File::create(&model_path).unwrap();
let mut handler = EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: None,
};
let result = handler.set_onnx_paths(Some(model_path), Some(invalid_tokenizer_path.clone()));
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX tokenizer file not found"));
assert!(msg.contains("bad_tokenizer.json"));
}
assert_eq!(handler.onnx_model_path(), None);
assert_eq!(handler.onnx_tokenizer_path(), None);
}
#[test]
fn test_create_embedding_model_onnx_paths_none() {
let handler = EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: None,
};
let result = handler.create_embedding_model();
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX model paths not set in handler"));
}
}
#[test]
fn test_create_embedding_model_onnx_model_path_none() {
let dir = tempdir().unwrap();
let tokenizer_path = dir.path().join("tokenizer.json");
File::create(&tokenizer_path).unwrap();
let handler = EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: Some(tokenizer_path),
};
let result = handler.create_embedding_model();
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX model paths not set in handler"));
}
}
#[test]
fn test_create_embedding_model_onnx_tokenizer_path_none() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("model.onnx");
File::create(&model_path).unwrap();
let handler = EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: Some(model_path),
onnx_tokenizer_path: None,
};
let result = handler.create_embedding_model();
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX model paths not set in handler"));
}
}
#[test]
fn test_embedding_handler_dimension_onnx_success() -> Result<()> {
let model_path = PathBuf::from("onnx/all-minilm-l12-v2.onnx");
let tokenizer_path = PathBuf::from("onnx/minilm_tokenizer.json");
if !model_path.exists() || !tokenizer_path.exists() {
println!("Skipping test_embedding_handler_dimension_onnx_success: ONNX files not found at expected paths.");
return Ok(());
}
let handler = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(model_path.clone()),
Some(tokenizer_path.clone()),
)?;
let dim = handler.dimension()?;
assert_eq!(dim, 384, "Expected dimension for MiniLM L12 v2");
Ok(())
}
#[test]
fn test_embedding_handler_dimension_onnx_fail_missing_path() {
let handler_no_paths = EmbeddingHandler {
embedding_model_type: EmbeddingModelType::Onnx,
onnx_model_path: None,
onnx_tokenizer_path: None,
};
let result = handler_no_paths.dimension();
assert!(matches!(result, Err(VectorDBError::EmbeddingError(_))));
if let Err(VectorDBError::EmbeddingError(msg)) = result {
assert!(msg.contains("ONNX model paths not set in handler"));
}
let dir = tempdir().unwrap();
let invalid_model_path = dir.path().join("invalid_model.onnx");
let invalid_tokenizer_path = dir.path().join("invalid_tokenizer.json");
create_dummy_file(&invalid_model_path).unwrap();
create_dummy_file(&invalid_tokenizer_path).unwrap();
let handler_invalid_files = EmbeddingHandler::new(
EmbeddingModelType::Onnx,
Some(invalid_model_path),
Some(invalid_tokenizer_path),
).expect("Handler creation should succeed with existing (but invalid) files");
let result_invalid = handler_invalid_files.dimension();
assert!(matches!(result_invalid, Err(VectorDBError::EmbeddingError(_))), "Expected EmbeddingError for invalid ONNX model/tokenizer files");
if let Err(VectorDBError::EmbeddingError(msg)) = result_invalid {
assert!(msg.contains("Failed to create ONNX provider"), "Error message mismatch: {}", msg);
}
}
}