oxicuda_ssl/contrastive/
simclr.rs1use crate::contrastive::info_nce::info_nce_loss;
12use crate::error::{SslError, SslResult};
13
14#[derive(Debug, Clone)]
16pub struct SimClrConfig {
17 pub temperature: f32,
19}
20
21impl Default for SimClrConfig {
22 fn default() -> Self {
23 Self { temperature: 0.1 }
24 }
25}
26
27impl SimClrConfig {
28 pub fn new(temperature: f32) -> SslResult<Self> {
33 if !(temperature.is_finite() && temperature > 0.0) {
34 return Err(SslError::InvalidTemperature { temp: temperature });
35 }
36 Ok(Self { temperature })
37 }
38}
39
40pub fn simclr_loss(
49 z_a: &[f32],
50 z_b: &[f32],
51 n: usize,
52 d: usize,
53 config: &SimClrConfig,
54) -> SslResult<(f32, f32)> {
55 info_nce_loss(z_a, z_b, n, d, config.temperature)
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[test]
63 fn simclr_default_temperature() {
64 let cfg = SimClrConfig::default();
65 assert!((cfg.temperature - 0.1).abs() < 1e-7);
66 }
67
68 #[test]
69 fn simclr_new_validates_temperature() {
70 assert!(SimClrConfig::new(0.0).is_err());
71 assert!(SimClrConfig::new(-1.0).is_err());
72 assert!(SimClrConfig::new(f32::NAN).is_err());
73 assert!(SimClrConfig::new(0.5).is_ok());
74 }
75
76 #[test]
77 fn simclr_loss_distinct_paired_views_low() {
78 let n = 4;
81 let d = 8;
82 let mut z = vec![0.0_f32; n * d];
83 for i in 0..n {
84 z[i * d + i] = 1.0;
85 }
86 let cfg = SimClrConfig::default();
87 let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).expect("simclr_loss should succeed");
88 assert!(loss < 0.5);
89 assert!((acc - 1.0).abs() < 1e-6);
90 }
91
92 #[test]
93 fn simclr_loss_random_views_finite() {
94 let n = 16;
95 let d = 32;
96 let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.0123).sin()).collect();
97 let z_b: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.0451).cos()).collect();
98 let cfg = SimClrConfig::default();
99 let (loss, _) = simclr_loss(&z_a, &z_b, n, d, &cfg).expect("simclr_loss should succeed");
100 assert!(loss.is_finite());
101 }
102
103 #[test]
104 fn simclr_temperature_affects_loss() {
105 let n = 4;
106 let d = 8;
107 let z: Vec<f32> = (0..n * d).map(|i| (i as f32) * 0.05).collect();
108 let z2: Vec<f32> = z.iter().map(|v| v + 0.1).collect();
109 let high_t = SimClrConfig { temperature: 1.0 };
110 let low_t = SimClrConfig { temperature: 0.05 };
111 let (l_high, _) = simclr_loss(&z, &z2, n, d, &high_t).expect("simclr_loss should succeed");
112 let (l_low, _) = simclr_loss(&z, &z2, n, d, &low_t).expect("simclr_loss should succeed");
113 assert!(l_low <= l_high + 1e-3, "l_low={l_low}, l_high={l_high}");
115 }
116}