1use super::embedder::FastEmbedder;
2use crate::chunker::Chunk;
3use anyhow::Result;
4use std::sync::{Arc, Mutex};
5
6#[derive(Debug, Clone, Default)]
8#[allow(dead_code)] pub struct EmbeddingStats {
10 pub total_chunks: usize,
11 pub embedded_chunks: usize,
12 pub cached_chunks: usize,
13 pub failed_chunks: usize,
14 pub total_time_ms: u128,
15}
16
17impl EmbeddingStats {
18 #[allow(dead_code)]
20 pub fn cache_hit_rate(&self) -> f64 {
21 if self.total_chunks == 0 {
22 return 0.0;
23 }
24 self.cached_chunks as f64 / self.total_chunks as f64
25 }
26
27 #[allow(dead_code)]
29 pub fn success_rate(&self) -> f64 {
30 if self.total_chunks == 0 {
31 return 0.0;
32 }
33 self.embedded_chunks as f64 / self.total_chunks as f64
34 }
35
36 #[allow(dead_code)]
38 pub fn chunks_per_second(&self) -> f64 {
39 if self.total_time_ms == 0 {
40 return 0.0;
41 }
42 (self.embedded_chunks as f64 / self.total_time_ms as f64) * 1000.0
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct EmbeddedChunk {
49 pub chunk: Chunk,
50 pub embedding: Vec<f32>,
51}
52
53impl EmbeddedChunk {
54 pub fn new(chunk: Chunk, embedding: Vec<f32>) -> Self {
55 Self { chunk, embedding }
56 }
57}
58
59pub struct BatchEmbedder {
61 pub embedder: Arc<Mutex<FastEmbedder>>,
62 batch_size: usize,
63}
64
65impl BatchEmbedder {
66 pub fn new(embedder: Arc<Mutex<FastEmbedder>>) -> Self {
68 Self {
69 embedder,
70 batch_size: 32, }
72 }
73
74 #[allow(dead_code)] pub fn with_batch_size(embedder: Arc<Mutex<FastEmbedder>>, batch_size: usize) -> Self {
77 Self {
78 embedder,
79 batch_size,
80 }
81 }
82
83 pub fn embed_chunks(&mut self, chunks: Vec<Chunk>) -> Result<Vec<EmbeddedChunk>> {
85 if chunks.is_empty() {
86 return Ok(Vec::new());
87 }
88
89 let total = chunks.len();
90 let _start = std::time::Instant::now();
91 let mut embedded_chunks = Vec::with_capacity(total);
92
93 for chunk_batch in chunks.chunks(self.batch_size) {
95 let texts: Vec<String> = chunk_batch
97 .iter()
98 .map(|chunk| self.prepare_text(chunk))
99 .collect();
100
101 let embeddings = self
103 .embedder
104 .lock()
105 .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
106 .embed_batch(texts)?;
107
108 for (chunk, embedding) in chunk_batch.iter().zip(embeddings.into_iter()) {
110 embedded_chunks.push(EmbeddedChunk::new(chunk.clone(), embedding));
111 }
112 }
113
114 Ok(embedded_chunks)
115 }
116
117 #[allow(dead_code)] pub fn embed_chunk(&mut self, chunk: Chunk) -> Result<EmbeddedChunk> {
120 let text = self.prepare_text(&chunk);
121 let embedding = self
122 .embedder
123 .lock()
124 .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
125 .embed_one(&text)?;
126 Ok(EmbeddedChunk::new(chunk, embedding))
127 }
128
129 fn prepare_text(&self, chunk: &Chunk) -> String {
138 let mut parts = Vec::new();
139
140 if !chunk.context.is_empty() {
142 let context = chunk.context.join(" > ");
143 parts.push(format!("Context: {}", context));
144 }
145
146 if let Some(sig) = &chunk.signature {
148 parts.push(format!("Signature: {}", sig));
149
150 if let Some(name) = sig.split_whitespace().nth(1) {
153 let name = name
155 .split('<')
156 .next()
157 .unwrap_or(name)
158 .split('(')
159 .next()
160 .unwrap_or(name)
161 .split('{')
162 .next()
163 .unwrap_or(name);
164 parts.push(format!("Name: {}", name));
165 }
166 }
167
168 if let Some(doc) = &chunk.docstring {
170 let cleaned = clean_docstring(doc);
172 if !cleaned.is_empty() {
173 parts.push(format!("Documentation: {}", cleaned));
174 }
175 }
176
177 parts.push(format!("Code:\n{}", chunk.content));
179
180 parts.join("\n")
181 }
182
183 pub fn dimensions(&self) -> usize {
185 self.embedder.lock().unwrap().dimensions()
186 }
187
188 #[allow(dead_code)] pub fn embedder_info(&self) -> (String, usize) {
191 let embedder = self.embedder.lock().unwrap();
192 (embedder.model_name().to_string(), embedder.dimensions())
193 }
194}
195
196fn clean_docstring(doc: &str) -> String {
198 let result = doc
199 .lines()
200 .map(|line| {
201 let trimmed = line.trim();
202
203 let cleaned = if trimmed == "*/" {
205 ""
206 } else {
207 trimmed
209 .strip_prefix("///")
210 .or_else(|| trimmed.strip_prefix("//!"))
211 .or_else(|| trimmed.strip_prefix("//"))
212 .or_else(|| trimmed.strip_prefix("/**"))
213 .or_else(|| trimmed.strip_prefix("*"))
214 .or_else(|| trimmed.strip_prefix("\""))
215 .unwrap_or(trimmed)
216 .trim()
217 };
218
219 cleaned.trim()
220 })
221 .filter(|line| !line.is_empty())
222 .collect::<Vec<_>>()
223 .join(" ");
224
225 result
227 .strip_suffix('"')
228 .unwrap_or(result.as_str())
229 .trim()
230 .to_string()
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::chunker::ChunkKind;
237
238 #[test]
239 fn test_embedding_stats() {
240 let stats = EmbeddingStats {
241 total_chunks: 100,
242 embedded_chunks: 80,
243 cached_chunks: 20,
244 failed_chunks: 0,
245 total_time_ms: 1000,
246 };
247
248 assert_eq!(stats.cache_hit_rate(), 0.2);
249 assert_eq!(stats.success_rate(), 0.8);
250 assert_eq!(stats.chunks_per_second(), 80.0);
251 }
252
253 #[test]
254 fn test_clean_docstring() {
255 let rust_doc = "/// This is a doc comment\n/// with multiple lines";
256 assert_eq!(
257 clean_docstring(rust_doc),
258 "This is a doc comment with multiple lines"
259 );
260
261 let python_doc = "\"\"\"This is a Python docstring\"\"\"";
263 assert_eq!(
264 clean_docstring(python_doc),
265 "\"\"This is a Python docstring\"\""
266 );
267
268 let jsdoc = "/**\n * JSDoc comment\n * with multiple lines\n */";
269 assert_eq!(clean_docstring(jsdoc), "JSDoc comment with multiple lines");
270
271 let quoted_doc = "\"This is a quoted docstring\"";
273 assert_eq!(clean_docstring(quoted_doc), "This is a quoted docstring");
274 }
275
276 #[test]
277 fn test_prepare_text() {
278 let temp_dir = std::env::temp_dir().join("codesearch_test_cache");
280 std::fs::create_dir_all(&temp_dir).ok();
281 std::env::set_var(
282 "FASTEMBED_CACHE_DIR",
283 temp_dir.to_string_lossy().to_string(),
284 );
285
286 let embedder = Arc::new(Mutex::new(FastEmbedder::new().unwrap_or_else(|_| {
287 panic!("Cannot create embedder in test");
289 })));
290
291 let batch = BatchEmbedder::new(embedder);
292
293 let mut chunk = Chunk::new(
294 "fn test() { println!(\"test\"); }".to_string(),
295 0,
296 1,
297 ChunkKind::Function,
298 "test.rs".to_string(),
299 );
300 chunk.context = vec!["File: test.rs".to_string(), "Function: test".to_string()];
301 chunk.signature = Some("fn test()".to_string());
302 chunk.docstring = Some("/// Test function".to_string());
303
304 let text = batch.prepare_text(&chunk);
305
306 assert!(text.contains("Context: File: test.rs > Function: test"));
307 assert!(text.contains("Signature: fn test()"));
308 assert!(text.contains("Documentation: Test function"));
309 assert!(text.contains("Code:"));
310
311 let _ = std::fs::remove_dir_all(temp_dir);
313 std::env::remove_var("FASTEMBED_CACHE_DIR");
314 }
315
316 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
317 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
318 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
319 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
320 if mag_a == 0.0 || mag_b == 0.0 {
321 return 0.0;
322 }
323 dot / (mag_a * mag_b)
324 }
325
326 #[test]
327 fn test_cosine_similarity() {
328 let a = vec![1.0, 0.0, 0.0];
329 let b = vec![1.0, 0.0, 0.0];
330 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
331
332 let c = vec![1.0, 0.0, 0.0];
333 let d = vec![0.0, 1.0, 0.0];
334 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
335
336 let e = vec![1.0, 1.0, 0.0];
337 let f = vec![1.0, 0.0, 0.0];
338 let sim = cosine_similarity(&e, &f);
339 assert!(sim > 0.7 && sim < 0.72); }
341
342 #[test]
343 #[ignore] fn test_batch_embedder() {
345 let embedder = Arc::new(Mutex::new(FastEmbedder::new().unwrap()));
346 let mut batch = BatchEmbedder::new(embedder);
347
348 let chunks = vec![
349 Chunk::new(
350 "fn main() {}".to_string(),
351 0,
352 1,
353 ChunkKind::Function,
354 "test.rs".to_string(),
355 ),
356 Chunk::new(
357 "struct Point { x: i32, y: i32 }".to_string(),
358 2,
359 3,
360 ChunkKind::Struct,
361 "test.rs".to_string(),
362 ),
363 ];
364
365 let embedded = batch.embed_chunks(chunks).unwrap();
366 assert_eq!(embedded.len(), 2);
367
368 for emb_chunk in &embedded {
369 assert_eq!(emb_chunk.embedding.len(), 384);
370 }
371 }
372}