graphify_analyze/
embedding.rs1use std::collections::HashMap;
9
10use graphify_core::graph::KnowledgeGraph;
11use graphify_core::model::SimilarPair;
12
13pub fn compute_embeddings(
21 graph: &KnowledgeGraph,
22 dim: usize,
23 walks_per_node: usize,
24 walk_length: usize,
25) -> HashMap<String, Vec<f64>> {
26 let ids = graph.node_ids();
27 let n = ids.len();
28 if n == 0 {
29 return HashMap::new();
30 }
31
32 let id_to_idx: HashMap<&str, usize> = ids
33 .iter()
34 .enumerate()
35 .map(|(i, s)| (s.as_str(), i))
36 .collect();
37
38 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
39 for (src, tgt, _) in graph.edges_with_endpoints() {
40 if let (Some(&si), Some(&ti)) = (id_to_idx.get(src), id_to_idx.get(tgt)) {
41 adj[si].push(ti);
42 adj[ti].push(si);
43 }
44 }
45
46 let mut embeddings: Vec<Vec<f64>> = (0..n)
47 .map(|i| {
48 (0..dim)
49 .map(|d| {
50 let seed = (i as u64)
51 .wrapping_mul(6364136223846793005)
52 .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
53 ((seed as f64).sin() * 0.1).abs() - 0.05
54 })
55 .collect()
56 })
57 .collect();
58
59 let mut context_vecs: Vec<Vec<f64>> = (0..n)
60 .map(|i| {
61 (0..dim)
62 .map(|d| {
63 let seed = ((i + n) as u64)
64 .wrapping_mul(6364136223846793005)
65 .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
66 ((seed as f64).cos() * 0.1).abs() - 0.05
67 })
68 .collect()
69 })
70 .collect();
71
72 let window = 5usize;
73 let learning_rate = 0.025;
74
75 for walk_num in 0..walks_per_node {
76 for start in 0..n {
77 let walk = random_walk(&adj, start, walk_length, walk_num);
78
79 for (pos, ¢er) in walk.iter().enumerate() {
80 let ctx_start = pos.saturating_sub(window);
81 let ctx_end = (pos + window + 1).min(walk.len());
82 for (ctx_pos, &context) in walk[ctx_start..ctx_end].iter().enumerate() {
83 let actual_pos = ctx_start + ctx_pos;
84 if actual_pos == pos {
85 continue;
86 }
87 let dot: f64 = embeddings[center]
88 .iter()
89 .zip(context_vecs[context].iter())
90 .map(|(a, b)| a * b)
91 .sum();
92 let sigmoid = 1.0 / (1.0 + (-dot).exp());
93 let err = 1.0 - sigmoid; let lr = learning_rate * err;
95
96 for d in 0..dim {
97 let grad_e = lr * context_vecs[context][d];
98 let grad_c = lr * embeddings[center][d];
99 embeddings[center][d] += grad_e;
100 context_vecs[context][d] += grad_c;
101 }
102 }
103 }
104 }
105 }
106
107 ids.into_iter()
108 .enumerate()
109 .map(|(i, id)| (id, embeddings[i].clone()))
110 .collect()
111}
112
113pub fn find_similar(
115 graph: &KnowledgeGraph,
116 embeddings: &HashMap<String, Vec<f64>>,
117 top_n: usize,
118) -> Vec<SimilarPair> {
119 let ids: Vec<&String> = embeddings.keys().collect();
120 let n = ids.len();
121 if n < 2 {
122 return Vec::new();
123 }
124
125 let norms: HashMap<&String, f64> = ids
126 .iter()
127 .map(|&id| {
128 let norm = embeddings[id]
129 .iter()
130 .map(|x| x * x)
131 .sum::<f64>()
132 .sqrt()
133 .max(1e-10);
134 (id, norm)
135 })
136 .collect();
137
138 let mut pairs: Vec<SimilarPair> = Vec::new();
139
140 let limit = n.min(500); for i in 0..limit {
142 for j in (i + 1)..limit {
143 let id_a = ids[i];
144 let id_b = ids[j];
145 let emb_a = &embeddings[id_a];
146 let emb_b = &embeddings[id_b];
147
148 let dot: f64 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
149 let sim = dot / (norms[id_a] * norms[id_b]);
150
151 if sim > 0.5 {
152 let label_a = graph
153 .get_node(id_a)
154 .map(|n| n.label.clone())
155 .unwrap_or_default();
156 let label_b = graph
157 .get_node(id_b)
158 .map(|n| n.label.clone())
159 .unwrap_or_default();
160 pairs.push(SimilarPair {
161 node_a: id_a.clone(),
162 node_b: id_b.clone(),
163 similarity: sim,
164 label_a,
165 label_b,
166 });
167 }
168 }
169 }
170
171 pairs.sort_by(|a, b| {
172 b.similarity
173 .partial_cmp(&a.similarity)
174 .unwrap_or(std::cmp::Ordering::Equal)
175 });
176 pairs.truncate(top_n);
177 pairs
178}
179
180fn random_walk(adj: &[Vec<usize>], start: usize, length: usize, seed: usize) -> Vec<usize> {
182 let mut walk = Vec::with_capacity(length);
183 let mut current = start;
184 let mut rng_state = start.wrapping_mul(2654435761) ^ seed.wrapping_mul(1103515245);
185
186 walk.push(current);
187 for _ in 1..length {
188 let neighbors = &adj[current];
189 if neighbors.is_empty() {
190 break;
191 }
192 rng_state = rng_state
193 .wrapping_mul(6364136223846793005)
194 .wrapping_add(1442695040888963407);
195 let idx = rng_state % neighbors.len();
196 current = neighbors[idx];
197 walk.push(current);
198 }
199 walk
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use graphify_core::confidence::Confidence;
206 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
207
208 fn make_graph() -> KnowledgeGraph {
209 let mut kg = KnowledgeGraph::new();
210 for id in &["a", "b", "c", "d"] {
211 kg.add_node(GraphNode {
212 id: id.to_string(),
213 label: id.to_string(),
214 source_file: "test.rs".into(),
215 source_location: None,
216 node_type: NodeType::Function,
217 community: None,
218 extra: Default::default(),
219 })
220 .unwrap();
221 }
222 for (s, t) in &[("a", "b"), ("b", "c"), ("c", "d"), ("a", "d")] {
223 kg.add_edge(GraphEdge {
224 source: s.to_string(),
225 target: t.to_string(),
226 relation: "calls".into(),
227 confidence: Confidence::Extracted,
228 confidence_score: 1.0,
229 source_file: "test.rs".into(),
230 source_location: None,
231 weight: 1.0,
232 extra: Default::default(),
233 })
234 .unwrap();
235 }
236 kg
237 }
238
239 #[test]
240 fn compute_embeddings_produces_correct_dims() {
241 let kg = make_graph();
242 let embs = compute_embeddings(&kg, 16, 5, 10);
243 assert_eq!(embs.len(), 4);
244 for vec in embs.values() {
245 assert_eq!(vec.len(), 16);
246 }
247 }
248
249 #[test]
250 fn find_similar_returns_pairs() {
251 let kg = make_graph();
252 let embs = compute_embeddings(&kg, 16, 10, 20);
253 let pairs = find_similar(&kg, &embs, 5);
254 assert!(!pairs.is_empty() || embs.len() < 2);
255 }
256
257 #[test]
258 fn empty_graph_embeddings() {
259 let kg = KnowledgeGraph::new();
260 let embs = compute_embeddings(&kg, 16, 5, 10);
261 assert!(embs.is_empty());
262 }
263}