Skip to main content

graphops/
node2vec.rs

1//! Node2Vec / Node2Vec+ walk generation (OTF + PreComp).
2//!
3//! Grounded against PecanPy:
4//! - `src/pecanpy/rw/sparse_rw.py` (`get_normalized_probs`, `get_extended_normalized_probs`)
5//! - `src/pecanpy/rw/dense_rw.py` (same semantics for dense graphs)
6//! - `src/pecanpy/pecanpy.py` (overall walk skeleton)
7
8use crate::graph::{GraphRef, WeightedGraphRef};
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11
12/// Parameters for weighted node2vec / node2vec+ walk generation.
13#[derive(Debug, Clone, Copy)]
14pub struct WeightedNode2VecPlusConfig {
15    /// Maximum walk length (in nodes).
16    pub length: usize,
17    /// Number of walks per node.
18    pub walks_per_node: usize,
19    /// Return parameter \(p\).
20    pub p: f32,
21    /// In-out parameter \(q\).
22    pub q: f32,
23    /// Node2vec+ parameter \(\gamma\) controlling the “noisy edge” threshold.
24    pub gamma: f32,
25    /// Seed for deterministic RNG.
26    pub seed: u64,
27}
28
29impl Default for WeightedNode2VecPlusConfig {
30    fn default() -> Self {
31        Self {
32            length: 80,
33            walks_per_node: 10,
34            p: 1.0,
35            q: 1.0,
36            gamma: 0.0,
37            seed: 42,
38        }
39    }
40}
41
42/// Weighted node2vec biased walks over a `WeightedGraphRef`.
43pub fn generate_biased_walks_weighted_ref<G: WeightedGraphRef>(
44    graph: &G,
45    config: WeightedNode2VecPlusConfig,
46) -> Vec<Vec<usize>> {
47    generate_biased_walks_weighted_impl(graph, config, false)
48}
49
50/// Weighted node2vec+ biased walks over a `WeightedGraphRef`.
51pub fn generate_biased_walks_weighted_plus_ref<G: WeightedGraphRef>(
52    graph: &G,
53    config: WeightedNode2VecPlusConfig,
54) -> Vec<Vec<usize>> {
55    generate_biased_walks_weighted_impl(graph, config, true)
56}
57
58fn generate_biased_walks_weighted_impl<G: WeightedGraphRef>(
59    graph: &G,
60    config: WeightedNode2VecPlusConfig,
61    extend: bool,
62) -> Vec<Vec<usize>> {
63    let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
64    let mut start_nodes: Vec<usize> = (0..graph.node_count()).collect();
65
66    let noise_thresholds = if extend {
67        compute_noise_thresholds(graph, config.gamma)
68    } else {
69        Vec::new()
70    };
71
72    let mut walks = Vec::with_capacity(graph.node_count() * config.walks_per_node);
73    for _ in 0..config.walks_per_node {
74        start_nodes.shuffle(&mut rng);
75        for &node in &start_nodes {
76            walks.push(weighted_walk(
77                graph,
78                node,
79                config,
80                extend,
81                &noise_thresholds,
82                &mut rng,
83            ));
84        }
85    }
86    walks
87}
88
89fn weighted_walk<G: WeightedGraphRef, R: Rng>(
90    graph: &G,
91    start: usize,
92    config: WeightedNode2VecPlusConfig,
93    extend: bool,
94    noise_thresholds: &[f32],
95    rng: &mut R,
96) -> Vec<usize> {
97    let mut walk = Vec::with_capacity(config.length);
98    walk.push(start);
99
100    let mut curr = start;
101    let mut prev: Option<usize> = None;
102    let mut buf: Vec<f32> = Vec::new();
103
104    for _ in 1..config.length {
105        let (nbrs, wts) = graph.neighbors_and_weights_ref(curr);
106        if nbrs.is_empty() {
107            break;
108        }
109        debug_assert_eq!(nbrs.len(), wts.len());
110
111        let next = if let Some(prev_idx) = prev {
112            if extend {
113                sample_next_node2vec_plus(
114                    graph,
115                    curr,
116                    prev_idx,
117                    nbrs,
118                    wts,
119                    config,
120                    noise_thresholds,
121                    &mut buf,
122                    rng,
123                )
124            } else {
125                sample_next_node2vec_weighted(graph, prev_idx, nbrs, wts, config, &mut buf, rng)
126            }
127        } else {
128            sample_cdf(rng, nbrs, wts)
129        };
130
131        walk.push(next);
132        prev = Some(curr);
133        curr = next;
134    }
135
136    walk
137}
138
139fn sample_next_node2vec_weighted<G: WeightedGraphRef, R: Rng>(
140    graph: &G,
141    prev: usize,
142    nbrs: &[usize],
143    wts: &[f32],
144    config: WeightedNode2VecPlusConfig,
145    buf: &mut Vec<f32>,
146    rng: &mut R,
147) -> usize {
148    fill_next_node2vec_weighted_buf(graph, prev, nbrs, wts, config, buf);
149    sample_cdf(rng, nbrs, buf)
150}
151
152fn fill_next_node2vec_weighted_buf<G: WeightedGraphRef>(
153    graph: &G,
154    prev: usize,
155    nbrs: &[usize],
156    wts: &[f32],
157    config: WeightedNode2VecPlusConfig,
158    buf: &mut Vec<f32>,
159) {
160    // Classic node2vec: out edges are neighbors(cur) that are not neighbors(prev).
161    let (prev_nbrs, _prev_wts) = graph.neighbors_and_weights_ref(prev);
162
163    buf.clear();
164    buf.extend_from_slice(wts);
165
166    // return bias
167    if let Some(i) = nbrs.iter().position(|&x| x == prev) {
168        buf[i] /= config.p;
169    }
170
171    for i in 0..nbrs.len() {
172        let x = nbrs[i];
173        if x == prev {
174            continue;
175        }
176        let is_common = prev_nbrs.contains(&x);
177        if !is_common {
178            buf[i] /= config.q;
179        }
180    }
181}
182
183#[allow(clippy::too_many_arguments)]
184fn sample_next_node2vec_plus<G: WeightedGraphRef, R: Rng>(
185    graph: &G,
186    cur: usize,
187    prev: usize,
188    nbrs: &[usize],
189    wts: &[f32],
190    config: WeightedNode2VecPlusConfig,
191    noise_thresholds: &[f32],
192    buf: &mut Vec<f32>,
193    rng: &mut R,
194) -> usize {
195    fill_next_node2vec_plus_buf(graph, cur, prev, nbrs, wts, config, noise_thresholds, buf);
196    sample_cdf(rng, nbrs, buf)
197}
198
199#[allow(clippy::too_many_arguments)]
200fn fill_next_node2vec_plus_buf<G: WeightedGraphRef>(
201    graph: &G,
202    cur: usize,
203    prev: usize,
204    nbrs: &[usize],
205    wts: &[f32],
206    config: WeightedNode2VecPlusConfig,
207    noise_thresholds: &[f32],
208    buf: &mut Vec<f32>,
209) {
210    // PecanPy semantics (SparseRWGraph.get_extended_normalized_probs):
211    // - Determine out edges via `isnotin_extended`.
212    // - alpha(out) = 1/q + (1 - 1/q) * t(out), where:
213    //   - t=0 for non-common neighbors
214    //   - t=w(prev,x)/threshold[x] for “loose common” edges (when w(prev,x) < threshold[x])
215    // - suppress: if w(cur,x) < threshold[cur], alpha = min(1, 1/q).
216
217    let (prev_nbrs, prev_wts) = graph.neighbors_and_weights_ref(prev);
218
219    buf.clear();
220    buf.extend_from_slice(wts);
221
222    // return bias
223    if let Some(i) = nbrs.iter().position(|&x| x == prev) {
224        buf[i] /= config.p;
225    }
226
227    let inv_q = 1.0 / config.q;
228    let thr_cur = noise_thresholds[cur];
229
230    for i in 0..nbrs.len() {
231        let x = nbrs[i];
232        if x == prev {
233            continue;
234        }
235
236        let mut is_out = true;
237        let mut t: f32 = 0.0;
238
239        if let Some(j) = prev_nbrs.iter().position(|&y| y == x) {
240            let thr_x = noise_thresholds[x];
241            let w_prev_x = prev_wts[j];
242            if thr_x > 0.0 && w_prev_x >= thr_x {
243                // strong common edge => in-edge
244                is_out = false;
245            } else if thr_x > 0.0 {
246                // loose common edge => out-edge with t in (0, 1)
247                t = (w_prev_x / thr_x).max(0.0);
248            }
249        }
250
251        if is_out {
252            let mut alpha = inv_q + (1.0 - inv_q) * t;
253            if buf[i] < thr_cur {
254                alpha = inv_q.min(1.0);
255            }
256            buf[i] *= alpha;
257        }
258    }
259}
260
261fn compute_noise_thresholds<G: WeightedGraphRef>(graph: &G, gamma: f32) -> Vec<f32> {
262    let n = graph.node_count();
263    let mut thr = vec![0.0f32; n];
264
265    for (v, thr_v) in thr.iter_mut().enumerate().take(n) {
266        let (_nbrs, wts) = graph.neighbors_and_weights_ref(v);
267        if wts.is_empty() {
268            *thr_v = 0.0;
269            continue;
270        }
271
272        let mean = wts.iter().copied().sum::<f32>() / (wts.len() as f32);
273        let var = wts
274            .iter()
275            .map(|&x| {
276                let d = x - mean;
277                d * d
278            })
279            .sum::<f32>()
280            / (wts.len() as f32);
281        let std = var.sqrt();
282
283        *thr_v = (mean + gamma * std).max(0.0);
284    }
285
286    thr
287}
288
289fn sample_cdf<R: Rng>(rng: &mut R, nbrs: &[usize], weights: &[f32]) -> usize {
290    debug_assert_eq!(nbrs.len(), weights.len());
291    if nbrs.len() == 1 {
292        return nbrs[0];
293    }
294
295    let sum = weights.iter().copied().sum::<f32>();
296    if !sum.is_finite() || sum <= 0.0 {
297        return *nbrs.choose(rng).unwrap();
298    }
299
300    let mut r = rng.random::<f32>() * sum;
301    for (i, &w) in weights.iter().enumerate() {
302        if r <= w {
303            return nbrs[i];
304        }
305        r -= w;
306    }
307    *nbrs.last().unwrap()
308}
309
310/// Precomputed alias tables for classic node2vec biased walks (unweighted).
311#[derive(Debug, Clone)]
312pub struct PrecomputedBiasedWalks {
313    neighbors: Vec<Vec<usize>>,
314    alias_dim: Vec<u32>,
315    alias_indptr: Vec<u64>,
316    alias_j: Vec<u32>,
317    alias_q: Vec<f32>,
318    p: f32,
319    q: f32,
320}
321
322impl PrecomputedBiasedWalks {
323    /// Precompute the alias tables needed for node2vec walks with return bias `p`
324    /// and in-out bias `q`.
325    pub fn new<G: GraphRef>(graph: &G, p: f32, q: f32) -> Self {
326        let n = graph.node_count();
327        let mut neighbors: Vec<Vec<usize>> = Vec::with_capacity(n);
328        let mut alias_dim: Vec<u32> = Vec::with_capacity(n);
329
330        for v in 0..n {
331            let mut nbrs = graph.neighbors_ref(v).to_vec();
332            nbrs.sort_unstable();
333            alias_dim.push(nbrs.len() as u32);
334            neighbors.push(nbrs);
335        }
336
337        let mut alias_indptr: Vec<u64> = vec![0; n + 1];
338        for i in 0..n {
339            let deg = alias_dim[i] as u64;
340            alias_indptr[i + 1] = alias_indptr[i] + deg * deg;
341        }
342        let total = alias_indptr[n] as usize;
343
344        let mut alias_j = vec![0u32; total];
345        let mut alias_q = vec![0.0f32; total];
346
347        let mut out_ind: Vec<bool> = Vec::new();
348        let mut probs: Vec<f32> = Vec::new();
349
350        for cur in 0..n {
351            let deg = alias_dim[cur] as usize;
352            if deg == 0 {
353                continue;
354            }
355            let offset = alias_indptr[cur] as usize;
356            let cur_nbrs = &neighbors[cur];
357
358            out_ind.clear();
359            out_ind.resize(deg, true);
360            probs.clear();
361            probs.resize(deg, 1.0);
362
363            for prev_j in 0..deg {
364                let prev = cur_nbrs[prev_j];
365                let prev_nbrs = &neighbors[prev];
366
367                mark_non_common(cur_nbrs, prev_nbrs, &mut out_ind);
368                out_ind[prev_j] = false; // exclude prev from out biases
369
370                probs.fill(1.0);
371                for i in 0..deg {
372                    if out_ind[i] {
373                        probs[i] /= q;
374                    }
375                }
376                probs[prev_j] /= p;
377
378                normalize_in_place(&mut probs);
379                let (j, qtab) = alias_setup(&probs);
380
381                let start = offset + deg * prev_j;
382                let end = start + deg;
383                alias_j[start..end].copy_from_slice(&j);
384                alias_q[start..end].copy_from_slice(&qtab);
385            }
386        }
387
388        Self {
389            neighbors,
390            alias_dim,
391            alias_indptr,
392            alias_j,
393            alias_q,
394            p,
395            q,
396        }
397    }
398}
399
400/// Generate node2vec biased walks using precomputed alias tables (`PrecomputedBiasedWalks`),
401/// amortizing the per-step sampling cost across multiple walks over the same graph.
402pub fn generate_biased_walks_precomp_ref(
403    pre: &PrecomputedBiasedWalks,
404    config: crate::random_walk::WalkConfig,
405) -> Vec<Vec<usize>> {
406    let start_nodes: Vec<usize> = (0..pre.neighbors.len()).collect();
407    generate_biased_walks_precomp_ref_from_nodes(pre, &start_nodes, config)
408}
409
410/// Precomputed node2vec biased walks, restricted to an explicit set of start nodes.
411///
412/// This is the “delta walk” primitive for PreComp mode: generate new walks only for the
413/// subset of nodes whose neighborhood changed (dynamic graphs), or for sharding.
414pub fn generate_biased_walks_precomp_ref_from_nodes(
415    pre: &PrecomputedBiasedWalks,
416    start_nodes: &[usize],
417    config: crate::random_walk::WalkConfig,
418) -> Vec<Vec<usize>> {
419    if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
420        panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
421    }
422
423    let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
424    let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
425    let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
426
427    for _ in 0..config.walks_per_node {
428        epoch_nodes.shuffle(&mut rng);
429        for &node in &epoch_nodes {
430            walks.push(biased_walk_precomp(pre, node, config.length, &mut rng));
431        }
432    }
433
434    walks
435}
436
437/// Deterministic parallel precomputed node2vec biased walks (delta/sharded start nodes).
438///
439/// Invariant: output is stable for a fixed `seed`, independent of Rayon thread count.
440#[cfg(feature = "parallel")]
441pub fn generate_biased_walks_precomp_ref_parallel_from_nodes(
442    pre: &PrecomputedBiasedWalks,
443    start_nodes: &[usize],
444    config: crate::random_walk::WalkConfig,
445) -> Vec<Vec<usize>> {
446    use rayon::prelude::*;
447
448    if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
449        panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
450    }
451
452    // Copy start nodes once; shuffle per epoch using a seed that depends only on (seed, epoch).
453    let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
454    let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
455
456    for epoch in 0..(config.walks_per_node as u32) {
457        // Keep a local mix64 here to avoid exposing random_walk::mix64 publicly.
458        fn mix64(mut x: u64) -> u64 {
459            x ^= x >> 30;
460            x = x.wrapping_mul(0xbf58476d1ce4e5b9);
461            x ^= x >> 27;
462            x = x.wrapping_mul(0x94d049bb133111eb);
463            x ^= x >> 31;
464            x
465        }
466
467        let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
468        epoch_nodes.shuffle(&mut rng);
469        for &node in &epoch_nodes {
470            jobs.push((epoch, node));
471        }
472    }
473
474    jobs.par_iter()
475        .enumerate()
476        .map(|(i, (epoch, node))| {
477            fn mix64(mut x: u64) -> u64 {
478                x ^= x >> 30;
479                x = x.wrapping_mul(0xbf58476d1ce4e5b9);
480                x ^= x >> 27;
481                x = x.wrapping_mul(0x94d049bb133111eb);
482                x ^= x >> 31;
483                x
484            }
485
486            let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
487            let mut rng = ChaCha8Rng::seed_from_u64(seed);
488            biased_walk_precomp(pre, *node, config.length, &mut rng)
489        })
490        .collect()
491}
492
493fn biased_walk_precomp<R: Rng>(
494    pre: &PrecomputedBiasedWalks,
495    start: usize,
496    length: usize,
497    rng: &mut R,
498) -> Vec<usize> {
499    let mut walk = Vec::with_capacity(length);
500    walk.push(start);
501    let mut curr = start;
502    let mut prev: Option<usize> = None;
503
504    for _ in 1..length {
505        let nbrs = &pre.neighbors[curr];
506        if nbrs.is_empty() {
507            break;
508        }
509
510        let next = if let Some(p) = prev {
511            sample_precomp(pre, curr, p, rng)
512        } else {
513            *nbrs.choose(rng).unwrap()
514        };
515
516        walk.push(next);
517        prev = Some(curr);
518        curr = next;
519    }
520
521    walk
522}
523
524fn sample_precomp<R: Rng>(
525    pre: &PrecomputedBiasedWalks,
526    cur: usize,
527    prev: usize,
528    rng: &mut R,
529) -> usize {
530    let nbrs = &pre.neighbors[cur];
531    let deg = pre.alias_dim[cur] as usize;
532    let prev_j = match nbrs.binary_search(&prev) {
533        Ok(i) => i,
534        Err(_) => {
535            // This can happen on directed / non-reciprocal graphs: we might have walked
536            // from `prev -> cur`, but `cur` may not have `prev` in its neighbor list.
537            // PecanPy prints "FATAL ERROR! Neighbor not found." in this situation.
538            //
539            // In Rust, we choose a safe fallback that preserves determinism and avoids
540            // returning a nonsense index: fall back to a 1st-order uniform step.
541            return *nbrs.choose(rng).unwrap();
542        }
543    };
544
545    let offset = pre.alias_indptr[cur] + (deg as u64) * (prev_j as u64);
546    let start = offset as usize;
547    let end = start + deg;
548
549    let choice = alias_draw(&pre.alias_j[start..end], &pre.alias_q[start..end], rng);
550    nbrs[choice]
551}
552
553fn normalize_in_place(x: &mut [f32]) {
554    let s = x.iter().copied().sum::<f32>();
555    if s > 0.0 {
556        for v in x {
557            *v /= s;
558        }
559    }
560}
561
562fn mark_non_common(cur: &[usize], prev: &[usize], out: &mut [bool]) {
563    debug_assert_eq!(cur.len(), out.len());
564    let mut j = 0usize;
565    for (i, &x) in cur.iter().enumerate() {
566        while j < prev.len() && prev[j] < x {
567            j += 1;
568        }
569        out[i] = !(j < prev.len() && prev[j] == x);
570    }
571}
572
573fn alias_setup(probs: &[f32]) -> (Vec<u32>, Vec<f32>) {
574    // Alias table construction (O(k)) for O(1) categorical draws.
575    //
576    // This implementation matches the common “Walker/Vose alias method” presentation and is
577    // intentionally structured to mirror PecanPy’s implementation to reduce drift when comparing
578    // outputs and edge cases.
579    //
580    // References:
581    // - Walker (1974): An efficient method for generating discrete random variables with general distributions.
582    // - Vose (1991): A linear algorithm for generating random numbers with a given distribution.
583    // - PecanPy (software reference implementation): https://github.com/krishnanlab/PecanPy
584    let k = probs.len();
585    let mut q = vec![0.0f32; k];
586    let mut j = vec![0u32; k];
587
588    let mut smaller: Vec<usize> = Vec::with_capacity(k);
589    let mut larger: Vec<usize> = Vec::with_capacity(k);
590
591    for kk in 0..k {
592        q[kk] = (k as f32) * probs[kk];
593        if q[kk] < 1.0 {
594            smaller.push(kk);
595        } else {
596            larger.push(kk);
597        }
598    }
599
600    while let (Some(small), Some(large)) = (smaller.pop(), larger.pop()) {
601        j[small] = large as u32;
602        q[large] = q[large] + q[small] - 1.0;
603        if q[large] < 1.0 {
604            smaller.push(large);
605        } else {
606            larger.push(large);
607        }
608    }
609
610    (j, q)
611}
612
613fn alias_draw<R: Rng>(j: &[u32], q: &[f32], rng: &mut R) -> usize {
614    debug_assert_eq!(j.len(), q.len());
615    let k = j.len();
616    let kk = rng.random_range(0..k);
617    if rng.random::<f32>() < q[kk] {
618        kk
619    } else {
620        j[kk] as usize
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[derive(Debug, Clone)]
629    struct RefAdj {
630        adj: Vec<Vec<usize>>,
631    }
632
633    impl RefAdj {
634        fn new(mut adj: Vec<Vec<usize>>) -> Self {
635            for nbrs in &mut adj {
636                nbrs.sort_unstable();
637            }
638            Self { adj }
639        }
640    }
641
642    impl GraphRef for RefAdj {
643        fn node_count(&self) -> usize {
644            self.adj.len()
645        }
646
647        fn neighbors_ref(&self, node: usize) -> &[usize] {
648            self.adj.get(node).map(Vec::as_slice).unwrap_or(&[])
649        }
650    }
651
652    #[derive(Debug, Clone)]
653    struct RefWeightedAdj {
654        adj: Vec<Vec<usize>>,
655        wts: Vec<Vec<f32>>,
656    }
657
658    impl RefWeightedAdj {
659        fn new(mut adj: Vec<Vec<usize>>, mut wts: Vec<Vec<f32>>) -> Self {
660            assert_eq!(adj.len(), wts.len());
661            for i in 0..adj.len() {
662                assert_eq!(adj[i].len(), wts[i].len());
663                let mut pairs: Vec<(usize, f32)> =
664                    adj[i].iter().copied().zip(wts[i].iter().copied()).collect();
665                pairs.sort_by_key(|(n, _)| *n);
666                adj[i] = pairs.iter().map(|(n, _)| *n).collect();
667                wts[i] = pairs.iter().map(|(_, w)| *w).collect();
668            }
669            Self { adj, wts }
670        }
671    }
672
673    impl WeightedGraphRef for RefWeightedAdj {
674        fn node_count(&self) -> usize {
675            self.adj.len()
676        }
677
678        fn neighbors_and_weights_ref(&self, node: usize) -> (&[usize], &[f32]) {
679            let nbrs = self.adj.get(node).map(Vec::as_slice).unwrap_or(&[]);
680            let wts = self.wts.get(node).map(Vec::as_slice).unwrap_or(&[]);
681            (nbrs, wts)
682        }
683    }
684
685    fn assert_close_f32(a: f32, b: f32, eps: f32) {
686        assert!(
687            (a - b).abs() <= eps,
688            "expected |{a} - {b}| <= {eps}, got {}",
689            (a - b).abs()
690        );
691    }
692
693    #[test]
694    fn alias_tables_match_expected_for_line_graph() {
695        // Graph: 0 -- 1 -- 2
696        // Using p=0.5, q=2.0, when at cur=1 coming from prev=0:
697        // weights are [1/p, 1/q] => [2.0, 0.5] => normalized [0.8, 0.2]
698        let g = RefAdj::new(vec![vec![1], vec![0, 2], vec![1]]);
699        let pre = PrecomputedBiasedWalks::new(&g, 0.5, 2.0);
700
701        assert_eq!(pre.alias_dim, vec![1, 2, 1]);
702        assert_eq!(pre.alias_indptr, vec![0, 1, 5, 6]);
703
704        // cur = 1, neighbors = [0, 2], deg=2.
705        // prev=0 corresponds to prev_j=0, slice is offset=alias_indptr[1]=1, start=1, end=3.
706        let j01 = &pre.alias_j[1..3];
707        let q01 = &pre.alias_q[1..3];
708        assert_eq!(j01, &[0u32, 0u32]);
709        assert_close_f32(q01[0], 1.0, 1e-6);
710        assert_close_f32(q01[1], 0.4, 1e-6);
711
712        // prev=2 corresponds to prev_j=1, start=3, end=5.
713        let j21 = &pre.alias_j[3..5];
714        let q21 = &pre.alias_q[3..5];
715        assert_eq!(j21, &[1u32, 0u32]);
716        assert_close_f32(q21[0], 0.4, 1e-6);
717        assert_close_f32(q21[1], 1.0, 1e-6);
718    }
719
720    #[test]
721    fn noise_thresholds_match_mean_plus_gamma_std() {
722        // One node with two outgoing weights: [1, 3]
723        // mean=2, std=1, tau = mean + gamma*std
724        let g = RefWeightedAdj::new(vec![vec![0]], vec![vec![1.0]]);
725        let thr0 = compute_noise_thresholds(&g, 2.0);
726        assert_eq!(thr0.len(), 1);
727        // Single weight => std=0, tau=mean=1
728        assert_close_f32(thr0[0], 1.0, 1e-6);
729
730        let g2 = RefWeightedAdj::new(vec![vec![0, 1]], vec![vec![1.0, 3.0]]);
731        let thr2 = compute_noise_thresholds(&g2, 2.0);
732        assert_eq!(thr2.len(), 1);
733        assert_close_f32(thr2[0], 4.0, 1e-6);
734    }
735
736    #[test]
737    fn node2vec_plus_suppress_caps_inv_q_when_q_lt_1() {
738        // Construct a situation where:
739        // - q < 1 (so inv_q > 1 would amplify out-edges in classic node2vec)
740        // - node2vec+ suppresses that amplification for “noisy” edges with
741        //   w(cur, x) < threshold[cur].
742        //
743        // Graph: 0 -- 1 -- 2 (weighted, symmetric for existence, but asymmetric weights at node 1)
744        // At cur=1 coming from prev=0, candidate x=2 is an out-edge.
745        let g = RefWeightedAdj::new(
746            vec![vec![1], vec![0, 2], vec![1]],
747            vec![vec![1.0], vec![1.0, 0.9], vec![1.0]],
748        );
749
750        let cfg = WeightedNode2VecPlusConfig {
751            length: 3,
752            walks_per_node: 1,
753            p: 1.0,
754            q: 0.5,     // inv_q = 2.0
755            gamma: 0.0, // threshold is mean
756            seed: 0,
757        };
758
759        let thr = compute_noise_thresholds(&g, cfg.gamma);
760        assert_eq!(thr.len(), 3);
761        // For node 1: weights [1.0, 0.9], mean=0.95 => threshold=0.95
762        assert_close_f32(thr[1], 0.95, 1e-6);
763
764        let (nbrs, wts) = g.neighbors_and_weights_ref(1);
765        assert_eq!(nbrs, &[0, 2]);
766        assert_eq!(wts, &[1.0, 0.9]);
767
768        let mut buf_weighted = Vec::new();
769        let mut buf_plus = Vec::new();
770
771        fill_next_node2vec_weighted_buf(&g, 0, nbrs, wts, cfg, &mut buf_weighted);
772        fill_next_node2vec_plus_buf(&g, 1, 0, nbrs, wts, cfg, &thr, &mut buf_plus);
773
774        // For x=2 (out-edge) classic weighted node2vec divides by q (q=0.5 => multiply by 2).
775        assert_close_f32(buf_weighted[1], 1.8, 1e-6);
776
777        // For node2vec+, since w(cur,2)=0.9 < threshold[cur]=0.95 and inv_q>1,
778        // suppress caps alpha at 1.0, so out-edge stays at 0.9.
779        assert_close_f32(buf_plus[1], 0.9, 1e-6);
780    }
781
782    #[test]
783    fn alias_draw_distribution_smoke() {
784        // Deterministic chi-squared smoke test: catches egregious alias bugs
785        // without being overly sensitive/flaky.
786        //
787        // Distribution: [0.1, 0.2, 0.7]
788        let probs = vec![0.1f32, 0.2f32, 0.7f32];
789        let (j, q) = alias_setup(&probs);
790
791        let trials = 20_000usize;
792        let mut counts = [0usize; 3];
793        for t in 0..trials {
794            let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
795            let k = alias_draw(&j, &q, &mut rng);
796            counts[k] += 1;
797        }
798
799        let expected = [
800            trials as f64 * 0.1,
801            trials as f64 * 0.2,
802            trials as f64 * 0.7,
803        ];
804        let chi2: f64 = counts
805            .iter()
806            .zip(expected.iter())
807            .map(|(&c, &e)| {
808                let diff = c as f64 - e;
809                (diff * diff) / e
810            })
811            .sum();
812
813        // df = 2; E[chi2] ~ 2, Var ~ 4. Use a very conservative cutoff.
814        assert!(
815            chi2 < 50.0,
816            "chi2 too large (chi2={chi2:.2}). counts={counts:?} expected={expected:?}"
817        );
818    }
819}