Skip to main content

oxirs_graphrag/graph/
embeddings.rs

1//! Community-aware graph embeddings (GraphSAGE, Node2Vec)
2
3use crate::{GraphRAGError, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use scirs2_core::random::{rand_prelude::StdRng, seeded_rng, CoreRandom};
6use std::collections::{HashMap, HashSet};
7
8/// Community structure for embeddings
9#[derive(Debug, Clone)]
10pub struct CommunityStructure {
11    /// Node to community mapping
12    pub node_to_community: HashMap<String, usize>,
13    /// Community to nodes mapping
14    pub community_to_nodes: HashMap<usize, HashSet<String>>,
15    /// Modularity score
16    pub modularity: f64,
17}
18
19impl CommunityStructure {
20    /// Create from community assignments
21    pub fn from_assignments(assignments: &[(String, usize)], modularity: f64) -> Self {
22        let mut node_to_community = HashMap::new();
23        let mut community_to_nodes: HashMap<usize, HashSet<String>> = HashMap::new();
24
25        for (node, comm) in assignments {
26            node_to_community.insert(node.clone(), *comm);
27            community_to_nodes
28                .entry(*comm)
29                .or_default()
30                .insert(node.clone());
31        }
32
33        Self {
34            node_to_community,
35            community_to_nodes,
36            modularity,
37        }
38    }
39}
40
41/// Configuration for community-aware embeddings
42#[derive(Debug, Clone)]
43pub struct EmbeddingConfig {
44    /// Embedding dimension (default: 128)
45    pub embedding_dim: usize,
46    /// Walk length for Node2Vec (default: 80)
47    pub walk_length: usize,
48    /// Number of walks per node (default: 10)
49    pub num_walks: usize,
50    /// Return parameter p for Node2Vec (default: 1.0)
51    pub p: f64,
52    /// In-out parameter q for Node2Vec (default: 1.0)
53    pub q: f64,
54    /// Community bias for random walks (default: 2.0, higher = prefer same community)
55    pub community_bias: f64,
56    /// Window size for skip-gram (default: 5)
57    pub window_size: usize,
58    /// Random seed
59    pub random_seed: u64,
60}
61
62impl Default for EmbeddingConfig {
63    fn default() -> Self {
64        Self {
65            embedding_dim: 128,
66            walk_length: 80,
67            num_walks: 10,
68            p: 1.0,
69            q: 1.0,
70            community_bias: 2.0,
71            window_size: 5,
72            random_seed: 42,
73        }
74    }
75}
76
77/// Community-aware graph embeddings
78pub struct CommunityAwareEmbeddings {
79    config: EmbeddingConfig,
80    rng: CoreRandom<StdRng>,
81}
82
83impl CommunityAwareEmbeddings {
84    /// Create new embeddings generator
85    pub fn new(config: EmbeddingConfig) -> Self {
86        let rng = seeded_rng(config.random_seed);
87        Self { config, rng }
88    }
89
90    /// Generate embeddings using GraphSAGE with community awareness
91    pub fn embed_graphsage(
92        &mut self,
93        triples: &[Triple],
94        communities: &CommunityStructure,
95    ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
96        let (graph, node_map) = self.build_graph(triples);
97
98        if graph.node_count() == 0 {
99            return Ok(HashMap::new());
100        }
101
102        let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
103
104        // Initialize with random features
105        for (node_label, &node_idx) in &node_map {
106            let mut features = vec![0.0; self.config.embedding_dim];
107            for f in &mut features {
108                *f = self.rng.random_range(0.0..1.0) * 2.0 - 1.0; // [-1, 1]
109            }
110            embeddings.insert(node_label.clone(), features);
111        }
112
113        // GraphSAGE aggregation (2 iterations)
114        for _ in 0..2 {
115            let mut new_embeddings = embeddings.clone();
116
117            for (node_label, &node_idx) in &node_map {
118                let node_community = communities.node_to_community.get(node_label);
119
120                // Get neighbors, prioritizing same-community neighbors
121                let mut same_comm_neighbors = Vec::new();
122                let mut other_neighbors = Vec::new();
123
124                for neighbor_idx in graph.neighbors(node_idx) {
125                    if let Some(neighbor_label) = graph.node_weight(neighbor_idx) {
126                        let neighbor_community = communities.node_to_community.get(neighbor_label);
127
128                        if node_community == neighbor_community {
129                            same_comm_neighbors.push(neighbor_label.clone());
130                        } else {
131                            other_neighbors.push(neighbor_label.clone());
132                        }
133                    }
134                }
135
136                // Aggregate: prioritize same-community neighbors
137                let mut aggregated = vec![0.0; self.config.embedding_dim];
138                let mut count = 0.0;
139
140                for neighbor in &same_comm_neighbors {
141                    if let Some(neighbor_emb) = embeddings.get(neighbor) {
142                        for (i, &val) in neighbor_emb.iter().enumerate() {
143                            aggregated[i] += val * self.config.community_bias;
144                        }
145                        count += self.config.community_bias;
146                    }
147                }
148
149                for neighbor in &other_neighbors {
150                    if let Some(neighbor_emb) = embeddings.get(neighbor) {
151                        for (i, &val) in neighbor_emb.iter().enumerate() {
152                            aggregated[i] += val;
153                        }
154                        count += 1.0;
155                    }
156                }
157
158                if count > 0.0 {
159                    for val in &mut aggregated {
160                        *val /= count;
161                    }
162
163                    // Combine with own embedding
164                    if let Some(own_emb) = embeddings.get(node_label) {
165                        for (i, &val) in own_emb.iter().enumerate() {
166                            aggregated[i] = (aggregated[i] + val) / 2.0;
167                        }
168                    }
169
170                    new_embeddings.insert(node_label.clone(), aggregated);
171                }
172            }
173
174            embeddings = new_embeddings;
175        }
176
177        // Normalize embeddings
178        for emb in embeddings.values_mut() {
179            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
180            if norm > 0.0 {
181                for val in emb {
182                    *val /= norm;
183                }
184            }
185        }
186
187        Ok(embeddings)
188    }
189
190    /// Generate embeddings using Node2Vec with community-biased random walks
191    pub fn embed_node2vec(
192        &mut self,
193        triples: &[Triple],
194        communities: &CommunityStructure,
195    ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
196        let (graph, node_map) = self.build_graph(triples);
197
198        if graph.node_count() == 0 {
199            return Ok(HashMap::new());
200        }
201
202        // Generate community-biased random walks
203        let walks = self.generate_community_biased_walks(&graph, &node_map, communities)?;
204
205        // Train skip-gram model (simplified)
206        let embeddings = self.train_skip_gram(&walks, &node_map)?;
207
208        Ok(embeddings)
209    }
210
211    /// Generate community-biased random walks
212    fn generate_community_biased_walks(
213        &mut self,
214        graph: &UnGraph<String, ()>,
215        node_map: &HashMap<String, NodeIndex>,
216        communities: &CommunityStructure,
217    ) -> GraphRAGResult<Vec<Vec<String>>> {
218        let mut walks = Vec::new();
219
220        for _ in 0..self.config.num_walks {
221            for (node_label, &start_idx) in node_map {
222                let walk = self.node2vec_walk(graph, start_idx, node_label, communities);
223                walks.push(walk);
224            }
225        }
226
227        Ok(walks)
228    }
229
230    /// Single Node2Vec random walk with community bias
231    fn node2vec_walk(
232        &mut self,
233        graph: &UnGraph<String, ()>,
234        start: NodeIndex,
235        start_label: &str,
236        communities: &CommunityStructure,
237    ) -> Vec<String> {
238        let mut walk = vec![start_label.to_string()];
239        let mut current = start;
240        let mut prev: Option<NodeIndex> = None;
241        let start_community = communities.node_to_community.get(start_label);
242
243        for _ in 1..self.config.walk_length {
244            let neighbors: Vec<NodeIndex> = graph.neighbors(current).collect();
245
246            if neighbors.is_empty() {
247                break;
248            }
249
250            // Calculate transition probabilities with community bias
251            let mut probs = vec![0.0; neighbors.len()];
252
253            for (i, &neighbor) in neighbors.iter().enumerate() {
254                let mut prob = 1.0;
255
256                // Node2Vec bias
257                if let Some(p) = prev {
258                    if neighbor == p {
259                        prob /= self.config.p; // Return parameter
260                    } else if !graph.neighbors(p).any(|n| n == neighbor) {
261                        prob /= self.config.q; // In-out parameter
262                    }
263                }
264
265                // Community bias: prefer staying in same community
266                if let Some(neighbor_label) = graph.node_weight(neighbor) {
267                    let neighbor_community = communities.node_to_community.get(neighbor_label);
268                    if start_community == neighbor_community {
269                        prob *= self.config.community_bias;
270                    }
271                }
272
273                probs[i] = prob;
274            }
275
276            // Normalize probabilities
277            let sum: f64 = probs.iter().sum();
278            if sum > 0.0 {
279                for p in &mut probs {
280                    *p /= sum;
281                }
282            }
283
284            // Sample next node
285            let r = self.rng.random_range(0.0..1.0);
286            let mut cumsum = 0.0;
287            let mut next_idx = 0;
288
289            for (i, &p) in probs.iter().enumerate() {
290                cumsum += p;
291                if r < cumsum {
292                    next_idx = i;
293                    break;
294                }
295            }
296
297            let next = neighbors[next_idx];
298            if let Some(next_label) = graph.node_weight(next) {
299                walk.push(next_label.clone());
300            }
301
302            prev = Some(current);
303            current = next;
304        }
305
306        walk
307    }
308
309    /// Train skip-gram model on walks (simplified)
310    fn train_skip_gram(
311        &mut self,
312        walks: &[Vec<String>],
313        node_map: &HashMap<String, NodeIndex>,
314    ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
315        // Initialize embeddings randomly
316        let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
317        for node_label in node_map.keys() {
318            let mut emb = vec![0.0; self.config.embedding_dim];
319            for val in &mut emb {
320                *val = (self.rng.random_range(0.0..1.0) - 0.5) * 0.1; // Small random init
321            }
322            embeddings.insert(node_label.clone(), emb);
323        }
324
325        // Skip-gram training (simplified, no negative sampling)
326        let learning_rate = 0.025;
327        let num_epochs = 5;
328
329        for _ in 0..num_epochs {
330            for walk in walks {
331                for (i, target) in walk.iter().enumerate() {
332                    let start = i.saturating_sub(self.config.window_size);
333                    let end = (i + self.config.window_size + 1).min(walk.len());
334
335                    for (offset, context) in walk[start..end].iter().enumerate() {
336                        let j = start + offset;
337                        if i == j {
338                            continue;
339                        }
340
341                        // Update embeddings to be similar
342                        if let (Some(target_emb), Some(context_emb)) =
343                            (embeddings.get(target), embeddings.get(context))
344                        {
345                            let mut target_update = vec![0.0; self.config.embedding_dim];
346                            let mut context_update = vec![0.0; self.config.embedding_dim];
347
348                            for k in 0..self.config.embedding_dim {
349                                let diff = context_emb[k] - target_emb[k];
350                                target_update[k] = learning_rate * diff;
351                                context_update[k] = -learning_rate * diff;
352                            }
353
354                            if let Some(emb) = embeddings.get_mut(target) {
355                                for (k, &update) in target_update.iter().enumerate() {
356                                    emb[k] += update;
357                                }
358                            }
359
360                            if let Some(emb) = embeddings.get_mut(context) {
361                                for (k, &update) in context_update.iter().enumerate() {
362                                    emb[k] += update;
363                                }
364                            }
365                        }
366                    }
367                }
368            }
369        }
370
371        // Normalize
372        for emb in embeddings.values_mut() {
373            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
374            if norm > 0.0 {
375                for val in emb {
376                    *val /= norm;
377                }
378            }
379        }
380
381        Ok(embeddings)
382    }
383
384    /// Build graph from triples
385    fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
386        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
387        let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
388
389        for triple in triples {
390            let subj_idx = *node_map
391                .entry(triple.subject.clone())
392                .or_insert_with(|| graph.add_node(triple.subject.clone()));
393            let obj_idx = *node_map
394                .entry(triple.object.clone())
395                .or_insert_with(|| graph.add_node(triple.object.clone()));
396
397            if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
398                graph.add_edge(subj_idx, obj_idx, ());
399            }
400        }
401
402        (graph, node_map)
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_community_aware_embeddings() {
412        let triples = vec![
413            Triple::new("http://a", "http://rel", "http://b"),
414            Triple::new("http://b", "http://rel", "http://c"),
415            Triple::new("http://a", "http://rel", "http://c"),
416        ];
417
418        let assignments = vec![
419            ("http://a".to_string(), 0),
420            ("http://b".to_string(), 0),
421            ("http://c".to_string(), 0),
422        ];
423
424        let communities = CommunityStructure::from_assignments(&assignments, 0.8);
425
426        let config = EmbeddingConfig {
427            embedding_dim: 16,
428            ..Default::default()
429        };
430
431        let mut embedder = CommunityAwareEmbeddings::new(config);
432        let embeddings = embedder
433            .embed_graphsage(&triples, &communities)
434            .expect("embeddings failed");
435
436        assert_eq!(embeddings.len(), 3);
437        for emb in embeddings.values() {
438            assert_eq!(emb.len(), 16);
439        }
440    }
441
442    #[test]
443    fn test_node2vec_embeddings() {
444        let triples = vec![
445            Triple::new("http://a", "http://rel", "http://b"),
446            Triple::new("http://b", "http://rel", "http://c"),
447            Triple::new("http://c", "http://rel", "http://d"),
448        ];
449
450        let assignments = vec![
451            ("http://a".to_string(), 0),
452            ("http://b".to_string(), 0),
453            ("http://c".to_string(), 1),
454            ("http://d".to_string(), 1),
455        ];
456
457        let communities = CommunityStructure::from_assignments(&assignments, 0.7);
458
459        let config = EmbeddingConfig {
460            embedding_dim: 16,
461            walk_length: 10,
462            num_walks: 5,
463            ..Default::default()
464        };
465
466        let mut embedder = CommunityAwareEmbeddings::new(config);
467        let embeddings = embedder
468            .embed_node2vec(&triples, &communities)
469            .expect("embeddings failed");
470
471        assert_eq!(embeddings.len(), 4);
472    }
473}