1use crate::graph::CodeGraph;
8
9const NORM_EPSILON: f32 = 1e-10;
11
12#[derive(Debug, Clone)]
17pub struct EmbeddingIndex {
18 entries: Vec<(u64, Vec<f32>)>,
20 dimension: usize,
22}
23
24#[derive(Debug, Clone)]
26pub struct EmbeddingMatch {
27 pub unit_id: u64,
29 pub score: f32,
31}
32
33impl EmbeddingIndex {
34 pub fn build(graph: &CodeGraph) -> Self {
39 let dimension = graph.dimension();
40 let mut entries = Vec::with_capacity(graph.unit_count());
41
42 for unit in graph.units() {
43 if unit.feature_vec.len() == dimension {
44 let norm = vec_norm(&unit.feature_vec);
45 if norm > NORM_EPSILON {
46 entries.push((unit.id, unit.feature_vec.clone()));
47 }
48 }
49 }
50
51 Self { entries, dimension }
52 }
53
54 pub fn search(&self, query: &[f32], top_k: usize, min_similarity: f32) -> Vec<EmbeddingMatch> {
62 if query.len() != self.dimension {
63 return Vec::new();
64 }
65
66 let query_norm = vec_norm(query);
67 if query_norm < NORM_EPSILON {
68 return Vec::new();
69 }
70
71 let mut results: Vec<EmbeddingMatch> = self
72 .entries
73 .iter()
74 .filter_map(|(id, vec)| {
75 let score = cosine_similarity(query, vec, query_norm);
76 if score >= min_similarity {
77 Some(EmbeddingMatch {
78 unit_id: *id,
79 score,
80 })
81 } else {
82 None
83 }
84 })
85 .collect();
86
87 results.sort_by(|a, b| {
89 b.score
90 .partial_cmp(&a.score)
91 .unwrap_or(std::cmp::Ordering::Equal)
92 .then_with(|| a.unit_id.cmp(&b.unit_id))
93 });
94
95 results.truncate(top_k);
96 results
97 }
98
99 pub fn dimension(&self) -> usize {
101 self.dimension
102 }
103
104 pub fn len(&self) -> usize {
106 self.entries.len()
107 }
108
109 pub fn is_empty(&self) -> bool {
111 self.entries.is_empty()
112 }
113}
114
115impl Default for EmbeddingIndex {
116 fn default() -> Self {
117 Self {
118 entries: Vec::new(),
119 dimension: crate::types::DEFAULT_DIMENSION,
120 }
121 }
122}
123
124fn vec_norm(v: &[f32]) -> f32 {
126 let sum: f32 = v.iter().map(|x| x * x).sum();
127 sum.sqrt()
128}
129
130fn cosine_similarity(a: &[f32], b: &[f32], a_norm: f32) -> f32 {
133 let b_norm = vec_norm(b);
134 if b_norm < NORM_EPSILON || a_norm < NORM_EPSILON {
135 return 0.0;
136 }
137
138 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
139 dot / (a_norm * b_norm)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::graph::CodeGraph;
146 use crate::types::{CodeUnit, CodeUnitType, Language, Span};
147 use std::path::PathBuf;
148
149 fn make_unit_with_vec(feature_vec: Vec<f32>) -> CodeUnit {
150 let mut unit = CodeUnit::new(
151 CodeUnitType::Function,
152 Language::Rust,
153 "test_fn".to_string(),
154 "mod::test_fn".to_string(),
155 PathBuf::from("src/lib.rs"),
156 Span::new(1, 0, 10, 0),
157 );
158 unit.feature_vec = feature_vec;
159 unit
160 }
161
162 #[test]
163 fn test_empty_index() {
164 let graph = CodeGraph::default();
165 let index = EmbeddingIndex::build(&graph);
166 assert!(index.is_empty());
167 assert_eq!(index.len(), 0);
168 assert_eq!(index.dimension(), 256);
169 }
170
171 #[test]
172 fn test_zero_vectors_excluded() {
173 let dim = 4;
174 let mut graph = CodeGraph::new(dim);
175 graph.add_unit(make_unit_with_vec(vec![0.0; dim]));
177 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
179
180 let index = EmbeddingIndex::build(&graph);
181 assert_eq!(index.len(), 1);
182 }
183
184 #[test]
185 fn test_search_identical_vector() {
186 let dim = 4;
187 let mut graph = CodeGraph::new(dim);
188 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
189 graph.add_unit(make_unit_with_vec(vec![0.0, 1.0, 0.0, 0.0]));
190
191 let index = EmbeddingIndex::build(&graph);
192
193 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10, 0.0);
195 assert_eq!(results.len(), 2);
196 assert_eq!(results[0].unit_id, 0);
198 assert!((results[0].score - 1.0).abs() < 1e-6);
199 assert_eq!(results[1].unit_id, 1);
201 assert!(results[1].score.abs() < 1e-6);
202 }
203
204 #[test]
205 fn test_search_top_k() {
206 let dim = 4;
207 let mut graph = CodeGraph::new(dim);
208 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
209 graph.add_unit(make_unit_with_vec(vec![0.9, 0.1, 0.0, 0.0]));
210 graph.add_unit(make_unit_with_vec(vec![0.5, 0.5, 0.0, 0.0]));
211
212 let index = EmbeddingIndex::build(&graph);
213 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 2, 0.0);
214 assert_eq!(results.len(), 2);
215 }
216
217 #[test]
218 fn test_search_min_similarity() {
219 let dim = 4;
220 let mut graph = CodeGraph::new(dim);
221 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
222 graph.add_unit(make_unit_with_vec(vec![0.0, 1.0, 0.0, 0.0]));
223
224 let index = EmbeddingIndex::build(&graph);
225 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10, 0.5);
226 assert_eq!(results.len(), 1);
227 assert_eq!(results[0].unit_id, 0);
228 }
229
230 #[test]
231 fn test_search_wrong_dimension() {
232 let dim = 4;
233 let mut graph = CodeGraph::new(dim);
234 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
235
236 let index = EmbeddingIndex::build(&graph);
237 let results = index.search(&[1.0, 0.0], 10, 0.0);
239 assert!(results.is_empty());
240 }
241
242 #[test]
243 fn test_search_zero_query() {
244 let dim = 4;
245 let mut graph = CodeGraph::new(dim);
246 graph.add_unit(make_unit_with_vec(vec![1.0, 0.0, 0.0, 0.0]));
247
248 let index = EmbeddingIndex::build(&graph);
249 let results = index.search(&[0.0; 4], 10, 0.0);
250 assert!(results.is_empty());
251 }
252}