1use rand::Rng;
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, HashSet};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EmbeddingConfig {
26 pub dimension: usize,
28 pub walk_length: usize,
30 pub walks_per_node: usize,
32 pub context_size: usize,
34 pub return_param: f32,
36 pub inout_param: f32,
38 pub learning_rate: f32,
40 pub negative_samples: usize,
42 pub epochs: usize,
44}
45
46impl Default for EmbeddingConfig {
47 fn default() -> Self {
48 Self {
49 dimension: 128,
50 walk_length: 80,
51 walks_per_node: 10,
52 context_size: 10,
53 return_param: 1.0,
54 inout_param: 1.0,
55 learning_rate: 0.025,
56 negative_samples: 5,
57 epochs: 10,
58 }
59 }
60}
61
62pub struct EmbeddingGraph {
64 adjacency: HashMap<String, Vec<(String, f32)>>,
66 nodes: Vec<String>,
68 node_index: HashMap<String, usize>,
70}
71
72impl EmbeddingGraph {
73 pub fn from_edges(edges: Vec<(String, String, f32)>) -> Self {
78 let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
79 let mut nodes_set = HashSet::new();
80
81 for (source, target, weight) in edges {
82 adjacency
83 .entry(source.clone())
84 .or_default()
85 .push((target.clone(), weight));
86
87 adjacency
88 .entry(target.clone())
89 .or_default()
90 .push((source.clone(), weight));
91
92 nodes_set.insert(source);
93 nodes_set.insert(target);
94 }
95
96 let nodes: Vec<String> = nodes_set.into_iter().collect();
97 let node_index: HashMap<String, usize> = nodes
98 .iter()
99 .enumerate()
100 .map(|(i, n)| (n.clone(), i))
101 .collect();
102
103 Self {
104 adjacency,
105 nodes,
106 node_index,
107 }
108 }
109
110 pub fn node_count(&self) -> usize {
112 self.nodes.len()
113 }
114
115 pub fn neighbors(&self, node: &str) -> Option<&Vec<(String, f32)>> {
117 self.adjacency.get(node)
118 }
119
120 pub fn get_index(&self, node: &str) -> Option<usize> {
122 self.node_index.get(node).copied()
123 }
124
125 pub fn get_node(&self, index: usize) -> Option<&String> {
127 self.nodes.get(index)
128 }
129}
130
131pub struct Node2Vec {
133 config: EmbeddingConfig,
134 embeddings: HashMap<String, Vec<f32>>,
136}
137
138impl Node2Vec {
139 pub fn new(config: EmbeddingConfig) -> Self {
141 Self {
142 config,
143 embeddings: HashMap::new(),
144 }
145 }
146
147 pub fn fit(&mut self, graph: &EmbeddingGraph) {
149 let walks = self.generate_walks(graph);
151
152 self.initialize_embeddings(graph);
154
155 self.train_skipgram(&walks);
157 }
158
159 fn generate_walks(&self, graph: &EmbeddingGraph) -> Vec<Vec<String>> {
161 let mut rng = rand::thread_rng();
162 let mut walks = Vec::new();
163
164 for _ in 0..self.config.walks_per_node {
165 for node in &graph.nodes {
166 let walk = self.random_walk(graph, node, &mut rng);
167 walks.push(walk);
168 }
169 }
170
171 walks
172 }
173
174 fn random_walk<R: Rng>(&self, graph: &EmbeddingGraph, start: &str, rng: &mut R) -> Vec<String> {
176 let mut walk = vec![start.to_string()];
177
178 for _ in 1..self.config.walk_length {
179 let current = walk.last().unwrap();
180
181 if let Some(neighbors) = graph.neighbors(current) {
182 if neighbors.is_empty() {
183 break;
184 }
185
186 let next = if walk.len() == 1 {
188 &neighbors[rng.gen_range(0..neighbors.len())].0
190 } else {
191 let prev = &walk[walk.len() - 2];
193 self.sample_next(prev, current, neighbors, rng)
194 };
195
196 walk.push(next.clone());
197 } else {
198 break;
199 }
200 }
201
202 walk
203 }
204
205 fn sample_next<'a, R: Rng>(
207 &self,
208 prev: &str,
209 _current: &str,
210 neighbors: &'a [(String, f32)],
211 rng: &mut R,
212 ) -> &'a String {
213 let mut probs: Vec<f32> = neighbors
215 .iter()
216 .map(|(neighbor, weight)| {
217 let alpha = if neighbor == prev {
218 1.0 / self.config.return_param
220 } else {
221 1.0 / self.config.inout_param
223 };
224 weight * alpha
225 })
226 .collect();
227
228 let sum: f32 = probs.iter().sum();
230 if sum > 0.0 {
231 for p in &mut probs {
232 *p /= sum;
233 }
234 }
235
236 let r: f32 = rng.gen();
238 let mut cumsum = 0.0;
239 for (i, &prob) in probs.iter().enumerate() {
240 cumsum += prob;
241 if r <= cumsum {
242 return &neighbors[i].0;
243 }
244 }
245
246 &neighbors[neighbors.len() - 1].0
247 }
248
249 fn initialize_embeddings(&mut self, graph: &EmbeddingGraph) {
251 let mut rng = rand::thread_rng();
252
253 for node in &graph.nodes {
254 let embedding: Vec<f32> = (0..self.config.dimension)
255 .map(|_| (rng.gen::<f32>() - 0.5) / self.config.dimension as f32)
256 .collect();
257
258 self.embeddings.insert(node.clone(), embedding);
259 }
260 }
261
262 fn train_skipgram(&mut self, walks: &[Vec<String>]) {
264 for _ in 0..self.config.epochs {
265 for walk in walks {
266 for (i, node) in walk.iter().enumerate() {
267 let start = i.saturating_sub(self.config.context_size);
269 let end = (i + self.config.context_size + 1).min(walk.len());
270
271 for (j, context_node) in walk.iter().enumerate().take(end).skip(start) {
272 if i != j {
273 self.update_embeddings(node, context_node);
274 }
275 }
276 }
277 }
278 }
279 }
280
281 fn update_embeddings(&mut self, target: &str, context: &str) {
283 let lr = self.config.learning_rate;
287
288 if let (Some(target_emb), Some(context_emb)) =
289 (self.embeddings.get(target), self.embeddings.get(context))
290 {
291 let mut target_new = target_emb.clone();
293 let mut context_new = context_emb.clone();
294
295 for i in 0..self.config.dimension {
296 let diff = context_emb[i] - target_emb[i];
297 target_new[i] += lr * diff;
298 context_new[i] -= lr * diff;
299 }
300
301 self.embeddings.insert(target.to_string(), target_new);
302 self.embeddings.insert(context.to_string(), context_new);
303 }
304 }
305
306 pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
308 self.embeddings.get(node)
309 }
310
311 pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
313 &self.embeddings
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct GraphSAGEConfig {
320 pub dimension: usize,
322 pub num_layers: usize,
324 pub samples_per_layer: Vec<usize>,
326 pub aggregator: Aggregator,
328}
329
330#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
332pub enum Aggregator {
333 Mean,
335 MaxPool,
337 Lstm,
339 Attention,
341}
342
343impl Default for GraphSAGEConfig {
344 fn default() -> Self {
345 Self {
346 dimension: 128,
347 num_layers: 2,
348 samples_per_layer: vec![25, 10],
349 aggregator: Aggregator::Mean,
350 }
351 }
352}
353
354pub struct GraphSAGE {
356 config: GraphSAGEConfig,
357 embeddings: HashMap<String, Vec<f32>>,
358}
359
360impl GraphSAGE {
361 pub fn new(config: GraphSAGEConfig) -> Self {
363 Self {
364 config,
365 embeddings: HashMap::new(),
366 }
367 }
368
369 pub fn fit(&mut self, graph: &EmbeddingGraph) {
371 let mut rng = rand::thread_rng();
373 let mut node_features: HashMap<String, Vec<f32>> = HashMap::new();
374
375 for node in &graph.nodes {
376 let features: Vec<f32> = (0..self.config.dimension)
377 .map(|_| rng.gen::<f32>())
378 .collect();
379 node_features.insert(node.clone(), features);
380 }
381
382 for layer in 0..self.config.num_layers {
384 let samples = self
385 .config
386 .samples_per_layer
387 .get(layer)
388 .copied()
389 .unwrap_or(10);
390 node_features = self.aggregate_layer(graph, &node_features, samples);
391 }
392
393 self.embeddings = node_features;
394 }
395
396 fn aggregate_layer(
398 &self,
399 graph: &EmbeddingGraph,
400 features: &HashMap<String, Vec<f32>>,
401 num_samples: usize,
402 ) -> HashMap<String, Vec<f32>> {
403 let mut rng = rand::thread_rng();
404 let mut new_features = HashMap::new();
405
406 for node in &graph.nodes {
407 let neighbors = if let Some(neighs) = graph.neighbors(node) {
409 let sample_size = num_samples.min(neighs.len());
410 let mut sampled = Vec::new();
411 let mut indices: Vec<usize> = (0..neighs.len()).collect();
412
413 for _ in 0..sample_size {
414 let idx = rng.gen_range(0..indices.len());
415 let neighbor_idx = indices.remove(idx);
416 sampled.push(&neighs[neighbor_idx].0);
417 }
418
419 sampled
420 } else {
421 Vec::new()
422 };
423
424 let aggregated = self.aggregate_neighbors(features, &neighbors);
426
427 let node_feat = features.get(node).unwrap();
429 let combined = self.combine_features(node_feat, &aggregated);
430
431 new_features.insert(node.clone(), combined);
432 }
433
434 new_features
435 }
436
437 fn aggregate_neighbors(
439 &self,
440 features: &HashMap<String, Vec<f32>>,
441 neighbors: &[&String],
442 ) -> Vec<f32> {
443 if neighbors.is_empty() {
444 return vec![0.0; self.config.dimension];
445 }
446
447 match self.config.aggregator {
448 Aggregator::Mean => {
449 let mut sum = vec![0.0; self.config.dimension];
450 for neighbor in neighbors {
451 if let Some(feat) = features.get(*neighbor) {
452 for i in 0..self.config.dimension {
453 sum[i] += feat[i];
454 }
455 }
456 }
457
458 for val in &mut sum {
459 *val /= neighbors.len() as f32;
460 }
461
462 sum
463 },
464 Aggregator::MaxPool => {
465 let mut max_feat = vec![f32::NEG_INFINITY; self.config.dimension];
467
468 for neighbor in neighbors {
469 if let Some(feat) = features.get(*neighbor) {
470 for i in 0..self.config.dimension {
471 max_feat[i] = max_feat[i].max(feat[i]);
472 }
473 }
474 }
475
476 if max_feat.iter().all(|&v| v == f32::NEG_INFINITY) {
478 vec![0.0; self.config.dimension]
479 } else {
480 max_feat
481 }
482 },
483 Aggregator::Attention => {
484 self.aggregate_attention(features, neighbors)
486 },
487 Aggregator::Lstm => {
488 self.aggregate_lstm(features, neighbors)
490 },
491 }
492 }
493
494 fn aggregate_attention(
496 &self,
497 features: &HashMap<String, Vec<f32>>,
498 neighbors: &[&String],
499 ) -> Vec<f32> {
500 if neighbors.is_empty() {
501 return vec![0.0; self.config.dimension];
502 }
503
504 let neighbor_feats: Vec<&Vec<f32>> =
506 neighbors.iter().filter_map(|n| features.get(*n)).collect();
507
508 if neighbor_feats.is_empty() {
509 return vec![0.0; self.config.dimension];
510 }
511
512 let mut attention_scores = Vec::with_capacity(neighbor_feats.len());
514 let mut score_sum = 0.0;
515
516 for feat in &neighbor_feats {
517 let score: f32 = feat.iter().sum();
519 let exp_score = score.exp();
520 attention_scores.push(exp_score);
521 score_sum += exp_score;
522 }
523
524 if score_sum > 0.0 {
526 for score in &mut attention_scores {
527 *score /= score_sum;
528 }
529 }
530
531 let mut result = vec![0.0; self.config.dimension];
533 for (feat, &weight) in neighbor_feats.iter().zip(attention_scores.iter()) {
534 for i in 0..self.config.dimension {
535 result[i] += feat[i] * weight;
536 }
537 }
538
539 result
540 }
541
542 fn aggregate_lstm(
544 &self,
545 features: &HashMap<String, Vec<f32>>,
546 neighbors: &[&String],
547 ) -> Vec<f32> {
548 if neighbors.is_empty() {
549 return vec![0.0; self.config.dimension];
550 }
551
552 let mut hidden_state = vec![0.0; self.config.dimension];
555 let decay: f32 = 0.9; for (idx, neighbor) in neighbors.iter().enumerate() {
558 if let Some(feat) = features.get(*neighbor) {
559 let weight = decay.powi(idx as i32);
561 for i in 0..self.config.dimension {
562 hidden_state[i] = hidden_state[i] * decay + feat[i] * weight;
563 }
564 }
565 }
566
567 let seq_len = neighbors.len() as f32;
569 for val in &mut hidden_state {
570 *val /= seq_len;
571 }
572
573 hidden_state
574 }
575
576 fn combine_features(&self, node_feat: &[f32], neighbor_feat: &[f32]) -> Vec<f32> {
578 let mut combined = Vec::with_capacity(self.config.dimension);
582
583 for i in 0..self.config.dimension {
584 combined.push((node_feat[i] + neighbor_feat[i]) / 2.0);
586 }
587
588 combined
589 }
590
591 pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
593 self.embeddings.get(node)
594 }
595
596 pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
598 &self.embeddings
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 fn create_test_graph() -> EmbeddingGraph {
607 let edges = vec![
608 ("A".to_string(), "B".to_string(), 1.0),
609 ("A".to_string(), "C".to_string(), 1.0),
610 ("B".to_string(), "C".to_string(), 1.0),
611 ("B".to_string(), "D".to_string(), 1.0),
612 ("C".to_string(), "D".to_string(), 1.0),
613 ("D".to_string(), "E".to_string(), 1.0),
614 ];
615
616 EmbeddingGraph::from_edges(edges)
617 }
618
619 #[test]
620 fn test_embedding_graph_creation() {
621 let graph = create_test_graph();
622 assert_eq!(graph.node_count(), 5);
623 assert!(graph.neighbors("A").is_some());
624 assert_eq!(graph.neighbors("A").unwrap().len(), 2);
625 }
626
627 #[test]
628 fn test_node2vec_initialization() {
629 let config = EmbeddingConfig::default();
630 let node2vec = Node2Vec::new(config);
631 assert_eq!(node2vec.embeddings.len(), 0);
632 }
633
634 #[test]
635 fn test_node2vec_fit() {
636 let graph = create_test_graph();
637 let config = EmbeddingConfig {
638 dimension: 64,
639 walk_length: 10,
640 walks_per_node: 5,
641 epochs: 1,
642 ..Default::default()
643 };
644
645 let mut node2vec = Node2Vec::new(config);
646 node2vec.fit(&graph);
647
648 assert_eq!(node2vec.embeddings.len(), 5);
649
650 for node in &graph.nodes {
651 let emb = node2vec.get_embedding(node).unwrap();
652 assert_eq!(emb.len(), 64);
653 }
654 }
655
656 #[test]
657 fn test_graphsage_fit() {
658 let graph = create_test_graph();
659 let config = GraphSAGEConfig {
660 dimension: 64,
661 num_layers: 2,
662 samples_per_layer: vec![3, 2],
663 aggregator: Aggregator::Mean,
664 };
665
666 let mut graphsage = GraphSAGE::new(config);
667 graphsage.fit(&graph);
668
669 assert_eq!(graphsage.embeddings.len(), 5);
670
671 for node in &graph.nodes {
672 let emb = graphsage.get_embedding(node).unwrap();
673 assert_eq!(emb.len(), 64);
674 }
675 }
676
677 #[test]
678 fn test_random_walk_generation() {
679 let graph = create_test_graph();
680 let config = EmbeddingConfig {
681 walk_length: 5,
682 walks_per_node: 1,
683 ..Default::default()
684 };
685
686 let node2vec = Node2Vec::new(config);
687 let walks = node2vec.generate_walks(&graph);
688
689 assert_eq!(walks.len(), 5); for walk in &walks {
691 assert!(walk.len() <= 5);
692 assert!(walk.len() > 0);
693 }
694 }
695}