Skip to main content

codelens_engine/embedding/engine_impl/
analysis.rs

1use anyhow::Result;
2use std::collections::HashMap;
3
4use super::super::EmbeddingEngine;
5use super::super::chunk_ops::{CategoryScore, OutlierSymbol, cosine_similarity};
6use crate::embedding_store::{ArtifactEmbeddingChunk, ScoredArtifactChunk};
7
8impl EmbeddingEngine {
9    /// Classify a code symbol into one of the given categories using zero-shot embedding similarity.
10    pub fn classify_symbol(
11        &self,
12        file_path: &str,
13        symbol_name: &str,
14        categories: &[&str],
15    ) -> Result<Vec<CategoryScore>> {
16        let target = match self.store.get_embedding(file_path, symbol_name)? {
17            Some(target) => target,
18            None => self
19                .store
20                .all_with_embeddings()?
21                .into_iter()
22                .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
23                .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
24        };
25
26        let embeddings = self.embed_texts_cached(categories)?;
27
28        let mut scores: Vec<CategoryScore> = categories
29            .iter()
30            .zip(embeddings.iter())
31            .map(|(cat, emb)| CategoryScore {
32                category: cat.to_string(),
33                score: cosine_similarity(&target.embedding, emb),
34            })
35            .collect();
36
37        scores.sort_by(|a, b| {
38            b.score
39                .partial_cmp(&a.score)
40                .unwrap_or(std::cmp::Ordering::Equal)
41        });
42        Ok(scores)
43    }
44
45    /// Find symbols that are outliers — semantically distant from their file's other symbols.
46    pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
47        let mut outliers = Vec::new();
48
49        self.store
50            .for_each_file_embeddings(&mut |file_path, chunks| {
51                if chunks.len() < 2 {
52                    return Ok(());
53                }
54
55                for (idx, chunk) in chunks.iter().enumerate() {
56                    let mut sim_sum = 0.0;
57                    let mut count = 0;
58                    for (other_idx, other_chunk) in chunks.iter().enumerate() {
59                        if other_idx == idx {
60                            continue;
61                        }
62                        sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
63                        count += 1;
64                    }
65                    if count > 0 {
66                        let avg_sim = sim_sum / count as f64; // Lower means more misplaced.
67                        outliers.push(OutlierSymbol {
68                            file_path: file_path.clone(),
69                            symbol_name: chunk.symbol_name.clone(),
70                            kind: chunk.kind.clone(),
71                            line: chunk.line,
72                            avg_similarity_to_file: avg_sim,
73                        });
74                    }
75                }
76                Ok(())
77            })?;
78
79        outliers.sort_by(|a, b| {
80            // G5: bias the ranking by structural role so expected-diverse
81            // files (entry points, tests, handler aggregators) fall below
82            // genuine misplacements instead of crowding the top.
83            let a_adj = a.avg_similarity_to_file + file_structural_role_boost(&a.file_path);
84            let b_adj = b.avg_similarity_to_file + file_structural_role_boost(&b.file_path);
85            a_adj
86                .partial_cmp(&b_adj)
87                .unwrap_or(std::cmp::Ordering::Equal)
88        });
89        outliers.truncate(max_results);
90        Ok(outliers)
91    }
92
93    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
94        self.embed_texts_cached(&[text])?
95            .into_iter()
96            .next()
97            .ok_or_else(|| anyhow::anyhow!("missing embedding for text"))
98    }
99
100    // ── Artifact memory API (Phase 1 — v0.15+) ────────────────────────────
101
102    /// Store pre-computed artifact embeddings in the semantic index.
103    pub fn store_artifact_embeddings(&self, chunks: &[ArtifactEmbeddingChunk]) -> Result<usize> {
104        self.store.upsert_artifacts(chunks)
105    }
106
107    /// Semantic search over stored artifact analyses.
108    pub fn search_artifact_embeddings(
109        &self,
110        query: &str,
111        top_k: usize,
112    ) -> Result<Vec<ScoredArtifactChunk>> {
113        let query_embedding = self.embed_query_cached(query)?;
114        self.store.search_artifacts(&query_embedding, top_k)
115    }
116
117    /// Count stored artifact embeddings.
118    pub fn artifact_embedding_count(&self) -> Result<usize> {
119        self.store.artifact_count()
120    }
121
122    /// Prune artifact embeddings older than the given duration (ms).
123    pub fn prune_artifact_embeddings(&self, max_age_ms: u64) -> Result<usize> {
124        self.store.prune_artifacts_by_age(max_age_ms)
125    }
126
127    /// Compute mean embedding for each file from indexed symbol embeddings.
128    pub fn file_mean_embeddings(&self, file_paths: &[&str]) -> Result<HashMap<String, Vec<f32>>> {
129        let chunks = self.store.embeddings_for_files(file_paths)?;
130        let mut per_file: HashMap<String, Vec<Vec<f32>>> = HashMap::new();
131        for chunk in chunks {
132            per_file
133                .entry(chunk.file_path)
134                .or_default()
135                .push(chunk.embedding);
136        }
137        let mut result = HashMap::new();
138        for (file, embeddings) in per_file {
139            if embeddings.is_empty() {
140                continue;
141            }
142            let dim = embeddings[0].len();
143            let mut mean = vec![0.0f32; dim];
144            for emb in &embeddings {
145                for i in 0..dim {
146                    mean[i] += emb[i];
147                }
148            }
149            let count = embeddings.len() as f32;
150            for v in &mut mean {
151                *v /= count;
152            }
153            result.insert(file, mean);
154        }
155        Ok(result)
156    }
157
158    /// Compute mean embedding of multiple file embeddings.
159    pub fn mean_of_embeddings(embeddings: &[Vec<f32>]) -> Option<Vec<f32>> {
160        if embeddings.is_empty() {
161            return None;
162        }
163        let dim = embeddings[0].len();
164        let mut mean = vec![0.0f32; dim];
165        for emb in embeddings {
166            for i in 0..dim {
167                mean[i] += emb[i];
168            }
169        }
170        let count = embeddings.len() as f32;
171        for v in &mut mean {
172            *v /= count;
173        }
174        Some(mean)
175    }
176}
177
178// ── G5: role-aware outlier weighting ───────────────────────────────
179// find_misplaced_code flags symbols whose embedding is dissimilar to the
180// rest of their file. Entry points (mod.rs/lib.rs/main.*), test files, and
181// handler/dispatch aggregators legitimately hold heterogeneous symbols, so
182// their low intra-file similarity is expected — not "misplaced". A small
183// role boost on the sort key pushes those expected-diverse files down the
184// outlier ranking, reducing false positives without dropping data.
185
186/// Sort-key boost for files whose role makes heterogeneous symbols normal.
187/// Tuned conservatively; revisit with live dogfood false-positive metrics.
188const ROLE_BOOST_DIVERSE: f64 = 0.15; // entry points + test files
189const ROLE_BOOST_HANDLER: f64 = 0.10; // handler/dispatch aggregators
190
191/// Outlier-score boost for files whose structural role makes low intra-file
192/// similarity expected. `0.0` for ordinary code files. Match is by file name
193/// (case-insensitive) and path segment, so it is language-agnostic.
194fn file_structural_role_boost(path: &str) -> f64 {
195    let file = std::path::Path::new(path)
196        .file_name()
197        .and_then(|n| n.to_str())
198        .unwrap_or("")
199        .to_ascii_lowercase();
200    let path_lower = path.to_ascii_lowercase();
201
202    let is_test = file == "tests.rs"
203        || file.ends_with("_test.rs")
204        || file.ends_with("_tests.rs")
205        || path_lower
206            .split('/')
207            .any(|seg| seg == "tests" || seg == "test");
208    let is_entry = matches!(file.as_str(), "mod.rs" | "lib.rs" | "main.rs")
209        || file.starts_with("main.")
210        || file.starts_with("index.");
211    let is_handler =
212        file == "handlers.rs" || file.ends_with("_handler.rs") || file.ends_with("_handlers.rs");
213
214    if is_test || is_entry {
215        ROLE_BOOST_DIVERSE
216    } else if is_handler {
217        ROLE_BOOST_HANDLER
218    } else {
219        0.0
220    }
221}
222
223#[cfg(test)]
224mod g5_role_boost_tests {
225    use super::file_structural_role_boost;
226
227    #[test]
228    fn entry_point_files_get_boost() {
229        assert!(file_structural_role_boost("src/lib.rs") > 0.0);
230        assert!(file_structural_role_boost("a/b/mod.rs") > 0.0);
231        assert!(file_structural_role_boost("pkg/main.py") > 0.0);
232    }
233
234    #[test]
235    fn test_files_get_boost() {
236        assert!(file_structural_role_boost("src/embedding/tests.rs") > 0.0);
237        assert!(file_structural_role_boost("foo/bar_test.rs") > 0.0);
238        assert!(file_structural_role_boost("tests/integration.rs") > 0.0);
239    }
240
241    #[test]
242    fn handler_aggregators_get_boost() {
243        assert!(file_structural_role_boost("tools/handlers.rs") > 0.0);
244        assert!(file_structural_role_boost("foo_handler.rs") > 0.0);
245    }
246
247    #[test]
248    fn plain_code_files_get_no_boost() {
249        assert_eq!(
250            file_structural_role_boost("src/embedding/duplicates.rs"),
251            0.0
252        );
253        assert_eq!(file_structural_role_boost("src/ranking.rs"), 0.0);
254    }
255
256    #[test]
257    fn boost_stays_bounded() {
258        for p in ["lib.rs", "tests.rs", "x_handler.rs", "normal.rs"] {
259            let b = file_structural_role_boost(p);
260            assert!((0.0..=0.3).contains(&b), "{p} -> {b}");
261        }
262    }
263}