oxirs_graphrag/gnn_encoder/
message_passing.rs1use std::collections::HashMap;
10
11use crate::GraphRAGError;
12
13use super::adjacency::AdjacencyGraph;
14
15struct Lcg {
21 state: u64,
22}
23
24impl Lcg {
25 fn new(seed: u64) -> Self {
26 Self {
27 state: seed.wrapping_add(1),
28 }
29 }
30
31 fn next_u64(&mut self) -> u64 {
32 self.state = self
33 .state
34 .wrapping_mul(6_364_136_223_846_793_005)
35 .wrapping_add(1_442_695_040_888_963_407);
36 self.state
37 }
38
39 fn next_f64(&mut self) -> f64 {
41 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
42 }
43
44 fn next_f64_range(&mut self, scale: f64) -> f64 {
46 (self.next_f64() * 2.0 - 1.0) * scale
47 }
48
49 fn next_usize(&mut self, n: usize) -> usize {
51 if n == 0 {
52 return 0;
53 }
54 (self.next_u64() as usize) % n
55 }
56}
57
58#[derive(Debug, Clone)]
64pub struct GnnEncoderConfig {
65 pub num_layers: usize,
67 pub hidden_dim: usize,
69 pub num_epochs: usize,
71 pub learning_rate: f64,
73 pub margin: f64,
75}
76
77impl Default for GnnEncoderConfig {
78 fn default() -> Self {
79 Self {
80 num_layers: 2,
81 hidden_dim: 64,
82 num_epochs: 50,
83 learning_rate: 0.01,
84 margin: 1.0,
85 }
86 }
87}
88
89pub struct GnnEncoder {
99 config: GnnEncoderConfig,
101 entity_embeddings: Vec<Vec<f64>>,
103 weight_matrices: Vec<Vec<Vec<f64>>>,
105 entity_index: HashMap<String, usize>,
107}
108
109impl GnnEncoder {
110 pub fn new(config: GnnEncoderConfig) -> Self {
112 Self {
113 config,
114 entity_embeddings: Vec::new(),
115 weight_matrices: Vec::new(),
116 entity_index: HashMap::new(),
117 }
118 }
119
120 pub fn fit(&mut self, triples: &[(String, String, String)]) -> Result<(), GraphRAGError> {
126 if triples.is_empty() {
127 return Err(GraphRAGError::EmbeddingError(
128 "Cannot fit GnnEncoder on empty triple set".into(),
129 ));
130 }
131
132 let graph = AdjacencyGraph::from_triples(triples);
133 let n = graph.entity_count();
134 let d = self.config.hidden_dim;
135
136 self.entity_index = graph.entity_to_idx.clone();
138
139 let mut rng = Lcg::new(42);
140
141 self.entity_embeddings = Self::xavier_init(n, d, &mut rng);
143
144 self.weight_matrices = (0..self.config.num_layers)
146 .map(|_| Self::xavier_init(d, d, &mut rng))
147 .collect();
148
149 for _epoch in 0..self.config.num_epochs {
151 for (s_str, _p_str, o_str) in triples {
153 let Some(&s_idx) = self.entity_index.get(s_str.as_str()) else {
154 continue;
155 };
156 let Some(&o_idx) = self.entity_index.get(o_str.as_str()) else {
157 continue;
158 };
159
160 let neg_idx = loop {
162 let candidate = rng.next_usize(n);
163 if candidate != o_idx {
164 break candidate;
165 }
166 };
167
168 let emb_s = self.forward_entity(s_idx, &graph);
170 let emb_o = self.forward_entity(o_idx, &graph);
171 let emb_neg = self.forward_entity(neg_idx, &graph);
172
173 let loss = Self::margin_loss(&emb_s, &emb_o, &emb_neg, self.config.margin);
174
175 if loss > 0.0 {
177 self.sgd_update(s_idx, o_idx, neg_idx, &graph);
178 }
179 }
180 }
181
182 for i in 0..n {
184 self.entity_embeddings[i] = self.forward_entity(i, &graph);
185 }
186
187 Ok(())
188 }
189
190 pub fn embed_entity(&self, entity: &str) -> Vec<f64> {
193 match self.entity_index.get(entity) {
194 Some(&idx) if idx < self.entity_embeddings.len() => self.entity_embeddings[idx].clone(),
195 _ => vec![0.0; self.config.hidden_dim],
196 }
197 }
198
199 fn xavier_init(rows: usize, cols: usize, rng: &mut Lcg) -> Vec<Vec<f64>> {
205 let scale = (6.0 / (rows + cols) as f64).sqrt();
206 (0..rows)
207 .map(|_| (0..cols).map(|_| rng.next_f64_range(scale)).collect())
208 .collect()
209 }
210
211 fn forward_entity(&self, idx: usize, graph: &AdjacencyGraph) -> Vec<f64> {
214 let d = self.config.hidden_dim;
215 let mut h = if idx < self.entity_embeddings.len() {
216 self.entity_embeddings[idx].clone()
217 } else {
218 vec![0.0; d]
219 };
220
221 for layer in 0..self.config.num_layers {
222 let neighbors = graph.neighbors(idx);
224 let neighbor_embs: Vec<&Vec<f64>> = neighbors
225 .iter()
226 .filter_map(|&nidx| self.entity_embeddings.get(nidx))
227 .collect();
228
229 let aggregated = if neighbor_embs.is_empty() {
230 h.clone()
231 } else {
232 let mut combined = neighbor_embs.clone();
234 combined.push(&h);
235 Self::mean_aggregate(&combined)
236 };
237
238 let w = &self.weight_matrices[layer];
240 let mut new_h = vec![0.0; d];
241 for (i, row) in w.iter().enumerate() {
242 let dot: f64 = row.iter().zip(aggregated.iter()).map(|(a, b)| a * b).sum();
243 new_h[i] = dot;
244 }
245
246 Self::relu_and_normalize(&mut new_h);
247 h = new_h;
248 }
249
250 h
251 }
252
253 fn sgd_update(&mut self, s_idx: usize, o_idx: usize, neg_idx: usize, graph: &AdjacencyGraph) {
256 let lr = self.config.learning_rate;
257 let d = self.config.hidden_dim;
258
259 let emb_s = self.forward_entity(s_idx, graph);
260 let emb_o = self.forward_entity(o_idx, graph);
261 let emb_neg = self.forward_entity(neg_idx, graph);
262
263 for j in 0..d {
267 if s_idx < self.entity_embeddings.len() {
268 let grad_pos = emb_s[j] - emb_o[j];
269 let grad_neg = emb_s[j] - emb_neg[j];
270 self.entity_embeddings[s_idx][j] -= lr * (grad_pos - grad_neg);
271 }
272 }
273
274 if s_idx < self.entity_embeddings.len() {
276 let v = &mut self.entity_embeddings[s_idx];
277 Self::relu_and_normalize(v);
278 }
279 }
280
281 pub fn mean_aggregate(embeddings: &[&Vec<f64>]) -> Vec<f64> {
283 if embeddings.is_empty() {
284 return Vec::new();
285 }
286 let d = embeddings[0].len();
287 let mut mean = vec![0.0_f64; d];
288 for emb in embeddings {
289 for (j, &val) in emb.iter().enumerate() {
290 if j < mean.len() {
291 mean[j] += val;
292 }
293 }
294 }
295 let n = embeddings.len() as f64;
296 for v in &mut mean {
297 *v /= n;
298 }
299 mean
300 }
301
302 pub fn relu_and_normalize(v: &mut [f64]) {
305 for x in v.iter_mut() {
307 if *x < 0.0 {
308 *x = 0.0;
309 }
310 }
311 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
313 if norm > 1e-10 {
314 for x in v.iter_mut() {
315 *x /= norm;
316 }
317 }
318 }
319
320 pub fn margin_loss(pos_s: &[f64], pos_o: &[f64], neg_o: &[f64], margin: f64) -> f64 {
323 let d_pos: f64 = pos_s
324 .iter()
325 .zip(pos_o.iter())
326 .map(|(a, b)| (a - b).powi(2))
327 .sum();
328 let d_neg: f64 = pos_s
329 .iter()
330 .zip(neg_o.iter())
331 .map(|(a, b)| (a - b).powi(2))
332 .sum();
333 (d_pos - d_neg + margin).max(0.0)
334 }
335}
336
337#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn triples() -> Vec<(String, String, String)> {
346 vec![
347 ("Alice".into(), "knows".into(), "Bob".into()),
348 ("Bob".into(), "worksAt".into(), "Acme".into()),
349 ("Carol".into(), "worksAt".into(), "Acme".into()),
350 ("Alice".into(), "friendOf".into(), "Carol".into()),
351 ("Dave".into(), "knows".into(), "Alice".into()),
352 ]
353 }
354
355 #[test]
356 fn test_fit_completes() {
357 let mut encoder = GnnEncoder::new(GnnEncoderConfig {
358 num_layers: 2,
359 hidden_dim: 16,
360 num_epochs: 5,
361 ..Default::default()
362 });
363 encoder.fit(&triples()).expect("fit should succeed");
364 }
365
366 #[test]
367 fn test_embed_shape_correct() {
368 let mut encoder = GnnEncoder::new(GnnEncoderConfig {
369 num_layers: 2,
370 hidden_dim: 32,
371 num_epochs: 3,
372 ..Default::default()
373 });
374 encoder.fit(&triples()).expect("fit should succeed");
375 let emb = encoder.embed_entity("Alice");
376 assert_eq!(emb.len(), 32, "Embedding dimension must match hidden_dim");
377 }
378
379 #[test]
380 fn test_unseen_entity_returns_zero_vec() {
381 let mut encoder = GnnEncoder::new(GnnEncoderConfig {
382 num_layers: 1,
383 hidden_dim: 8,
384 num_epochs: 2,
385 ..Default::default()
386 });
387 encoder.fit(&triples()).expect("fit should succeed");
388 let emb = encoder.embed_entity("UnknownEntity_XYZ");
389 assert_eq!(emb.len(), 8);
390 assert!(
391 emb.iter().all(|&x| x == 0.0),
392 "Unknown entity must map to zero vector"
393 );
394 }
395
396 #[test]
397 fn test_loss_is_non_negative() {
398 let a = vec![1.0_f64, 0.0, 0.0];
400 let b = vec![0.0_f64, 1.0, 0.0];
401 let c = vec![0.0_f64, 0.0, 1.0];
402 let loss = GnnEncoder::margin_loss(&a, &b, &c, 1.0);
403 assert!(loss >= 0.0, "Margin loss must be non-negative");
404 }
405
406 #[test]
407 fn test_embeddings_l2_normalized() {
408 let mut encoder = GnnEncoder::new(GnnEncoderConfig {
409 num_layers: 2,
410 hidden_dim: 16,
411 num_epochs: 5,
412 ..Default::default()
413 });
414 encoder.fit(&triples()).expect("fit should succeed");
415
416 for entity in &["Alice", "Bob", "Acme"] {
417 let emb = encoder.embed_entity(entity);
418 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
419 assert!(
422 (norm - 1.0).abs() < 1e-6 || norm < 1e-10,
423 "Entity {} norm={} should be 1 or 0 (all-zero)",
424 entity,
425 norm
426 );
427 }
428 }
429
430 #[test]
431 fn test_mean_aggregation_correct() {
432 let a = vec![1.0_f64, 2.0, 3.0];
433 let b = vec![3.0_f64, 4.0, 5.0];
434 let result = GnnEncoder::mean_aggregate(&[&a, &b]);
435 assert_eq!(result.len(), 3);
436 assert!((result[0] - 2.0).abs() < 1e-10);
437 assert!((result[1] - 3.0).abs() < 1e-10);
438 assert!((result[2] - 4.0).abs() < 1e-10);
439 }
440}