use anyhow::{Context, Result};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
pub trait EmbeddingProvider: Send + Sync {
fn embed(&mut self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimensions(&self) -> usize;
fn model_id(&self) -> &str;
}
pub struct FastEmbedProvider {
model: TextEmbedding,
model_id: String,
dimensions: usize,
}
impl FastEmbedProvider {
pub fn new() -> Result<Self> {
let cache_dir = crate::paths::fastembed_cache_dir();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::BGEBaseENV15)
.with_cache_dir(cache_dir)
.with_show_download_progress(true),
)
.context("Failed to initialize FastEmbed model")?;
Ok(Self {
model,
model_id: "BAAI/bge-base-en-v1.5".to_string(),
dimensions: 768,
})
}
}
impl EmbeddingProvider for FastEmbedProvider {
fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
let embeddings = self
.model
.embed(vec![text.to_string()], None)
.context("Failed to generate embedding")?;
embeddings
.into_iter()
.next()
.context("No embedding returned")
}
fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
self.model
.embed(texts, None)
.context("Failed to generate batch embeddings")
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_id(&self) -> &str {
&self.model_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn test_embed_single() -> Result<()> {
let mut provider = FastEmbedProvider::new()?;
let embedding = provider.embed("Hello, world!")?;
assert_eq!(embedding.len(), 768);
Ok(())
}
#[test]
#[serial]
fn test_embed_batch() -> Result<()> {
let mut provider = FastEmbedProvider::new()?;
let texts = vec!["First text".to_string(), "Second text".to_string()];
let embeddings = provider.embed_batch(&texts)?;
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 768);
assert_eq!(embeddings[1].len(), 768);
Ok(())
}
#[test]
#[serial]
fn test_dimensions() -> Result<()> {
let provider = FastEmbedProvider::new()?;
assert_eq!(provider.dimensions(), 768);
Ok(())
}
}