use crate::error::{GnnError, GnnResult};
use crate::handle::LcgRng;
pub type GnnRng = LcgRng;
#[derive(Debug, Clone)]
pub struct NeighborSampleConfig {
pub n_hops: usize,
pub n_neighbors: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct SampledSubgraph {
pub node_ids: Vec<usize>,
pub edge_src: Vec<usize>,
pub edge_dst: Vec<usize>,
pub seed_nodes: Vec<usize>,
}
impl SampledSubgraph {
#[must_use]
pub fn global_to_local(&self, global_id: usize) -> Option<usize> {
self.node_ids.iter().position(|&n| n == global_id)
}
#[must_use]
pub fn n_nodes(&self) -> usize {
self.node_ids.len()
}
#[must_use]
pub fn n_edges(&self) -> usize {
self.edge_src.len()
}
}
pub struct NeighborSampler {
adj: Vec<Vec<usize>>,
config: NeighborSampleConfig,
}
impl NeighborSampler {
pub fn new(adj: Vec<Vec<usize>>, config: NeighborSampleConfig) -> GnnResult<Self> {
if adj.is_empty() {
return Err(GnnError::EmptyGraph);
}
if config.n_hops == 0 {
return Err(GnnError::InvalidLayerConfig(
"n_hops must be at least 1".to_string(),
));
}
if config.n_neighbors.len() != config.n_hops {
return Err(GnnError::InvalidLayerConfig(format!(
"n_neighbors length {} != n_hops {}",
config.n_neighbors.len(),
config.n_hops
)));
}
for &k in &config.n_neighbors {
if k == 0 {
return Err(GnnError::InvalidLayerConfig(
"n_neighbors per hop must be > 0".to_string(),
));
}
}
let n = adj.len();
for nbrs in adj.iter() {
for &u in nbrs {
if u >= n {
return Err(GnnError::NodeIndexOutOfRange { idx: u, n_nodes: n });
}
}
}
Ok(Self { adj, config })
}
#[must_use]
pub fn n_nodes(&self) -> usize {
self.adj.len()
}
pub fn sample(&self, seeds: &[usize], rng: &mut GnnRng) -> GnnResult<SampledSubgraph> {
if seeds.is_empty() {
return Err(GnnError::InvalidLayerConfig(
"seeds must be non-empty".to_string(),
));
}
let n = self.adj.len();
for &s in seeds {
if s >= n {
return Err(GnnError::NodeIndexOutOfRange { idx: s, n_nodes: n });
}
}
let mut in_set = vec![false; n];
let mut node_ids: Vec<usize> = Vec::new();
let add_node = |node: usize, in_set: &mut Vec<bool>, node_ids: &mut Vec<usize>| {
if !in_set[node] {
in_set[node] = true;
node_ids.push(node);
}
};
for &s in seeds {
add_node(s, &mut in_set, &mut node_ids);
}
let mut raw_edges: Vec<(usize, usize)> = Vec::new();
let mut frontier: Vec<usize> = seeds.to_vec();
for hop in 0..self.config.n_hops {
let k = self.config.n_neighbors[hop];
let mut next_frontier_set: Vec<bool> = vec![false; n];
let mut next_frontier: Vec<usize> = Vec::new();
for &v in &frontier {
let nbrs = &self.adj[v];
if nbrs.is_empty() {
continue;
}
let sample_count = k.min(nbrs.len());
let sampled = sample_without_replacement(nbrs, sample_count, rng);
for &u in &sampled {
raw_edges.push((u, v));
add_node(u, &mut in_set, &mut node_ids);
if !next_frontier_set[u] {
next_frontier_set[u] = true;
next_frontier.push(u);
}
}
}
frontier = next_frontier;
}
let mut global_to_local = vec![usize::MAX; n];
for (local_idx, &global_id) in node_ids.iter().enumerate() {
global_to_local[global_id] = local_idx;
}
let mut seen_edges: std::collections::HashSet<(usize, usize)> =
std::collections::HashSet::new();
let mut edge_src: Vec<usize> = Vec::with_capacity(raw_edges.len());
let mut edge_dst: Vec<usize> = Vec::with_capacity(raw_edges.len());
for (gsrc, gdst) in raw_edges {
let lsrc = global_to_local[gsrc];
let ldst = global_to_local[gdst];
debug_assert!(lsrc != usize::MAX, "sampled node not in set");
debug_assert!(ldst != usize::MAX, "seed node not in set");
if seen_edges.insert((lsrc, ldst)) {
edge_src.push(lsrc);
edge_dst.push(ldst);
}
}
let seed_nodes: Vec<usize> = seeds.iter().map(|&s| global_to_local[s]).collect();
Ok(SampledSubgraph {
node_ids,
edge_src,
edge_dst,
seed_nodes,
})
}
}
fn sample_without_replacement(items: &[usize], k: usize, rng: &mut LcgRng) -> Vec<usize> {
debug_assert!(k <= items.len());
let mut buf: Vec<usize> = items.to_vec();
let n = buf.len();
for i in 0..k {
let j = i + rng.next_usize(n - i);
buf.swap(i, j);
}
buf[..k].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_ring(n: usize) -> Vec<Vec<usize>> {
(0..n).map(|v| vec![(v + 1) % n]).collect()
}
fn make_star(n: usize) -> Vec<Vec<usize>> {
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
for v in 1..n {
adj[0].push(v);
adj[v].push(0);
}
adj
}
fn make_complete(n: usize) -> Vec<Vec<usize>> {
(0..n)
.map(|v| (0..n).filter(|&u| u != v).collect())
.collect()
}
fn rng() -> LcgRng {
LcgRng::new(42)
}
#[test]
fn sample_returns_seeds() {
let adj = make_ring(10);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![2],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let seeds = vec![3, 7];
let sg = sampler
.sample(&seeds, &mut rng())
.expect("value should be present");
assert!(sg.node_ids.contains(&3));
assert!(sg.node_ids.contains(&7));
}
#[test]
fn subgraph_includes_all_seeds() {
let adj = make_star(8);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![3],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let seeds: Vec<usize> = (1..5).collect();
let sg = sampler
.sample(&seeds, &mut rng())
.expect("value should be present");
for &s in &seeds {
assert!(sg.node_ids.contains(&s), "seed {s} missing");
}
}
#[test]
fn node_ids_include_seeds() {
let adj = make_complete(6);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![2],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let seeds = vec![0, 5];
let sg = sampler
.sample(&seeds, &mut rng())
.expect("value should be present");
let local_0 = sg.global_to_local(0).expect("node 0 in subgraph");
let local_5 = sg.global_to_local(5).expect("node 5 in subgraph");
assert!(sg.seed_nodes.contains(&local_0));
assert!(sg.seed_nodes.contains(&local_5));
}
#[test]
fn edge_src_in_range() {
let adj = make_complete(5);
let config = NeighborSampleConfig {
n_hops: 2,
n_neighbors: vec![2, 2],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let sg = sampler
.sample(&[0], &mut rng())
.expect("value should be present");
let n_local = sg.n_nodes();
for &s in &sg.edge_src {
assert!(s < n_local, "edge_src {s} >= n_nodes {n_local}");
}
}
#[test]
fn edge_dst_in_range() {
let adj = make_complete(5);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![3],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let sg = sampler
.sample(&[2], &mut rng())
.expect("value should be present");
let n_local = sg.n_nodes();
for &d in &sg.edge_dst {
assert!(d < n_local, "edge_dst {d} >= n_nodes {n_local}");
}
}
#[test]
fn n_neighbors_bounded() {
let adj = make_ring(5);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![10],
}; let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let sg = sampler
.sample(&[0], &mut rng())
.expect("value should be present");
assert!(sg.n_edges() <= 1, "expected ≤ 1 edge, got {}", sg.n_edges());
}
#[test]
fn isolated_node_works() {
let adj = vec![vec![], vec![], vec![]]; let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![5],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let sg = sampler
.sample(&[1], &mut rng())
.expect("value should be present");
assert_eq!(sg.n_nodes(), 1);
assert_eq!(sg.n_edges(), 0);
}
#[test]
fn empty_seeds_error() {
let adj = make_ring(4);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![2],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let result = sampler.sample(&[], &mut rng());
assert!(result.is_err());
}
#[test]
fn one_hop_sample() {
let adj = make_star(10); let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![3],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
let sg = sampler
.sample(&[0], &mut rng())
.expect("value should be present");
assert!(sg.n_nodes() <= 4, "seed + ≤3 neighbors");
assert!(sg.n_nodes() >= 1);
}
#[test]
fn two_hop_sample_larger_than_one_hop() {
let adj = make_complete(8);
let config1 = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![2],
};
let config2 = NeighborSampleConfig {
n_hops: 2,
n_neighbors: vec![2, 2],
};
let s1 = NeighborSampler::new(adj.clone(), config1).expect("value should be present");
let s2 = NeighborSampler::new(adj, config2).expect("new should succeed");
let sg1 = s1
.sample(&[0], &mut LcgRng::new(7))
.expect("value should be present");
let sg2 = s2
.sample(&[0], &mut LcgRng::new(7))
.expect("value should be present");
assert!(
sg2.n_nodes() >= sg1.n_nodes(),
"2-hop ({}) should have ≥ nodes than 1-hop ({})",
sg2.n_nodes(),
sg1.n_nodes()
);
}
#[test]
fn n_nodes_accessor() {
let adj = make_ring(12);
let config = NeighborSampleConfig {
n_hops: 1,
n_neighbors: vec![1],
};
let sampler = NeighborSampler::new(adj, config).expect("new should succeed");
assert_eq!(sampler.n_nodes(), 12);
}
#[test]
fn invalid_n_hops_zero() {
let adj = make_ring(5);
let config = NeighborSampleConfig {
n_hops: 0,
n_neighbors: vec![],
};
let result = NeighborSampler::new(adj, config);
assert!(result.is_err());
}
}