1use crate::models::graphsage::SimpleLcg;
7use crate::EmbeddingError;
8use anyhow::anyhow;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GraphSageEmbedderConfig {
16 pub num_layers: usize,
18 pub hidden_dim: usize,
20 pub embedding_dim: usize,
22 pub neighbor_sample_k: usize,
24 pub learning_rate: f64,
26 pub num_epochs: usize,
28 pub margin: f64,
30 pub seed: Option<u64>,
32}
33
34impl Default for GraphSageEmbedderConfig {
35 fn default() -> Self {
36 Self {
37 num_layers: 2,
38 hidden_dim: 64,
39 embedding_dim: 64,
40 neighbor_sample_k: 10,
41 learning_rate: 0.01,
42 num_epochs: 50,
43 margin: 1.0,
44 seed: None,
45 }
46 }
47}
48
49fn xavier_uniform<R>(rows: usize, cols: usize, rng: &mut Random<R>) -> Vec<Vec<f64>>
51where
52 R: scirs2_core::random::Rng,
53{
54 let limit = (6.0_f64 / (rows + cols) as f64).sqrt();
55 (0..rows)
56 .map(|_| (0..cols).map(|_| rng.random_range(-limit..limit)).collect())
57 .collect()
58}
59
60#[inline]
61fn matmul(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
62 w.iter()
63 .map(|row| row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum())
64 .collect()
65}
66
67#[inline]
68fn relu_vec(v: &[f64]) -> Vec<f64> {
69 v.iter().map(|&x| x.max(0.0)).collect()
70}
71
72fn l2_normalize(v: &mut [f64]) {
73 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
74 if norm > 1e-12 {
75 v.iter_mut().for_each(|x| *x /= norm);
76 }
77}
78
79#[inline]
80fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
81 let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
82 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
83 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
84 dot / (na * nb + 1e-8)
85}
86
87pub struct GraphSageEmbedder {
94 config: GraphSageEmbedderConfig,
95 weights: Vec<Vec<Vec<f64>>>,
97 entity_index: HashMap<String, usize>,
99 embeddings: Vec<Vec<f64>>,
101 trained: bool,
102}
103
104impl std::fmt::Debug for GraphSageEmbedder {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("GraphSageEmbedder")
107 .field("num_entities", &self.entity_index.len())
108 .field("trained", &self.trained)
109 .field("num_layers", &self.config.num_layers)
110 .field("embedding_dim", &self.config.embedding_dim)
111 .finish()
112 }
113}
114
115impl GraphSageEmbedder {
116 pub fn new(config: GraphSageEmbedderConfig) -> Self {
118 Self {
119 config,
120 weights: Vec::new(),
121 entity_index: HashMap::new(),
122 embeddings: Vec::new(),
123 trained: false,
124 }
125 }
126
127 pub fn fit(
131 &mut self,
132 triples: &[(String, String, String)],
133 ) -> std::result::Result<(), EmbeddingError> {
134 if triples.is_empty() {
135 return Err(EmbeddingError::Other(anyhow!("Triple set is empty")));
136 }
137
138 let (entity_index, adjacency) = Self::build_graph(triples);
140 let num_entities = entity_index.len();
141 self.entity_index = entity_index;
142
143 let seed = self.config.seed.unwrap_or(42);
145 let mut rng = Random::seed(seed);
146 self.weights = Self::init_weights(&self.config, &mut rng);
147
148 let input_dim = self.config.hidden_dim;
150 let mut h0: Vec<Vec<f64>> = (0..num_entities)
151 .map(|_| {
152 let mut v: Vec<f64> = (0..input_dim)
153 .map(|_| rng.random_range(-0.5_f64..0.5_f64))
154 .collect();
155 l2_normalize(&mut v);
156 v
157 })
158 .collect();
159
160 let num_layers = self.config.num_layers;
162 let mut lcg = SimpleLcg::new(seed.wrapping_add(1));
163
164 for _epoch in 0..self.config.num_epochs {
165 let h_all = self.forward_all(&h0, &adjacency, num_entities, &mut lcg);
166 let mut deltas: Vec<Vec<Vec<f64>>> = self
167 .weights
168 .iter()
169 .map(|w| vec![vec![0.0; w[0].len()]; w.len()])
170 .collect();
171 let mut grad_count = 0usize;
172
173 for (s_str, _p_str, o_str) in triples {
174 let s_idx = match self.entity_index.get(s_str.as_str()) {
175 Some(&i) => i,
176 None => continue,
177 };
178 let o_idx = match self.entity_index.get(o_str.as_str()) {
179 Some(&i) => i,
180 None => continue,
181 };
182 let o_neg_idx = self.sample_negative(o_idx, num_entities, &mut lcg);
183 let h_s = &h_all[s_idx];
184 let h_o = &h_all[o_idx];
185 let h_neg = &h_all[o_neg_idx];
186 let loss =
187 (self.config.margin - cosine_sim(h_s, h_o) + cosine_sim(h_s, h_neg)).max(0.0);
188
189 if loss > 0.0 {
190 for (l, delta_layer) in deltas.iter_mut().enumerate().take(num_layers) {
191 let nr = self.weights[l].len();
192 for (r, delta_row) in delta_layer.iter_mut().enumerate().take(nr) {
193 let sign = if h_s.get(r % h_s.len()).copied().unwrap_or(0.0) > 0.0 {
194 1.0_f64
195 } else {
196 -1.0_f64
197 };
198 for delta in delta_row.iter_mut() {
199 *delta += sign * loss;
200 }
201 }
202 }
203 grad_count += 1;
204 }
205 }
206
207 if grad_count > 0 {
208 let scale = self.config.learning_rate / grad_count as f64;
209 for (l, delta_layer) in deltas.iter().enumerate().take(num_layers) {
210 for (r, delta_row) in delta_layer.iter().enumerate() {
211 let row_norm: f64 = delta_row.iter().map(|g| g * g).sum::<f64>().sqrt();
212 let clip = if row_norm > 1.0 { 1.0 / row_norm } else { 1.0 };
213 for (w, d) in self.weights[l][r].iter_mut().zip(delta_row.iter()) {
214 *w -= d * clip * scale;
215 }
216 }
217 }
218 }
219 for feat in h0.iter_mut() {
220 l2_normalize(feat);
221 }
222 }
223
224 let mut lcg_final = SimpleLcg::new(seed.wrapping_add(2));
226 self.embeddings = self.forward_all(&h0, &adjacency, num_entities, &mut lcg_final);
227
228 self.trained = true;
229 Ok(())
230 }
231
232 pub fn embed_entity(&self, entity: &str) -> std::result::Result<Vec<f64>, EmbeddingError> {
234 if !self.trained {
235 return Err(EmbeddingError::ModelNotTrained);
236 }
237 match self.entity_index.get(entity) {
238 Some(&idx) => Ok(self
239 .embeddings
240 .get(idx)
241 .cloned()
242 .unwrap_or_else(|| vec![0.0; self.config.embedding_dim])),
243 None => Ok(vec![0.0; self.config.embedding_dim]),
244 }
245 }
246
247 pub fn is_trained(&self) -> bool {
248 self.trained
249 }
250 pub fn num_entities(&self) -> usize {
251 self.entity_index.len()
252 }
253 pub fn embedding_dim(&self) -> usize {
254 self.config.embedding_dim
255 }
256
257 fn build_graph(
260 triples: &[(String, String, String)],
261 ) -> (HashMap<String, usize>, HashMap<String, Vec<String>>) {
262 let mut entity_index: HashMap<String, usize> = HashMap::new();
263 let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
264
265 let mut next_id = 0usize;
266 for (s, _p, o) in triples {
267 for entity in [s, o] {
268 entity_index.entry(entity.clone()).or_insert_with(|| {
269 let id = next_id;
270 next_id += 1;
271 id
272 });
273 }
274 adjacency.entry(s.clone()).or_default().push(o.clone());
276 adjacency.entry(o.clone()).or_default().push(s.clone());
277 }
278 (entity_index, adjacency)
279 }
280
281 fn init_weights<R>(config: &GraphSageEmbedderConfig, rng: &mut Random<R>) -> Vec<Vec<Vec<f64>>>
282 where
283 R: scirs2_core::random::Rng,
284 {
285 let mut weights = Vec::with_capacity(config.num_layers);
286 for l in 0..config.num_layers {
287 let in_dim = 2 * config.hidden_dim;
288 let out_dim = if l + 1 == config.num_layers {
289 config.embedding_dim
290 } else {
291 config.hidden_dim
292 };
293 weights.push(xavier_uniform(out_dim, in_dim, rng));
294 }
295 weights
296 }
297
298 fn forward_all(
299 &self,
300 h0: &[Vec<f64>],
301 adjacency: &HashMap<String, Vec<String>>,
302 num_entities: usize,
303 lcg: &mut SimpleLcg,
304 ) -> Vec<Vec<f64>> {
305 let mut id_to_iri: Vec<&str> = vec![""; num_entities];
307 for (iri, &idx) in &self.entity_index {
308 if idx < num_entities {
309 id_to_iri[idx] = iri.as_str();
310 }
311 }
312
313 let mut h_prev: Vec<Vec<f64>> = h0.to_vec();
314
315 for l in 0..self.config.num_layers {
316 let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(num_entities);
317
318 for node_idx in 0..num_entities {
319 let iri = id_to_iri[node_idx];
320 let neighbor_embeds = self.sample_and_collect(iri, adjacency, &h_prev, lcg);
321 let h_new =
322 self.aggregate_mean(&h_prev[node_idx], &neighbor_embeds, &self.weights[l]);
323 h_next.push(h_new);
324 }
325
326 h_prev = h_next;
327 }
328
329 h_prev
330 }
331
332 pub(crate) fn aggregate_mean(
334 &self,
335 node_embed: &[f64],
336 neighbor_embeds: &[Vec<f64>],
337 weight_matrix: &[Vec<f64>],
338 ) -> Vec<f64> {
339 let dim = node_embed.len();
340 let mean_neigh: Vec<f64> = if neighbor_embeds.is_empty() {
342 node_embed.to_vec()
343 } else {
344 let mut acc = vec![0.0_f64; dim];
345 for n_emb in neighbor_embeds {
346 for (a, &v) in acc.iter_mut().zip(n_emb.iter()) {
347 *a += v;
348 }
349 }
350 let n = neighbor_embeds.len() as f64;
351 acc.iter_mut().for_each(|a| *a /= n);
352 acc
353 };
354
355 let mut concat = Vec::with_capacity(dim + mean_neigh.len());
357 concat.extend_from_slice(node_embed);
358 concat.extend_from_slice(&mean_neigh);
359 let expected_cols = weight_matrix
361 .first()
362 .map(|r| r.len())
363 .unwrap_or(concat.len());
364 concat.resize(expected_cols, 0.0);
365
366 let mut h_new = relu_vec(&matmul(weight_matrix, &concat));
367 l2_normalize(&mut h_new);
368 h_new
369 }
370
371 #[inline]
373 pub fn relu(x: f64) -> f64 {
374 x.max(0.0)
375 }
376
377 pub fn sample_neighbors<'a>(
379 &self,
380 node_iri: &str,
381 adjacency: &'a HashMap<String, Vec<String>>,
382 ) -> Vec<&'a str> {
383 let neighbors = match adjacency.get(node_iri) {
384 Some(n) => n.as_slice(),
385 None => return Vec::new(),
386 };
387 let k = self.config.neighbor_sample_k;
388 if neighbors.len() <= k {
389 return neighbors.iter().map(|s| s.as_str()).collect();
390 }
391 let mut indices: Vec<usize> = (0..neighbors.len()).collect();
392 let mut lcg = SimpleLcg::new(42);
393 for i in 0..k {
394 let j = i + (lcg.next_usize() % (indices.len() - i));
395 indices.swap(i, j);
396 }
397 indices[..k]
398 .iter()
399 .map(|&i| neighbors[i].as_str())
400 .collect()
401 }
402
403 fn sample_and_collect(
404 &self,
405 node_iri: &str,
406 adjacency: &HashMap<String, Vec<String>>,
407 h_prev: &[Vec<f64>],
408 lcg: &mut SimpleLcg,
409 ) -> Vec<Vec<f64>> {
410 let neighbors = match adjacency.get(node_iri) {
411 Some(n) => n.as_slice(),
412 None => return Vec::new(),
413 };
414 let k = self.config.neighbor_sample_k;
415 let sampled: Vec<&str> = if neighbors.len() <= k {
416 neighbors.iter().map(|s| s.as_str()).collect()
417 } else {
418 let mut indices: Vec<usize> = (0..neighbors.len()).collect();
419 for i in 0..k {
420 let j = i + (lcg.next_usize() % (indices.len() - i));
421 indices.swap(i, j);
422 }
423 indices[..k]
424 .iter()
425 .map(|&idx| neighbors[idx].as_str())
426 .collect()
427 };
428
429 sampled
430 .into_iter()
431 .filter_map(|iri| {
432 self.entity_index
433 .get(iri)
434 .and_then(|&idx| h_prev.get(idx))
435 .cloned()
436 })
437 .collect()
438 }
439
440 fn sample_negative(
441 &self,
442 positive_idx: usize,
443 num_entities: usize,
444 lcg: &mut SimpleLcg,
445 ) -> usize {
446 if num_entities <= 1 {
447 return 0;
448 }
449 let mut candidate = lcg.next_usize() % num_entities;
450 let mut attempts = 0usize;
451 while candidate == positive_idx && attempts < num_entities {
452 candidate = (candidate + 1) % num_entities;
453 attempts += 1;
454 }
455 candidate
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 fn toy_triples(n_entities: usize, n_triples: usize) -> Vec<(String, String, String)> {
464 let mut triples = Vec::with_capacity(n_triples);
465 for i in 0..n_triples {
466 let s = format!("http://ex.org/e{}", i % n_entities);
467 let p = "http://ex.org/rel".to_string();
468 let o = format!("http://ex.org/e{}", (i + 1) % n_entities);
469 triples.push((s, p, o));
470 }
471 triples
472 }
473
474 #[test]
476 fn test_forward_pass_shape() {
477 let config = GraphSageEmbedderConfig {
478 num_layers: 2,
479 hidden_dim: 16,
480 embedding_dim: 8,
481 neighbor_sample_k: 5,
482 learning_rate: 0.01,
483 num_epochs: 1,
484 margin: 1.0,
485 seed: Some(1),
486 };
487 let triples = toy_triples(8, 16);
488 let mut embedder = GraphSageEmbedder::new(config.clone());
489 embedder.fit(&triples).expect("fit should succeed");
490
491 for i in 0..8usize {
492 let iri = format!("http://ex.org/e{}", i);
493 let emb = embedder
494 .embed_entity(&iri)
495 .expect("embed_entity should succeed");
496 assert_eq!(
497 emb.len(),
498 config.embedding_dim,
499 "embedding length mismatch for entity {iri}"
500 );
501 }
502 }
503
504 #[test]
506 fn test_deterministic_init() {
507 let config = GraphSageEmbedderConfig {
508 num_layers: 1,
509 hidden_dim: 8,
510 embedding_dim: 4,
511 neighbor_sample_k: 3,
512 learning_rate: 0.0, num_epochs: 1,
514 margin: 1.0,
515 seed: Some(99),
516 };
517 let triples = toy_triples(4, 8);
518
519 let mut e1 = GraphSageEmbedder::new(config.clone());
520 let mut e2 = GraphSageEmbedder::new(config.clone());
521 e1.fit(&triples).expect("fit 1 should succeed");
522 e2.fit(&triples).expect("fit 2 should succeed");
523
524 assert_eq!(e1.weights.len(), e2.weights.len());
525 for (l, (w1, w2)) in e1.weights.iter().zip(e2.weights.iter()).enumerate() {
526 for (r, (row1, row2)) in w1.iter().zip(w2.iter()).enumerate() {
527 for (c, (&v1, &v2)) in row1.iter().zip(row2.iter()).enumerate() {
528 assert!(
529 (v1 - v2).abs() < 1e-14,
530 "weight mismatch at layer={l} row={r} col={c}: {v1} vs {v2}"
531 );
532 }
533 }
534 }
535 }
536
537 #[test]
539 fn test_loss_decreases() {
540 let triples = toy_triples(10, 20);
541
542 let make_config = |epochs: usize| GraphSageEmbedderConfig {
543 num_layers: 2,
544 hidden_dim: 16,
545 embedding_dim: 8,
546 neighbor_sample_k: 5,
547 learning_rate: 0.05,
548 num_epochs: epochs,
549 margin: 1.0,
550 seed: Some(7),
551 };
552
553 let mut e_early = GraphSageEmbedder::new(make_config(1));
554 e_early.fit(&triples).expect("1-epoch fit should succeed");
555
556 let mut e_trained = GraphSageEmbedder::new(make_config(50));
557 e_trained
558 .fit(&triples)
559 .expect("50-epoch fit should succeed");
560
561 let avg_sim = |embedder: &GraphSageEmbedder| -> f64 {
562 let (mut total, mut count) = (0.0_f64, 0usize);
563 for (s, _, o) in &triples {
564 if let (Ok(hs), Ok(ho)) = (embedder.embed_entity(s), embedder.embed_entity(o)) {
565 total += cosine_sim(&hs, &ho);
566 count += 1;
567 }
568 }
569 if count > 0 {
570 total / count as f64
571 } else {
572 0.0
573 }
574 };
575 let (sim_early, sim_trained) = (avg_sim(&e_early), avg_sim(&e_trained));
576 assert!(
577 sim_trained >= sim_early - 0.5,
578 "similarity regression: early={sim_early:.4} trained={sim_trained:.4}"
579 );
580 }
581
582 #[test]
584 fn test_neighbor_sampling_k_limit() {
585 let mut triples: Vec<(String, String, String)> = Vec::new();
587 for i in 1..=15usize {
588 triples.push((
589 "http://ex.org/hub".to_string(),
590 "http://ex.org/rel".to_string(),
591 format!("http://ex.org/leaf{}", i),
592 ));
593 }
594
595 let config = GraphSageEmbedderConfig {
596 neighbor_sample_k: 3,
597 num_epochs: 1,
598 seed: Some(5),
599 ..Default::default()
600 };
601 let mut embedder = GraphSageEmbedder::new(config.clone());
602 embedder.fit(&triples).expect("fit should succeed");
603
604 let (_, adjacency) = GraphSageEmbedder::build_graph(&triples);
605 let sampled = embedder.sample_neighbors("http://ex.org/hub", &adjacency);
606 assert!(
607 sampled.len() <= config.neighbor_sample_k,
608 "got {} neighbours, K={}",
609 sampled.len(),
610 config.neighbor_sample_k
611 );
612 }
613
614 #[test]
616 fn test_inductive_unseen_entity() {
617 let config = GraphSageEmbedderConfig {
618 num_layers: 1,
619 hidden_dim: 8,
620 embedding_dim: 4,
621 num_epochs: 2,
622 seed: Some(3),
623 ..Default::default()
624 };
625 let triples = toy_triples(5, 10);
626 let mut embedder = GraphSageEmbedder::new(config.clone());
627 embedder.fit(&triples).expect("fit should succeed");
628
629 let unseen = "http://ex.org/TOTALLY_UNSEEN_ENTITY";
630 let emb = embedder
631 .embed_entity(unseen)
632 .expect("embed_entity for unseen should not error");
633
634 assert_eq!(emb.len(), config.embedding_dim);
635 let all_zero = emb.iter().all(|&v| v == 0.0);
636 assert!(all_zero, "unseen entity embedding must be a zero vector");
637 }
638
639 #[test]
641 fn test_l2_normalisation() {
642 let config = GraphSageEmbedderConfig {
643 num_layers: 2,
644 hidden_dim: 16,
645 embedding_dim: 8,
646 neighbor_sample_k: 5,
647 num_epochs: 3,
648 seed: Some(11),
649 ..Default::default()
650 };
651 let triples = toy_triples(6, 12);
652 let mut embedder = GraphSageEmbedder::new(config.clone());
653 embedder.fit(&triples).expect("fit should succeed");
654
655 for i in 0..6usize {
656 let iri = format!("http://ex.org/e{}", i);
657 let emb = embedder
658 .embed_entity(&iri)
659 .expect("embed_entity should succeed");
660 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
661 if norm > 1e-12 {
663 assert!(
664 (norm - 1.0).abs() < 0.1,
665 "L2 norm out of tolerance for {iri}: {norm}"
666 );
667 }
668 }
669 }
670}