1use crate::db::IndexDb;
2use crate::embedding_store::{EmbeddingChunk, ScoredChunk};
3use crate::project::ProjectRoot;
4use anyhow::{Context, Result};
5use fastembed::TextEmbedding;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use super::cache::{
10 ReusableEmbeddingKey, TextEmbeddingCache, reusable_embedding_key_for_chunk,
11 reusable_embedding_key_for_symbol,
12};
13use super::chunk_ops::{
14 CategoryScore, DuplicatePair, OutlierSymbol, StoredChunkKey, cosine_similarity,
15 duplicate_candidate_limit, duplicate_pair_key, stored_chunk_key, stored_chunk_key_for_score,
16};
17use super::ffi;
18use super::prompt::{
19 build_embedding_text, extract_leading_doc, is_test_only_symbol, split_identifier,
20};
21use super::runtime::{configured_rerank_blend, embed_batch_size, max_embed_symbols};
22use super::vec_store::SqliteVecStore;
23use super::{
24 CHANGED_FILE_QUERY_CHUNK, DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, EmbeddingEngine,
25 EmbeddingIndexInfo, EmbeddingRuntimeInfo, SemanticMatch,
26};
27use rusqlite::Connection;
28
29impl EmbeddingEngine {
30 fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
31 if texts.is_empty() {
32 return Ok(Vec::new());
33 }
34
35 let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
36 let mut missing_order: Vec<String> = Vec::new();
37 let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
38
39 {
40 let mut cache = self
41 .text_embed_cache
42 .lock()
43 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
44 for (index, text) in texts.iter().enumerate() {
45 if let Some(cached) = cache.get(text) {
46 resolved[index] = Some(cached);
47 } else {
48 let key = (*text).to_owned();
49 if !missing_positions.contains_key(&key) {
50 missing_order.push(key.clone());
51 }
52 missing_positions.entry(key).or_default().push(index);
53 }
54 }
55 }
56
57 if !missing_order.is_empty() {
58 let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
59 let embeddings = self
60 .model
61 .lock()
62 .map_err(|_| anyhow::anyhow!("model lock"))?
63 .embed(missing_refs, None)
64 .context("text embedding failed")?;
65
66 let mut cache = self
67 .text_embed_cache
68 .lock()
69 .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
70 for (text, embedding) in missing_order.into_iter().zip(embeddings.into_iter()) {
71 cache.insert(text.clone(), embedding.clone());
72 if let Some(indices) = missing_positions.remove(&text) {
73 for index in indices {
74 resolved[index] = Some(embedding.clone());
75 }
76 }
77 }
78 }
79
80 resolved
81 .into_iter()
82 .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
83 .collect()
84 }
85
86 pub fn new(project: &ProjectRoot) -> Result<Self> {
87 let (model, dimension, model_name, runtime_info) = super::runtime::load_codesearch_model()?;
88
89 let db_dir = project.as_path().join(".codelens/index");
90 std::fs::create_dir_all(&db_dir)?;
91 let db_path = db_dir.join("embeddings.db");
92
93 let store = SqliteVecStore::new(&db_path, dimension, &model_name)?;
94
95 Ok(Self {
96 model: std::sync::Mutex::new(model),
97 store,
98 model_name,
99 runtime_info,
100 text_embed_cache: std::sync::Mutex::new(TextEmbeddingCache::new(
101 super::runtime::configured_embedding_text_cache_size(),
102 )),
103 indexing: std::sync::atomic::AtomicBool::new(false),
104 })
105 }
106
107 pub fn model_name(&self) -> &str {
108 &self.model_name
109 }
110
111 pub fn runtime_info(&self) -> &EmbeddingRuntimeInfo {
112 &self.runtime_info
113 }
114
115 pub fn is_indexing(&self) -> bool {
122 self.indexing.load(std::sync::atomic::Ordering::Relaxed)
123 }
124
125 pub fn index_from_project(&self, project: &ProjectRoot) -> Result<usize> {
126 if self
128 .indexing
129 .compare_exchange(
130 false,
131 true,
132 std::sync::atomic::Ordering::AcqRel,
133 std::sync::atomic::Ordering::Relaxed,
134 )
135 .is_err()
136 {
137 anyhow::bail!(
138 "Embedding indexing already in progress — wait for the current run to complete before retrying."
139 );
140 }
141 struct IndexGuard<'a>(&'a std::sync::atomic::AtomicBool);
143 impl Drop for IndexGuard<'_> {
144 fn drop(&mut self) {
145 self.0.store(false, std::sync::atomic::Ordering::Release);
146 }
147 }
148 let _guard = IndexGuard(&self.indexing);
149
150 let db_path = crate::db::index_db_path(project.as_path());
151 let symbol_db = IndexDb::open(&db_path)?;
152 let batch_size = embed_batch_size();
153 let max_symbols = max_embed_symbols();
154 let mut total_indexed = 0usize;
155 let mut total_seen = 0usize;
156 let mut model = None;
157 let mut existing_embeddings: HashMap<
158 String,
159 HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
160 > = HashMap::new();
161 let mut current_db_files = HashSet::new();
162 let mut capped = false;
163
164 self.store
165 .for_each_file_embeddings(&mut |file_path, chunks| {
166 existing_embeddings.insert(
167 file_path,
168 chunks
169 .into_iter()
170 .map(|chunk| (reusable_embedding_key_for_chunk(&chunk), chunk))
171 .collect(),
172 );
173 Ok(())
174 })?;
175
176 symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
177 current_db_files.insert(file_path.clone());
178 if capped {
179 return Ok(());
180 }
181
182 let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
183 let relevant_symbols: Vec<_> = symbols
184 .into_iter()
185 .filter(|sym| !is_test_only_symbol(sym, source.as_deref()))
186 .collect();
187
188 if relevant_symbols.is_empty() {
189 self.store.delete_by_file(&[file_path.as_str()])?;
190 existing_embeddings.remove(&file_path);
191 return Ok(());
192 }
193
194 if total_seen + relevant_symbols.len() > max_symbols {
195 capped = true;
196 return Ok(());
197 }
198 total_seen += relevant_symbols.len();
199
200 let existing_for_file = existing_embeddings.remove(&file_path).unwrap_or_default();
201 total_indexed += self.reconcile_file_embeddings(
202 &file_path,
203 relevant_symbols,
204 source.as_deref(),
205 existing_for_file,
206 batch_size,
207 &mut model,
208 )?;
209 Ok(())
210 })?;
211
212 let removed_files: Vec<String> = existing_embeddings
213 .into_keys()
214 .filter(|file_path| !current_db_files.contains(file_path))
215 .collect();
216 if !removed_files.is_empty() {
217 let removed_refs: Vec<&str> = removed_files.iter().map(String::as_str).collect();
218 self.store.delete_by_file(&removed_refs)?;
219 }
220
221 Ok(total_indexed)
222 }
223
224 pub fn generate_bridge_candidates(
228 &self,
229 project: &ProjectRoot,
230 ) -> Result<Vec<(String, String)>> {
231 let db_path = crate::db::index_db_path(project.as_path());
232 let symbol_db = IndexDb::open(&db_path)?;
233 let mut bridges: Vec<(String, String)> = Vec::new();
234 let mut seen_nl = HashSet::new();
235
236 symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
237 let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
238 for sym in &symbols {
239 if is_test_only_symbol(sym, source.as_deref()) {
240 continue;
241 }
242 let doc = source.as_deref().and_then(|src| {
243 extract_leading_doc(src, sym.start_byte as usize, sym.end_byte as usize)
244 });
245 let doc = match doc {
246 Some(d) if !d.is_empty() => d,
247 _ => continue,
248 };
249
250 let split = split_identifier(&sym.name);
252 let code_term = if split != sym.name {
253 format!("{} {}", sym.name, split)
254 } else {
255 sym.name.clone()
256 };
257
258 let first_line = doc.lines().next().unwrap_or("").trim().to_lowercase();
262 let clean = first_line.trim_end_matches(|c: char| c.is_ascii_punctuation());
264 let words: Vec<&str> = clean.split_whitespace().collect();
265 if words.len() < 2 {
266 continue;
267 }
268
269 for window in 2..=words.len().min(4) {
271 let key = words[..window].join(" ");
272 if key.len() < 5 || key.len() > 60 {
273 continue;
274 }
275 if seen_nl.insert(key.clone()) {
276 bridges.push((key, code_term.clone()));
277 }
278 }
279
280 if split != sym.name && !seen_nl.contains(&split.to_lowercase()) {
283 let lowered = split.to_lowercase();
284 if lowered.split_whitespace().count() >= 2 && seen_nl.insert(lowered.clone()) {
285 bridges.push((lowered, code_term.clone()));
286 }
287 }
288 }
289 Ok(())
290 })?;
291
292 Ok(bridges)
293 }
294
295 fn reconcile_file_embeddings<'a>(
296 &'a self,
297 file_path: &str,
298 symbols: Vec<crate::db::SymbolWithFile>,
299 source: Option<&str>,
300 mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
301 batch_size: usize,
302 model: &mut Option<std::sync::MutexGuard<'a, TextEmbedding>>,
303 ) -> Result<usize> {
304 let mut reconciled_chunks = Vec::with_capacity(symbols.len());
305 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
306 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
307
308 for sym in symbols {
309 let text = build_embedding_text(&sym, source);
310 if let Some(existing) =
311 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
312 {
313 reconciled_chunks.push(EmbeddingChunk {
314 file_path: sym.file_path.clone(),
315 symbol_name: sym.name.clone(),
316 kind: sym.kind.clone(),
317 line: sym.line as usize,
318 signature: sym.signature.clone(),
319 name_path: sym.name_path.clone(),
320 text,
321 embedding: existing.embedding,
322 doc_embedding: existing.doc_embedding,
323 });
324 continue;
325 }
326
327 batch_texts.push(text);
328 batch_meta.push(sym);
329
330 if batch_texts.len() >= batch_size {
331 if model.is_none() {
332 *model = Some(
333 self.model
334 .lock()
335 .map_err(|_| anyhow::anyhow!("model lock"))?,
336 );
337 }
338 reconciled_chunks.extend(Self::embed_chunks(
339 model.as_mut().expect("model lock initialized"),
340 &batch_texts,
341 &batch_meta,
342 )?);
343 batch_texts.clear();
344 batch_meta.clear();
345 }
346 }
347
348 if !batch_texts.is_empty() {
349 if model.is_none() {
350 *model = Some(
351 self.model
352 .lock()
353 .map_err(|_| anyhow::anyhow!("model lock"))?,
354 );
355 }
356 reconciled_chunks.extend(Self::embed_chunks(
357 model.as_mut().expect("model lock initialized"),
358 &batch_texts,
359 &batch_meta,
360 )?);
361 }
362
363 self.store.delete_by_file(&[file_path])?;
364 if reconciled_chunks.is_empty() {
365 return Ok(0);
366 }
367 self.store.insert(&reconciled_chunks)
368 }
369
370 fn embed_chunks(
371 model: &mut TextEmbedding,
372 texts: &[String],
373 meta: &[crate::db::SymbolWithFile],
374 ) -> Result<Vec<EmbeddingChunk>> {
375 let batch_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
376 let embeddings = model.embed(batch_refs, None).context("embedding failed")?;
377
378 Ok(meta
379 .iter()
380 .zip(embeddings)
381 .zip(texts.iter())
382 .map(|((sym, emb), text)| EmbeddingChunk {
383 file_path: sym.file_path.clone(),
384 symbol_name: sym.name.clone(),
385 kind: sym.kind.clone(),
386 line: sym.line as usize,
387 signature: sym.signature.clone(),
388 name_path: sym.name_path.clone(),
389 text: text.clone(),
390 embedding: emb,
391 doc_embedding: None,
392 })
393 .collect())
394 }
395
396 fn flush_batch(
398 model: &mut TextEmbedding,
399 store: &SqliteVecStore,
400 texts: &[String],
401 meta: &[crate::db::SymbolWithFile],
402 ) -> Result<usize> {
403 let chunks = Self::embed_chunks(model, texts, meta)?;
404 store.insert(&chunks)
405 }
406
407 pub fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticMatch>> {
409 let results = self.search_scored(query, max_results)?;
410 Ok(results.into_iter().map(SemanticMatch::from).collect())
411 }
412
413 pub fn search_scored(&self, query: &str, max_results: usize) -> Result<Vec<ScoredChunk>> {
420 let query_embedding = self.embed_texts_cached(&[query])?;
421
422 if query_embedding.is_empty() {
423 return Ok(Vec::new());
424 }
425
426 let factor = std::env::var("CODELENS_RERANK_FACTOR")
430 .ok()
431 .and_then(|v| v.parse::<usize>().ok())
432 .unwrap_or(5);
433 let candidate_count = max_results.saturating_mul(factor).max(max_results);
434 let mut candidates = self.store.search(&query_embedding[0], candidate_count)?;
435
436 if candidates.len() <= max_results {
437 return Ok(candidates);
438 }
439
440 let query_lower = query.to_lowercase();
443 let query_tokens: Vec<&str> = query_lower
444 .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
445 .filter(|t| t.len() >= 2)
446 .collect();
447
448 if query_tokens.is_empty() {
449 candidates.truncate(max_results);
450 return Ok(candidates);
451 }
452
453 let blend = configured_rerank_blend();
454 for chunk in &mut candidates {
455 let split_name = split_identifier(&chunk.symbol_name);
460 let searchable = format!(
461 "{} {} {} {} {}",
462 chunk.symbol_name.to_lowercase(),
463 split_name.to_lowercase(),
464 chunk.name_path.to_lowercase(),
465 chunk.signature.to_lowercase(),
466 chunk.file_path.to_lowercase(),
467 );
468 let overlap = query_tokens
469 .iter()
470 .filter(|t| searchable.contains(**t))
471 .count() as f64;
472 let overlap_ratio = overlap / query_tokens.len().max(1) as f64;
473 chunk.score = chunk.score * blend + overlap_ratio * (1.0 - blend);
475 }
476
477 candidates.sort_by(|a, b| {
478 b.score
479 .partial_cmp(&a.score)
480 .unwrap_or(std::cmp::Ordering::Equal)
481 });
482 candidates.truncate(max_results);
483 Ok(candidates)
484 }
485
486 pub fn index_changed_files(
488 &self,
489 project: &ProjectRoot,
490 changed_files: &[&str],
491 ) -> Result<usize> {
492 if changed_files.is_empty() {
493 return Ok(0);
494 }
495 let batch_size = embed_batch_size();
496 let mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk> = HashMap::new();
497 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
498 for chunk in self.store.embeddings_for_files(file_chunk)? {
499 existing_embeddings.insert(reusable_embedding_key_for_chunk(&chunk), chunk);
500 }
501 }
502 self.store.delete_by_file(changed_files)?;
503
504 let db_path = crate::db::index_db_path(project.as_path());
505 let symbol_db = IndexDb::open(&db_path)?;
506
507 let mut total_indexed = 0usize;
508 let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
509 let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
510 let mut batch_reused: Vec<EmbeddingChunk> = Vec::with_capacity(batch_size);
511 let mut file_cache: std::collections::HashMap<String, Option<String>> =
512 std::collections::HashMap::new();
513 let mut model = None;
514
515 for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
516 let relevant = symbol_db.symbols_for_files(file_chunk)?;
517 for sym in relevant {
518 let source = file_cache.entry(sym.file_path.clone()).or_insert_with(|| {
519 std::fs::read_to_string(project.as_path().join(&sym.file_path)).ok()
520 });
521 if is_test_only_symbol(&sym, source.as_deref()) {
522 continue;
523 }
524 let text = build_embedding_text(&sym, source.as_deref());
525 if let Some(existing) =
526 existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
527 {
528 batch_reused.push(EmbeddingChunk {
529 file_path: sym.file_path.clone(),
530 symbol_name: sym.name.clone(),
531 kind: sym.kind.clone(),
532 line: sym.line as usize,
533 signature: sym.signature.clone(),
534 name_path: sym.name_path.clone(),
535 text,
536 embedding: existing.embedding,
537 doc_embedding: existing.doc_embedding,
538 });
539 if batch_reused.len() >= batch_size {
540 total_indexed += self.store.insert(&batch_reused)?;
541 batch_reused.clear();
542 }
543 continue;
544 }
545 batch_texts.push(text);
546 batch_meta.push(sym);
547
548 if batch_texts.len() >= batch_size {
549 if model.is_none() {
550 model = Some(
551 self.model
552 .lock()
553 .map_err(|_| anyhow::anyhow!("model lock"))?,
554 );
555 }
556 total_indexed += Self::flush_batch(
557 model.as_mut().expect("model lock initialized"),
558 &self.store,
559 &batch_texts,
560 &batch_meta,
561 )?;
562 batch_texts.clear();
563 batch_meta.clear();
564 }
565 }
566 }
567
568 if !batch_reused.is_empty() {
569 total_indexed += self.store.insert(&batch_reused)?;
570 }
571
572 if !batch_texts.is_empty() {
573 if model.is_none() {
574 model = Some(
575 self.model
576 .lock()
577 .map_err(|_| anyhow::anyhow!("model lock"))?,
578 );
579 }
580 total_indexed += Self::flush_batch(
581 model.as_mut().expect("model lock initialized"),
582 &self.store,
583 &batch_texts,
584 &batch_meta,
585 )?;
586 }
587
588 Ok(total_indexed)
589 }
590
591 pub fn is_indexed(&self) -> bool {
593 self.store.count().unwrap_or(0) > 0
594 }
595
596 pub fn index_info(&self) -> EmbeddingIndexInfo {
597 EmbeddingIndexInfo {
598 model_name: self.model_name.clone(),
599 indexed_symbols: self.store.count().unwrap_or(0),
600 }
601 }
602
603 pub fn inspect_existing_index(project: &ProjectRoot) -> Result<Option<EmbeddingIndexInfo>> {
604 let db_path = project.as_path().join(".codelens/index/embeddings.db");
605 if !db_path.exists() {
606 return Ok(None);
607 }
608
609 let conn =
610 crate::db::open_derived_sqlite_with_recovery(&db_path, "embedding index", || {
611 ffi::register_sqlite_vec()?;
612 let conn = Connection::open(&db_path)?;
613 conn.execute_batch("PRAGMA busy_timeout=5000;")?;
614 conn.query_row("PRAGMA schema_version", [], |_row| Ok(()))?;
615 Ok(conn)
616 })?;
617
618 let model_name: Option<String> = conn
619 .query_row(
620 "SELECT value FROM meta WHERE key = 'model' LIMIT 1",
621 [],
622 |row| row.get(0),
623 )
624 .ok();
625 let indexed_symbols: usize = conn
626 .query_row("SELECT COUNT(*) FROM symbols", [], |row| {
627 row.get::<_, i64>(0)
628 })
629 .map(|count| count.max(0) as usize)
630 .unwrap_or(0);
631
632 Ok(model_name.map(|model_name| EmbeddingIndexInfo {
633 model_name,
634 indexed_symbols,
635 }))
636 }
637
638 pub fn find_similar_code(
642 &self,
643 file_path: &str,
644 symbol_name: &str,
645 max_results: usize,
646 ) -> Result<Vec<SemanticMatch>> {
647 let target = self
648 .store
649 .get_embedding(file_path, symbol_name)?
650 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?;
651
652 let oversample = max_results.saturating_add(8).max(1);
653 let scored = self
654 .store
655 .search(&target.embedding, oversample)?
656 .into_iter()
657 .filter(|c| !(c.file_path == file_path && c.symbol_name == symbol_name))
658 .take(max_results)
659 .map(SemanticMatch::from)
660 .collect();
661 Ok(scored)
662 }
663
664 pub fn find_duplicates(&self, threshold: f64, max_pairs: usize) -> Result<Vec<DuplicatePair>> {
667 let mut pairs = Vec::new();
668 let mut seen_pairs = HashSet::new();
669 let mut embedding_cache: HashMap<StoredChunkKey, Arc<EmbeddingChunk>> = HashMap::new();
670 let candidate_limit = duplicate_candidate_limit(max_pairs);
671 let mut done = false;
672
673 self.store
674 .for_each_embedding_batch(DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, &mut |batch| {
675 if done {
676 return Ok(());
677 }
678
679 let mut candidate_lists = Vec::with_capacity(batch.len());
680 let mut missing_candidates = Vec::new();
681 let mut missing_keys = HashSet::new();
682
683 for chunk in &batch {
684 if pairs.len() >= max_pairs {
685 done = true;
686 break;
687 }
688
689 let filtered: Vec<ScoredChunk> = self
690 .store
691 .search(&chunk.embedding, candidate_limit)?
692 .into_iter()
693 .filter(|candidate| {
694 !(chunk.file_path == candidate.file_path
695 && chunk.symbol_name == candidate.symbol_name
696 && chunk.line == candidate.line
697 && chunk.signature == candidate.signature
698 && chunk.name_path == candidate.name_path)
699 })
700 .collect();
701
702 for candidate in &filtered {
703 let cache_key = stored_chunk_key_for_score(candidate);
704 if !embedding_cache.contains_key(&cache_key)
705 && missing_keys.insert(cache_key)
706 {
707 missing_candidates.push(candidate.clone());
708 }
709 }
710
711 candidate_lists.push(filtered);
712 }
713
714 if !missing_candidates.is_empty() {
715 for candidate_chunk in self
716 .store
717 .embeddings_for_scored_chunks(&missing_candidates)?
718 {
719 embedding_cache
720 .entry(stored_chunk_key(&candidate_chunk))
721 .or_insert_with(|| Arc::new(candidate_chunk));
722 }
723 }
724
725 for (chunk, candidates) in batch.iter().zip(candidate_lists.iter()) {
726 if pairs.len() >= max_pairs {
727 done = true;
728 break;
729 }
730
731 for candidate in candidates {
732 let pair_key = duplicate_pair_key(
733 &chunk.file_path,
734 &chunk.symbol_name,
735 &candidate.file_path,
736 &candidate.symbol_name,
737 );
738 if !seen_pairs.insert(pair_key) {
739 continue;
740 }
741
742 let Some(candidate_chunk) =
743 embedding_cache.get(&stored_chunk_key_for_score(candidate))
744 else {
745 continue;
746 };
747
748 let sim = cosine_similarity(&chunk.embedding, &candidate_chunk.embedding);
749 if sim < threshold {
750 continue;
751 }
752
753 pairs.push(DuplicatePair {
754 symbol_a: format!("{}:{}", chunk.file_path, chunk.symbol_name),
755 symbol_b: format!(
756 "{}:{}",
757 candidate_chunk.file_path, candidate_chunk.symbol_name
758 ),
759 file_a: chunk.file_path.clone(),
760 file_b: candidate_chunk.file_path.clone(),
761 line_a: chunk.line,
762 line_b: candidate_chunk.line,
763 similarity: sim,
764 });
765 if pairs.len() >= max_pairs {
766 done = true;
767 break;
768 }
769 }
770 }
771 Ok(())
772 })?;
773
774 pairs.sort_by(|a, b| {
775 b.similarity
776 .partial_cmp(&a.similarity)
777 .unwrap_or(std::cmp::Ordering::Equal)
778 });
779 Ok(pairs)
780 }
781}
782
783impl EmbeddingEngine {
784 pub fn classify_symbol(
786 &self,
787 file_path: &str,
788 symbol_name: &str,
789 categories: &[&str],
790 ) -> Result<Vec<CategoryScore>> {
791 let target = match self.store.get_embedding(file_path, symbol_name)? {
792 Some(target) => target,
793 None => self
794 .store
795 .all_with_embeddings()?
796 .into_iter()
797 .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
798 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
799 };
800
801 let embeddings = self.embed_texts_cached(categories)?;
802
803 let mut scores: Vec<CategoryScore> = categories
804 .iter()
805 .zip(embeddings.iter())
806 .map(|(cat, emb)| CategoryScore {
807 category: cat.to_string(),
808 score: cosine_similarity(&target.embedding, emb),
809 })
810 .collect();
811
812 scores.sort_by(|a, b| {
813 b.score
814 .partial_cmp(&a.score)
815 .unwrap_or(std::cmp::Ordering::Equal)
816 });
817 Ok(scores)
818 }
819
820 pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
822 let mut outliers = Vec::new();
823
824 self.store
825 .for_each_file_embeddings(&mut |file_path, chunks| {
826 if chunks.len() < 2 {
827 return Ok(());
828 }
829
830 for (idx, chunk) in chunks.iter().enumerate() {
831 let mut sim_sum = 0.0;
832 let mut count = 0;
833 for (other_idx, other_chunk) in chunks.iter().enumerate() {
834 if other_idx == idx {
835 continue;
836 }
837 sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
838 count += 1;
839 }
840 if count > 0 {
841 let avg_sim = sim_sum / count as f64; outliers.push(OutlierSymbol {
843 file_path: file_path.clone(),
844 symbol_name: chunk.symbol_name.clone(),
845 kind: chunk.kind.clone(),
846 line: chunk.line,
847 avg_similarity_to_file: avg_sim,
848 });
849 }
850 }
851 Ok(())
852 })?;
853
854 outliers.sort_by(|a, b| {
855 a.avg_similarity_to_file
856 .partial_cmp(&b.avg_similarity_to_file)
857 .unwrap_or(std::cmp::Ordering::Equal)
858 });
859 outliers.truncate(max_results);
860 Ok(outliers)
861 }
862}