Skip to main content

graphops/
node2vec.rs

1//! Node2Vec / Node2Vec+ walk generation (OTF + PreComp).
2//!
3//! Grounded against PecanPy:
4//! - `src/pecanpy/rw/sparse_rw.py` (`get_normalized_probs`, `get_extended_normalized_probs`)
5//! - `src/pecanpy/rw/dense_rw.py` (same semantics for dense graphs)
6//! - `src/pecanpy/pecanpy.py` (overall walk skeleton)
7
8use crate::graph::{GraphRef, WeightedGraphRef};
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11
12/// Parameters for weighted node2vec / node2vec+ walk generation.
13#[derive(Debug, Clone, Copy)]
14pub struct WeightedNode2VecPlusConfig {
15    /// Maximum walk length (in nodes).
16    pub length: usize,
17    /// Number of walks per node.
18    pub walks_per_node: usize,
19    /// Return parameter \(p\).
20    pub p: f32,
21    /// In-out parameter \(q\).
22    pub q: f32,
23    /// Node2vec+ parameter \(\gamma\) controlling the “noisy edge” threshold.
24    pub gamma: f32,
25    /// Seed for deterministic RNG.
26    pub seed: u64,
27}
28
29impl Default for WeightedNode2VecPlusConfig {
30    fn default() -> Self {
31        Self {
32            length: 80,
33            walks_per_node: 10,
34            p: 1.0,
35            q: 1.0,
36            gamma: 0.0,
37            seed: 42,
38        }
39    }
40}
41
42pub fn generate_biased_walks_weighted_ref<G: WeightedGraphRef>(
43    graph: &G,
44    config: WeightedNode2VecPlusConfig,
45) -> Vec<Vec<usize>> {
46    generate_biased_walks_weighted_impl(graph, config, false)
47}
48
49pub fn generate_biased_walks_weighted_plus_ref<G: WeightedGraphRef>(
50    graph: &G,
51    config: WeightedNode2VecPlusConfig,
52) -> Vec<Vec<usize>> {
53    generate_biased_walks_weighted_impl(graph, config, true)
54}
55
56fn generate_biased_walks_weighted_impl<G: WeightedGraphRef>(
57    graph: &G,
58    config: WeightedNode2VecPlusConfig,
59    extend: bool,
60) -> Vec<Vec<usize>> {
61    let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
62    let mut start_nodes: Vec<usize> = (0..graph.node_count()).collect();
63
64    let noise_thresholds = if extend {
65        compute_noise_thresholds(graph, config.gamma)
66    } else {
67        Vec::new()
68    };
69
70    let mut walks = Vec::with_capacity(graph.node_count() * config.walks_per_node);
71    for _ in 0..config.walks_per_node {
72        start_nodes.shuffle(&mut rng);
73        for &node in &start_nodes {
74            walks.push(weighted_walk(
75                graph,
76                node,
77                config,
78                extend,
79                &noise_thresholds,
80                &mut rng,
81            ));
82        }
83    }
84    walks
85}
86
87fn weighted_walk<G: WeightedGraphRef, R: Rng>(
88    graph: &G,
89    start: usize,
90    config: WeightedNode2VecPlusConfig,
91    extend: bool,
92    noise_thresholds: &[f32],
93    rng: &mut R,
94) -> Vec<usize> {
95    let mut walk = Vec::with_capacity(config.length);
96    walk.push(start);
97
98    let mut curr = start;
99    let mut prev: Option<usize> = None;
100    let mut buf: Vec<f32> = Vec::new();
101
102    for _ in 1..config.length {
103        let (nbrs, wts) = graph.neighbors_and_weights_ref(curr);
104        if nbrs.is_empty() {
105            break;
106        }
107        debug_assert_eq!(nbrs.len(), wts.len());
108
109        let next = if let Some(prev_idx) = prev {
110            if extend {
111                sample_next_node2vec_plus(
112                    graph,
113                    curr,
114                    prev_idx,
115                    nbrs,
116                    wts,
117                    config,
118                    noise_thresholds,
119                    &mut buf,
120                    rng,
121                )
122            } else {
123                sample_next_node2vec_weighted(graph, prev_idx, nbrs, wts, config, &mut buf, rng)
124            }
125        } else {
126            sample_cdf(rng, nbrs, wts)
127        };
128
129        walk.push(next);
130        prev = Some(curr);
131        curr = next;
132    }
133
134    walk
135}
136
137fn sample_next_node2vec_weighted<G: WeightedGraphRef, R: Rng>(
138    graph: &G,
139    prev: usize,
140    nbrs: &[usize],
141    wts: &[f32],
142    config: WeightedNode2VecPlusConfig,
143    buf: &mut Vec<f32>,
144    rng: &mut R,
145) -> usize {
146    fill_next_node2vec_weighted_buf(graph, prev, nbrs, wts, config, buf);
147    sample_cdf(rng, nbrs, buf)
148}
149
150fn fill_next_node2vec_weighted_buf<G: WeightedGraphRef>(
151    graph: &G,
152    prev: usize,
153    nbrs: &[usize],
154    wts: &[f32],
155    config: WeightedNode2VecPlusConfig,
156    buf: &mut Vec<f32>,
157) {
158    // Classic node2vec: out edges are neighbors(cur) that are not neighbors(prev).
159    let (prev_nbrs, _prev_wts) = graph.neighbors_and_weights_ref(prev);
160
161    buf.clear();
162    buf.extend_from_slice(wts);
163
164    // return bias
165    if let Some(i) = nbrs.iter().position(|&x| x == prev) {
166        buf[i] /= config.p;
167    }
168
169    for i in 0..nbrs.len() {
170        let x = nbrs[i];
171        if x == prev {
172            continue;
173        }
174        let is_common = prev_nbrs.contains(&x);
175        if !is_common {
176            buf[i] /= config.q;
177        }
178    }
179}
180
181#[allow(clippy::too_many_arguments)]
182fn sample_next_node2vec_plus<G: WeightedGraphRef, R: Rng>(
183    graph: &G,
184    cur: usize,
185    prev: usize,
186    nbrs: &[usize],
187    wts: &[f32],
188    config: WeightedNode2VecPlusConfig,
189    noise_thresholds: &[f32],
190    buf: &mut Vec<f32>,
191    rng: &mut R,
192) -> usize {
193    fill_next_node2vec_plus_buf(graph, cur, prev, nbrs, wts, config, noise_thresholds, buf);
194    sample_cdf(rng, nbrs, buf)
195}
196
197#[allow(clippy::too_many_arguments)]
198fn fill_next_node2vec_plus_buf<G: WeightedGraphRef>(
199    graph: &G,
200    cur: usize,
201    prev: usize,
202    nbrs: &[usize],
203    wts: &[f32],
204    config: WeightedNode2VecPlusConfig,
205    noise_thresholds: &[f32],
206    buf: &mut Vec<f32>,
207) {
208    // PecanPy semantics (SparseRWGraph.get_extended_normalized_probs):
209    // - Determine out edges via `isnotin_extended`.
210    // - alpha(out) = 1/q + (1 - 1/q) * t(out), where:
211    //   - t=0 for non-common neighbors
212    //   - t=w(prev,x)/threshold[x] for “loose common” edges (when w(prev,x) < threshold[x])
213    // - suppress: if w(cur,x) < threshold[cur], alpha = min(1, 1/q).
214
215    let (prev_nbrs, prev_wts) = graph.neighbors_and_weights_ref(prev);
216
217    buf.clear();
218    buf.extend_from_slice(wts);
219
220    // return bias
221    if let Some(i) = nbrs.iter().position(|&x| x == prev) {
222        buf[i] /= config.p;
223    }
224
225    let inv_q = 1.0 / config.q;
226    let thr_cur = noise_thresholds[cur];
227
228    for i in 0..nbrs.len() {
229        let x = nbrs[i];
230        if x == prev {
231            continue;
232        }
233
234        let mut is_out = true;
235        let mut t: f32 = 0.0;
236
237        if let Some(j) = prev_nbrs.iter().position(|&y| y == x) {
238            let thr_x = noise_thresholds[x];
239            let w_prev_x = prev_wts[j];
240            if thr_x > 0.0 && w_prev_x >= thr_x {
241                // strong common edge => in-edge
242                is_out = false;
243            } else if thr_x > 0.0 {
244                // loose common edge => out-edge with t in (0, 1)
245                t = (w_prev_x / thr_x).max(0.0);
246            }
247        }
248
249        if is_out {
250            let mut alpha = inv_q + (1.0 - inv_q) * t;
251            if buf[i] < thr_cur {
252                alpha = inv_q.min(1.0);
253            }
254            buf[i] *= alpha;
255        }
256    }
257}
258
259fn compute_noise_thresholds<G: WeightedGraphRef>(graph: &G, gamma: f32) -> Vec<f32> {
260    let n = graph.node_count();
261    let mut thr = vec![0.0f32; n];
262
263    for (v, thr_v) in thr.iter_mut().enumerate().take(n) {
264        let (_nbrs, wts) = graph.neighbors_and_weights_ref(v);
265        if wts.is_empty() {
266            *thr_v = 0.0;
267            continue;
268        }
269
270        let mean = wts.iter().copied().sum::<f32>() / (wts.len() as f32);
271        let var = wts
272            .iter()
273            .map(|&x| {
274                let d = x - mean;
275                d * d
276            })
277            .sum::<f32>()
278            / (wts.len() as f32);
279        let std = var.sqrt();
280
281        *thr_v = (mean + gamma * std).max(0.0);
282    }
283
284    thr
285}
286
287fn sample_cdf<R: Rng>(rng: &mut R, nbrs: &[usize], weights: &[f32]) -> usize {
288    debug_assert_eq!(nbrs.len(), weights.len());
289    if nbrs.len() == 1 {
290        return nbrs[0];
291    }
292
293    let sum = weights.iter().copied().sum::<f32>();
294    if !sum.is_finite() || sum <= 0.0 {
295        return *nbrs.choose(rng).unwrap();
296    }
297
298    let mut r = rng.random::<f32>() * sum;
299    for (i, &w) in weights.iter().enumerate() {
300        if r <= w {
301            return nbrs[i];
302        }
303        r -= w;
304    }
305    *nbrs.last().unwrap()
306}
307
308/// Precomputed alias tables for classic node2vec biased walks (unweighted).
309#[derive(Debug, Clone)]
310pub struct PrecomputedBiasedWalks {
311    neighbors: Vec<Vec<usize>>,
312    alias_dim: Vec<u32>,
313    alias_indptr: Vec<u64>,
314    alias_j: Vec<u32>,
315    alias_q: Vec<f32>,
316    p: f32,
317    q: f32,
318}
319
320impl PrecomputedBiasedWalks {
321    pub fn new<G: GraphRef>(graph: &G, p: f32, q: f32) -> Self {
322        let n = graph.node_count();
323        let mut neighbors: Vec<Vec<usize>> = Vec::with_capacity(n);
324        let mut alias_dim: Vec<u32> = Vec::with_capacity(n);
325
326        for v in 0..n {
327            let mut nbrs = graph.neighbors_ref(v).to_vec();
328            nbrs.sort_unstable();
329            alias_dim.push(nbrs.len() as u32);
330            neighbors.push(nbrs);
331        }
332
333        let mut alias_indptr: Vec<u64> = vec![0; n + 1];
334        for i in 0..n {
335            let deg = alias_dim[i] as u64;
336            alias_indptr[i + 1] = alias_indptr[i] + deg * deg;
337        }
338        let total = alias_indptr[n] as usize;
339
340        let mut alias_j = vec![0u32; total];
341        let mut alias_q = vec![0.0f32; total];
342
343        let mut out_ind: Vec<bool> = Vec::new();
344        let mut probs: Vec<f32> = Vec::new();
345
346        for cur in 0..n {
347            let deg = alias_dim[cur] as usize;
348            if deg == 0 {
349                continue;
350            }
351            let offset = alias_indptr[cur] as usize;
352            let cur_nbrs = &neighbors[cur];
353
354            out_ind.clear();
355            out_ind.resize(deg, true);
356            probs.clear();
357            probs.resize(deg, 1.0);
358
359            for prev_j in 0..deg {
360                let prev = cur_nbrs[prev_j];
361                let prev_nbrs = &neighbors[prev];
362
363                mark_non_common(cur_nbrs, prev_nbrs, &mut out_ind);
364                out_ind[prev_j] = false; // exclude prev from out biases
365
366                probs.fill(1.0);
367                for i in 0..deg {
368                    if out_ind[i] {
369                        probs[i] /= q;
370                    }
371                }
372                probs[prev_j] /= p;
373
374                normalize_in_place(&mut probs);
375                let (j, qtab) = alias_setup(&probs);
376
377                let start = offset + deg * prev_j;
378                let end = start + deg;
379                alias_j[start..end].copy_from_slice(&j);
380                alias_q[start..end].copy_from_slice(&qtab);
381            }
382        }
383
384        Self {
385            neighbors,
386            alias_dim,
387            alias_indptr,
388            alias_j,
389            alias_q,
390            p,
391            q,
392        }
393    }
394}
395
396pub fn generate_biased_walks_precomp_ref(
397    pre: &PrecomputedBiasedWalks,
398    config: crate::random_walk::WalkConfig,
399) -> Vec<Vec<usize>> {
400    let start_nodes: Vec<usize> = (0..pre.neighbors.len()).collect();
401    generate_biased_walks_precomp_ref_from_nodes(pre, &start_nodes, config)
402}
403
404/// Precomputed node2vec biased walks, restricted to an explicit set of start nodes.
405///
406/// This is the “delta walk” primitive for PreComp mode: generate new walks only for the
407/// subset of nodes whose neighborhood changed (dynamic graphs), or for sharding.
408pub fn generate_biased_walks_precomp_ref_from_nodes(
409    pre: &PrecomputedBiasedWalks,
410    start_nodes: &[usize],
411    config: crate::random_walk::WalkConfig,
412) -> Vec<Vec<usize>> {
413    if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
414        panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
415    }
416
417    let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
418    let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
419    let mut walks = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
420
421    for _ in 0..config.walks_per_node {
422        epoch_nodes.shuffle(&mut rng);
423        for &node in &epoch_nodes {
424            walks.push(biased_walk_precomp(pre, node, config.length, &mut rng));
425        }
426    }
427
428    walks
429}
430
431/// Deterministic parallel precomputed node2vec biased walks (delta/sharded start nodes).
432///
433/// Invariant: output is stable for a fixed `seed`, independent of Rayon thread count.
434#[cfg(feature = "parallel")]
435pub fn generate_biased_walks_precomp_ref_parallel_from_nodes(
436    pre: &PrecomputedBiasedWalks,
437    start_nodes: &[usize],
438    config: crate::random_walk::WalkConfig,
439) -> Vec<Vec<usize>> {
440    use rayon::prelude::*;
441
442    if (pre.p - config.p).abs() > 1e-6 || (pre.q - config.q).abs() > 1e-6 {
443        panic!("PrecomputedBiasedWalks p/q do not match WalkConfig");
444    }
445
446    // Copy start nodes once; shuffle per epoch using a seed that depends only on (seed, epoch).
447    let mut epoch_nodes: Vec<usize> = start_nodes.to_vec();
448    let mut jobs: Vec<(u32, usize)> = Vec::with_capacity(start_nodes.len() * config.walks_per_node);
449
450    for epoch in 0..(config.walks_per_node as u32) {
451        // Keep a local mix64 here to avoid exposing random_walk::mix64 publicly.
452        fn mix64(mut x: u64) -> u64 {
453            x ^= x >> 30;
454            x = x.wrapping_mul(0xbf58476d1ce4e5b9);
455            x ^= x >> 27;
456            x = x.wrapping_mul(0x94d049bb133111eb);
457            x ^= x >> 31;
458            x
459        }
460
461        let mut rng = ChaCha8Rng::seed_from_u64(mix64(config.seed ^ (epoch as u64)));
462        epoch_nodes.shuffle(&mut rng);
463        for &node in &epoch_nodes {
464            jobs.push((epoch, node));
465        }
466    }
467
468    jobs.par_iter()
469        .enumerate()
470        .map(|(i, (epoch, node))| {
471            fn mix64(mut x: u64) -> u64 {
472                x ^= x >> 30;
473                x = x.wrapping_mul(0xbf58476d1ce4e5b9);
474                x ^= x >> 27;
475                x = x.wrapping_mul(0x94d049bb133111eb);
476                x ^= x >> 31;
477                x
478            }
479
480            let seed = mix64(config.seed ^ ((*epoch as u64) << 32) ^ (*node as u64) ^ (i as u64));
481            let mut rng = ChaCha8Rng::seed_from_u64(seed);
482            biased_walk_precomp(pre, *node, config.length, &mut rng)
483        })
484        .collect()
485}
486
487fn biased_walk_precomp<R: Rng>(
488    pre: &PrecomputedBiasedWalks,
489    start: usize,
490    length: usize,
491    rng: &mut R,
492) -> Vec<usize> {
493    let mut walk = Vec::with_capacity(length);
494    walk.push(start);
495    let mut curr = start;
496    let mut prev: Option<usize> = None;
497
498    for _ in 1..length {
499        let nbrs = &pre.neighbors[curr];
500        if nbrs.is_empty() {
501            break;
502        }
503
504        let next = if let Some(p) = prev {
505            sample_precomp(pre, curr, p, rng)
506        } else {
507            *nbrs.choose(rng).unwrap()
508        };
509
510        walk.push(next);
511        prev = Some(curr);
512        curr = next;
513    }
514
515    walk
516}
517
518fn sample_precomp<R: Rng>(
519    pre: &PrecomputedBiasedWalks,
520    cur: usize,
521    prev: usize,
522    rng: &mut R,
523) -> usize {
524    let nbrs = &pre.neighbors[cur];
525    let deg = pre.alias_dim[cur] as usize;
526    let prev_j = match nbrs.binary_search(&prev) {
527        Ok(i) => i,
528        Err(_) => {
529            // This can happen on directed / non-reciprocal graphs: we might have walked
530            // from `prev -> cur`, but `cur` may not have `prev` in its neighbor list.
531            // PecanPy prints "FATAL ERROR! Neighbor not found." in this situation.
532            //
533            // In Rust, we choose a safe fallback that preserves determinism and avoids
534            // returning a nonsense index: fall back to a 1st-order uniform step.
535            return *nbrs.choose(rng).unwrap();
536        }
537    };
538
539    let offset = pre.alias_indptr[cur] + (deg as u64) * (prev_j as u64);
540    let start = offset as usize;
541    let end = start + deg;
542
543    let choice = alias_draw(&pre.alias_j[start..end], &pre.alias_q[start..end], rng);
544    nbrs[choice]
545}
546
547fn normalize_in_place(x: &mut [f32]) {
548    let s = x.iter().copied().sum::<f32>();
549    if s > 0.0 {
550        for v in x {
551            *v /= s;
552        }
553    }
554}
555
556fn mark_non_common(cur: &[usize], prev: &[usize], out: &mut [bool]) {
557    debug_assert_eq!(cur.len(), out.len());
558    let mut j = 0usize;
559    for (i, &x) in cur.iter().enumerate() {
560        while j < prev.len() && prev[j] < x {
561            j += 1;
562        }
563        out[i] = !(j < prev.len() && prev[j] == x);
564    }
565}
566
567fn alias_setup(probs: &[f32]) -> (Vec<u32>, Vec<f32>) {
568    // Alias table construction (O(k)) for O(1) categorical draws.
569    //
570    // This implementation matches the common “Walker/Vose alias method” presentation and is
571    // intentionally structured to mirror PecanPy’s implementation to reduce drift when comparing
572    // outputs and edge cases.
573    //
574    // References:
575    // - Walker (1974): An efficient method for generating discrete random variables with general distributions.
576    // - Vose (1991): A linear algorithm for generating random numbers with a given distribution.
577    // - PecanPy (software reference implementation): https://github.com/krishnanlab/PecanPy
578    let k = probs.len();
579    let mut q = vec![0.0f32; k];
580    let mut j = vec![0u32; k];
581
582    let mut smaller: Vec<usize> = Vec::with_capacity(k);
583    let mut larger: Vec<usize> = Vec::with_capacity(k);
584
585    for kk in 0..k {
586        q[kk] = (k as f32) * probs[kk];
587        if q[kk] < 1.0 {
588            smaller.push(kk);
589        } else {
590            larger.push(kk);
591        }
592    }
593
594    while let (Some(small), Some(large)) = (smaller.pop(), larger.pop()) {
595        j[small] = large as u32;
596        q[large] = q[large] + q[small] - 1.0;
597        if q[large] < 1.0 {
598            smaller.push(large);
599        } else {
600            larger.push(large);
601        }
602    }
603
604    (j, q)
605}
606
607fn alias_draw<R: Rng>(j: &[u32], q: &[f32], rng: &mut R) -> usize {
608    debug_assert_eq!(j.len(), q.len());
609    let k = j.len();
610    let kk = rng.random_range(0..k);
611    if rng.random::<f32>() < q[kk] {
612        kk
613    } else {
614        j[kk] as usize
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[derive(Debug, Clone)]
623    struct RefAdj {
624        adj: Vec<Vec<usize>>,
625    }
626
627    impl RefAdj {
628        fn new(mut adj: Vec<Vec<usize>>) -> Self {
629            for nbrs in &mut adj {
630                nbrs.sort_unstable();
631            }
632            Self { adj }
633        }
634    }
635
636    impl GraphRef for RefAdj {
637        fn node_count(&self) -> usize {
638            self.adj.len()
639        }
640
641        fn neighbors_ref(&self, node: usize) -> &[usize] {
642            self.adj.get(node).map(Vec::as_slice).unwrap_or(&[])
643        }
644    }
645
646    #[derive(Debug, Clone)]
647    struct RefWeightedAdj {
648        adj: Vec<Vec<usize>>,
649        wts: Vec<Vec<f32>>,
650    }
651
652    impl RefWeightedAdj {
653        fn new(mut adj: Vec<Vec<usize>>, mut wts: Vec<Vec<f32>>) -> Self {
654            assert_eq!(adj.len(), wts.len());
655            for i in 0..adj.len() {
656                assert_eq!(adj[i].len(), wts[i].len());
657                let mut pairs: Vec<(usize, f32)> =
658                    adj[i].iter().copied().zip(wts[i].iter().copied()).collect();
659                pairs.sort_by_key(|(n, _)| *n);
660                adj[i] = pairs.iter().map(|(n, _)| *n).collect();
661                wts[i] = pairs.iter().map(|(_, w)| *w).collect();
662            }
663            Self { adj, wts }
664        }
665    }
666
667    impl WeightedGraphRef for RefWeightedAdj {
668        fn node_count(&self) -> usize {
669            self.adj.len()
670        }
671
672        fn neighbors_and_weights_ref(&self, node: usize) -> (&[usize], &[f32]) {
673            let nbrs = self.adj.get(node).map(Vec::as_slice).unwrap_or(&[]);
674            let wts = self.wts.get(node).map(Vec::as_slice).unwrap_or(&[]);
675            (nbrs, wts)
676        }
677    }
678
679    fn assert_close_f32(a: f32, b: f32, eps: f32) {
680        assert!(
681            (a - b).abs() <= eps,
682            "expected |{a} - {b}| <= {eps}, got {}",
683            (a - b).abs()
684        );
685    }
686
687    #[test]
688    fn alias_tables_match_expected_for_line_graph() {
689        // Graph: 0 -- 1 -- 2
690        // Using p=0.5, q=2.0, when at cur=1 coming from prev=0:
691        // weights are [1/p, 1/q] => [2.0, 0.5] => normalized [0.8, 0.2]
692        let g = RefAdj::new(vec![vec![1], vec![0, 2], vec![1]]);
693        let pre = PrecomputedBiasedWalks::new(&g, 0.5, 2.0);
694
695        assert_eq!(pre.alias_dim, vec![1, 2, 1]);
696        assert_eq!(pre.alias_indptr, vec![0, 1, 5, 6]);
697
698        // cur = 1, neighbors = [0, 2], deg=2.
699        // prev=0 corresponds to prev_j=0, slice is offset=alias_indptr[1]=1, start=1, end=3.
700        let j01 = &pre.alias_j[1..3];
701        let q01 = &pre.alias_q[1..3];
702        assert_eq!(j01, &[0u32, 0u32]);
703        assert_close_f32(q01[0], 1.0, 1e-6);
704        assert_close_f32(q01[1], 0.4, 1e-6);
705
706        // prev=2 corresponds to prev_j=1, start=3, end=5.
707        let j21 = &pre.alias_j[3..5];
708        let q21 = &pre.alias_q[3..5];
709        assert_eq!(j21, &[1u32, 0u32]);
710        assert_close_f32(q21[0], 0.4, 1e-6);
711        assert_close_f32(q21[1], 1.0, 1e-6);
712    }
713
714    #[test]
715    fn noise_thresholds_match_mean_plus_gamma_std() {
716        // One node with two outgoing weights: [1, 3]
717        // mean=2, std=1, tau = mean + gamma*std
718        let g = RefWeightedAdj::new(vec![vec![0]], vec![vec![1.0]]);
719        let thr0 = compute_noise_thresholds(&g, 2.0);
720        assert_eq!(thr0.len(), 1);
721        // Single weight => std=0, tau=mean=1
722        assert_close_f32(thr0[0], 1.0, 1e-6);
723
724        let g2 = RefWeightedAdj::new(vec![vec![0, 1]], vec![vec![1.0, 3.0]]);
725        let thr2 = compute_noise_thresholds(&g2, 2.0);
726        assert_eq!(thr2.len(), 1);
727        assert_close_f32(thr2[0], 4.0, 1e-6);
728    }
729
730    #[test]
731    fn node2vec_plus_suppress_caps_inv_q_when_q_lt_1() {
732        // Construct a situation where:
733        // - q < 1 (so inv_q > 1 would amplify out-edges in classic node2vec)
734        // - node2vec+ suppresses that amplification for “noisy” edges with
735        //   w(cur, x) < threshold[cur].
736        //
737        // Graph: 0 -- 1 -- 2 (weighted, symmetric for existence, but asymmetric weights at node 1)
738        // At cur=1 coming from prev=0, candidate x=2 is an out-edge.
739        let g = RefWeightedAdj::new(
740            vec![vec![1], vec![0, 2], vec![1]],
741            vec![vec![1.0], vec![1.0, 0.9], vec![1.0]],
742        );
743
744        let cfg = WeightedNode2VecPlusConfig {
745            length: 3,
746            walks_per_node: 1,
747            p: 1.0,
748            q: 0.5,     // inv_q = 2.0
749            gamma: 0.0, // threshold is mean
750            seed: 0,
751        };
752
753        let thr = compute_noise_thresholds(&g, cfg.gamma);
754        assert_eq!(thr.len(), 3);
755        // For node 1: weights [1.0, 0.9], mean=0.95 => threshold=0.95
756        assert_close_f32(thr[1], 0.95, 1e-6);
757
758        let (nbrs, wts) = g.neighbors_and_weights_ref(1);
759        assert_eq!(nbrs, &[0, 2]);
760        assert_eq!(wts, &[1.0, 0.9]);
761
762        let mut buf_weighted = Vec::new();
763        let mut buf_plus = Vec::new();
764
765        fill_next_node2vec_weighted_buf(&g, 0, nbrs, wts, cfg, &mut buf_weighted);
766        fill_next_node2vec_plus_buf(&g, 1, 0, nbrs, wts, cfg, &thr, &mut buf_plus);
767
768        // For x=2 (out-edge) classic weighted node2vec divides by q (q=0.5 => multiply by 2).
769        assert_close_f32(buf_weighted[1], 1.8, 1e-6);
770
771        // For node2vec+, since w(cur,2)=0.9 < threshold[cur]=0.95 and inv_q>1,
772        // suppress caps alpha at 1.0, so out-edge stays at 0.9.
773        assert_close_f32(buf_plus[1], 0.9, 1e-6);
774    }
775
776    #[test]
777    fn alias_draw_distribution_smoke() {
778        // Deterministic chi-squared smoke test: catches egregious alias bugs
779        // without being overly sensitive/flaky.
780        //
781        // Distribution: [0.1, 0.2, 0.7]
782        let probs = vec![0.1f32, 0.2f32, 0.7f32];
783        let (j, q) = alias_setup(&probs);
784
785        let trials = 20_000usize;
786        let mut counts = [0usize; 3];
787        for t in 0..trials {
788            let mut rng = ChaCha8Rng::seed_from_u64(t as u64);
789            let k = alias_draw(&j, &q, &mut rng);
790            counts[k] += 1;
791        }
792
793        let expected = [
794            trials as f64 * 0.1,
795            trials as f64 * 0.2,
796            trials as f64 * 0.7,
797        ];
798        let chi2: f64 = counts
799            .iter()
800            .zip(expected.iter())
801            .map(|(&c, &e)| {
802                let diff = c as f64 - e;
803                (diff * diff) / e
804            })
805            .sum();
806
807        // df = 2; E[chi2] ~ 2, Var ~ 4. Use a very conservative cutoff.
808        assert!(
809            chi2 < 50.0,
810            "chi2 too large (chi2={chi2:.2}). counts={counts:?} expected={expected:?}"
811        );
812    }
813}