Skip to main content

seekr_code/embedder/
batch.rs

1//! Batch embedding wrapper.
2//!
3//! Groups input texts into batches and calls the underlying Embedder
4//! for batch-optimized inference, improving throughput for index building.
5
6use crate::embedder::traits::Embedder;
7use crate::error::EmbedderError;
8
9/// Batch embedding processor with progress reporting.
10pub struct BatchEmbedder<E: Embedder> {
11    embedder: E,
12    batch_size: usize,
13}
14
15impl<E: Embedder> BatchEmbedder<E> {
16    /// Create a new BatchEmbedder wrapping the given embedder.
17    pub fn new(embedder: E, batch_size: usize) -> Self {
18        Self {
19            embedder,
20            batch_size: batch_size.max(1),
21        }
22    }
23
24    /// Get the embedding dimension.
25    pub fn dimension(&self) -> usize {
26        self.embedder.dimension()
27    }
28
29    /// Embed all texts in batches, calling the progress callback after each batch.
30    ///
31    /// `progress_fn` receives (completed_count, total_count).
32    pub fn embed_all_with_progress<F>(
33        &self,
34        texts: &[String],
35        mut progress_fn: F,
36    ) -> Result<Vec<Vec<f32>>, EmbedderError>
37    where
38        F: FnMut(usize, usize),
39    {
40        let total = texts.len();
41        let mut all_embeddings = Vec::with_capacity(total);
42        let mut completed = 0;
43
44        for chunk in texts.chunks(self.batch_size) {
45            let refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
46            let batch_result = self.embedder.embed_batch(&refs)?;
47            all_embeddings.extend(batch_result);
48            completed += chunk.len();
49            progress_fn(completed, total);
50        }
51
52        Ok(all_embeddings)
53    }
54
55    /// Embed all texts in batches without progress reporting.
56    pub fn embed_all(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
57        self.embed_all_with_progress(texts, |_, _| {})
58    }
59
60    /// Get a reference to the inner embedder.
61    pub fn inner(&self) -> &E {
62        &self.embedder
63    }
64}
65
66/// A dummy embedder for testing that produces random-like but deterministic vectors.
67pub struct DummyEmbedder {
68    dim: usize,
69}
70
71impl DummyEmbedder {
72    /// Create a new dummy embedder with the given dimension.
73    pub fn new(dim: usize) -> Self {
74        Self { dim }
75    }
76}
77
78impl Embedder for DummyEmbedder {
79    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
80        // Generate a deterministic pseudo-random embedding based on text content
81        let mut embedding = vec![0.0f32; self.dim];
82        let mut hash: u64 = 5381;
83
84        for byte in text.bytes() {
85            hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
86        }
87
88        for (i, val) in embedding.iter_mut().enumerate() {
89            hash = hash
90                .wrapping_mul(6364136223846793005)
91                .wrapping_add(1442695040888963407);
92            *val = ((hash >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
93            // Mix in position
94            let _ = i; // suppress unused warning, position affects hash via iteration
95        }
96
97        // L2 normalize
98        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
99        if norm > 0.0 {
100            for x in &mut embedding {
101                *x /= norm;
102            }
103        }
104
105        Ok(embedding)
106    }
107
108    fn dimension(&self) -> usize {
109        self.dim
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_dummy_embedder() {
119        let embedder = DummyEmbedder::new(384);
120        let embedding = embedder.embed("hello world").unwrap();
121        assert_eq!(embedding.len(), 384);
122
123        // Check L2 norm ≈ 1.0
124        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
125        assert!(
126            (norm - 1.0).abs() < 0.01,
127            "Embedding should be L2 normalized"
128        );
129    }
130
131    #[test]
132    fn test_dummy_embedder_deterministic() {
133        let embedder = DummyEmbedder::new(384);
134        let e1 = embedder.embed("test").unwrap();
135        let e2 = embedder.embed("test").unwrap();
136        assert_eq!(e1, e2, "Same input should produce same embedding");
137    }
138
139    #[test]
140    fn test_dummy_embedder_different_inputs() {
141        let embedder = DummyEmbedder::new(384);
142        let e1 = embedder.embed("hello").unwrap();
143        let e2 = embedder.embed("world").unwrap();
144        assert_ne!(
145            e1, e2,
146            "Different inputs should produce different embeddings"
147        );
148    }
149
150    #[test]
151    fn test_batch_embedder() {
152        let embedder = DummyEmbedder::new(128);
153        let batch = BatchEmbedder::new(embedder, 2);
154
155        let texts: Vec<String> = vec![
156            "hello".to_string(),
157            "world".to_string(),
158            "foo".to_string(),
159            "bar".to_string(),
160            "baz".to_string(),
161        ];
162
163        let mut progress_calls = Vec::new();
164        let results = batch
165            .embed_all_with_progress(&texts, |completed, total| {
166                progress_calls.push((completed, total));
167            })
168            .unwrap();
169
170        assert_eq!(results.len(), 5);
171        assert_eq!(results[0].len(), 128);
172
173        // Should have 3 progress calls (batches of 2, 2, 1)
174        assert_eq!(progress_calls.len(), 3);
175        assert_eq!(progress_calls[0], (2, 5));
176        assert_eq!(progress_calls[1], (4, 5));
177        assert_eq!(progress_calls[2], (5, 5));
178    }
179}