use std::sync::Mutex;
use crate::error::{AlayaError, Result};
use crate::provider::EmbeddingProvider;
pub use fastembed::EmbeddingModel;
pub struct LocalEmbeddingProvider {
model: Mutex<fastembed::TextEmbedding>,
dimensions: usize,
}
impl LocalEmbeddingProvider {
pub fn new() -> Result<Self> {
Self::with_model(EmbeddingModel::AllMiniLML6V2)
}
pub fn with_model(model: EmbeddingModel) -> Result<Self> {
let mut text_embedding = fastembed::TextEmbedding::try_new(
fastembed::InitOptions::new(model).with_show_download_progress(false),
)
.map_err(|e| AlayaError::InvalidInput(format!("Failed to load embedding model: {e}")))?;
let test = text_embedding.embed(vec!["test"], None).map_err(|e| {
AlayaError::InvalidInput(format!("Failed to determine dimensions: {e}"))
})?;
let dimensions = test.first().map(|v| v.len()).unwrap_or(384);
Ok(Self {
model: Mutex::new(text_embedding),
dimensions,
})
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
impl EmbeddingProvider for LocalEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut model = self
.model
.lock()
.map_err(|e| AlayaError::InvalidInput(format!("Lock poisoned: {e}")))?;
let results = model
.embed(vec![text], None)
.map_err(|e| AlayaError::InvalidInput(format!("Embedding failed: {e}")))?;
results
.into_iter()
.next()
.ok_or_else(|| AlayaError::InvalidInput("No embedding returned".into()))
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
let mut model = self
.model
.lock()
.map_err(|e| AlayaError::InvalidInput(format!("Lock poisoned: {e}")))?;
let results = model
.embed(owned, None)
.map_err(|e| AlayaError::InvalidInput(format!("Batch embedding failed: {e}")))?;
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_local_embedding_produces_vector() {
let provider = LocalEmbeddingProvider::new().unwrap();
let embedding = provider.embed("Hello, world!").unwrap();
assert!(!embedding.is_empty());
assert_eq!(embedding.len(), 384); }
#[test]
#[ignore] fn test_local_embedding_consistent() {
let provider = LocalEmbeddingProvider::new().unwrap();
let e1 = provider.embed("test text").unwrap();
let e2 = provider.embed("test text").unwrap();
assert_eq!(e1, e2);
}
#[test]
#[ignore] fn test_local_embedding_different_texts_differ() {
let provider = LocalEmbeddingProvider::new().unwrap();
let e1 = provider.embed("cat").unwrap();
let e2 = provider.embed("quantum physics").unwrap();
assert_ne!(e1, e2);
}
#[test]
#[ignore] fn test_local_embedding_batch() {
let provider = LocalEmbeddingProvider::new().unwrap();
let results = provider.embed_batch(&["hello", "world"]).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 384);
}
#[test]
fn test_dimensions_type_exists() {
assert!(std::mem::size_of::<LocalEmbeddingProvider>() > 0);
}
}