use super::{EmbeddingProvider, FastEmbedModel};
use anyhow::{Context, Result};
use async_trait::async_trait;
use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
use rig_fastembed::{Client as FastembedClient, FastembedModel as RigFastembedModel};
use std::sync::Arc;
pub struct FastEmbedProvider {
client: Arc<FastembedClient>,
model: FastEmbedModel,
rig_model: RigFastembedModel,
dims: usize,
}
impl FastEmbedProvider {
pub fn new() -> Result<Self> {
Self::with_model(FastEmbedModel::default())
}
pub fn with_model(model: FastEmbedModel) -> Result<Self> {
let client = Arc::new(FastembedClient::new());
let rig_model = Self::to_rig_model(&model);
let dims = model.dimensions();
Ok(Self {
client,
model,
rig_model,
dims,
})
}
pub fn from_model_name(name: &str) -> Result<Self> {
let model: FastEmbedModel = name.parse()?;
Self::with_model(model)
}
fn to_rig_model(model: &FastEmbedModel) -> RigFastembedModel {
match model {
FastEmbedModel::AllMiniLM => RigFastembedModel::AllMiniLML6V2Q,
FastEmbedModel::BGESmallEN => RigFastembedModel::BGESmallENV15Q,
FastEmbedModel::BGEBaseEN => RigFastembedModel::BGEBaseENV15,
FastEmbedModel::BGELargeEN => RigFastembedModel::BGELargeENV15,
}
}
}
impl Default for FastEmbedProvider {
fn default() -> Self {
Self::new().expect("Failed to create default FastEmbed provider")
}
}
#[async_trait]
impl EmbeddingProvider for FastEmbedProvider {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let embedding_model = self.client.embedding_model(&self.rig_model);
let embeddings = embedding_model
.embed_texts(texts)
.await
.context("FastEmbed failed to generate embeddings")?;
let results: Vec<Vec<f32>> = embeddings
.into_iter()
.map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
.collect();
Ok(results)
}
fn dimensions(&self) -> usize {
self.dims
}
fn model_name(&self) -> &str {
self.model.rig_model_name()
}
fn provider_name(&self) -> &str {
"fastembed"
}
fn max_batch_size(&self) -> usize {
256
}
async fn health_check(&self) -> Result<bool> {
match self.embed_query("health check").await {
Ok(emb) => Ok(emb.len() == self.dims),
Err(_) => Ok(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_conversion() {
let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::AllMiniLM);
let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGESmallEN);
let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGEBaseEN);
let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGELargeEN);
}
#[test]
fn test_provider_creation() {
let provider = FastEmbedProvider::new().unwrap();
assert_eq!(provider.dimensions(), 384);
assert_eq!(provider.model_name(), "all-minilm");
assert_eq!(provider.provider_name(), "fastembed");
}
#[test]
fn test_from_model_name() {
let provider = FastEmbedProvider::from_model_name("bge-small").unwrap();
assert_eq!(provider.dimensions(), 384);
let provider = FastEmbedProvider::from_model_name("bge-base").unwrap();
assert_eq!(provider.dimensions(), 768);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn test_embed_documents() {
let provider = FastEmbedProvider::new().unwrap();
let texts = vec![
"Hello world".to_string(),
"How are you".to_string(),
];
let embeddings = provider.embed_documents(texts).await.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 384);
assert_eq!(embeddings[1].len(), 384);
}
#[tokio::test]
async fn test_embed_empty() {
let provider = FastEmbedProvider::new().unwrap();
let embeddings = provider.embed_documents(vec![]).await.unwrap();
assert!(embeddings.is_empty());
}
}