use crate::base::Graph;
use crate::error::Result;
use scirs2_core::random::{Rng, RngExt};
pub fn biased_random_walk(
graph: &Graph<usize, f64>,
start: usize,
length: usize,
p: f64,
q: f64,
) -> Result<Vec<usize>> {
let mut rng = scirs2_core::random::rng();
biased_random_walk_with_rng(graph, start, length, p, q, &mut rng)
}
pub fn biased_random_walk_with_rng<R: Rng>(
graph: &Graph<usize, f64>,
start: usize,
length: usize,
p: f64,
q: f64,
rng: &mut R,
) -> Result<Vec<usize>> {
use crate::error::GraphError;
if !graph.has_node(&start) {
return Err(GraphError::node_not_found(format!("{start}")));
}
if p <= 0.0 || q <= 0.0 {
return Err(GraphError::InvalidParameter {
param: "p/q".to_string(),
value: format!("p={p}, q={q}"),
expected: "p > 0 and q > 0".to_string(),
context: "biased_random_walk".to_string(),
});
}
let mut walk = vec![start];
if length <= 1 {
return Ok(walk);
}
let first_neighbors = graph.neighbors(&start)?;
if first_neighbors.is_empty() {
return Ok(walk);
}
let first_idx = rng.random_range(0..first_neighbors.len());
let mut current = first_neighbors[first_idx];
walk.push(current);
for _ in 2..length {
let neighbors = graph.neighbors(¤t)?;
if neighbors.is_empty() {
break;
}
let prev = walk[walk.len() - 2];
let weights: Vec<f64> = neighbors
.iter()
.map(|&nbr| {
if nbr == prev {
1.0 / p
} else if graph.has_edge(&prev, &nbr) {
1.0
} else {
1.0 / q
}
})
.collect();
let total: f64 = weights.iter().sum();
if total <= 0.0 {
break;
}
let mut r = rng.random::<f64>() * total;
let mut chosen = neighbors.len() - 1;
for (i, &w) in weights.iter().enumerate() {
r -= w;
if r <= 0.0 {
chosen = i;
break;
}
}
current = neighbors[chosen];
walk.push(current);
}
Ok(walk)
}
pub fn generate_walks(
graph: &Graph<usize, f64>,
num_walks: usize,
walk_length: usize,
p: f64,
q: f64,
) -> Result<Vec<Vec<usize>>> {
let mut rng = scirs2_core::random::rng();
let mut all_walks = Vec::new();
for &node in graph.nodes() {
for _ in 0..num_walks {
let walk = biased_random_walk_with_rng(graph, node, walk_length, p, q, &mut rng)?;
all_walks.push(walk);
}
}
Ok(all_walks)
}
pub fn deepwalk_walks(
graph: &Graph<usize, f64>,
num_walks: usize,
walk_length: usize,
) -> Result<Vec<Vec<usize>>> {
generate_walks(graph, num_walks, walk_length, 1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_square_graph() -> Graph<usize, f64> {
let mut g: Graph<usize, f64> = Graph::new();
for i in 0..4 {
g.add_node(i);
}
let _ = g.add_edge(0, 1, 1.0);
let _ = g.add_edge(1, 2, 1.0);
let _ = g.add_edge(2, 3, 1.0);
let _ = g.add_edge(3, 0, 1.0);
g
}
#[test]
fn test_biased_random_walk_length() {
let g = make_square_graph();
let walk = biased_random_walk(&g, 0, 10, 1.0, 1.0).expect("walk should succeed");
assert!(walk.len() <= 10);
assert!(!walk.is_empty());
assert_eq!(walk[0], 0);
}
#[test]
fn test_biased_random_walk_nodes_valid() {
let g = make_square_graph();
let walk = biased_random_walk(&g, 0, 20, 1.0, 1.0).expect("walk should succeed");
for &node in &walk {
assert!(node < 4, "walk should only visit nodes 0..3, got {node}");
}
}
#[test]
fn test_biased_random_walk_invalid_node() {
let g = make_square_graph();
let result = biased_random_walk(&g, 99, 10, 1.0, 1.0);
assert!(result.is_err(), "should fail for non-existent node");
}
#[test]
fn test_biased_random_walk_invalid_params() {
let g = make_square_graph();
let result = biased_random_walk(&g, 0, 10, -1.0, 1.0);
assert!(result.is_err(), "p <= 0 should return error");
}
#[test]
fn test_generate_walks_count() {
let g = make_square_graph();
let walks = generate_walks(&g, 3, 5, 1.0, 1.0).expect("should succeed");
assert_eq!(walks.len(), 12);
}
#[test]
fn test_generate_walks_all_start_valid() {
let g = make_square_graph();
let walks = generate_walks(&g, 2, 8, 2.0, 0.5).expect("should succeed");
for walk in &walks {
for &node in walk {
assert!(node < 4, "walk node {node} out of range");
}
}
}
#[test]
fn test_deepwalk_walks_equivalence() {
let g = make_square_graph();
let dw = deepwalk_walks(&g, 3, 6).expect("should succeed");
assert_eq!(dw.len(), 12, "4 nodes × 3 walks = 12");
}
#[test]
fn test_biased_walk_length_one() {
let g = make_square_graph();
let walk = biased_random_walk(&g, 0, 1, 1.0, 1.0).expect("walk should succeed");
assert_eq!(walk, vec![0]);
}
#[test]
fn test_biased_walk_isolated_node() {
let mut g: Graph<usize, f64> = Graph::new();
g.add_node(0);
g.add_node(1);
let _ = g.add_edge(0, 1, 1.0);
g.add_node(2);
let walk = biased_random_walk(&g, 2, 10, 1.0, 1.0).expect("walk should succeed");
assert_eq!(walk.len(), 1);
assert_eq!(walk[0], 2);
}
}