Skip to main content

graphrag_core/graph/
embeddings.rs

1//! Graph Embeddings
2//!
3//! This module provides graph embedding algorithms for converting graph structures
4//! into dense vector representations:
5//!
6//! - **Node2Vec**: Random walk-based embeddings capturing network neighborhoods
7//! - **GraphSAGE**: Inductive representation learning using neighborhood sampling
8//! - **DeepWalk**: Simplified random walk embeddings
9//! - **Struct2Vec**: Structure-aware graph embeddings
10//!
11//! ## Use Cases
12//!
13//! - Node classification and clustering
14//! - Link prediction
15//! - Graph visualization
16//! - Similarity search in graph space
17//! - Transfer learning across graphs
18
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use rand::Rng;
22
23/// Graph embedding configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EmbeddingConfig {
26    /// Embedding dimension
27    pub dimension: usize,
28    /// Walk length for random walks
29    pub walk_length: usize,
30    /// Number of walks per node
31    pub walks_per_node: usize,
32    /// Context window size (for Skip-Gram)
33    pub context_size: usize,
34    /// Return parameter (Node2Vec p)
35    pub return_param: f32,
36    /// In-out parameter (Node2Vec q)
37    pub inout_param: f32,
38    /// Learning rate
39    pub learning_rate: f32,
40    /// Number of negative samples
41    pub negative_samples: usize,
42    /// Number of training epochs
43    pub epochs: usize,
44}
45
46impl Default for EmbeddingConfig {
47    fn default() -> Self {
48        Self {
49            dimension: 128,
50            walk_length: 80,
51            walks_per_node: 10,
52            context_size: 10,
53            return_param: 1.0,
54            inout_param: 1.0,
55            learning_rate: 0.025,
56            negative_samples: 5,
57            epochs: 10,
58        }
59    }
60}
61
62/// Graph for embedding generation
63pub struct EmbeddingGraph {
64    /// Adjacency list: node_id -> [(neighbor_id, weight)]
65    adjacency: HashMap<String, Vec<(String, f32)>>,
66    /// All node IDs
67    nodes: Vec<String>,
68    /// Node index mapping
69    node_index: HashMap<String, usize>,
70}
71
72impl EmbeddingGraph {
73    /// Create embedding graph from edge list
74    ///
75    /// # Arguments
76    /// * `edges` - List of (source, target, weight) tuples
77    pub fn from_edges(edges: Vec<(String, String, f32)>) -> Self {
78        let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
79        let mut nodes_set = HashSet::new();
80
81        for (source, target, weight) in edges {
82            adjacency
83                .entry(source.clone())
84                .or_default()
85                .push((target.clone(), weight));
86
87            adjacency
88                .entry(target.clone())
89                .or_default()
90                .push((source.clone(), weight));
91
92            nodes_set.insert(source);
93            nodes_set.insert(target);
94        }
95
96        let nodes: Vec<String> = nodes_set.into_iter().collect();
97        let node_index: HashMap<String, usize> = nodes
98            .iter()
99            .enumerate()
100            .map(|(i, n)| (n.clone(), i))
101            .collect();
102
103        Self {
104            adjacency,
105            nodes,
106            node_index,
107        }
108    }
109
110    /// Get node count
111    pub fn node_count(&self) -> usize {
112        self.nodes.len()
113    }
114
115    /// Get neighbors of a node
116    pub fn neighbors(&self, node: &str) -> Option<&Vec<(String, f32)>> {
117        self.adjacency.get(node)
118    }
119
120    /// Get node index
121    pub fn get_index(&self, node: &str) -> Option<usize> {
122        self.node_index.get(node).copied()
123    }
124
125    /// Get node by index
126    pub fn get_node(&self, index: usize) -> Option<&String> {
127        self.nodes.get(index)
128    }
129}
130
131/// Node2Vec embeddings generator
132pub struct Node2Vec {
133    config: EmbeddingConfig,
134    /// Learned embeddings: node_id -> embedding vector
135    embeddings: HashMap<String, Vec<f32>>,
136}
137
138impl Node2Vec {
139    /// Create new Node2Vec generator
140    pub fn new(config: EmbeddingConfig) -> Self {
141        Self {
142            config,
143            embeddings: HashMap::new(),
144        }
145    }
146
147    /// Generate embeddings for graph
148    pub fn fit(&mut self, graph: &EmbeddingGraph) {
149        // Generate random walks
150        let walks = self.generate_walks(graph);
151
152        // Initialize embeddings randomly
153        self.initialize_embeddings(graph);
154
155        // Train Skip-Gram model on walks
156        self.train_skipgram(&walks);
157    }
158
159    /// Generate biased random walks (Node2Vec)
160    fn generate_walks(&self, graph: &EmbeddingGraph) -> Vec<Vec<String>> {
161        let mut rng = rand::thread_rng();
162        let mut walks = Vec::new();
163
164        for _ in 0..self.config.walks_per_node {
165            for node in &graph.nodes {
166                let walk = self.random_walk(graph, node, &mut rng);
167                walks.push(walk);
168            }
169        }
170
171        walks
172    }
173
174    /// Perform single biased random walk from starting node
175    fn random_walk<R: Rng>(
176        &self,
177        graph: &EmbeddingGraph,
178        start: &str,
179        rng: &mut R,
180    ) -> Vec<String> {
181        let mut walk = vec![start.to_string()];
182
183        for _ in 1..self.config.walk_length {
184            let current = walk.last().unwrap();
185
186            if let Some(neighbors) = graph.neighbors(current) {
187                if neighbors.is_empty() {
188                    break;
189                }
190
191                // Sample next node using biased probabilities
192                let next = if walk.len() == 1 {
193                    // First step: uniform random
194                    &neighbors[rng.gen_range(0..neighbors.len())].0
195                } else {
196                    // Subsequent steps: use Node2Vec bias
197                    let prev = &walk[walk.len() - 2];
198                    self.sample_next(prev, current, neighbors, rng)
199                };
200
201                walk.push(next.clone());
202            } else {
203                break;
204            }
205        }
206
207        walk
208    }
209
210    /// Sample next node with Node2Vec bias (p, q parameters)
211    fn sample_next<'a, R: Rng>(
212        &self,
213        prev: &str,
214        _current: &str,
215        neighbors: &'a [(String, f32)],
216        rng: &mut R,
217    ) -> &'a String {
218        // Calculate transition probabilities based on p and q
219        let mut probs: Vec<f32> = neighbors
220            .iter()
221            .map(|(neighbor, weight)| {
222                let alpha = if neighbor == prev {
223                    // Return to previous node
224                    1.0 / self.config.return_param
225                } else {
226                    // Check if neighbor is also neighbor of prev (BFS vs DFS)
227                    1.0 / self.config.inout_param
228                };
229                weight * alpha
230            })
231            .collect();
232
233        // Normalize probabilities
234        let sum: f32 = probs.iter().sum();
235        if sum > 0.0 {
236            for p in &mut probs {
237                *p /= sum;
238            }
239        }
240
241        // Sample using cumulative distribution
242        let r: f32 = rng.gen();
243        let mut cumsum = 0.0;
244        for (i, &prob) in probs.iter().enumerate() {
245            cumsum += prob;
246            if r <= cumsum {
247                return &neighbors[i].0;
248            }
249        }
250
251        &neighbors[neighbors.len() - 1].0
252    }
253
254    /// Initialize random embeddings
255    fn initialize_embeddings(&mut self, graph: &EmbeddingGraph) {
256        let mut rng = rand::thread_rng();
257
258        for node in &graph.nodes {
259            let embedding: Vec<f32> = (0..self.config.dimension)
260                .map(|_| (rng.gen::<f32>() - 0.5) / self.config.dimension as f32)
261                .collect();
262
263            self.embeddings.insert(node.clone(), embedding);
264        }
265    }
266
267    /// Train Skip-Gram model on walks
268    fn train_skipgram(&mut self, walks: &[Vec<String>]) {
269        for _ in 0..self.config.epochs {
270            for walk in walks {
271                for (i, node) in walk.iter().enumerate() {
272                    // Define context window
273                    let start = i.saturating_sub(self.config.context_size);
274                    let end = (i + self.config.context_size + 1).min(walk.len());
275
276                    for (j, context_node) in walk.iter().enumerate().take(end).skip(start) {
277                        if i != j {
278                            self.update_embeddings(node, context_node);
279                        }
280                    }
281                }
282            }
283        }
284    }
285
286    /// Update embeddings using Skip-Gram objective (simplified)
287    fn update_embeddings(&mut self, target: &str, context: &str) {
288        // Simplified update: move embeddings closer for positive pairs
289        // Real implementation would use negative sampling and gradient descent
290
291        let lr = self.config.learning_rate;
292
293        if let (Some(target_emb), Some(context_emb)) =
294            (self.embeddings.get(target), self.embeddings.get(context))
295        {
296            // Calculate gradient direction (simplified)
297            let mut target_new = target_emb.clone();
298            let mut context_new = context_emb.clone();
299
300            for i in 0..self.config.dimension {
301                let diff = context_emb[i] - target_emb[i];
302                target_new[i] += lr * diff;
303                context_new[i] -= lr * diff;
304            }
305
306            self.embeddings.insert(target.to_string(), target_new);
307            self.embeddings.insert(context.to_string(), context_new);
308        }
309    }
310
311    /// Get embedding for a node
312    pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
313        self.embeddings.get(node)
314    }
315
316    /// Get all embeddings
317    pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
318        &self.embeddings
319    }
320}
321
322/// GraphSAGE configuration
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct GraphSAGEConfig {
325    /// Embedding dimension
326    pub dimension: usize,
327    /// Number of layers
328    pub num_layers: usize,
329    /// Samples per layer
330    pub samples_per_layer: Vec<usize>,
331    /// Aggregation function
332    pub aggregator: Aggregator,
333}
334
335/// Aggregation functions for GraphSAGE
336#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
337pub enum Aggregator {
338    /// Mean aggregation
339    Mean,
340    /// Max pooling
341    MaxPool,
342    /// LSTM aggregation
343    Lstm,
344    /// Attention-based
345    Attention,
346}
347
348impl Default for GraphSAGEConfig {
349    fn default() -> Self {
350        Self {
351            dimension: 128,
352            num_layers: 2,
353            samples_per_layer: vec![25, 10],
354            aggregator: Aggregator::Mean,
355        }
356    }
357}
358
359/// GraphSAGE embeddings generator
360pub struct GraphSAGE {
361    config: GraphSAGEConfig,
362    embeddings: HashMap<String, Vec<f32>>,
363}
364
365impl GraphSAGE {
366    /// Create new GraphSAGE generator
367    pub fn new(config: GraphSAGEConfig) -> Self {
368        Self {
369            config,
370            embeddings: HashMap::new(),
371        }
372    }
373
374    /// Generate embeddings for graph (simplified inductive approach)
375    pub fn fit(&mut self, graph: &EmbeddingGraph) {
376        // Initialize with node features (random for now)
377        let mut rng = rand::thread_rng();
378        let mut node_features: HashMap<String, Vec<f32>> = HashMap::new();
379
380        for node in &graph.nodes {
381            let features: Vec<f32> = (0..self.config.dimension)
382                .map(|_| rng.gen::<f32>())
383                .collect();
384            node_features.insert(node.clone(), features);
385        }
386
387        // Iteratively aggregate neighborhood information
388        for layer in 0..self.config.num_layers {
389            let samples = self.config.samples_per_layer.get(layer).copied().unwrap_or(10);
390            node_features = self.aggregate_layer(graph, &node_features, samples);
391        }
392
393        self.embeddings = node_features;
394    }
395
396    /// Aggregate one layer of neighborhood information
397    fn aggregate_layer(
398        &self,
399        graph: &EmbeddingGraph,
400        features: &HashMap<String, Vec<f32>>,
401        num_samples: usize,
402    ) -> HashMap<String, Vec<f32>> {
403        let mut rng = rand::thread_rng();
404        let mut new_features = HashMap::new();
405
406        for node in &graph.nodes {
407            // Sample neighbors
408            let neighbors = if let Some(neighs) = graph.neighbors(node) {
409                let sample_size = num_samples.min(neighs.len());
410                let mut sampled = Vec::new();
411                let mut indices: Vec<usize> = (0..neighs.len()).collect();
412
413                for _ in 0..sample_size {
414                    let idx = rng.gen_range(0..indices.len());
415                    let neighbor_idx = indices.remove(idx);
416                    sampled.push(&neighs[neighbor_idx].0);
417                }
418
419                sampled
420            } else {
421                Vec::new()
422            };
423
424            // Aggregate neighbor features
425            let aggregated = self.aggregate_neighbors(features, &neighbors);
426
427            // Combine with node's own features
428            let node_feat = features.get(node).unwrap();
429            let combined = self.combine_features(node_feat, &aggregated);
430
431            new_features.insert(node.clone(), combined);
432        }
433
434        new_features
435    }
436
437    /// Aggregate neighbor features
438    fn aggregate_neighbors(
439        &self,
440        features: &HashMap<String, Vec<f32>>,
441        neighbors: &[&String],
442    ) -> Vec<f32> {
443        if neighbors.is_empty() {
444            return vec![0.0; self.config.dimension];
445        }
446
447        match self.config.aggregator {
448            Aggregator::Mean => {
449                let mut sum = vec![0.0; self.config.dimension];
450                for neighbor in neighbors {
451                    if let Some(feat) = features.get(*neighbor) {
452                        for i in 0..self.config.dimension {
453                            sum[i] += feat[i];
454                        }
455                    }
456                }
457
458                for val in &mut sum {
459                    *val /= neighbors.len() as f32;
460                }
461
462                sum
463            }
464            _ => {
465                // For now, default to mean for other aggregators
466                // TODO: Implement MaxPool, LSTM, Attention
467                let mut sum = vec![0.0; self.config.dimension];
468                for neighbor in neighbors {
469                    if let Some(feat) = features.get(*neighbor) {
470                        for i in 0..self.config.dimension {
471                            sum[i] += feat[i];
472                        }
473                    }
474                }
475
476                for val in &mut sum {
477                    *val /= neighbors.len() as f32;
478                }
479
480                sum
481            }
482        }
483    }
484
485    /// Combine node features with aggregated neighbor features
486    fn combine_features(&self, node_feat: &[f32], neighbor_feat: &[f32]) -> Vec<f32> {
487        // Simple concatenation followed by projection (simplified)
488        // Real implementation would use learned weight matrices
489
490        let mut combined = Vec::with_capacity(self.config.dimension);
491
492        for i in 0..self.config.dimension {
493            // Weighted combination
494            combined.push((node_feat[i] + neighbor_feat[i]) / 2.0);
495        }
496
497        combined
498    }
499
500    /// Get embedding for a node
501    pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
502        self.embeddings.get(node)
503    }
504
505    /// Get all embeddings
506    pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
507        &self.embeddings
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    fn create_test_graph() -> EmbeddingGraph {
516        let edges = vec![
517            ("A".to_string(), "B".to_string(), 1.0),
518            ("A".to_string(), "C".to_string(), 1.0),
519            ("B".to_string(), "C".to_string(), 1.0),
520            ("B".to_string(), "D".to_string(), 1.0),
521            ("C".to_string(), "D".to_string(), 1.0),
522            ("D".to_string(), "E".to_string(), 1.0),
523        ];
524
525        EmbeddingGraph::from_edges(edges)
526    }
527
528    #[test]
529    fn test_embedding_graph_creation() {
530        let graph = create_test_graph();
531        assert_eq!(graph.node_count(), 5);
532        assert!(graph.neighbors("A").is_some());
533        assert_eq!(graph.neighbors("A").unwrap().len(), 2);
534    }
535
536    #[test]
537    fn test_node2vec_initialization() {
538        let config = EmbeddingConfig::default();
539        let node2vec = Node2Vec::new(config);
540        assert_eq!(node2vec.embeddings.len(), 0);
541    }
542
543    #[test]
544    fn test_node2vec_fit() {
545        let graph = create_test_graph();
546        let config = EmbeddingConfig {
547            dimension: 64,
548            walk_length: 10,
549            walks_per_node: 5,
550            epochs: 1,
551            ..Default::default()
552        };
553
554        let mut node2vec = Node2Vec::new(config);
555        node2vec.fit(&graph);
556
557        assert_eq!(node2vec.embeddings.len(), 5);
558
559        for node in &graph.nodes {
560            let emb = node2vec.get_embedding(node).unwrap();
561            assert_eq!(emb.len(), 64);
562        }
563    }
564
565    #[test]
566    fn test_graphsage_fit() {
567        let graph = create_test_graph();
568        let config = GraphSAGEConfig {
569            dimension: 64,
570            num_layers: 2,
571            samples_per_layer: vec![3, 2],
572            aggregator: Aggregator::Mean,
573        };
574
575        let mut graphsage = GraphSAGE::new(config);
576        graphsage.fit(&graph);
577
578        assert_eq!(graphsage.embeddings.len(), 5);
579
580        for node in &graph.nodes {
581            let emb = graphsage.get_embedding(node).unwrap();
582            assert_eq!(emb.len(), 64);
583        }
584    }
585
586    #[test]
587    fn test_random_walk_generation() {
588        let graph = create_test_graph();
589        let config = EmbeddingConfig {
590            walk_length: 5,
591            walks_per_node: 1,
592            ..Default::default()
593        };
594
595        let node2vec = Node2Vec::new(config);
596        let walks = node2vec.generate_walks(&graph);
597
598        assert_eq!(walks.len(), 5); // 5 nodes * 1 walk per node
599        for walk in &walks {
600            assert!(walk.len() <= 5);
601            assert!(walk.len() > 0);
602        }
603    }
604}