Skip to main content

codelens_engine/embedding/
engine_impl.rs

1use crate::db::IndexDb;
2use crate::embedding_store::{EmbeddingChunk, ScoredChunk};
3use crate::project::ProjectRoot;
4use anyhow::{Context, Result};
5use fastembed::TextEmbedding;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use super::cache::{
10    reusable_embedding_key_for_chunk, reusable_embedding_key_for_symbol, ReusableEmbeddingKey,
11    TextEmbeddingCache,
12};
13use super::chunk_ops::{
14    cosine_similarity, duplicate_candidate_limit, duplicate_pair_key, stored_chunk_key,
15    stored_chunk_key_for_score, CategoryScore, DuplicatePair, OutlierSymbol, StoredChunkKey,
16};
17use super::ffi;
18use super::prompt::{
19    build_embedding_text, extract_leading_doc, is_test_only_symbol, split_identifier,
20};
21use super::runtime::{configured_rerank_blend, embed_batch_size, max_embed_symbols};
22use super::vec_store::SqliteVecStore;
23use super::{
24    EmbeddingEngine, EmbeddingIndexInfo, EmbeddingRuntimeInfo, SemanticMatch,
25    CHANGED_FILE_QUERY_CHUNK, DEFAULT_DUPLICATE_SCAN_BATCH_SIZE,
26};
27use rusqlite::Connection;
28
29impl EmbeddingEngine {
30    fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
31        if texts.is_empty() {
32            return Ok(Vec::new());
33        }
34
35        let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
36        let mut missing_order: Vec<String> = Vec::new();
37        let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
38
39        {
40            let mut cache = self
41                .text_embed_cache
42                .lock()
43                .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
44            for (index, text) in texts.iter().enumerate() {
45                if let Some(cached) = cache.get(text) {
46                    resolved[index] = Some(cached);
47                } else {
48                    let key = (*text).to_owned();
49                    if !missing_positions.contains_key(&key) {
50                        missing_order.push(key.clone());
51                    }
52                    missing_positions.entry(key).or_default().push(index);
53                }
54            }
55        }
56
57        if !missing_order.is_empty() {
58            let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
59            let embeddings = self
60                .model
61                .lock()
62                .map_err(|_| anyhow::anyhow!("model lock"))?
63                .embed(missing_refs, None)
64                .context("text embedding failed")?;
65
66            let mut cache = self
67                .text_embed_cache
68                .lock()
69                .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
70            for (text, embedding) in missing_order.into_iter().zip(embeddings.into_iter()) {
71                cache.insert(text.clone(), embedding.clone());
72                if let Some(indices) = missing_positions.remove(&text) {
73                    for index in indices {
74                        resolved[index] = Some(embedding.clone());
75                    }
76                }
77            }
78        }
79
80        resolved
81            .into_iter()
82            .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
83            .collect()
84    }
85
86    pub fn new(project: &ProjectRoot) -> Result<Self> {
87        let (model, dimension, model_name, runtime_info) = super::runtime::load_codesearch_model()?;
88
89        let db_dir = project.as_path().join(".codelens/index");
90        std::fs::create_dir_all(&db_dir)?;
91        let db_path = db_dir.join("embeddings.db");
92
93        let store = SqliteVecStore::new(&db_path, dimension, &model_name)?;
94
95        Ok(Self {
96            model: std::sync::Mutex::new(model),
97            store,
98            model_name,
99            runtime_info,
100            text_embed_cache: std::sync::Mutex::new(TextEmbeddingCache::new(
101                super::runtime::configured_embedding_text_cache_size(),
102            )),
103            indexing: std::sync::atomic::AtomicBool::new(false),
104        })
105    }
106
107    pub fn model_name(&self) -> &str {
108        &self.model_name
109    }
110
111    pub fn runtime_info(&self) -> &EmbeddingRuntimeInfo {
112        &self.runtime_info
113    }
114
115    /// Index all symbols from the project's symbol database into the embedding index.
116    ///
117    /// Reconciles the embedding store file-by-file so unchanged symbols can
118    /// reuse their existing vectors and only changed/new symbols are re-embedded.
119    /// Caps at a configurable max to prevent runaway on huge projects.
120    /// Returns true if a full reindex is currently in progress.
121    pub fn is_indexing(&self) -> bool {
122        self.indexing.load(std::sync::atomic::Ordering::Relaxed)
123    }
124
125    pub fn index_from_project(&self, project: &ProjectRoot) -> Result<usize> {
126        // Guard against concurrent full reindex (14s+ operation)
127        if self
128            .indexing
129            .compare_exchange(
130                false,
131                true,
132                std::sync::atomic::Ordering::AcqRel,
133                std::sync::atomic::Ordering::Relaxed,
134            )
135            .is_err()
136        {
137            anyhow::bail!(
138                "Embedding indexing already in progress — wait for the current run to complete before retrying."
139            );
140        }
141        // RAII guard to reset the flag on any exit path
142        struct IndexGuard<'a>(&'a std::sync::atomic::AtomicBool);
143        impl Drop for IndexGuard<'_> {
144            fn drop(&mut self) {
145                self.0.store(false, std::sync::atomic::Ordering::Release);
146            }
147        }
148        let _guard = IndexGuard(&self.indexing);
149
150        let db_path = crate::db::index_db_path(project.as_path());
151        let symbol_db = IndexDb::open(&db_path)?;
152        let batch_size = embed_batch_size();
153        let max_symbols = max_embed_symbols();
154        let mut total_indexed = 0usize;
155        let mut total_seen = 0usize;
156        let mut model = None;
157        let mut existing_embeddings: HashMap<
158            String,
159            HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
160        > = HashMap::new();
161        let mut current_db_files = HashSet::new();
162        let mut capped = false;
163
164        self.store
165            .for_each_file_embeddings(&mut |file_path, chunks| {
166                existing_embeddings.insert(
167                    file_path,
168                    chunks
169                        .into_iter()
170                        .map(|chunk| (reusable_embedding_key_for_chunk(&chunk), chunk))
171                        .collect(),
172                );
173                Ok(())
174            })?;
175
176        symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
177            current_db_files.insert(file_path.clone());
178            if capped {
179                return Ok(());
180            }
181
182            let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
183            let relevant_symbols: Vec<_> = symbols
184                .into_iter()
185                .filter(|sym| !is_test_only_symbol(sym, source.as_deref()))
186                .collect();
187
188            if relevant_symbols.is_empty() {
189                self.store.delete_by_file(&[file_path.as_str()])?;
190                existing_embeddings.remove(&file_path);
191                return Ok(());
192            }
193
194            if total_seen + relevant_symbols.len() > max_symbols {
195                capped = true;
196                return Ok(());
197            }
198            total_seen += relevant_symbols.len();
199
200            let existing_for_file = existing_embeddings.remove(&file_path).unwrap_or_default();
201            total_indexed += self.reconcile_file_embeddings(
202                &file_path,
203                relevant_symbols,
204                source.as_deref(),
205                existing_for_file,
206                batch_size,
207                &mut model,
208            )?;
209            Ok(())
210        })?;
211
212        let removed_files: Vec<String> = existing_embeddings
213            .into_keys()
214            .filter(|file_path| !current_db_files.contains(file_path))
215            .collect();
216        if !removed_files.is_empty() {
217            let removed_refs: Vec<&str> = removed_files.iter().map(String::as_str).collect();
218            self.store.delete_by_file(&removed_refs)?;
219        }
220
221        Ok(total_indexed)
222    }
223
224    /// Extract NL→code bridge candidates from indexed symbols.
225    /// For each symbol with a docstring, produces a (docstring_first_line, symbol_name) pair.
226    /// The caller writes these to `.codelens/bridges.json` for project-specific NL bridging.
227    pub fn generate_bridge_candidates(
228        &self,
229        project: &ProjectRoot,
230    ) -> Result<Vec<(String, String)>> {
231        let db_path = crate::db::index_db_path(project.as_path());
232        let symbol_db = IndexDb::open(&db_path)?;
233        let mut bridges: Vec<(String, String)> = Vec::new();
234        let mut seen_nl = HashSet::new();
235
236        symbol_db.for_each_file_symbols_with_bytes(|file_path, symbols| {
237            let source = std::fs::read_to_string(project.as_path().join(&file_path)).ok();
238            for sym in &symbols {
239                if is_test_only_symbol(sym, source.as_deref()) {
240                    continue;
241                }
242                let doc = source.as_deref().and_then(|src| {
243                    extract_leading_doc(src, sym.start_byte as usize, sym.end_byte as usize)
244                });
245                let doc = match doc {
246                    Some(d) if !d.is_empty() => d,
247                    _ => continue,
248                };
249
250                // Build code term: symbol_name + split words
251                let split = split_identifier(&sym.name);
252                let code_term = if split != sym.name {
253                    format!("{} {}", sym.name, split)
254                } else {
255                    sym.name.clone()
256                };
257
258                // Extract short NL phrases (3-6 words) from the docstring.
259                // This produces multiple bridge entries per symbol, each matching
260                // common NL query patterns like "render template" or "parse url".
261                let first_line = doc.lines().next().unwrap_or("").trim().to_lowercase();
262                // Remove trailing period/punctuation
263                let clean = first_line.trim_end_matches(|c: char| c.is_ascii_punctuation());
264                let words: Vec<&str> = clean.split_whitespace().collect();
265                if words.len() < 2 {
266                    continue;
267                }
268
269                // Generate short N-gram keys (2-4 words from the start)
270                for window in 2..=words.len().min(4) {
271                    let key = words[..window].join(" ");
272                    if key.len() < 5 || key.len() > 60 {
273                        continue;
274                    }
275                    if seen_nl.insert(key.clone()) {
276                        bridges.push((key, code_term.clone()));
277                    }
278                }
279
280                // Also add split_identifier words as a bridge key
281                // so "render template" → render_template
282                if split != sym.name && !seen_nl.contains(&split.to_lowercase()) {
283                    let lowered = split.to_lowercase();
284                    if lowered.split_whitespace().count() >= 2 && seen_nl.insert(lowered.clone()) {
285                        bridges.push((lowered, code_term.clone()));
286                    }
287                }
288            }
289            Ok(())
290        })?;
291
292        Ok(bridges)
293    }
294
295    fn reconcile_file_embeddings<'a>(
296        &'a self,
297        file_path: &str,
298        symbols: Vec<crate::db::SymbolWithFile>,
299        source: Option<&str>,
300        mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk>,
301        batch_size: usize,
302        model: &mut Option<std::sync::MutexGuard<'a, TextEmbedding>>,
303    ) -> Result<usize> {
304        let mut reconciled_chunks = Vec::with_capacity(symbols.len());
305        let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
306        let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
307
308        for sym in symbols {
309            let text = build_embedding_text(&sym, source);
310            if let Some(existing) =
311                existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
312            {
313                reconciled_chunks.push(EmbeddingChunk {
314                    file_path: sym.file_path.clone(),
315                    symbol_name: sym.name.clone(),
316                    kind: sym.kind.clone(),
317                    line: sym.line as usize,
318                    signature: sym.signature.clone(),
319                    name_path: sym.name_path.clone(),
320                    text,
321                    embedding: existing.embedding,
322                    doc_embedding: existing.doc_embedding,
323                });
324                continue;
325            }
326
327            batch_texts.push(text);
328            batch_meta.push(sym);
329
330            if batch_texts.len() >= batch_size {
331                if model.is_none() {
332                    *model = Some(
333                        self.model
334                            .lock()
335                            .map_err(|_| anyhow::anyhow!("model lock"))?,
336                    );
337                }
338                reconciled_chunks.extend(Self::embed_chunks(
339                    model.as_mut().expect("model lock initialized"),
340                    &batch_texts,
341                    &batch_meta,
342                )?);
343                batch_texts.clear();
344                batch_meta.clear();
345            }
346        }
347
348        if !batch_texts.is_empty() {
349            if model.is_none() {
350                *model = Some(
351                    self.model
352                        .lock()
353                        .map_err(|_| anyhow::anyhow!("model lock"))?,
354                );
355            }
356            reconciled_chunks.extend(Self::embed_chunks(
357                model.as_mut().expect("model lock initialized"),
358                &batch_texts,
359                &batch_meta,
360            )?);
361        }
362
363        self.store.delete_by_file(&[file_path])?;
364        if reconciled_chunks.is_empty() {
365            return Ok(0);
366        }
367        self.store.insert(&reconciled_chunks)
368    }
369
370    fn embed_chunks(
371        model: &mut TextEmbedding,
372        texts: &[String],
373        meta: &[crate::db::SymbolWithFile],
374    ) -> Result<Vec<EmbeddingChunk>> {
375        let batch_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
376        let embeddings = model.embed(batch_refs, None).context("embedding failed")?;
377
378        Ok(meta
379            .iter()
380            .zip(embeddings)
381            .zip(texts.iter())
382            .map(|((sym, emb), text)| EmbeddingChunk {
383                file_path: sym.file_path.clone(),
384                symbol_name: sym.name.clone(),
385                kind: sym.kind.clone(),
386                line: sym.line as usize,
387                signature: sym.signature.clone(),
388                name_path: sym.name_path.clone(),
389                text: text.clone(),
390                embedding: emb,
391                doc_embedding: None,
392            })
393            .collect())
394    }
395
396    /// Embed one batch of texts and upsert immediately, then the caller drops the batch.
397    fn flush_batch(
398        model: &mut TextEmbedding,
399        store: &SqliteVecStore,
400        texts: &[String],
401        meta: &[crate::db::SymbolWithFile],
402    ) -> Result<usize> {
403        let chunks = Self::embed_chunks(model, texts, meta)?;
404        store.insert(&chunks)
405    }
406
407    /// Search for symbols semantically similar to the query.
408    pub fn search(&self, query: &str, max_results: usize) -> Result<Vec<SemanticMatch>> {
409        let results = self.search_scored(query, max_results)?;
410        Ok(results.into_iter().map(SemanticMatch::from).collect())
411    }
412
413    /// Search returning raw ScoredChunks with optional reranking.
414    ///
415    /// Pipeline: bi-encoder → candidate pool (3× requested) → rerank → top-N.
416    /// Reranking uses query-document text overlap scoring to refine bi-encoder
417    /// cosine similarity. This catches cases where embedding similarity is high
418    /// but the actual text relevance is low (or vice versa).
419    pub fn search_scored(&self, query: &str, max_results: usize) -> Result<Vec<ScoredChunk>> {
420        let query_embedding = self.embed_texts_cached(&[query])?;
421
422        if query_embedding.is_empty() {
423            return Ok(Vec::new());
424        }
425
426        // Fetch N× candidates for reranking headroom (default 5×, override via
427        // CODELENS_RERANK_FACTOR). More candidates = better rerank quality at
428        // marginal latency cost (sqlite-vec scan is fast).
429        let factor = std::env::var("CODELENS_RERANK_FACTOR")
430            .ok()
431            .and_then(|v| v.parse::<usize>().ok())
432            .unwrap_or(5);
433        let candidate_count = max_results.saturating_mul(factor).max(max_results);
434        let mut candidates = self.store.search(&query_embedding[0], candidate_count)?;
435
436        if candidates.len() <= max_results {
437            return Ok(candidates);
438        }
439
440        // Lightweight rerank: blend bi-encoder score with text overlap signal.
441        // This is a stopgap until a proper cross-encoder is plugged in.
442        let query_lower = query.to_lowercase();
443        let query_tokens: Vec<&str> = query_lower
444            .split(|c: char| c.is_whitespace() || c == '_' || c == '-')
445            .filter(|t| t.len() >= 2)
446            .collect();
447
448        if query_tokens.is_empty() {
449            candidates.truncate(max_results);
450            return Ok(candidates);
451        }
452
453        let blend = configured_rerank_blend();
454        for chunk in &mut candidates {
455            // Build searchable text: symbol_name + split identifier words +
456            // name_path (parent context) + signature + file_path.
457            // split_identifier turns "parseSymbols" into "parse Symbols" for
458            // better NL token matching.
459            let split_name = split_identifier(&chunk.symbol_name);
460            let searchable = format!(
461                "{} {} {} {} {}",
462                chunk.symbol_name.to_lowercase(),
463                split_name.to_lowercase(),
464                chunk.name_path.to_lowercase(),
465                chunk.signature.to_lowercase(),
466                chunk.file_path.to_lowercase(),
467            );
468            let overlap = query_tokens
469                .iter()
470                .filter(|t| searchable.contains(**t))
471                .count() as f64;
472            let overlap_ratio = overlap / query_tokens.len().max(1) as f64;
473            // Blend: configurable bi-encoder + text overlap (default 75/25)
474            chunk.score = chunk.score * blend + overlap_ratio * (1.0 - blend);
475        }
476
477        candidates.sort_by(|a, b| {
478            b.score
479                .partial_cmp(&a.score)
480                .unwrap_or(std::cmp::Ordering::Equal)
481        });
482        candidates.truncate(max_results);
483        Ok(candidates)
484    }
485
486    /// Incrementally re-index only the given files.
487    pub fn index_changed_files(
488        &self,
489        project: &ProjectRoot,
490        changed_files: &[&str],
491    ) -> Result<usize> {
492        if changed_files.is_empty() {
493            return Ok(0);
494        }
495        let batch_size = embed_batch_size();
496        let mut existing_embeddings: HashMap<ReusableEmbeddingKey, EmbeddingChunk> = HashMap::new();
497        for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
498            for chunk in self.store.embeddings_for_files(file_chunk)? {
499                existing_embeddings.insert(reusable_embedding_key_for_chunk(&chunk), chunk);
500            }
501        }
502        self.store.delete_by_file(changed_files)?;
503
504        let db_path = crate::db::index_db_path(project.as_path());
505        let symbol_db = IndexDb::open(&db_path)?;
506
507        let mut total_indexed = 0usize;
508        let mut batch_texts: Vec<String> = Vec::with_capacity(batch_size);
509        let mut batch_meta: Vec<crate::db::SymbolWithFile> = Vec::with_capacity(batch_size);
510        let mut batch_reused: Vec<EmbeddingChunk> = Vec::with_capacity(batch_size);
511        let mut file_cache: std::collections::HashMap<String, Option<String>> =
512            std::collections::HashMap::new();
513        let mut model = None;
514
515        for file_chunk in changed_files.chunks(CHANGED_FILE_QUERY_CHUNK) {
516            let relevant = symbol_db.symbols_for_files(file_chunk)?;
517            for sym in relevant {
518                let source = file_cache.entry(sym.file_path.clone()).or_insert_with(|| {
519                    std::fs::read_to_string(project.as_path().join(&sym.file_path)).ok()
520                });
521                if is_test_only_symbol(&sym, source.as_deref()) {
522                    continue;
523                }
524                let text = build_embedding_text(&sym, source.as_deref());
525                if let Some(existing) =
526                    existing_embeddings.remove(&reusable_embedding_key_for_symbol(&sym, &text))
527                {
528                    batch_reused.push(EmbeddingChunk {
529                        file_path: sym.file_path.clone(),
530                        symbol_name: sym.name.clone(),
531                        kind: sym.kind.clone(),
532                        line: sym.line as usize,
533                        signature: sym.signature.clone(),
534                        name_path: sym.name_path.clone(),
535                        text,
536                        embedding: existing.embedding,
537                        doc_embedding: existing.doc_embedding,
538                    });
539                    if batch_reused.len() >= batch_size {
540                        total_indexed += self.store.insert(&batch_reused)?;
541                        batch_reused.clear();
542                    }
543                    continue;
544                }
545                batch_texts.push(text);
546                batch_meta.push(sym);
547
548                if batch_texts.len() >= batch_size {
549                    if model.is_none() {
550                        model = Some(
551                            self.model
552                                .lock()
553                                .map_err(|_| anyhow::anyhow!("model lock"))?,
554                        );
555                    }
556                    total_indexed += Self::flush_batch(
557                        model.as_mut().expect("model lock initialized"),
558                        &self.store,
559                        &batch_texts,
560                        &batch_meta,
561                    )?;
562                    batch_texts.clear();
563                    batch_meta.clear();
564                }
565            }
566        }
567
568        if !batch_reused.is_empty() {
569            total_indexed += self.store.insert(&batch_reused)?;
570        }
571
572        if !batch_texts.is_empty() {
573            if model.is_none() {
574                model = Some(
575                    self.model
576                        .lock()
577                        .map_err(|_| anyhow::anyhow!("model lock"))?,
578                );
579            }
580            total_indexed += Self::flush_batch(
581                model.as_mut().expect("model lock initialized"),
582                &self.store,
583                &batch_texts,
584                &batch_meta,
585            )?;
586        }
587
588        Ok(total_indexed)
589    }
590
591    /// Whether the embedding index has been populated.
592    pub fn is_indexed(&self) -> bool {
593        self.store.count().unwrap_or(0) > 0
594    }
595
596    pub fn index_info(&self) -> EmbeddingIndexInfo {
597        EmbeddingIndexInfo {
598            model_name: self.model_name.clone(),
599            indexed_symbols: self.store.count().unwrap_or(0),
600        }
601    }
602
603    pub fn inspect_existing_index(project: &ProjectRoot) -> Result<Option<EmbeddingIndexInfo>> {
604        let db_path = project.as_path().join(".codelens/index/embeddings.db");
605        if !db_path.exists() {
606            return Ok(None);
607        }
608
609        let conn =
610            crate::db::open_derived_sqlite_with_recovery(&db_path, "embedding index", || {
611                ffi::register_sqlite_vec()?;
612                let conn = Connection::open(&db_path)?;
613                conn.execute_batch("PRAGMA busy_timeout=5000;")?;
614                conn.query_row("PRAGMA schema_version", [], |_row| Ok(()))?;
615                Ok(conn)
616            })?;
617
618        let model_name: Option<String> = conn
619            .query_row(
620                "SELECT value FROM meta WHERE key = 'model' LIMIT 1",
621                [],
622                |row| row.get(0),
623            )
624            .ok();
625        let indexed_symbols: usize = conn
626            .query_row("SELECT COUNT(*) FROM symbols", [], |row| {
627                row.get::<_, i64>(0)
628            })
629            .map(|count| count.max(0) as usize)
630            .unwrap_or(0);
631
632        Ok(model_name.map(|model_name| EmbeddingIndexInfo {
633            model_name,
634            indexed_symbols,
635        }))
636    }
637
638    // ── Embedding-powered analysis ─────────────────────────────────
639
640    /// Find code symbols most similar to the given symbol.
641    pub fn find_similar_code(
642        &self,
643        file_path: &str,
644        symbol_name: &str,
645        max_results: usize,
646    ) -> Result<Vec<SemanticMatch>> {
647        let target = self
648            .store
649            .get_embedding(file_path, symbol_name)?
650            .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?;
651
652        let oversample = max_results.saturating_add(8).max(1);
653        let scored = self
654            .store
655            .search(&target.embedding, oversample)?
656            .into_iter()
657            .filter(|c| !(c.file_path == file_path && c.symbol_name == symbol_name))
658            .take(max_results)
659            .map(SemanticMatch::from)
660            .collect();
661        Ok(scored)
662    }
663
664    /// Find near-duplicate code pairs across the codebase.
665    /// Returns pairs with cosine similarity above the threshold (default 0.85).
666    pub fn find_duplicates(&self, threshold: f64, max_pairs: usize) -> Result<Vec<DuplicatePair>> {
667        let mut pairs = Vec::new();
668        let mut seen_pairs = HashSet::new();
669        let mut embedding_cache: HashMap<StoredChunkKey, Arc<EmbeddingChunk>> = HashMap::new();
670        let candidate_limit = duplicate_candidate_limit(max_pairs);
671        let mut done = false;
672
673        self.store
674            .for_each_embedding_batch(DEFAULT_DUPLICATE_SCAN_BATCH_SIZE, &mut |batch| {
675                if done {
676                    return Ok(());
677                }
678
679                let mut candidate_lists = Vec::with_capacity(batch.len());
680                let mut missing_candidates = Vec::new();
681                let mut missing_keys = HashSet::new();
682
683                for chunk in &batch {
684                    if pairs.len() >= max_pairs {
685                        done = true;
686                        break;
687                    }
688
689                    let filtered: Vec<ScoredChunk> = self
690                        .store
691                        .search(&chunk.embedding, candidate_limit)?
692                        .into_iter()
693                        .filter(|candidate| {
694                            !(chunk.file_path == candidate.file_path
695                                && chunk.symbol_name == candidate.symbol_name
696                                && chunk.line == candidate.line
697                                && chunk.signature == candidate.signature
698                                && chunk.name_path == candidate.name_path)
699                        })
700                        .collect();
701
702                    for candidate in &filtered {
703                        let cache_key = stored_chunk_key_for_score(candidate);
704                        if !embedding_cache.contains_key(&cache_key)
705                            && missing_keys.insert(cache_key)
706                        {
707                            missing_candidates.push(candidate.clone());
708                        }
709                    }
710
711                    candidate_lists.push(filtered);
712                }
713
714                if !missing_candidates.is_empty() {
715                    for candidate_chunk in self
716                        .store
717                        .embeddings_for_scored_chunks(&missing_candidates)?
718                    {
719                        embedding_cache
720                            .entry(stored_chunk_key(&candidate_chunk))
721                            .or_insert_with(|| Arc::new(candidate_chunk));
722                    }
723                }
724
725                for (chunk, candidates) in batch.iter().zip(candidate_lists.iter()) {
726                    if pairs.len() >= max_pairs {
727                        done = true;
728                        break;
729                    }
730
731                    for candidate in candidates {
732                        let pair_key = duplicate_pair_key(
733                            &chunk.file_path,
734                            &chunk.symbol_name,
735                            &candidate.file_path,
736                            &candidate.symbol_name,
737                        );
738                        if !seen_pairs.insert(pair_key) {
739                            continue;
740                        }
741
742                        let Some(candidate_chunk) =
743                            embedding_cache.get(&stored_chunk_key_for_score(candidate))
744                        else {
745                            continue;
746                        };
747
748                        let sim = cosine_similarity(&chunk.embedding, &candidate_chunk.embedding);
749                        if sim < threshold {
750                            continue;
751                        }
752
753                        pairs.push(DuplicatePair {
754                            symbol_a: format!("{}:{}", chunk.file_path, chunk.symbol_name),
755                            symbol_b: format!(
756                                "{}:{}",
757                                candidate_chunk.file_path, candidate_chunk.symbol_name
758                            ),
759                            file_a: chunk.file_path.clone(),
760                            file_b: candidate_chunk.file_path.clone(),
761                            line_a: chunk.line,
762                            line_b: candidate_chunk.line,
763                            similarity: sim,
764                        });
765                        if pairs.len() >= max_pairs {
766                            done = true;
767                            break;
768                        }
769                    }
770                }
771                Ok(())
772            })?;
773
774        pairs.sort_by(|a, b| {
775            b.similarity
776                .partial_cmp(&a.similarity)
777                .unwrap_or(std::cmp::Ordering::Equal)
778        });
779        Ok(pairs)
780    }
781}
782
783impl EmbeddingEngine {
784    /// Classify a code symbol into one of the given categories using zero-shot embedding similarity.
785    pub fn classify_symbol(
786        &self,
787        file_path: &str,
788        symbol_name: &str,
789        categories: &[&str],
790    ) -> Result<Vec<CategoryScore>> {
791        let target = match self.store.get_embedding(file_path, symbol_name)? {
792            Some(target) => target,
793            None => self
794                .store
795                .all_with_embeddings()?
796                .into_iter()
797                .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
798                .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
799        };
800
801        let embeddings = self.embed_texts_cached(categories)?;
802
803        let mut scores: Vec<CategoryScore> = categories
804            .iter()
805            .zip(embeddings.iter())
806            .map(|(cat, emb)| CategoryScore {
807                category: cat.to_string(),
808                score: cosine_similarity(&target.embedding, emb),
809            })
810            .collect();
811
812        scores.sort_by(|a, b| {
813            b.score
814                .partial_cmp(&a.score)
815                .unwrap_or(std::cmp::Ordering::Equal)
816        });
817        Ok(scores)
818    }
819
820    /// Find symbols that are outliers — semantically distant from their file's other symbols.
821    pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
822        let mut outliers = Vec::new();
823
824        self.store
825            .for_each_file_embeddings(&mut |file_path, chunks| {
826                if chunks.len() < 2 {
827                    return Ok(());
828                }
829
830                for (idx, chunk) in chunks.iter().enumerate() {
831                    let mut sim_sum = 0.0;
832                    let mut count = 0;
833                    for (other_idx, other_chunk) in chunks.iter().enumerate() {
834                        if other_idx == idx {
835                            continue;
836                        }
837                        sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
838                        count += 1;
839                    }
840                    if count > 0 {
841                        let avg_sim = sim_sum / count as f64; // Lower means more misplaced.
842                        outliers.push(OutlierSymbol {
843                            file_path: file_path.clone(),
844                            symbol_name: chunk.symbol_name.clone(),
845                            kind: chunk.kind.clone(),
846                            line: chunk.line,
847                            avg_similarity_to_file: avg_sim,
848                        });
849                    }
850                }
851                Ok(())
852            })?;
853
854        outliers.sort_by(|a, b| {
855            a.avg_similarity_to_file
856                .partial_cmp(&b.avg_similarity_to_file)
857                .unwrap_or(std::cmp::Ordering::Equal)
858        });
859        outliers.truncate(max_results);
860        Ok(outliers)
861    }
862}