mod fallback;
#[cfg(feature = "fastembed-embeddings")]
mod fastembed_impl;
pub use fallback::FallbackEmbedder;
#[cfg(feature = "fastembed-embeddings")]
pub use fastembed_impl::FastEmbedEmbedder;
use crate::Result;
pub const DEFAULT_DIMENSIONS: usize = 1024;
pub trait Embedder: Send + Sync {
fn dimensions(&self) -> usize;
fn model_name(&self) -> &'static str;
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
}
#[cfg(feature = "fastembed-embeddings")]
pub fn create_embedder() -> Result<Box<dyn Embedder>> {
Ok(Box::new(FastEmbedEmbedder::new()?))
}
#[cfg(not(feature = "fastembed-embeddings"))]
pub fn create_embedder() -> Result<Box<dyn Embedder>> {
Ok(Box::new(FallbackEmbedder::new(DEFAULT_DIMENSIONS)))
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
dot / (mag_a * mag_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_create_embedder() {
let embedder = create_embedder().unwrap();
assert_eq!(embedder.dimensions(), DEFAULT_DIMENSIONS);
}
#[test]
fn test_embed_batch_default_impl() {
let embedder = create_embedder().unwrap();
let texts = vec!["hello", "world", "test"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in &embeddings {
assert_eq!(embedding.len(), embedder.dimensions());
}
}
#[test]
fn test_embed_batch_empty() {
let embedder = create_embedder().unwrap();
let texts: Vec<&str> = vec![];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert!(embeddings.is_empty());
}
}