use super::types::RandomWalk;
use crate::base::{DiGraph, EdgeWeight, Graph, Node};
use crate::error::{GraphError, Result};
use scirs2_core::random::rand_prelude::IndexedRandom;
use scirs2_core::random::{Rng, RngExt};
pub struct RandomWalkGenerator<N: Node> {
rng: scirs2_core::random::rngs::ThreadRng,
_phantom: std::marker::PhantomData<N>,
}
impl<N: Node> Default for RandomWalkGenerator<N> {
fn default() -> Self {
Self::new()
}
}
impl<N: Node> RandomWalkGenerator<N> {
pub fn new() -> Self {
RandomWalkGenerator {
rng: scirs2_core::random::rng(),
_phantom: std::marker::PhantomData,
}
}
pub fn simple_random_walk<E, Ix>(
&mut self,
graph: &Graph<N, E, Ix>,
start: &N,
length: usize,
) -> Result<RandomWalk<N>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
if !graph.contains_node(start) {
return Err(GraphError::node_not_found("node"));
}
let mut walk = vec![start.clone()];
let mut current = start.clone();
for _ in 1..length {
let neighbors = graph.neighbors(¤t)?;
if neighbors.is_empty() {
break; }
current = neighbors
.choose(&mut self.rng)
.ok_or(GraphError::AlgorithmError(
"Failed to choose neighbor".to_string(),
))?
.clone();
walk.push(current.clone());
}
Ok(RandomWalk { nodes: walk })
}
pub fn simple_random_walk_digraph<E, Ix>(
&mut self,
graph: &DiGraph<N, E, Ix>,
start: &N,
length: usize,
) -> Result<RandomWalk<N>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
if !graph.contains_node(start) {
return Err(GraphError::node_not_found("node"));
}
let mut walk = vec![start.clone()];
let mut current = start.clone();
for _ in 1..length {
let successors = graph.successors(¤t)?;
if successors.is_empty() {
break; }
current = successors
.choose(&mut self.rng)
.ok_or(GraphError::AlgorithmError(
"Failed to choose successor".to_string(),
))?
.clone();
walk.push(current.clone());
}
Ok(RandomWalk { nodes: walk })
}
pub fn node2vec_walk<E, Ix>(
&mut self,
graph: &Graph<N, E, Ix>,
start: &N,
length: usize,
p: f64,
q: f64,
) -> Result<RandomWalk<N>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight + Into<f64>,
Ix: petgraph::graph::IndexType,
{
if !graph.contains_node(start) {
return Err(GraphError::node_not_found("node"));
}
if p <= 0.0 || q <= 0.0 {
return Err(GraphError::InvalidParameter {
param: "p/q".to_string(),
value: format!("p={p}, q={q}"),
expected: "p > 0 and q > 0".to_string(),
context: "Node2Vec walk parameters".to_string(),
});
}
let mut walk = vec![start.clone()];
if length <= 1 {
return Ok(RandomWalk { nodes: walk });
}
let first_neighbors = graph.neighbors(start)?;
if first_neighbors.is_empty() {
return Ok(RandomWalk { nodes: walk });
}
let mut current = first_neighbors
.choose(&mut self.rng)
.ok_or(GraphError::AlgorithmError(
"Failed to choose first neighbor".to_string(),
))?
.clone();
walk.push(current.clone());
for _ in 2..length {
let current_neighbors = graph.neighbors(¤t)?;
if current_neighbors.is_empty() {
break;
}
let prev = &walk[walk.len() - 2];
let mut weights = Vec::new();
for neighbor in ¤t_neighbors {
let weight = if neighbor == prev {
1.0 / p
} else if graph.has_edge(prev, neighbor) {
1.0
} else {
1.0 / q
};
weights.push(weight);
}
let total_weight: f64 = weights.iter().sum();
if total_weight <= 0.0 {
break;
}
let mut random_value = self.rng.random::<f64>() * total_weight;
let mut selected_index = 0;
for (i, &weight) in weights.iter().enumerate() {
random_value -= weight;
if random_value <= 0.0 {
selected_index = i;
break;
}
}
let next_node = current_neighbors[selected_index].clone();
walk.push(next_node.clone());
current = next_node;
}
Ok(RandomWalk { nodes: walk })
}
pub fn node2vec_walk_digraph<E, Ix>(
&mut self,
graph: &DiGraph<N, E, Ix>,
start: &N,
length: usize,
p: f64,
q: f64,
) -> Result<RandomWalk<N>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight + Into<f64>,
Ix: petgraph::graph::IndexType,
{
if !graph.contains_node(start) {
return Err(GraphError::node_not_found("node"));
}
if p <= 0.0 || q <= 0.0 {
return Err(GraphError::InvalidParameter {
param: "p/q".to_string(),
value: format!("p={p}, q={q}"),
expected: "p > 0 and q > 0".to_string(),
context: "Node2Vec walk parameters".to_string(),
});
}
let mut walk = vec![start.clone()];
if length <= 1 {
return Ok(RandomWalk { nodes: walk });
}
let first_successors = graph.successors(start)?;
if first_successors.is_empty() {
return Ok(RandomWalk { nodes: walk });
}
let mut current = first_successors
.choose(&mut self.rng)
.ok_or(GraphError::AlgorithmError(
"Failed to choose first successor".to_string(),
))?
.clone();
walk.push(current.clone());
for _ in 2..length {
let current_successors = graph.successors(¤t)?;
if current_successors.is_empty() {
break;
}
let prev = &walk[walk.len() - 2];
let mut weights = Vec::new();
for neighbor in ¤t_successors {
let weight = if neighbor == prev {
1.0 / p
} else if graph.has_edge(prev, neighbor) {
1.0
} else {
1.0 / q
};
weights.push(weight);
}
let total_weight: f64 = weights.iter().sum();
if total_weight <= 0.0 {
break;
}
let mut random_value = self.rng.random::<f64>() * total_weight;
let mut selected_index = 0;
for (i, &weight) in weights.iter().enumerate() {
random_value -= weight;
if random_value <= 0.0 {
selected_index = i;
break;
}
}
let next_node = current_successors[selected_index].clone();
walk.push(next_node.clone());
current = next_node;
}
Ok(RandomWalk { nodes: walk })
}
pub fn generate_walks<E, Ix>(
&mut self,
graph: &Graph<N, E, Ix>,
start: &N,
num_walks: usize,
walk_length: usize,
) -> Result<Vec<RandomWalk<N>>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
let mut walks = Vec::new();
for _ in 0..num_walks {
let walk = self.simple_random_walk(graph, start, walk_length)?;
walks.push(walk);
}
Ok(walks)
}
pub fn generate_walks_digraph<E, Ix>(
&mut self,
graph: &DiGraph<N, E, Ix>,
start: &N,
num_walks: usize,
walk_length: usize,
) -> Result<Vec<RandomWalk<N>>>
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
let mut walks = Vec::new();
for _ in 0..num_walks {
let walk = self.simple_random_walk_digraph(graph, start, walk_length)?;
walks.push(walk);
}
Ok(walks)
}
}