use crate::{Error, KnowledgeGraph};
use rand::prelude::*;
use rand_xorshift::XorShiftRng;
use rayon::prelude::*;
use std::collections::HashMap;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy)]
pub struct RandomWalkConfig {
pub walk_length: usize,
pub num_walks: usize,
pub p: f32,
pub q: f32,
pub seed: u64,
}
impl Default for RandomWalkConfig {
fn default() -> Self {
Self {
walk_length: 80,
num_walks: 10,
p: 1.0,
q: 1.0,
seed: 42,
}
}
}
#[must_use]
pub fn generate_walks(kg: &KnowledgeGraph, config: RandomWalkConfig) -> Vec<Vec<String>> {
let walker = Node2Vec::new(kg, config);
walker.walk()
}
#[derive(Debug, Clone)]
pub struct WalkCorpus {
pub node_ids: Vec<String>,
pub walks: Vec<Vec<u32>>,
}
pub fn generate_walk_corpus(
kg: &KnowledgeGraph,
config: RandomWalkConfig,
) -> crate::Result<WalkCorpus> {
let graph = kg.as_petgraph();
let mut node_indices: Vec<_> = graph.node_indices().collect();
node_indices.sort_by_key(|n| n.index());
let node_ids: Vec<String> = node_indices
.iter()
.map(|&n| graph[n].id.as_str().to_owned())
.collect();
let id_to_dense: HashMap<String, u32> = node_ids
.iter()
.enumerate()
.map(|(i, id)| (id.clone(), i as u32))
.collect();
let walks_ids = generate_walks(kg, config);
let mut walks = Vec::with_capacity(walks_ids.len());
for w in walks_ids {
let mut dense_walk = Vec::with_capacity(w.len());
for id in w {
let idx = *id_to_dense.get(&id).ok_or(Error::EntityNotFound(id))?;
dense_walk.push(idx);
}
walks.push(dense_walk);
}
Ok(WalkCorpus { node_ids, walks })
}
pub struct Node2Vec<'a> {
kg: &'a KnowledgeGraph,
config: RandomWalkConfig,
}
impl<'a> Node2Vec<'a> {
#[must_use]
pub const fn new(kg: &'a KnowledgeGraph, config: RandomWalkConfig) -> Self {
Self { kg, config }
}
#[must_use]
pub fn walk(&self) -> Vec<Vec<String>> {
let node_indices: Vec<_> = self.kg.as_petgraph().node_indices().collect();
let is_unbiased = (self.config.p - 1.0).abs() < f32::EPSILON
&& (self.config.q - 1.0).abs() < f32::EPSILON;
(0..self.config.num_walks)
.into_par_iter()
.flat_map(|iter_idx| {
let mut rng = XorShiftRng::seed_from_u64(self.config.seed + iter_idx as u64);
let mut walks = Vec::with_capacity(node_indices.len());
let mut shuffled = node_indices.clone();
shuffled.shuffle(&mut rng);
for &start in &shuffled {
let walk = if is_unbiased {
self.unbiased_walk(start, &mut rng)
} else {
self.biased_walk(start, &mut rng)
};
walks.push(walk);
}
walks
})
.collect()
}
fn unbiased_walk<R: Rng>(&self, start: petgraph::graph::NodeIndex, rng: &mut R) -> Vec<String> {
let graph = self.kg.as_petgraph();
let mut walk = Vec::with_capacity(self.config.walk_length);
walk.push(graph[start].id.as_str().to_owned());
let mut curr = start;
for _ in 1..self.config.walk_length {
let neighbors: Vec<_> = graph.neighbors(curr).collect();
if neighbors.is_empty() {
break;
}
curr = *neighbors
.choose(rng)
.expect("internal: neighbors non-empty after check");
walk.push(graph[curr].id.as_str().to_owned());
}
walk
}
fn biased_walk<R: Rng>(&self, start: petgraph::graph::NodeIndex, rng: &mut R) -> Vec<String> {
let graph = self.kg.as_petgraph();
let mut walk = Vec::with_capacity(self.config.walk_length);
walk.push(graph[start].id.as_str().to_owned());
let mut curr = start;
let mut prev: Option<petgraph::graph::NodeIndex> = None;
let mut prev_neighbors: HashSet<petgraph::graph::NodeIndex> = HashSet::new();
for _ in 1..self.config.walk_length {
let neighbors: Vec<_> = graph.neighbors(curr).collect();
if neighbors.is_empty() {
break;
}
let next = if let Some(prev_node) = prev {
self.sample_biased_rejection(rng, prev_node, &prev_neighbors, &neighbors)
} else {
*neighbors
.choose(rng)
.expect("internal: neighbors non-empty after check")
};
walk.push(graph[next].id.as_str().to_owned());
prev = Some(curr);
prev_neighbors.clear();
prev_neighbors.extend(graph.neighbors(curr));
curr = next;
}
walk
}
fn sample_biased_rejection<R: Rng>(
&self,
rng: &mut R,
prev_node: petgraph::graph::NodeIndex,
prev_neighbors: &HashSet<petgraph::graph::NodeIndex>,
neighbors: &[petgraph::graph::NodeIndex],
) -> petgraph::graph::NodeIndex {
let p = f64::from(self.config.p);
let q = f64::from(self.config.q);
let max_prob = (1.0 / p).max(1.0).max(1.0 / q);
loop {
let candidate = *neighbors
.choose(rng)
.expect("internal: neighbors non-empty (caller guarantees)");
let r: f64 = rng.random();
let unnorm_prob = if candidate == prev_node {
1.0 / p } else if prev_neighbors.contains(&candidate) {
1.0 } else {
1.0 / q };
if r < unnorm_prob / max_prob {
return candidate;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
#[test]
fn test_random_walk_uniform() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "A"));
kg.add_triple(Triple::new("B", "rel", "C"));
kg.add_triple(Triple::new("C", "rel", "B"));
let config = RandomWalkConfig {
walk_length: 10,
num_walks: 2,
p: 1.0,
q: 1.0,
seed: 42,
};
let walks = generate_walks(&kg, config);
assert_eq!(walks.len(), 3 * 2); for walk in &walks {
assert!(!walk.is_empty());
}
}
#[test]
fn test_random_walk_biased() {
let mut kg = KnowledgeGraph::new();
for (a, b) in [("A", "B"), ("B", "C"), ("C", "D")] {
kg.add_triple(Triple::new(a, "rel", b));
kg.add_triple(Triple::new(b, "rel", a));
}
let config = RandomWalkConfig {
walk_length: 20,
num_walks: 5,
p: 0.5, q: 2.0, seed: 123,
};
let walks = generate_walks(&kg, config);
assert_eq!(walks.len(), 4 * 5);
for walk in &walks {
assert!(walk.len() > 1);
}
}
#[test]
fn test_random_walk_reproducible() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "C"));
let config = RandomWalkConfig {
walk_length: 10,
num_walks: 3,
seed: 999,
..Default::default()
};
let walks1 = generate_walks(&kg, config);
let walks2 = generate_walks(&kg, config);
assert_eq!(walks1, walks2);
}
}