Skip to main content

oxirs_graphrag/embeddings/
node2vec.rs

1//! Node2Vec: Scalable Feature Learning for Networks
2//!
3//! Reference: Grover & Leskovec (KDD 2016) — <https://arxiv.org/abs/1607.00653>
4//!
5//! # Algorithm Overview
6//!
7//! Node2Vec generates node embeddings via **biased second-order random walks**.
8//! Two hyperparameters control the walk bias:
9//!
10//! * **p** (return parameter) – probability of returning to the previous node.
11//!   High p → DFS-like walk (explores further from source).
12//! * **q** (in-out parameter) – probability of moving away from the previous node.
13//!   Low q → BFS-like walk (stays near source).
14//!
15//! The algorithm:
16//! 1. Pre-compute per-edge alias tables for O(1) biased sampling.
17//! 2. Simulate `num_walks` second-order random walks of length `walk_length`
18//!    from every node.
19//! 3. Train a simplified skip-gram model on the walk corpus to produce
20//!    `embedding_dim`-dimensional embeddings.
21//!
22//! # Implementation Notes
23//!
24//! * No external ML or linear-algebra crates are required.
25//! * Random number generation uses `scirs2_core::random` per project policy.
26//! * All unsafe code is avoided; heavy loops use Rust iterators.
27
28use petgraph::graph::{NodeIndex, UnGraph};
29use scirs2_core::random::rand_prelude::StdRng;
30use scirs2_core::random::{seeded_rng, CoreRandom, Random};
31use std::collections::HashMap;
32
33use crate::{GraphRAGError, GraphRAGResult, Triple};
34
35// ─── Configuration ───────────────────────────────────────────────────────────
36
37/// Configuration for the random-walk phase of Node2Vec.
38#[derive(Debug, Clone)]
39pub struct Node2VecWalkConfig {
40    /// Number of random walks starting from each node.
41    pub num_walks: usize,
42    /// Length (number of steps) of each random walk.
43    pub walk_length: usize,
44    /// Return parameter *p* (see module docs).
45    pub p: f64,
46    /// In-out parameter *q* (see module docs).
47    pub q: f64,
48    /// Random seed for reproducibility.
49    pub random_seed: u64,
50}
51
52impl Default for Node2VecWalkConfig {
53    fn default() -> Self {
54        Self {
55            num_walks: 10,
56            walk_length: 80,
57            p: 1.0,
58            q: 1.0,
59            random_seed: 42,
60        }
61    }
62}
63
64/// Full Node2Vec configuration (walks + skip-gram training).
65#[derive(Debug, Clone)]
66pub struct Node2VecConfig {
67    /// Walk phase configuration.
68    pub walk: Node2VecWalkConfig,
69    /// Embedding dimension.
70    pub embedding_dim: usize,
71    /// Skip-gram context window radius (both sides).
72    pub window_size: usize,
73    /// Number of training epochs over the full walk corpus.
74    pub num_epochs: usize,
75    /// Initial learning rate for skip-gram SGD.
76    pub learning_rate: f64,
77    /// Whether to normalize final embeddings to unit length.
78    pub normalize: bool,
79}
80
81impl Default for Node2VecConfig {
82    fn default() -> Self {
83        Self {
84            walk: Node2VecWalkConfig::default(),
85            embedding_dim: 128,
86            window_size: 5,
87            num_epochs: 5,
88            learning_rate: 0.025,
89            normalize: true,
90        }
91    }
92}
93
94// ─── Output ──────────────────────────────────────────────────────────────────
95
96/// Computed Node2Vec embeddings.
97#[derive(Debug, Clone)]
98pub struct Node2VecEmbeddings {
99    /// Map from node URI/label → embedding vector.
100    pub embeddings: HashMap<String, Vec<f64>>,
101    /// Dimension of each embedding vector.
102    pub dim: usize,
103    /// Total number of random-walk steps generated.
104    pub total_walk_steps: usize,
105}
106
107impl Node2VecEmbeddings {
108    /// Return the embedding for a specific node, if present.
109    pub fn get(&self, node: &str) -> Option<&[f64]> {
110        self.embeddings.get(node).map(|v| v.as_slice())
111    }
112
113    /// Cosine similarity between two node embeddings.
114    ///
115    /// Returns `None` if either node is missing or an embedding has zero norm.
116    pub fn cosine_similarity(&self, a: &str, b: &str) -> Option<f64> {
117        let ea = self.embeddings.get(a)?;
118        let eb = self.embeddings.get(b)?;
119
120        let dot: f64 = ea.iter().zip(eb.iter()).map(|(x, y)| x * y).sum();
121        let norm_a: f64 = ea.iter().map(|x| x * x).sum::<f64>().sqrt();
122        let norm_b: f64 = eb.iter().map(|x| x * x).sum::<f64>().sqrt();
123
124        if norm_a < 1e-12 || norm_b < 1e-12 {
125            None
126        } else {
127            Some(dot / (norm_a * norm_b))
128        }
129    }
130
131    /// Return the `k` most similar nodes to `query` by cosine similarity.
132    pub fn top_k_similar(&self, query: &str, k: usize) -> Vec<(String, f64)> {
133        let Some(eq) = self.embeddings.get(query) else {
134            return vec![];
135        };
136
137        let norm_q: f64 = eq.iter().map(|x| x * x).sum::<f64>().sqrt();
138        if norm_q < 1e-12 {
139            return vec![];
140        }
141
142        let mut scored: Vec<(String, f64)> = self
143            .embeddings
144            .iter()
145            .filter(|(node, _)| node.as_str() != query)
146            .map(|(node, emb)| {
147                let dot: f64 = emb.iter().zip(eq.iter()).map(|(x, y)| x * y).sum();
148                let norm_e: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
149                let sim = if norm_e < 1e-12 {
150                    0.0
151                } else {
152                    dot / (norm_q * norm_e)
153                };
154                (node.clone(), sim)
155            })
156            .collect();
157
158        scored.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
159        scored.truncate(k);
160        scored
161    }
162}
163
164// ─── Alias sampling helper ────────────────────────────────────────────────────
165
166/// Vose's alias method for O(1) sampling from a discrete distribution.
167///
168/// Reference: <https://www.keithschwarz.com/darts-dice-coins/>
169struct AliasTable {
170    prob: Vec<f64>,
171    alias: Vec<usize>,
172}
173
174impl AliasTable {
175    /// Build a table from a slice of unnormalized weights.
176    fn build(weights: &[f64]) -> Option<Self> {
177        let n = weights.len();
178        if n == 0 {
179            return None;
180        }
181
182        let sum: f64 = weights.iter().sum();
183        if sum <= 0.0 {
184            return None;
185        }
186
187        // Normalize.
188        let prob_norm: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
189
190        let mut small: Vec<usize> = Vec::with_capacity(n);
191        let mut large: Vec<usize> = Vec::with_capacity(n);
192        let mut prob = prob_norm.clone();
193        let mut alias = vec![0usize; n];
194
195        for (i, &p) in prob_norm.iter().enumerate() {
196            if p < 1.0 {
197                small.push(i);
198            } else {
199                large.push(i);
200            }
201        }
202
203        while !small.is_empty() && !large.is_empty() {
204            let l = small.pop().expect("checked non-empty");
205            let g = large.last().copied().expect("checked non-empty");
206
207            alias[l] = g;
208            prob[g] -= 1.0 - prob[l];
209
210            if prob[g] < 1.0 {
211                large.pop();
212                small.push(g);
213            }
214        }
215
216        Some(Self { prob, alias })
217    }
218
219    /// Draw one sample index in O(1).
220    fn sample(&self, rng: &mut CoreRandom<StdRng>) -> usize {
221        let n = self.prob.len();
222        // Pick a column uniformly at random.
223        let i = (rng.random_range(0.0..1.0) * n as f64) as usize;
224        let i = i.min(n - 1);
225        // Flip a biased coin.
226        if rng.random_range(0.0..1.0) < self.prob[i] {
227            i
228        } else {
229            self.alias[i]
230        }
231    }
232}
233
234// ─── Pre-computed transition tables ──────────────────────────────────────────
235
236/// Alias table keyed by (previous_node, current_node) → table for next step.
237type EdgeAlias = HashMap<(NodeIndex, NodeIndex), (Vec<NodeIndex>, AliasTable)>;
238/// Alias table for the *first* step from each node (no previous node yet).
239type NodeAlias = HashMap<NodeIndex, (Vec<NodeIndex>, AliasTable)>;
240
241// ─── Main embedder ────────────────────────────────────────────────────────────
242
243/// Node2Vec graph embedding generator.
244///
245/// # Example
246///
247/// ```rust,ignore
248/// use oxirs_graphrag::{Triple, embeddings::node2vec::{Node2VecConfig, Node2VecEmbedder}};
249///
250/// let triples = vec![
251///     Triple::new("a", "knows", "b"),
252///     Triple::new("b", "knows", "c"),
253/// ];
254/// let embedder = Node2VecEmbedder::new(Node2VecConfig::default());
255/// let embs = embedder.embed(&triples)?;
256/// println!("a→b similarity: {:?}", embs.cosine_similarity("a", "b"));
257/// ```
258pub struct Node2VecEmbedder {
259    config: Node2VecConfig,
260}
261
262impl Node2VecEmbedder {
263    /// Create a new embedder with the given configuration.
264    pub fn new(config: Node2VecConfig) -> Self {
265        Self { config }
266    }
267
268    /// Create a new embedder with default configuration.
269    pub fn with_defaults() -> Self {
270        Self::new(Node2VecConfig::default())
271    }
272
273    /// Build embeddings for all nodes reachable via `triples`.
274    pub fn embed(&self, triples: &[Triple]) -> GraphRAGResult<Node2VecEmbeddings> {
275        // 1. Build adjacency graph.
276        let (graph, node_map) = self.build_graph(triples);
277
278        if graph.node_count() == 0 {
279            return Ok(Node2VecEmbeddings {
280                embeddings: HashMap::new(),
281                dim: self.config.embedding_dim,
282                total_walk_steps: 0,
283            });
284        }
285
286        let mut rng = seeded_rng(self.config.walk.random_seed);
287
288        // 2. Pre-compute alias tables for O(1) per-step sampling.
289        let (node_alias, edge_alias) = self.build_alias_tables(&graph)?;
290
291        // 3. Simulate random walks.
292        let (walks, total_steps) =
293            self.simulate_walks(&graph, &node_map, &node_alias, &edge_alias, &mut rng);
294
295        // 4. Train skip-gram.
296        let embeddings = self.train_skip_gram(&walks, &node_map, &mut rng)?;
297
298        Ok(Node2VecEmbeddings {
299            embeddings,
300            dim: self.config.embedding_dim,
301            total_walk_steps: total_steps,
302        })
303    }
304
305    // ─── Graph construction ───────────────────────────────────────────────────
306
307    fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
308        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
309        let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
310
311        for triple in triples {
312            let s = *node_map
313                .entry(triple.subject.clone())
314                .or_insert_with(|| graph.add_node(triple.subject.clone()));
315            let o = *node_map
316                .entry(triple.object.clone())
317                .or_insert_with(|| graph.add_node(triple.object.clone()));
318            if s != o && graph.find_edge(s, o).is_none() {
319                graph.add_edge(s, o, ());
320            }
321        }
322
323        (graph, node_map)
324    }
325
326    // ─── Alias table construction ─────────────────────────────────────────────
327
328    /// Build alias tables for:
329    /// * first-step transitions from each node (uniform over neighbors), and
330    /// * second-order transitions given (prev, cur) pairs.
331    fn build_alias_tables(
332        &self,
333        graph: &UnGraph<String, ()>,
334    ) -> GraphRAGResult<(NodeAlias, EdgeAlias)> {
335        let p = self.config.walk.p;
336        let q = self.config.walk.q;
337
338        // Per-node table (first step).
339        let mut node_alias: NodeAlias = HashMap::new();
340        for node in graph.node_indices() {
341            let neighbors: Vec<NodeIndex> = graph.neighbors(node).collect();
342            if neighbors.is_empty() {
343                continue;
344            }
345            let weights: Vec<f64> = vec![1.0; neighbors.len()];
346            if let Some(table) = AliasTable::build(&weights) {
347                node_alias.insert(node, (neighbors, table));
348            }
349        }
350
351        // Per-edge table (second-order Markov transitions).
352        let mut edge_alias: EdgeAlias = HashMap::new();
353        for edge in graph.edge_indices() {
354            let (u, v) = graph
355                .edge_endpoints(edge)
356                .ok_or_else(|| GraphRAGError::InternalError("bad edge".to_string()))?;
357
358            // Process both directed orientations of the undirected edge.
359            for (prev, cur) in [(u, v), (v, u)] {
360                let neighbors: Vec<NodeIndex> = graph.neighbors(cur).collect();
361                if neighbors.is_empty() {
362                    continue;
363                }
364                let weights: Vec<f64> = neighbors
365                    .iter()
366                    .map(|&next| {
367                        if next == prev {
368                            // Return to previous node → weight = 1/p.
369                            1.0 / p
370                        } else if graph.find_edge(prev, next).is_some() {
371                            // Neighbor of prev → stay in neighborhood → weight 1.
372                            1.0
373                        } else {
374                            // Further away → weight = 1/q.
375                            1.0 / q
376                        }
377                    })
378                    .collect();
379
380                if let Some(table) = AliasTable::build(&weights) {
381                    edge_alias.insert((prev, cur), (neighbors, table));
382                }
383            }
384        }
385
386        Ok((node_alias, edge_alias))
387    }
388
389    // ─── Random walk simulation ───────────────────────────────────────────────
390
391    fn simulate_walks(
392        &self,
393        graph: &UnGraph<String, ()>,
394        node_map: &HashMap<String, NodeIndex>,
395        node_alias: &NodeAlias,
396        edge_alias: &EdgeAlias,
397        rng: &mut CoreRandom<StdRng>,
398    ) -> (Vec<Vec<String>>, usize) {
399        let walk_length = self.config.walk.walk_length;
400        let num_walks = self.config.walk.num_walks;
401
402        let node_indices: Vec<NodeIndex> = node_map.values().copied().collect();
403        let mut walks: Vec<Vec<String>> = Vec::with_capacity(num_walks * node_indices.len());
404        let mut total_steps = 0usize;
405
406        for _ in 0..num_walks {
407            // Shuffle node order per walk round.
408            let mut order = node_indices.clone();
409            for i in (1..order.len()).rev() {
410                let j = (rng.random_range(0.0..1.0) * (i + 1) as f64) as usize;
411                order.swap(i, j.min(i));
412            }
413
414            for &start in &order {
415                let walk = self.single_walk(graph, start, walk_length, node_alias, edge_alias, rng);
416                total_steps += walk.len();
417                walks.push(walk);
418            }
419        }
420
421        (walks, total_steps)
422    }
423
424    fn single_walk(
425        &self,
426        graph: &UnGraph<String, ()>,
427        start: NodeIndex,
428        walk_length: usize,
429        node_alias: &NodeAlias,
430        edge_alias: &EdgeAlias,
431        rng: &mut CoreRandom<StdRng>,
432    ) -> Vec<String> {
433        let mut walk: Vec<String> = Vec::with_capacity(walk_length);
434
435        // Add start node label.
436        if let Some(label) = graph.node_weight(start) {
437            walk.push(label.clone());
438        } else {
439            return walk;
440        }
441
442        let mut current = start;
443        let mut prev: Option<NodeIndex> = None;
444
445        for _ in 1..walk_length {
446            let next = if let Some(p) = prev {
447                // Second-order: use per-edge alias table.
448                if let Some((neighbors, table)) = edge_alias.get(&(p, current)) {
449                    let idx = table.sample(rng);
450                    neighbors.get(idx).copied()
451                } else {
452                    None
453                }
454            } else {
455                // First step: uniform over neighbors.
456                if let Some((neighbors, table)) = node_alias.get(&current) {
457                    let idx = table.sample(rng);
458                    neighbors.get(idx).copied()
459                } else {
460                    None
461                }
462            };
463
464            match next {
465                Some(n) => {
466                    if let Some(label) = graph.node_weight(n) {
467                        walk.push(label.clone());
468                    }
469                    prev = Some(current);
470                    current = n;
471                }
472                None => break, // Dead end (isolated node after first step).
473            }
474        }
475
476        walk
477    }
478
479    // ─── Skip-gram training ───────────────────────────────────────────────────
480
481    /// Simplified skip-gram with stochastic gradient descent.
482    ///
483    /// For each (target, context) pair within the window, the update maximizes
484    /// the inner-product similarity between target and context embeddings.
485    /// Negative sampling is approximated via L2 regularization to bound norms.
486    fn train_skip_gram(
487        &self,
488        walks: &[Vec<String>],
489        node_map: &HashMap<String, NodeIndex>,
490        rng: &mut CoreRandom<StdRng>,
491    ) -> GraphRAGResult<HashMap<String, Vec<f64>>> {
492        let dim = self.config.embedding_dim;
493        let window = self.config.window_size;
494        let lr_init = self.config.learning_rate;
495
496        // Initialize embeddings with small random values.
497        let mut embeddings: HashMap<String, Vec<f64>> = HashMap::new();
498        for node_label in node_map.keys() {
499            let emb: Vec<f64> = (0..dim)
500                .map(|_| (rng.random_range(0.0..1.0) - 0.5) / dim as f64)
501                .collect();
502            embeddings.insert(node_label.clone(), emb);
503        }
504
505        // Context embeddings (separate parameter matrix as in word2vec).
506        let mut ctx_embeddings: HashMap<String, Vec<f64>> = HashMap::new();
507        for node_label in node_map.keys() {
508            ctx_embeddings.insert(node_label.clone(), vec![0.0f64; dim]);
509        }
510
511        let total_epochs = self.config.num_epochs;
512        let total_pairs: usize = walks
513            .iter()
514            .map(|w| w.len() * (2 * window).min(if w.len() > 1 { w.len() - 1 } else { 0 }))
515            .sum();
516
517        let mut pair_count = 0usize;
518
519        for epoch in 0..total_epochs {
520            // Linear learning rate decay.
521            let lr = lr_init * (1.0 - epoch as f64 / total_epochs as f64).max(0.001);
522
523            for walk in walks {
524                for (i, target) in walk.iter().enumerate() {
525                    let start = i.saturating_sub(window);
526                    let end = (i + window + 1).min(walk.len());
527
528                    for (j, context) in walk[start..end].iter().enumerate() {
529                        let abs_j = start + j;
530                        if abs_j == i || context == target {
531                            continue;
532                        }
533
534                        // Decay lr further within epoch based on pair index.
535                        let local_lr =
536                            lr * (1.0 - pair_count as f64 / (total_pairs + 1) as f64).max(0.001);
537
538                        self.sgd_update(
539                            target,
540                            context,
541                            &mut embeddings,
542                            &mut ctx_embeddings,
543                            local_lr,
544                            dim,
545                        );
546
547                        pair_count += 1;
548                    }
549                }
550            }
551        }
552
553        // Merge context vectors into main embeddings (average of both).
554        for (node, emb) in &mut embeddings {
555            if let Some(ctx) = ctx_embeddings.get(node) {
556                for (e, c) in emb.iter_mut().zip(ctx.iter()) {
557                    *e = (*e + c) / 2.0;
558                }
559            }
560        }
561
562        // Optionally normalize to unit vectors.
563        if self.config.normalize {
564            for emb in embeddings.values_mut() {
565                let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
566                if norm > 1e-12 {
567                    for v in emb.iter_mut() {
568                        *v /= norm;
569                    }
570                }
571            }
572        }
573
574        Ok(embeddings)
575    }
576
577    /// Single SGD step for one (target, context) positive pair.
578    ///
579    /// Gradient of log-sigmoid loss:
580    ///   L = log σ(v_t · v_c)
581    ///   ∂L/∂v_t = (1 - σ(score)) · v_c
582    ///   ∂L/∂v_c = (1 - σ(score)) · v_t
583    fn sgd_update(
584        &self,
585        target: &str,
586        context: &str,
587        embeddings: &mut HashMap<String, Vec<f64>>,
588        ctx_embeddings: &mut HashMap<String, Vec<f64>>,
589        lr: f64,
590        dim: usize,
591    ) {
592        // Compute inner product (score).
593        let score = {
594            let Some(te) = embeddings.get(target) else {
595                return;
596            };
597            let Some(ce) = ctx_embeddings.get(context) else {
598                return;
599            };
600            te.iter().zip(ce.iter()).map(|(a, b)| a * b).sum::<f64>()
601        };
602
603        // Sigmoid of score → gradient weight.
604        let sigma = 1.0 / (1.0 + (-score).exp());
605        let grad = (1.0 - sigma) * lr;
606
607        // Capture snapshots to avoid borrow conflicts.
608        let te_snap: Vec<f64> = match embeddings.get(target) {
609            Some(v) => v.clone(),
610            None => return,
611        };
612        let ce_snap: Vec<f64> = match ctx_embeddings.get(context) {
613            Some(v) => v.clone(),
614            None => return,
615        };
616
617        // Update target embedding.
618        if let Some(te) = embeddings.get_mut(target) {
619            for k in 0..dim {
620                te[k] += grad * ce_snap[k];
621            }
622        }
623
624        // Update context embedding.
625        if let Some(ce) = ctx_embeddings.get_mut(context) {
626            for k in 0..dim {
627                ce[k] += grad * te_snap[k];
628            }
629        }
630    }
631}
632
633// ─── Tests ───────────────────────────────────────────────────────────────────
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::Triple;
639
640    fn ring_triples(n: usize) -> Vec<Triple> {
641        (0..n)
642            .map(|i| {
643                Triple::new(
644                    format!("node_{}", i),
645                    "connects",
646                    format!("node_{}", (i + 1) % n),
647                )
648            })
649            .collect()
650    }
651
652    fn complete_triples(n: usize) -> Vec<Triple> {
653        let mut ts = Vec::new();
654        for i in 0..n {
655            for j in i + 1..n {
656                ts.push(Triple::new(format!("n{}", i), "edge", format!("n{}", j)));
657            }
658        }
659        ts
660    }
661
662    fn small_config() -> Node2VecConfig {
663        Node2VecConfig {
664            walk: Node2VecWalkConfig {
665                num_walks: 5,
666                walk_length: 10,
667                p: 1.0,
668                q: 1.0,
669                random_seed: 99,
670            },
671            embedding_dim: 16,
672            window_size: 2,
673            num_epochs: 3,
674            learning_rate: 0.05,
675            normalize: true,
676        }
677    }
678
679    #[test]
680    fn test_embed_produces_correct_number_of_embeddings() {
681        let triples = ring_triples(6);
682        let embedder = Node2VecEmbedder::new(small_config());
683        let result = embedder.embed(&triples).expect("embed failed");
684        // A 6-node ring should produce 6 embeddings.
685        assert_eq!(result.embeddings.len(), 6);
686        assert_eq!(result.dim, 16);
687    }
688
689    #[test]
690    fn test_embed_correct_dimension() {
691        let triples = complete_triples(4);
692        let embedder = Node2VecEmbedder::new(small_config());
693        let result = embedder.embed(&triples).expect("embed failed");
694        for emb in result.embeddings.values() {
695            assert_eq!(emb.len(), 16);
696        }
697    }
698
699    #[test]
700    fn test_normalized_embeddings_have_unit_norm() {
701        let triples = ring_triples(5);
702        let config = Node2VecConfig {
703            normalize: true,
704            ..small_config()
705        };
706        let embedder = Node2VecEmbedder::new(config);
707        let result = embedder.embed(&triples).expect("embed failed");
708        for (node, emb) in &result.embeddings {
709            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
710            assert!(
711                (norm - 1.0).abs() < 1e-6,
712                "node {} has non-unit norm {:.6}",
713                node,
714                norm
715            );
716        }
717    }
718
719    #[test]
720    fn test_cosine_similarity_in_range() {
721        let triples = complete_triples(5);
722        let embedder = Node2VecEmbedder::new(small_config());
723        let result = embedder.embed(&triples).expect("embed failed");
724
725        let nodes: Vec<String> = result.embeddings.keys().cloned().collect();
726        if nodes.len() >= 2 {
727            if let Some(sim) = result.cosine_similarity(&nodes[0], &nodes[1]) {
728                assert!(
729                    (-1.0 - 1e-9..=1.0 + 1e-9).contains(&sim),
730                    "cosine similarity out of range: {}",
731                    sim
732                );
733            }
734        }
735    }
736
737    #[test]
738    fn test_top_k_similar_returns_at_most_k() {
739        let triples = ring_triples(8);
740        let embedder = Node2VecEmbedder::new(small_config());
741        let result = embedder.embed(&triples).expect("embed failed");
742
743        let similar = result.top_k_similar("node_0", 3);
744        assert!(similar.len() <= 3);
745    }
746
747    #[test]
748    fn test_empty_triples_returns_empty_embeddings() {
749        let embedder = Node2VecEmbedder::new(small_config());
750        let result = embedder.embed(&[]).expect("embed failed");
751        assert!(result.embeddings.is_empty());
752        assert_eq!(result.total_walk_steps, 0);
753    }
754
755    #[test]
756    fn test_single_node_isolated() {
757        // A single triple where subject == object is filtered out,
758        // but two different nodes with no shared edges should still produce embeddings.
759        let triples = vec![Triple::new("a", "r", "b")];
760        let embedder = Node2VecEmbedder::new(small_config());
761        let result = embedder.embed(&triples).expect("embed failed");
762        assert_eq!(result.embeddings.len(), 2);
763    }
764
765    #[test]
766    fn test_walk_bias_dfs_vs_bfs() {
767        // With q=0.1 (DFS), walks explore further; with q=10.0 (BFS), walks stay near.
768        let triples = ring_triples(10);
769
770        let dfs_config = Node2VecConfig {
771            walk: Node2VecWalkConfig {
772                num_walks: 3,
773                walk_length: 20,
774                p: 0.25,
775                q: 0.25,
776                random_seed: 1,
777            },
778            ..small_config()
779        };
780        let bfs_config = Node2VecConfig {
781            walk: Node2VecWalkConfig {
782                num_walks: 3,
783                walk_length: 20,
784                p: 4.0,
785                q: 4.0,
786                random_seed: 1,
787            },
788            ..small_config()
789        };
790
791        let embedder_dfs = Node2VecEmbedder::new(dfs_config);
792        let embedder_bfs = Node2VecEmbedder::new(bfs_config);
793
794        let res_dfs = embedder_dfs.embed(&triples).expect("dfs embed failed");
795        let res_bfs = embedder_bfs.embed(&triples).expect("bfs embed failed");
796
797        // Both should produce embeddings for all 10 nodes.
798        assert_eq!(res_dfs.embeddings.len(), 10);
799        assert_eq!(res_bfs.embeddings.len(), 10);
800    }
801
802    #[test]
803    fn test_total_walk_steps_is_plausible() {
804        let n = 5usize;
805        let triples = ring_triples(n);
806        let config = Node2VecConfig {
807            walk: Node2VecWalkConfig {
808                num_walks: 2,
809                walk_length: 10,
810                ..Default::default()
811            },
812            ..small_config()
813        };
814        let embedder = Node2VecEmbedder::new(config);
815        let result = embedder.embed(&triples).expect("embed failed");
816
817        // Each node starts a walk of up to `walk_length` steps.
818        // total_steps ≥ num_nodes × num_walks (at least 1 step per walk).
819        assert!(
820            result.total_walk_steps >= n * 2,
821            "expected ≥{} steps, got {}",
822            n * 2,
823            result.total_walk_steps
824        );
825        assert!(
826            result.total_walk_steps <= n * 2 * 10 + n * 2,
827            "unexpectedly many steps: {}",
828            result.total_walk_steps
829        );
830    }
831
832    #[test]
833    fn test_alias_table_samples_valid_index() {
834        let weights = vec![1.0, 2.0, 3.0, 4.0];
835        let table = AliasTable::build(&weights).expect("alias build failed");
836        let mut rng = seeded_rng(777);
837        for _ in 0..100 {
838            let idx = table.sample(&mut rng);
839            assert!(idx < weights.len());
840        }
841    }
842
843    #[test]
844    fn test_alias_table_uniform_weights() {
845        // All weights equal → each index should be sampled roughly equally.
846        let weights = vec![1.0; 4];
847        let table = AliasTable::build(&weights).expect("alias build failed");
848        let mut rng = seeded_rng(42);
849        let mut counts = [0usize; 4];
850        for _ in 0..4000 {
851            counts[table.sample(&mut rng)] += 1;
852        }
853        // Expect each bucket ~1000 ± 200.
854        for c in counts {
855            assert!(c > 800 && c < 1200, "bucket count out of range: {}", c);
856        }
857    }
858}