Skip to main content

graphify_analyze/
embedding.rs

1//! Graph embedding via simplified Node2Vec random walks.
2//!
3//! Learns low-dimensional vector representations of graph nodes by:
4//! 1. Performing random walks from each node
5//! 2. Training Skip-gram embeddings with SGD
6//! 3. Finding structurally similar node pairs via cosine similarity
7
8use std::collections::HashMap;
9
10use graphify_core::graph::KnowledgeGraph;
11use graphify_core::model::SimilarPair;
12
13/// Compute node embeddings using random walks + Skip-gram.
14///
15/// - `dim`: embedding dimension (default 64)
16/// - `walks_per_node`: random walks starting from each node (default 10)
17/// - `walk_length`: length of each walk (default 40)
18///
19/// Returns a map of node_id → embedding vector.
20pub fn compute_embeddings(
21    graph: &KnowledgeGraph,
22    dim: usize,
23    walks_per_node: usize,
24    walk_length: usize,
25) -> HashMap<String, Vec<f64>> {
26    let ids = graph.node_ids();
27    let n = ids.len();
28    if n == 0 {
29        return HashMap::new();
30    }
31
32    let id_to_idx: HashMap<&str, usize> = ids
33        .iter()
34        .enumerate()
35        .map(|(i, s)| (s.as_str(), i))
36        .collect();
37
38    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
39    for (src, tgt, _) in graph.edges_with_endpoints() {
40        if let (Some(&si), Some(&ti)) = (id_to_idx.get(src), id_to_idx.get(tgt)) {
41            adj[si].push(ti);
42            adj[ti].push(si);
43        }
44    }
45
46    let mut embeddings: Vec<Vec<f64>> = (0..n)
47        .map(|i| {
48            (0..dim)
49                .map(|d| {
50                    let seed = (i as u64)
51                        .wrapping_mul(6364136223846793005)
52                        .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
53                    ((seed as f64).sin() * 0.1).abs() - 0.05
54                })
55                .collect()
56        })
57        .collect();
58
59    let mut context_vecs: Vec<Vec<f64>> = (0..n)
60        .map(|i| {
61            (0..dim)
62                .map(|d| {
63                    let seed = ((i + n) as u64)
64                        .wrapping_mul(6364136223846793005)
65                        .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
66                    ((seed as f64).cos() * 0.1).abs() - 0.05
67                })
68                .collect()
69        })
70        .collect();
71
72    let window = 5usize;
73    let learning_rate = 0.025;
74
75    for walk_num in 0..walks_per_node {
76        for start in 0..n {
77            let walk = random_walk(&adj, start, walk_length, walk_num);
78
79            for (pos, &center) in walk.iter().enumerate() {
80                let ctx_start = pos.saturating_sub(window);
81                let ctx_end = (pos + window + 1).min(walk.len());
82                for (ctx_pos, &context) in walk[ctx_start..ctx_end].iter().enumerate() {
83                    let actual_pos = ctx_start + ctx_pos;
84                    if actual_pos == pos {
85                        continue;
86                    }
87                    let dot: f64 = embeddings[center]
88                        .iter()
89                        .zip(context_vecs[context].iter())
90                        .map(|(a, b)| a * b)
91                        .sum();
92                    let sigmoid = 1.0 / (1.0 + (-dot).exp());
93                    let err = 1.0 - sigmoid; // target = 1 for positive pair
94                    let lr = learning_rate * err;
95
96                    for d in 0..dim {
97                        let grad_e = lr * context_vecs[context][d];
98                        let grad_c = lr * embeddings[center][d];
99                        embeddings[center][d] += grad_e;
100                        context_vecs[context][d] += grad_c;
101                    }
102                }
103            }
104        }
105    }
106
107    ids.into_iter()
108        .enumerate()
109        .map(|(i, id)| (id, embeddings[i].clone()))
110        .collect()
111}
112
113/// Find top-N most similar node pairs by cosine similarity of embeddings.
114pub fn find_similar(
115    graph: &KnowledgeGraph,
116    embeddings: &HashMap<String, Vec<f64>>,
117    top_n: usize,
118) -> Vec<SimilarPair> {
119    let ids: Vec<&String> = embeddings.keys().collect();
120    let n = ids.len();
121    if n < 2 {
122        return Vec::new();
123    }
124
125    let norms: HashMap<&String, f64> = ids
126        .iter()
127        .map(|&id| {
128            let norm = embeddings[id]
129                .iter()
130                .map(|x| x * x)
131                .sum::<f64>()
132                .sqrt()
133                .max(1e-10);
134            (id, norm)
135        })
136        .collect();
137
138    let mut pairs: Vec<SimilarPair> = Vec::new();
139
140    let limit = n.min(500); // Cap to avoid O(n²) explosion on large graphs
141    for i in 0..limit {
142        for j in (i + 1)..limit {
143            let id_a = ids[i];
144            let id_b = ids[j];
145            let emb_a = &embeddings[id_a];
146            let emb_b = &embeddings[id_b];
147
148            let dot: f64 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
149            let sim = dot / (norms[id_a] * norms[id_b]);
150
151            if sim > 0.5 {
152                let label_a = graph
153                    .get_node(id_a)
154                    .map(|n| n.label.clone())
155                    .unwrap_or_default();
156                let label_b = graph
157                    .get_node(id_b)
158                    .map(|n| n.label.clone())
159                    .unwrap_or_default();
160                pairs.push(SimilarPair {
161                    node_a: id_a.clone(),
162                    node_b: id_b.clone(),
163                    similarity: sim,
164                    label_a,
165                    label_b,
166                });
167            }
168        }
169    }
170
171    pairs.sort_by(|a, b| {
172        b.similarity
173            .partial_cmp(&a.similarity)
174            .unwrap_or(std::cmp::Ordering::Equal)
175    });
176    pairs.truncate(top_n);
177    pairs
178}
179
180/// Deterministic random walk from a start node.
181fn random_walk(adj: &[Vec<usize>], start: usize, length: usize, seed: usize) -> Vec<usize> {
182    let mut walk = Vec::with_capacity(length);
183    let mut current = start;
184    let mut rng_state = start.wrapping_mul(2654435761) ^ seed.wrapping_mul(1103515245);
185
186    walk.push(current);
187    for _ in 1..length {
188        let neighbors = &adj[current];
189        if neighbors.is_empty() {
190            break;
191        }
192        rng_state = rng_state
193            .wrapping_mul(6364136223846793005)
194            .wrapping_add(1442695040888963407);
195        let idx = rng_state % neighbors.len();
196        current = neighbors[idx];
197        walk.push(current);
198    }
199    walk
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use graphify_core::confidence::Confidence;
206    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
207
208    fn make_graph() -> KnowledgeGraph {
209        let mut kg = KnowledgeGraph::new();
210        for id in &["a", "b", "c", "d"] {
211            kg.add_node(GraphNode {
212                id: id.to_string(),
213                label: id.to_string(),
214                source_file: "test.rs".into(),
215                source_location: None,
216                node_type: NodeType::Function,
217                community: None,
218                extra: Default::default(),
219            })
220            .unwrap();
221        }
222        for (s, t) in &[("a", "b"), ("b", "c"), ("c", "d"), ("a", "d")] {
223            kg.add_edge(GraphEdge {
224                source: s.to_string(),
225                target: t.to_string(),
226                relation: "calls".into(),
227                confidence: Confidence::Extracted,
228                confidence_score: 1.0,
229                source_file: "test.rs".into(),
230                source_location: None,
231                weight: 1.0,
232                provenance: None,
233                extra: Default::default(),
234            })
235            .unwrap();
236        }
237        kg
238    }
239
240    #[test]
241    fn compute_embeddings_produces_correct_dims() {
242        let kg = make_graph();
243        let embs = compute_embeddings(&kg, 16, 5, 10);
244        assert_eq!(embs.len(), 4);
245        for vec in embs.values() {
246            assert_eq!(vec.len(), 16);
247        }
248    }
249
250    #[test]
251    fn find_similar_returns_pairs() {
252        let kg = make_graph();
253        let embs = compute_embeddings(&kg, 16, 10, 20);
254        let pairs = find_similar(&kg, &embs, 5);
255        assert!(!pairs.is_empty() || embs.len() < 2);
256    }
257
258    #[test]
259    fn empty_graph_embeddings() {
260        let kg = KnowledgeGraph::new();
261        let embs = compute_embeddings(&kg, 16, 5, 10);
262        assert!(embs.is_empty());
263    }
264}