Skip to main content

graphrag_core/graph/
embeddings.rs

1//! Graph Embeddings
2//!
3//! This module provides graph embedding algorithms for converting graph structures
4//! into dense vector representations:
5//!
6//! - **Node2Vec**: Random walk-based embeddings capturing network neighborhoods
7//! - **GraphSAGE**: Inductive representation learning using neighborhood sampling
8//! - **DeepWalk**: Simplified random walk embeddings
9//! - **Struct2Vec**: Structure-aware graph embeddings
10//!
11//! ## Use Cases
12//!
13//! - Node classification and clustering
14//! - Link prediction
15//! - Graph visualization
16//! - Similarity search in graph space
17//! - Transfer learning across graphs
18
19use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, HashSet};
22
23/// Graph embedding configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EmbeddingConfig {
26    /// Embedding dimension
27    pub dimension: usize,
28    /// Walk length for random walks
29    pub walk_length: usize,
30    /// Number of walks per node
31    pub walks_per_node: usize,
32    /// Context window size (for Skip-Gram)
33    pub context_size: usize,
34    /// Return parameter (Node2Vec p)
35    pub return_param: f32,
36    /// In-out parameter (Node2Vec q)
37    pub inout_param: f32,
38    /// Learning rate
39    pub learning_rate: f32,
40    /// Number of negative samples
41    pub negative_samples: usize,
42    /// Number of training epochs
43    pub epochs: usize,
44}
45
46impl Default for EmbeddingConfig {
47    fn default() -> Self {
48        Self {
49            dimension: 128,
50            walk_length: 80,
51            walks_per_node: 10,
52            context_size: 10,
53            return_param: 1.0,
54            inout_param: 1.0,
55            learning_rate: 0.025,
56            negative_samples: 5,
57            epochs: 10,
58        }
59    }
60}
61
62/// Graph for embedding generation
63pub struct EmbeddingGraph {
64    /// Adjacency list: node_id -> [(neighbor_id, weight)]
65    adjacency: HashMap<String, Vec<(String, f32)>>,
66    /// All node IDs
67    nodes: Vec<String>,
68    /// Node index mapping
69    node_index: HashMap<String, usize>,
70}
71
72impl EmbeddingGraph {
73    /// Create embedding graph from edge list
74    ///
75    /// # Arguments
76    /// * `edges` - List of (source, target, weight) tuples
77    pub fn from_edges(edges: Vec<(String, String, f32)>) -> Self {
78        let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
79        let mut nodes_set = HashSet::new();
80
81        for (source, target, weight) in edges {
82            adjacency
83                .entry(source.clone())
84                .or_default()
85                .push((target.clone(), weight));
86
87            adjacency
88                .entry(target.clone())
89                .or_default()
90                .push((source.clone(), weight));
91
92            nodes_set.insert(source);
93            nodes_set.insert(target);
94        }
95
96        let nodes: Vec<String> = nodes_set.into_iter().collect();
97        let node_index: HashMap<String, usize> = nodes
98            .iter()
99            .enumerate()
100            .map(|(i, n)| (n.clone(), i))
101            .collect();
102
103        Self {
104            adjacency,
105            nodes,
106            node_index,
107        }
108    }
109
110    /// Get node count
111    pub fn node_count(&self) -> usize {
112        self.nodes.len()
113    }
114
115    /// Get neighbors of a node
116    pub fn neighbors(&self, node: &str) -> Option<&Vec<(String, f32)>> {
117        self.adjacency.get(node)
118    }
119
120    /// Get node index
121    pub fn get_index(&self, node: &str) -> Option<usize> {
122        self.node_index.get(node).copied()
123    }
124
125    /// Get node by index
126    pub fn get_node(&self, index: usize) -> Option<&String> {
127        self.nodes.get(index)
128    }
129}
130
131/// Node2Vec embeddings generator
132pub struct Node2Vec {
133    config: EmbeddingConfig,
134    /// Learned embeddings: node_id -> embedding vector
135    embeddings: HashMap<String, Vec<f32>>,
136}
137
138impl Node2Vec {
139    /// Create new Node2Vec generator
140    pub fn new(config: EmbeddingConfig) -> Self {
141        Self {
142            config,
143            embeddings: HashMap::new(),
144        }
145    }
146
147    /// Generate embeddings for graph
148    pub fn fit(&mut self, graph: &EmbeddingGraph) {
149        // Generate random walks
150        let walks = self.generate_walks(graph);
151
152        // Initialize embeddings randomly
153        self.initialize_embeddings(graph);
154
155        // Train Skip-Gram model on walks
156        self.train_skipgram(&walks);
157    }
158
159    /// Generate biased random walks (Node2Vec)
160    fn generate_walks(&self, graph: &EmbeddingGraph) -> Vec<Vec<String>> {
161        let mut rng = rand::thread_rng();
162        let mut walks = Vec::new();
163
164        for _ in 0..self.config.walks_per_node {
165            for node in &graph.nodes {
166                let walk = self.random_walk(graph, node, &mut rng);
167                walks.push(walk);
168            }
169        }
170
171        walks
172    }
173
174    /// Perform single biased random walk from starting node
175    fn random_walk<R: Rng>(&self, graph: &EmbeddingGraph, start: &str, rng: &mut R) -> Vec<String> {
176        let mut walk = vec![start.to_string()];
177
178        for _ in 1..self.config.walk_length {
179            let current = walk.last().unwrap();
180
181            if let Some(neighbors) = graph.neighbors(current) {
182                if neighbors.is_empty() {
183                    break;
184                }
185
186                // Sample next node using biased probabilities
187                let next = if walk.len() == 1 {
188                    // First step: uniform random
189                    &neighbors[rng.gen_range(0..neighbors.len())].0
190                } else {
191                    // Subsequent steps: use Node2Vec bias
192                    let prev = &walk[walk.len() - 2];
193                    self.sample_next(prev, current, neighbors, rng)
194                };
195
196                walk.push(next.clone());
197            } else {
198                break;
199            }
200        }
201
202        walk
203    }
204
205    /// Sample next node with Node2Vec bias (p, q parameters)
206    fn sample_next<'a, R: Rng>(
207        &self,
208        prev: &str,
209        _current: &str,
210        neighbors: &'a [(String, f32)],
211        rng: &mut R,
212    ) -> &'a String {
213        // Calculate transition probabilities based on p and q
214        let mut probs: Vec<f32> = neighbors
215            .iter()
216            .map(|(neighbor, weight)| {
217                let alpha = if neighbor == prev {
218                    // Return to previous node
219                    1.0 / self.config.return_param
220                } else {
221                    // Check if neighbor is also neighbor of prev (BFS vs DFS)
222                    1.0 / self.config.inout_param
223                };
224                weight * alpha
225            })
226            .collect();
227
228        // Normalize probabilities
229        let sum: f32 = probs.iter().sum();
230        if sum > 0.0 {
231            for p in &mut probs {
232                *p /= sum;
233            }
234        }
235
236        // Sample using cumulative distribution
237        let r: f32 = rng.gen();
238        let mut cumsum = 0.0;
239        for (i, &prob) in probs.iter().enumerate() {
240            cumsum += prob;
241            if r <= cumsum {
242                return &neighbors[i].0;
243            }
244        }
245
246        &neighbors[neighbors.len() - 1].0
247    }
248
249    /// Initialize random embeddings
250    fn initialize_embeddings(&mut self, graph: &EmbeddingGraph) {
251        let mut rng = rand::thread_rng();
252
253        for node in &graph.nodes {
254            let embedding: Vec<f32> = (0..self.config.dimension)
255                .map(|_| (rng.gen::<f32>() - 0.5) / self.config.dimension as f32)
256                .collect();
257
258            self.embeddings.insert(node.clone(), embedding);
259        }
260    }
261
262    /// Train Skip-Gram model on walks
263    fn train_skipgram(&mut self, walks: &[Vec<String>]) {
264        for _ in 0..self.config.epochs {
265            for walk in walks {
266                for (i, node) in walk.iter().enumerate() {
267                    // Define context window
268                    let start = i.saturating_sub(self.config.context_size);
269                    let end = (i + self.config.context_size + 1).min(walk.len());
270
271                    for (j, context_node) in walk.iter().enumerate().take(end).skip(start) {
272                        if i != j {
273                            self.update_embeddings(node, context_node);
274                        }
275                    }
276                }
277            }
278        }
279    }
280
281    /// Update embeddings using Skip-Gram objective (simplified)
282    fn update_embeddings(&mut self, target: &str, context: &str) {
283        // Simplified update: move embeddings closer for positive pairs
284        // Real implementation would use negative sampling and gradient descent
285
286        let lr = self.config.learning_rate;
287
288        if let (Some(target_emb), Some(context_emb)) =
289            (self.embeddings.get(target), self.embeddings.get(context))
290        {
291            // Calculate gradient direction (simplified)
292            let mut target_new = target_emb.clone();
293            let mut context_new = context_emb.clone();
294
295            for i in 0..self.config.dimension {
296                let diff = context_emb[i] - target_emb[i];
297                target_new[i] += lr * diff;
298                context_new[i] -= lr * diff;
299            }
300
301            self.embeddings.insert(target.to_string(), target_new);
302            self.embeddings.insert(context.to_string(), context_new);
303        }
304    }
305
306    /// Get embedding for a node
307    pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
308        self.embeddings.get(node)
309    }
310
311    /// Get all embeddings
312    pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
313        &self.embeddings
314    }
315}
316
317/// GraphSAGE configuration
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct GraphSAGEConfig {
320    /// Embedding dimension
321    pub dimension: usize,
322    /// Number of layers
323    pub num_layers: usize,
324    /// Samples per layer
325    pub samples_per_layer: Vec<usize>,
326    /// Aggregation function
327    pub aggregator: Aggregator,
328}
329
330/// Aggregation functions for GraphSAGE
331#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
332pub enum Aggregator {
333    /// Mean aggregation
334    Mean,
335    /// Max pooling
336    MaxPool,
337    /// LSTM aggregation
338    Lstm,
339    /// Attention-based
340    Attention,
341}
342
343impl Default for GraphSAGEConfig {
344    fn default() -> Self {
345        Self {
346            dimension: 128,
347            num_layers: 2,
348            samples_per_layer: vec![25, 10],
349            aggregator: Aggregator::Mean,
350        }
351    }
352}
353
354/// GraphSAGE embeddings generator
355pub struct GraphSAGE {
356    config: GraphSAGEConfig,
357    embeddings: HashMap<String, Vec<f32>>,
358}
359
360impl GraphSAGE {
361    /// Create new GraphSAGE generator
362    pub fn new(config: GraphSAGEConfig) -> Self {
363        Self {
364            config,
365            embeddings: HashMap::new(),
366        }
367    }
368
369    /// Generate embeddings for graph (simplified inductive approach)
370    pub fn fit(&mut self, graph: &EmbeddingGraph) {
371        // Initialize with node features (random for now)
372        let mut rng = rand::thread_rng();
373        let mut node_features: HashMap<String, Vec<f32>> = HashMap::new();
374
375        for node in &graph.nodes {
376            let features: Vec<f32> = (0..self.config.dimension)
377                .map(|_| rng.gen::<f32>())
378                .collect();
379            node_features.insert(node.clone(), features);
380        }
381
382        // Iteratively aggregate neighborhood information
383        for layer in 0..self.config.num_layers {
384            let samples = self
385                .config
386                .samples_per_layer
387                .get(layer)
388                .copied()
389                .unwrap_or(10);
390            node_features = self.aggregate_layer(graph, &node_features, samples);
391        }
392
393        self.embeddings = node_features;
394    }
395
396    /// Aggregate one layer of neighborhood information
397    fn aggregate_layer(
398        &self,
399        graph: &EmbeddingGraph,
400        features: &HashMap<String, Vec<f32>>,
401        num_samples: usize,
402    ) -> HashMap<String, Vec<f32>> {
403        let mut rng = rand::thread_rng();
404        let mut new_features = HashMap::new();
405
406        for node in &graph.nodes {
407            // Sample neighbors
408            let neighbors = if let Some(neighs) = graph.neighbors(node) {
409                let sample_size = num_samples.min(neighs.len());
410                let mut sampled = Vec::new();
411                let mut indices: Vec<usize> = (0..neighs.len()).collect();
412
413                for _ in 0..sample_size {
414                    let idx = rng.gen_range(0..indices.len());
415                    let neighbor_idx = indices.remove(idx);
416                    sampled.push(&neighs[neighbor_idx].0);
417                }
418
419                sampled
420            } else {
421                Vec::new()
422            };
423
424            // Aggregate neighbor features
425            let aggregated = self.aggregate_neighbors(features, &neighbors);
426
427            // Combine with node's own features
428            let node_feat = features.get(node).unwrap();
429            let combined = self.combine_features(node_feat, &aggregated);
430
431            new_features.insert(node.clone(), combined);
432        }
433
434        new_features
435    }
436
437    /// Aggregate neighbor features
438    fn aggregate_neighbors(
439        &self,
440        features: &HashMap<String, Vec<f32>>,
441        neighbors: &[&String],
442    ) -> Vec<f32> {
443        if neighbors.is_empty() {
444            return vec![0.0; self.config.dimension];
445        }
446
447        match self.config.aggregator {
448            Aggregator::Mean => {
449                let mut sum = vec![0.0; self.config.dimension];
450                for neighbor in neighbors {
451                    if let Some(feat) = features.get(*neighbor) {
452                        for i in 0..self.config.dimension {
453                            sum[i] += feat[i];
454                        }
455                    }
456                }
457
458                for val in &mut sum {
459                    *val /= neighbors.len() as f32;
460                }
461
462                sum
463            },
464            Aggregator::MaxPool => {
465                // Element-wise maximum across all neighbor features
466                let mut max_feat = vec![f32::NEG_INFINITY; self.config.dimension];
467
468                for neighbor in neighbors {
469                    if let Some(feat) = features.get(*neighbor) {
470                        for i in 0..self.config.dimension {
471                            max_feat[i] = max_feat[i].max(feat[i]);
472                        }
473                    }
474                }
475
476                // If no valid features found, return zeros
477                if max_feat.iter().all(|&v| v == f32::NEG_INFINITY) {
478                    vec![0.0; self.config.dimension]
479                } else {
480                    max_feat
481                }
482            },
483            Aggregator::Attention => {
484                // Attention-weighted aggregation
485                self.aggregate_attention(features, neighbors)
486            },
487            Aggregator::Lstm => {
488                // LSTM-based aggregation (order-dependent)
489                self.aggregate_lstm(features, neighbors)
490            },
491        }
492    }
493
494    /// Aggregate neighbors using attention mechanism
495    fn aggregate_attention(
496        &self,
497        features: &HashMap<String, Vec<f32>>,
498        neighbors: &[&String],
499    ) -> Vec<f32> {
500        if neighbors.is_empty() {
501            return vec![0.0; self.config.dimension];
502        }
503
504        // Collect neighbor features
505        let neighbor_feats: Vec<&Vec<f32>> =
506            neighbors.iter().filter_map(|n| features.get(*n)).collect();
507
508        if neighbor_feats.is_empty() {
509            return vec![0.0; self.config.dimension];
510        }
511
512        // Compute attention scores (simplified: dot product similarity)
513        let mut attention_scores = Vec::with_capacity(neighbor_feats.len());
514        let mut score_sum = 0.0;
515
516        for feat in &neighbor_feats {
517            // Simplified attention: sum of features as query
518            let score: f32 = feat.iter().sum();
519            let exp_score = score.exp();
520            attention_scores.push(exp_score);
521            score_sum += exp_score;
522        }
523
524        // Normalize attention scores (softmax)
525        if score_sum > 0.0 {
526            for score in &mut attention_scores {
527                *score /= score_sum;
528            }
529        }
530
531        // Weighted sum based on attention
532        let mut result = vec![0.0; self.config.dimension];
533        for (feat, &weight) in neighbor_feats.iter().zip(attention_scores.iter()) {
534            for i in 0..self.config.dimension {
535                result[i] += feat[i] * weight;
536            }
537        }
538
539        result
540    }
541
542    /// Aggregate neighbors using LSTM (order-dependent)
543    fn aggregate_lstm(
544        &self,
545        features: &HashMap<String, Vec<f32>>,
546        neighbors: &[&String],
547    ) -> Vec<f32> {
548        if neighbors.is_empty() {
549            return vec![0.0; self.config.dimension];
550        }
551
552        // Simplified LSTM aggregation without full LSTM cell
553        // Uses a running weighted average with decay
554        let mut hidden_state = vec![0.0; self.config.dimension];
555        let decay: f32 = 0.9; // Decay factor for previous states
556
557        for (idx, neighbor) in neighbors.iter().enumerate() {
558            if let Some(feat) = features.get(*neighbor) {
559                // Simple recurrent combination with decay
560                let weight = decay.powi(idx as i32);
561                for i in 0..self.config.dimension {
562                    hidden_state[i] = hidden_state[i] * decay + feat[i] * weight;
563                }
564            }
565        }
566
567        // Normalize by sequence length
568        let seq_len = neighbors.len() as f32;
569        for val in &mut hidden_state {
570            *val /= seq_len;
571        }
572
573        hidden_state
574    }
575
576    /// Combine node features with aggregated neighbor features
577    fn combine_features(&self, node_feat: &[f32], neighbor_feat: &[f32]) -> Vec<f32> {
578        // Simple concatenation followed by projection (simplified)
579        // Real implementation would use learned weight matrices
580
581        let mut combined = Vec::with_capacity(self.config.dimension);
582
583        for i in 0..self.config.dimension {
584            // Weighted combination
585            combined.push((node_feat[i] + neighbor_feat[i]) / 2.0);
586        }
587
588        combined
589    }
590
591    /// Get embedding for a node
592    pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
593        self.embeddings.get(node)
594    }
595
596    /// Get all embeddings
597    pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
598        &self.embeddings
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    fn create_test_graph() -> EmbeddingGraph {
607        let edges = vec![
608            ("A".to_string(), "B".to_string(), 1.0),
609            ("A".to_string(), "C".to_string(), 1.0),
610            ("B".to_string(), "C".to_string(), 1.0),
611            ("B".to_string(), "D".to_string(), 1.0),
612            ("C".to_string(), "D".to_string(), 1.0),
613            ("D".to_string(), "E".to_string(), 1.0),
614        ];
615
616        EmbeddingGraph::from_edges(edges)
617    }
618
619    #[test]
620    fn test_embedding_graph_creation() {
621        let graph = create_test_graph();
622        assert_eq!(graph.node_count(), 5);
623        assert!(graph.neighbors("A").is_some());
624        assert_eq!(graph.neighbors("A").unwrap().len(), 2);
625    }
626
627    #[test]
628    fn test_node2vec_initialization() {
629        let config = EmbeddingConfig::default();
630        let node2vec = Node2Vec::new(config);
631        assert_eq!(node2vec.embeddings.len(), 0);
632    }
633
634    #[test]
635    fn test_node2vec_fit() {
636        let graph = create_test_graph();
637        let config = EmbeddingConfig {
638            dimension: 64,
639            walk_length: 10,
640            walks_per_node: 5,
641            epochs: 1,
642            ..Default::default()
643        };
644
645        let mut node2vec = Node2Vec::new(config);
646        node2vec.fit(&graph);
647
648        assert_eq!(node2vec.embeddings.len(), 5);
649
650        for node in &graph.nodes {
651            let emb = node2vec.get_embedding(node).unwrap();
652            assert_eq!(emb.len(), 64);
653        }
654    }
655
656    #[test]
657    fn test_graphsage_fit() {
658        let graph = create_test_graph();
659        let config = GraphSAGEConfig {
660            dimension: 64,
661            num_layers: 2,
662            samples_per_layer: vec![3, 2],
663            aggregator: Aggregator::Mean,
664        };
665
666        let mut graphsage = GraphSAGE::new(config);
667        graphsage.fit(&graph);
668
669        assert_eq!(graphsage.embeddings.len(), 5);
670
671        for node in &graph.nodes {
672            let emb = graphsage.get_embedding(node).unwrap();
673            assert_eq!(emb.len(), 64);
674        }
675    }
676
677    #[test]
678    fn test_random_walk_generation() {
679        let graph = create_test_graph();
680        let config = EmbeddingConfig {
681            walk_length: 5,
682            walks_per_node: 1,
683            ..Default::default()
684        };
685
686        let node2vec = Node2Vec::new(config);
687        let walks = node2vec.generate_walks(&graph);
688
689        assert_eq!(walks.len(), 5); // 5 nodes * 1 walk per node
690        for walk in &walks {
691            assert!(walk.len() <= 5);
692            assert!(walk.len() > 0);
693        }
694    }
695}