Skip to main content

oxicuda_ssl/contrastive/
simclr.rs

1//! SimCLR — Chen et al. 2020 — symmetric NT-Xent contrastive loss.
2//!
3//! Given two augmented views `(z_a, z_b)` of the same `N` items, computes the
4//! symmetric InfoNCE loss with cosine similarity at temperature τ:
5//!
6//! ```text
7//!   L = (1/2N) Σ_i [ −log p_{ab}(i, i) − log p_{ba}(i, i) ]
8//!   p_{ab}(i, j) = exp(s_{ij}/τ) / Σ_k exp(s_{ik}/τ)
9//! ```
10
11use crate::contrastive::info_nce::info_nce_loss;
12use crate::error::{SslError, SslResult};
13
14/// Configuration for SimCLR.
15#[derive(Debug, Clone)]
16pub struct SimClrConfig {
17    /// Temperature for InfoNCE (default 0.1).
18    pub temperature: f32,
19}
20
21impl Default for SimClrConfig {
22    fn default() -> Self {
23        Self { temperature: 0.1 }
24    }
25}
26
27impl SimClrConfig {
28    /// Create a validated SimCLR config.
29    ///
30    /// # Errors
31    /// [`SslError::InvalidTemperature`] if `temperature <= 0` or non-finite.
32    pub fn new(temperature: f32) -> SslResult<Self> {
33        if !(temperature.is_finite() && temperature > 0.0) {
34            return Err(SslError::InvalidTemperature { temp: temperature });
35        }
36        Ok(Self { temperature })
37    }
38}
39
40/// Compute the SimCLR symmetric NT-Xent loss.
41///
42/// `z_a` and `z_b` are `[N, D]` row-major projection matrices for two
43/// augmented views of the same `N` items (positive pairs are diagonal).
44/// Returns `(loss, accuracy@1)`.
45///
46/// # Errors
47/// Propagates errors from [`info_nce_loss`].
48pub fn simclr_loss(
49    z_a: &[f32],
50    z_b: &[f32],
51    n: usize,
52    d: usize,
53    config: &SimClrConfig,
54) -> SslResult<(f32, f32)> {
55    info_nce_loss(z_a, z_b, n, d, config.temperature)
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn simclr_default_temperature() {
64        let cfg = SimClrConfig::default();
65        assert!((cfg.temperature - 0.1).abs() < 1e-7);
66    }
67
68    #[test]
69    fn simclr_new_validates_temperature() {
70        assert!(SimClrConfig::new(0.0).is_err());
71        assert!(SimClrConfig::new(-1.0).is_err());
72        assert!(SimClrConfig::new(f32::NAN).is_err());
73        assert!(SimClrConfig::new(0.5).is_ok());
74    }
75
76    #[test]
77    fn simclr_loss_distinct_paired_views_low() {
78        // Distinct rows so off-diagonal cosine similarity is small; identical
79        // pair so diagonal cosine is 1.
80        let n = 4;
81        let d = 8;
82        let mut z = vec![0.0_f32; n * d];
83        for i in 0..n {
84            z[i * d + i] = 1.0;
85        }
86        let cfg = SimClrConfig::default();
87        let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).expect("simclr_loss should succeed");
88        assert!(loss < 0.5);
89        assert!((acc - 1.0).abs() < 1e-6);
90    }
91
92    #[test]
93    fn simclr_loss_random_views_finite() {
94        let n = 16;
95        let d = 32;
96        let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.0123).sin()).collect();
97        let z_b: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.0451).cos()).collect();
98        let cfg = SimClrConfig::default();
99        let (loss, _) = simclr_loss(&z_a, &z_b, n, d, &cfg).expect("simclr_loss should succeed");
100        assert!(loss.is_finite());
101    }
102
103    #[test]
104    fn simclr_temperature_affects_loss() {
105        let n = 4;
106        let d = 8;
107        let z: Vec<f32> = (0..n * d).map(|i| (i as f32) * 0.05).collect();
108        let z2: Vec<f32> = z.iter().map(|v| v + 0.1).collect();
109        let high_t = SimClrConfig { temperature: 1.0 };
110        let low_t = SimClrConfig { temperature: 0.05 };
111        let (l_high, _) = simclr_loss(&z, &z2, n, d, &high_t).expect("simclr_loss should succeed");
112        let (l_low, _) = simclr_loss(&z, &z2, n, d, &low_t).expect("simclr_loss should succeed");
113        // Lower temperature sharpens softmax; well-aligned positives → lower loss.
114        assert!(l_low <= l_high + 1e-3, "l_low={l_low}, l_high={l_high}");
115    }
116}