1use crate::graph::{GraphRef, WeightedGraphRef};
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11
12#[derive(Debug, Clone, Copy)]
14pub struct WeightedNode2VecPlusConfig {
15 pub length: usize,
17 pub walks_per_node: usize,
19 pub p: f32,
21 pub q: f32,
23 pub gamma: f32,
25 pub seed: u64,
27}
28
29impl Default for WeightedNode2VecPlusConfig {
30 fn default() -> Self {
31 Self {
32 length: 80,
33 walks_per_node: 10,
34 p: 1.0,
35 q: 1.0,
36 gamma: 0.0,
37 seed: 42,
38 }
39 }
40}
41
42pub fn generate_biased_walks_weighted_ref<G: WeightedGraphRef>(
43 graph: &G,
44 config: WeightedNode2VecPlusConfig,
45) -> Vec<Vec<usize>> {
46 generate_biased_walks_weighted_impl(graph, config, false)
47}
48
49pub fn generate_biased_walks_weighted_plus_ref<G: WeightedGraphRef>(
50 graph: &G,
51 config: WeightedNode2VecPlusConfig,
52) -> Vec<Vec<usize>> {
53 generate_biased_walks_weighted_impl(graph, config, true)
54}
55
56fn generate_biased_walks_weighted_impl<G: WeightedGraphRef>(
57 graph: &G,
58 config: WeightedNode2VecPlusConfig,
59 extend: bool,
60) -> Vec<Vec<usize>> {
61 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
62 let mut start_nodes: Vec<usize> = (0..graph.node_count()).collect();
63
64 let noise_thresholds = if extend {
65 compute_noise_thresholds(graph, config.gamma)
66 } else {
67 Vec::new()
68 };
69
70 let mut walks = Vec::with_capacity(graph.node_count() * config.walks_per_node);
71 for _ in 0..config.walks_per_node {
72 start_nodes.shuffle(&mut rng);
73 for &node in &start_nodes {
74 walks.push(weighted_walk(
75 graph,
76 node,
77 config,
78 extend,
79 &noise_thresholds,
80 &mut rng,
81 ));
82 }
83 }
84 walks
85}
86
87fn weighted_walk<G: WeightedGraphRef, R: Rng>(
88 graph: &G,
89 start: usize,
90 config: WeightedNode2VecPlusConfig,
91 extend: bool,
92 noise_thresholds: &[f32],
93 rng: &mut R,
94) -> Vec<usize> {
95 let mut walk = Vec::with_capacity(config.length);
96 walk.push(start);
97
98 let mut curr = start;
99 let mut prev: Option<usize> = None;
100 let mut buf: Vec<f32> = Vec::new();
101
102 for _ in 1..config.length {
103 let (nbrs, wts) = graph.neighbors_and_weights_ref(curr);
104 if nbrs.is_empty() {
105 break;
106 }
107 debug_assert_eq!(nbrs.len(), wts.len());
108
109 let next = if let Some(prev_idx) = prev {
110 if extend {
111 sample_next_node2vec_plus(
112 graph,
113 curr,
114 prev_idx,
115 nbrs,
116 wts,
117 config,
118 noise_thresholds,
119 &mut buf,
120 rng,
121 )
122 } else {
123 sample_next_node2vec_weighted(graph, prev_idx, nbrs, wts, config, &mut buf, rng)
124 }
125 } else {
126 sample_cdf(rng, nbrs, wts)
127 };
128
129 walk.push(next);
130 prev = Some(curr);
131 curr = next;
132 }
133
134 walk
135}
136
137fn sample_next_node2vec_weighted<G: WeightedGraphRef, R: Rng>(
138 graph: &G,
139 prev: usize,
140 nbrs: &[usize],
141 wts: &[f32],
142 config: WeightedNode2VecPlusConfig,
143 buf: &mut Vec<f32>,
144 rng: &mut R,
145) -> usize {
146 fill_next_node2vec_weighted_buf(graph, prev, nbrs, wts, config, buf);
147 sample_cdf(rng, nbrs, buf)
148}
149
150fn fill_next_node2vec_weighted_buf<G: WeightedGraphRef>(
151 graph: &G,
152 prev: usize,
153 nbrs: &[usize],
154 wts: &[f32],
155 config: WeightedNode2VecPlusConfig,
156 buf: &mut Vec<f32>,
157) {
158 let (prev_nbrs, _prev_wts) = graph.neighbors_and_weights_ref(prev);
160
161 buf.clear();
162 buf.extend_from_slice(wts);
163
164 if let Some(i) = nbrs.iter().position(|&x| x == prev) {
166 buf[i] /= config.p;
167 }
168
169 for i in 0..nbrs.len() {
170 let x = nbrs[i];
171 if x == prev {
172 continue;
173 }
174 let is_common = prev_nbrs.contains(&x);
175 if !is_common {
176 buf[i] /= config.q;
177 }
178 }
179}
180
181#[allow(clippy::too_many_arguments)]
182fn sample_next_node2vec_plus<G: WeightedGraphRef, R: Rng>(
183 graph: &G,
184 cur: usize,
185 prev: usize,
186 nbrs: &[usize],
187 wts: &[f32],
188 config: WeightedNode2VecPlusConfig,
189 noise_thresholds: &[f32],
190 buf: &mut Vec<f32>,
191 rng: &mut R,
192) -> usize {
193 fill_next_node2vec_plus_buf(graph, cur, prev, nbrs, wts, config, noise_thresholds, buf);
194 sample_cdf(rng, nbrs, buf)
195}
196
197#[allow(clippy::too_many_arguments)]
198fn fill_next_node2vec_plus_buf<G: WeightedGraphRef>(
199 graph: &G,
200 cur: usize,
201 prev: usize,
202 nbrs: &[usize],
203 wts: &[f32],
204 config: WeightedNode2VecPlusConfig,
205 noise_thresholds: &[f32],
206 buf: &mut Vec<f32>,
207) {
208 let (prev_nbrs, prev_wts) = graph.neighbors_and_weights_ref(prev);
216
217 buf.clear();
218 buf.extend_from_slice(wts);
219
220 if let Some(i) = nbrs.iter().position(|&x| x == prev) {
222 buf[i] /= config.p;
223 }
224
225 let inv_q = 1.0 / config.q;
226 let thr_cur = noise_thresholds[cur];
227
228 for i in 0..nbrs.len() {
229 let x = nbrs[i];
230 if x == prev {
231 continue;
232 }
233
234 let mut is_out = true;
235 let mut t: f32 = 0.0;
236
237 if let Some(j) = prev_nbrs.iter().position(|&y| y == x) {
238 let thr_x = noise_thresholds[x];
239 let w_prev_x = prev_wts[j];
240 if thr_x > 0.0 && w_prev_x >= thr_x {
241 is_out = false;
243 } else if thr_x > 0.0 {
244 t = (w_prev_x / thr_x).max(0.0);
246 }
247 }
248
249 if is_out {
250 let mut alpha = inv_q + (1.0 - inv_q) * t;
251 if buf[i] < thr_cur {
252 alpha = inv_q.min(1.0);
253 }
254 buf[i] *= alpha;
255 }
256 }
257}
258
259fn compute_noise_thresholds<G: WeightedGraphRef>(graph: &G, gamma: f32) -> Vec<f32> {
260 let n = graph.node_count();
261 let mut thr = vec![0.0f32; n];
262
263 for (v, thr_v) in thr.iter_mut().enumerate().take(n) {
264 let (_nbrs, wts) = graph.neighbors_and_weights_ref(v);
265 if wts.is_empty() {
266 *thr_v = 0.0;
267 continue;
268 }
269
270 let mean = wts.iter().copied().sum::<f32>() / (wts.len() as f32);
271 let var = wts
272 .iter()
273 .map(|&x| {
274 let d = x - mean;
275 d * d
276 })
277 .sum::<f32>()
278 / (wts.len() as f32);
279 let std = var.sqrt();
280
281 *thr_v = (mean + gamma * std).max(0.0);
282 }
283
284 thr
285}
286
287fn sample_cdf<R: Rng>(rng: &mut R, nbrs: &[usize], weights: &[f32]) -> usize {
288 debug_assert_eq!(nbrs.len(), weights.len());
289 if nbrs.len() == 1 {
290 return nbrs[0];
291 }
292
293 let sum = weights.iter().copied().sum::<f32>();
294 if !sum.is_finite() || sum <= 0.0 {
295 return *nbrs.choose(rng).unwrap();
296 }
297
298 let mut r = rng.random::<f32>() * sum;
299 for (i, &w) in weights.iter().enumerate() {
300 if r <= w {
301 return nbrs[i];
302 }
303 r -= w;
304 }
305 *nbrs.last().unwrap()
306}
307
308#[derive(Debug, Clone)]
310pub struct PrecomputedBiasedWalks {
311 neighbors: Vec<Vec<usize>>,
312 alias_dim: Vec<u32>,
313 alias_indptr: Vec<u64>,
314 alias_j: Vec<u32>,
315 alias_q: Vec<f32>,
316 p: f32,
317 q: f32,
318}
319
320impl PrecomputedBiasedWalks {
321 pub fn new<G: GraphRef>(graph: &G, p: f32, q: f32) -> Self {
322 let n = graph.node_count();
323 let mut neighbors: Vec<Vec<usize>> = Vec::with_capacity(n);
324 let mut alias_dim: Vec<u32> = Vec::with_capacity(n);
325
326 for v in 0..n {
327 let mut nbrs = graph.neighbors_ref(v).to_vec();
328 nbrs.sort_unstable();
329 alias_dim.push(nbrs.len() as u32);
330 neighbors.push(nbrs);
331 }
332
333 let mut alias_indptr: Vec<u64> = vec![0; n + 1];
334 for i in 0..n {
335 let deg = alias_dim[i] as u64;
336 alias_indptr[i + 1] = alias_indptr[i] + deg * deg;
337 }
338 let total = alias_indptr[n] as usize;
339
340 let mut alias_j = vec![0u32; total];
341 let mut alias_q = vec![0.0f32; total];
342
343 let mut out_ind: Vec<bool> = Vec::new();
344 let mut probs: Vec<f32> = Vec::new();
345
346 for cur in 0..n {
347 let deg = alias_dim[cur] as usize;
348 if deg == 0 {
349 continue;
350 }
351 let offset = alias_indptr[cur] as usize;
352 let cur_nbrs = &neighbors[cur];
353
354 out_ind.clear();
355 out_ind.resize(deg, true);
356 probs.clear();
357 probs.resize(deg, 1.0);
358
359 for prev_j in 0..deg {
360 let prev = cur_nbrs[prev_j];
361 let prev_nbrs = &neighbors[prev];
362
363 mark_non_common(cur_nbrs, prev_nbrs, &mut out_ind);
364 out_ind[prev_j] = false; probs.fill(1.0);
367 for i in 0..deg {
368 if out_ind[i] {
369 probs[i] /= q;
370 }
371 }
372 probs[prev_j] /= p;
373
374 normalize_in_place(&mut probs);
375 let (j, qtab) = alias_setup(&probs);
376
377 let start = offset + deg * prev_j;
378 let end = start + deg;
379 alias_j[start..end].copy_from_slice(&j);
380 alias_q[start..end].copy_from_slice(&qtab);
381 }
382 }
383
384 Self {
385 neighbors,
386 alias_dim,
387 alias_indptr,
388 alias_j,
389 alias_q,
390 p,
391 q,
392 }
393 }
394}
395
396pub fn generate_biased_walks_precomp_ref(
397 pre: &PrecomputedBiasedWalks,
398 config: crate::random_walk::WalkConfig,
399) -> Vec<Vec<usize>> {
400 let start_nodes: Vec<usize> = (0..pre.neighbors.len()).collect();
401 generate_biased_walks_precomp_ref_from_nodes(pre, &start_nodes, config)
402}
403
404pub fn generate_biased_walks_precomp_ref_from_nodes(
409 pre: &PrecomputedBiasedWalks,
410 start_nodes: &[usize],
411 config: crate::random_walk::WalkConfig,
412) -> Vec<Vec<usize>> {
413 if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
414 panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
415 }
416
417 let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
418 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
419 let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
420
421 for _ in 0..config.walks_per_node {
422 epoch_nodes.shuffle(&mut rng);
423 for &node in &epoch_nodes {
424 walks.push(biased_walk_precomp(pre, node, config.length, &mut rng));
425 }
426 }
427
428 walks
429}
430
431#[cfg(feature = "parallel")]
435pub fn generate_biased_walks_precomp_ref_parallel_from_nodes(
436 pre: &PrecomputedBiasedWalks,
437 start_nodes: &[usize],
438 config: crate::random_walk::WalkConfig,
439) -> Vec<Vec<usize>> {
440 use rayon::prelude::*;
441
442 if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
443 panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
444 }
445
446 let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
448 let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
449
450 for epoch in 0..(config.walks_per_node as u32) {
451 fn mix64(mut x: u64) -> u64 {
453 x ^= x >> 30;
454 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
455 x ^= x >> 27;
456 x = x.wrapping_mul(0x94d049bb133111eb);
457 x ^= x >> 31;
458 x
459 }
460
461 let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
462 epoch_nodes.shuffle(&mut rng);
463 for &node in &epoch_nodes {
464 jobs.push((epoch, node));
465 }
466 }
467
468 jobs.par_iter()
469 .enumerate()
470 .map(|(i, (epoch, node))| {
471 fn mix64(mut x: u64) -> u64 {
472 x ^= x >> 30;
473 x = x.wrapping_mul(0xbf58476d1ce4e5b9);
474 x ^= x >> 27;
475 x = x.wrapping_mul(0x94d049bb133111eb);
476 x ^= x >> 31;
477 x
478 }
479
480 let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
481 let mut rng = ChaCha8Rng::seed_from_u64(seed);
482 biased_walk_precomp(pre, *node, config.length, &mut rng)
483 })
484 .collect()
485}
486
487fn biased_walk_precomp<R: Rng>(
488 pre: &PrecomputedBiasedWalks,
489 start: usize,
490 length: usize,
491 rng: &mut R,
492) -> Vec<usize> {
493 let mut walk = Vec::with_capacity(length);
494 walk.push(start);
495 let mut curr = start;
496 let mut prev: Option<usize> = None;
497
498 for _ in 1..length {
499 let nbrs = &pre.neighbors[curr];
500 if nbrs.is_empty() {
501 break;
502 }
503
504 let next = if let Some(p) = prev {
505 sample_precomp(pre, curr, p, rng)
506 } else {
507 *nbrs.choose(rng).unwrap()
508 };
509
510 walk.push(next);
511 prev = Some(curr);
512 curr = next;
513 }
514
515 walk
516}
517
518fn sample_precomp<R: Rng>(
519 pre: &PrecomputedBiasedWalks,
520 cur: usize,
521 prev: usize,
522 rng: &mut R,
523) -> usize {
524 let nbrs = &pre.neighbors[cur];
525 let deg = pre.alias_dim[cur] as usize;
526 let prev_j = match nbrs.binary_search(&prev) {
527 Ok(i) => i,
528 Err(_) => {
529 return *nbrs.choose(rng).unwrap();
536 }
537 };
538
539 let offset = pre.alias_indptr[cur] + (deg as u64) * (prev_j as u64);
540 let start = offset as usize;
541 let end = start + deg;
542
543 let choice = alias_draw(&pre.alias_j[start..end], &pre.alias_q[start..end], rng);
544 nbrs[choice]
545}
546
547fn normalize_in_place(x: &mut [f32]) {
548 let s = x.iter().copied().sum::<f32>();
549 if s > 0.0 {
550 for v in x {
551 *v /= s;
552 }
553 }
554}
555
556fn mark_non_common(cur: &[usize], prev: &[usize], out: &mut [bool]) {
557 debug_assert_eq!(cur.len(), out.len());
558 let mut j = 0usize;
559 for (i, &x) in cur.iter().enumerate() {
560 while j < prev.len() && prev[j] < x {
561 j += 1;
562 }
563 out[i] = !(j < prev.len() && prev[j] == x);
564 }
565}
566
567fn alias_setup(probs: &[f32]) -> (Vec<u32>, Vec<f32>) {
568 let k = probs.len();
579 let mut q = vec![0.0f32; k];
580 let mut j = vec![0u32; k];
581
582 let mut smaller: Vec<usize> = Vec::with_capacity(k);
583 let mut larger: Vec<usize> = Vec::with_capacity(k);
584
585 for kk in 0..k {
586 q[kk] = (k as f32) * probs[kk];
587 if q[kk] < 1.0 {
588 smaller.push(kk);
589 } else {
590 larger.push(kk);
591 }
592 }
593
594 while let (Some(small), Some(large)) = (smaller.pop(), larger.pop()) {
595 j[small] = large as u32;
596 q[large] = q[large] + q[small] - 1.0;
597 if q[large] < 1.0 {
598 smaller.push(large);
599 } else {
600 larger.push(large);
601 }
602 }
603
604 (j, q)
605}
606
607fn alias_draw<R: Rng>(j: &[u32], q: &[f32], rng: &mut R) -> usize {
608 debug_assert_eq!(j.len(), q.len());
609 let k = j.len();
610 let kk = rng.random_range(0..k);
611 if rng.random::<f32>() < q[kk] {
612 kk
613 } else {
614 j[kk] as usize
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[derive(Debug, Clone)]
623 struct RefAdj {
624 adj: Vec<Vec<usize>>,
625 }
626
627 impl RefAdj {
628 fn new(mut adj: Vec<Vec<usize>>) -> Self {
629 for nbrs in &mut adj {
630 nbrs.sort_unstable();
631 }
632 Self { adj }
633 }
634 }
635
636 impl GraphRef for RefAdj {
637 fn node_count(&self) -> usize {
638 self.adj.len()
639 }
640
641 fn neighbors_ref(&self, node: usize) -> &[usize] {
642 self.adj.get(node).map(Vec::as_slice).unwrap_or(&[])
643 }
644 }
645
646 #[derive(Debug, Clone)]
647 struct RefWeightedAdj {
648 adj: Vec<Vec<usize>>,
649 wts: Vec<Vec<f32>>,
650 }
651
652 impl RefWeightedAdj {
653 fn new(mut adj: Vec<Vec<usize>>, mut wts: Vec<Vec<f32>>) -> Self {
654 assert_eq!(adj.len(), wts.len());
655 for i in 0..adj.len() {
656 assert_eq!(adj[i].len(), wts[i].len());
657 let mut pairs: Vec<(usize, f32)> =
658 adj[i].iter().copied().zip(wts[i].iter().copied()).collect();
659 pairs.sort_by_key(|(n, _)| *n);
660 adj[i] = pairs.iter().map(|(n, _)| *n).collect();
661 wts[i] = pairs.iter().map(|(_, w)| *w).collect();
662 }
663 Self { adj, wts }
664 }
665 }
666
667 impl WeightedGraphRef for RefWeightedAdj {
668 fn node_count(&self) -> usize {
669 self.adj.len()
670 }
671
672 fn neighbors_and_weights_ref(&self, node: usize) -> (&[usize], &[f32]) {
673 let nbrs = self.adj.get(node).map(Vec::as_slice).unwrap_or(&[]);
674 let wts = self.wts.get(node).map(Vec::as_slice).unwrap_or(&[]);
675 (nbrs, wts)
676 }
677 }
678
679 fn assert_close_f32(a: f32, b: f32, eps: f32) {
680 assert!(
681 (a - b).abs() <= eps,
682 "expected |{a} - {b}| <= {eps}, got {}",
683 (a - b).abs()
684 );
685 }
686
687 #[test]
688 fn alias_tables_match_expected_for_line_graph() {
689 let g = RefAdj::new(vec![vec![1], vec![0, 2], vec![1]]);
693 let pre = PrecomputedBiasedWalks::new(&g, 0.5, 2.0);
694
695 assert_eq!(pre.alias_dim, vec![1, 2, 1]);
696 assert_eq!(pre.alias_indptr, vec![0, 1, 5, 6]);
697
698 let j01 = &pre.alias_j[1..3];
701 let q01 = &pre.alias_q[1..3];
702 assert_eq!(j01, &[0u32, 0u32]);
703 assert_close_f32(q01[0], 1.0, 1e-6);
704 assert_close_f32(q01[1], 0.4, 1e-6);
705
706 let j21 = &pre.alias_j[3..5];
708 let q21 = &pre.alias_q[3..5];
709 assert_eq!(j21, &[1u32, 0u32]);
710 assert_close_f32(q21[0], 0.4, 1e-6);
711 assert_close_f32(q21[1], 1.0, 1e-6);
712 }
713
714 #[test]
715 fn noise_thresholds_match_mean_plus_gamma_std() {
716 let g = RefWeightedAdj::new(vec![vec![0]], vec![vec![1.0]]);
719 let thr0 = compute_noise_thresholds(&g, 2.0);
720 assert_eq!(thr0.len(), 1);
721 assert_close_f32(thr0[0], 1.0, 1e-6);
723
724 let g2 = RefWeightedAdj::new(vec![vec![0, 1]], vec![vec![1.0, 3.0]]);
725 let thr2 = compute_noise_thresholds(&g2, 2.0);
726 assert_eq!(thr2.len(), 1);
727 assert_close_f32(thr2[0], 4.0, 1e-6);
728 }
729
730 #[test]
731 fn node2vec_plus_suppress_caps_inv_q_when_q_lt_1() {
732 let g = RefWeightedAdj::new(
740 vec![vec![1], vec![0, 2], vec![1]],
741 vec![vec![1.0], vec![1.0, 0.9], vec![1.0]],
742 );
743
744 let cfg = WeightedNode2VecPlusConfig {
745 length: 3,
746 walks_per_node: 1,
747 p: 1.0,
748 q: 0.5, gamma: 0.0, seed: 0,
751 };
752
753 let thr = compute_noise_thresholds(&g, cfg.gamma);
754 assert_eq!(thr.len(), 3);
755 assert_close_f32(thr[1], 0.95, 1e-6);
757
758 let (nbrs, wts) = g.neighbors_and_weights_ref(1);
759 assert_eq!(nbrs, &[0, 2]);
760 assert_eq!(wts, &[1.0, 0.9]);
761
762 let mut buf_weighted = Vec::new();
763 let mut buf_plus = Vec::new();
764
765 fill_next_node2vec_weighted_buf(&g, 0, nbrs, wts, cfg, &mut buf_weighted);
766 fill_next_node2vec_plus_buf(&g, 1, 0, nbrs, wts, cfg, &thr, &mut buf_plus);
767
768 assert_close_f32(buf_weighted[1], 1.8, 1e-6);
770
771 assert_close_f32(buf_plus[1], 0.9, 1e-6);
774 }
775
776 #[test]
777 fn alias_draw_distribution_smoke() {
778 let probs = vec![0.1f32, 0.2f32, 0.7f32];
783 let (j, q) = alias_setup(&probs);
784
785 let trials = 20_000usize;
786 let mut counts = [0usize; 3];
787 for t in 0..trials {
788 let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
789 let k = alias_draw(&j, &q, &mut rng);
790 counts[k] += 1;
791 }
792
793 let expected = [
794 trials as f64 * 0.1,
795 trials as f64 * 0.2,
796 trials as f64 * 0.7,
797 ];
798 let chi2: f64 = counts
799 .iter()
800 .zip(expected.iter())
801 .map(|(&c, &e)| {
802 let diff = c as f64 - e;
803 (diff * diff) / e
804 })
805 .sum();
806
807 assert!(
809 chi2 < 50.0,
810 "chi2 too large (chi2={chi2:.2}). counts={counts:?} expected={expected:?}"
811 );
812 }
813}