codelens_engine/embedding/engine_impl/
analysis.rs1use anyhow::Result;
2use std::collections::HashMap;
3
4use super::super::EmbeddingEngine;
5use super::super::chunk_ops::{CategoryScore, OutlierSymbol, cosine_similarity};
6use crate::embedding_store::{ArtifactEmbeddingChunk, ScoredArtifactChunk};
7
8impl EmbeddingEngine {
9 pub fn classify_symbol(
11 &self,
12 file_path: &str,
13 symbol_name: &str,
14 categories: &[&str],
15 ) -> Result<Vec<CategoryScore>> {
16 let target = match self.store.get_embedding(file_path, symbol_name)? {
17 Some(target) => target,
18 None => self
19 .store
20 .all_with_embeddings()?
21 .into_iter()
22 .find(|c| c.file_path == file_path && c.symbol_name == symbol_name)
23 .ok_or_else(|| anyhow::anyhow!("Symbol '{}' not found in index", symbol_name))?,
24 };
25
26 let embeddings = self.embed_texts_cached(categories)?;
27
28 let mut scores: Vec<CategoryScore> = categories
29 .iter()
30 .zip(embeddings.iter())
31 .map(|(cat, emb)| CategoryScore {
32 category: cat.to_string(),
33 score: cosine_similarity(&target.embedding, emb),
34 })
35 .collect();
36
37 scores.sort_by(|a, b| {
38 b.score
39 .partial_cmp(&a.score)
40 .unwrap_or(std::cmp::Ordering::Equal)
41 });
42 Ok(scores)
43 }
44
45 pub fn find_misplaced_code(&self, max_results: usize) -> Result<Vec<OutlierSymbol>> {
47 let mut outliers = Vec::new();
48
49 self.store
50 .for_each_file_embeddings(&mut |file_path, chunks| {
51 if chunks.len() < 2 {
52 return Ok(());
53 }
54
55 for (idx, chunk) in chunks.iter().enumerate() {
56 let mut sim_sum = 0.0;
57 let mut count = 0;
58 for (other_idx, other_chunk) in chunks.iter().enumerate() {
59 if other_idx == idx {
60 continue;
61 }
62 sim_sum += cosine_similarity(&chunk.embedding, &other_chunk.embedding);
63 count += 1;
64 }
65 if count > 0 {
66 let avg_sim = sim_sum / count as f64; outliers.push(OutlierSymbol {
68 file_path: file_path.clone(),
69 symbol_name: chunk.symbol_name.clone(),
70 kind: chunk.kind.clone(),
71 line: chunk.line,
72 avg_similarity_to_file: avg_sim,
73 });
74 }
75 }
76 Ok(())
77 })?;
78
79 outliers.sort_by(|a, b| {
80 let a_adj = a.avg_similarity_to_file + file_structural_role_boost(&a.file_path);
84 let b_adj = b.avg_similarity_to_file + file_structural_role_boost(&b.file_path);
85 a_adj
86 .partial_cmp(&b_adj)
87 .unwrap_or(std::cmp::Ordering::Equal)
88 });
89 outliers.truncate(max_results);
90 Ok(outliers)
91 }
92
93 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
94 self.embed_texts_cached(&[text])?
95 .into_iter()
96 .next()
97 .ok_or_else(|| anyhow::anyhow!("missing embedding for text"))
98 }
99
100 pub fn store_artifact_embeddings(&self, chunks: &[ArtifactEmbeddingChunk]) -> Result<usize> {
104 self.store.upsert_artifacts(chunks)
105 }
106
107 pub fn search_artifact_embeddings(
109 &self,
110 query: &str,
111 top_k: usize,
112 ) -> Result<Vec<ScoredArtifactChunk>> {
113 let query_embedding = self.embed_query_cached(query)?;
114 self.store.search_artifacts(&query_embedding, top_k)
115 }
116
117 pub fn artifact_embedding_count(&self) -> Result<usize> {
119 self.store.artifact_count()
120 }
121
122 pub fn prune_artifact_embeddings(&self, max_age_ms: u64) -> Result<usize> {
124 self.store.prune_artifacts_by_age(max_age_ms)
125 }
126
127 pub fn file_mean_embeddings(&self, file_paths: &[&str]) -> Result<HashMap<String, Vec<f32>>> {
129 let chunks = self.store.embeddings_for_files(file_paths)?;
130 let mut per_file: HashMap<String, Vec<Vec<f32>>> = HashMap::new();
131 for chunk in chunks {
132 per_file
133 .entry(chunk.file_path)
134 .or_default()
135 .push(chunk.embedding);
136 }
137 let mut result = HashMap::new();
138 for (file, embeddings) in per_file {
139 if embeddings.is_empty() {
140 continue;
141 }
142 let dim = embeddings[0].len();
143 let mut mean = vec![0.0f32; dim];
144 for emb in &embeddings {
145 for i in 0..dim {
146 mean[i] += emb[i];
147 }
148 }
149 let count = embeddings.len() as f32;
150 for v in &mut mean {
151 *v /= count;
152 }
153 result.insert(file, mean);
154 }
155 Ok(result)
156 }
157
158 pub fn mean_of_embeddings(embeddings: &[Vec<f32>]) -> Option<Vec<f32>> {
160 if embeddings.is_empty() {
161 return None;
162 }
163 let dim = embeddings[0].len();
164 let mut mean = vec![0.0f32; dim];
165 for emb in embeddings {
166 for i in 0..dim {
167 mean[i] += emb[i];
168 }
169 }
170 let count = embeddings.len() as f32;
171 for v in &mut mean {
172 *v /= count;
173 }
174 Some(mean)
175 }
176}
177
178const ROLE_BOOST_DIVERSE: f64 = 0.15; const ROLE_BOOST_HANDLER: f64 = 0.10; fn file_structural_role_boost(path: &str) -> f64 {
195 let file = std::path::Path::new(path)
196 .file_name()
197 .and_then(|n| n.to_str())
198 .unwrap_or("")
199 .to_ascii_lowercase();
200 let path_lower = path.to_ascii_lowercase();
201
202 let is_test = file == "tests.rs"
203 || file.ends_with("_test.rs")
204 || file.ends_with("_tests.rs")
205 || path_lower
206 .split('/')
207 .any(|seg| seg == "tests" || seg == "test");
208 let is_entry = matches!(file.as_str(), "mod.rs" | "lib.rs" | "main.rs")
209 || file.starts_with("main.")
210 || file.starts_with("index.");
211 let is_handler =
212 file == "handlers.rs" || file.ends_with("_handler.rs") || file.ends_with("_handlers.rs");
213
214 if is_test || is_entry {
215 ROLE_BOOST_DIVERSE
216 } else if is_handler {
217 ROLE_BOOST_HANDLER
218 } else {
219 0.0
220 }
221}
222
223#[cfg(test)]
224mod g5_role_boost_tests {
225 use super::file_structural_role_boost;
226
227 #[test]
228 fn entry_point_files_get_boost() {
229 assert!(file_structural_role_boost("src/lib.rs") > 0.0);
230 assert!(file_structural_role_boost("a/b/mod.rs") > 0.0);
231 assert!(file_structural_role_boost("pkg/main.py") > 0.0);
232 }
233
234 #[test]
235 fn test_files_get_boost() {
236 assert!(file_structural_role_boost("src/embedding/tests.rs") > 0.0);
237 assert!(file_structural_role_boost("foo/bar_test.rs") > 0.0);
238 assert!(file_structural_role_boost("tests/integration.rs") > 0.0);
239 }
240
241 #[test]
242 fn handler_aggregators_get_boost() {
243 assert!(file_structural_role_boost("tools/handlers.rs") > 0.0);
244 assert!(file_structural_role_boost("foo_handler.rs") > 0.0);
245 }
246
247 #[test]
248 fn plain_code_files_get_no_boost() {
249 assert_eq!(
250 file_structural_role_boost("src/embedding/duplicates.rs"),
251 0.0
252 );
253 assert_eq!(file_structural_role_boost("src/ranking.rs"), 0.0);
254 }
255
256 #[test]
257 fn boost_stays_bounded() {
258 for p in ["lib.rs", "tests.rs", "x_handler.rs", "normal.rs"] {
259 let b = file_structural_role_boost(p);
260 assert!((0.0..=0.3).contains(&b), "{p} -> {b}");
261 }
262 }
263}