1use super::core::{Embedding, EmbeddingModel};
12use super::negative_sampling::NegativeSampler;
13use super::random_walk::RandomWalkGenerator;
14use super::types::{DeepWalkConfig, RandomWalk};
15use crate::base::{DiGraph, EdgeWeight, Graph, Node};
16use crate::error::{GraphError, Result};
17use scirs2_core::random::seq::SliceRandom;
18use scirs2_core::random::{Rng, RngExt};
19use std::collections::HashMap;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum DeepWalkMode {
24 NegativeSampling,
26 HierarchicalSoftmax,
28}
29
30#[derive(Debug, Clone)]
32struct HuffmanNode {
33 code: Vec<bool>,
35 point: Vec<usize>,
37}
38
39#[derive(Debug)]
41struct HuffmanTree {
42 codes: Vec<HuffmanNode>,
44 num_internal: usize,
46}
47
48impl HuffmanTree {
49 fn build(frequencies: &[f64]) -> Result<Self> {
51 let n = frequencies.len();
52 if n == 0 {
53 return Err(GraphError::InvalidGraph(
54 "Cannot build Huffman tree from empty frequency list".to_string(),
55 ));
56 }
57
58 if n == 1 {
59 let codes = vec![HuffmanNode {
61 code: vec![false],
62 point: vec![0],
63 }];
64 return Ok(HuffmanTree {
65 codes,
66 num_internal: 1,
67 });
68 }
69
70 let total = 2 * n - 1;
73 let mut count = vec![0.0f64; total];
74 let mut parent = vec![0usize; total];
75 let mut binary = vec![false; total]; for (i, &freq) in frequencies.iter().enumerate() {
79 count[i] = freq.max(1e-10); }
81
82 for i in n..total {
84 count[i] = f64::MAX;
85 }
86
87 let mut pos1 = n - 1; let mut pos2 = n; let mut sorted_indices: Vec<usize> = (0..n).collect();
93 sorted_indices.sort_by(|&a, &b| {
94 count[a]
95 .partial_cmp(&count[b])
96 .unwrap_or(std::cmp::Ordering::Equal)
97 });
98
99 let mut sorted_counts = vec![0.0; n];
101 let mut reverse_map = vec![0usize; n]; for (sorted_pos, &orig_idx) in sorted_indices.iter().enumerate() {
103 sorted_counts[sorted_pos] = count[orig_idx];
104 reverse_map[orig_idx] = sorted_pos;
105 }
106 count[..n].copy_from_slice(&sorted_counts[..n]);
107
108 for internal_idx in n..total {
110 let min1;
112 let min2;
113
114 if pos1 < n && (pos2 >= internal_idx || count[pos1] < count[pos2]) {
116 min1 = pos1;
117 pos1 = pos1.wrapping_sub(1); if pos1 == usize::MAX {
119 pos1 = n; }
121 } else {
122 min1 = pos2;
123 pos2 += 1;
124 }
125
126 if pos1 < n && (pos2 >= internal_idx || count[pos1] < count[pos2]) {
128 min2 = pos1;
129 pos1 = pos1.wrapping_sub(1);
130 if pos1 == usize::MAX {
131 pos1 = n;
132 }
133 } else if pos2 < internal_idx {
134 min2 = pos2;
135 pos2 += 1;
136 } else {
137 min2 = min1; }
139
140 count[internal_idx] = count[min1] + count[min2];
141 parent[min1] = internal_idx;
142 parent[min2] = internal_idx;
143 binary[min2] = true; }
145
146 let mut codes = vec![
148 HuffmanNode {
149 code: Vec::new(),
150 point: Vec::new(),
151 };
152 n
153 ];
154
155 for sorted_pos in 0..n {
156 let mut code = Vec::new();
157 let mut point = Vec::new();
158
159 let mut current = sorted_pos;
160 while current < total - 1 {
161 code.push(binary[current]);
163 let par = parent[current];
164 if par >= n {
166 point.push(par - n);
167 }
168 current = par;
169 }
170
171 code.reverse();
173 point.reverse();
174
175 let orig_idx = sorted_indices[sorted_pos];
177 codes[orig_idx] = HuffmanNode { code, point };
178 }
179
180 Ok(HuffmanTree {
181 codes,
182 num_internal: n - 1,
183 })
184 }
185}
186
187pub struct DeepWalk<N: Node> {
193 config: DeepWalkConfig,
194 model: EmbeddingModel<N>,
195 walk_generator: RandomWalkGenerator<N>,
196 mode: DeepWalkMode,
198 internal_vectors: Vec<Vec<f64>>,
200}
201
202impl<N: Node> DeepWalk<N> {
203 pub fn new(config: DeepWalkConfig) -> Self {
205 DeepWalk {
206 model: EmbeddingModel::new(config.dimensions),
207 config,
208 walk_generator: RandomWalkGenerator::new(),
209 mode: DeepWalkMode::NegativeSampling,
210 internal_vectors: Vec::new(),
211 }
212 }
213
214 pub fn with_hierarchical_softmax(config: DeepWalkConfig) -> Self {
216 DeepWalk {
217 model: EmbeddingModel::new(config.dimensions),
218 config,
219 walk_generator: RandomWalkGenerator::new(),
220 mode: DeepWalkMode::HierarchicalSoftmax,
221 internal_vectors: Vec::new(),
222 }
223 }
224
225 pub fn set_mode(&mut self, mode: DeepWalkMode) {
227 self.mode = mode;
228 }
229
230 pub fn mode(&self) -> DeepWalkMode {
232 self.mode
233 }
234
235 pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
237 where
238 N: Clone + std::fmt::Debug,
239 E: EdgeWeight,
240 Ix: petgraph::graph::IndexType,
241 {
242 let mut all_walks = Vec::new();
243
244 for node in graph.nodes() {
245 for _ in 0..self.config.num_walks {
246 let walk =
247 self.walk_generator
248 .simple_random_walk(graph, node, self.config.walk_length)?;
249 all_walks.push(walk);
250 }
251 }
252
253 Ok(all_walks)
254 }
255
256 pub fn generate_walks_digraph<E, Ix>(
258 &mut self,
259 graph: &DiGraph<N, E, Ix>,
260 ) -> Result<Vec<RandomWalk<N>>>
261 where
262 N: Clone + std::fmt::Debug,
263 E: EdgeWeight,
264 Ix: petgraph::graph::IndexType,
265 {
266 let mut all_walks = Vec::new();
267
268 for node in graph.nodes() {
269 for _ in 0..self.config.num_walks {
270 let walk = self.walk_generator.simple_random_walk_digraph(
271 graph,
272 node,
273 self.config.walk_length,
274 )?;
275 all_walks.push(walk);
276 }
277 }
278
279 Ok(all_walks)
280 }
281
282 pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
284 where
285 N: Clone + std::fmt::Debug,
286 E: EdgeWeight,
287 Ix: petgraph::graph::IndexType,
288 {
289 let mut rng = scirs2_core::random::rng();
291 self.model.initialize_random(graph, &mut rng);
292
293 match self.mode {
294 DeepWalkMode::NegativeSampling => {
295 self.train_negative_sampling(graph, &mut rng)?;
296 }
297 DeepWalkMode::HierarchicalSoftmax => {
298 self.train_hierarchical_softmax(graph, &mut rng)?;
299 }
300 }
301
302 Ok(())
303 }
304
305 fn train_negative_sampling<E, Ix>(
307 &mut self,
308 graph: &Graph<N, E, Ix>,
309 rng: &mut impl Rng,
310 ) -> Result<()>
311 where
312 N: Clone + std::fmt::Debug,
313 E: EdgeWeight,
314 Ix: petgraph::graph::IndexType,
315 {
316 let negative_sampler = NegativeSampler::new(graph);
317
318 for epoch in 0..self.config.epochs {
319 let walks = self.generate_walks(graph)?;
320 let context_pairs =
321 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
322
323 let mut shuffled_pairs = context_pairs;
324 shuffled_pairs.shuffle(rng);
325
326 let current_lr = self.config.learning_rate
327 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
328
329 self.model.train_skip_gram(
330 &shuffled_pairs,
331 &negative_sampler,
332 current_lr,
333 self.config.negative_samples,
334 rng,
335 )?;
336 }
337
338 Ok(())
339 }
340
341 fn train_hierarchical_softmax<E, Ix>(
343 &mut self,
344 graph: &Graph<N, E, Ix>,
345 rng: &mut impl Rng,
346 ) -> Result<()>
347 where
348 N: Clone + std::fmt::Debug,
349 E: EdgeWeight,
350 Ix: petgraph::graph::IndexType,
351 {
352 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
353 let n = nodes.len();
354
355 if n == 0 {
356 return Err(GraphError::InvalidGraph(
357 "Cannot train on empty graph".to_string(),
358 ));
359 }
360
361 let node_to_idx: HashMap<N, usize> = nodes
363 .iter()
364 .enumerate()
365 .map(|(i, n)| (n.clone(), i))
366 .collect();
367
368 let frequencies: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64 + 1.0).collect();
370
371 let huffman = HuffmanTree::build(&frequencies)?;
373
374 let dim = self.config.dimensions;
376 self.internal_vectors = (0..huffman.num_internal).map(|_| vec![0.0; dim]).collect();
377
378 for epoch in 0..self.config.epochs {
380 let walks = self.generate_walks(graph)?;
381
382 let current_lr = self.config.learning_rate
383 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
384
385 for walk in &walks {
387 let walk_indices: Vec<usize> = walk
388 .nodes
389 .iter()
390 .filter_map(|n| node_to_idx.get(n).copied())
391 .collect();
392
393 for (i, &target_idx) in walk_indices.iter().enumerate() {
395 let start = i.saturating_sub(self.config.window_size);
396 let end = (i + self.config.window_size + 1).min(walk_indices.len());
397
398 for j in start..end {
399 if i == j {
400 continue;
401 }
402
403 let context_idx = walk_indices[j];
404 self.hierarchical_softmax_update(
405 &nodes[target_idx],
406 context_idx,
407 &huffman,
408 current_lr,
409 );
410 }
411 }
412 }
413
414 let _ = rng; }
417
418 Ok(())
419 }
420
421 fn hierarchical_softmax_update(
423 &mut self,
424 target_node: &N,
425 context_idx: usize,
426 huffman: &HuffmanTree,
427 learning_rate: f64,
428 ) where
429 N: Clone,
430 {
431 let dim = self.config.dimensions;
432
433 if context_idx >= huffman.codes.len() {
434 return;
435 }
436
437 let huffman_node = &huffman.codes[context_idx];
438
439 let target_emb = match self.model.embeddings.get(target_node) {
441 Some(e) => e.vector.clone(),
442 None => return,
443 };
444
445 let mut grad = vec![0.0; dim];
446
447 for (step, (&is_right, &internal_idx)) in huffman_node
449 .code
450 .iter()
451 .zip(huffman_node.point.iter())
452 .enumerate()
453 {
454 if internal_idx >= self.internal_vectors.len() {
455 continue;
456 }
457
458 let dot: f64 = target_emb
460 .iter()
461 .zip(self.internal_vectors[internal_idx].iter())
462 .map(|(a, b)| a * b)
463 .sum();
464
465 let sig = 1.0 / (1.0 + (-dot).exp());
466
467 let label = if is_right { 0.0 } else { 1.0 };
469 let g = learning_rate * (label - sig);
470
471 for d in 0..dim {
473 grad[d] += g * self.internal_vectors[internal_idx][d];
474 }
475
476 for d in 0..dim {
478 self.internal_vectors[internal_idx][d] += g * target_emb[d];
479 }
480
481 let _ = step; }
483
484 if let Some(emb) = self.model.embeddings.get_mut(target_node) {
486 for d in 0..dim {
487 emb.vector[d] += grad[d];
488 }
489 }
490 }
491
492 pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
494 where
495 N: Clone + std::fmt::Debug,
496 E: EdgeWeight,
497 Ix: petgraph::graph::IndexType,
498 {
499 let mut rng = scirs2_core::random::rng();
500 self.model.initialize_random_digraph(graph, &mut rng);
501
502 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
505 let degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64 + 1.0).collect();
506 let total: f64 = degrees.iter().sum();
507 let powered: Vec<f64> = degrees.iter().map(|d| (d / total).powf(0.75)).collect();
508 let total_powered: f64 = powered.iter().sum();
509 let probs: Vec<f64> = powered.iter().map(|p| p / total_powered).collect();
510
511 let mut cumulative = vec![0.0; probs.len()];
512 if !cumulative.is_empty() {
513 cumulative[0] = probs[0];
514 for i in 1..probs.len() {
515 cumulative[i] = cumulative[i - 1] + probs[i];
516 }
517 }
518
519 for epoch in 0..self.config.epochs {
520 let walks = self.generate_walks_digraph(graph)?;
521 let context_pairs =
522 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
523
524 let mut shuffled_pairs = context_pairs;
525 shuffled_pairs.shuffle(&mut rng);
526
527 let current_lr = self.config.learning_rate
528 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
529
530 let dim = self.config.dimensions;
531 let num_neg = self.config.negative_samples;
532
533 for pair in &shuffled_pairs {
535 let target_emb = match self.model.embeddings.get(&pair.target) {
536 Some(e) => e.clone(),
537 None => continue,
538 };
539 let context_emb = match self.model.context_embeddings.get(&pair.context) {
540 Some(e) => e.clone(),
541 None => continue,
542 };
543
544 let dot: f64 = target_emb
545 .vector
546 .iter()
547 .zip(context_emb.vector.iter())
548 .map(|(a, b)| a * b)
549 .sum();
550 let sig = 1.0 / (1.0 + (-dot).exp());
551 let g = current_lr * (1.0 - sig);
552
553 let mut target_grad = vec![0.0; dim];
554 for d in 0..dim {
555 target_grad[d] = g * context_emb.vector[d];
556 }
557
558 if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
559 for d in 0..dim {
560 ctx.vector[d] += g * target_emb.vector[d];
561 }
562 }
563
564 for _ in 0..num_neg {
566 let r = rng.random::<f64>();
567 let neg_idx = cumulative
568 .iter()
569 .position(|&c| r <= c)
570 .unwrap_or(cumulative.len().saturating_sub(1));
571
572 if neg_idx >= nodes.len() {
573 continue;
574 }
575 let neg_node = &nodes[neg_idx];
576 if neg_node == &pair.target || neg_node == &pair.context {
577 continue;
578 }
579
580 if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
581 let neg_dot: f64 = target_emb
582 .vector
583 .iter()
584 .zip(neg_emb.vector.iter())
585 .map(|(a, b)| a * b)
586 .sum();
587 let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
588 let neg_g = current_lr * (-neg_sig);
589
590 for d in 0..dim {
591 target_grad[d] += neg_g * neg_emb.vector[d];
592 }
593
594 if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
595 for d in 0..dim {
596 neg_ctx.vector[d] += neg_g * target_emb.vector[d];
597 }
598 }
599 }
600 }
601
602 if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
603 for d in 0..dim {
604 target.vector[d] += target_grad[d];
605 }
606 }
607 }
608 }
609
610 Ok(())
611 }
612
613 pub fn model(&self) -> &EmbeddingModel<N> {
615 &self.model
616 }
617
618 pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
620 &mut self.model
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 fn make_triangle() -> Graph<i32, f64> {
629 let mut g = Graph::new();
630 for i in 0..3 {
631 g.add_node(i);
632 }
633 let _ = g.add_edge(0, 1, 1.0);
634 let _ = g.add_edge(1, 2, 1.0);
635 let _ = g.add_edge(0, 2, 1.0);
636 g
637 }
638
639 fn make_path_graph() -> Graph<i32, f64> {
640 let mut g = Graph::new();
641 for i in 0..5 {
642 g.add_node(i);
643 }
644 let _ = g.add_edge(0, 1, 1.0);
645 let _ = g.add_edge(1, 2, 1.0);
646 let _ = g.add_edge(2, 3, 1.0);
647 let _ = g.add_edge(3, 4, 1.0);
648 g
649 }
650
651 fn make_directed_cycle() -> DiGraph<i32, f64> {
652 let mut g = DiGraph::new();
653 for i in 0..4 {
654 g.add_node(i);
655 }
656 let _ = g.add_edge(0, 1, 1.0);
657 let _ = g.add_edge(1, 2, 1.0);
658 let _ = g.add_edge(2, 3, 1.0);
659 let _ = g.add_edge(3, 0, 1.0);
660 g
661 }
662
663 #[test]
664 fn test_deepwalk_negative_sampling() {
665 let g = make_triangle();
666 let config = DeepWalkConfig {
667 dimensions: 8,
668 walk_length: 5,
669 num_walks: 3,
670 window_size: 2,
671 epochs: 2,
672 learning_rate: 0.025,
673 negative_samples: 2,
674 };
675
676 let mut dw = DeepWalk::new(config);
677 assert_eq!(dw.mode(), DeepWalkMode::NegativeSampling);
678
679 let result = dw.train(&g);
680 assert!(
681 result.is_ok(),
682 "DeepWalk negative sampling training should succeed"
683 );
684
685 for node in [0, 1, 2] {
686 assert!(
687 dw.model().get_embedding(&node).is_some(),
688 "Node {node} should have embedding"
689 );
690 }
691 }
692
693 #[test]
694 fn test_deepwalk_hierarchical_softmax() {
695 let g = make_triangle();
696 let config = DeepWalkConfig {
697 dimensions: 8,
698 walk_length: 5,
699 num_walks: 3,
700 window_size: 2,
701 epochs: 2,
702 learning_rate: 0.025,
703 negative_samples: 2,
704 };
705
706 let mut dw = DeepWalk::with_hierarchical_softmax(config);
707 assert_eq!(dw.mode(), DeepWalkMode::HierarchicalSoftmax);
708
709 let result = dw.train(&g);
710 assert!(
711 result.is_ok(),
712 "DeepWalk hierarchical softmax training should succeed"
713 );
714
715 for node in [0, 1, 2] {
716 assert!(
717 dw.model().get_embedding(&node).is_some(),
718 "Node {node} should have embedding"
719 );
720 }
721 }
722
723 #[test]
724 fn test_deepwalk_walk_generation() {
725 let g = make_path_graph();
726 let config = DeepWalkConfig {
727 dimensions: 8,
728 walk_length: 4,
729 num_walks: 2,
730 ..Default::default()
731 };
732
733 let mut dw = DeepWalk::new(config);
734 let walks = dw.generate_walks(&g);
735 assert!(walks.is_ok());
736
737 let walks = walks.expect("walks should be valid");
738 assert_eq!(walks.len(), 10);
740
741 for walk in &walks {
742 assert!(!walk.nodes.is_empty());
743 assert!(walk.nodes.len() <= 4);
744 for node in &walk.nodes {
746 assert!((0..5).contains(node));
747 }
748 }
749 }
750
751 #[test]
752 fn test_deepwalk_digraph() {
753 let g = make_directed_cycle();
754 let config = DeepWalkConfig {
755 dimensions: 8,
756 walk_length: 6,
757 num_walks: 3,
758 window_size: 2,
759 epochs: 2,
760 learning_rate: 0.025,
761 negative_samples: 2,
762 };
763
764 let mut dw = DeepWalk::new(config);
765 let result = dw.train_digraph(&g);
766 assert!(result.is_ok(), "DiGraph DeepWalk training should succeed");
767
768 for node in 0..4 {
769 assert!(
770 dw.model().get_embedding(&node).is_some(),
771 "Node {node} should have embedding in directed graph"
772 );
773 }
774 }
775
776 #[test]
777 fn test_deepwalk_mode_switching() {
778 let g = make_triangle();
779 let config = DeepWalkConfig {
780 dimensions: 8,
781 walk_length: 5,
782 num_walks: 2,
783 epochs: 1,
784 ..Default::default()
785 };
786
787 let mut dw = DeepWalk::new(config);
788 assert_eq!(dw.mode(), DeepWalkMode::NegativeSampling);
789
790 dw.set_mode(DeepWalkMode::HierarchicalSoftmax);
791 assert_eq!(dw.mode(), DeepWalkMode::HierarchicalSoftmax);
792
793 let result = dw.train(&g);
794 assert!(result.is_ok());
795 }
796
797 #[test]
798 fn test_deepwalk_embedding_dimensions() {
799 let g = make_triangle();
800 let config = DeepWalkConfig {
801 dimensions: 32,
802 walk_length: 5,
803 num_walks: 2,
804 epochs: 1,
805 ..Default::default()
806 };
807
808 let mut dw = DeepWalk::new(config);
809 let _ = dw.train(&g);
810
811 for node in [0, 1, 2] {
812 let emb = dw.model().get_embedding(&node);
813 assert!(emb.is_some());
814 assert_eq!(emb.map(|e| e.dimensions()).unwrap_or(0), 32);
815 }
816 }
817
818 #[test]
819 fn test_huffman_tree_basic() {
820 let freqs = vec![5.0, 2.0, 1.0, 3.0];
821 let tree = HuffmanTree::build(&freqs);
822 assert!(tree.is_ok());
823
824 let tree = tree.expect("tree should be valid");
825 assert_eq!(tree.codes.len(), 4);
826 assert_eq!(tree.num_internal, 3);
827
828 for (i, code) in tree.codes.iter().enumerate() {
830 assert!(
831 !code.code.is_empty(),
832 "Node {i} should have non-empty Huffman code"
833 );
834 assert!(
835 !code.point.is_empty(),
836 "Node {i} should have non-empty path"
837 );
838 }
839 }
840
841 #[test]
842 fn test_huffman_tree_single_node() {
843 let freqs = vec![1.0];
844 let tree = HuffmanTree::build(&freqs);
845 assert!(tree.is_ok());
846
847 let tree = tree.expect("tree should be valid");
848 assert_eq!(tree.codes.len(), 1);
849 }
850}