use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
#[derive(Debug)]
pub struct PinSageConfig {
pub n_nodes: usize,
pub n_walks: usize,
pub walk_len: usize,
pub n_neighbors: usize,
pub d_input: usize,
pub d_output: usize,
}
#[derive(Debug)]
pub struct PinSage {
w: Vec<f32>,
b: Vec<f32>,
config: PinSageConfig,
}
impl PinSage {
pub fn new(config: PinSageConfig, rng: &mut LcgRng) -> RecsysResult<Self> {
if config.n_walks == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_walks must be > 0".into(),
});
}
if config.d_input == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: 0 });
}
if config.d_output == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: 0 });
}
let d_in2 = 2 * config.d_input;
let d_out = config.d_output;
let scale = (2.0 / (d_in2 + d_out) as f32).sqrt();
let w: Vec<f32> = (0..d_out * d_in2)
.map(|_| rng.next_normal() * scale)
.collect();
let b: Vec<f32> = vec![0.0_f32; d_out];
Ok(Self { w, b, config })
}
pub fn sample_neighbors(
&self,
node_id: usize,
adj: &[Vec<usize>],
rng: &mut LcgRng,
) -> RecsysResult<Vec<usize>> {
if node_id >= self.config.n_nodes {
return Err(RecsysError::ItemOutOfBounds {
idx: node_id,
n: self.config.n_nodes,
});
}
let neighbors_of = |node: usize| -> &[usize] {
if node < adj.len() {
adj[node].as_slice()
} else {
&[]
}
};
let start_neighbors = neighbors_of(node_id);
if start_neighbors.is_empty() {
return Ok(Vec::new());
}
let mut visit_count: Vec<u32> = vec![0_u32; self.config.n_nodes];
for _ in 0..self.config.n_walks {
let mut current = node_id;
for _ in 0..self.config.walk_len {
let nbrs = neighbors_of(current);
if nbrs.is_empty() {
break;
}
let next = nbrs[rng.next_usize(nbrs.len())];
if next != node_id && next < self.config.n_nodes {
visit_count[next] = visit_count[next].saturating_add(1);
}
current = next;
}
}
let mut visited: Vec<(u32, usize)> = visit_count
.iter()
.enumerate()
.filter_map(
|(node, &cnt)| {
if cnt > 0 { Some((cnt, node)) } else { None }
},
)
.collect();
visited.sort_unstable_by_key(|&(cnt, _)| std::cmp::Reverse(cnt));
let top_k: Vec<usize> = visited
.iter()
.take(self.config.n_neighbors)
.map(|&(_, node)| node)
.collect();
Ok(top_k)
}
pub fn forward(
&self,
node_id: usize,
node_feats: &[f32],
adj: &[Vec<usize>],
rng: &mut LcgRng,
) -> RecsysResult<Vec<f32>> {
if node_id >= self.config.n_nodes {
return Err(RecsysError::ItemOutOfBounds {
idx: node_id,
n: self.config.n_nodes,
});
}
let expected = self.config.n_nodes * self.config.d_input;
if node_feats.len() != expected {
return Err(RecsysError::DimensionMismatch {
expected,
got: node_feats.len(),
});
}
let d_in = self.config.d_input;
let d_out = self.config.d_output;
let neighbors = self.sample_neighbors(node_id, adj, rng)?;
let pooled: Vec<f32> = if neighbors.is_empty() {
vec![0.0_f32; d_in]
} else {
let mut pool = vec![0.0_f32; d_in];
for &nb in &neighbors {
let feat = &node_feats[nb * d_in..(nb + 1) * d_in];
for k in 0..d_in {
pool[k] += feat[k];
}
}
let inv_n = 1.0 / neighbors.len() as f32;
pool.iter_mut().for_each(|v| *v *= inv_n);
pool
};
let self_feat = &node_feats[node_id * d_in..(node_id + 1) * d_in];
let mut concat = Vec::with_capacity(2 * d_in);
concat.extend_from_slice(self_feat);
concat.extend_from_slice(&pooled);
let d_in2 = 2 * d_in;
let mut out: Vec<f32> = (0..d_out)
.map(|row| {
self.w[row * d_in2..(row + 1) * d_in2]
.iter()
.zip(concat.iter())
.map(|(&wi, &xi)| wi * xi)
.sum::<f32>()
+ self.b[row]
})
.collect();
for v in out.iter_mut() {
if *v < 0.0 {
*v = 0.0;
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_config(n_nodes: usize, n_walks: usize, walk_len: usize) -> PinSageConfig {
PinSageConfig {
n_nodes,
n_walks,
walk_len,
n_neighbors: 3,
d_input: 4,
d_output: 8,
}
}
fn make_model(n_nodes: usize, n_walks: usize, walk_len: usize) -> PinSage {
let mut rng = LcgRng::new(42);
PinSage::new(make_config(n_nodes, n_walks, walk_len), &mut rng)
.expect("model construction must succeed")
}
fn dense_adj(n: usize) -> Vec<Vec<usize>> {
(0..n)
.map(|i| (0..n).filter(|&j| j != i).collect())
.collect()
}
fn make_feats(n_nodes: usize, d_input: usize) -> Vec<f32> {
(0..n_nodes * d_input)
.map(|i| (i as f32) * 0.01 + 0.5)
.collect()
}
#[test]
fn forward_shape() {
let n = 6;
let model = make_model(n, 4, 3);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let out = model
.forward(0, &feats, &adj, &mut rng)
.expect("forward must succeed");
assert_eq!(out.len(), 8, "forward output must have d_output elements");
}
#[test]
fn forward_finite() {
let n = 6;
let model = make_model(n, 4, 3);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let out = model
.forward(2, &feats, &adj, &mut rng)
.expect("forward must succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v.is_finite(), "output[{i}] = {v} must be finite");
}
}
#[test]
fn sample_neighbors_returns_at_most_n() {
let n = 10;
let model = make_model(n, 8, 4);
let adj = dense_adj(n);
let mut rng = LcgRng::new(42);
let neighbors = model
.sample_neighbors(0, &adj, &mut rng)
.expect("sample must succeed");
assert!(
neighbors.len() <= 3,
"sample_neighbors must return at most n_neighbors, got {}",
neighbors.len()
);
}
#[test]
fn isolated_node_works() {
let n = 4;
let model = make_model(n, 4, 3);
let adj: Vec<Vec<usize>> = vec![vec![]; n];
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let out = model
.forward(1, &feats, &adj, &mut rng)
.expect("isolated forward must succeed");
assert_eq!(
out.len(),
8,
"isolated node forward must return d_output elements"
);
for &v in &out {
assert!(v >= 0.0, "ReLU output must be non-negative; got {v}");
}
}
#[test]
fn n_walks_zero_error() {
let mut rng = LcgRng::new(42);
let cfg = PinSageConfig {
n_nodes: 6,
n_walks: 0,
walk_len: 3,
n_neighbors: 2,
d_input: 4,
d_output: 8,
};
let result = PinSage::new(cfg, &mut rng);
assert!(
matches!(result, Err(RecsysError::InvalidConfig { .. })),
"expected InvalidConfig, got: {result:?}"
);
}
#[test]
fn node_out_of_range_error() {
let n = 4;
let model = make_model(n, 4, 2);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let result = model.forward(99, &feats, &adj, &mut rng);
assert!(
matches!(result, Err(RecsysError::ItemOutOfBounds { idx: 99, n: 4 })),
"expected ItemOutOfBounds, got: {result:?}"
);
}
#[test]
fn different_nodes_different_output() {
let n = 8;
let model = make_model(n, 6, 3);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng_a = LcgRng::new(42);
let mut rng_b = LcgRng::new(42);
let out_0 = model
.forward(0, &feats, &adj, &mut rng_a)
.expect("forward node 0");
let out_7 = model
.forward(7, &feats, &adj, &mut rng_b)
.expect("forward node 7");
let diff: f32 = out_0
.iter()
.zip(out_7.iter())
.map(|(&a, &b)| (a - b).abs())
.sum();
assert!(
diff > 1e-6,
"different nodes should have different forward output (diff={diff})"
);
}
#[test]
fn output_relu_nonneg() {
let n = 6;
let model = make_model(n, 4, 3);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let out = model
.forward(3, &feats, &adj, &mut rng)
.expect("forward must succeed");
for (i, &v) in out.iter().enumerate() {
assert!(v >= 0.0, "output[{i}] = {v} must be >= 0 after ReLU");
}
}
#[test]
fn walk_len_1_works() {
let n = 5;
let model = make_model(n, 4, 1);
let adj = dense_adj(n);
let feats = make_feats(n, 4);
let mut rng = LcgRng::new(42);
let out = model
.forward(2, &feats, &adj, &mut rng)
.expect("walk_len=1 must work");
assert_eq!(out.len(), 8);
}
#[test]
fn d_input_zero_error() {
let mut rng = LcgRng::new(42);
let cfg = PinSageConfig {
n_nodes: 4,
n_walks: 2,
walk_len: 2,
n_neighbors: 2,
d_input: 0,
d_output: 4,
};
let result = PinSage::new(cfg, &mut rng);
assert!(
matches!(result, Err(RecsysError::InvalidEmbeddingDim { d: 0 })),
"expected InvalidEmbeddingDim, got: {result:?}"
);
}
#[test]
fn sample_neighbors_node_out_of_range() {
let n = 4;
let model = make_model(n, 2, 2);
let adj = dense_adj(n);
let mut rng = LcgRng::new(42);
let result = model.sample_neighbors(99, &adj, &mut rng);
assert!(
matches!(result, Err(RecsysError::ItemOutOfBounds { idx: 99, n: 4 })),
"expected ItemOutOfBounds, got: {result:?}"
);
}
}