use petgraph::graph::{NodeIndex, UnGraph};
use scirs2_core::random::rand_prelude::StdRng;
use scirs2_core::random::{seeded_rng, CoreRandom, Random};
use std::collections::HashMap;
use crate::{GraphRAGError, GraphRAGResult, Triple};
#[derive(Debug, Clone)]
pub struct Node2VecWalkConfig {
pub num_walks: usize,
pub walk_length: usize,
pub p: f64,
pub q: f64,
pub random_seed: u64,
}
impl Default for Node2VecWalkConfig {
fn default() -> Self {
Self {
num_walks: 10,
walk_length: 80,
p: 1.0,
q: 1.0,
random_seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct Node2VecConfig {
pub walk: Node2VecWalkConfig,
pub embedding_dim: usize,
pub window_size: usize,
pub num_epochs: usize,
pub learning_rate: f64,
pub normalize: bool,
}
impl Default for Node2VecConfig {
fn default() -> Self {
Self {
walk: Node2VecWalkConfig::default(),
embedding_dim: 128,
window_size: 5,
num_epochs: 5,
learning_rate: 0.025,
normalize: true,
}
}
}
#[derive(Debug, Clone)]
pub struct Node2VecEmbeddings {
pub embeddings: HashMap<String, Vec<f64>>,
pub dim: usize,
pub total_walk_steps: usize,
}
impl Node2VecEmbeddings {
pub fn get(&self, node: &str) -> Option<&[f64]> {
self.embeddings.get(node).map(|v| v.as_slice())
}
pub fn cosine_similarity(&self, a: &str, b: &str) -> Option<f64> {
let ea = self.embeddings.get(a)?;
let eb = self.embeddings.get(b)?;
let dot: f64 = ea.iter().zip(eb.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = ea.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = eb.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a < 1e-12 || norm_b < 1e-12 {
None
} else {
Some(dot / (norm_a * norm_b))
}
}
pub fn top_k_similar(&self, query: &str, k: usize) -> Vec<(String, f64)> {
let Some(eq) = self.embeddings.get(query) else {
return vec![];
};
let norm_q: f64 = eq.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_q < 1e-12 {
return vec![];
}
let mut scored: Vec<(String, f64)> = self
.embeddings
.iter()
.filter(|(node, _)| node.as_str() != query)
.map(|(node, emb)| {
let dot: f64 = emb.iter().zip(eq.iter()).map(|(x, y)| x * y).sum();
let norm_e: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
let sim = if norm_e < 1e-12 {
0.0
} else {
dot / (norm_q * norm_e)
};
(node.clone(), sim)
})
.collect();
scored.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
}
struct AliasTable {
prob: Vec<f64>,
alias: Vec<usize>,
}
impl AliasTable {
fn build(weights: &[f64]) -> Option<Self> {
let n = weights.len();
if n == 0 {
return None;
}
let sum: f64 = weights.iter().sum();
if sum <= 0.0 {
return None;
}
let prob_norm: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
let mut small: Vec<usize> = Vec::with_capacity(n);
let mut large: Vec<usize> = Vec::with_capacity(n);
let mut prob = prob_norm.clone();
let mut alias = vec![0usize; n];
for (i, &p) in prob_norm.iter().enumerate() {
if p < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
while !small.is_empty() && !large.is_empty() {
let l = small.pop().expect("checked non-empty");
let g = large.last().copied().expect("checked non-empty");
alias[l] = g;
prob[g] -= 1.0 - prob[l];
if prob[g] < 1.0 {
large.pop();
small.push(g);
}
}
Some(Self { prob, alias })
}
fn sample(&self, rng: &mut CoreRandom<StdRng>) -> usize {
let n = self.prob.len();
let i = (rng.random_range(0.0..1.0) * n as f64) as usize;
let i = i.min(n - 1);
if rng.random_range(0.0..1.0) < self.prob[i] {
i
} else {
self.alias[i]
}
}
}
type EdgeAlias = HashMap<(NodeIndex, NodeIndex), (Vec<NodeIndex>, AliasTable)>;
type NodeAlias = HashMap<NodeIndex, (Vec<NodeIndex>, AliasTable)>;
pub struct Node2VecEmbedder {
config: Node2VecConfig,
}
impl Node2VecEmbedder {
pub fn new(config: Node2VecConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(Node2VecConfig::default())
}
pub fn embed(&self, triples: &[Triple]) -> GraphRAGResult<Node2VecEmbeddings> {
let (graph, node_map) = self.build_graph(triples);
if graph.node_count() == 0 {
return Ok(Node2VecEmbeddings {
embeddings: HashMap::new(),
dim: self.config.embedding_dim,
total_walk_steps: 0,
});
}
let mut rng = seeded_rng(self.config.walk.random_seed);
let (node_alias, edge_alias) = self.build_alias_tables(&graph)?;
let (walks, total_steps) =
self.simulate_walks(&graph, &node_map, &node_alias, &edge_alias, &mut rng);
let embeddings = self.train_skip_gram(&walks, &node_map, &mut rng)?;
Ok(Node2VecEmbeddings {
embeddings,
dim: self.config.embedding_dim,
total_walk_steps: total_steps,
})
}
fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
for triple in triples {
let s = *node_map
.entry(triple.subject.clone())
.or_insert_with(|| graph.add_node(triple.subject.clone()));
let o = *node_map
.entry(triple.object.clone())
.or_insert_with(|| graph.add_node(triple.object.clone()));
if s != o && graph.find_edge(s, o).is_none() {
graph.add_edge(s, o, ());
}
}
(graph, node_map)
}
fn build_alias_tables(
&self,
graph: &UnGraph<String, ()>,
) -> GraphRAGResult<(NodeAlias, EdgeAlias)> {
let p = self.config.walk.p;
let q = self.config.walk.q;
let mut node_alias: NodeAlias = HashMap::new();
for node in graph.node_indices() {
let neighbors: Vec<NodeIndex> = graph.neighbors(node).collect();
if neighbors.is_empty() {
continue;
}
let weights: Vec<f64> = vec![1.0; neighbors.len()];
if let Some(table) = AliasTable::build(&weights) {
node_alias.insert(node, (neighbors, table));
}
}
let mut edge_alias: EdgeAlias = HashMap::new();
for edge in graph.edge_indices() {
let (u, v) = graph
.edge_endpoints(edge)
.ok_or_else(|| GraphRAGError::InternalError("bad edge".to_string()))?;
for (prev, cur) in [(u, v), (v, u)] {
let neighbors: Vec<NodeIndex> = graph.neighbors(cur).collect();
if neighbors.is_empty() {
continue;
}
let weights: Vec<f64> = neighbors
.iter()
.map(|&next| {
if next == prev {
1.0 / p
} else if graph.find_edge(prev, next).is_some() {
1.0
} else {
1.0 / q
}
})
.collect();
if let Some(table) = AliasTable::build(&weights) {
edge_alias.insert((prev, cur), (neighbors, table));
}
}
}
Ok((node_alias, edge_alias))
}
fn simulate_walks(
&self,
graph: &UnGraph<String, ()>,
node_map: &HashMap<String, NodeIndex>,
node_alias: &NodeAlias,
edge_alias: &EdgeAlias,
rng: &mut CoreRandom<StdRng>,
) -> (Vec<Vec<String>>, usize) {
let walk_length = self.config.walk.walk_length;
let num_walks = self.config.walk.num_walks;
let node_indices: Vec<NodeIndex> = node_map.values().copied().collect();
let mut walks: Vec<Vec<String>> = Vec::with_capacity(num_walks * node_indices.len());
let mut total_steps = 0usize;
for _ in 0..num_walks {
let mut order = node_indices.clone();
for i in (1..order.len()).rev() {
let j = (rng.random_range(0.0..1.0) * (i + 1) as f64) as usize;
order.swap(i, j.min(i));
}
for &start in &order {
let walk = self.single_walk(graph, start, walk_length, node_alias, edge_alias, rng);
total_steps += walk.len();
walks.push(walk);
}
}
(walks, total_steps)
}
fn single_walk(
&self,
graph: &UnGraph<String, ()>,
start: NodeIndex,
walk_length: usize,
node_alias: &NodeAlias,
edge_alias: &EdgeAlias,
rng: &mut CoreRandom<StdRng>,
) -> Vec<String> {
let mut walk: Vec<String> = Vec::with_capacity(walk_length);
if let Some(label) = graph.node_weight(start) {
walk.push(label.clone());
} else {
return walk;
}
let mut current = start;
let mut prev: Option<NodeIndex> = None;
for _ in 1..walk_length {
let next = if let Some(p) = prev {
if let Some((neighbors, table)) = edge_alias.get(&(p, current)) {
let idx = table.sample(rng);
neighbors.get(idx).copied()
} else {
None
}
} else {
if let Some((neighbors, table)) = node_alias.get(¤t) {
let idx = table.sample(rng);
neighbors.get(idx).copied()
} else {
None
}
};
match next {
Some(n) => {
if let Some(label) = graph.node_weight(n) {
walk.push(label.clone());
}
prev = Some(current);
current = n;
}
None => break, }
}
walk
}
fn train_skip_gram(
&self,
walks: &[Vec<String>],
node_map: &HashMap<String, NodeIndex>,
rng: &mut CoreRandom<StdRng>,
) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
let dim = self.config.embedding_dim;
let window = self.config.window_size;
let lr_init = self.config.learning_rate;
let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
for node_label in node_map.keys() {
let emb: Vec<f64> = (0..dim)
.map(|_| (rng.random_range(0.0..1.0) - 0.5) / dim as f64)
.collect();
embeddings.insert(node_label.clone(), emb);
}
let mut ctx_embeddings: HashMap<String, Vec<f64>> = HashMap::new();
for node_label in node_map.keys() {
ctx_embeddings.insert(node_label.clone(), vec![0.0f64; dim]);
}
let total_epochs = self.config.num_epochs;
let total_pairs: usize = walks
.iter()
.map(|w| w.len() * (2 * window).min(if w.len() > 1 { w.len() - 1 } else { 0 }))
.sum();
let mut pair_count = 0usize;
for epoch in 0..total_epochs {
let lr = lr_init * (1.0 - epoch as f64 / total_epochs as f64).max(0.001);
for walk in walks {
for (i, target) in walk.iter().enumerate() {
let start = i.saturating_sub(window);
let end = (i + window + 1).min(walk.len());
for (j, context) in walk[start..end].iter().enumerate() {
let abs_j = start + j;
if abs_j == i || context == target {
continue;
}
let local_lr =
lr * (1.0 - pair_count as f64 / (total_pairs + 1) as f64).max(0.001);
self.sgd_update(
target,
context,
&mut embeddings,
&mut ctx_embeddings,
local_lr,
dim,
);
pair_count += 1;
}
}
}
}
for (node, emb) in &mut embeddings {
if let Some(ctx) = ctx_embeddings.get(node) {
for (e, c) in emb.iter_mut().zip(ctx.iter()) {
*e = (*e + c) / 2.0;
}
}
}
if self.config.normalize {
for emb in embeddings.values_mut() {
let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-12 {
for v in emb.iter_mut() {
*v /= norm;
}
}
}
}
Ok(embeddings)
}
fn sgd_update(
&self,
target: &str,
context: &str,
embeddings: &mut HashMap<String, Vec<f64>>,
ctx_embeddings: &mut HashMap<String, Vec<f64>>,
lr: f64,
dim: usize,
) {
let score = {
let Some(te) = embeddings.get(target) else {
return;
};
let Some(ce) = ctx_embeddings.get(context) else {
return;
};
te.iter().zip(ce.iter()).map(|(a, b)| a * b).sum::<f64>()
};
let sigma = 1.0 / (1.0 + (-score).exp());
let grad = (1.0 - sigma) * lr;
let te_snap: Vec<f64> = match embeddings.get(target) {
Some(v) => v.clone(),
None => return,
};
let ce_snap: Vec<f64> = match ctx_embeddings.get(context) {
Some(v) => v.clone(),
None => return,
};
if let Some(te) = embeddings.get_mut(target) {
for k in 0..dim {
te[k] += grad * ce_snap[k];
}
}
if let Some(ce) = ctx_embeddings.get_mut(context) {
for k in 0..dim {
ce[k] += grad * te_snap[k];
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
fn ring_triples(n: usize) -> Vec<Triple> {
(0..n)
.map(|i| {
Triple::new(
format!("node_{}", i),
"connects",
format!("node_{}", (i + 1) % n),
)
})
.collect()
}
fn complete_triples(n: usize) -> Vec<Triple> {
let mut ts = Vec::new();
for i in 0..n {
for j in i + 1..n {
ts.push(Triple::new(format!("n{}", i), "edge", format!("n{}", j)));
}
}
ts
}
fn small_config() -> Node2VecConfig {
Node2VecConfig {
walk: Node2VecWalkConfig {
num_walks: 5,
walk_length: 10,
p: 1.0,
q: 1.0,
random_seed: 99,
},
embedding_dim: 16,
window_size: 2,
num_epochs: 3,
learning_rate: 0.05,
normalize: true,
}
}
#[test]
fn test_embed_produces_correct_number_of_embeddings() {
let triples = ring_triples(6);
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&triples).expect("embed failed");
assert_eq!(result.embeddings.len(), 6);
assert_eq!(result.dim, 16);
}
#[test]
fn test_embed_correct_dimension() {
let triples = complete_triples(4);
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&triples).expect("embed failed");
for emb in result.embeddings.values() {
assert_eq!(emb.len(), 16);
}
}
#[test]
fn test_normalized_embeddings_have_unit_norm() {
let triples = ring_triples(5);
let config = Node2VecConfig {
normalize: true,
..small_config()
};
let embedder = Node2VecEmbedder::new(config);
let result = embedder.embed(&triples).expect("embed failed");
for (node, emb) in &result.embeddings {
let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-6,
"node {} has non-unit norm {:.6}",
node,
norm
);
}
}
#[test]
fn test_cosine_similarity_in_range() {
let triples = complete_triples(5);
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&triples).expect("embed failed");
let nodes: Vec<String> = result.embeddings.keys().cloned().collect();
if nodes.len() >= 2 {
if let Some(sim) = result.cosine_similarity(&nodes[0], &nodes[1]) {
assert!(
(-1.0 - 1e-9..=1.0 + 1e-9).contains(&sim),
"cosine similarity out of range: {}",
sim
);
}
}
}
#[test]
fn test_top_k_similar_returns_at_most_k() {
let triples = ring_triples(8);
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&triples).expect("embed failed");
let similar = result.top_k_similar("node_0", 3);
assert!(similar.len() <= 3);
}
#[test]
fn test_empty_triples_returns_empty_embeddings() {
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&[]).expect("embed failed");
assert!(result.embeddings.is_empty());
assert_eq!(result.total_walk_steps, 0);
}
#[test]
fn test_single_node_isolated() {
let triples = vec![Triple::new("a", "r", "b")];
let embedder = Node2VecEmbedder::new(small_config());
let result = embedder.embed(&triples).expect("embed failed");
assert_eq!(result.embeddings.len(), 2);
}
#[test]
fn test_walk_bias_dfs_vs_bfs() {
let triples = ring_triples(10);
let dfs_config = Node2VecConfig {
walk: Node2VecWalkConfig {
num_walks: 3,
walk_length: 20,
p: 0.25,
q: 0.25,
random_seed: 1,
},
..small_config()
};
let bfs_config = Node2VecConfig {
walk: Node2VecWalkConfig {
num_walks: 3,
walk_length: 20,
p: 4.0,
q: 4.0,
random_seed: 1,
},
..small_config()
};
let embedder_dfs = Node2VecEmbedder::new(dfs_config);
let embedder_bfs = Node2VecEmbedder::new(bfs_config);
let res_dfs = embedder_dfs.embed(&triples).expect("dfs embed failed");
let res_bfs = embedder_bfs.embed(&triples).expect("bfs embed failed");
assert_eq!(res_dfs.embeddings.len(), 10);
assert_eq!(res_bfs.embeddings.len(), 10);
}
#[test]
fn test_total_walk_steps_is_plausible() {
let n = 5usize;
let triples = ring_triples(n);
let config = Node2VecConfig {
walk: Node2VecWalkConfig {
num_walks: 2,
walk_length: 10,
..Default::default()
},
..small_config()
};
let embedder = Node2VecEmbedder::new(config);
let result = embedder.embed(&triples).expect("embed failed");
assert!(
result.total_walk_steps >= n * 2,
"expected ≥{} steps, got {}",
n * 2,
result.total_walk_steps
);
assert!(
result.total_walk_steps <= n * 2 * 10 + n * 2,
"unexpectedly many steps: {}",
result.total_walk_steps
);
}
#[test]
fn test_alias_table_samples_valid_index() {
let weights = vec![1.0, 2.0, 3.0, 4.0];
let table = AliasTable::build(&weights).expect("alias build failed");
let mut rng = seeded_rng(777);
for _ in 0..100 {
let idx = table.sample(&mut rng);
assert!(idx < weights.len());
}
}
#[test]
fn test_alias_table_uniform_weights() {
let weights = vec![1.0; 4];
let table = AliasTable::build(&weights).expect("alias build failed");
let mut rng = seeded_rng(42);
let mut counts = [0usize; 4];
for _ in 0..4000 {
counts[table.sample(&mut rng)] += 1;
}
for c in counts {
assert!(c > 800 && c < 1200, "bucket count out of range: {}", c);
}
}
}