1use crate::{GraphRAGError, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use scirs2_core::random::{rand_prelude::StdRng, seeded_rng, CoreRandom};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone)]
10pub struct CommunityStructure {
11 pub node_to_community: HashMap<String, usize>,
13 pub community_to_nodes: HashMap<usize, HashSet<String>>,
15 pub modularity: f64,
17}
18
19impl CommunityStructure {
20 pub fn from_assignments(assignments: &[(String, usize)], modularity: f64) -> Self {
22 let mut node_to_community = HashMap::new();
23 let mut community_to_nodes: HashMap<usize, HashSet<String>> = HashMap::new();
24
25 for (node, comm) in assignments {
26 node_to_community.insert(node.clone(), *comm);
27 community_to_nodes
28 .entry(*comm)
29 .or_default()
30 .insert(node.clone());
31 }
32
33 Self {
34 node_to_community,
35 community_to_nodes,
36 modularity,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct EmbeddingConfig {
44 pub embedding_dim: usize,
46 pub walk_length: usize,
48 pub num_walks: usize,
50 pub p: f64,
52 pub q: f64,
54 pub community_bias: f64,
56 pub window_size: usize,
58 pub random_seed: u64,
60}
61
62impl Default for EmbeddingConfig {
63 fn default() -> Self {
64 Self {
65 embedding_dim: 128,
66 walk_length: 80,
67 num_walks: 10,
68 p: 1.0,
69 q: 1.0,
70 community_bias: 2.0,
71 window_size: 5,
72 random_seed: 42,
73 }
74 }
75}
76
77pub struct CommunityAwareEmbeddings {
79 config: EmbeddingConfig,
80 rng: CoreRandom<StdRng>,
81}
82
83impl CommunityAwareEmbeddings {
84 pub fn new(config: EmbeddingConfig) -> Self {
86 let rng = seeded_rng(config.random_seed);
87 Self { config, rng }
88 }
89
90 pub fn embed_graphsage(
92 &mut self,
93 triples: &[Triple],
94 communities: &CommunityStructure,
95 ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
96 let (graph, node_map) = self.build_graph(triples);
97
98 if graph.node_count() == 0 {
99 return Ok(HashMap::new());
100 }
101
102 let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
103
104 for (node_label, &node_idx) in &node_map {
106 let mut features = vec![0.0; self.config.embedding_dim];
107 for f in &mut features {
108 *f = self.rng.random_range(0.0..1.0) * 2.0 - 1.0; }
110 embeddings.insert(node_label.clone(), features);
111 }
112
113 for _ in 0..2 {
115 let mut new_embeddings = embeddings.clone();
116
117 for (node_label, &node_idx) in &node_map {
118 let node_community = communities.node_to_community.get(node_label);
119
120 let mut same_comm_neighbors = Vec::new();
122 let mut other_neighbors = Vec::new();
123
124 for neighbor_idx in graph.neighbors(node_idx) {
125 if let Some(neighbor_label) = graph.node_weight(neighbor_idx) {
126 let neighbor_community = communities.node_to_community.get(neighbor_label);
127
128 if node_community == neighbor_community {
129 same_comm_neighbors.push(neighbor_label.clone());
130 } else {
131 other_neighbors.push(neighbor_label.clone());
132 }
133 }
134 }
135
136 let mut aggregated = vec![0.0; self.config.embedding_dim];
138 let mut count = 0.0;
139
140 for neighbor in &same_comm_neighbors {
141 if let Some(neighbor_emb) = embeddings.get(neighbor) {
142 for (i, &val) in neighbor_emb.iter().enumerate() {
143 aggregated[i] += val * self.config.community_bias;
144 }
145 count += self.config.community_bias;
146 }
147 }
148
149 for neighbor in &other_neighbors {
150 if let Some(neighbor_emb) = embeddings.get(neighbor) {
151 for (i, &val) in neighbor_emb.iter().enumerate() {
152 aggregated[i] += val;
153 }
154 count += 1.0;
155 }
156 }
157
158 if count > 0.0 {
159 for val in &mut aggregated {
160 *val /= count;
161 }
162
163 if let Some(own_emb) = embeddings.get(node_label) {
165 for (i, &val) in own_emb.iter().enumerate() {
166 aggregated[i] = (aggregated[i] + val) / 2.0;
167 }
168 }
169
170 new_embeddings.insert(node_label.clone(), aggregated);
171 }
172 }
173
174 embeddings = new_embeddings;
175 }
176
177 for emb in embeddings.values_mut() {
179 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
180 if norm > 0.0 {
181 for val in emb {
182 *val /= norm;
183 }
184 }
185 }
186
187 Ok(embeddings)
188 }
189
190 pub fn embed_node2vec(
192 &mut self,
193 triples: &[Triple],
194 communities: &CommunityStructure,
195 ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
196 let (graph, node_map) = self.build_graph(triples);
197
198 if graph.node_count() == 0 {
199 return Ok(HashMap::new());
200 }
201
202 let walks = self.generate_community_biased_walks(&graph, &node_map, communities)?;
204
205 let embeddings = self.train_skip_gram(&walks, &node_map)?;
207
208 Ok(embeddings)
209 }
210
211 fn generate_community_biased_walks(
213 &mut self,
214 graph: &UnGraph<String, ()>,
215 node_map: &HashMap<String, NodeIndex>,
216 communities: &CommunityStructure,
217 ) -> GraphRAGResult<Vec<Vec<String>>> {
218 let mut walks = Vec::new();
219
220 for _ in 0..self.config.num_walks {
221 for (node_label, &start_idx) in node_map {
222 let walk = self.node2vec_walk(graph, start_idx, node_label, communities);
223 walks.push(walk);
224 }
225 }
226
227 Ok(walks)
228 }
229
230 fn node2vec_walk(
232 &mut self,
233 graph: &UnGraph<String, ()>,
234 start: NodeIndex,
235 start_label: &str,
236 communities: &CommunityStructure,
237 ) -> Vec<String> {
238 let mut walk = vec![start_label.to_string()];
239 let mut current = start;
240 let mut prev: Option<NodeIndex> = None;
241 let start_community = communities.node_to_community.get(start_label);
242
243 for _ in 1..self.config.walk_length {
244 let neighbors: Vec<NodeIndex> = graph.neighbors(current).collect();
245
246 if neighbors.is_empty() {
247 break;
248 }
249
250 let mut probs = vec![0.0; neighbors.len()];
252
253 for (i, &neighbor) in neighbors.iter().enumerate() {
254 let mut prob = 1.0;
255
256 if let Some(p) = prev {
258 if neighbor == p {
259 prob /= self.config.p; } else if !graph.neighbors(p).any(|n| n == neighbor) {
261 prob /= self.config.q; }
263 }
264
265 if let Some(neighbor_label) = graph.node_weight(neighbor) {
267 let neighbor_community = communities.node_to_community.get(neighbor_label);
268 if start_community == neighbor_community {
269 prob *= self.config.community_bias;
270 }
271 }
272
273 probs[i] = prob;
274 }
275
276 let sum: f64 = probs.iter().sum();
278 if sum > 0.0 {
279 for p in &mut probs {
280 *p /= sum;
281 }
282 }
283
284 let r = self.rng.random_range(0.0..1.0);
286 let mut cumsum = 0.0;
287 let mut next_idx = 0;
288
289 for (i, &p) in probs.iter().enumerate() {
290 cumsum += p;
291 if r < cumsum {
292 next_idx = i;
293 break;
294 }
295 }
296
297 let next = neighbors[next_idx];
298 if let Some(next_label) = graph.node_weight(next) {
299 walk.push(next_label.clone());
300 }
301
302 prev = Some(current);
303 current = next;
304 }
305
306 walk
307 }
308
309 fn train_skip_gram(
311 &mut self,
312 walks: &[Vec<String>],
313 node_map: &HashMap<String, NodeIndex>,
314 ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
315 let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
317 for node_label in node_map.keys() {
318 let mut emb = vec![0.0; self.config.embedding_dim];
319 for val in &mut emb {
320 *val = (self.rng.random_range(0.0..1.0) - 0.5) * 0.1; }
322 embeddings.insert(node_label.clone(), emb);
323 }
324
325 let learning_rate = 0.025;
327 let num_epochs = 5;
328
329 for _ in 0..num_epochs {
330 for walk in walks {
331 for (i, target) in walk.iter().enumerate() {
332 let start = i.saturating_sub(self.config.window_size);
333 let end = (i + self.config.window_size + 1).min(walk.len());
334
335 for (offset, context) in walk[start..end].iter().enumerate() {
336 let j = start + offset;
337 if i == j {
338 continue;
339 }
340
341 if let (Some(target_emb), Some(context_emb)) =
343 (embeddings.get(target), embeddings.get(context))
344 {
345 let mut target_update = vec![0.0; self.config.embedding_dim];
346 let mut context_update = vec![0.0; self.config.embedding_dim];
347
348 for k in 0..self.config.embedding_dim {
349 let diff = context_emb[k] - target_emb[k];
350 target_update[k] = learning_rate * diff;
351 context_update[k] = -learning_rate * diff;
352 }
353
354 if let Some(emb) = embeddings.get_mut(target) {
355 for (k, &update) in target_update.iter().enumerate() {
356 emb[k] += update;
357 }
358 }
359
360 if let Some(emb) = embeddings.get_mut(context) {
361 for (k, &update) in context_update.iter().enumerate() {
362 emb[k] += update;
363 }
364 }
365 }
366 }
367 }
368 }
369 }
370
371 for emb in embeddings.values_mut() {
373 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
374 if norm > 0.0 {
375 for val in emb {
376 *val /= norm;
377 }
378 }
379 }
380
381 Ok(embeddings)
382 }
383
384 fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
386 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
387 let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
388
389 for triple in triples {
390 let subj_idx = *node_map
391 .entry(triple.subject.clone())
392 .or_insert_with(|| graph.add_node(triple.subject.clone()));
393 let obj_idx = *node_map
394 .entry(triple.object.clone())
395 .or_insert_with(|| graph.add_node(triple.object.clone()));
396
397 if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
398 graph.add_edge(subj_idx, obj_idx, ());
399 }
400 }
401
402 (graph, node_map)
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_community_aware_embeddings() {
412 let triples = vec![
413 Triple::new("http://a", "http://rel", "http://b"),
414 Triple::new("http://b", "http://rel", "http://c"),
415 Triple::new("http://a", "http://rel", "http://c"),
416 ];
417
418 let assignments = vec![
419 ("http://a".to_string(), 0),
420 ("http://b".to_string(), 0),
421 ("http://c".to_string(), 0),
422 ];
423
424 let communities = CommunityStructure::from_assignments(&assignments, 0.8);
425
426 let config = EmbeddingConfig {
427 embedding_dim: 16,
428 ..Default::default()
429 };
430
431 let mut embedder = CommunityAwareEmbeddings::new(config);
432 let embeddings = embedder
433 .embed_graphsage(&triples, &communities)
434 .expect("embeddings failed");
435
436 assert_eq!(embeddings.len(), 3);
437 for emb in embeddings.values() {
438 assert_eq!(emb.len(), 16);
439 }
440 }
441
442 #[test]
443 fn test_node2vec_embeddings() {
444 let triples = vec![
445 Triple::new("http://a", "http://rel", "http://b"),
446 Triple::new("http://b", "http://rel", "http://c"),
447 Triple::new("http://c", "http://rel", "http://d"),
448 ];
449
450 let assignments = vec![
451 ("http://a".to_string(), 0),
452 ("http://b".to_string(), 0),
453 ("http://c".to_string(), 1),
454 ("http://d".to_string(), 1),
455 ];
456
457 let communities = CommunityStructure::from_assignments(&assignments, 0.7);
458
459 let config = EmbeddingConfig {
460 embedding_dim: 16,
461 walk_length: 10,
462 num_walks: 5,
463 ..Default::default()
464 };
465
466 let mut embedder = CommunityAwareEmbeddings::new(config);
467 let embeddings = embedder
468 .embed_node2vec(&triples, &communities)
469 .expect("embeddings failed");
470
471 assert_eq!(embeddings.len(), 4);
472 }
473}