Skip to main content

graphify_analyze/
embedding.rs

1//! Graph embedding via simplified Node2Vec random walks.
2//!
3//! Learns low-dimensional vector representations of graph nodes by:
4//! 1. Performing random walks from each node
5//! 2. Training Skip-gram embeddings with SGD
6//! 3. Finding structurally similar node pairs via cosine similarity
7
8use std::collections::HashMap;
9
10use graphify_core::graph::KnowledgeGraph;
11use graphify_core::model::SimilarPair;
12
13/// Compute node embeddings using random walks + Skip-gram.
14///
15/// - `dim`: embedding dimension (default 64)
16/// - `walks_per_node`: random walks starting from each node (default 10)
17/// - `walk_length`: length of each walk (default 40)
18///
19/// Returns a map of node_id → embedding vector.
20pub fn compute_embeddings(
21    graph: &KnowledgeGraph,
22    dim: usize,
23    walks_per_node: usize,
24    walk_length: usize,
25) -> HashMap<String, Vec<f64>> {
26    let ids = graph.node_ids();
27    let n = ids.len();
28    if n == 0 {
29        return HashMap::new();
30    }
31
32    let id_to_idx: HashMap<&str, usize> = ids
33        .iter()
34        .enumerate()
35        .map(|(i, s)| (s.as_str(), i))
36        .collect();
37
38    // Build adjacency list with indices
39    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
40    for (src, tgt, _) in graph.edges_with_endpoints() {
41        if let (Some(&si), Some(&ti)) = (id_to_idx.get(src), id_to_idx.get(tgt)) {
42            adj[si].push(ti);
43            adj[ti].push(si);
44        }
45    }
46
47    // Initialize embeddings randomly (deterministic seed via simple hash)
48    let mut embeddings: Vec<Vec<f64>> = (0..n)
49        .map(|i| {
50            (0..dim)
51                .map(|d| {
52                    let seed = (i as u64)
53                        .wrapping_mul(6364136223846793005)
54                        .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
55                    ((seed as f64).sin() * 0.1).abs() - 0.05
56                })
57                .collect()
58        })
59        .collect();
60
61    let mut context_vecs: Vec<Vec<f64>> = (0..n)
62        .map(|i| {
63            (0..dim)
64                .map(|d| {
65                    let seed = ((i + n) as u64)
66                        .wrapping_mul(6364136223846793005)
67                        .wrapping_add((d as u64).wrapping_mul(1442695040888963407));
68                    ((seed as f64).cos() * 0.1).abs() - 0.05
69                })
70                .collect()
71        })
72        .collect();
73
74    let window = 5usize;
75    let learning_rate = 0.025;
76
77    // Generate walks and train
78    for walk_num in 0..walks_per_node {
79        for start in 0..n {
80            // Random walk
81            let walk = random_walk(&adj, start, walk_length, walk_num);
82
83            // Skip-gram: for each position, update with context window
84            for (pos, &center) in walk.iter().enumerate() {
85                let ctx_start = pos.saturating_sub(window);
86                let ctx_end = (pos + window + 1).min(walk.len());
87                for (ctx_pos, &context) in walk[ctx_start..ctx_end].iter().enumerate() {
88                    let actual_pos = ctx_start + ctx_pos;
89                    if actual_pos == pos {
90                        continue;
91                    }
92                    // Simplified SGD update (no negative sampling for speed)
93                    let dot: f64 = embeddings[center]
94                        .iter()
95                        .zip(context_vecs[context].iter())
96                        .map(|(a, b)| a * b)
97                        .sum();
98                    let sigmoid = 1.0 / (1.0 + (-dot).exp());
99                    let err = 1.0 - sigmoid; // target = 1 for positive pair
100                    let lr = learning_rate * err;
101
102                    for d in 0..dim {
103                        let grad_e = lr * context_vecs[context][d];
104                        let grad_c = lr * embeddings[center][d];
105                        embeddings[center][d] += grad_e;
106                        context_vecs[context][d] += grad_c;
107                    }
108                }
109            }
110        }
111    }
112
113    // Build result map
114    ids.into_iter()
115        .enumerate()
116        .map(|(i, id)| (id, embeddings[i].clone()))
117        .collect()
118}
119
120/// Find top-N most similar node pairs by cosine similarity of embeddings.
121pub fn find_similar(
122    graph: &KnowledgeGraph,
123    embeddings: &HashMap<String, Vec<f64>>,
124    top_n: usize,
125) -> Vec<SimilarPair> {
126    let ids: Vec<&String> = embeddings.keys().collect();
127    let n = ids.len();
128    if n < 2 {
129        return Vec::new();
130    }
131
132    // Pre-compute norms
133    let norms: HashMap<&String, f64> = ids
134        .iter()
135        .map(|&id| {
136            let norm = embeddings[id]
137                .iter()
138                .map(|x| x * x)
139                .sum::<f64>()
140                .sqrt()
141                .max(1e-10);
142            (id, norm)
143        })
144        .collect();
145
146    let mut pairs: Vec<SimilarPair> = Vec::new();
147
148    // Compare all pairs (for large graphs, could use LSH)
149    let limit = n.min(500); // Cap to avoid O(n²) explosion on large graphs
150    for i in 0..limit {
151        for j in (i + 1)..limit {
152            let id_a = ids[i];
153            let id_b = ids[j];
154            let emb_a = &embeddings[id_a];
155            let emb_b = &embeddings[id_b];
156
157            let dot: f64 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
158            let sim = dot / (norms[id_a] * norms[id_b]);
159
160            if sim > 0.5 {
161                let label_a = graph
162                    .get_node(id_a)
163                    .map(|n| n.label.clone())
164                    .unwrap_or_default();
165                let label_b = graph
166                    .get_node(id_b)
167                    .map(|n| n.label.clone())
168                    .unwrap_or_default();
169                pairs.push(SimilarPair {
170                    node_a: id_a.clone(),
171                    node_b: id_b.clone(),
172                    similarity: sim,
173                    label_a,
174                    label_b,
175                });
176            }
177        }
178    }
179
180    pairs.sort_by(|a, b| {
181        b.similarity
182            .partial_cmp(&a.similarity)
183            .unwrap_or(std::cmp::Ordering::Equal)
184    });
185    pairs.truncate(top_n);
186    pairs
187}
188
189/// Deterministic random walk from a start node.
190fn random_walk(adj: &[Vec<usize>], start: usize, length: usize, seed: usize) -> Vec<usize> {
191    let mut walk = Vec::with_capacity(length);
192    let mut current = start;
193    let mut rng_state = start.wrapping_mul(2654435761) ^ seed.wrapping_mul(1103515245);
194
195    walk.push(current);
196    for _ in 1..length {
197        let neighbors = &adj[current];
198        if neighbors.is_empty() {
199            break;
200        }
201        // Deterministic pseudo-random neighbor selection
202        rng_state = rng_state
203            .wrapping_mul(6364136223846793005)
204            .wrapping_add(1442695040888963407);
205        let idx = rng_state % neighbors.len();
206        current = neighbors[idx];
207        walk.push(current);
208    }
209    walk
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use graphify_core::confidence::Confidence;
216    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
217
218    fn make_graph() -> KnowledgeGraph {
219        let mut kg = KnowledgeGraph::new();
220        for id in &["a", "b", "c", "d"] {
221            kg.add_node(GraphNode {
222                id: id.to_string(),
223                label: id.to_string(),
224                source_file: "test.rs".into(),
225                source_location: None,
226                node_type: NodeType::Function,
227                community: None,
228                extra: Default::default(),
229            })
230            .unwrap();
231        }
232        for (s, t) in &[("a", "b"), ("b", "c"), ("c", "d"), ("a", "d")] {
233            kg.add_edge(GraphEdge {
234                source: s.to_string(),
235                target: t.to_string(),
236                relation: "calls".into(),
237                confidence: Confidence::Extracted,
238                confidence_score: 1.0,
239                source_file: "test.rs".into(),
240                source_location: None,
241                weight: 1.0,
242                extra: Default::default(),
243            })
244            .unwrap();
245        }
246        kg
247    }
248
249    #[test]
250    fn compute_embeddings_produces_correct_dims() {
251        let kg = make_graph();
252        let embs = compute_embeddings(&kg, 16, 5, 10);
253        assert_eq!(embs.len(), 4);
254        for vec in embs.values() {
255            assert_eq!(vec.len(), 16);
256        }
257    }
258
259    #[test]
260    fn find_similar_returns_pairs() {
261        let kg = make_graph();
262        let embs = compute_embeddings(&kg, 16, 10, 20);
263        let pairs = find_similar(&kg, &embs, 5);
264        // Connected nodes should have some similarity
265        assert!(!pairs.is_empty() || embs.len() < 2);
266    }
267
268    #[test]
269    fn empty_graph_embeddings() {
270        let kg = KnowledgeGraph::new();
271        let embs = compute_embeddings(&kg, 16, 5, 10);
272        assert!(embs.is_empty());
273    }
274}