use crate::vectordb::embedding::{EmbeddingModelType};
use crate::vectordb::error::{Result, VectorDBError};
use crate::vectordb::provider::EmbeddingProvider;
#[cfg(feature = "ort")]
use crate::vectordb::provider::onnx::OnnxEmbeddingModel;
use std::path::PathBuf;
use std::sync::Mutex;
use crate::config::AppConfig;
#[derive(Debug)]
pub struct EmbeddingHandler {
embedding_model_type: EmbeddingModelType,
onnx_model_path: Option<PathBuf>,
onnx_tokenizer_path: Option<PathBuf>,
provider_cache: Mutex<Option<Box<dyn EmbeddingProvider>>>,
}
impl EmbeddingHandler {
pub fn new(config: &AppConfig) -> std::result::Result<Self, VectorDBError> {
let model_type = EmbeddingModelType::Onnx;
let provider_result: Result<Box<dyn EmbeddingProvider>> = match model_type {
EmbeddingModelType::Onnx | EmbeddingModelType::Default => {
#[cfg(feature = "ort")]
{
let model_path_str = config.onnx_model_path.as_deref()
.ok_or_else(|| VectorDBError::ConfigurationError("ONNX model path not set in AppConfig".to_string()))?;
let tokenizer_path_str = config.onnx_tokenizer_path.as_deref()
.ok_or_else(|| VectorDBError::ConfigurationError("ONNX tokenizer path not set in AppConfig".to_string()))?;
let model_path = PathBuf::from(model_path_str);
let tokenizer_path = PathBuf::from(tokenizer_path_str);
OnnxEmbeddingModel::new(&model_path, &tokenizer_path)
.map(|provider| Box::new(provider) as Box<dyn EmbeddingProvider>)
.map_err(VectorDBError::from) }
#[cfg(not(feature = "ort"))]
{
Err(VectorDBError::FeatureNotEnabled("ort".to_string()))
}
},
};
Ok(Self {
embedding_model_type: model_type,
onnx_model_path: config.onnx_model_path.clone().map(PathBuf::from),
onnx_tokenizer_path: config.onnx_tokenizer_path.clone().map(PathBuf::from),
provider_cache: Mutex::new(provider_result.ok()), })
}
pub fn create_embedding_model(&self) -> Result<Box<dyn EmbeddingProvider>> {
match self.embedding_model_type {
EmbeddingModelType::Onnx => {
let model_path = self.onnx_model_path.as_ref().ok_or_else(|| {
VectorDBError::EmbeddingError("ONNX model path not set in handler.".to_string())
})?;
let tokenizer_path = self.onnx_tokenizer_path.as_ref().ok_or_else(|| {
VectorDBError::EmbeddingError("ONNX tokenizer path not set in handler.".to_string())
})?;
let provider: Box<dyn EmbeddingProvider> = Box::new(OnnxEmbeddingModel::new(
model_path,
tokenizer_path,
)?);
Ok(provider)
}
EmbeddingModelType::Default => {
Err(VectorDBError::NotImplemented("Default embedding model provider not yet implemented".to_string()))
}
}
}
pub fn set_onnx_paths(
&mut self,
model_path: Option<PathBuf>,
tokenizer_path: Option<PathBuf>,
) -> Result<()> {
if let Some(model_p) = &model_path {
if !model_p.exists() {
return Err(VectorDBError::EmbeddingError(format!(
"ONNX model file not found: {}",
model_p.display()
)));
}
}
if let Some(tokenizer_p) = &tokenizer_path {
if !tokenizer_p.exists() {
return Err(VectorDBError::EmbeddingError(format!(
"ONNX tokenizer file not found: {}",
tokenizer_p.display()
)));
}
}
if model_path.is_some() || tokenizer_path.is_some() {
self.embedding_model_type = EmbeddingModelType::Onnx;
}
self.onnx_model_path = model_path;
self.onnx_tokenizer_path = tokenizer_path;
self.provider_cache.lock().unwrap().take();
Ok(())
}
pub fn embedding_model_type(&self) -> EmbeddingModelType {
self.embedding_model_type
}
pub fn onnx_model_path(&self) -> Option<&PathBuf> {
self.onnx_model_path.as_ref()
}
pub fn onnx_tokenizer_path(&self) -> Option<&PathBuf> {
self.onnx_tokenizer_path.as_ref()
}
pub fn dimension(&self) -> Result<usize> {
let mut cache_guard = self.provider_cache.lock().unwrap();
if cache_guard.is_none() {
log::debug!("Provider cache miss for dimension. Creating provider...");
let provider = self.create_embedding_model()?;
cache_guard.replace(provider);
}
Ok(cache_guard.as_ref().unwrap().dimension())
}
pub fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut cache_guard = self.provider_cache.lock().unwrap();
if cache_guard.is_none() {
log::debug!("Provider cache miss for embed. Creating provider...");
let provider = self.create_embedding_model()?;
cache_guard.replace(provider);
}
cache_guard.as_mut().unwrap().embed_batch(texts)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AppConfig;
use crate::vectordb::error::VectorDBError;
use std::fs;
use tempfile::tempdir;
fn create_test_config(
model_path: Option<String>,
tokenizer_path: Option<String>,
) -> AppConfig {
AppConfig {
qdrant_url: "http://localhost:6333".to_string(),
onnx_model_path: model_path,
onnx_tokenizer_path: tokenizer_path,
server_api_key_path: None,
repositories: vec![],
active_repository: None,
repositories_base_path: None,
}
}
#[test]
fn test_embedding_handler_new_onnx_valid_paths() {
let dir = tempdir().unwrap();
let model_path = dir.path().join("model.onnx");
let tokenizer_path = dir.path().join("tokenizer.json");
fs::write(&model_path, b"dummy model").unwrap();
fs::write(&tokenizer_path, b"dummy tokenizer").unwrap();
let config = create_test_config(
Some(model_path.to_str().unwrap().to_string()),
Some(tokenizer_path.to_str().unwrap().to_string()),
);
let result = EmbeddingHandler::new(&config);
assert!(result.is_ok());
}
#[test]
fn test_embedding_handler_new_onnx_missing_paths() {
let config = create_test_config(None, None); let result = EmbeddingHandler::new(&config);
assert!(
matches!(result, Err(VectorDBError::ConfigurationError(_))),
"Expected ConfigurationError when ONNX paths are None, got {:?}", result
);
}
#[test]
fn test_embedding_handler_new_onnx_invalid_paths() {
let dir = tempdir().unwrap();
let model_path_buf = dir.path().join("model.onnx");
let tokenizer_path_buf = dir.path().join("tokenizer.json");
fs::write(&model_path_buf, b"").expect("Failed to create dummy model file");
fs::write(&tokenizer_path_buf, b"").expect("Failed to create dummy tokenizer file");
fs::remove_file(&model_path_buf).expect("Failed to remove dummy model file");
let config = create_test_config(
Some(model_path_buf.to_str().unwrap().to_string()),
Some(tokenizer_path_buf.to_str().unwrap().to_string()),
);
let result = EmbeddingHandler::new(&config);
assert!(result.is_ok(), "Expected Ok even with non-existent file, got {:?}", result);
dir.close().expect("Failed to close temp dir");
}
}