Skip to main content

oxirs_graphrag/gnn_encoder/
message_passing.rs

1//! Message-passing GNN encoder for knowledge-graph entity embeddings.
2//!
3//! Implements a simple mean-aggregation multi-layer GNN with Xavier weight
4//! initialisation, ReLU + L2-normalisation activations, and a margin-based
5//! link-prediction training objective.  All random numbers are generated
6//! via an internal Linear Congruential Generator so the crate does not depend
7//! on the `rand` crate.
8
9use std::collections::HashMap;
10
11use crate::GraphRAGError;
12
13use super::adjacency::AdjacencyGraph;
14
15// ─────────────────────────────────────────────────────────────────────────────
16// Linear Congruential Generator (internal; no rand dependency)
17// ─────────────────────────────────────────────────────────────────────────────
18
19/// Simple Linear Congruential Generator — Numerical Recipes parameters.
20struct Lcg {
21    state: u64,
22}
23
24impl Lcg {
25    fn new(seed: u64) -> Self {
26        Self {
27            state: seed.wrapping_add(1),
28        }
29    }
30
31    fn next_u64(&mut self) -> u64 {
32        self.state = self
33            .state
34            .wrapping_mul(6_364_136_223_846_793_005)
35            .wrapping_add(1_442_695_040_888_963_407);
36        self.state
37    }
38
39    /// Uniform sample in [0, 1)
40    fn next_f64(&mut self) -> f64 {
41        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
42    }
43
44    /// Uniform sample in [-scale, scale)
45    fn next_f64_range(&mut self, scale: f64) -> f64 {
46        (self.next_f64() * 2.0 - 1.0) * scale
47    }
48
49    /// Random usize in [0, n)
50    fn next_usize(&mut self, n: usize) -> usize {
51        if n == 0 {
52            return 0;
53        }
54        (self.next_u64() as usize) % n
55    }
56}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// Configuration
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// Configuration for the message-passing GNN encoder.
63#[derive(Debug, Clone)]
64pub struct GnnEncoderConfig {
65    /// Number of message-passing layers (default: 2)
66    pub num_layers: usize,
67    /// Dimensionality of hidden and output embeddings (default: 64)
68    pub hidden_dim: usize,
69    /// Training epochs for the link-prediction objective (default: 50)
70    pub num_epochs: usize,
71    /// SGD learning rate (default: 0.01)
72    pub learning_rate: f64,
73    /// Margin for the triplet margin loss (default: 1.0)
74    pub margin: f64,
75}
76
77impl Default for GnnEncoderConfig {
78    fn default() -> Self {
79        Self {
80            num_layers: 2,
81            hidden_dim: 64,
82            num_epochs: 50,
83            learning_rate: 0.01,
84            margin: 1.0,
85        }
86    }
87}
88
89// ─────────────────────────────────────────────────────────────────────────────
90// GnnEncoder
91// ─────────────────────────────────────────────────────────────────────────────
92
93/// Multi-layer mean-aggregation GNN encoder.
94///
95/// After calling [`GnnEncoder::fit`] on a set of RDF triples the encoder can
96/// produce a fixed-dimensional embedding for any entity seen during training
97/// via [`GnnEncoder::embed_entity`].  Unknown entities return a zero vector.
98pub struct GnnEncoder {
99    /// Encoder configuration
100    config: GnnEncoderConfig,
101    /// Entity embeddings: `entity_embeddings[i]` is the embedding for node `i`
102    entity_embeddings: Vec<Vec<f64>>,
103    /// Weight matrices for each layer: `weight_matrices[l][i][j]`
104    weight_matrices: Vec<Vec<Vec<f64>>>,
105    /// Map from entity name → integer index
106    entity_index: HashMap<String, usize>,
107}
108
109impl GnnEncoder {
110    /// Create a new, untrained encoder.
111    pub fn new(config: GnnEncoderConfig) -> Self {
112        Self {
113            config,
114            entity_embeddings: Vec::new(),
115            weight_matrices: Vec::new(),
116            entity_index: HashMap::new(),
117        }
118    }
119
120    /// Fit the encoder to the provided RDF triples.
121    ///
122    /// Builds the adjacency graph, initialises embeddings and weight matrices,
123    /// then runs stochastic gradient descent with a margin-based
124    /// link-prediction loss.
125    pub fn fit(&mut self, triples: &[(String, String, String)]) -> Result<(), GraphRAGError> {
126        if triples.is_empty() {
127            return Err(GraphRAGError::EmbeddingError(
128                "Cannot fit GnnEncoder on empty triple set".into(),
129            ));
130        }
131
132        let graph = AdjacencyGraph::from_triples(triples);
133        let n = graph.entity_count();
134        let d = self.config.hidden_dim;
135
136        // Store entity index map
137        self.entity_index = graph.entity_to_idx.clone();
138
139        let mut rng = Lcg::new(42);
140
141        // Initialise entity embeddings with Xavier uniform
142        self.entity_embeddings = Self::xavier_init(n, d, &mut rng);
143
144        // Initialise one weight matrix per layer (d × d)
145        self.weight_matrices = (0..self.config.num_layers)
146            .map(|_| Self::xavier_init(d, d, &mut rng))
147            .collect();
148
149        // Training loop
150        for _epoch in 0..self.config.num_epochs {
151            // For each triple, run a positive + negative sample update
152            for (s_str, _p_str, o_str) in triples {
153                let Some(&s_idx) = self.entity_index.get(s_str.as_str()) else {
154                    continue;
155                };
156                let Some(&o_idx) = self.entity_index.get(o_str.as_str()) else {
157                    continue;
158                };
159
160                // Sample a random negative entity
161                let neg_idx = loop {
162                    let candidate = rng.next_usize(n);
163                    if candidate != o_idx {
164                        break candidate;
165                    }
166                };
167
168                // Forward pass: compute embeddings for s, o, neg
169                let emb_s = self.forward_entity(s_idx, &graph);
170                let emb_o = self.forward_entity(o_idx, &graph);
171                let emb_neg = self.forward_entity(neg_idx, &graph);
172
173                let loss = Self::margin_loss(&emb_s, &emb_o, &emb_neg, self.config.margin);
174
175                // Only update if loss is positive (violated margin)
176                if loss > 0.0 {
177                    self.sgd_update(s_idx, o_idx, neg_idx, &graph);
178                }
179            }
180        }
181
182        // Final forward pass to store steady-state embeddings
183        for i in 0..n {
184            self.entity_embeddings[i] = self.forward_entity(i, &graph);
185        }
186
187        Ok(())
188    }
189
190    /// Return the embedding vector for a named entity.
191    /// Returns a zero vector of length `hidden_dim` if the entity is unknown.
192    pub fn embed_entity(&self, entity: &str) -> Vec<f64> {
193        match self.entity_index.get(entity) {
194            Some(&idx) if idx < self.entity_embeddings.len() => self.entity_embeddings[idx].clone(),
195            _ => vec![0.0; self.config.hidden_dim],
196        }
197    }
198
199    // ─────────────────────────────────────────────────────────────────────────
200    // Private helpers
201    // ─────────────────────────────────────────────────────────────────────────
202
203    /// Xavier/Glorot uniform initialisation for an (rows × cols) weight matrix.
204    fn xavier_init(rows: usize, cols: usize, rng: &mut Lcg) -> Vec<Vec<f64>> {
205        let scale = (6.0 / (rows + cols) as f64).sqrt();
206        (0..rows)
207            .map(|_| (0..cols).map(|_| rng.next_f64_range(scale)).collect())
208            .collect()
209    }
210
211    /// Run a single forward pass for node `idx`, producing a `hidden_dim`-dimensional
212    /// embedding by iterating over each message-passing layer.
213    fn forward_entity(&self, idx: usize, graph: &AdjacencyGraph) -> Vec<f64> {
214        let d = self.config.hidden_dim;
215        let mut h = if idx < self.entity_embeddings.len() {
216            self.entity_embeddings[idx].clone()
217        } else {
218            vec![0.0; d]
219        };
220
221        for layer in 0..self.config.num_layers {
222            // Collect neighbour embeddings for mean aggregation
223            let neighbors = graph.neighbors(idx);
224            let neighbor_embs: Vec<&Vec<f64>> = neighbors
225                .iter()
226                .filter_map(|&nidx| self.entity_embeddings.get(nidx))
227                .collect();
228
229            let aggregated = if neighbor_embs.is_empty() {
230                h.clone()
231            } else {
232                // Mean of self + neighbours
233                let mut combined = neighbor_embs.clone();
234                combined.push(&h);
235                Self::mean_aggregate(&combined)
236            };
237
238            // Apply weight matrix for this layer: h_new = W * aggregated
239            let w = &self.weight_matrices[layer];
240            let mut new_h = vec![0.0; d];
241            for (i, row) in w.iter().enumerate() {
242                let dot: f64 = row.iter().zip(aggregated.iter()).map(|(a, b)| a * b).sum();
243                new_h[i] = dot;
244            }
245
246            Self::relu_and_normalize(&mut new_h);
247            h = new_h;
248        }
249
250        h
251    }
252
253    /// One SGD step pushing the positive pair (s, o) closer and the negative
254    /// pair (s, neg) further apart in embedding space.
255    fn sgd_update(&mut self, s_idx: usize, o_idx: usize, neg_idx: usize, graph: &AdjacencyGraph) {
256        let lr = self.config.learning_rate;
257        let d = self.config.hidden_dim;
258
259        let emb_s = self.forward_entity(s_idx, graph);
260        let emb_o = self.forward_entity(o_idx, graph);
261        let emb_neg = self.forward_entity(neg_idx, graph);
262
263        // Gradient for entity embeddings:
264        //   push s closer to o: emb_s -= lr * (emb_s - emb_o)
265        //   push s away from neg: emb_s += lr * (emb_s - emb_neg)
266        for j in 0..d {
267            if s_idx < self.entity_embeddings.len() {
268                let grad_pos = emb_s[j] - emb_o[j];
269                let grad_neg = emb_s[j] - emb_neg[j];
270                self.entity_embeddings[s_idx][j] -= lr * (grad_pos - grad_neg);
271            }
272        }
273
274        // Normalise updated embedding
275        if s_idx < self.entity_embeddings.len() {
276            let v = &mut self.entity_embeddings[s_idx];
277            Self::relu_and_normalize(v);
278        }
279    }
280
281    /// Compute the mean embedding of a non-empty slice of embedding vectors.
282    pub fn mean_aggregate(embeddings: &[&Vec<f64>]) -> Vec<f64> {
283        if embeddings.is_empty() {
284            return Vec::new();
285        }
286        let d = embeddings[0].len();
287        let mut mean = vec![0.0_f64; d];
288        for emb in embeddings {
289            for (j, &val) in emb.iter().enumerate() {
290                if j < mean.len() {
291                    mean[j] += val;
292                }
293            }
294        }
295        let n = embeddings.len() as f64;
296        for v in &mut mean {
297            *v /= n;
298        }
299        mean
300    }
301
302    /// Apply ReLU activation then L2-normalise the vector in-place.
303    /// If the L2 norm is near zero the vector is left unchanged.
304    pub fn relu_and_normalize(v: &mut [f64]) {
305        // ReLU
306        for x in v.iter_mut() {
307            if *x < 0.0 {
308                *x = 0.0;
309            }
310        }
311        // L2 normalise
312        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
313        if norm > 1e-10 {
314            for x in v.iter_mut() {
315                *x /= norm;
316            }
317        }
318    }
319
320    /// Triplet margin loss: max(0, d(s, o) - d(s, neg) + margin)
321    /// where d is squared Euclidean distance.
322    pub fn margin_loss(pos_s: &[f64], pos_o: &[f64], neg_o: &[f64], margin: f64) -> f64 {
323        let d_pos: f64 = pos_s
324            .iter()
325            .zip(pos_o.iter())
326            .map(|(a, b)| (a - b).powi(2))
327            .sum();
328        let d_neg: f64 = pos_s
329            .iter()
330            .zip(neg_o.iter())
331            .map(|(a, b)| (a - b).powi(2))
332            .sum();
333        (d_pos - d_neg + margin).max(0.0)
334    }
335}
336
337// ─────────────────────────────────────────────────────────────────────────────
338// Tests
339// ─────────────────────────────────────────────────────────────────────────────
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    fn triples() -> Vec<(String, String, String)> {
346        vec![
347            ("Alice".into(), "knows".into(), "Bob".into()),
348            ("Bob".into(), "worksAt".into(), "Acme".into()),
349            ("Carol".into(), "worksAt".into(), "Acme".into()),
350            ("Alice".into(), "friendOf".into(), "Carol".into()),
351            ("Dave".into(), "knows".into(), "Alice".into()),
352        ]
353    }
354
355    #[test]
356    fn test_fit_completes() {
357        let mut encoder = GnnEncoder::new(GnnEncoderConfig {
358            num_layers: 2,
359            hidden_dim: 16,
360            num_epochs: 5,
361            ..Default::default()
362        });
363        encoder.fit(&triples()).expect("fit should succeed");
364    }
365
366    #[test]
367    fn test_embed_shape_correct() {
368        let mut encoder = GnnEncoder::new(GnnEncoderConfig {
369            num_layers: 2,
370            hidden_dim: 32,
371            num_epochs: 3,
372            ..Default::default()
373        });
374        encoder.fit(&triples()).expect("fit should succeed");
375        let emb = encoder.embed_entity("Alice");
376        assert_eq!(emb.len(), 32, "Embedding dimension must match hidden_dim");
377    }
378
379    #[test]
380    fn test_unseen_entity_returns_zero_vec() {
381        let mut encoder = GnnEncoder::new(GnnEncoderConfig {
382            num_layers: 1,
383            hidden_dim: 8,
384            num_epochs: 2,
385            ..Default::default()
386        });
387        encoder.fit(&triples()).expect("fit should succeed");
388        let emb = encoder.embed_entity("UnknownEntity_XYZ");
389        assert_eq!(emb.len(), 8);
390        assert!(
391            emb.iter().all(|&x| x == 0.0),
392            "Unknown entity must map to zero vector"
393        );
394    }
395
396    #[test]
397    fn test_loss_is_non_negative() {
398        // The margin loss must always be ≥ 0
399        let a = vec![1.0_f64, 0.0, 0.0];
400        let b = vec![0.0_f64, 1.0, 0.0];
401        let c = vec![0.0_f64, 0.0, 1.0];
402        let loss = GnnEncoder::margin_loss(&a, &b, &c, 1.0);
403        assert!(loss >= 0.0, "Margin loss must be non-negative");
404    }
405
406    #[test]
407    fn test_embeddings_l2_normalized() {
408        let mut encoder = GnnEncoder::new(GnnEncoderConfig {
409            num_layers: 2,
410            hidden_dim: 16,
411            num_epochs: 5,
412            ..Default::default()
413        });
414        encoder.fit(&triples()).expect("fit should succeed");
415
416        for entity in &["Alice", "Bob", "Acme"] {
417            let emb = encoder.embed_entity(entity);
418            let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
419            // After relu + l2-normalise, norm should be close to 1
420            // (or zero if all activations are zero)
421            assert!(
422                (norm - 1.0).abs() < 1e-6 || norm < 1e-10,
423                "Entity {} norm={} should be 1 or 0 (all-zero)",
424                entity,
425                norm
426            );
427        }
428    }
429
430    #[test]
431    fn test_mean_aggregation_correct() {
432        let a = vec![1.0_f64, 2.0, 3.0];
433        let b = vec![3.0_f64, 4.0, 5.0];
434        let result = GnnEncoder::mean_aggregate(&[&a, &b]);
435        assert_eq!(result.len(), 3);
436        assert!((result[0] - 2.0).abs() < 1e-10);
437        assert!((result[1] - 3.0).abs() < 1e-10);
438        assert!((result[2] - 4.0).abs() < 1e-10);
439    }
440}