use thiserror::Error;
use std::path::Path;
use std::sync::Arc;
use candle_core::{Device, Tensor, DType, IndexOp};
use candle_transformers::models::bert::{BertModel, Config};
use tokenizers::Tokenizer;
#[cfg(not(feature = "custom-model"))]
const DEFAULT_EMBEDDING_MODEL: &str = "all-MiniLM-L6-v2";
#[cfg(feature = "custom-model")]
const DEFAULT_EMBEDDING_MODEL: &str = env!("DEFAULT_EMBEDDING_MODEL");
#[derive(Error, Debug)]
pub enum EmbeddingError {
#[error("Model loading failed: {0}")]
ModelLoading(String),
#[error("Tokenization failed: {0}")]
Tokenization(String),
#[error("Inference failed: {0}")]
Inference(String),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
}
pub type Result<T> = std::result::Result<T, EmbeddingError>;
pub struct EmbeddingGenerator {
model: Arc<BertModel>,
tokenizer: Arc<Tokenizer>,
device: Device,
dimension: usize,
}
impl std::fmt::Debug for EmbeddingGenerator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EmbeddingGenerator")
.field("device", &self.device)
.field("dimension", &self.dimension)
.field("model", &"<BertModel>")
.field("tokenizer", &"<Tokenizer>")
.finish()
}
}
pub trait EmbeddingFunction: Send + Sync {
fn generate_embedding(&self, text: &str) -> Result<Vec<f64>>;
fn dimension(&self) -> usize;
}
impl EmbeddingFunction for EmbeddingGenerator {
fn generate_embedding(&self, text: &str) -> Result<Vec<f64>> {
let encoding = self.tokenizer.encode(text, true)
.map_err(|e| EmbeddingError::Tokenization(format!("Tokenization failed: {}", e)))?;
let token_ids = encoding.get_ids();
let input_ids = Tensor::new(token_ids, &self.device)
.map_err(|e| EmbeddingError::Inference(format!("Failed to create input tensor: {}", e)))?;
let input_ids = input_ids.unsqueeze(0)
.map_err(|e| EmbeddingError::Inference(format!("Failed to add batch dimension: {}", e)))?;
let token_type_ids = Tensor::zeros((1, input_ids.dim(1).unwrap()), input_ids.dtype(), &self.device)
.map_err(|e| EmbeddingError::Inference(format!("Failed to create token type ids: {}", e)))?;
let outputs = self.model.forward(&input_ids, &token_type_ids, None)
.map_err(|e| EmbeddingError::Inference(format!("Model inference failed: {}", e)))?;
let cls_embedding = outputs.i((0, 0))
.map_err(|e| EmbeddingError::Inference(format!("Failed to extract CLS token: {}", e)))?;
let embedding_f32: Vec<f32> = cls_embedding.to_vec1()
.map_err(|e| EmbeddingError::Inference(format!("Failed to convert to Vec: {}", e)))?;
let embedding: Vec<f64> = embedding_f32.into_iter().map(|x| x as f64).collect();
let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
let normalized: Vec<f64> = if norm > 0.0 {
embedding.iter().map(|x| x / norm).collect()
} else {
embedding
};
Ok(normalized)
}
fn dimension(&self) -> usize {
self.dimension
}
}
impl EmbeddingGenerator {
pub fn new() -> Result<Self> {
Self::new_from_path(&format!("./models/{}", DEFAULT_EMBEDDING_MODEL))
}
pub fn new_from_path(model_path: &str) -> Result<Self> {
Self::configure_threading();
let device = Device::Cpu;
let (model, tokenizer, dimension) = Self::load_model_from_path(model_path, &device)?;
Ok(Self {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
device,
dimension,
})
}
fn configure_threading() {
let num_threads = num_cpus::get();
unsafe {
std::env::set_var("RAYON_NUM_THREADS", num_threads.to_string());
}
unsafe {
std::env::set_var("CANDLE_NUM_THREADS", num_threads.to_string());
}
}
fn load_model_from_path(model_path: &str, device: &Device) -> Result<(BertModel, Tokenizer, usize)> {
let model_dir = Path::new(model_path);
let tokenizer_path = model_dir.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(EmbeddingError::ModelLoading(format!(
"Tokenizer file not found: {}. Please ensure the model is properly downloaded.",
tokenizer_path.display()
)));
}
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| EmbeddingError::ModelLoading(format!("Failed to load tokenizer: {}", e)))?;
let config_path = model_dir.join("config.json");
if !config_path.exists() {
return Err(EmbeddingError::ModelLoading(format!(
"Config file not found: {}. Please ensure the model is properly downloaded.",
config_path.display()
)));
}
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| EmbeddingError::ModelLoading(format!("Failed to read config: {}", e)))?;
let config: Config = serde_json::from_str(&config_str)
.map_err(|e| EmbeddingError::ModelLoading(format!("Failed to parse config: {}", e)))?;
let dimension = config.hidden_size;
let model_file = model_dir.join("pytorch_model.bin");
if !model_file.exists() {
return Err(EmbeddingError::ModelLoading(format!(
"Model weights file not found: {}. Please ensure the model is properly downloaded.",
model_file.display()
)));
}
let weights = candle_nn::VarBuilder::from_pth(&model_file, DType::F32, device)
.map_err(|e| EmbeddingError::ModelLoading(format!("Failed to load weights: {}", e)))?;
let model = BertModel::load(weights, &config)
.map_err(|e| EmbeddingError::ModelLoading(format!("Failed to create model: {}", e)))?;
Ok((model, tokenizer, dimension))
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn generate_embedding(&self, text: &str) -> Result<Vec<f64>> {
<Self as EmbeddingFunction>::generate_embedding(self, text)
}
pub fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f64>>> {
use rayon::prelude::*;
texts.par_iter()
.map(|text| self.generate_embedding(text))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_generator() -> Box<dyn EmbeddingFunction> {
#[cfg(feature = "mock-embeddings")]
{
Box::new(MockEmbeddingGenerator::new())
}
#[cfg(not(feature = "mock-embeddings"))]
{
Box::new(EmbeddingGenerator::new().unwrap())
}
}
#[cfg(feature = "mock-embeddings")]
struct MockEmbeddingGenerator {
dimension: usize,
}
#[cfg(feature = "mock-embeddings")]
impl MockEmbeddingGenerator {
fn new() -> Self {
Self { dimension: 384 }
}
}
#[cfg(feature = "mock-embeddings")]
impl EmbeddingFunction for MockEmbeddingGenerator {
fn generate_embedding(&self, _text: &str) -> Result<Vec<f64>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
_text.hash(&mut hasher);
let hash = hasher.finish();
let mut embedding = vec![0.0; self.dimension];
for i in 0..self.dimension {
let seed = hash.wrapping_add(i as u64);
let value = (seed as f64) / (u64::MAX as f64) * 2.0 - 1.0; embedding[i] = value;
}
let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[test]
fn test_embedding_generation() {
let generator = create_test_generator();
let text = "hello world this is a test";
let embedding = generator.generate_embedding(text).unwrap();
assert_eq!(embedding.len(), 384); assert!(!embedding.iter().all(|&x| x == 0.0), "Embedding should not be all zeros");
}
#[test]
fn test_embedding_dimension() {
let generator = create_test_generator();
assert_eq!(generator.dimension(), 384);
}
#[test]
fn test_embedding_consistency() {
let generator = create_test_generator();
let text = "the quick brown fox";
let embedding1 = generator.generate_embedding(text).unwrap();
let embedding2 = generator.generate_embedding(text).unwrap();
for (a, b) in embedding1.iter().zip(embedding2.iter()) {
assert!((a - b).abs() < 1e-10, "Embeddings should be consistent");
}
}
#[test]
fn test_embedding_normalization() {
let generator = create_test_generator();
let text = "test normalization";
let embedding = generator.generate_embedding(text).unwrap();
let norm: f64 = embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-10, "Embedding should be L2 normalized");
}
#[test]
fn test_different_texts_different_embeddings() {
let generator = create_test_generator();
let text1 = "hello world";
let text2 = "goodbye universe";
let embedding1 = generator.generate_embedding(text1).unwrap();
let embedding2 = generator.generate_embedding(text2).unwrap();
let cosine_sim = crate::SimilarityMetric::Cosine.calculate(&embedding1, &embedding2);
assert!(cosine_sim < 0.99, "Different texts should produce different embeddings");
}
#[test]
fn test_batch_embedding_generation() {
let generator = create_test_generator();
let texts = vec![
"first text".to_string(),
"second text".to_string(),
"third text".to_string(),
];
let embeddings: Vec<Vec<f64>> = texts.iter()
.map(|text| generator.generate_embedding(text).unwrap())
.collect();
assert_eq!(embeddings.len(), 3);
assert_eq!(embeddings[0].len(), 384);
assert_eq!(embeddings[1].len(), 384);
assert_eq!(embeddings[2].len(), 384);
}
#[test]
fn test_empty_text_embedding() {
let generator = create_test_generator();
let embedding = generator.generate_embedding("").unwrap();
assert_eq!(embedding.len(), 384);
assert!(embedding.iter().all(|&x| x.abs() < 1.0), "Empty text should produce valid embedding");
}
}