Skip to main content

scirs2_graph/ssl/
contrastive.rs

1//! Graph Contrastive Learning: GraphCL (You et al. 2020) and SimGRACE (Xia 2022).
2//!
3//! ## GraphCL Augmentations
4//!
5//! GraphCL creates two augmented views of each graph and trains an encoder to
6//! maximise agreement between the views using the **NT-Xent** (normalised
7//! temperature-scaled cross-entropy) loss:
8//!
9//! ```text
10//! L = -(1/2N) Σ_i [ log exp(s(z_i, z_i') / τ) / Σ_{k≠i} exp(s(z_i, z_k) / τ) ]
11//! ```
12//!
13//! Supported augmentations:
14//! - **Feature masking**: zero out a random fraction of node feature dimensions.
15//! - **Edge perturbation**: randomly drop existing edges and/or insert new ones.
16//!
17//! ## SimGRACE
18//!
19//! SimGRACE (Xia et al. 2022) avoids explicit augmentations by creating the
20//! second view through small Gaussian perturbations of the encoder weights.
21
22use scirs2_core::ndarray::{Array1, Array2, Axis};
23use scirs2_core::random::{Rng, RngExt, SeedableRng};
24
25// ============================================================================
26// Configuration
27// ============================================================================
28
29/// Configuration for GraphCL-style contrastive learning.
30#[derive(Debug, Clone)]
31pub struct GraphClConfig {
32    /// Temperature parameter τ in the NT-Xent loss.
33    pub temperature: f64,
34    /// Output dimension of the projection head.
35    pub proj_dim: usize,
36    /// Fraction of node feature dimensions to zero out in each view.
37    pub mask_feature_rate: f64,
38    /// Fraction of edges to randomly drop from the adjacency matrix.
39    pub drop_edge_rate: f64,
40    /// Fraction of non-edges to randomly add to the adjacency matrix.
41    pub add_edge_rate: f64,
42}
43
44impl Default for GraphClConfig {
45    fn default() -> Self {
46        Self {
47            temperature: 0.5,
48            proj_dim: 128,
49            mask_feature_rate: 0.1,
50            drop_edge_rate: 0.1,
51            add_edge_rate: 0.0,
52        }
53    }
54}
55
56// ============================================================================
57// Augmentation functions
58// ============================================================================
59
60/// Feature-masking augmentation (GraphCL).
61///
62/// Independently zeros out each feature dimension of every node with
63/// probability `mask_rate`.  When `mask_rate = 0.0` the input is returned
64/// unchanged; when `mask_rate = 1.0` a zero matrix is returned.
65///
66/// # Arguments
67/// * `features`  – node feature matrix `[n_nodes × feature_dim]`
68/// * `mask_rate` – probability of zeroing each feature dimension
69/// * `seed`      – RNG seed for reproducibility
70pub fn augment_features(features: &Array2<f64>, mask_rate: f64, seed: u64) -> Array2<f64> {
71    if mask_rate <= 0.0 {
72        return features.clone();
73    }
74    if mask_rate >= 1.0 {
75        return Array2::zeros(features.dim());
76    }
77
78    let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
79    let mut out = features.clone();
80    let (n_nodes, feat_dim) = features.dim();
81
82    for i in 0..n_nodes {
83        for j in 0..feat_dim {
84            if rng.random::<f64>() < mask_rate {
85                out[[i, j]] = 0.0;
86            }
87        }
88    }
89    out
90}
91
92/// Edge-perturbation augmentation (GraphCL).
93///
94/// For each existing edge, drops it with probability `drop_rate`.
95/// For each non-edge, adds it with probability `add_rate`.
96///
97/// The returned matrix is forced to be symmetric (undirected graph).
98///
99/// # Arguments
100/// * `adj`       – adjacency matrix `[n × n]` (any non-zero entry = edge)
101/// * `drop_rate` – probability of removing an existing edge
102/// * `add_rate`  – probability of adding a new edge between non-adjacent nodes
103/// * `seed`      – RNG seed
104pub fn augment_edges(adj: &Array2<f64>, drop_rate: f64, add_rate: f64, seed: u64) -> Array2<f64> {
105    let n = adj.dim().0;
106    let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
107    let mut out = adj.clone();
108
109    for i in 0..n {
110        for j in (i + 1)..n {
111            if adj[[i, j]] > 0.0 {
112                // Existing edge: maybe drop
113                if drop_rate > 0.0 && rng.random::<f64>() < drop_rate {
114                    out[[i, j]] = 0.0;
115                    out[[j, i]] = 0.0;
116                }
117            } else {
118                // Non-edge: maybe add
119                if add_rate > 0.0 && rng.random::<f64>() < add_rate {
120                    out[[i, j]] = 1.0;
121                    out[[j, i]] = 1.0;
122                }
123            }
124        }
125    }
126    out
127}
128
129// ============================================================================
130// NT-Xent loss
131// ============================================================================
132
133/// NT-Xent (normalised temperature-scaled cross-entropy) contrastive loss.
134///
135/// Given two batches of projected representations `z1` and `z2` (each of
136/// shape `[batch × proj_dim]`), where `(z1[i], z2[i])` are positive pairs,
137/// computes the symmetric InfoNCE loss over all `2N` samples.
138///
139/// # Arguments
140/// * `z1`         – first-view projections `[N × D]`
141/// * `z2`         – second-view projections `[N × D]`
142/// * `temperature` – temperature τ > 0; lower values create sharper distributions
143///
144/// # Returns
145/// Scalar loss value.
146pub fn nt_xent_loss(z1: &Array2<f64>, z2: &Array2<f64>, temperature: f64) -> f64 {
147    let (n, _d) = z1.dim();
148    assert_eq!(z1.dim(), z2.dim(), "z1 and z2 must have the same shape");
149    assert!(temperature > 0.0, "temperature must be positive");
150
151    // L2-normalise each row
152    let norm_z1 = l2_normalise_rows(z1);
153    let norm_z2 = l2_normalise_rows(z2);
154
155    // Stack: rows 0..N from z1, rows N..2N from z2  →  [2N × D]
156    let mut stacked = Array2::zeros((2 * n, z1.dim().1));
157    for i in 0..n {
158        for d in 0..z1.dim().1 {
159            stacked[[i, d]] = norm_z1[[i, d]];
160            stacked[[i + n, d]] = norm_z2[[i, d]];
161        }
162    }
163
164    // Compute full cosine similarity matrix [2N × 2N] / tau
165    let two_n = 2 * n;
166    let mut sim = Array2::zeros((two_n, two_n));
167    for i in 0..two_n {
168        for j in 0..two_n {
169            let mut dot = 0.0;
170            for d in 0..stacked.dim().1 {
171                dot += stacked[[i, d]] * stacked[[j, d]];
172            }
173            sim[[i, j]] = dot / temperature;
174        }
175    }
176
177    // Mask diagonal (self-similarity) with large negative value
178    for i in 0..two_n {
179        sim[[i, i]] = f64::NEG_INFINITY;
180    }
181
182    // Positive pair indices:
183    //   for i in 0..N: positive = i+N
184    //   for i in N..2N: positive = i-N
185    let mut loss = 0.0;
186    for i in 0..two_n {
187        let pos_j = if i < n { i + n } else { i - n };
188        let pos_score = sim[[i, pos_j]];
189
190        // log-sum-exp over all non-self entries
191        let row_scores: Vec<f64> = (0..two_n)
192            .filter(|&j| j != i)
193            .map(|j| sim[[i, j]])
194            .collect();
195        let max_s = row_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
196        let log_sum_exp = max_s
197            + row_scores
198                .iter()
199                .map(|&s| (s - max_s).exp())
200                .sum::<f64>()
201                .ln();
202
203        loss += -(pos_score - log_sum_exp);
204    }
205
206    loss / two_n as f64
207}
208
209/// L2-normalise each row of a 2-D array.
210fn l2_normalise_rows(x: &Array2<f64>) -> Array2<f64> {
211    let norms: Array1<f64> = x.map_axis(Axis(1), |row| {
212        let s: f64 = row.iter().map(|&v| v * v).sum();
213        s.sqrt().max(1e-12)
214    });
215    let mut out = x.clone();
216    let (n, _d) = x.dim();
217    for i in 0..n {
218        for d in 0.._d {
219            out[[i, d]] /= norms[i];
220        }
221    }
222    out
223}
224
225// ============================================================================
226// Projection head
227// ============================================================================
228
229/// Two-layer MLP projection head used in contrastive learning.
230///
231/// Architecture: `in_dim → hidden_dim (ReLU) → out_dim`
232pub struct ProjectionHead {
233    w1: Array2<f64>,
234    b1: Array1<f64>,
235    w2: Array2<f64>,
236    b2: Array1<f64>,
237}
238
239impl ProjectionHead {
240    /// Construct a new projection head with Xavier-uniform initialised weights.
241    ///
242    /// # Arguments
243    /// * `in_dim`     – input dimension (encoder output size)
244    /// * `hidden_dim` – hidden layer dimension
245    /// * `out_dim`    – projection output dimension
246    /// * `seed`       – RNG seed
247    pub fn new(in_dim: usize, hidden_dim: usize, out_dim: usize, seed: u64) -> Self {
248        let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
249
250        let s1 = (6.0 / (in_dim + hidden_dim) as f64).sqrt();
251        let w1 = Array2::from_shape_fn((in_dim, hidden_dim), |_| {
252            rng.random::<f64>() * 2.0 * s1 - s1
253        });
254        let b1 = Array1::zeros(hidden_dim);
255
256        let s2 = (6.0 / (hidden_dim + out_dim) as f64).sqrt();
257        let w2 = Array2::from_shape_fn((hidden_dim, out_dim), |_| {
258            rng.random::<f64>() * 2.0 * s2 - s2
259        });
260        let b2 = Array1::zeros(out_dim);
261
262        ProjectionHead { w1, b1, w2, b2 }
263    }
264
265    /// Forward pass: `x → W1 x + b1 → ReLU → W2 x + b2`
266    ///
267    /// # Arguments
268    /// * `x` – input `[batch × in_dim]`
269    ///
270    /// # Returns
271    /// Projected representations `[batch × out_dim]`
272    pub fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
273        let batch = x.dim().0;
274        let hidden_dim = self.w1.dim().1;
275        let out_dim = self.w2.dim().1;
276
277        // First linear layer + ReLU
278        let mut h = Array2::zeros((batch, hidden_dim));
279        for i in 0..batch {
280            for j in 0..hidden_dim {
281                let mut val = self.b1[j];
282                for d in 0..x.dim().1 {
283                    val += x[[i, d]] * self.w1[[d, j]];
284                }
285                h[[i, j]] = if val > 0.0 { val } else { 0.0 };
286            }
287        }
288
289        // Second linear layer
290        let mut out = Array2::zeros((batch, out_dim));
291        for i in 0..batch {
292            for k in 0..out_dim {
293                let mut val = self.b2[k];
294                for j in 0..hidden_dim {
295                    val += h[[i, j]] * self.w2[[j, k]];
296                }
297                out[[i, k]] = val;
298            }
299        }
300
301        out
302    }
303
304    /// Input dimension.
305    pub fn in_dim(&self) -> usize {
306        self.w1.dim().0
307    }
308
309    /// Output projection dimension.
310    pub fn out_dim(&self) -> usize {
311        self.w2.dim().1
312    }
313}
314
315// ============================================================================
316// SimGRACE perturbation
317// ============================================================================
318
319/// SimGRACE weight perturbation.
320///
321/// Creates a second view by adding Gaussian noise to a weight matrix,
322/// effectively simulating a slightly different encoder without explicit
323/// graph augmentation.
324///
325/// # Arguments
326/// * `weights`     – weight matrix to perturb `[r × c]`
327/// * `noise_scale` – standard deviation of the Gaussian perturbation
328/// * `seed`        – RNG seed
329///
330/// # Returns
331/// Perturbed weight matrix of the same shape.
332pub fn simgrace_perturb(weights: &Array2<f64>, noise_scale: f64, seed: u64) -> Array2<f64> {
333    let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
334    weights.mapv(|v| {
335        // Box-Muller for Gaussian noise
336        let u1: f64 = rng.random::<f64>().max(1e-12);
337        let u2: f64 = rng.random::<f64>();
338        let noise = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
339        v + noise_scale * noise
340    })
341}
342
343// ============================================================================
344// Tests
345// ============================================================================
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_augment_features_zero_rate_identity() {
353        let x = Array2::from_shape_vec((3, 4), (0..12).map(|v| v as f64).collect()).expect("ok");
354        let out = augment_features(&x, 0.0, 0);
355        for (a, b) in x.iter().zip(out.iter()) {
356            assert_eq!(a, b);
357        }
358    }
359
360    #[test]
361    fn test_augment_features_full_rate_zeros() {
362        let x = Array2::ones((5, 8));
363        let out = augment_features(&x, 1.0, 0);
364        for v in out.iter() {
365            assert_eq!(*v, 0.0);
366        }
367    }
368
369    #[test]
370    fn test_nt_xent_identical_views_low_loss() {
371        // Identical (normalised) views → minimum possible loss (all positives perfectly aligned)
372        let z = Array2::from_shape_fn((8, 16), |(i, j)| if i == j { 1.0 } else { 0.0 });
373        let loss = nt_xent_loss(&z, &z, 0.5);
374        // With perfectly aligned views the loss should be near -log(1/(2N-1)) / 2N
375        // In practice it should be strictly positive (negatives still contribute)
376        assert!(loss >= 0.0, "loss should be non-negative, got {loss}");
377        // And lower than random baseline
378        let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(0);
379        let z_rand = Array2::from_shape_fn((8, 16), |_| rng.random::<f64>() - 0.5);
380        let loss_rand = nt_xent_loss(&z_rand, &z_rand, 0.5);
381        // Identical views have lower loss than random
382        assert!(loss <= loss_rand + 1e-6);
383    }
384
385    #[test]
386    fn test_nt_xent_random_views_positive_loss() {
387        let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(42);
388        let z1 = Array2::from_shape_fn((6, 8), |_| rng.random::<f64>() - 0.5);
389        let z2 = Array2::from_shape_fn((6, 8), |_| rng.random::<f64>() - 0.5);
390        let loss = nt_xent_loss(&z1, &z2, 0.5);
391        assert!(
392            loss > 0.0,
393            "loss with random views should be positive, got {loss}"
394        );
395    }
396
397    #[test]
398    fn test_projection_head_output_shape() {
399        let head = ProjectionHead::new(32, 64, 16, 0);
400        let x = Array2::ones((10, 32));
401        let out = head.forward(&x);
402        assert_eq!(out.dim(), (10, 16));
403    }
404
405    #[test]
406    fn test_projection_head_dims() {
407        let head = ProjectionHead::new(32, 64, 16, 0);
408        assert_eq!(head.in_dim(), 32);
409        assert_eq!(head.out_dim(), 16);
410    }
411
412    #[test]
413    fn test_simgrace_perturb_changes_weights() {
414        let w = Array2::ones((8, 8));
415        let perturbed = simgrace_perturb(&w, 0.1, 99);
416        let diff: f64 = w
417            .iter()
418            .zip(perturbed.iter())
419            .map(|(a, b)| (a - b).abs())
420            .sum();
421        assert!(
422            diff > 1e-10,
423            "perturbed weights should differ from original"
424        );
425    }
426
427    #[test]
428    fn test_simgrace_zero_noise_preserves_weights() {
429        let w = Array2::ones((4, 4));
430        let perturbed = simgrace_perturb(&w, 0.0, 0);
431        for (a, b) in w.iter().zip(perturbed.iter()) {
432            assert!((a - b).abs() < 1e-12);
433        }
434    }
435
436    #[test]
437    fn test_augment_edges_symmetry() {
438        // Build a symmetric adjacency matrix (path graph 0-1-2-3)
439        let mut adj = Array2::zeros((4, 4));
440        adj[[0, 1]] = 1.0;
441        adj[[1, 0]] = 1.0;
442        adj[[1, 2]] = 1.0;
443        adj[[2, 1]] = 1.0;
444        adj[[2, 3]] = 1.0;
445        adj[[3, 2]] = 1.0;
446
447        let aug = augment_edges(&adj, 0.3, 0.1, 7);
448        let n = 4;
449        for i in 0..n {
450            for j in 0..n {
451                assert_eq!(
452                    aug[[i, j]],
453                    aug[[j, i]],
454                    "augmented adjacency must remain symmetric at ({i},{j})"
455                );
456            }
457        }
458    }
459
460    #[test]
461    fn test_temperature_sensitivity() {
462        // Lower temperature → sharper distribution → different (usually lower) loss for aligned views
463        let z = Array2::from_shape_fn((4, 8), |(i, j)| if i == j { 1.0 } else { 0.0 });
464        let loss_low_t = nt_xent_loss(&z, &z, 0.1);
465        let loss_high_t = nt_xent_loss(&z, &z, 2.0);
466        // Both should be non-negative; they should differ
467        assert!(loss_low_t >= 0.0);
468        assert!(loss_high_t >= 0.0);
469        assert!(
470            (loss_low_t - loss_high_t).abs() > 1e-6,
471            "temperature should affect loss magnitude"
472        );
473    }
474}