1use crate::graph::{GraphRef, WeightedGraphRef};
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11
12#[derive(Debug, Clone, Copy)]
14pub struct WeightedNode2VecPlusConfig {
15 pub length: usize,
17 pub walks_per_node: usize,
19 pub p: f32,
21 pub q: f32,
23 pub gamma: f32,
25 pub seed: u64,
27}
28
29impl Default for WeightedNode2VecPlusConfig {
30 fn default() -> Self {
31 Self {
32 length: 80,
33 walks_per_node: 10,
34 p: 1.0,
35 q: 1.0,
36 gamma: 0.0,
37 seed: 42,
38 }
39 }
40}
41
42pub fn generate_biased_walks_weighted_ref<G: WeightedGraphRef>(
44 graph: &G,
45 config: WeightedNode2VecPlusConfig,
46) -> Vec<Vec<usize>> {
47 generate_biased_walks_weighted_impl(graph, config, false)
48}
49
50pub fn generate_biased_walks_weighted_plus_ref<G: WeightedGraphRef>(
52 graph: &G,
53 config: WeightedNode2VecPlusConfig,
54) -> Vec<Vec<usize>> {
55 generate_biased_walks_weighted_impl(graph, config, true)
56}
57
58fn generate_biased_walks_weighted_impl<G: WeightedGraphRef>(
59 graph: &G,
60 config: WeightedNode2VecPlusConfig,
61 extend: bool,
62) -> Vec<Vec<usize>> {
63 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
64 let mut start_nodes: Vec<usize> = (0..graph.node_count()).collect();
65
66 let noise_thresholds = if extend {
67 compute_noise_thresholds(graph, config.gamma)
68 } else {
69 Vec::new()
70 };
71
72 let mut walks = Vec::with_capacity(graph.node_count() * config.walks_per_node);
73 for _ in 0..config.walks_per_node {
74 start_nodes.shuffle(&mut rng);
75 for &node in &start_nodes {
76 walks.push(weighted_walk(
77 graph,
78 node,
79 config,
80 extend,
81 &noise_thresholds,
82 &mut rng,
83 ));
84 }
85 }
86 walks
87}
88
89fn weighted_walk<G: WeightedGraphRef, R: Rng>(
90 graph: &G,
91 start: usize,
92 config: WeightedNode2VecPlusConfig,
93 extend: bool,
94 noise_thresholds: &[f32],
95 rng: &mut R,
96) -> Vec<usize> {
97 let mut walk = Vec::with_capacity(config.length);
98 walk.push(start);
99
100 let mut curr = start;
101 let mut prev: Option<usize> = None;
102 let mut buf: Vec<f32> = Vec::new();
103
104 for _ in 1..config.length {
105 let (nbrs, wts) = graph.neighbors_and_weights_ref(curr);
106 if nbrs.is_empty() {
107 break;
108 }
109 debug_assert_eq!(nbrs.len(), wts.len());
110
111 let next = if let Some(prev_idx) = prev {
112 if extend {
113 sample_next_node2vec_plus(
114 graph,
115 curr,
116 prev_idx,
117 nbrs,
118 wts,
119 config,
120 noise_thresholds,
121 &mut buf,
122 rng,
123 )
124 } else {
125 sample_next_node2vec_weighted(graph, prev_idx, nbrs, wts, config, &mut buf, rng)
126 }
127 } else {
128 sample_cdf(rng, nbrs, wts)
129 };
130
131 walk.push(next);
132 prev = Some(curr);
133 curr = next;
134 }
135
136 walk
137}
138
139fn sample_next_node2vec_weighted<G: WeightedGraphRef, R: Rng>(
140 graph: &G,
141 prev: usize,
142 nbrs: &[usize],
143 wts: &[f32],
144 config: WeightedNode2VecPlusConfig,
145 buf: &mut Vec<f32>,
146 rng: &mut R,
147) -> usize {
148 fill_next_node2vec_weighted_buf(graph, prev, nbrs, wts, config, buf);
149 sample_cdf(rng, nbrs, buf)
150}
151
152fn fill_next_node2vec_weighted_buf<G: WeightedGraphRef>(
153 graph: &G,
154 prev: usize,
155 nbrs: &[usize],
156 wts: &[f32],
157 config: WeightedNode2VecPlusConfig,
158 buf: &mut Vec<f32>,
159) {
160 let (prev_nbrs, _prev_wts) = graph.neighbors_and_weights_ref(prev);
162
163 buf.clear();
164 buf.extend_from_slice(wts);
165
166 if let Some(i) = nbrs.iter().position(|&x| x == prev) {
168 buf[i] /= config.p;
169 }
170
171 for i in 0..nbrs.len() {
172 let x = nbrs[i];
173 if x == prev {
174 continue;
175 }
176 let is_common = prev_nbrs.contains(&x);
177 if !is_common {
178 buf[i] /= config.q;
179 }
180 }
181}
182
183#[allow(clippy::too_many_arguments)]
184fn sample_next_node2vec_plus<G: WeightedGraphRef, R: Rng>(
185 graph: &G,
186 cur: usize,
187 prev: usize,
188 nbrs: &[usize],
189 wts: &[f32],
190 config: WeightedNode2VecPlusConfig,
191 noise_thresholds: &[f32],
192 buf: &mut Vec<f32>,
193 rng: &mut R,
194) -> usize {
195 fill_next_node2vec_plus_buf(graph, cur, prev, nbrs, wts, config, noise_thresholds, buf);
196 sample_cdf(rng, nbrs, buf)
197}
198
199#[allow(clippy::too_many_arguments)]
200fn fill_next_node2vec_plus_buf<G: WeightedGraphRef>(
201 graph: &G,
202 cur: usize,
203 prev: usize,
204 nbrs: &[usize],
205 wts: &[f32],
206 config: WeightedNode2VecPlusConfig,
207 noise_thresholds: &[f32],
208 buf: &mut Vec<f32>,
209) {
210 let (prev_nbrs, prev_wts) = graph.neighbors_and_weights_ref(prev);
218
219 buf.clear();
220 buf.extend_from_slice(wts);
221
222 if let Some(i) = nbrs.iter().position(|&x| x == prev) {
224 buf[i] /= config.p;
225 }
226
227 let inv_q = 1.0 / config.q;
228 let thr_cur = noise_thresholds[cur];
229
230 for i in 0..nbrs.len() {
231 let x = nbrs[i];
232 if x == prev {
233 continue;
234 }
235
236 let mut is_out = true;
237 let mut t: f32 = 0.0;
238
239 if let Some(j) = prev_nbrs.iter().position(|&y| y == x) {
240 let thr_x = noise_thresholds[x];
241 let w_prev_x = prev_wts[j];
242 if thr_x > 0.0 && w_prev_x >= thr_x {
243 is_out = false;
245 } else if thr_x > 0.0 {
246 t = (w_prev_x / thr_x).max(0.0);
248 }
249 }
250
251 if is_out {
252 let mut alpha = inv_q + (1.0 - inv_q) * t;
253 if buf[i] < thr_cur {
254 alpha = inv_q.min(1.0);
255 }
256 buf[i] *= alpha;
257 }
258 }
259}
260
261fn compute_noise_thresholds<G: WeightedGraphRef>(graph: &G, gamma: f32) -> Vec<f32> {
262 let n = graph.node_count();
263 let mut thr = vec![0.0f32; n];
264
265 for (v, thr_v) in thr.iter_mut().enumerate().take(n) {
266 let (_nbrs, wts) = graph.neighbors_and_weights_ref(v);
267 if wts.is_empty() {
268 *thr_v = 0.0;
269 continue;
270 }
271
272 let mean = wts.iter().copied().sum::<f32>() / (wts.len() as f32);
273 let var = wts
274 .iter()
275 .map(|&x| {
276 let d = x - mean;
277 d * d
278 })
279 .sum::<f32>()
280 / (wts.len() as f32);
281 let std = var.sqrt();
282
283 *thr_v = (mean + gamma * std).max(0.0);
284 }
285
286 thr
287}
288
289fn sample_cdf<R: Rng>(rng: &mut R, nbrs: &[usize], weights: &[f32]) -> usize {
290 debug_assert_eq!(nbrs.len(), weights.len());
291 if nbrs.len() == 1 {
292 return nbrs[0];
293 }
294
295 let sum = weights.iter().copied().sum::<f32>();
296 if !sum.is_finite() || sum <= 0.0 {
297 return *nbrs.choose(rng).unwrap();
298 }
299
300 let mut r = rng.random::<f32>() * sum;
301 for (i, &w) in weights.iter().enumerate() {
302 if r <= w {
303 return nbrs[i];
304 }
305 r -= w;
306 }
307 *nbrs.last().unwrap()
308}
309
310#[derive(Debug, Clone)]
312pub struct PrecomputedBiasedWalks {
313 neighbors: Vec<Vec<usize>>,
314 alias_dim: Vec<u32>,
315 alias_indptr: Vec<u64>,
316 alias_j: Vec<u32>,
317 alias_q: Vec<f32>,
318 p: f32,
319 q: f32,
320}
321
322impl PrecomputedBiasedWalks {
323 pub fn new<G: GraphRef>(graph: &G, p: f32, q: f32) -> Self {
326 let n = graph.node_count();
327 let mut neighbors: Vec<Vec<usize>> = Vec::with_capacity(n);
328 let mut alias_dim: Vec<u32> = Vec::with_capacity(n);
329
330 for v in 0..n {
331 let mut nbrs = graph.neighbors_ref(v).to_vec();
332 nbrs.sort_unstable();
333 alias_dim.push(nbrs.len() as u32);
334 neighbors.push(nbrs);
335 }
336
337 let mut alias_indptr: Vec<u64> = vec![0; n + 1];
338 for i in 0..n {
339 let deg = alias_dim[i] as u64;
340 alias_indptr[i + 1] = alias_indptr[i] + deg * deg;
341 }
342 let total = alias_indptr[n] as usize;
343
344 let mut alias_j = vec![0u32; total];
345 let mut alias_q = vec![0.0f32; total];
346
347 let mut out_ind: Vec<bool> = Vec::new();
348 let mut probs: Vec<f32> = Vec::new();
349
350 for cur in 0..n {
351 let deg = alias_dim[cur] as usize;
352 if deg == 0 {
353 continue;
354 }
355 let offset = alias_indptr[cur] as usize;
356 let cur_nbrs = &neighbors[cur];
357
358 out_ind.clear();
359 out_ind.resize(deg, true);
360 probs.clear();
361 probs.resize(deg, 1.0);
362
363 for prev_j in 0..deg {
364 let prev = cur_nbrs[prev_j];
365 let prev_nbrs = &neighbors[prev];
366
367 mark_non_common(cur_nbrs, prev_nbrs, &mut out_ind);
368 out_ind[prev_j] = false; probs.fill(1.0);
371 for i in 0..deg {
372 if out_ind[i] {
373 probs[i] /= q;
374 }
375 }
376 probs[prev_j] /= p;
377
378 normalize_in_place(&mut probs);
379 let (j, qtab) = alias_setup(&probs);
380
381 let start = offset + deg * prev_j;
382 let end = start + deg;
383 alias_j[start..end].copy_from_slice(&j);
384 alias_q[start..end].copy_from_slice(&qtab);
385 }
386 }
387
388 Self {
389 neighbors,
390 alias_dim,
391 alias_indptr,
392 alias_j,
393 alias_q,
394 p,
395 q,
396 }
397 }
398}
399
400pub fn generate_biased_walks_precomp_ref(
403 pre: &PrecomputedBiasedWalks,
404 config: crate::random_walk::WalkConfig,
405) -> Vec<Vec<usize>> {
406 let start_nodes: Vec<usize> = (0..pre.neighbors.len()).collect();
407 generate_biased_walks_precomp_ref_from_nodes(pre, &start_nodes, config)
408}
409
410pub fn generate_biased_walks_precomp_ref_from_nodes(
415 pre: &PrecomputedBiasedWalks,
416 start_nodes: &[usize],
417 config: crate::random_walk::WalkConfig,
418) -> Vec<Vec<usize>> {
419 if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
420 panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
421 }
422
423 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
424 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
425 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
426
427 for _ in 0..config.walks_per_node {
428 epoch_nodes.shuffle(&mut rng);
429 for &node in &epoch_nodes {
430 walks.push(biased_walk_precomp(pre, node, config.length, &mut rng));
431 }
432 }
433
434 walks
435}
436
437#[cfg(feature = "parallel")]
441pub fn generate_biased_walks_precomp_ref_parallel_from_nodes(
442 pre: &PrecomputedBiasedWalks,
443 start_nodes: &[usize],
444 config: crate::random_walk::WalkConfig,
445) -> Vec<Vec<usize>> {
446 use rayon::prelude::*;
447
448 if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
449 panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
450 }
451
452 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
454 let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
455
456 for epoch in 0..(config.walks_per_node as u32) {
457 fn mix64(mut x: u64) -> u64 {
459 x ^= x >> 30;
460 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
461 x ^= x >> 27;
462 x = x.wrapping_mul(0x94d049bb133111eb);
463 x ^= x >> 31;
464 x
465 }
466
467 let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
468 epoch_nodes.shuffle(&mut rng);
469 for &node in &epoch_nodes {
470 jobs.push((epoch, node));
471 }
472 }
473
474 jobs.par_iter()
475 .enumerate()
476 .map(|(i, (epoch, node))| {
477 fn mix64(mut x: u64) -> u64 {
478 x ^= x >> 30;
479 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
480 x ^= x >> 27;
481 x = x.wrapping_mul(0x94d049bb133111eb);
482 x ^= x >> 31;
483 x
484 }
485
486 let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
487 let mut rng = ChaCha8Rng::seed_from_u64(seed);
488 biased_walk_precomp(pre, *node, config.length, &mut rng)
489 })
490 .collect()
491}
492
493fn biased_walk_precomp<R: Rng>(
494 pre: &PrecomputedBiasedWalks,
495 start: usize,
496 length: usize,
497 rng: &mut R,
498) -> Vec<usize> {
499 let mut walk = Vec::with_capacity(length);
500 walk.push(start);
501 let mut curr = start;
502 let mut prev: Option<usize> = None;
503
504 for _ in 1..length {
505 let nbrs = &pre.neighbors[curr];
506 if nbrs.is_empty() {
507 break;
508 }
509
510 let next = if let Some(p) = prev {
511 sample_precomp(pre, curr, p, rng)
512 } else {
513 *nbrs.choose(rng).unwrap()
514 };
515
516 walk.push(next);
517 prev = Some(curr);
518 curr = next;
519 }
520
521 walk
522}
523
524fn sample_precomp<R: Rng>(
525 pre: &PrecomputedBiasedWalks,
526 cur: usize,
527 prev: usize,
528 rng: &mut R,
529) -> usize {
530 let nbrs = &pre.neighbors[cur];
531 let deg = pre.alias_dim[cur] as usize;
532 let prev_j = match nbrs.binary_search(&prev) {
533 Ok(i) => i,
534 Err(_) => {
535 return *nbrs.choose(rng).unwrap();
542 }
543 };
544
545 let offset = pre.alias_indptr[cur] + (deg as u64) * (prev_j as u64);
546 let start = offset as usize;
547 let end = start + deg;
548
549 let choice = alias_draw(&pre.alias_j[start..end], &pre.alias_q[start..end], rng);
550 nbrs[choice]
551}
552
553fn normalize_in_place(x: &mut [f32]) {
554 let s = x.iter().copied().sum::<f32>();
555 if s > 0.0 {
556 for v in x {
557 *v /= s;
558 }
559 }
560}
561
562fn mark_non_common(cur: &[usize], prev: &[usize], out: &mut [bool]) {
563 debug_assert_eq!(cur.len(), out.len());
564 let mut j = 0usize;
565 for (i, &x) in cur.iter().enumerate() {
566 while j < prev.len() && prev[j] < x {
567 j += 1;
568 }
569 out[i] = !(j < prev.len() && prev[j] == x);
570 }
571}
572
573fn alias_setup(probs: &[f32]) -> (Vec<u32>, Vec<f32>) {
574 let k = probs.len();
585 let mut q = vec![0.0f32; k];
586 let mut j = vec![0u32; k];
587
588 let mut smaller: Vec<usize> = Vec::with_capacity(k);
589 let mut larger: Vec<usize> = Vec::with_capacity(k);
590
591 for kk in 0..k {
592 q[kk] = (k as f32) * probs[kk];
593 if q[kk] < 1.0 {
594 smaller.push(kk);
595 } else {
596 larger.push(kk);
597 }
598 }
599
600 while let (Some(small), Some(large)) = (smaller.pop(), larger.pop()) {
601 j[small] = large as u32;
602 q[large] = q[large] + q[small] - 1.0;
603 if q[large] < 1.0 {
604 smaller.push(large);
605 } else {
606 larger.push(large);
607 }
608 }
609
610 (j, q)
611}
612
613fn alias_draw<R: Rng>(j: &[u32], q: &[f32], rng: &mut R) -> usize {
614 debug_assert_eq!(j.len(), q.len());
615 let k = j.len();
616 let kk = rng.random_range(0..k);
617 if rng.random::<f32>() < q[kk] {
618 kk
619 } else {
620 j[kk] as usize
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[derive(Debug, Clone)]
629 struct RefAdj {
630 adj: Vec<Vec<usize>>,
631 }
632
633 impl RefAdj {
634 fn new(mut adj: Vec<Vec<usize>>) -> Self {
635 for nbrs in &mut adj {
636 nbrs.sort_unstable();
637 }
638 Self { adj }
639 }
640 }
641
642 impl GraphRef for RefAdj {
643 fn node_count(&self) -> usize {
644 self.adj.len()
645 }
646
647 fn neighbors_ref(&self, node: usize) -> &[usize] {
648 self.adj.get(node).map(Vec::as_slice).unwrap_or(&[])
649 }
650 }
651
652 #[derive(Debug, Clone)]
653 struct RefWeightedAdj {
654 adj: Vec<Vec<usize>>,
655 wts: Vec<Vec<f32>>,
656 }
657
658 impl RefWeightedAdj {
659 fn new(mut adj: Vec<Vec<usize>>, mut wts: Vec<Vec<f32>>) -> Self {
660 assert_eq!(adj.len(), wts.len());
661 for i in 0..adj.len() {
662 assert_eq!(adj[i].len(), wts[i].len());
663 let mut pairs: Vec<(usize, f32)> =
664 adj[i].iter().copied().zip(wts[i].iter().copied()).collect();
665 pairs.sort_by_key(|(n, _)| *n);
666 adj[i] = pairs.iter().map(|(n, _)| *n).collect();
667 wts[i] = pairs.iter().map(|(_, w)| *w).collect();
668 }
669 Self { adj, wts }
670 }
671 }
672
673 impl WeightedGraphRef for RefWeightedAdj {
674 fn node_count(&self) -> usize {
675 self.adj.len()
676 }
677
678 fn neighbors_and_weights_ref(&self, node: usize) -> (&[usize], &[f32]) {
679 let nbrs = self.adj.get(node).map(Vec::as_slice).unwrap_or(&[]);
680 let wts = self.wts.get(node).map(Vec::as_slice).unwrap_or(&[]);
681 (nbrs, wts)
682 }
683 }
684
685 fn assert_close_f32(a: f32, b: f32, eps: f32) {
686 assert!(
687 (a - b).abs() <= eps,
688 "expected |{a} - {b}| <= {eps}, got {}",
689 (a - b).abs()
690 );
691 }
692
693 #[test]
694 fn alias_tables_match_expected_for_line_graph() {
695 let g = RefAdj::new(vec![vec![1], vec![0, 2], vec![1]]);
699 let pre = PrecomputedBiasedWalks::new(&g, 0.5, 2.0);
700
701 assert_eq!(pre.alias_dim, vec![1, 2, 1]);
702 assert_eq!(pre.alias_indptr, vec![0, 1, 5, 6]);
703
704 let j01 = &pre.alias_j[1..3];
707 let q01 = &pre.alias_q[1..3];
708 assert_eq!(j01, &[0u32, 0u32]);
709 assert_close_f32(q01[0], 1.0, 1e-6);
710 assert_close_f32(q01[1], 0.4, 1e-6);
711
712 let j21 = &pre.alias_j[3..5];
714 let q21 = &pre.alias_q[3..5];
715 assert_eq!(j21, &[1u32, 0u32]);
716 assert_close_f32(q21[0], 0.4, 1e-6);
717 assert_close_f32(q21[1], 1.0, 1e-6);
718 }
719
720 #[test]
721 fn noise_thresholds_match_mean_plus_gamma_std() {
722 let g = RefWeightedAdj::new(vec![vec![0]], vec![vec![1.0]]);
725 let thr0 = compute_noise_thresholds(&g, 2.0);
726 assert_eq!(thr0.len(), 1);
727 assert_close_f32(thr0[0], 1.0, 1e-6);
729
730 let g2 = RefWeightedAdj::new(vec![vec![0, 1]], vec![vec![1.0, 3.0]]);
731 let thr2 = compute_noise_thresholds(&g2, 2.0);
732 assert_eq!(thr2.len(), 1);
733 assert_close_f32(thr2[0], 4.0, 1e-6);
734 }
735
736 #[test]
737 fn node2vec_plus_suppress_caps_inv_q_when_q_lt_1() {
738 let g = RefWeightedAdj::new(
746 vec![vec![1], vec![0, 2], vec![1]],
747 vec![vec![1.0], vec![1.0, 0.9], vec![1.0]],
748 );
749
750 let cfg = WeightedNode2VecPlusConfig {
751 length: 3,
752 walks_per_node: 1,
753 p: 1.0,
754 q: 0.5, gamma: 0.0, seed: 0,
757 };
758
759 let thr = compute_noise_thresholds(&g, cfg.gamma);
760 assert_eq!(thr.len(), 3);
761 assert_close_f32(thr[1], 0.95, 1e-6);
763
764 let (nbrs, wts) = g.neighbors_and_weights_ref(1);
765 assert_eq!(nbrs, &[0, 2]);
766 assert_eq!(wts, &[1.0, 0.9]);
767
768 let mut buf_weighted = Vec::new();
769 let mut buf_plus = Vec::new();
770
771 fill_next_node2vec_weighted_buf(&g, 0, nbrs, wts, cfg, &mut buf_weighted);
772 fill_next_node2vec_plus_buf(&g, 1, 0, nbrs, wts, cfg, &thr, &mut buf_plus);
773
774 assert_close_f32(buf_weighted[1], 1.8, 1e-6);
776
777 assert_close_f32(buf_plus[1], 0.9, 1e-6);
780 }
781
782 #[test]
783 fn alias_draw_distribution_smoke() {
784 let probs = vec![0.1f32, 0.2f32, 0.7f32];
789 let (j, q) = alias_setup(&probs);
790
791 let trials = 20_000usize;
792 let mut counts = [0usize; 3];
793 for t in 0..trials {
794 let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
795 let k = alias_draw(&j, &q, &mut rng);
796 counts[k] += 1;
797 }
798
799 let expected = [
800 trials as f64 * 0.1,
801 trials as f64 * 0.2,
802 trials as f64 * 0.7,
803 ];
804 let chi2: f64 = counts
805 .iter()
806 .zip(expected.iter())
807 .map(|(&c, &e)| {
808 let diff = c as f64 - e;
809 (diff * diff) / e
810 })
811 .sum();
812
813 assert!(
815 chi2 < 50.0,
816 "chi2 too large (chi2={chi2:.2}). counts={counts:?} expected={expected:?}"
817 );
818 }
819}