mod fallback;
#[cfg(feature = "fastembed-embeddings")]
mod fastembed_impl;
pub use fallback::FallbackEmbedder;
#[cfg(feature = "fastembed-embeddings")]
pub use fastembed_impl::FastEmbedEmbedder;
use crate::Result;
pub const DEFAULT_DIMENSIONS: usize = 1024;
pub trait Embedder: Send + Sync {
fn dimensions(&self) -> usize;
fn model_name(&self) -> &'static str;
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
}
#[cfg(feature = "fastembed-embeddings")]
pub fn create_embedder() -> Result<Box<dyn Embedder>> {
Ok(Box::new(FastEmbedEmbedder::new()?))
}
#[cfg(not(feature = "fastembed-embeddings"))]
pub fn create_embedder() -> Result<Box<dyn Embedder>> {
Ok(Box::new(FallbackEmbedder::new(DEFAULT_DIMENSIONS)))
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
dot / (mag_a * mag_b)
}
#[cfg(test)]
mod tests {
use super::*;
fn embedder_or_skip(test: &str) -> Option<Box<dyn Embedder>> {
match create_embedder().and_then(|e| e.embed("warmup probe").map(|_| e)) {
Ok(embedder) => Some(embedder),
Err(err) => {
eprintln!("skipping {test}: embedding model unavailable: {err}");
None
}
}
}
struct StubEmbedder;
impl Embedder for StubEmbedder {
fn dimensions(&self) -> usize {
3
}
fn model_name(&self) -> &'static str {
"stub"
}
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![1.0, 2.0, 3.0])
}
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_create_embedder() {
let Some(embedder) = embedder_or_skip("test_create_embedder") else {
return;
};
assert_eq!(embedder.dimensions(), DEFAULT_DIMENSIONS);
}
#[test]
fn test_embed_batch_default_impl() {
let embedder = StubEmbedder;
let texts = vec!["hello", "world", "test"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in &embeddings {
assert_eq!(embedding.len(), embedder.dimensions());
}
}
#[test]
fn test_embed_batch_empty() {
let embedder = StubEmbedder;
let texts: Vec<&str> = vec![];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert!(embeddings.is_empty());
}
}