rlm_rs/embedding/mod.rs
1//! Embedding generation for semantic search.
2//!
3//! Provides embedding generation using fastembed (when available) or a
4//! hash-based fallback for deterministic pseudo-embeddings.
5//!
6//! # Feature Flags
7//!
8//! - `fastembed-embeddings`: Enables `FastEmbed` with BGE-M3 (1024 dimensions, 8192 token max)
9//! - Without the feature: Uses hash-based fallback (deterministic but not semantic)
10
11mod fallback;
12
13#[cfg(feature = "fastembed-embeddings")]
14mod fastembed_impl;
15
16pub use fallback::FallbackEmbedder;
17
18#[cfg(feature = "fastembed-embeddings")]
19pub use fastembed_impl::FastEmbedEmbedder;
20
21use crate::Result;
22
23/// Default embedding dimensions for the BGE-M3 model.
24///
25/// This is the authoritative source for embedding dimensions across the codebase.
26/// All vector backends should use this constant for consistency.
27pub const DEFAULT_DIMENSIONS: usize = 1024;
28
29/// Trait for embedding generators.
30///
31/// Implementations must be thread-safe (`Send + Sync`) to support parallel
32/// embedding generation during chunk loading.
33///
34/// # Examples
35///
36/// ```
37/// use rlm_rs::embedding::{Embedder, FallbackEmbedder, DEFAULT_DIMENSIONS};
38///
39/// let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
40/// let embedding = embedder.embed("Hello, world!").unwrap();
41/// assert_eq!(embedding.len(), DEFAULT_DIMENSIONS);
42/// ```
43pub trait Embedder: Send + Sync {
44 /// Returns the embedding dimensions.
45 fn dimensions(&self) -> usize;
46
47 /// Returns the model name/version identifier.
48 ///
49 /// This is stored with embeddings to detect model changes.
50 fn model_name(&self) -> &'static str;
51
52 /// Generates an embedding for the given text.
53 ///
54 /// # Errors
55 ///
56 /// Returns an error if embedding generation fails.
57 fn embed(&self, text: &str) -> Result<Vec<f32>>;
58
59 /// Generates embeddings for multiple texts.
60 ///
61 /// The default implementation calls `embed` for each text sequentially.
62 /// Implementations may override this for batch optimization.
63 ///
64 /// # Errors
65 ///
66 /// Returns an error if embedding generation fails for any text.
67 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
68 texts.iter().map(|t| self.embed(t)).collect()
69 }
70}
71
72/// Creates the default embedder based on available features.
73///
74/// - With `fastembed-embeddings`: Returns `FastEmbedEmbedder`
75/// - Without: Returns `FallbackEmbedder`
76///
77/// # Errors
78///
79/// Returns an error if embedder initialization fails.
80#[cfg(feature = "fastembed-embeddings")]
81pub fn create_embedder() -> Result<Box<dyn Embedder>> {
82 Ok(Box::new(FastEmbedEmbedder::new()?))
83}
84
85/// Creates the default embedder based on available features.
86///
87/// - With `fastembed-embeddings`: Returns `FastEmbedEmbedder`
88/// - Without: Returns `FallbackEmbedder`
89///
90/// # Errors
91///
92/// Returns an error if embedder initialization fails (never fails for fallback).
93#[cfg(not(feature = "fastembed-embeddings"))]
94pub fn create_embedder() -> Result<Box<dyn Embedder>> {
95 Ok(Box::new(FallbackEmbedder::new(DEFAULT_DIMENSIONS)))
96}
97
98/// Computes cosine similarity between two embedding vectors.
99///
100/// Returns a value between -1.0 (opposite) and 1.0 (identical).
101/// For normalized vectors (L2 norm = 1), this is equivalent to the dot product.
102///
103/// # Panics
104///
105/// Does not panic but returns 0.0 if vectors have different lengths or zero magnitude.
106#[must_use]
107pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
108 if a.len() != b.len() {
109 return 0.0;
110 }
111
112 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
113 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
114 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
115
116 if mag_a == 0.0 || mag_b == 0.0 {
117 return 0.0;
118 }
119
120 dot / (mag_a * mag_b)
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_cosine_similarity_identical() {
129 let a = vec![1.0, 0.0, 0.0];
130 let b = vec![1.0, 0.0, 0.0];
131 let sim = cosine_similarity(&a, &b);
132 assert!((sim - 1.0).abs() < 1e-6);
133 }
134
135 #[test]
136 fn test_cosine_similarity_orthogonal() {
137 let a = vec![1.0, 0.0, 0.0];
138 let b = vec![0.0, 1.0, 0.0];
139 let sim = cosine_similarity(&a, &b);
140 assert!(sim.abs() < 1e-6);
141 }
142
143 #[test]
144 fn test_cosine_similarity_opposite() {
145 let a = vec![1.0, 0.0, 0.0];
146 let b = vec![-1.0, 0.0, 0.0];
147 let sim = cosine_similarity(&a, &b);
148 assert!((sim + 1.0).abs() < 1e-6);
149 }
150
151 #[test]
152 fn test_cosine_similarity_different_lengths() {
153 let a = vec![1.0, 0.0];
154 let b = vec![1.0, 0.0, 0.0];
155 let sim = cosine_similarity(&a, &b);
156 assert!(sim.abs() < 1e-6);
157 }
158
159 #[test]
160 fn test_cosine_similarity_zero_vector() {
161 let a = vec![0.0, 0.0, 0.0];
162 let b = vec![1.0, 0.0, 0.0];
163 let sim = cosine_similarity(&a, &b);
164 assert!(sim.abs() < 1e-6);
165 }
166
167 #[test]
168 fn test_create_embedder() {
169 let embedder = create_embedder().unwrap();
170 assert_eq!(embedder.dimensions(), DEFAULT_DIMENSIONS);
171 }
172
173 #[test]
174 fn test_embed_batch_default_impl() {
175 // Test the default embed_batch implementation (lines 62-63)
176 let embedder = create_embedder().unwrap();
177 let texts = vec!["hello", "world", "test"];
178 let embeddings = embedder.embed_batch(&texts).unwrap();
179
180 assert_eq!(embeddings.len(), 3);
181 for embedding in &embeddings {
182 assert_eq!(embedding.len(), embedder.dimensions());
183 }
184 }
185
186 #[test]
187 fn test_embed_batch_empty() {
188 // Test embed_batch with empty slice
189 let embedder = create_embedder().unwrap();
190 let texts: Vec<&str> = vec![];
191 let embeddings = embedder.embed_batch(&texts).unwrap();
192 assert!(embeddings.is_empty());
193 }
194}