graphify_analyze/
embedding.rs1use std::collections::HashMap;
9
10use graphify_core::graph::KnowledgeGraph;
11use graphify_core::model::SimilarPair;
12
13pub fn compute_embeddings(
21 graph: &KnowledgeGraph,
22 dim: usize,
23 walks_per_node: usize,
24 walk_length: usize,
25) -> HashMap<String, Vec<f64>> {
26 let ids = graph.node_ids();
27 let n = ids.len();
28 if n == 0 {
29 return HashMap::new();
30 }
31
32 let id_to_idx: HashMap<&str, usize> = ids
33 .iter()
34 .enumerate()
35 .map(|(i, s)| (s.as_str(), i))
36 .collect();
37
38 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
40 for (src, tgt, _) in graph.edges_with_endpoints() {
41 if let (Some(&si), Some(&ti)) = (id_to_idx.get(src), id_to_idx.get(tgt)) {
42 adj[si].push(ti);
43 adj[ti].push(si);
44 }
45 }
46
47 let mut embeddings: Vec<Vec<f64>> = (0..n)
49 .map(|i| {
50 (0..dim)
51 .map(|d| {
52 let seed = (i as u64)
53 .wrapping_mul(6364136223846793005)
54 .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
55 ((seed as f64).sin() * 0.1).abs() - 0.05
56 })
57 .collect()
58 })
59 .collect();
60
61 let mut context_vecs: Vec<Vec<f64>> = (0..n)
62 .map(|i| {
63 (0..dim)
64 .map(|d| {
65 let seed = ((i + n) as u64)
66 .wrapping_mul(6364136223846793005)
67 .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
68 ((seed as f64).cos() * 0.1).abs() - 0.05
69 })
70 .collect()
71 })
72 .collect();
73
74 let window = 5usize;
75 let learning_rate = 0.025;
76
77 for walk_num in 0..walks_per_node {
79 for start in 0..n {
80 let walk = random_walk(&adj, start, walk_length, walk_num);
82
83 for (pos, ¢er) in walk.iter().enumerate() {
85 let ctx_start = pos.saturating_sub(window);
86 let ctx_end = (pos + window + 1).min(walk.len());
87 for ctx_pos in ctx_start..ctx_end {
88 if ctx_pos == pos {
89 continue;
90 }
91 let context = walk[ctx_pos];
92 let dot: f64 = embeddings[center]
94 .iter()
95 .zip(context_vecs[context].iter())
96 .map(|(a, b)| a * b)
97 .sum();
98 let sigmoid = 1.0 / (1.0 + (-dot).exp());
99 let err = 1.0 - sigmoid; let lr = learning_rate * err;
101
102 for d in 0..dim {
103 let grad_e = lr * context_vecs[context][d];
104 let grad_c = lr * embeddings[center][d];
105 embeddings[center][d] += grad_e;
106 context_vecs[context][d] += grad_c;
107 }
108 }
109 }
110 }
111 }
112
113 ids.into_iter()
115 .enumerate()
116 .map(|(i, id)| (id, embeddings[i].clone()))
117 .collect()
118}
119
120pub fn find_similar(
122 graph: &KnowledgeGraph,
123 embeddings: &HashMap<String, Vec<f64>>,
124 top_n: usize,
125) -> Vec<SimilarPair> {
126 let ids: Vec<&String> = embeddings.keys().collect();
127 let n = ids.len();
128 if n < 2 {
129 return Vec::new();
130 }
131
132 let norms: HashMap<&String, f64> = ids
134 .iter()
135 .map(|&id| {
136 let norm = embeddings[id]
137 .iter()
138 .map(|x| x * x)
139 .sum::<f64>()
140 .sqrt()
141 .max(1e-10);
142 (id, norm)
143 })
144 .collect();
145
146 let mut pairs: Vec<SimilarPair> = Vec::new();
147
148 let limit = n.min(500); for i in 0..limit {
151 for j in (i + 1)..limit {
152 let id_a = ids[i];
153 let id_b = ids[j];
154 let emb_a = &embeddings[id_a];
155 let emb_b = &embeddings[id_b];
156
157 let dot: f64 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
158 let sim = dot / (norms[id_a] * norms[id_b]);
159
160 if sim > 0.5 {
161 let label_a = graph
162 .get_node(id_a)
163 .map(|n| n.label.clone())
164 .unwrap_or_default();
165 let label_b = graph
166 .get_node(id_b)
167 .map(|n| n.label.clone())
168 .unwrap_or_default();
169 pairs.push(SimilarPair {
170 node_a: id_a.clone(),
171 node_b: id_b.clone(),
172 similarity: sim,
173 label_a,
174 label_b,
175 });
176 }
177 }
178 }
179
180 pairs.sort_by(|a, b| {
181 b.similarity
182 .partial_cmp(&a.similarity)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185 pairs.truncate(top_n);
186 pairs
187}
188
189fn random_walk(adj: &[Vec<usize>], start: usize, length: usize, seed: usize) -> Vec<usize> {
191 let mut walk = Vec::with_capacity(length);
192 let mut current = start;
193 let mut rng_state = start.wrapping_mul(2654435761) ^ seed.wrapping_mul(1103515245);
194
195 walk.push(current);
196 for _ in 1..length {
197 let neighbors = &adj[current];
198 if neighbors.is_empty() {
199 break;
200 }
201 rng_state = rng_state
203 .wrapping_mul(6364136223846793005)
204 .wrapping_add(1442695040888963407);
205 let idx = rng_state % neighbors.len();
206 current = neighbors[idx];
207 walk.push(current);
208 }
209 walk
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use graphify_core::confidence::Confidence;
216 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
217
218 fn make_graph() -> KnowledgeGraph {
219 let mut kg = KnowledgeGraph::new();
220 for id in &["a", "b", "c", "d"] {
221 kg.add_node(GraphNode {
222 id: id.to_string(),
223 label: id.to_string(),
224 source_file: "test.rs".into(),
225 source_location: None,
226 node_type: NodeType::Function,
227 community: None,
228 extra: Default::default(),
229 })
230 .unwrap();
231 }
232 for (s, t) in &[("a", "b"), ("b", "c"), ("c", "d"), ("a", "d")] {
233 kg.add_edge(GraphEdge {
234 source: s.to_string(),
235 target: t.to_string(),
236 relation: "calls".into(),
237 confidence: Confidence::Extracted,
238 confidence_score: 1.0,
239 source_file: "test.rs".into(),
240 source_location: None,
241 weight: 1.0,
242 extra: Default::default(),
243 })
244 .unwrap();
245 }
246 kg
247 }
248
249 #[test]
250 fn compute_embeddings_produces_correct_dims() {
251 let kg = make_graph();
252 let embs = compute_embeddings(&kg, 16, 5, 10);
253 assert_eq!(embs.len(), 4);
254 for vec in embs.values() {
255 assert_eq!(vec.len(), 16);
256 }
257 }
258
259 #[test]
260 fn find_similar_returns_pairs() {
261 let kg = make_graph();
262 let embs = compute_embeddings(&kg, 16, 10, 20);
263 let pairs = find_similar(&kg, &embs, 5);
264 assert!(!pairs.is_empty() || embs.len() < 2);
266 }
267
268 #[test]
269 fn empty_graph_embeddings() {
270 let kg = KnowledgeGraph::new();
271 let embs = compute_embeddings(&kg, 16, 5, 10);
272 assert!(embs.is_empty());
273 }
274}