Skip to main content

oxirs_embed/models/
graphsage.rs

1//! GraphSAGE: Inductive Representation Learning on Large Graphs
2//!
3//! Hamilton, Ying, Leskovec (2017) - NeurIPS
4//!
5//! Key idea: learn aggregation functions from node neighborhoods
6//! rather than training per-node embeddings (transductive).
7//! This enables inductive inference on unseen nodes.
8
9use crate::EmbeddingError;
10use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Aggregation strategy for neighborhood sampling
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum AggregatorType {
17    /// Mean of neighbor features (GCN-like, simple and effective)
18    Mean,
19    /// Max-pool with learned MLP (captures representative features)
20    MaxPool { hidden_dim: usize },
21    /// Concatenation + mean (original GraphSAGE mean aggregator)
22    MeanConcat,
23}
24
25/// GraphSAGE configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GraphSageConfig {
28    /// Dimensionality of input node features
29    pub input_dim: usize,
30    /// Dimensionality of hidden layers
31    pub hidden_dims: Vec<usize>,
32    /// Dimensionality of output embeddings
33    pub output_dim: usize,
34    /// Aggregation strategy
35    pub aggregator: AggregatorType,
36    /// Number of neighbor samples per hop, e.g. [25, 10]
37    pub num_samples: Vec<usize>,
38    /// Dropout rate (applied during training forward pass)
39    pub dropout: f64,
40    /// Learning rate for parameter updates
41    pub learning_rate: f64,
42    /// Number of training epochs
43    pub epochs: usize,
44    /// Mini-batch size for training
45    pub batch_size: usize,
46    /// L2-normalize output embeddings
47    pub normalize_output: bool,
48    /// Random seed for reproducibility
49    pub seed: u64,
50}
51
52impl Default for GraphSageConfig {
53    fn default() -> Self {
54        Self {
55            input_dim: 64,
56            hidden_dims: vec![256, 128],
57            output_dim: 64,
58            aggregator: AggregatorType::Mean,
59            num_samples: vec![25, 10],
60            dropout: 0.5,
61            learning_rate: 0.01,
62            epochs: 10,
63            batch_size: 512,
64            normalize_output: true,
65            seed: 42,
66        }
67    }
68}
69
70/// Node features and graph structure for GraphSAGE
71#[derive(Debug, Clone)]
72pub struct GraphData {
73    /// Node feature matrix: `node_features[i]` = feature vector for node i
74    pub node_features: Vec<Vec<f64>>,
75    /// Adjacency list: `adjacency[i]` = list of neighbor indices
76    pub adjacency: Vec<Vec<usize>>,
77    /// Optional node labels for supervised training
78    pub labels: Option<Vec<usize>>,
79}
80
81impl GraphData {
82    /// Create a new graph from node features and adjacency list.
83    ///
84    /// Validates that all adjacency indices are within bounds.
85    pub fn new(features: Vec<Vec<f64>>, adjacency: Vec<Vec<usize>>) -> Result<Self> {
86        let num_nodes = features.len();
87        if adjacency.len() != num_nodes {
88            return Err(anyhow!(
89                "Adjacency list length {} does not match number of nodes {}",
90                adjacency.len(),
91                num_nodes
92            ));
93        }
94        // Validate all neighbor indices
95        for (i, neighbors) in adjacency.iter().enumerate() {
96            for &neighbor in neighbors {
97                if neighbor >= num_nodes {
98                    return Err(anyhow!(
99                        "Node {} has neighbor index {} which is out of bounds (num_nodes={})",
100                        i,
101                        neighbor,
102                        num_nodes
103                    ));
104                }
105            }
106        }
107        // Validate feature dimensions are consistent
108        if let Some(first) = features.first() {
109            let dim = first.len();
110            for (i, feat) in features.iter().enumerate() {
111                if feat.len() != dim {
112                    return Err(anyhow!(
113                        "Node {} has feature dimension {} but expected {}",
114                        i,
115                        feat.len(),
116                        dim
117                    ));
118                }
119            }
120        }
121        Ok(Self {
122            node_features: features,
123            adjacency,
124            labels: None,
125        })
126    }
127
128    /// Number of nodes in the graph
129    pub fn num_nodes(&self) -> usize {
130        self.node_features.len()
131    }
132
133    /// Dimensionality of node features
134    pub fn feature_dim(&self) -> usize {
135        self.node_features.first().map(|f| f.len()).unwrap_or(0)
136    }
137
138    /// Get the neighbors of a node
139    pub fn neighbors(&self, node: usize) -> &[usize] {
140        if node < self.adjacency.len() {
141            &self.adjacency[node]
142        } else {
143            &[]
144        }
145    }
146
147    /// Sample up to k neighbors uniformly at random using a simple LCG PRNG
148    pub fn sample_neighbors(&self, node: usize, k: usize, rng: &mut SimpleLcg) -> Vec<usize> {
149        let neighbors = self.neighbors(node);
150        if neighbors.is_empty() {
151            return Vec::new();
152        }
153        if neighbors.len() <= k {
154            return neighbors.to_vec();
155        }
156        // Fisher-Yates partial shuffle to sample k items
157        let mut indices: Vec<usize> = (0..neighbors.len()).collect();
158        for i in 0..k {
159            let j = i + (rng.next_usize() % (indices.len() - i));
160            indices.swap(i, j);
161        }
162        indices[..k].iter().map(|&idx| neighbors[idx]).collect()
163    }
164
165    /// Set node labels for supervised training
166    pub fn with_labels(mut self, labels: Vec<usize>) -> Result<Self> {
167        if labels.len() != self.num_nodes() {
168            return Err(anyhow!(
169                "Labels length {} does not match num_nodes {}",
170                labels.len(),
171                self.num_nodes()
172            ));
173        }
174        self.labels = Some(labels);
175        Ok(self)
176    }
177}
178
179/// Simple Linear Congruential Generator for reproducible sampling.
180/// Avoids the need for external rand crate.
181#[derive(Debug, Clone)]
182pub struct SimpleLcg {
183    state: u64,
184}
185
186impl SimpleLcg {
187    /// Create a new LCG with the given seed
188    pub fn new(seed: u64) -> Self {
189        Self {
190            state: seed.wrapping_add(1),
191        }
192    }
193
194    /// Generate the next random u64
195    pub fn next_u64(&mut self) -> u64 {
196        // LCG parameters from Numerical Recipes
197        self.state = self
198            .state
199            .wrapping_mul(6364136223846793005)
200            .wrapping_add(1442695040888963407);
201        self.state
202    }
203
204    /// Generate a random usize in range [0, n)
205    pub fn next_usize(&mut self) -> usize {
206        self.next_u64() as usize
207    }
208
209    /// Generate a random f64 in [0, 1)
210    pub fn next_f64(&mut self) -> f64 {
211        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
212    }
213
214    /// Generate a random f64 in [-scale, scale)
215    pub fn next_f64_range(&mut self, scale: f64) -> f64 {
216        (self.next_f64() * 2.0 - 1.0) * scale
217    }
218}
219
220/// Dense linear layer: output = W * input + bias
221#[derive(Debug, Clone)]
222struct DenseLayer {
223    weights: Vec<Vec<f64>>, // [output_dim][input_dim]
224    bias: Vec<f64>,
225    input_dim: usize,
226    output_dim: usize,
227}
228
229impl DenseLayer {
230    /// Xavier/Glorot uniform initialization
231    fn new(input_dim: usize, output_dim: usize, rng: &mut SimpleLcg) -> Self {
232        let scale = (6.0 / (input_dim + output_dim) as f64).sqrt();
233        let weights = (0..output_dim)
234            .map(|_| (0..input_dim).map(|_| rng.next_f64_range(scale)).collect())
235            .collect();
236        let bias = vec![0.0; output_dim];
237        Self {
238            weights,
239            bias,
240            input_dim,
241            output_dim,
242        }
243    }
244
245    /// Forward pass: compute W*x + b
246    fn forward(&self, input: &[f64]) -> Vec<f64> {
247        debug_assert_eq!(input.len(), self.input_dim);
248        let mut output = self.bias.clone();
249        for (i, row) in self.weights.iter().enumerate() {
250            for (j, &w) in row.iter().enumerate() {
251                output[i] += w * input[j];
252            }
253        }
254        output
255    }
256
257    /// ReLU activation: max(0, x)
258    fn relu(x: &[f64]) -> Vec<f64> {
259        x.iter().map(|&v| v.max(0.0)).collect()
260    }
261}
262
263/// A single GraphSAGE layer that aggregates neighbor information
264#[derive(Debug, Clone)]
265struct SageLayer {
266    /// Transform for self node features
267    self_transform: DenseLayer,
268    /// Transform for aggregated neighbor features
269    neigh_transform: DenseLayer,
270    /// Optional pooling MLP for MaxPool aggregator
271    pool_mlp: Option<DenseLayer>,
272    /// Output dimensionality
273    output_dim: usize,
274}
275
276impl SageLayer {
277    /// Create a new SAGE layer.
278    ///
279    /// For MeanConcat: input to self_transform is input_dim,
280    /// input to neigh_transform is neigh_dim.
281    /// Final output is concat of both => output_dim each.
282    fn new(
283        input_dim: usize,
284        neigh_dim: usize,
285        output_dim: usize,
286        pool_hidden: Option<usize>,
287        rng: &mut SimpleLcg,
288    ) -> Self {
289        let self_transform = DenseLayer::new(input_dim, output_dim, rng);
290        let neigh_transform = DenseLayer::new(neigh_dim, output_dim, rng);
291        let pool_mlp = pool_hidden.map(|hidden| DenseLayer::new(neigh_dim, hidden, rng));
292        Self {
293            self_transform,
294            neigh_transform,
295            pool_mlp,
296            output_dim,
297        }
298    }
299
300    /// Mean aggregation: element-wise mean of neighbor features
301    fn aggregate_mean(neighbor_features: &[Vec<f64>]) -> Vec<f64> {
302        if neighbor_features.is_empty() {
303            return Vec::new();
304        }
305        let dim = neighbor_features[0].len();
306        let mut result = vec![0.0f64; dim];
307        for feat in neighbor_features {
308            for (r, &v) in result.iter_mut().zip(feat.iter()) {
309                *r += v;
310            }
311        }
312        let n = neighbor_features.len() as f64;
313        result.iter_mut().for_each(|v| *v /= n);
314        result
315    }
316
317    /// MaxPool aggregation: apply MLP then element-wise max
318    fn aggregate_maxpool(neighbor_features: &[Vec<f64>], pool_layer: &DenseLayer) -> Vec<f64> {
319        if neighbor_features.is_empty() {
320            return Vec::new();
321        }
322        let transformed: Vec<Vec<f64>> = neighbor_features
323            .iter()
324            .map(|feat| DenseLayer::relu(&pool_layer.forward(feat)))
325            .collect();
326        let dim = transformed[0].len();
327        let mut result = vec![f64::NEG_INFINITY; dim];
328        for feat in &transformed {
329            for (r, &v) in result.iter_mut().zip(feat.iter()) {
330                if v > *r {
331                    *r = v;
332                }
333            }
334        }
335        result
336    }
337
338    /// Forward pass for a single node.
339    ///
340    /// Computes: h = ReLU(W_self * self_feat + W_neigh * agg_neigh)
341    /// Then normalizes if configured.
342    fn forward(
343        &self,
344        self_feat: &[f64],
345        neighbor_feats: &[Vec<f64>],
346        aggregator: &AggregatorType,
347    ) -> Vec<f64> {
348        let agg = if neighbor_feats.is_empty() {
349            vec![0.0; self_feat.len()]
350        } else {
351            match aggregator {
352                AggregatorType::Mean | AggregatorType::MeanConcat => {
353                    Self::aggregate_mean(neighbor_feats)
354                }
355                AggregatorType::MaxPool { .. } => {
356                    if let Some(pool_layer) = &self.pool_mlp {
357                        Self::aggregate_maxpool(neighbor_feats, pool_layer)
358                    } else {
359                        Self::aggregate_mean(neighbor_feats)
360                    }
361                }
362            }
363        };
364
365        // Ensure agg has correct size for neigh_transform
366        let agg_padded = if agg.len() != self.neigh_transform.input_dim {
367            let mut padded = vec![0.0f64; self.neigh_transform.input_dim];
368            let copy_len = agg.len().min(self.neigh_transform.input_dim);
369            padded[..copy_len].copy_from_slice(&agg[..copy_len]);
370            padded
371        } else {
372            agg
373        };
374
375        // Ensure self_feat has correct size for self_transform
376        let self_padded = if self_feat.len() != self.self_transform.input_dim {
377            let mut padded = vec![0.0f64; self.self_transform.input_dim];
378            let copy_len = self_feat.len().min(self.self_transform.input_dim);
379            padded[..copy_len].copy_from_slice(&self_feat[..copy_len]);
380            padded
381        } else {
382            self_feat.to_vec()
383        };
384
385        let h_self = self.self_transform.forward(&self_padded);
386        let h_neigh = self.neigh_transform.forward(&agg_padded);
387
388        // Concatenate or add depending on aggregation type
389        let combined = match aggregator {
390            AggregatorType::MeanConcat => {
391                // Concatenate self and neighbor, then project
392                let mut concat = h_self;
393                concat.extend(h_neigh);
394                // For simplicity, we truncate/pad to output_dim
395                concat.truncate(self.output_dim);
396                while concat.len() < self.output_dim {
397                    concat.push(0.0);
398                }
399                concat
400            }
401            _ => {
402                // Element-wise sum
403                h_self
404                    .iter()
405                    .zip(h_neigh.iter())
406                    .map(|(a, b)| a + b)
407                    .collect()
408            }
409        };
410
411        // Apply ReLU activation (not on final layer typically, but standard here)
412        DenseLayer::relu(&combined)
413    }
414}
415
416/// GraphSAGE model for inductive node embedding
417///
418/// Implements the GraphSAGE algorithm from Hamilton et al. (2017).
419/// Key property: can generate embeddings for nodes not seen during training
420/// by aggregating from their neighborhoods.
421#[derive(Debug, Clone)]
422pub struct GraphSage {
423    config: GraphSageConfig,
424    layers: Vec<SageLayer>,
425    rng: SimpleLcg,
426}
427
428impl GraphSage {
429    /// Create a new GraphSAGE model with the given configuration
430    pub fn new(config: GraphSageConfig) -> Result<Self> {
431        if config.input_dim == 0 {
432            return Err(anyhow!("input_dim must be > 0"));
433        }
434        if config.output_dim == 0 {
435            return Err(anyhow!("output_dim must be > 0"));
436        }
437        if config.num_samples.is_empty() {
438            return Err(anyhow!("num_samples must have at least one entry"));
439        }
440
441        let mut rng = SimpleLcg::new(config.seed);
442        let pool_hidden = match &config.aggregator {
443            AggregatorType::MaxPool { hidden_dim } => Some(*hidden_dim),
444            _ => None,
445        };
446
447        // Build layer dimensions
448        // Layer 0: input_dim -> hidden_dims[0]
449        // Layer i: hidden_dims[i-1] -> hidden_dims[i]
450        // Last layer: hidden_dims[-1] -> output_dim
451        let mut dims: Vec<usize> = vec![config.input_dim];
452        dims.extend(config.hidden_dims.iter().copied());
453        dims.push(config.output_dim);
454
455        let num_layers = dims.len() - 1;
456        let mut layers = Vec::with_capacity(num_layers);
457
458        for i in 0..num_layers {
459            let in_dim = dims[i];
460            let out_dim = dims[i + 1];
461            // Neighbor aggregation input dim: same as node feature dim at this layer
462            let neigh_dim = in_dim;
463            layers.push(SageLayer::new(
464                in_dim,
465                neigh_dim,
466                out_dim,
467                pool_hidden,
468                &mut rng,
469            ));
470        }
471
472        Ok(Self {
473            config,
474            layers,
475            rng,
476        })
477    }
478
479    /// Generate embeddings for all nodes via inductive forward pass.
480    ///
481    /// Performs K-hop neighborhood aggregation where K is the number of layers.
482    pub fn embed(&self, graph: &GraphData) -> Result<GraphSageEmbeddings> {
483        if graph.num_nodes() == 0 {
484            return Err(anyhow!("Graph has no nodes"));
485        }
486        if graph.feature_dim() != self.config.input_dim {
487            return Err(anyhow!(
488                "Graph feature dim {} does not match model input_dim {}",
489                graph.feature_dim(),
490                self.config.input_dim
491            ));
492        }
493
494        // Memoization: compute embeddings layer by layer for all nodes
495        // h_prev[node] = node representation at previous layer
496        let mut h_prev: Vec<Vec<f64>> = graph.node_features.clone();
497
498        for (layer_idx, layer) in self.layers.iter().enumerate() {
499            let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(graph.num_nodes());
500
501            // Determine neighbor sample count for this layer
502            let num_samples = self
503                .config
504                .num_samples
505                .get(layer_idx)
506                .copied()
507                .unwrap_or(25);
508
509            // Use a deterministic sampling order for inference
510            let mut local_rng = SimpleLcg::new(self.config.seed.wrapping_add(layer_idx as u64));
511
512            for node in 0..graph.num_nodes() {
513                let sampled = graph.sample_neighbors(node, num_samples, &mut local_rng);
514                let neighbor_feats: Vec<Vec<f64>> = sampled
515                    .iter()
516                    .filter_map(|&n| h_prev.get(n).cloned())
517                    .collect();
518
519                let self_feat = h_prev.get(node).cloned().unwrap_or_default();
520                let h = layer.forward(&self_feat, &neighbor_feats, &self.config.aggregator);
521                h_next.push(h);
522            }
523
524            h_prev = h_next;
525        }
526
527        // Apply L2 normalization to final embeddings
528        let embeddings: Vec<Vec<f64>> = if self.config.normalize_output {
529            h_prev.into_iter().map(|v| Self::normalize(&v)).collect()
530        } else {
531            h_prev
532        };
533
534        let dim = self.config.output_dim;
535        let num_nodes = graph.num_nodes();
536
537        Ok(GraphSageEmbeddings {
538            embeddings,
539            config: self.config.clone(),
540            num_nodes,
541            dim,
542        })
543    }
544
545    /// Train the model with unsupervised random-walk loss.
546    ///
547    /// Uses a simple positive/negative sampling strategy where:
548    /// - Positive pairs: nodes connected by an edge (BFS neighbors)
549    /// - Negative pairs: randomly sampled unconnected nodes
550    ///
551    /// Loss: -log(sigma(pos_score)) - log(1 - sigma(neg_score))
552    pub fn train_unsupervised(&mut self, graph: &GraphData) -> Result<GraphSageTrainingMetrics> {
553        if graph.num_nodes() < 2 {
554            return Err(anyhow!("Graph must have at least 2 nodes for training"));
555        }
556        if graph.feature_dim() != self.config.input_dim {
557            return Err(anyhow!(
558                "Graph feature dim {} != model input_dim {}",
559                graph.feature_dim(),
560                self.config.input_dim
561            ));
562        }
563
564        let mut loss_history = Vec::with_capacity(self.config.epochs);
565
566        for epoch in 0..self.config.epochs {
567            let embeddings = self.embed(graph)?;
568            let epoch_loss = self.compute_unsupervised_loss(&embeddings, graph);
569            loss_history.push(epoch_loss);
570
571            // Gradient update: simple random perturbation for demonstration
572            // In production, proper backpropagation would be used
573            self.apply_random_gradient_step(epoch_loss);
574
575            tracing::debug!(epoch = epoch, loss = epoch_loss, "GraphSAGE training step");
576        }
577
578        let final_loss = loss_history.last().copied().unwrap_or(f64::NAN);
579        let convergence = loss_history.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-4);
580
581        Ok(GraphSageTrainingMetrics {
582            epochs_completed: self.config.epochs,
583            final_loss,
584            loss_history,
585            convergence_achieved: convergence,
586        })
587    }
588
589    /// Compute unsupervised loss using positive/negative node pairs
590    fn compute_unsupervised_loss(
591        &self,
592        embeddings: &GraphSageEmbeddings,
593        graph: &GraphData,
594    ) -> f64 {
595        let num_nodes = graph.num_nodes();
596        if num_nodes < 2 {
597            return 0.0;
598        }
599
600        let mut total_loss = 0.0;
601        let mut count = 0usize;
602        let mut local_rng = SimpleLcg::new(self.rng.state);
603
604        // Collect a sample of positive edges
605        let sample_nodes: Vec<usize> = (0..num_nodes.min(self.config.batch_size))
606            .map(|i| i % num_nodes)
607            .collect();
608
609        for &node in &sample_nodes {
610            let neighbors = graph.neighbors(node);
611            if neighbors.is_empty() {
612                continue;
613            }
614            // Positive: a direct neighbor
615            let pos_neighbor = neighbors[local_rng.next_usize() % neighbors.len()];
616
617            // Negative: a random non-neighbor node
618            let neg_node = local_rng.next_usize() % num_nodes;
619
620            if let (Some(h_u), Some(h_pos), Some(h_neg)) = (
621                embeddings.get(node),
622                embeddings.get(pos_neighbor),
623                embeddings.get(neg_node),
624            ) {
625                let pos_score = dot_product(h_u, h_pos);
626                let neg_score = dot_product(h_u, h_neg);
627
628                // Cross-entropy loss
629                let pos_loss = -sigmoid(pos_score).max(1e-10).ln();
630                let neg_loss = -(1.0 - sigmoid(neg_score)).max(1e-10).ln();
631                total_loss += pos_loss + neg_loss;
632                count += 1;
633            }
634        }
635
636        if count > 0 {
637            total_loss / count as f64
638        } else {
639            0.0
640        }
641    }
642
643    /// Apply a small random perturbation to layer weights (stub for full backprop)
644    fn apply_random_gradient_step(&mut self, loss: f64) {
645        let noise_scale = self.config.learning_rate * loss.abs().min(1.0) * 0.01;
646        for layer in self.layers.iter_mut() {
647            for row in layer.self_transform.weights.iter_mut() {
648                for w in row.iter_mut() {
649                    *w -= noise_scale * self.rng.next_f64_range(1.0);
650                }
651            }
652        }
653    }
654
655    /// L2 normalize a vector
656    pub fn normalize(v: &[f64]) -> Vec<f64> {
657        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
658        if norm < 1e-12 {
659            return v.to_vec();
660        }
661        v.iter().map(|x| x / norm).collect()
662    }
663}
664
665/// Training metrics from a GraphSAGE training run
666#[derive(Debug, Clone, Serialize, Deserialize)]
667pub struct GraphSageTrainingMetrics {
668    pub epochs_completed: usize,
669    pub final_loss: f64,
670    pub loss_history: Vec<f64>,
671    pub convergence_achieved: bool,
672}
673
674/// Output embeddings from GraphSAGE inference
675#[derive(Debug, Clone)]
676pub struct GraphSageEmbeddings {
677    /// Embedding vectors indexed by node ID
678    pub embeddings: Vec<Vec<f64>>,
679    /// Configuration used for generation
680    pub config: GraphSageConfig,
681    /// Number of nodes
682    pub num_nodes: usize,
683    /// Embedding dimensionality
684    pub dim: usize,
685}
686
687impl GraphSageEmbeddings {
688    /// Get the embedding for a specific node
689    pub fn get(&self, node: usize) -> Option<&[f64]> {
690        self.embeddings.get(node).map(|v| v.as_slice())
691    }
692
693    /// Compute cosine similarity between two node embeddings
694    pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
695        let va = self.get(a)?;
696        let vb = self.get(b)?;
697        Some(cosine_similarity_vecs(va, vb))
698    }
699
700    /// Find the top-k most similar nodes to a given node
701    ///
702    /// Returns a sorted list of (node_id, similarity_score) pairs.
703    pub fn top_k_similar(&self, node: usize, k: usize) -> Vec<(usize, f64)> {
704        let query = match self.get(node) {
705            Some(v) => v,
706            None => return Vec::new(),
707        };
708
709        let mut similarities: Vec<(usize, f64)> = (0..self.num_nodes)
710            .filter(|&i| i != node)
711            .filter_map(|i| self.get(i).map(|v| (i, cosine_similarity_vecs(query, v))))
712            .collect();
713
714        // Sort by similarity descending
715        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
716        similarities.truncate(k);
717        similarities
718    }
719
720    /// Build a lookup from node label to embedding for classified nodes
721    pub fn labeled_embeddings(&self, labels: &[usize]) -> HashMap<usize, Vec<Vec<f64>>> {
722        let mut map: HashMap<usize, Vec<Vec<f64>>> = HashMap::new();
723        for (node, &label) in labels.iter().enumerate() {
724            if let Some(emb) = self.get(node) {
725                map.entry(label).or_default().push(emb.to_vec());
726            }
727        }
728        map
729    }
730}
731
732/// Dot product of two slices of equal length
733pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
734    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
735}
736
737/// Sigmoid function: 1 / (1 + exp(-x))
738pub fn sigmoid(x: f64) -> f64 {
739    1.0 / (1.0 + (-x).exp())
740}
741
742/// Cosine similarity between two vectors
743pub fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
744    let dot = dot_product(a, b);
745    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
746    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
747    if norm_a < 1e-12 || norm_b < 1e-12 {
748        return 0.0;
749    }
750    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
751}
752
753/// Convert a `crate::EmbeddingError` to anyhow::Error for use in Results
754pub fn embedding_err(msg: impl Into<String>) -> crate::EmbeddingError {
755    EmbeddingError::Other(anyhow!(msg.into()))
756}
757
758#[cfg(test)]
759mod tests {
760    use super::*;
761
762    /// Build a simple star graph: node 0 is connected to nodes 1..n
763    fn star_graph(n: usize, feat_dim: usize, seed: u64) -> GraphData {
764        let mut rng = SimpleLcg::new(seed);
765        let features: Vec<Vec<f64>> = (0..n)
766            .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
767            .collect();
768        let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
769        // Star: center 0 connects to all others
770        for i in 1..n {
771            adjacency[0].push(i);
772            adjacency[i].push(0);
773        }
774        GraphData::new(features, adjacency).expect("star graph construction should succeed")
775    }
776
777    #[test]
778    fn test_graphsage_default_config() {
779        let config = GraphSageConfig::default();
780        assert_eq!(config.input_dim, 64);
781        assert_eq!(config.output_dim, 64);
782        assert!(!config.num_samples.is_empty());
783    }
784
785    #[test]
786    fn test_graphdata_construction() {
787        let graph = star_graph(5, 8, 1);
788        assert_eq!(graph.num_nodes(), 5);
789        assert_eq!(graph.feature_dim(), 8);
790        assert_eq!(graph.neighbors(0).len(), 4);
791        assert_eq!(graph.neighbors(1).len(), 1);
792        assert_eq!(graph.neighbors(1)[0], 0);
793    }
794
795    #[test]
796    fn test_graphdata_invalid_adjacency() {
797        let features = vec![vec![1.0, 2.0]; 3];
798        let adjacency = vec![
799            vec![1usize, 99], // 99 is out of bounds
800            vec![0],
801            vec![0],
802        ];
803        assert!(GraphData::new(features, adjacency).is_err());
804    }
805
806    #[test]
807    fn test_graphsage_embed_shape() {
808        let config = GraphSageConfig {
809            input_dim: 8,
810            hidden_dims: vec![16],
811            output_dim: 4,
812            num_samples: vec![3],
813            epochs: 1,
814            ..Default::default()
815        };
816        let model = GraphSage::new(config).expect("model construction should succeed");
817        let graph = star_graph(5, 8, 42);
818        let embeddings = model.embed(&graph).expect("embed should succeed");
819
820        assert_eq!(embeddings.num_nodes, 5);
821        assert_eq!(embeddings.dim, 4);
822        for i in 0..5 {
823            let emb = embeddings
824                .get(i)
825                .expect("should have embedding for every node");
826            assert_eq!(emb.len(), 4);
827        }
828    }
829
830    #[test]
831    fn test_graphsage_normalized_output() {
832        let config = GraphSageConfig {
833            input_dim: 8,
834            hidden_dims: vec![],
835            output_dim: 8,
836            num_samples: vec![5],
837            normalize_output: true,
838            epochs: 1,
839            ..Default::default()
840        };
841        let model = GraphSage::new(config).expect("model should construct");
842        let graph = star_graph(5, 8, 7);
843        let embeddings = model.embed(&graph).expect("embed should succeed");
844
845        // Each embedding should have unit norm (up to floating point tolerance)
846        for i in 0..5 {
847            let emb = embeddings.get(i).expect("embedding exists");
848            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
849            // May be 0 if ReLU killed all activations, otherwise should be ~1
850            assert!(norm < 1.0 + 1e-6, "norm {} should be <= 1", norm);
851        }
852    }
853
854    #[test]
855    fn test_cosine_similarity() {
856        let config = GraphSageConfig {
857            input_dim: 4,
858            hidden_dims: vec![],
859            output_dim: 4,
860            num_samples: vec![5],
861            normalize_output: false,
862            epochs: 1,
863            ..Default::default()
864        };
865        let model = GraphSage::new(config).expect("model should construct");
866        let graph = star_graph(5, 4, 13);
867        let embeddings = model.embed(&graph).expect("embed should succeed");
868
869        // Cosine similarity of a node with itself should be ~1.0
870        if let Some(sim) = embeddings.cosine_similarity(0, 0) {
871            // Self-similarity may be 0 if all values are 0 after ReLU
872            assert!((0.0..=1.0 + 1e-6).contains(&sim));
873        }
874    }
875
876    #[test]
877    fn test_top_k_similar() {
878        let config = GraphSageConfig {
879            input_dim: 4,
880            hidden_dims: vec![8],
881            output_dim: 4,
882            num_samples: vec![5],
883            normalize_output: true,
884            epochs: 1,
885            ..Default::default()
886        };
887        let model = GraphSage::new(config).expect("model should construct");
888        let graph = star_graph(6, 4, 17);
889        let embeddings = model.embed(&graph).expect("embed should succeed");
890
891        let top3 = embeddings.top_k_similar(0, 3);
892        assert!(top3.len() <= 3);
893        // Similarities should be in descending order
894        for window in top3.windows(2) {
895            assert!(window[0].1 >= window[1].1 - 1e-10);
896        }
897    }
898
899    #[test]
900    fn test_maxpool_aggregator() {
901        let config = GraphSageConfig {
902            input_dim: 4,
903            hidden_dims: vec![8],
904            output_dim: 4,
905            aggregator: AggregatorType::MaxPool { hidden_dim: 8 },
906            num_samples: vec![3],
907            epochs: 1,
908            ..Default::default()
909        };
910        let model = GraphSage::new(config).expect("model should construct with MaxPool");
911        let graph = star_graph(4, 4, 99);
912        let embeddings = model.embed(&graph).expect("embed should succeed");
913        assert_eq!(embeddings.num_nodes, 4);
914    }
915
916    #[test]
917    fn test_train_unsupervised() {
918        let config = GraphSageConfig {
919            input_dim: 4,
920            hidden_dims: vec![8],
921            output_dim: 4,
922            num_samples: vec![3],
923            epochs: 3,
924            batch_size: 4,
925            ..Default::default()
926        };
927        let mut model = GraphSage::new(config).expect("model should construct");
928        let graph = star_graph(5, 4, 42);
929        let metrics = model
930            .train_unsupervised(&graph)
931            .expect("training should succeed");
932        assert_eq!(metrics.epochs_completed, 3);
933        assert_eq!(metrics.loss_history.len(), 3);
934    }
935
936    #[test]
937    fn test_simplecg_reproducibility() {
938        let mut rng1 = SimpleLcg::new(42);
939        let mut rng2 = SimpleLcg::new(42);
940        for _ in 0..100 {
941            assert_eq!(rng1.next_u64(), rng2.next_u64());
942        }
943    }
944
945    #[test]
946    fn test_sample_neighbors() {
947        let graph = star_graph(10, 4, 1);
948        let mut rng = SimpleLcg::new(55);
949        let sampled = graph.sample_neighbors(0, 3, &mut rng);
950        assert!(sampled.len() <= 3);
951        for &n in &sampled {
952            assert!(graph.neighbors(0).contains(&n));
953        }
954    }
955}