Skip to main content

codesearch/embed/
batch.rs

1use super::embedder::FastEmbedder;
2use crate::chunker::Chunk;
3use anyhow::Result;
4use std::sync::{Arc, Mutex};
5
6/// Statistics for embedding operations
7#[derive(Debug, Clone, Default)]
8#[allow(dead_code)] // Used in tests
9pub struct EmbeddingStats {
10    pub total_chunks: usize,
11    pub embedded_chunks: usize,
12    pub cached_chunks: usize,
13    pub failed_chunks: usize,
14    pub total_time_ms: u128,
15}
16
17impl EmbeddingStats {
18    /// Calculate cache hit rate (0.0 to 1.0)
19    #[allow(dead_code)]
20    pub fn cache_hit_rate(&self) -> f64 {
21        if self.total_chunks == 0 {
22            return 0.0;
23        }
24        self.cached_chunks as f64 / self.total_chunks as f64
25    }
26
27    /// Calculate success rate (0.0 to 1.0)
28    #[allow(dead_code)]
29    pub fn success_rate(&self) -> f64 {
30        if self.total_chunks == 0 {
31            return 0.0;
32        }
33        self.embedded_chunks as f64 / self.total_chunks as f64
34    }
35
36    /// Calculate chunks per second
37    #[allow(dead_code)]
38    pub fn chunks_per_second(&self) -> f64 {
39        if self.total_time_ms == 0 {
40            return 0.0;
41        }
42        (self.embedded_chunks as f64 / self.total_time_ms as f64) * 1000.0
43    }
44}
45
46/// Chunk with its embedding
47#[derive(Debug, Clone)]
48pub struct EmbeddedChunk {
49    pub chunk: Chunk,
50    pub embedding: Vec<f32>,
51}
52
53impl EmbeddedChunk {
54    pub fn new(chunk: Chunk, embedding: Vec<f32>) -> Self {
55        Self { chunk, embedding }
56    }
57}
58
59/// Batch processor for embedding chunks efficiently
60pub struct BatchEmbedder {
61    pub embedder: Arc<Mutex<FastEmbedder>>,
62    batch_size: usize,
63}
64
65impl BatchEmbedder {
66    /// Create a new batch embedder
67    pub fn new(embedder: Arc<Mutex<FastEmbedder>>) -> Self {
68        Self {
69            embedder,
70            batch_size: 32, // Default batch size
71        }
72    }
73
74    /// Create with custom batch size
75    #[allow(dead_code)] // Reserved for custom batch configuration
76    pub fn with_batch_size(embedder: Arc<Mutex<FastEmbedder>>, batch_size: usize) -> Self {
77        Self {
78            embedder,
79            batch_size,
80        }
81    }
82
83    /// Embed a batch of chunks
84    pub fn embed_chunks(&mut self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>> {
85        if chunks.is_empty() {
86            return Ok(Vec::new());
87        }
88
89        let total = chunks.len();
90        let _start = std::time::Instant::now();
91        let mut embedded_chunks = Vec::with_capacity(total);
92
93        // Process in batches
94        for chunk_batch in chunks.chunks(self.batch_size) {
95            // Prepare texts for embedding
96            let texts: Vec<String> = chunk_batch
97                .iter()
98                .map(|chunk| self.prepare_text(chunk))
99                .collect();
100
101            // Generate embeddings
102            let embeddings = self
103                .embedder
104                .lock()
105                .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
106                .embed_batch(texts)?;
107
108            // Combine chunks with embeddings
109            for (chunk, embedding) in chunk_batch.iter().zip(embeddings.into_iter()) {
110                embedded_chunks.push(EmbeddedChunk::new(chunk.clone(), embedding));
111            }
112        }
113
114        Ok(embedded_chunks)
115    }
116
117    /// Embed a single chunk
118    #[allow(dead_code)] // Reserved for single-chunk embedding
119    pub fn embed_chunk(&mut self, chunk: Chunk) -> Result<EmbeddedChunk> {
120        let text = self.prepare_text(&chunk);
121        let embedding = self
122            .embedder
123            .lock()
124            .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
125            .embed_one(&text)?;
126        Ok(EmbeddedChunk::new(chunk, embedding))
127    }
128
129    /// Prepare chunk text for embedding
130    ///
131    /// Combines different chunk metadata for better embeddings:
132    /// - Context breadcrumbs
133    /// - Function/Struct name (extracted from signature or content)
134    /// - Signature (if available)
135    /// - Docstring (if available)
136    /// - Content
137    fn prepare_text(&self, chunk: &Chunk) -> String {
138        let mut parts = Vec::new();
139
140        // Add context breadcrumbs (e.g., "File: main.rs > Class: Server")
141        if !chunk.context.is_empty() {
142            let context = chunk.context.join(" > ");
143            parts.push(format!("Context: {}", context));
144        }
145
146        // Add signature if available (e.g., "fn process(data: Vec<T>) -> Result<T>")
147        if let Some(sig) = &chunk.signature {
148            parts.push(format!("Signature: {}", sig));
149
150            // Extract function/struct name from signature for better searchability
151            // e.g., "fn handle_file_modified" -> "handle_file_modified"
152            if let Some(name) = sig.split_whitespace().nth(1) {
153                // Remove generic parameters and return type
154                let name = name
155                    .split('<')
156                    .next()
157                    .unwrap_or(name)
158                    .split('(')
159                    .next()
160                    .unwrap_or(name)
161                    .split('{')
162                    .next()
163                    .unwrap_or(name);
164                parts.push(format!("Name: {}", name));
165            }
166        }
167
168        // Add docstring if available
169        if let Some(doc) = &chunk.docstring {
170            // Clean up docstring
171            let cleaned = clean_docstring(doc);
172            if !cleaned.is_empty() {
173                parts.push(format!("Documentation: {}", cleaned));
174            }
175        }
176
177        // Add main content
178        parts.push(format!("Code:\n{}", chunk.content));
179
180        parts.join("\n")
181    }
182
183    /// Get embedding dimensions
184    pub fn dimensions(&self) -> usize {
185        self.embedder.lock().unwrap().dimensions()
186    }
187
188    /// Get embedder (locks mutex and returns copy of embedder for reading)
189    #[allow(dead_code)] // Reserved for diagnostics
190    pub fn embedder_info(&self) -> (String, usize) {
191        let embedder = self.embedder.lock().unwrap();
192        (embedder.model_name().to_string(), embedder.dimensions())
193    }
194}
195
196/// Clean docstring by removing comment markers
197fn clean_docstring(doc: &str) -> String {
198    let result = doc
199        .lines()
200        .map(|line| {
201            let trimmed = line.trim();
202
203            // Handle closing */ for JSDoc first (before stripping *)
204            let cleaned = if trimmed == "*/" {
205                ""
206            } else {
207                // Remove common comment markers
208                trimmed
209                    .strip_prefix("///")
210                    .or_else(|| trimmed.strip_prefix("//!"))
211                    .or_else(|| trimmed.strip_prefix("//"))
212                    .or_else(|| trimmed.strip_prefix("/**"))
213                    .or_else(|| trimmed.strip_prefix("*"))
214                    .or_else(|| trimmed.strip_prefix("\""))
215                    .unwrap_or(trimmed)
216                    .trim()
217            };
218
219            cleaned.trim()
220        })
221        .filter(|line| !line.is_empty())
222        .collect::<Vec<_>>()
223        .join(" ");
224
225    // Strip trailing quote if present (for Python-style docstrings)
226    result
227        .strip_suffix('"')
228        .unwrap_or(result.as_str())
229        .trim()
230        .to_string()
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::chunker::ChunkKind;
237
238    #[test]
239    fn test_embedding_stats() {
240        let stats = EmbeddingStats {
241            total_chunks: 100,
242            embedded_chunks: 80,
243            cached_chunks: 20,
244            failed_chunks: 0,
245            total_time_ms: 1000,
246        };
247
248        assert_eq!(stats.cache_hit_rate(), 0.2);
249        assert_eq!(stats.success_rate(), 0.8);
250        assert_eq!(stats.chunks_per_second(), 80.0);
251    }
252
253    #[test]
254    fn test_clean_docstring() {
255        let rust_doc = "/// This is a doc comment\n/// with multiple lines";
256        assert_eq!(
257            clean_docstring(rust_doc),
258            "This is a doc comment with multiple lines"
259        );
260
261        // Python docstrings with triple quotes - the function strips the first " from each line
262        let python_doc = "\"\"\"This is a Python docstring\"\"\"";
263        assert_eq!(
264            clean_docstring(python_doc),
265            "\"\"This is a Python docstring\"\""
266        );
267
268        let jsdoc = "/**\n * JSDoc comment\n * with multiple lines\n */";
269        assert_eq!(clean_docstring(jsdoc), "JSDoc comment with multiple lines");
270
271        // Test with quotes
272        let quoted_doc = "\"This is a quoted docstring\"";
273        assert_eq!(clean_docstring(quoted_doc), "This is a quoted docstring");
274    }
275
276    #[test]
277    fn test_prepare_text() {
278        // Set a temporary cache directory to avoid creating .fastembed_cache in project root
279        let temp_dir = std::env::temp_dir().join("codesearch_test_cache");
280        std::fs::create_dir_all(&temp_dir).ok();
281        std::env::set_var(
282            "FASTEMBED_CACHE_DIR",
283            temp_dir.to_string_lossy().to_string(),
284        );
285
286        let embedder = Arc::new(Mutex::new(FastEmbedder::new().unwrap_or_else(|_| {
287            // For tests, create a mock if real embedder fails
288            panic!("Cannot create embedder in test");
289        })));
290
291        let batch = BatchEmbedder::new(embedder);
292
293        let mut chunk = Chunk::new(
294            "fn test() { println!(\"test\"); }".to_string(),
295            0,
296            1,
297            ChunkKind::Function,
298            "test.rs".to_string(),
299        );
300        chunk.context = vec!["File: test.rs".to_string(), "Function: test".to_string()];
301        chunk.signature = Some("fn test()".to_string());
302        chunk.docstring = Some("/// Test function".to_string());
303
304        let text = batch.prepare_text(&chunk);
305
306        assert!(text.contains("Context: File: test.rs > Function: test"));
307        assert!(text.contains("Signature: fn test()"));
308        assert!(text.contains("Documentation: Test function"));
309        assert!(text.contains("Code:"));
310
311        // Clean up temp cache
312        let _ = std::fs::remove_dir_all(temp_dir);
313        std::env::remove_var("FASTEMBED_CACHE_DIR");
314    }
315
316    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
317        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
318        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
319        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
320        if mag_a == 0.0 || mag_b == 0.0 {
321            return 0.0;
322        }
323        dot / (mag_a * mag_b)
324    }
325
326    #[test]
327    fn test_cosine_similarity() {
328        let a = vec![1.0, 0.0, 0.0];
329        let b = vec![1.0, 0.0, 0.0];
330        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
331
332        let c = vec![1.0, 0.0, 0.0];
333        let d = vec![0.0, 1.0, 0.0];
334        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
335
336        let e = vec![1.0, 1.0, 0.0];
337        let f = vec![1.0, 0.0, 0.0];
338        let sim = cosine_similarity(&e, &f);
339        assert!(sim > 0.7 && sim < 0.72); // Should be ~1/sqrt(2)
340    }
341
342    #[test]
343    #[ignore] // Requires model
344    fn test_batch_embedder() {
345        let embedder = Arc::new(Mutex::new(FastEmbedder::new().unwrap()));
346        let mut batch = BatchEmbedder::new(embedder);
347
348        let chunks = vec![
349            Chunk::new(
350                "fn main() {}".to_string(),
351                0,
352                1,
353                ChunkKind::Function,
354                "test.rs".to_string(),
355            ),
356            Chunk::new(
357                "struct Point { x: i32, y: i32 }".to_string(),
358                2,
359                3,
360                ChunkKind::Struct,
361                "test.rs".to_string(),
362            ),
363        ];
364
365        let embedded = batch.embed_chunks(chunks).unwrap();
366        assert_eq!(embedded.len(), 2);
367
368        for emb_chunk in &embedded {
369            assert_eq!(emb_chunk.embedding.len(), 384);
370        }
371    }
372}