1use crate::error::{GraphError, Result};
12
13struct Lcg(u64);
16
17impl Lcg {
18 fn new(seed: u64) -> Self {
19 Self(seed ^ 0xdeadbeefcafe1234)
20 }
21 fn next_f64(&mut self) -> f64 {
23 self.0 = self
24 .0
25 .wrapping_mul(6364136223846793005)
26 .wrapping_add(1442695040888963407);
27 let bits = self.0 >> 11;
29 bits as f64 / (1u64 << 53) as f64
30 }
31 fn next_usize(&mut self, bound: usize) -> usize {
33 self.0 = self
34 .0
35 .wrapping_mul(6364136223846793005)
36 .wrapping_add(1442695040888963407);
37 ((self.0 >> 33) as usize) % bound
38 }
39}
40
41#[derive(Debug, Clone)]
47pub struct NodeMaskingConfig {
48 pub mask_rate: f64,
50 pub replace_rate: f64,
53 pub n_neighbors: usize,
56 pub feature_dim: usize,
58}
59
60impl Default for NodeMaskingConfig {
61 fn default() -> Self {
62 Self {
63 mask_rate: 0.15,
64 replace_rate: 0.1,
65 n_neighbors: 2,
66 feature_dim: 64,
67 }
68 }
69}
70
71pub struct NodeMaskingPretrainer {
83 config: NodeMaskingConfig,
84}
85
86impl NodeMaskingPretrainer {
87 pub fn new(config: NodeMaskingConfig) -> Self {
89 Self { config }
90 }
91
92 pub fn mask_nodes(
107 &self,
108 features: &[Vec<f64>],
109 rng_seed: u64,
110 ) -> Result<(Vec<Vec<f64>>, Vec<usize>)> {
111 let n = features.len();
112 if n == 0 {
113 return Ok((vec![], vec![]));
114 }
115 let dim = features[0].len();
116 if dim == 0 {
117 return Err(GraphError::InvalidParameter {
118 param: "features".to_string(),
119 value: "empty feature vectors".to_string(),
120 expected: "non-empty feature vectors".to_string(),
121 context: "NodeMaskingPretrainer::mask_nodes".to_string(),
122 });
123 }
124 for (i, f) in features.iter().enumerate() {
125 if f.len() != dim {
126 return Err(GraphError::InvalidParameter {
127 param: format!("features[{i}]"),
128 value: format!("length {}", f.len()),
129 expected: format!("length {dim}"),
130 context: "NodeMaskingPretrainer::mask_nodes".to_string(),
131 });
132 }
133 }
134 if !(0.0 < self.config.mask_rate && self.config.mask_rate <= 1.0) {
135 return Err(GraphError::InvalidParameter {
136 param: "mask_rate".to_string(),
137 value: format!("{}", self.config.mask_rate),
138 expected: "value in (0, 1]".to_string(),
139 context: "NodeMaskingPretrainer::mask_nodes".to_string(),
140 });
141 }
142
143 let k = ((n as f64 * self.config.mask_rate).ceil() as usize).min(n);
144 let mut rng = Lcg::new(rng_seed);
145
146 let mut indices: Vec<usize> = (0..n).collect();
148 for i in (n - k..n).rev() {
149 let j = rng.next_usize(i + 1);
150 indices.swap(i, j);
151 }
152 let mut masked_indices: Vec<usize> = indices[n - k..].to_vec();
153 masked_indices.sort_unstable();
154
155 let mut masked = features.to_vec();
157 let masked_set: std::collections::HashSet<usize> = masked_indices.iter().cloned().collect();
158 for &node in &masked_indices {
159 let replace = rng.next_f64() < self.config.replace_rate;
160 masked[node] = if replace {
161 (0..dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect()
163 } else {
164 vec![0.0; dim]
166 };
167 }
168 let _ = masked_set;
170
171 Ok((masked, masked_indices))
172 }
173
174 pub fn reconstruction_loss(
185 &self,
186 predicted: &[Vec<f64>],
187 original: &[Vec<f64>],
188 masked_indices: &[usize],
189 ) -> Result<f64> {
190 if predicted.len() != original.len() {
191 return Err(GraphError::InvalidParameter {
192 param: "predicted / original".to_string(),
193 value: format!("lengths {} vs {}", predicted.len(), original.len()),
194 expected: "equal lengths".to_string(),
195 context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
196 });
197 }
198 if masked_indices.is_empty() {
199 return Ok(0.0);
200 }
201 let n = predicted.len();
202 let mut total = 0.0_f64;
203 let mut count = 0usize;
204 for &idx in masked_indices {
205 if idx >= n {
206 return Err(GraphError::InvalidParameter {
207 param: "masked_indices".to_string(),
208 value: format!("{idx}"),
209 expected: format!("index < {n}"),
210 context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
211 });
212 }
213 let p = &predicted[idx];
214 let o = &original[idx];
215 if p.len() != o.len() {
216 return Err(GraphError::InvalidParameter {
217 param: format!("predicted[{idx}]"),
218 value: format!("length {}", p.len()),
219 expected: format!("length {}", o.len()),
220 context: "NodeMaskingPretrainer::reconstruction_loss".to_string(),
221 });
222 }
223 for (a, b) in p.iter().zip(o.iter()) {
224 let diff = a - b;
225 total += diff * diff;
226 count += 1;
227 }
228 }
229 Ok(if count > 0 { total / count as f64 } else { 0.0 })
230 }
231}
232
233#[non_exhaustive]
239#[derive(Debug, Clone)]
240pub struct GraphContextConfig {
241 pub context_size: usize,
244 pub negative_samples: usize,
246 pub feature_dim: usize,
248 pub temperature: f64,
250}
251
252impl Default for GraphContextConfig {
253 fn default() -> Self {
254 Self {
255 context_size: 8,
256 negative_samples: 4,
257 feature_dim: 64,
258 temperature: 0.07,
259 }
260 }
261}
262
263pub struct GraphContextPretrainer {
268 config: GraphContextConfig,
269}
270
271impl GraphContextPretrainer {
272 pub fn new(config: GraphContextConfig) -> Self {
274 Self { config }
275 }
276
277 pub fn sample_context_subgraph(
292 &self,
293 adj: &[(usize, usize)],
294 center: usize,
295 n_nodes: usize,
296 seed: u64,
297 ) -> Result<Vec<usize>> {
298 if n_nodes == 0 {
299 return Ok(vec![]);
300 }
301 if center >= n_nodes {
302 return Err(GraphError::InvalidParameter {
303 param: "center".to_string(),
304 value: format!("{center}"),
305 expected: format!("index < {n_nodes}"),
306 context: "GraphContextPretrainer::sample_context_subgraph".to_string(),
307 });
308 }
309
310 let mut lists: Vec<Vec<usize>> = vec![Vec::new(); n_nodes];
312 for &(u, v) in adj {
313 if u < n_nodes && v < n_nodes && u != v {
314 lists[u].push(v);
315 lists[v].push(u);
316 }
317 }
318
319 let max_ctx = self.config.context_size.max(1);
320 let mut visited = vec![false; n_nodes];
321 let mut result = Vec::with_capacity(max_ctx);
322 let mut queue = std::collections::VecDeque::new();
323 let mut rng = Lcg::new(seed);
324
325 visited[center] = true;
326 queue.push_back(center);
327 result.push(center);
328
329 while let Some(v) = queue.pop_front() {
330 if result.len() >= max_ctx {
331 break;
332 }
333 let mut nbrs = lists[v].clone();
335 for i in (1..nbrs.len()).rev() {
336 let j = rng.next_usize(i + 1);
337 nbrs.swap(i, j);
338 }
339 for w in nbrs {
340 if result.len() >= max_ctx {
341 break;
342 }
343 if !visited[w] {
344 visited[w] = true;
345 result.push(w);
346 queue.push_back(w);
347 }
348 }
349 }
350
351 result.sort_unstable();
352 Ok(result)
353 }
354
355 pub fn contrastive_loss(
372 &self,
373 anchor: &[f64],
374 positive: &[f64],
375 negatives: &[Vec<f64>],
376 temperature: f64,
377 ) -> Result<f64> {
378 infonce_loss(anchor, positive, negatives, temperature)
379 }
380}
381
382fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
385 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
386 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
387 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
388 if na == 0.0 || nb == 0.0 {
389 0.0
390 } else {
391 dot / (na * nb)
392 }
393}
394
395pub fn infonce_loss(
401 anchor: &[f64],
402 positive: &[f64],
403 negatives: &[Vec<f64>],
404 temperature: f64,
405) -> Result<f64> {
406 let dim = anchor.len();
407 if dim == 0 {
408 return Err(GraphError::InvalidParameter {
409 param: "anchor".to_string(),
410 value: "empty".to_string(),
411 expected: "non-empty embedding vector".to_string(),
412 context: "infonce_loss".to_string(),
413 });
414 }
415 if positive.len() != dim {
416 return Err(GraphError::InvalidParameter {
417 param: "positive".to_string(),
418 value: format!("length {}", positive.len()),
419 expected: format!("length {dim}"),
420 context: "infonce_loss".to_string(),
421 });
422 }
423 if temperature <= 0.0 {
424 return Err(GraphError::InvalidParameter {
425 param: "temperature".to_string(),
426 value: format!("{temperature}"),
427 expected: "positive value".to_string(),
428 context: "infonce_loss".to_string(),
429 });
430 }
431 for (i, neg) in negatives.iter().enumerate() {
432 if neg.len() != dim {
433 return Err(GraphError::InvalidParameter {
434 param: format!("negatives[{i}]"),
435 value: format!("length {}", neg.len()),
436 expected: format!("length {dim}"),
437 context: "infonce_loss".to_string(),
438 });
439 }
440 }
441
442 let sim_pos = cosine_similarity(anchor, positive) / temperature;
443 let mut sims: Vec<f64> = std::iter::once(sim_pos)
445 .chain(
446 negatives
447 .iter()
448 .map(|n| cosine_similarity(anchor, n) / temperature),
449 )
450 .collect();
451 let max_sim = sims.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
452 for s in sims.iter_mut() {
453 *s = (*s - max_sim).exp();
454 }
455 let denom: f64 = sims.iter().sum();
456 let loss = -(sims[0].ln() - denom.ln());
457 Ok(loss)
458}
459
460#[non_exhaustive]
466#[derive(Debug, Clone)]
467pub struct AttrReconConfig {
468 pub n_layers: usize,
470 pub hidden_dim: usize,
472 pub dropout: f64,
474}
475
476impl Default for AttrReconConfig {
477 fn default() -> Self {
478 Self {
479 n_layers: 2,
480 hidden_dim: 128,
481 dropout: 0.1,
482 }
483 }
484}
485
486#[derive(Debug, Clone)]
488struct LinearLayer {
489 weights: Vec<Vec<f64>>,
491 bias: Vec<f64>,
493}
494
495impl LinearLayer {
496 fn new(in_dim: usize, out_dim: usize, seed: u64) -> Self {
498 let mut rng = Lcg::new(seed);
499 let scale = (6.0 / (in_dim + out_dim) as f64).sqrt();
500 let weights = (0..out_dim)
501 .map(|_| {
502 (0..in_dim)
503 .map(|_| (rng.next_f64() * 2.0 - 1.0) * scale)
504 .collect()
505 })
506 .collect();
507 let bias = vec![0.0f64; out_dim];
508 Self { weights, bias }
509 }
510
511 fn forward_tanh(&self, x: &[f64]) -> Vec<f64> {
513 self.weights
514 .iter()
515 .zip(self.bias.iter())
516 .map(|(row, b)| {
517 let pre: f64 = row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum::<f64>() + b;
518 pre.tanh()
519 })
520 .collect()
521 }
522
523 fn forward_linear(&self, x: &[f64]) -> Vec<f64> {
525 self.weights
526 .iter()
527 .zip(self.bias.iter())
528 .map(|(row, b)| row.iter().zip(x.iter()).map(|(w, xi)| w * xi).sum::<f64>() + b)
529 .collect()
530 }
531}
532
533pub struct AttributeReconstructionObjective {
539 config: AttrReconConfig,
540 layers: Vec<LinearLayer>,
541 input_dim: usize,
542}
543
544impl AttributeReconstructionObjective {
545 pub fn new(input_dim: usize, config: AttrReconConfig, seed: u64) -> Result<Self> {
554 if input_dim == 0 {
555 return Err(GraphError::InvalidParameter {
556 param: "input_dim".to_string(),
557 value: "0".to_string(),
558 expected: "positive dimension".to_string(),
559 context: "AttributeReconstructionObjective::new".to_string(),
560 });
561 }
562 if config.n_layers == 0 {
563 return Err(GraphError::InvalidParameter {
564 param: "n_layers".to_string(),
565 value: "0".to_string(),
566 expected: "at least 1 layer".to_string(),
567 context: "AttributeReconstructionObjective::new".to_string(),
568 });
569 }
570 let hidden = config.hidden_dim.max(1);
571 let mut layers = Vec::with_capacity(config.n_layers);
572
573 layers.push(LinearLayer::new(input_dim, hidden, seed));
575
576 for i in 1..config.n_layers.saturating_sub(1) {
578 layers.push(LinearLayer::new(
579 hidden,
580 hidden,
581 seed.wrapping_add(i as u64),
582 ));
583 }
584
585 if config.n_layers > 1 {
587 layers.push(LinearLayer::new(
588 hidden,
589 input_dim,
590 seed.wrapping_add(config.n_layers as u64),
591 ));
592 }
593
594 Ok(Self {
595 config,
596 layers,
597 input_dim,
598 })
599 }
600
601 pub fn forward(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
613 features
614 .iter()
615 .enumerate()
616 .map(|(i, f)| {
617 if f.len() != self.input_dim {
618 return Err(GraphError::InvalidParameter {
619 param: format!("features[{i}]"),
620 value: format!("length {}", f.len()),
621 expected: format!("length {}", self.input_dim),
622 context: "AttributeReconstructionObjective::forward".to_string(),
623 });
624 }
625 let mut h = f.clone();
626 let last = self.layers.len().saturating_sub(1);
627 for (j, layer) in self.layers.iter().enumerate() {
628 h = if j < last {
629 layer.forward_tanh(&h)
630 } else {
631 layer.forward_linear(&h)
632 };
633 }
634 Ok(h)
635 })
636 .collect()
637 }
638
639 pub fn mse_loss(&self, predicted: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
644 if predicted.len() != target.len() {
645 return Err(GraphError::InvalidParameter {
646 param: "predicted".to_string(),
647 value: format!("length {}", predicted.len()),
648 expected: format!("length {}", target.len()),
649 context: "AttributeReconstructionObjective::mse_loss".to_string(),
650 });
651 }
652 if predicted.is_empty() {
653 return Ok(0.0);
654 }
655 let mut total = 0.0_f64;
656 let mut count = 0usize;
657 for (p_row, t_row) in predicted.iter().zip(target.iter()) {
658 if p_row.len() != t_row.len() {
659 return Err(GraphError::InvalidParameter {
660 param: "predicted row".to_string(),
661 value: format!("length {}", p_row.len()),
662 expected: format!("length {}", t_row.len()),
663 context: "AttributeReconstructionObjective::mse_loss".to_string(),
664 });
665 }
666 for (a, b) in p_row.iter().zip(t_row.iter()) {
667 let diff = a - b;
668 total += diff * diff;
669 count += 1;
670 }
671 }
672 Ok(if count > 0 { total / count as f64 } else { 0.0 })
673 }
674
675 pub fn config(&self) -> &AttrReconConfig {
677 &self.config
678 }
679}
680
681#[cfg(test)]
684mod tests {
685 use super::*;
686
687 #[test]
690 fn test_masking_correct_fraction() {
691 let n = 100;
692 let dim = 8;
693 let features: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64; dim]).collect();
694 let cfg = NodeMaskingConfig {
695 mask_rate: 0.15,
696 replace_rate: 0.0,
697 feature_dim: dim,
698 ..Default::default()
699 };
700 let pretrainer = NodeMaskingPretrainer::new(cfg);
701 let (_, indices) = pretrainer.mask_nodes(&features, 7).unwrap();
702 assert_eq!(indices.len(), 15, "should mask exactly 15 nodes");
704 }
705
706 #[test]
707 fn test_masking_features_differ() {
708 let n = 20;
709 let dim = 4;
710 let features: Vec<Vec<f64>> = (0..n).map(|i| vec![(i + 1) as f64; dim]).collect();
711 let cfg = NodeMaskingConfig {
712 mask_rate: 0.5,
713 replace_rate: 0.0,
714 feature_dim: dim,
715 ..Default::default()
716 };
717 let pretrainer = NodeMaskingPretrainer::new(cfg);
718 let (masked, indices) = pretrainer.mask_nodes(&features, 99).unwrap();
719 for &idx in &indices {
721 assert_eq!(masked[idx], vec![0.0; dim], "node {idx} should be zeroed");
722 }
723 for i in 0..n {
725 if !indices.contains(&i) {
726 assert_eq!(masked[i], features[i], "node {i} should be unchanged");
727 }
728 }
729 }
730
731 #[test]
732 fn test_reconstruction_loss_finite_positive() {
733 let n = 10;
734 let dim = 6;
735 let original: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64; dim]).collect();
736 let cfg = NodeMaskingConfig {
737 mask_rate: 0.3,
738 feature_dim: dim,
739 ..Default::default()
740 };
741 let pretrainer = NodeMaskingPretrainer::new(cfg);
742 let (masked, indices) = pretrainer.mask_nodes(&original, 11).unwrap();
743 let loss = pretrainer
744 .reconstruction_loss(&masked, &original, &indices)
745 .unwrap();
746 assert!(loss.is_finite(), "loss should be finite");
747 assert!(loss >= 0.0, "loss should be non-negative");
748 }
749
750 #[test]
753 fn test_context_subgraph_bounded() {
754 let edges: Vec<(usize, usize)> = (0..9).map(|i| (i, i + 1)).collect();
755 let cfg = GraphContextConfig {
756 context_size: 4,
757 ..Default::default()
758 };
759 let pretrainer = GraphContextPretrainer::new(cfg.clone());
760 let ctx = pretrainer
761 .sample_context_subgraph(&edges, 5, 10, 42)
762 .unwrap();
763 assert!(
764 ctx.len() <= cfg.context_size,
765 "context size {} should be ≤ {}",
766 ctx.len(),
767 cfg.context_size
768 );
769 }
770
771 #[test]
772 fn test_context_subgraph_contains_center() {
773 let edges = vec![(0, 1), (1, 2), (2, 3)];
774 let cfg = GraphContextConfig {
775 context_size: 3,
776 ..Default::default()
777 };
778 let pretrainer = GraphContextPretrainer::new(cfg);
779 let ctx = pretrainer.sample_context_subgraph(&edges, 1, 4, 0).unwrap();
780 assert!(ctx.contains(&1), "context should include center node 1");
781 }
782
783 #[test]
784 fn test_contrastive_loss_pos_closer_lower_loss() {
785 let anchor = vec![1.0_f64, 0.0];
788 let positive = vec![1.0_f64, 0.0];
789 let negatives = vec![vec![-1.0_f64, 0.0]; 4];
790 let cfg = GraphContextConfig {
791 temperature: 0.07,
792 ..Default::default()
793 };
794 let pretrainer = GraphContextPretrainer::new(cfg.clone());
795 let loss = pretrainer
796 .contrastive_loss(&anchor, &positive, &negatives, cfg.temperature)
797 .unwrap();
798 let far_negatives = vec![vec![1.0_f64, 0.0]; 4];
800 let high_loss = pretrainer
801 .contrastive_loss(&anchor, &positive, &far_negatives, cfg.temperature)
802 .unwrap();
803 assert!(
804 loss < high_loss,
805 "loss with far negatives ({loss}) should be lower than loss with close negatives ({high_loss})"
806 );
807 }
808
809 #[test]
810 fn test_contrastive_loss_finite() {
811 let anchor = vec![0.5, 0.3, 0.2];
812 let positive = vec![0.4, 0.4, 0.2];
813 let negatives = vec![vec![0.1, 0.1, 0.8], vec![-0.1, 0.5, 0.4]];
814 let loss = infonce_loss(&anchor, &positive, &negatives, 0.1).unwrap();
815 assert!(loss.is_finite(), "InfoNCE loss should be finite");
816 assert!(loss >= 0.0, "InfoNCE loss should be non-negative");
817 }
818
819 #[test]
822 fn test_attr_recon_forward_shape() {
823 let cfg = AttrReconConfig {
824 n_layers: 2,
825 hidden_dim: 16,
826 dropout: 0.0,
827 };
828 let obj = AttributeReconstructionObjective::new(8, cfg, 123).unwrap();
829 let features: Vec<Vec<f64>> = (0..5).map(|_| vec![1.0; 8]).collect();
830 let out = obj.forward(&features).unwrap();
831 assert_eq!(out.len(), 5, "output should have same number of nodes");
832 for row in &out {
833 assert_eq!(row.len(), 8, "each output vector should have dim 8");
834 }
835 }
836
837 #[test]
838 fn test_config_defaults() {
839 let pr = NodeMaskingConfig::default();
840 assert!((pr.mask_rate - 0.15).abs() < 1e-9);
841 assert!((pr.replace_rate - 0.1).abs() < 1e-9);
842 assert_eq!(pr.n_neighbors, 2);
843 assert_eq!(pr.feature_dim, 64);
844
845 let gc = GraphContextConfig::default();
846 assert_eq!(gc.context_size, 8);
847 assert_eq!(gc.negative_samples, 4);
848 assert!((gc.temperature - 0.07).abs() < 1e-9);
849
850 let ar = AttrReconConfig::default();
851 assert_eq!(ar.n_layers, 2);
852 assert_eq!(ar.hidden_dim, 128);
853 assert!((ar.dropout - 0.1).abs() < 1e-9);
854 }
855
856 #[test]
857 fn test_empty_graph_handling() {
858 let cfg = NodeMaskingConfig::default();
860 let pretrainer = NodeMaskingPretrainer::new(cfg);
861 let (masked, indices) = pretrainer.mask_nodes(&[], 0).unwrap();
862 assert!(masked.is_empty());
863 assert!(indices.is_empty());
864
865 let cfg2 = GraphContextConfig::default();
867 let pretrainer2 = GraphContextPretrainer::new(cfg2);
868 let ctx = pretrainer2.sample_context_subgraph(&[], 0, 0, 0).unwrap();
869 assert!(ctx.is_empty());
870 }
871}