use anyhow::{Context, Result};
pub use brainwires_core::EmbeddingProvider as EmbeddingProviderTrait;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use lru::LruCache;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
const DEFAULT_CACHE_SIZE: usize = 1000;
const EMBEDDING_DIM_MINILM: usize = 384;
const EMBEDDING_DIM_BGE_BASE: usize = 768;
pub struct FastEmbedManager {
model: RwLock<TextEmbedding>,
dimension: usize,
model_name: String,
}
impl FastEmbedManager {
pub fn new() -> Result<Self> {
Self::with_model(EmbeddingModel::AllMiniLML6V2)
}
pub fn from_model_name(model_name: &str) -> Result<Self> {
let model = match model_name {
"all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
"all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2,
"BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
"BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
_ => {
tracing::warn!(
"Unknown model '{}', falling back to all-MiniLM-L6-v2",
model_name
);
EmbeddingModel::AllMiniLML6V2
}
};
Self::with_model(model)
}
pub fn with_model(model: EmbeddingModel) -> Result<Self> {
tracing::info!("Initializing FastEmbed model: {:?}", model);
let (dimension, name) = match model {
EmbeddingModel::AllMiniLML6V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
EmbeddingModel::AllMiniLML12V2 => (EMBEDDING_DIM_MINILM, "all-MiniLM-L12-v2"),
EmbeddingModel::BGEBaseENV15 => (EMBEDDING_DIM_BGE_BASE, "BAAI/bge-base-en-v1.5"),
EmbeddingModel::BGESmallENV15 => (EMBEDDING_DIM_MINILM, "BAAI/bge-small-en-v1.5"),
_ => (EMBEDDING_DIM_MINILM, "all-MiniLM-L6-v2"),
};
let mut options = InitOptions::default();
options.model_name = model;
options.show_download_progress = true;
let embedding_model =
TextEmbedding::try_new(options).context("Failed to initialize FastEmbed model")?;
Ok(Self {
model: RwLock::new(embedding_model),
dimension,
model_name: name.to_string(),
})
}
pub fn embed_batch_vec(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
tracing::debug!("Generating embeddings for {} texts", texts.len());
let mut model = self.model.write().unwrap_or_else(|poisoned| {
tracing::warn!("FastEmbed model lock was poisoned, recovering...");
poisoned.into_inner()
});
let embeddings = model
.embed(texts, None)
.context("Failed to generate embeddings")?;
Ok(embeddings)
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embedding generated"))
}
pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch_vec(texts.to_vec())
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn model_name(&self) -> &str {
&self.model_name
}
}
impl EmbeddingProviderTrait for FastEmbedManager {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch_vec(vec![text.to_string()])?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embedding generated"))
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.embed_batch_vec(texts.to_vec())
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model_name
}
}
impl Default for FastEmbedManager {
fn default() -> Self {
Self::new().expect("Failed to initialize default FastEmbed model")
}
}
pub struct CachedEmbeddingProvider {
inner: Arc<FastEmbedManager>,
cache: RwLock<LruCache<u64, Vec<f32>>>,
}
impl CachedEmbeddingProvider {
pub fn new() -> Result<Self> {
let inner = FastEmbedManager::new().context("Failed to create embedding provider")?;
Ok(Self {
inner: Arc::new(inner),
cache: RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
)),
})
}
pub fn with_manager(manager: Arc<FastEmbedManager>) -> Self {
Self {
inner: manager,
cache: RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
)),
}
}
fn hash_text(text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
pub fn embed_cached(&self, text: &str) -> Result<Vec<f32>> {
let cache_key = Self::hash_text(text);
if let Ok(cache) = self.cache.read()
&& let Some(embedding) = cache.peek(&cache_key)
{
return Ok(embedding.clone());
}
let embedding = self.inner.embed(text)?;
if let Ok(mut cache) = self.cache.write() {
cache.put(cache_key, embedding.clone());
}
Ok(embedding)
}
pub fn cache_len(&self) -> usize {
self.cache.read().map(|c| c.len()).unwrap_or(0)
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
pub fn inner(&self) -> &Arc<FastEmbedManager> {
&self.inner
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_cached(text)
}
pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.inner.embed_batch_vec(texts.to_vec())
}
pub fn dimension(&self) -> usize {
self.inner.dimension
}
pub fn model_name(&self) -> &str {
&self.inner.model_name
}
}
impl EmbeddingProviderTrait for CachedEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_cached(text)
}
fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.inner.embed_batch(texts)
}
fn dimension(&self) -> usize {
self.inner.dimension()
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
}
impl Clone for CachedEmbeddingProvider {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
cache: RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).expect("DEFAULT_CACHE_SIZE is non-zero"),
)),
}
}
}
pub type EmbeddingProvider = CachedEmbeddingProvider;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fastembed_creation() {
let manager = FastEmbedManager::new().unwrap();
assert_eq!(manager.dimension(), 384);
assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
}
#[test]
fn test_fastembed_embed_single() {
let manager = FastEmbedManager::new().unwrap();
let embedding = manager.embed("Hello, world!").unwrap();
assert_eq!(embedding.len(), 384);
}
#[test]
fn test_fastembed_embed_batch() {
let manager = FastEmbedManager::new().unwrap();
let texts = vec![
"fn main() { println!(\"Hello, world!\"); }".to_string(),
"pub struct Vector { x: f32, y: f32 }".to_string(),
];
let embeddings = manager.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 384);
assert_eq!(embeddings[1].len(), 384);
}
#[test]
fn test_fastembed_empty_batch() {
let manager = FastEmbedManager::new().unwrap();
let embeddings = manager.embed_batch_vec(vec![]).unwrap();
assert_eq!(embeddings.len(), 0);
}
#[test]
fn test_fastembed_default() {
let manager = FastEmbedManager::default();
assert_eq!(manager.dimension(), 384);
}
#[test]
fn test_fastembed_from_model_name() {
let manager = FastEmbedManager::from_model_name("all-MiniLM-L6-v2").unwrap();
assert_eq!(manager.dimension(), 384);
}
#[test]
fn test_fastembed_unknown_model_fallback() {
let manager = FastEmbedManager::from_model_name("unknown-model").unwrap();
assert_eq!(manager.dimension(), 384);
assert_eq!(manager.model_name(), "all-MiniLM-L6-v2");
}
#[test]
fn test_cached_provider_creation() {
let provider = CachedEmbeddingProvider::new().unwrap();
assert_eq!(provider.dimension(), 384);
}
#[test]
fn test_cached_provider_embed_single() {
let provider = CachedEmbeddingProvider::new().unwrap();
let embedding = provider.embed("Hello, world!").unwrap();
assert_eq!(embedding.len(), 384);
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.1);
}
#[test]
fn test_cached_provider_embed_batch() {
let provider = CachedEmbeddingProvider::new().unwrap();
let texts = vec![
"First message".to_string(),
"Second message".to_string(),
"Third message".to_string(),
];
let embeddings = provider.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
assert_eq!(embeddings[0].len(), 384);
assert_eq!(embeddings[1].len(), 384);
assert_eq!(embeddings[2].len(), 384);
}
#[test]
fn test_cached_provider_clone() {
let provider = CachedEmbeddingProvider::new().unwrap();
let cloned = provider.clone();
assert_eq!(provider.dimension(), cloned.dimension());
}
#[test]
fn test_cached_provider_caching() {
let provider = CachedEmbeddingProvider::new().unwrap();
let embedding1 = provider.embed_cached("test query").unwrap();
assert_eq!(provider.cache_len(), 1);
let embedding2 = provider.embed_cached("test query").unwrap();
assert_eq!(provider.cache_len(), 1);
assert_eq!(embedding1, embedding2);
let _embedding3 = provider.embed_cached("different query").unwrap();
assert_eq!(provider.cache_len(), 2);
provider.clear_cache();
assert_eq!(provider.cache_len(), 0);
}
}