mod types;
mod fastembed;
mod openai;
mod ollama;
mod factory;
pub use types::*;
pub use fastembed::FastEmbedProvider;
pub use openai::OpenAIEmbedProvider;
pub use ollama::OllamaProvider;
pub use factory::{EmbeddingProviderFactory, create_provider};
use async_trait::async_trait;
use anyhow::Result;
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed_documents(vec![text.to_string()]).await?;
results.into_iter().next().ok_or_else(|| {
anyhow::anyhow!("embed_documents returned empty result for single query")
})
}
fn dimensions(&self) -> usize;
fn model_name(&self) -> &str;
fn provider_name(&self) -> &str;
async fn health_check(&self) -> Result<bool> {
match self.embed_query("test").await {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
fn max_batch_size(&self) -> usize {
100 }
async fn embed_documents_batched(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let batch_size = self.max_batch_size();
if texts.len() <= batch_size {
return self.embed_documents(texts).await;
}
let mut all_embeddings = Vec::with_capacity(texts.len());
for chunk in texts.chunks(batch_size) {
let embeddings = self.embed_documents(chunk.to_vec()).await?;
all_embeddings.extend(embeddings);
}
Ok(all_embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockProvider {
dims: usize,
}
#[async_trait]
impl EmbeddingProvider for MockProvider {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
Ok(texts
.iter()
.map(|_| vec![0.1; self.dims])
.collect())
}
fn dimensions(&self) -> usize {
self.dims
}
fn model_name(&self) -> &str {
"mock-model"
}
fn provider_name(&self) -> &str {
"mock"
}
fn max_batch_size(&self) -> usize {
2
}
}
#[tokio::test]
async fn test_embed_query_default() {
let provider = MockProvider { dims: 384 };
let embedding = provider.embed_query("test query").await.unwrap();
assert_eq!(embedding.len(), 384);
}
#[tokio::test]
async fn test_embed_documents_batched() {
let provider = MockProvider { dims: 3 };
let texts: Vec<String> = (0..5).map(|i| format!("doc{}", i)).collect();
let embeddings = provider.embed_documents_batched(texts).await.unwrap();
assert_eq!(embeddings.len(), 5);
for emb in embeddings {
assert_eq!(emb.len(), 3);
}
}
#[tokio::test]
async fn test_health_check_default() {
let provider = MockProvider { dims: 3 };
let healthy = provider.health_check().await.unwrap();
assert!(healthy);
}
}