1use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21use rand::Rng;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EmbeddingConfig {
26 pub dimension: usize,
28 pub walk_length: usize,
30 pub walks_per_node: usize,
32 pub context_size: usize,
34 pub return_param: f32,
36 pub inout_param: f32,
38 pub learning_rate: f32,
40 pub negative_samples: usize,
42 pub epochs: usize,
44}
45
46impl Default for EmbeddingConfig {
47 fn default() -> Self {
48 Self {
49 dimension: 128,
50 walk_length: 80,
51 walks_per_node: 10,
52 context_size: 10,
53 return_param: 1.0,
54 inout_param: 1.0,
55 learning_rate: 0.025,
56 negative_samples: 5,
57 epochs: 10,
58 }
59 }
60}
61
62pub struct EmbeddingGraph {
64 adjacency: HashMap<String, Vec<(String, f32)>>,
66 nodes: Vec<String>,
68 node_index: HashMap<String, usize>,
70}
71
72impl EmbeddingGraph {
73 pub fn from_edges(edges: Vec<(String, String, f32)>) -> Self {
78 let mut adjacency: HashMap<String, Vec<(String, f32)>> = HashMap::new();
79 let mut nodes_set = HashSet::new();
80
81 for (source, target, weight) in edges {
82 adjacency
83 .entry(source.clone())
84 .or_default()
85 .push((target.clone(), weight));
86
87 adjacency
88 .entry(target.clone())
89 .or_default()
90 .push((source.clone(), weight));
91
92 nodes_set.insert(source);
93 nodes_set.insert(target);
94 }
95
96 let nodes: Vec<String> = nodes_set.into_iter().collect();
97 let node_index: HashMap<String, usize> = nodes
98 .iter()
99 .enumerate()
100 .map(|(i, n)| (n.clone(), i))
101 .collect();
102
103 Self {
104 adjacency,
105 nodes,
106 node_index,
107 }
108 }
109
110 pub fn node_count(&self) -> usize {
112 self.nodes.len()
113 }
114
115 pub fn neighbors(&self, node: &str) -> Option<&Vec<(String, f32)>> {
117 self.adjacency.get(node)
118 }
119
120 pub fn get_index(&self, node: &str) -> Option<usize> {
122 self.node_index.get(node).copied()
123 }
124
125 pub fn get_node(&self, index: usize) -> Option<&String> {
127 self.nodes.get(index)
128 }
129}
130
131pub struct Node2Vec {
133 config: EmbeddingConfig,
134 embeddings: HashMap<String, Vec<f32>>,
136}
137
138impl Node2Vec {
139 pub fn new(config: EmbeddingConfig) -> Self {
141 Self {
142 config,
143 embeddings: HashMap::new(),
144 }
145 }
146
147 pub fn fit(&mut self, graph: &EmbeddingGraph) {
149 let walks = self.generate_walks(graph);
151
152 self.initialize_embeddings(graph);
154
155 self.train_skipgram(&walks);
157 }
158
159 fn generate_walks(&self, graph: &EmbeddingGraph) -> Vec<Vec<String>> {
161 let mut rng = rand::thread_rng();
162 let mut walks = Vec::new();
163
164 for _ in 0..self.config.walks_per_node {
165 for node in &graph.nodes {
166 let walk = self.random_walk(graph, node, &mut rng);
167 walks.push(walk);
168 }
169 }
170
171 walks
172 }
173
174 fn random_walk<R: Rng>(
176 &self,
177 graph: &EmbeddingGraph,
178 start: &str,
179 rng: &mut R,
180 ) -> Vec<String> {
181 let mut walk = vec![start.to_string()];
182
183 for _ in 1..self.config.walk_length {
184 let current = walk.last().unwrap();
185
186 if let Some(neighbors) = graph.neighbors(current) {
187 if neighbors.is_empty() {
188 break;
189 }
190
191 let next = if walk.len() == 1 {
193 &neighbors[rng.gen_range(0..neighbors.len())].0
195 } else {
196 let prev = &walk[walk.len() - 2];
198 self.sample_next(prev, current, neighbors, rng)
199 };
200
201 walk.push(next.clone());
202 } else {
203 break;
204 }
205 }
206
207 walk
208 }
209
210 fn sample_next<'a, R: Rng>(
212 &self,
213 prev: &str,
214 _current: &str,
215 neighbors: &'a [(String, f32)],
216 rng: &mut R,
217 ) -> &'a String {
218 let mut probs: Vec<f32> = neighbors
220 .iter()
221 .map(|(neighbor, weight)| {
222 let alpha = if neighbor == prev {
223 1.0 / self.config.return_param
225 } else {
226 1.0 / self.config.inout_param
228 };
229 weight * alpha
230 })
231 .collect();
232
233 let sum: f32 = probs.iter().sum();
235 if sum > 0.0 {
236 for p in &mut probs {
237 *p /= sum;
238 }
239 }
240
241 let r: f32 = rng.gen();
243 let mut cumsum = 0.0;
244 for (i, &prob) in probs.iter().enumerate() {
245 cumsum += prob;
246 if r <= cumsum {
247 return &neighbors[i].0;
248 }
249 }
250
251 &neighbors[neighbors.len() - 1].0
252 }
253
254 fn initialize_embeddings(&mut self, graph: &EmbeddingGraph) {
256 let mut rng = rand::thread_rng();
257
258 for node in &graph.nodes {
259 let embedding: Vec<f32> = (0..self.config.dimension)
260 .map(|_| (rng.gen::<f32>() - 0.5) / self.config.dimension as f32)
261 .collect();
262
263 self.embeddings.insert(node.clone(), embedding);
264 }
265 }
266
267 fn train_skipgram(&mut self, walks: &[Vec<String>]) {
269 for _ in 0..self.config.epochs {
270 for walk in walks {
271 for (i, node) in walk.iter().enumerate() {
272 let start = i.saturating_sub(self.config.context_size);
274 let end = (i + self.config.context_size + 1).min(walk.len());
275
276 for (j, context_node) in walk.iter().enumerate().take(end).skip(start) {
277 if i != j {
278 self.update_embeddings(node, context_node);
279 }
280 }
281 }
282 }
283 }
284 }
285
286 fn update_embeddings(&mut self, target: &str, context: &str) {
288 let lr = self.config.learning_rate;
292
293 if let (Some(target_emb), Some(context_emb)) =
294 (self.embeddings.get(target), self.embeddings.get(context))
295 {
296 let mut target_new = target_emb.clone();
298 let mut context_new = context_emb.clone();
299
300 for i in 0..self.config.dimension {
301 let diff = context_emb[i] - target_emb[i];
302 target_new[i] += lr * diff;
303 context_new[i] -= lr * diff;
304 }
305
306 self.embeddings.insert(target.to_string(), target_new);
307 self.embeddings.insert(context.to_string(), context_new);
308 }
309 }
310
311 pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
313 self.embeddings.get(node)
314 }
315
316 pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
318 &self.embeddings
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct GraphSAGEConfig {
325 pub dimension: usize,
327 pub num_layers: usize,
329 pub samples_per_layer: Vec<usize>,
331 pub aggregator: Aggregator,
333}
334
335#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
337pub enum Aggregator {
338 Mean,
340 MaxPool,
342 Lstm,
344 Attention,
346}
347
348impl Default for GraphSAGEConfig {
349 fn default() -> Self {
350 Self {
351 dimension: 128,
352 num_layers: 2,
353 samples_per_layer: vec![25, 10],
354 aggregator: Aggregator::Mean,
355 }
356 }
357}
358
359pub struct GraphSAGE {
361 config: GraphSAGEConfig,
362 embeddings: HashMap<String, Vec<f32>>,
363}
364
365impl GraphSAGE {
366 pub fn new(config: GraphSAGEConfig) -> Self {
368 Self {
369 config,
370 embeddings: HashMap::new(),
371 }
372 }
373
374 pub fn fit(&mut self, graph: &EmbeddingGraph) {
376 let mut rng = rand::thread_rng();
378 let mut node_features: HashMap<String, Vec<f32>> = HashMap::new();
379
380 for node in &graph.nodes {
381 let features: Vec<f32> = (0..self.config.dimension)
382 .map(|_| rng.gen::<f32>())
383 .collect();
384 node_features.insert(node.clone(), features);
385 }
386
387 for layer in 0..self.config.num_layers {
389 let samples = self.config.samples_per_layer.get(layer).copied().unwrap_or(10);
390 node_features = self.aggregate_layer(graph, &node_features, samples);
391 }
392
393 self.embeddings = node_features;
394 }
395
396 fn aggregate_layer(
398 &self,
399 graph: &EmbeddingGraph,
400 features: &HashMap<String, Vec<f32>>,
401 num_samples: usize,
402 ) -> HashMap<String, Vec<f32>> {
403 let mut rng = rand::thread_rng();
404 let mut new_features = HashMap::new();
405
406 for node in &graph.nodes {
407 let neighbors = if let Some(neighs) = graph.neighbors(node) {
409 let sample_size = num_samples.min(neighs.len());
410 let mut sampled = Vec::new();
411 let mut indices: Vec<usize> = (0..neighs.len()).collect();
412
413 for _ in 0..sample_size {
414 let idx = rng.gen_range(0..indices.len());
415 let neighbor_idx = indices.remove(idx);
416 sampled.push(&neighs[neighbor_idx].0);
417 }
418
419 sampled
420 } else {
421 Vec::new()
422 };
423
424 let aggregated = self.aggregate_neighbors(features, &neighbors);
426
427 let node_feat = features.get(node).unwrap();
429 let combined = self.combine_features(node_feat, &aggregated);
430
431 new_features.insert(node.clone(), combined);
432 }
433
434 new_features
435 }
436
437 fn aggregate_neighbors(
439 &self,
440 features: &HashMap<String, Vec<f32>>,
441 neighbors: &[&String],
442 ) -> Vec<f32> {
443 if neighbors.is_empty() {
444 return vec![0.0; self.config.dimension];
445 }
446
447 match self.config.aggregator {
448 Aggregator::Mean => {
449 let mut sum = vec![0.0; self.config.dimension];
450 for neighbor in neighbors {
451 if let Some(feat) = features.get(*neighbor) {
452 for i in 0..self.config.dimension {
453 sum[i] += feat[i];
454 }
455 }
456 }
457
458 for val in &mut sum {
459 *val /= neighbors.len() as f32;
460 }
461
462 sum
463 }
464 _ => {
465 let mut sum = vec![0.0; self.config.dimension];
468 for neighbor in neighbors {
469 if let Some(feat) = features.get(*neighbor) {
470 for i in 0..self.config.dimension {
471 sum[i] += feat[i];
472 }
473 }
474 }
475
476 for val in &mut sum {
477 *val /= neighbors.len() as f32;
478 }
479
480 sum
481 }
482 }
483 }
484
485 fn combine_features(&self, node_feat: &[f32], neighbor_feat: &[f32]) -> Vec<f32> {
487 let mut combined = Vec::with_capacity(self.config.dimension);
491
492 for i in 0..self.config.dimension {
493 combined.push((node_feat[i] + neighbor_feat[i]) / 2.0);
495 }
496
497 combined
498 }
499
500 pub fn get_embedding(&self, node: &str) -> Option<&Vec<f32>> {
502 self.embeddings.get(node)
503 }
504
505 pub fn embeddings(&self) -> &HashMap<String, Vec<f32>> {
507 &self.embeddings
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 fn create_test_graph() -> EmbeddingGraph {
516 let edges = vec![
517 ("A".to_string(), "B".to_string(), 1.0),
518 ("A".to_string(), "C".to_string(), 1.0),
519 ("B".to_string(), "C".to_string(), 1.0),
520 ("B".to_string(), "D".to_string(), 1.0),
521 ("C".to_string(), "D".to_string(), 1.0),
522 ("D".to_string(), "E".to_string(), 1.0),
523 ];
524
525 EmbeddingGraph::from_edges(edges)
526 }
527
528 #[test]
529 fn test_embedding_graph_creation() {
530 let graph = create_test_graph();
531 assert_eq!(graph.node_count(), 5);
532 assert!(graph.neighbors("A").is_some());
533 assert_eq!(graph.neighbors("A").unwrap().len(), 2);
534 }
535
536 #[test]
537 fn test_node2vec_initialization() {
538 let config = EmbeddingConfig::default();
539 let node2vec = Node2Vec::new(config);
540 assert_eq!(node2vec.embeddings.len(), 0);
541 }
542
543 #[test]
544 fn test_node2vec_fit() {
545 let graph = create_test_graph();
546 let config = EmbeddingConfig {
547 dimension: 64,
548 walk_length: 10,
549 walks_per_node: 5,
550 epochs: 1,
551 ..Default::default()
552 };
553
554 let mut node2vec = Node2Vec::new(config);
555 node2vec.fit(&graph);
556
557 assert_eq!(node2vec.embeddings.len(), 5);
558
559 for node in &graph.nodes {
560 let emb = node2vec.get_embedding(node).unwrap();
561 assert_eq!(emb.len(), 64);
562 }
563 }
564
565 #[test]
566 fn test_graphsage_fit() {
567 let graph = create_test_graph();
568 let config = GraphSAGEConfig {
569 dimension: 64,
570 num_layers: 2,
571 samples_per_layer: vec![3, 2],
572 aggregator: Aggregator::Mean,
573 };
574
575 let mut graphsage = GraphSAGE::new(config);
576 graphsage.fit(&graph);
577
578 assert_eq!(graphsage.embeddings.len(), 5);
579
580 for node in &graph.nodes {
581 let emb = graphsage.get_embedding(node).unwrap();
582 assert_eq!(emb.len(), 64);
583 }
584 }
585
586 #[test]
587 fn test_random_walk_generation() {
588 let graph = create_test_graph();
589 let config = EmbeddingConfig {
590 walk_length: 5,
591 walks_per_node: 1,
592 ..Default::default()
593 };
594
595 let node2vec = Node2Vec::new(config);
596 let walks = node2vec.generate_walks(&graph);
597
598 assert_eq!(walks.len(), 5); for walk in &walks {
600 assert!(walk.len() <= 5);
601 assert!(walk.len() > 0);
602 }
603 }
604}