use crate::base::{EdgeWeight, Graph, IndexType, Node};
use crate::error::{GraphError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::parallel_ops::*;
use scirs2_core::random::RngExt;
use std::collections::HashMap;
use std::hash::Hash;
#[allow(dead_code)]
pub fn random_walk<N, E, Ix>(
graph: &Graph<N, E, Ix>,
start: &N,
steps: usize,
restart_probability: f64,
) -> Result<Vec<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
if !graph.contains_node(start) {
return Err(GraphError::node_not_found("node"));
}
let mut walk = vec![start.clone()];
let mut current = start.clone();
let mut rng = scirs2_core::random::rng();
use scirs2_core::random::{Rng, RngExt};
for _ in 0..steps {
if rng.random::<f64>() < restart_probability {
current = start.clone();
walk.push(current.clone());
continue;
}
if let Ok(neighbors) = graph.neighbors(¤t) {
let neighbor_vec: Vec<N> = neighbors;
if !neighbor_vec.is_empty() {
let idx = rng.random_range(0..neighbor_vec.len());
current = neighbor_vec[idx].clone();
walk.push(current.clone());
} else {
current = start.clone();
walk.push(current.clone());
}
}
}
Ok(walk)
}
#[allow(dead_code)]
pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
where
N: Node + Clone + std::fmt::Debug,
E: EdgeWeight + Into<f64>,
Ix: IndexType,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
if n == 0 {
return Err(GraphError::InvalidGraph("Empty graph".to_string()));
}
let mut matrix = Array2::<f64>::zeros((n, n));
for (i, node) in nodes.iter().enumerate() {
if let Ok(neighbors) = graph.neighbors(node) {
let neighbor_weights: Vec<(usize, f64)> = neighbors
.into_iter()
.filter_map(|neighbor| {
nodes.iter().position(|n| n == &neighbor).and_then(|j| {
graph
.edge_weight(node, &neighbor)
.ok()
.map(|w| (j, w.into()))
})
})
.collect();
let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
if total_weight > 0.0 {
for (j, weight) in neighbor_weights {
matrix[[i, j]] = weight / total_weight;
}
} else {
for j in 0..n {
matrix[[i, j]] = 1.0 / n as f64;
}
}
}
}
Ok((nodes, matrix))
}
#[allow(dead_code)]
pub fn personalized_pagerank<N, E, Ix>(
graph: &Graph<N, E, Ix>,
source: &N,
damping: f64,
tolerance: f64,
max_iter: usize,
) -> Result<HashMap<N, f64>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight + Into<f64>,
Ix: IndexType,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
if n == 0 || !graph.contains_node(source) {
return Err(GraphError::node_not_found("node"));
}
let source_idx = nodes
.iter()
.position(|n| n == source)
.expect("Operation failed");
let (_, trans_matrix) = transition_matrix(graph)?;
let mut pr = Array1::<f64>::zeros(n);
pr[source_idx] = 1.0;
let mut personalization = Array1::<f64>::zeros(n);
personalization[source_idx] = 1.0;
for _ in 0..max_iter {
let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
if diff < tolerance {
break;
}
pr = new_pr;
}
Ok(nodes
.into_iter()
.enumerate()
.map(|(i, node)| (node, pr[i]))
.collect())
}
#[allow(dead_code)]
pub fn parallel_random_walks<N, E, Ix>(
graph: &Graph<N, E, Ix>,
starts: &[N],
walk_length: usize,
restart_probability: f64,
) -> Result<Vec<Vec<N>>>
where
N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
E: EdgeWeight + Send + Sync,
Ix: IndexType + Send + Sync,
{
starts
.par_iter()
.map(|start| random_walk(graph, start, walk_length, restart_probability))
.collect::<Result<Vec<_>>>()
}
pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
node_to_idx: HashMap<N, usize>,
idx_to_node: Vec<N>,
#[allow(dead_code)]
transition_probs: Vec<Vec<f64>>,
alias_tables: Vec<AliasTable>,
}
#[derive(Debug, Clone)]
struct AliasTable {
prob: Vec<f64>,
alias: Vec<usize>,
}
impl AliasTable {
fn new(weights: &[f64]) -> Self {
let n = weights.len();
let mut prob = vec![0.0; n];
let mut alias = vec![0; n];
if n == 0 {
return AliasTable { prob, alias };
}
let sum: f64 = weights.iter().sum();
if sum == 0.0 {
return AliasTable { prob, alias };
}
let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
let mut small = Vec::new();
let mut large = Vec::new();
for (i, &p) in normalized.iter().enumerate() {
if p < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
prob[..n].copy_from_slice(&normalized[..n]);
while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
alias[small_idx] = large_idx;
prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
if prob[large_idx] < 1.0 {
small.push(large_idx);
} else {
large.push(large_idx);
}
}
AliasTable { prob, alias }
}
fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> usize {
if self.prob.is_empty() {
return 0;
}
let i = rng.random_range(0..self.prob.len());
let coin_flip = rng.random::<f64>();
if coin_flip <= self.prob[i] {
i
} else {
self.alias[i]
}
}
}
impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
where
E: EdgeWeight + Into<f64>,
Ix: IndexType,
N: std::fmt::Debug,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let node_to_idx: HashMap<N, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
let mut transition_probs = Vec::new();
let mut alias_tables = Vec::new();
for node in &nodes {
if let Ok(neighbors) = graph.neighbors(node) {
let neighbor_weights: Vec<f64> = neighbors
.iter()
.filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
.map(|w| w.into())
.collect();
if !neighbor_weights.is_empty() {
let total: f64 = neighbor_weights.iter().sum();
let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
let mut cumulative = vec![0.0; probs.len()];
cumulative[0] = probs[0];
for i in 1..probs.len() {
cumulative[i] = cumulative[i - 1] + probs[i];
}
transition_probs.push(cumulative);
alias_tables.push(AliasTable::new(&neighbor_weights));
} else {
transition_probs.push(vec![]);
alias_tables.push(AliasTable::new(&[]));
}
} else {
transition_probs.push(vec![]);
alias_tables.push(AliasTable::new(&[]));
}
}
Ok(BatchRandomWalker {
node_to_idx,
idx_to_node: nodes,
transition_probs,
alias_tables,
})
}
pub fn generate_walks<E, Ix>(
&self,
graph: &Graph<N, E, Ix>,
starts: &[N],
walk_length: usize,
num_walks_per_node: usize,
) -> Result<Vec<Vec<N>>>
where
E: EdgeWeight,
Ix: IndexType + std::marker::Sync,
N: Send + Sync + std::fmt::Debug,
{
let total_walks = starts.len() * num_walks_per_node;
let mut all_walks = Vec::with_capacity(total_walks);
starts
.par_iter()
.map(|start| {
let mut local_walks = Vec::with_capacity(num_walks_per_node);
let mut rng = scirs2_core::random::rng();
for _ in 0..num_walks_per_node {
if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
local_walks.push(walk);
}
}
local_walks
})
.collect::<Vec<_>>()
.into_iter()
.for_each(|walks| all_walks.extend(walks));
Ok(all_walks)
}
fn single_walk<E, Ix>(
&self,
graph: &Graph<N, E, Ix>,
start: &N,
walk_length: usize,
rng: &mut impl scirs2_core::random::Rng,
) -> Result<Vec<N>>
where
E: EdgeWeight,
Ix: IndexType,
{
let mut walk = Vec::with_capacity(walk_length + 1);
walk.push(start.clone());
let mut current_idx = *self
.node_to_idx
.get(start)
.ok_or(GraphError::node_not_found("node"))?;
for _ in 0..walk_length {
if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
let neighbors: Vec<_> = neighbors;
if !neighbors.is_empty() {
let neighbor_idx = self.alias_tables[current_idx].sample(rng);
if neighbor_idx < neighbors.len() {
let next_node = neighbors[neighbor_idx].clone();
walk.push(next_node.clone());
if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
current_idx = next_idx;
}
} else {
break;
}
} else {
break;
}
} else {
break;
}
}
Ok(walk)
}
}
#[allow(dead_code)]
pub fn node2vec_walk<N, E, Ix>(
graph: &Graph<N, E, Ix>,
start: &N,
walk_length: usize,
p: f64, q: f64, rng: &mut impl scirs2_core::random::Rng,
) -> Result<Vec<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight + Into<f64>,
Ix: IndexType,
{
let mut walk = vec![start.clone()];
if walk_length == 0 {
return Ok(walk);
}
if let Ok(neighbors) = graph.neighbors(start) {
let neighbors: Vec<_> = neighbors;
if neighbors.is_empty() {
return Ok(walk);
}
let idx = rng.random_range(0..neighbors.len());
walk.push(neighbors[idx].clone());
} else {
return Ok(walk);
}
for step in 1..walk_length {
let current = &walk[step];
let previous = &walk[step - 1];
if let Ok(neighbors) = graph.neighbors(current) {
let neighbors: Vec<_> = neighbors;
if neighbors.is_empty() {
break;
}
let mut weights = Vec::with_capacity(neighbors.len());
for neighbor in &neighbors {
let weight = if neighbor == previous {
1.0 / p
} else if graph.has_edge(previous, neighbor) {
1.0
} else {
1.0 / q
};
let edge_weight = graph
.edge_weight(current, neighbor)
.map(|w| w.into())
.unwrap_or(1.0);
weights.push(weight * edge_weight);
}
let total: f64 = weights.iter().sum();
if total > 0.0 {
let mut cumulative = vec![0.0; weights.len()];
cumulative[0] = weights[0] / total;
for i in 1..weights.len() {
cumulative[i] = cumulative[i - 1] + weights[i] / total;
}
let r = rng.random::<f64>();
for (i, &cum_prob) in cumulative.iter().enumerate() {
if r <= cum_prob {
walk.push(neighbors[i].clone());
break;
}
}
}
} else {
break;
}
}
Ok(walk)
}
#[allow(dead_code)]
pub fn parallel_node2vec_walks<N, E, Ix>(
graph: &Graph<N, E, Ix>,
starts: &[N],
walk_length: usize,
num_walks: usize,
p: f64,
q: f64,
) -> Result<Vec<Vec<N>>>
where
N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
E: EdgeWeight + Into<f64> + Send + Sync,
Ix: IndexType + Send + Sync,
{
let total_walks = starts.len() * num_walks;
(0..total_walks)
.into_par_iter()
.map(|i| {
let start_idx = i % starts.len();
let start = &starts[start_idx];
let mut rng = scirs2_core::random::rng();
node2vec_walk(graph, start, walk_length, p, q, &mut rng)
})
.collect()
}
#[allow(dead_code)]
pub fn simd_random_walk_with_restart<N, E, Ix>(
graph: &Graph<N, E, Ix>,
start: &N,
walk_length: usize,
restart_prob: f64,
rng: &mut impl scirs2_core::random::Rng,
) -> Result<Vec<N>>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
let mut walk = Vec::with_capacity(walk_length + 1);
walk.push(start.clone());
let mut current = start.clone();
for _ in 0..walk_length {
if rng.random::<f64>() < restart_prob {
current = start.clone();
walk.push(current.clone());
continue;
}
if let Ok(neighbors) = graph.neighbors(¤t) {
let neighbors: Vec<_> = neighbors;
if !neighbors.is_empty() {
let idx = rng.random_range(0..neighbors.len());
current = neighbors[idx].clone();
walk.push(current.clone());
} else {
current = start.clone();
walk.push(current.clone());
}
} else {
break;
}
}
Ok(walk)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result as GraphResult;
use crate::generators::create_graph;
#[test]
fn test_random_walk() -> GraphResult<()> {
let mut graph = create_graph::<&str, ()>();
graph.add_edge("A", "B", ())?;
graph.add_edge("B", "C", ())?;
graph.add_edge("C", "D", ())?;
let walk = random_walk(&graph, &"A", 10, 0.1)?;
assert_eq!(walk[0], "A");
assert_eq!(walk.len(), 11);
for node in &walk {
assert!(graph.contains_node(node));
}
Ok(())
}
#[test]
fn test_transition_matrix() -> GraphResult<()> {
let mut graph = create_graph::<&str, f64>();
graph.add_edge("A", "B", 1.0)?;
graph.add_edge("B", "C", 1.0)?;
graph.add_edge("C", "A", 1.0)?;
let (nodes, matrix) = transition_matrix(&graph)?;
assert_eq!(nodes.len(), 3);
assert_eq!(matrix.shape(), &[3, 3]);
for i in 0..3 {
let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
assert!((row_sum - 1.0).abs() < 1e-6);
}
Ok(())
}
#[test]
fn test_personalized_pagerank() -> GraphResult<()> {
let mut graph = create_graph::<&str, f64>();
graph.add_edge("A", "B", 1.0)?;
graph.add_edge("A", "C", 1.0)?;
graph.add_edge("A", "D", 1.0)?;
let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
assert_eq!(pagerank.len(), 4);
let total: f64 = pagerank.values().sum();
assert!((total - 1.0).abs() < 1e-3);
let a_rank = pagerank[&"A"];
for (node, &rank) in &pagerank {
if node != &"A" {
assert!(a_rank >= rank);
}
}
Ok(())
}
}