Skip to main content

rlm_rs/embedding/
mod.rs

1//! Embedding generation for semantic search.
2//!
3//! Provides embedding generation using fastembed (when available) or a
4//! hash-based fallback for deterministic pseudo-embeddings.
5//!
6//! # Feature Flags
7//!
8//! - `fastembed-embeddings`: Enables `FastEmbed` with BGE-M3 (1024 dimensions, 8192 token max)
9//! - Without the feature: Uses hash-based fallback (deterministic but not semantic)
10
11mod fallback;
12
13#[cfg(feature = "fastembed-embeddings")]
14mod fastembed_impl;
15
16pub use fallback::FallbackEmbedder;
17
18#[cfg(feature = "fastembed-embeddings")]
19pub use fastembed_impl::FastEmbedEmbedder;
20
21use crate::Result;
22
23/// Default embedding dimensions for the BGE-M3 model.
24///
25/// This is the authoritative source for embedding dimensions across the codebase.
26/// All vector backends should use this constant for consistency.
27pub const DEFAULT_DIMENSIONS: usize = 1024;
28
29/// Trait for embedding generators.
30///
31/// Implementations must be thread-safe (`Send + Sync`) to support parallel
32/// embedding generation during chunk loading.
33///
34/// # Examples
35///
36/// ```
37/// use rlm_rs::embedding::{Embedder, FallbackEmbedder, DEFAULT_DIMENSIONS};
38///
39/// let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
40/// let embedding = embedder.embed("Hello, world!").unwrap();
41/// assert_eq!(embedding.len(), DEFAULT_DIMENSIONS);
42/// ```
43pub trait Embedder: Send + Sync {
44    /// Returns the embedding dimensions.
45    fn dimensions(&self) -> usize;
46
47    /// Returns the model name/version identifier.
48    ///
49    /// This is stored with embeddings to detect model changes.
50    fn model_name(&self) -> &'static str;
51
52    /// Generates an embedding for the given text.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if embedding generation fails.
57    fn embed(&self, text: &str) -> Result<Vec<f32>>;
58
59    /// Generates embeddings for multiple texts.
60    ///
61    /// The default implementation calls `embed` for each text sequentially.
62    /// Implementations may override this for batch optimization.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if embedding generation fails for any text.
67    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
68        texts.iter().map(|t| self.embed(t)).collect()
69    }
70}
71
72/// Creates the default embedder based on available features.
73///
74/// - With `fastembed-embeddings`: Returns `FastEmbedEmbedder`
75/// - Without: Returns `FallbackEmbedder`
76///
77/// # Errors
78///
79/// Returns an error if embedder initialization fails.
80#[cfg(feature = "fastembed-embeddings")]
81pub fn create_embedder() -> Result<Box<dyn Embedder>> {
82    Ok(Box::new(FastEmbedEmbedder::new()?))
83}
84
85/// Creates the default embedder based on available features.
86///
87/// - With `fastembed-embeddings`: Returns `FastEmbedEmbedder`
88/// - Without: Returns `FallbackEmbedder`
89///
90/// # Errors
91///
92/// Returns an error if embedder initialization fails (never fails for fallback).
93#[cfg(not(feature = "fastembed-embeddings"))]
94pub fn create_embedder() -> Result<Box<dyn Embedder>> {
95    Ok(Box::new(FallbackEmbedder::new(DEFAULT_DIMENSIONS)))
96}
97
98/// Computes cosine similarity between two embedding vectors.
99///
100/// Returns a value between -1.0 (opposite) and 1.0 (identical).
101/// For normalized vectors (L2 norm = 1), this is equivalent to the dot product.
102///
103/// # Panics
104///
105/// Does not panic but returns 0.0 if vectors have different lengths or zero magnitude.
106#[must_use]
107pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
108    if a.len() != b.len() {
109        return 0.0;
110    }
111
112    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
113    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
114    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
115
116    if mag_a == 0.0 || mag_b == 0.0 {
117        return 0.0;
118    }
119
120    dot / (mag_a * mag_b)
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_cosine_similarity_identical() {
129        let a = vec![1.0, 0.0, 0.0];
130        let b = vec![1.0, 0.0, 0.0];
131        let sim = cosine_similarity(&a, &b);
132        assert!((sim - 1.0).abs() < 1e-6);
133    }
134
135    #[test]
136    fn test_cosine_similarity_orthogonal() {
137        let a = vec![1.0, 0.0, 0.0];
138        let b = vec![0.0, 1.0, 0.0];
139        let sim = cosine_similarity(&a, &b);
140        assert!(sim.abs() < 1e-6);
141    }
142
143    #[test]
144    fn test_cosine_similarity_opposite() {
145        let a = vec![1.0, 0.0, 0.0];
146        let b = vec![-1.0, 0.0, 0.0];
147        let sim = cosine_similarity(&a, &b);
148        assert!((sim + 1.0).abs() < 1e-6);
149    }
150
151    #[test]
152    fn test_cosine_similarity_different_lengths() {
153        let a = vec![1.0, 0.0];
154        let b = vec![1.0, 0.0, 0.0];
155        let sim = cosine_similarity(&a, &b);
156        assert!(sim.abs() < 1e-6);
157    }
158
159    #[test]
160    fn test_cosine_similarity_zero_vector() {
161        let a = vec![0.0, 0.0, 0.0];
162        let b = vec![1.0, 0.0, 0.0];
163        let sim = cosine_similarity(&a, &b);
164        assert!(sim.abs() < 1e-6);
165    }
166
167    #[test]
168    fn test_create_embedder() {
169        let embedder = create_embedder().unwrap();
170        assert_eq!(embedder.dimensions(), DEFAULT_DIMENSIONS);
171    }
172
173    #[test]
174    fn test_embed_batch_default_impl() {
175        // Test the default embed_batch implementation (lines 62-63)
176        let embedder = create_embedder().unwrap();
177        let texts = vec!["hello", "world", "test"];
178        let embeddings = embedder.embed_batch(&texts).unwrap();
179
180        assert_eq!(embeddings.len(), 3);
181        for embedding in &embeddings {
182            assert_eq!(embedding.len(), embedder.dimensions());
183        }
184    }
185
186    #[test]
187    fn test_embed_batch_empty() {
188        // Test embed_batch with empty slice
189        let embedder = create_embedder().unwrap();
190        let texts: Vec<&str> = vec![];
191        let embeddings = embedder.embed_batch(&texts).unwrap();
192        assert!(embeddings.is_empty());
193    }
194}