Skip to main content

scirs2_graph/
sampling.rs

1//! Graph Sampling Algorithms
2//!
3//! This module provides a comprehensive suite of graph sampling methods including:
4//!
5//! - **Random walks**: Uniform random walk, Node2Vec biased random walk
6//! - **Graph sampling**: Frontier sampling, forest-fire sampling, snowball sampling
7//! - **Subgraph operations**: Induced subgraph extraction
8//!
9//! All algorithms operate on adjacency-list representations for efficiency.
10//!
11//! ## References
12//! - Leskovec & Faloutsos (2006): Sampling from Large Graphs. KDD 2006.
13//! - Grover & Leskovec (2016): node2vec: Scalable Feature Learning for Networks. KDD 2016.
14//! - Stumpf et al. (2005): Subnets of scale-free networks are not scale-free. PNAS.
15
16use std::collections::{HashMap, HashSet, VecDeque};
17
18use crate::error::{GraphError, Result};
19
20/// Return type for [`induced_subgraph`]: `(subgraph_adjacency, original_indices)`.
21type InducedSubgraphResult = (Vec<Vec<(usize, f64)>>, Vec<usize>);
22
23// ─────────────────────────────────────────────────────────────────────────────
24// Minimal LCG-based PRNG (avoids external rand dependency)
25// ─────────────────────────────────────────────────────────────────────────────
26
27/// A fast, seedable linear-congruential pseudo-random number generator.
28///
29/// Uses the parameters from Knuth's MMIX (64-bit LCG).
30struct Lcg {
31    state: u64,
32}
33
34impl Lcg {
35    fn new(seed: u64) -> Self {
36        // Avoid degenerate seed=0 by mixing in a constant.
37        Self {
38            state: seed.wrapping_add(6364136223846793005),
39        }
40    }
41
42    /// Advance the state and return the next u64.
43    fn next_u64(&mut self) -> u64 {
44        self.state = self
45            .state
46            .wrapping_mul(6364136223846793005)
47            .wrapping_add(1442695040888963407);
48        self.state
49    }
50
51    /// Return a uniform f64 in [0, 1).
52    fn next_f64(&mut self) -> f64 {
53        // Use upper 53 bits for the mantissa.
54        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
55    }
56
57    /// Return a uniform usize in 0..n (exclusive). Panics if n == 0.
58    fn next_usize(&mut self, n: usize) -> usize {
59        debug_assert!(n > 0, "n must be > 0");
60        (self.next_u64() as usize) % n
61    }
62}
63
64// ─────────────────────────────────────────────────────────────────────────────
65// Random Walk
66// ─────────────────────────────────────────────────────────────────────────────
67
68/// Perform a uniform random walk on an unweighted graph.
69///
70/// Starting from `start_node`, at each step a uniformly random neighbour is
71/// chosen.  If the current node has no neighbours the walk terminates early.
72///
73/// # Parameters
74/// - `adjacency`   – adjacency list (unweighted); `adjacency[u]` contains the
75///   neighbours of node `u`.
76/// - `start_node`  – index of the walk's first node.
77/// - `walk_length` – desired total number of nodes in the walk (including the
78///   starting node).
79/// - `rng_seed`    – seed for the internal pseudo-random number generator.
80///
81/// # Returns
82/// A `Vec<usize>` of length ≤ `walk_length` with the visited node sequence.
83///
84/// # Errors
85/// Returns [`GraphError::InvalidParameter`] if `start_node` is out of range.
86pub fn random_walk(
87    adjacency: &[Vec<usize>],
88    start_node: usize,
89    walk_length: usize,
90    rng_seed: u64,
91) -> Result<Vec<usize>> {
92    let n = adjacency.len();
93    if start_node >= n {
94        return Err(GraphError::invalid_parameter(
95            "start_node",
96            start_node,
97            format!("must be < n_nodes ({})", n),
98        ));
99    }
100    if walk_length == 0 {
101        return Ok(Vec::new());
102    }
103
104    let mut rng = Lcg::new(rng_seed);
105    let mut walk = Vec::with_capacity(walk_length);
106    walk.push(start_node);
107
108    let mut current = start_node;
109    for _ in 1..walk_length {
110        let neighbours = &adjacency[current];
111        if neighbours.is_empty() {
112            break;
113        }
114        current = neighbours[rng.next_usize(neighbours.len())];
115        walk.push(current);
116    }
117
118    Ok(walk)
119}
120
121// ─────────────────────────────────────────────────────────────────────────────
122// Node2Vec Biased Random Walk
123// ─────────────────────────────────────────────────────────────────────────────
124
125/// Perform a Node2Vec biased random walk on a weighted graph.
126///
127/// Node2Vec generalises DeepWalk by interpolating between BFS-like (p<1) and
128/// DFS-like (q<1) exploration using the *return* parameter `p` and the
129/// *in-out* parameter `q`.
130///
131/// The transition probability from node `v` to neighbour `x` (when the
132/// previous node was `t`) is proportional to:
133/// - `1/p` if `x == t`  (backtrack)
134/// - `1`   if `x` is also a neighbour of `t`  (same distance)
135/// - `1/q` otherwise    (explore further)
136///
137/// # Parameters
138/// - `adjacency`   – weighted adjacency list; `adjacency[u]` is a list of
139///   `(neighbour, weight)` pairs.
140/// - `start_node`  – starting node index.
141/// - `walk_length` – desired walk length (≥ 1).
142/// - `p`           – return parameter (> 0). Higher values discourage backtracking.
143/// - `q`           – in-out parameter (> 0). < 1 favours DFS-like walks; > 1
144///   favours BFS-like walks.
145/// - `rng_seed`    – PRNG seed.
146///
147/// # Errors
148/// Returns [`GraphError::InvalidParameter`] for out-of-range `start_node` or
149/// non-positive `p`/`q`.
150pub fn node2vec_walk(
151    adjacency: &[Vec<(usize, f64)>],
152    start_node: usize,
153    walk_length: usize,
154    p: f64,
155    q: f64,
156    rng_seed: u64,
157) -> Result<Vec<usize>> {
158    let n = adjacency.len();
159    if start_node >= n {
160        return Err(GraphError::invalid_parameter(
161            "start_node",
162            start_node,
163            format!("must be < n_nodes ({})", n),
164        ));
165    }
166    if p <= 0.0 {
167        return Err(GraphError::invalid_parameter(
168            "p",
169            p,
170            "must be strictly positive",
171        ));
172    }
173    if q <= 0.0 {
174        return Err(GraphError::invalid_parameter(
175            "q",
176            q,
177            "must be strictly positive",
178        ));
179    }
180    if walk_length == 0 {
181        return Ok(Vec::new());
182    }
183
184    // Pre-build a fast neighbour-set lookup for bias computation.
185    // neighbour_set[u] is the set of indices adjacent to u.
186    let neighbour_sets: Vec<HashSet<usize>> = adjacency
187        .iter()
188        .map(|nbrs| nbrs.iter().map(|&(v, _)| v).collect())
189        .collect();
190
191    let mut rng = Lcg::new(rng_seed);
192    let mut walk: Vec<usize> = Vec::with_capacity(walk_length);
193    walk.push(start_node);
194
195    // First step: uniform over neighbours (no previous node).
196    if walk_length == 1 || adjacency[start_node].is_empty() {
197        return Ok(walk);
198    }
199    let first_idx = rng.next_usize(adjacency[start_node].len());
200    let first_next = adjacency[start_node][first_idx].0;
201    walk.push(first_next);
202
203    // Subsequent steps: biased by p and q relative to previous node.
204    for _ in 2..walk_length {
205        let prev = walk[walk.len() - 2];
206        let curr = walk[walk.len() - 1];
207
208        let nbrs = &adjacency[curr];
209        if nbrs.is_empty() {
210            break;
211        }
212
213        // Compute unnormalised weights for each candidate.
214        let prev_set = &neighbour_sets[prev];
215        let weights: Vec<f64> = nbrs
216            .iter()
217            .map(|&(x, edge_w)| {
218                let bias = if x == prev {
219                    1.0 / p
220                } else if prev_set.contains(&x) {
221                    1.0
222                } else {
223                    1.0 / q
224                };
225                (edge_w.max(0.0)) * bias
226            })
227            .collect();
228
229        let total: f64 = weights.iter().sum();
230        let next_node = if total <= 0.0 {
231            // Fallback to uniform if all weights are zero.
232            nbrs[rng.next_usize(nbrs.len())].0
233        } else {
234            let threshold = rng.next_f64() * total;
235            let mut cumulative = 0.0;
236            let mut chosen = nbrs.last().map(|&(v, _)| v).unwrap_or(curr);
237            for (idx, &w) in weights.iter().enumerate() {
238                cumulative += w;
239                if cumulative >= threshold {
240                    chosen = nbrs[idx].0;
241                    break;
242                }
243            }
244            chosen
245        };
246
247        walk.push(next_node);
248    }
249
250    Ok(walk)
251}
252
253// ─────────────────────────────────────────────────────────────────────────────
254// Frontier Sampling
255// ─────────────────────────────────────────────────────────────────────────────
256
257/// Frontier-based graph sampling.
258///
259/// Maintains a *frontier* set of nodes and at each step:
260/// 1. Picks a random frontier node `u`.
261/// 2. Picks a random neighbour `v` of `u`.
262/// 3. If `v` is not yet sampled, adds it to the sample and the frontier; if
263///    `v` already sampled, reinserts `u` into the frontier (Frontier Sampling
264///    per Stumpf et al. / Leskovec & Faloutsos 2006).
265///
266/// Frontier sampling preserves degree distribution better than naive random
267/// node or random edge sampling.
268///
269/// # Parameters
270/// - `adjacency`   – unweighted adjacency list.
271/// - `n_nodes`     – total number of nodes (= `adjacency.len()`).
272/// - `sample_size` – desired number of nodes in the sample.
273/// - `rng_seed`    – PRNG seed.
274///
275/// # Returns
276/// Sorted `Vec<usize>` of sampled node indices (length ≤ `sample_size`).
277///
278/// # Errors
279/// Returns [`GraphError::InvalidParameter`] if `n_nodes` is 0 or
280/// `sample_size > n_nodes`.
281pub fn frontier_sampling(
282    adjacency: &[Vec<usize>],
283    n_nodes: usize,
284    sample_size: usize,
285    rng_seed: u64,
286) -> Result<Vec<usize>> {
287    if n_nodes == 0 {
288        return Err(GraphError::invalid_parameter(
289            "n_nodes",
290            0usize,
291            "must be > 0",
292        ));
293    }
294    if sample_size > n_nodes {
295        return Err(GraphError::invalid_parameter(
296            "sample_size",
297            sample_size,
298            format!("must be ≤ n_nodes ({})", n_nodes),
299        ));
300    }
301    if sample_size == 0 {
302        return Ok(Vec::new());
303    }
304
305    let mut rng = Lcg::new(rng_seed);
306    let mut sampled: HashSet<usize> = HashSet::with_capacity(sample_size);
307    let mut frontier: Vec<usize> = Vec::new();
308
309    // Seed with a random starting node.
310    let seed = rng.next_usize(n_nodes);
311    sampled.insert(seed);
312    frontier.push(seed);
313
314    let mut iters = 0usize;
315    let max_iters = sample_size * n_nodes.max(100) * 10;
316
317    while sampled.len() < sample_size && !frontier.is_empty() && iters < max_iters {
318        iters += 1;
319        // Pick random frontier node.
320        let fi = rng.next_usize(frontier.len());
321        let u = frontier[fi];
322
323        let nbrs = &adjacency[u];
324        if nbrs.is_empty() {
325            // Dead-end: remove u from frontier.
326            frontier.swap_remove(fi);
327            continue;
328        }
329
330        let v = nbrs[rng.next_usize(nbrs.len())];
331        if sampled.insert(v) {
332            // New node: add to frontier.
333            frontier.push(v);
334        }
335        // Whether new or not, keep u in frontier (it may have other unvisited neighbours).
336    }
337
338    // If graph is disconnected and we haven't reached sample_size, inject random unsampled nodes.
339    if sampled.len() < sample_size {
340        for candidate in 0..n_nodes {
341            if sampled.len() >= sample_size {
342                break;
343            }
344            sampled.insert(candidate);
345        }
346    }
347
348    let mut result: Vec<usize> = sampled.into_iter().collect();
349    result.sort_unstable();
350    Ok(result)
351}
352
353// ─────────────────────────────────────────────────────────────────────────────
354// Forest-Fire Sampling
355// ─────────────────────────────────────────────────────────────────────────────
356
357/// Forest-fire graph sampling.
358///
359/// Mimics a "fire spreading" process: from each burning node, a geometrically
360/// distributed number of unvisited neighbours are "burned" with forward
361/// probability `forward_prob`.  The process regenerates from a new random seed
362/// when all fires die out.
363///
364/// Forest-fire sampling is known to preserve heavy-tail degree distributions
365/// and densification patterns (Leskovec et al. 2005).
366///
367/// # Parameters
368/// - `adjacency`    – unweighted adjacency list.
369/// - `n_nodes`      – total number of nodes.
370/// - `sample_size`  – target number of sampled nodes.
371/// - `forward_prob` – probability of spreading to each neighbour (0 < p < 1).
372/// - `rng_seed`     – PRNG seed.
373///
374/// # Errors
375/// Returns [`GraphError::InvalidParameter`] for invalid inputs.
376pub fn forest_fire_sampling(
377    adjacency: &[Vec<usize>],
378    n_nodes: usize,
379    sample_size: usize,
380    forward_prob: f64,
381    rng_seed: u64,
382) -> Result<Vec<usize>> {
383    if n_nodes == 0 {
384        return Err(GraphError::invalid_parameter(
385            "n_nodes",
386            0usize,
387            "must be > 0",
388        ));
389    }
390    if sample_size > n_nodes {
391        return Err(GraphError::invalid_parameter(
392            "sample_size",
393            sample_size,
394            format!("must be ≤ n_nodes ({})", n_nodes),
395        ));
396    }
397    if forward_prob <= 0.0 || forward_prob >= 1.0 {
398        return Err(GraphError::invalid_parameter(
399            "forward_prob",
400            forward_prob,
401            "must be in (0, 1)",
402        ));
403    }
404    if sample_size == 0 {
405        return Ok(Vec::new());
406    }
407
408    let mut rng = Lcg::new(rng_seed);
409    let mut sampled: HashSet<usize> = HashSet::with_capacity(sample_size);
410    // Queue of currently-burning nodes.
411    let mut burning: VecDeque<usize> = VecDeque::new();
412
413    // Helper: geometric-distributed number of links to burn.
414    // Draw from Geometric(1 - forward_prob): # of successes before first failure.
415    let geometric_draw = |rng: &mut Lcg| -> usize {
416        let mut count = 0usize;
417        while rng.next_f64() < forward_prob {
418            count += 1;
419        }
420        count
421    };
422
423    while sampled.len() < sample_size {
424        // Light a new fire from a random unsampled node.
425        if burning.is_empty() {
426            // Find an unsampled node.
427            let start = rng.next_usize(n_nodes);
428            let mut found = false;
429            for offset in 0..n_nodes {
430                let candidate = (start + offset) % n_nodes;
431                if sampled.insert(candidate) {
432                    burning.push_back(candidate);
433                    found = true;
434                    break;
435                }
436            }
437            if !found {
438                break; // All nodes sampled.
439            }
440        }
441
442        // Spread the fire.
443        while let Some(u) = burning.pop_front() {
444            if sampled.len() >= sample_size {
445                break;
446            }
447            let nbrs = &adjacency[u];
448            if nbrs.is_empty() {
449                continue;
450            }
451
452            // Number of neighbours to burn (capped by available).
453            let n_burn = geometric_draw(&mut rng).min(nbrs.len());
454            if n_burn == 0 {
455                continue;
456            }
457
458            // Pick n_burn distinct unsampled neighbours (reservoir sample).
459            // Shuffle first n_burn positions of a candidate list.
460            let mut candidates: Vec<usize> = nbrs.clone();
461            for i in 0..n_burn {
462                let j = i + rng.next_usize(candidates.len() - i);
463                candidates.swap(i, j);
464            }
465            for &v in candidates.iter().take(n_burn) {
466                if sampled.len() >= sample_size {
467                    break;
468                }
469                if sampled.insert(v) {
470                    burning.push_back(v);
471                }
472            }
473        }
474    }
475
476    let mut result: Vec<usize> = sampled.into_iter().collect();
477    result.sort_unstable();
478    Ok(result)
479}
480
481// ─────────────────────────────────────────────────────────────────────────────
482// Snowball Sampling
483// ─────────────────────────────────────────────────────────────────────────────
484
485/// Snowball (BFS-neighbourhood) sampling.
486///
487/// Starting from the given `seed_nodes`, collects all nodes reachable within
488/// `n_hops` hops.  This is equivalent to an ego-network expansion.
489///
490/// # Parameters
491/// - `adjacency`  – unweighted adjacency list.
492/// - `seed_nodes` – starting node indices.
493/// - `n_hops`     – number of BFS expansion steps (0 = seed nodes only).
494///
495/// # Returns
496/// Sorted `Vec<usize>` of all nodes within `n_hops` hops of any seed node.
497///
498/// # Errors
499/// Returns [`GraphError::InvalidParameter`] if any seed node index is
500/// out of range or if the adjacency list is empty.
501pub fn snowball_sampling(
502    adjacency: &[Vec<usize>],
503    seed_nodes: &[usize],
504    n_hops: usize,
505) -> Result<Vec<usize>> {
506    let n = adjacency.len();
507    if n == 0 {
508        return Err(GraphError::invalid_parameter(
509            "adjacency",
510            "empty",
511            "graph must have at least one node",
512        ));
513    }
514    for &s in seed_nodes {
515        if s >= n {
516            return Err(GraphError::invalid_parameter(
517                "seed_node",
518                s,
519                format!("must be < n_nodes ({})", n),
520            ));
521        }
522    }
523
524    let mut visited: HashSet<usize> = seed_nodes.iter().cloned().collect();
525    let mut frontier: Vec<usize> = seed_nodes.to_vec();
526
527    for _ in 0..n_hops {
528        let mut next_frontier: Vec<usize> = Vec::new();
529        for &u in &frontier {
530            for &v in &adjacency[u] {
531                if visited.insert(v) {
532                    next_frontier.push(v);
533                }
534            }
535        }
536        if next_frontier.is_empty() {
537            break;
538        }
539        frontier = next_frontier;
540    }
541
542    let mut result: Vec<usize> = visited.into_iter().collect();
543    result.sort_unstable();
544    Ok(result)
545}
546
547// ─────────────────────────────────────────────────────────────────────────────
548// Induced Subgraph
549// ─────────────────────────────────────────────────────────────────────────────
550
551/// Extract the induced subgraph on a set of nodes.
552///
553/// Given a weighted adjacency list and a set of node indices, returns:
554/// - A new weighted adjacency list on the *re-indexed* subgraph (nodes are
555///   re-numbered 0..node_set.len() in the order they appear after sorting).
556/// - A mapping `original_indices[i]` = original node index of subgraph node `i`.
557///
558/// Only edges where **both** endpoints are in `node_set` are retained.
559///
560/// # Parameters
561/// - `adjacency` – weighted adjacency list of the full graph.
562/// - `node_set`  – node indices to include (may contain duplicates; duplicates
563///   are silently deduplicated).
564///
565/// # Returns
566/// `(subgraph_adjacency, original_indices)` where:
567/// - `subgraph_adjacency[i]` is a list of `(j, weight)` pairs in subgraph
568///   coordinates.
569/// - `original_indices[i]` is the original node index for subgraph node `i`.
570///
571/// # Errors
572/// Returns [`GraphError::InvalidParameter`] if any node index in `node_set`
573/// is out of range.
574pub fn induced_subgraph(
575    adjacency: &[Vec<(usize, f64)>],
576    node_set: &[usize],
577) -> Result<InducedSubgraphResult> {
578    let n = adjacency.len();
579    for &v in node_set {
580        if v >= n {
581            return Err(GraphError::invalid_parameter(
582                "node_set entry",
583                v,
584                format!("must be < n_nodes ({})", n),
585            ));
586        }
587    }
588
589    // Deduplicate and sort to get a stable ordering.
590    let mut original_indices: Vec<usize> = {
591        let mut s: Vec<usize> = node_set.to_vec();
592        s.sort_unstable();
593        s.dedup();
594        s
595    };
596    original_indices.sort_unstable();
597
598    let sub_n = original_indices.len();
599
600    // Build reverse map: original_index → subgraph_index.
601    let mut rev_map: HashMap<usize, usize> = HashMap::with_capacity(sub_n);
602    for (sub_i, &orig_i) in original_indices.iter().enumerate() {
603        rev_map.insert(orig_i, sub_i);
604    }
605
606    // Build subgraph adjacency.
607    let mut sub_adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); sub_n];
608    for (sub_i, &orig_i) in original_indices.iter().enumerate() {
609        for &(orig_j, w) in &adjacency[orig_i] {
610            if let Some(&sub_j) = rev_map.get(&orig_j) {
611                sub_adj[sub_i].push((sub_j, w));
612            }
613        }
614    }
615
616    Ok((sub_adj, original_indices))
617}
618
619// ─────────────────────────────────────────────────────────────────────────────
620// Tests
621// ─────────────────────────────────────────────────────────────────────────────
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    // ── helpers ────────────────────────────────────────────────────────────
628
629    /// Path graph  0 – 1 – 2 – … – (n-1)  (unweighted)
630    fn path_adj(n: usize) -> Vec<Vec<usize>> {
631        let mut adj = vec![vec![]; n];
632        for i in 0..n.saturating_sub(1) {
633            adj[i].push(i + 1);
634            adj[i + 1].push(i);
635        }
636        adj
637    }
638
639    /// Two cliques of size k connected by a single bridge (unweighted)
640    fn two_clique_adj(k: usize) -> Vec<Vec<usize>> {
641        let n = 2 * k;
642        let mut adj = vec![vec![]; n];
643        for i in 0..k {
644            for j in (i + 1)..k {
645                adj[i].push(j);
646                adj[j].push(i);
647            }
648        }
649        for i in 0..k {
650            for j in (i + 1)..k {
651                adj[k + i].push(k + j);
652                adj[k + j].push(k + i);
653            }
654        }
655        // Bridge: 0 — k
656        adj[0].push(k);
657        adj[k].push(0);
658        adj
659    }
660
661    /// Weighted cycle  0–1–2–…–(n-1)–0
662    fn weighted_cycle(n: usize) -> Vec<Vec<(usize, f64)>> {
663        let mut adj = vec![vec![]; n];
664        for i in 0..n {
665            let j = (i + 1) % n;
666            adj[i].push((j, 1.0));
667            adj[j].push((i, 1.0));
668        }
669        adj
670    }
671
672    // ── random_walk ────────────────────────────────────────────────────────
673
674    #[test]
675    fn test_random_walk_length() {
676        let adj = path_adj(10);
677        let walk = random_walk(&adj, 0, 8, 42).expect("random_walk");
678        assert!(walk.len() <= 8, "walk too long: {}", walk.len());
679        assert_eq!(walk[0], 0, "must start at start_node");
680    }
681
682    #[test]
683    fn test_random_walk_all_valid_nodes() {
684        let adj = two_clique_adj(5);
685        let walk = random_walk(&adj, 0, 20, 7).expect("random_walk");
686        let n = adj.len();
687        for &node in &walk {
688            assert!(node < n, "node {} out of range", node);
689        }
690    }
691
692    #[test]
693    fn test_random_walk_isolated_node_stops_early() {
694        // Node 0 has no neighbours.
695        let adj: Vec<Vec<usize>> = vec![vec![], vec![0]];
696        let walk = random_walk(&adj, 0, 5, 0).expect("random_walk");
697        // Should stop after the first step (no neighbours).
698        assert_eq!(walk, vec![0]);
699    }
700
701    #[test]
702    fn test_random_walk_zero_length() {
703        let adj = path_adj(5);
704        let walk = random_walk(&adj, 0, 0, 0).expect("random_walk");
705        assert!(walk.is_empty());
706    }
707
708    #[test]
709    fn test_random_walk_invalid_start() {
710        let adj = path_adj(5);
711        assert!(random_walk(&adj, 99, 5, 0).is_err());
712    }
713
714    #[test]
715    fn test_random_walk_consecutive_valid_edges() {
716        // Every consecutive pair in the walk must be an edge.
717        let adj = two_clique_adj(4);
718        let walk = random_walk(&adj, 0, 30, 123).expect("random_walk");
719        for window in walk.windows(2) {
720            let u = window[0];
721            let v = window[1];
722            assert!(
723                adj[u].contains(&v),
724                "edge ({}, {}) does not exist in adjacency list",
725                u,
726                v
727            );
728        }
729    }
730
731    // ── node2vec_walk ──────────────────────────────────────────────────────
732
733    #[test]
734    fn test_node2vec_walk_length() {
735        let adj = weighted_cycle(8);
736        let walk = node2vec_walk(&adj, 0, 10, 1.0, 1.0, 42).expect("node2vec_walk");
737        assert!(walk.len() <= 10);
738        assert_eq!(walk[0], 0);
739    }
740
741    #[test]
742    fn test_node2vec_walk_all_valid_nodes() {
743        let adj = weighted_cycle(6);
744        let n = adj.len();
745        let walk = node2vec_walk(&adj, 2, 20, 2.0, 0.5, 77).expect("node2vec_walk");
746        for &v in &walk {
747            assert!(v < n, "invalid node index {}", v);
748        }
749    }
750
751    #[test]
752    fn test_node2vec_walk_consecutive_edges() {
753        let adj = weighted_cycle(6);
754        let walk = node2vec_walk(&adj, 0, 15, 1.0, 1.0, 0).expect("node2vec_walk");
755        let unweighted: Vec<Vec<usize>> = adj
756            .iter()
757            .map(|nbrs| nbrs.iter().map(|&(v, _)| v).collect())
758            .collect();
759        for w in walk.windows(2) {
760            let u = w[0];
761            let v = w[1];
762            assert!(unweighted[u].contains(&v), "({}, {}) not an edge", u, v);
763        }
764    }
765
766    #[test]
767    fn test_node2vec_walk_invalid_p() {
768        let adj = weighted_cycle(4);
769        assert!(node2vec_walk(&adj, 0, 5, 0.0, 1.0, 0).is_err());
770        assert!(node2vec_walk(&adj, 0, 5, -1.0, 1.0, 0).is_err());
771    }
772
773    #[test]
774    fn test_node2vec_walk_invalid_q() {
775        let adj = weighted_cycle(4);
776        assert!(node2vec_walk(&adj, 0, 5, 1.0, 0.0, 0).is_err());
777    }
778
779    #[test]
780    fn test_node2vec_walk_zero_length() {
781        let adj = weighted_cycle(4);
782        let walk = node2vec_walk(&adj, 0, 0, 1.0, 1.0, 0).expect("node2vec_walk");
783        assert!(walk.is_empty());
784    }
785
786    #[test]
787    fn test_node2vec_walk_length_one() {
788        let adj = weighted_cycle(4);
789        let walk = node2vec_walk(&adj, 1, 1, 1.0, 1.0, 0).expect("node2vec_walk");
790        assert_eq!(walk, vec![1]);
791    }
792
793    // ── frontier_sampling ──────────────────────────────────────────────────
794
795    #[test]
796    fn test_frontier_sampling_basic() {
797        let adj = two_clique_adj(5);
798        let n = adj.len();
799        let sample = frontier_sampling(&adj, n, 6, 42).expect("frontier_sampling");
800        assert_eq!(sample.len(), 6);
801        // All returned nodes must be valid.
802        for &v in &sample {
803            assert!(v < n);
804        }
805        // No duplicates.
806        let set: HashSet<usize> = sample.iter().cloned().collect();
807        assert_eq!(set.len(), sample.len());
808    }
809
810    #[test]
811    fn test_frontier_sampling_full_graph() {
812        let adj = path_adj(5);
813        let sample = frontier_sampling(&adj, 5, 5, 0).expect("frontier_sampling");
814        assert_eq!(sample.len(), 5);
815    }
816
817    #[test]
818    fn test_frontier_sampling_zero_size() {
819        let adj = path_adj(5);
820        let sample = frontier_sampling(&adj, 5, 0, 0).expect("frontier_sampling");
821        assert!(sample.is_empty());
822    }
823
824    #[test]
825    fn test_frontier_sampling_invalid_n_nodes() {
826        let adj: Vec<Vec<usize>> = vec![];
827        assert!(frontier_sampling(&adj, 0, 1, 0).is_err());
828    }
829
830    #[test]
831    fn test_frontier_sampling_sample_exceeds_n() {
832        let adj = path_adj(3);
833        assert!(frontier_sampling(&adj, 3, 5, 0).is_err());
834    }
835
836    #[test]
837    fn test_frontier_sampling_sorted_output() {
838        let adj = two_clique_adj(4);
839        let n = adj.len();
840        let sample = frontier_sampling(&adj, n, 5, 99).expect("frontier_sampling");
841        let mut sorted = sample.clone();
842        sorted.sort_unstable();
843        assert_eq!(sample, sorted, "output must be sorted");
844    }
845
846    // ── forest_fire_sampling ───────────────────────────────────────────────
847
848    #[test]
849    fn test_forest_fire_basic() {
850        let adj = two_clique_adj(5);
851        let n = adj.len();
852        let sample = forest_fire_sampling(&adj, n, 6, 0.7, 42).expect("forest_fire");
853        assert_eq!(sample.len(), 6);
854        for &v in &sample {
855            assert!(v < n);
856        }
857        let set: HashSet<usize> = sample.iter().cloned().collect();
858        assert_eq!(set.len(), sample.len());
859    }
860
861    #[test]
862    fn test_forest_fire_full_graph() {
863        let adj = path_adj(4);
864        let sample = forest_fire_sampling(&adj, 4, 4, 0.5, 0).expect("forest_fire");
865        assert_eq!(sample.len(), 4);
866    }
867
868    #[test]
869    fn test_forest_fire_zero_size() {
870        let adj = path_adj(5);
871        let sample = forest_fire_sampling(&adj, 5, 0, 0.5, 0).expect("forest_fire");
872        assert!(sample.is_empty());
873    }
874
875    #[test]
876    fn test_forest_fire_invalid_prob() {
877        let adj = path_adj(5);
878        assert!(forest_fire_sampling(&adj, 5, 3, 0.0, 0).is_err());
879        assert!(forest_fire_sampling(&adj, 5, 3, 1.0, 0).is_err());
880        assert!(forest_fire_sampling(&adj, 5, 3, -0.5, 0).is_err());
881    }
882
883    #[test]
884    fn test_forest_fire_sorted_output() {
885        let adj = two_clique_adj(4);
886        let n = adj.len();
887        let sample = forest_fire_sampling(&adj, n, 5, 0.6, 13).expect("forest_fire");
888        let mut sorted = sample.clone();
889        sorted.sort_unstable();
890        assert_eq!(sample, sorted);
891    }
892
893    // ── snowball_sampling ──────────────────────────────────────────────────
894
895    #[test]
896    fn test_snowball_sampling_zero_hops() {
897        let adj = path_adj(8);
898        let sample = snowball_sampling(&adj, &[3], 0).expect("snowball");
899        assert_eq!(sample, vec![3]);
900    }
901
902    #[test]
903    fn test_snowball_sampling_one_hop_path() {
904        let adj = path_adj(6);
905        // From node 3: neighbours are 2 and 4.
906        let sample = snowball_sampling(&adj, &[3], 1).expect("snowball");
907        let set: HashSet<usize> = sample.iter().cloned().collect();
908        assert!(set.contains(&2));
909        assert!(set.contains(&3));
910        assert!(set.contains(&4));
911        assert_eq!(sample.len(), 3);
912    }
913
914    #[test]
915    fn test_snowball_sampling_two_hops_path() {
916        let adj = path_adj(7);
917        // From node 3, 2 hops: nodes 1, 2, 3, 4, 5.
918        let sample = snowball_sampling(&adj, &[3], 2).expect("snowball");
919        let set: HashSet<usize> = sample.iter().cloned().collect();
920        for v in [1, 2, 3, 4, 5] {
921            assert!(set.contains(&v), "node {} missing", v);
922        }
923    }
924
925    #[test]
926    fn test_snowball_sampling_multiple_seeds() {
927        let adj = path_adj(10);
928        // Seeds 0 and 9 (endpoints) with 1 hop each.
929        let sample = snowball_sampling(&adj, &[0, 9], 1).expect("snowball");
930        let set: HashSet<usize> = sample.iter().cloned().collect();
931        // From 0: {0, 1}; From 9: {8, 9}.
932        assert!(set.contains(&0) && set.contains(&1));
933        assert!(set.contains(&8) && set.contains(&9));
934    }
935
936    #[test]
937    fn test_snowball_sampling_empty_adj() {
938        let adj: Vec<Vec<usize>> = vec![];
939        assert!(snowball_sampling(&adj, &[0], 1).is_err());
940    }
941
942    #[test]
943    fn test_snowball_sampling_out_of_range_seed() {
944        let adj = path_adj(4);
945        assert!(snowball_sampling(&adj, &[99], 1).is_err());
946    }
947
948    #[test]
949    fn test_snowball_sampling_sorted_no_duplicates() {
950        let adj = two_clique_adj(4);
951        let sample = snowball_sampling(&adj, &[0, 1], 2).expect("snowball");
952        let mut sorted = sample.clone();
953        sorted.sort_unstable();
954        sorted.dedup();
955        assert_eq!(sample, sorted, "output must be sorted with no duplicates");
956    }
957
958    // ── induced_subgraph ───────────────────────────────────────────────────
959
960    #[test]
961    fn test_induced_subgraph_basic() {
962        //  0 ─ 1 ─ 2 ─ 3  (path graph, weighted)
963        let adj = vec![
964            vec![(1, 1.0)],
965            vec![(0, 1.0), (2, 1.0)],
966            vec![(1, 1.0), (3, 1.0)],
967            vec![(2, 1.0)],
968        ];
969        // Take nodes {1, 2}.
970        let (sub, orig) = induced_subgraph(&adj, &[1, 2]).expect("induced_subgraph");
971        assert_eq!(orig, vec![1, 2]);
972        assert_eq!(sub.len(), 2);
973        // Subgraph node 0 (original 1) → subgraph node 1 (original 2) with w=1.0.
974        assert_eq!(sub[0].len(), 1);
975        assert_eq!(sub[0][0], (1, 1.0));
976        // Subgraph node 1 (original 2) → subgraph node 0 (original 1).
977        assert_eq!(sub[1].len(), 1);
978        assert_eq!(sub[1][0], (0, 1.0));
979    }
980
981    #[test]
982    fn test_induced_subgraph_no_internal_edges() {
983        // Star graph centred at 0.
984        let adj = vec![
985            vec![(1, 1.0), (2, 1.0), (3, 1.0)],
986            vec![(0, 1.0)],
987            vec![(0, 1.0)],
988            vec![(0, 1.0)],
989        ];
990        // Take leaves only: {1, 2, 3}. No edges among them.
991        let (sub, orig) = induced_subgraph(&adj, &[1, 2, 3]).expect("induced_subgraph");
992        assert_eq!(orig, vec![1, 2, 3]);
993        for nbrs in &sub {
994            assert!(
995                nbrs.is_empty(),
996                "leaves should have no edges among themselves"
997            );
998        }
999    }
1000
1001    #[test]
1002    fn test_induced_subgraph_full_graph() {
1003        let adj = vec![vec![(1, 2.0)], vec![(0, 2.0), (2, 3.0)], vec![(1, 3.0)]];
1004        let (sub, orig) = induced_subgraph(&adj, &[0, 1, 2]).expect("induced_subgraph");
1005        assert_eq!(orig, vec![0, 1, 2]);
1006        // Subgraph should equal the original.
1007        assert_eq!(sub, adj);
1008    }
1009
1010    #[test]
1011    fn test_induced_subgraph_duplicates_in_node_set() {
1012        let adj = vec![vec![(1, 1.0)], vec![(0, 1.0), (2, 1.0)], vec![(1, 1.0)]];
1013        // Passing duplicates: {0, 0, 1} → should give sub on {0, 1}.
1014        let (sub, orig) = induced_subgraph(&adj, &[0, 0, 1]).expect("induced_subgraph");
1015        assert_eq!(orig, vec![0, 1]);
1016        assert_eq!(sub.len(), 2);
1017    }
1018
1019    #[test]
1020    fn test_induced_subgraph_out_of_range() {
1021        let adj = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
1022        assert!(induced_subgraph(&adj, &[0, 99]).is_err());
1023    }
1024
1025    #[test]
1026    fn test_induced_subgraph_empty_node_set() {
1027        let adj = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
1028        let (sub, orig) = induced_subgraph(&adj, &[]).expect("induced_subgraph");
1029        assert!(sub.is_empty());
1030        assert!(orig.is_empty());
1031    }
1032
1033    #[test]
1034    fn test_induced_subgraph_preserves_weights() {
1035        //  0 ──(5.0)── 1 ──(3.0)── 2
1036        let adj = vec![vec![(1, 5.0)], vec![(0, 5.0), (2, 3.0)], vec![(1, 3.0)]];
1037        let (sub, _) = induced_subgraph(&adj, &[0, 1]).expect("induced_subgraph");
1038        // sub[0] should contain (1, 5.0) in subgraph coords.
1039        assert_eq!(sub[0], vec![(1, 5.0)]);
1040        assert_eq!(sub[1], vec![(0, 5.0)]);
1041    }
1042}