Skip to main content

oxicuda_ssl/ssl/
sim_siam.rs

1//! SimSiam struct — owns projector + predictor weights.
2//!
3//! This is a struct-based incarnation of SimSiam (Chen & He 2021) that manages
4//! its own weight matrices in-process, unlike the functional API in
5//! [`crate::non_contrastive::simsiam`] which requires the caller to provide
6//! pre-computed projections.
7//!
8//! ## Architecture
9//! ```text
10//! Projector:  d_encoder → (Linear + ReLU) → d_projector → Linear → d_out → L2-norm
11//! Predictor:  d_out     → (Linear + ReLU) → d_predictor → Linear → d_out → L2-norm
12//! ```
13//!
14//! ## Loss
15//! ```text
16//! p1 = predict(project(z1)),   p2 = predict(project(z2))
17//! z1_p = project(z1),          z2_p = project(z2)
18//! L = (D(p1, sg(z2_p)) + D(p2, sg(z1_p))) / 2
19//! D(a, b) = -(a · b)     [both are already L2-normalised]
20//! ```
21
22use crate::error::{SslError, SslResult};
23use crate::handle::LcgRng;
24
25// ─── Configuration ────────────────────────────────────────────────────────────
26
27/// Hyper-parameters for the struct-based [`SimSiam`] model.
28#[derive(Debug, Clone)]
29pub struct SimSiamConfig {
30    /// Backbone output dimension (input to projector).
31    pub d_encoder: usize,
32    /// Projector hidden dimension.
33    pub d_projector: usize,
34    /// Predictor hidden dimension.
35    pub d_predictor: usize,
36    /// Output dimension (projector output = predictor I/O).
37    pub d_out: usize,
38}
39
40impl Default for SimSiamConfig {
41    fn default() -> Self {
42        Self {
43            d_encoder: 64,
44            d_projector: 128,
45            d_predictor: 64,
46            d_out: 32,
47        }
48    }
49}
50
51// ─── SimSiam model ───────────────────────────────────────────────────────────
52
53/// Struct-based SimSiam model that owns its projector and predictor weights.
54///
55/// All weight matrices use Kaiming (He) initialisation with `scale = sqrt(2 / fan_in)`.
56#[derive(Debug, Clone)]
57pub struct SimSiam {
58    /// Projector first layer weights `[d_projector × d_encoder]`.
59    proj_w1: Vec<f32>,
60    /// Projector first layer bias `[d_projector]`.
61    proj_b1: Vec<f32>,
62    /// Projector second layer weights `[d_out × d_projector]`.
63    proj_w2: Vec<f32>,
64    /// Projector second layer bias `[d_out]`.
65    proj_b2: Vec<f32>,
66    /// Predictor first layer weights `[d_predictor × d_out]`.
67    pred_w1: Vec<f32>,
68    /// Predictor first layer bias `[d_predictor]`.
69    pred_b1: Vec<f32>,
70    /// Predictor second layer weights `[d_out × d_predictor]`.
71    pred_w2: Vec<f32>,
72    /// Predictor second layer bias `[d_out]`.
73    pred_b2: Vec<f32>,
74    /// Configuration used to construct this model.
75    config: SimSiamConfig,
76}
77
78impl SimSiam {
79    /// Create a new `SimSiam` model with Kaiming-initialised weights.
80    ///
81    /// # Errors
82    /// [`SslError::InvalidParameter`] when any dimension in `config` is zero.
83    pub fn new(config: SimSiamConfig, rng: &mut LcgRng) -> SslResult<Self> {
84        if config.d_encoder == 0 {
85            return Err(SslError::InvalidParameter {
86                name: "d_encoder".into(),
87                reason: "must be > 0".into(),
88            });
89        }
90        if config.d_projector == 0 {
91            return Err(SslError::InvalidParameter {
92                name: "d_projector".into(),
93                reason: "must be > 0".into(),
94            });
95        }
96        if config.d_predictor == 0 {
97            return Err(SslError::InvalidParameter {
98                name: "d_predictor".into(),
99                reason: "must be > 0".into(),
100            });
101        }
102        if config.d_out == 0 {
103            return Err(SslError::InvalidParameter {
104                name: "d_out".into(),
105                reason: "must be > 0".into(),
106            });
107        }
108
109        let proj_w1 = kaiming_init(config.d_projector, config.d_encoder, rng);
110        let proj_b1 = vec![0.0_f32; config.d_projector];
111        let proj_w2 = kaiming_init(config.d_out, config.d_projector, rng);
112        let proj_b2 = vec![0.0_f32; config.d_out];
113
114        let pred_w1 = kaiming_init(config.d_predictor, config.d_out, rng);
115        let pred_b1 = vec![0.0_f32; config.d_predictor];
116        let pred_w2 = kaiming_init(config.d_out, config.d_predictor, rng);
117        let pred_b2 = vec![0.0_f32; config.d_out];
118
119        Ok(Self {
120            proj_w1,
121            proj_b1,
122            proj_w2,
123            proj_b2,
124            pred_w1,
125            pred_b1,
126            pred_w2,
127            pred_b2,
128            config,
129        })
130    }
131
132    /// Project a single encoder output vector.
133    ///
134    /// Computes `z = L2_norm(proj_w2 · ReLU(proj_w1 · x + proj_b1) + proj_b2)`.
135    ///
136    /// # Arguments
137    /// * `z` — encoder output `[d_encoder]`.
138    ///
139    /// # Errors
140    /// [`SslError::DimensionMismatch`] when `z.len() != d_encoder`.
141    pub fn project(&self, z: &[f32]) -> SslResult<Vec<f32>> {
142        let d = self.config.d_encoder;
143        if z.len() != d {
144            return Err(SslError::DimensionMismatch {
145                expected: d,
146                got: z.len(),
147            });
148        }
149        let hidden = linear_relu(&self.proj_w1, &self.proj_b1, z, d, self.config.d_projector);
150        let out = linear(
151            &self.proj_w2,
152            &self.proj_b2,
153            &hidden,
154            self.config.d_projector,
155            self.config.d_out,
156        );
157        Ok(l2_normalize(out))
158    }
159
160    /// Apply the predictor to a projected representation.
161    ///
162    /// Computes `p = L2_norm(pred_w2 · ReLU(pred_w1 · proj + pred_b1) + pred_b2)`.
163    ///
164    /// # Arguments
165    /// * `p` — projected representation `[d_out]`.
166    ///
167    /// # Errors
168    /// [`SslError::DimensionMismatch`] when `p.len() != d_out`.
169    pub fn predict(&self, p: &[f32]) -> SslResult<Vec<f32>> {
170        let d = self.config.d_out;
171        if p.len() != d {
172            return Err(SslError::DimensionMismatch {
173                expected: d,
174                got: p.len(),
175            });
176        }
177        let hidden = linear_relu(&self.pred_w1, &self.pred_b1, p, d, self.config.d_predictor);
178        let out = linear(
179            &self.pred_w2,
180            &self.pred_b2,
181            &hidden,
182            self.config.d_predictor,
183            self.config.d_out,
184        );
185        Ok(l2_normalize(out))
186    }
187
188    /// Compute the symmetric SimSiam loss for two encoder outputs.
189    ///
190    /// Implements `L = (D(p1, sg(z2_p)) + D(p2, sg(z1_p))) / 2` where
191    /// `D(a, b) = -(a · b)` for unit-norm vectors and `sg` denotes stop-gradient
192    /// (a no-op in this pure-Rust implementation since there is no autograd engine).
193    ///
194    /// # Arguments
195    /// * `z1` — encoder output from view 1 `[d_encoder]`.
196    /// * `z2` — encoder output from view 2 `[d_encoder]`.
197    ///
198    /// # Errors
199    /// Propagates dimension mismatch errors from [`Self::project`] and [`Self::predict`].
200    pub fn loss(&self, z1: &[f32], z2: &[f32]) -> SslResult<f32> {
201        let z1_proj = self.project(z1)?;
202        let z2_proj = self.project(z2)?;
203        let p1 = self.predict(&z1_proj)?;
204        let p2 = self.predict(&z2_proj)?;
205
206        // D(p, z_sg) = -(p · z_sg); both are L2-normalised → cosine distance
207        let d1 = neg_dot(&p1, &z2_proj);
208        let d2 = neg_dot(&p2, &z1_proj);
209        Ok((d1 + d2) * 0.5)
210    }
211
212    /// Return the output dimension of the projector (= predictor I/O dim).
213    #[inline]
214    #[must_use]
215    pub fn d_out(&self) -> usize {
216        self.config.d_out
217    }
218
219    /// Overwrite the predictor with an exact direction-preserving identity map.
220    ///
221    /// The predictor MLP `L2(W2 · ReLU(W1·p + b1) + b2)` is normally a learned,
222    /// randomly-initialised non-linear transform, so for a *random* predictor the
223    /// SimSiam loss of two identical views is some arbitrary value in `[-1, 1]`.
224    /// SimSiam's negative-cosine loss only attains its minimum of `-1` for
225    /// identical views when the predictor preserves the projection's direction.
226    ///
227    /// This installs such an identity predictor. The ReLU non-linearity is bridged
228    /// with the standard positive/negative split: the hidden layer computes
229    /// `[ReLU(p), ReLU(-p)]` and the output layer reconstructs `p = p⁺ - p⁻`,
230    /// reproducing the input exactly. The trailing L2-norm then leaves an
231    /// already unit-norm projection unchanged, so `predict(project(z)) == project(z)`
232    /// and `loss(z, z) == -1`.
233    ///
234    /// This requires the predictor hidden dimension to be exactly twice the output
235    /// dimension so the two halves can hold the positive and negative parts.
236    ///
237    /// # Errors
238    /// [`SslError::InvalidParameter`] when `d_predictor != 2 * d_out`.
239    pub fn set_identity_predictor(&mut self) -> SslResult<()> {
240        let d_out = self.config.d_out;
241        let d_pred = self.config.d_predictor;
242        if d_pred != 2 * d_out {
243            return Err(SslError::InvalidParameter {
244                name: "d_predictor".into(),
245                reason: "identity predictor requires d_predictor == 2 * d_out".into(),
246            });
247        }
248
249        // W1: `[d_predictor × d_out]`. Row i selects +p_i, row (d_out + i) selects -p_i.
250        let mut pred_w1 = vec![0.0_f32; d_pred * d_out];
251        for i in 0..d_out {
252            pred_w1[i * d_out + i] = 1.0;
253            pred_w1[(d_out + i) * d_out + i] = -1.0;
254        }
255        // W2: `[d_out × d_predictor]`. Output i = p⁺_i - p⁻_i.
256        let mut pred_w2 = vec![0.0_f32; d_out * d_pred];
257        for i in 0..d_out {
258            pred_w2[i * d_pred + i] = 1.0;
259            pred_w2[i * d_pred + (d_out + i)] = -1.0;
260        }
261
262        self.pred_w1 = pred_w1;
263        self.pred_b1 = vec![0.0_f32; d_pred];
264        self.pred_w2 = pred_w2;
265        self.pred_b2 = vec![0.0_f32; d_out];
266        Ok(())
267    }
268}
269
270// ─── Internal helpers ────────────────────────────────────────────────────────
271
272/// Allocate a `[out_dim × in_dim]` weight matrix with Kaiming normal init.
273///
274/// `scale = sqrt(2 / in_dim)` (He initialisation for ReLU activations).
275fn kaiming_init(out_dim: usize, in_dim: usize, rng: &mut LcgRng) -> Vec<f32> {
276    let scale = (2.0_f32 / in_dim as f32).sqrt();
277    let mut w = vec![0.0_f32; out_dim * in_dim];
278    rng.fill_normal(&mut w);
279    for v in w.iter_mut() {
280        *v *= scale;
281    }
282    w
283}
284
285/// Standard matrix-vector multiply: `out[i] = b[i] + Σ_j w[i·in_dim + j] * x[j]`.
286///
287/// No activation is applied; shape is `[out_dim × in_dim] × [in_dim] = [out_dim]`.
288fn linear(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
289    let mut out = vec![0.0_f32; out_dim];
290    for i in 0..out_dim {
291        let mut acc = b[i];
292        let row_start = i * in_dim;
293        for j in 0..in_dim {
294            acc += w[row_start + j] * x[j];
295        }
296        out[i] = acc;
297    }
298    out
299}
300
301/// `linear` followed by element-wise ReLU.
302fn linear_relu(w: &[f32], b: &[f32], x: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
303    let mut out = linear(w, b, x, in_dim, out_dim);
304    for v in out.iter_mut() {
305        *v = v.max(0.0);
306    }
307    out
308}
309
310/// L2-normalise a vector in-place, returning it.
311///
312/// A floor of `1e-12` on the norm prevents division by zero.
313fn l2_normalize(mut v: Vec<f32>) -> Vec<f32> {
314    let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt().max(1e-12);
315    for x in v.iter_mut() {
316        *x /= norm;
317    }
318    v
319}
320
321/// Compute the negative dot product `-(a · b)` for two same-length slices.
322fn neg_dot(a: &[f32], b: &[f32]) -> f32 {
323    -a.iter()
324        .zip(b.iter())
325        .map(|(&ai, &bi)| ai * bi)
326        .sum::<f32>()
327}
328
329// ─── Tests ────────────────────────────────────────────────────────────────────
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::handle::LcgRng;
335
336    fn make_simsiam(seed: u64) -> SimSiam {
337        let mut rng = LcgRng::new(seed);
338        SimSiam::new(
339            SimSiamConfig {
340                d_encoder: 16,
341                d_projector: 32,
342                d_predictor: 16,
343                d_out: 8,
344            },
345            &mut rng,
346        )
347        .expect("value should be present")
348    }
349
350    fn random_vec(n: usize, seed: u64) -> Vec<f32> {
351        let mut rng = LcgRng::new(seed);
352        let mut v = vec![0.0_f32; n];
353        rng.fill_normal(&mut v);
354        v
355    }
356
357    #[test]
358    fn project_shape() {
359        let ss = make_simsiam(1);
360        let z = random_vec(16, 2);
361        let out = ss.project(&z).expect("project should succeed");
362        assert_eq!(out.len(), 8, "project output must have len == d_out");
363    }
364
365    #[test]
366    fn predict_shape() {
367        let ss = make_simsiam(3);
368        let p = random_vec(8, 4);
369        let out = ss.predict(&p).expect("predict should succeed");
370        assert_eq!(out.len(), 8, "predict output must have len == d_out");
371    }
372
373    #[test]
374    fn loss_finite() {
375        let ss = make_simsiam(5);
376        let z1 = random_vec(16, 6);
377        let z2 = random_vec(16, 7);
378        let l = ss.loss(&z1, &z2).expect("loss should succeed");
379        assert!(l.is_finite(), "loss must be finite, got {l}");
380    }
381
382    #[test]
383    fn loss_in_range() {
384        let ss = make_simsiam(8);
385        let z1 = random_vec(16, 9);
386        let z2 = random_vec(16, 10);
387        let l = ss.loss(&z1, &z2).expect("loss should succeed");
388        assert!(
389            (-1.0 - 1e-5..=1.0 + 1e-5).contains(&l),
390            "loss={l} must be in [-1, 1]"
391        );
392    }
393
394    #[test]
395    fn loss_symmetric() {
396        let ss = make_simsiam(11);
397        let z1 = random_vec(16, 12);
398        let z2 = random_vec(16, 13);
399        let l12 = ss.loss(&z1, &z2).expect("loss should succeed");
400        let l21 = ss.loss(&z2, &z1).expect("loss should succeed");
401        assert!(
402            (l12 - l21).abs() < 1e-5,
403            "loss(z1,z2)={l12} != loss(z2,z1)={l21}"
404        );
405    }
406
407    #[test]
408    fn different_views_different_projections() {
409        let ss = make_simsiam(14);
410        let z1 = random_vec(16, 15);
411        let z2 = random_vec(16, 16);
412        let p1 = ss.project(&z1).expect("project should succeed");
413        let p2 = ss.project(&z2).expect("project should succeed");
414        let diff: f32 = p1.iter().zip(p2.iter()).map(|(a, b)| (a - b).abs()).sum();
415        assert!(
416            diff > 1e-6,
417            "projections of different inputs must differ, diff={diff}"
418        );
419    }
420
421    #[test]
422    fn identical_views_low_loss() {
423        // For identical views z1 == z2 we have z1_p == z2_p and p1 == p2, so the
424        // symmetric loss collapses to -dot(p, z_p). This reaches its minimum of -1
425        // only when the predictor preserves the projection's direction. With a
426        // *random* predictor MLP, dot(p, z_p) is some arbitrary cosine in [-1, 1],
427        // so the loss is not necessarily near -1. Installing a direction-preserving
428        // identity predictor makes predict(project(z)) == project(z) exactly, which
429        // drives the identical-view loss to its -1 floor.
430        let mut ss = make_simsiam(17);
431        ss.set_identity_predictor()
432            .expect("config has d_predictor == 2 * d_out");
433        let z = random_vec(16, 18);
434        let l = ss.loss(&z, &z).expect("loss should succeed");
435        assert!(
436            (l - (-1.0)).abs() < 1e-5,
437            "with a direction-preserving predictor, loss for identical views must be -1, got {l}"
438        );
439    }
440
441    #[test]
442    fn identity_predictor_is_direction_preserving() {
443        // The identity predictor must reproduce its (unit-norm) input exactly,
444        // so predict(project(z)) == project(z) for every z.
445        let mut ss = make_simsiam(27);
446        ss.set_identity_predictor()
447            .expect("config has d_predictor == 2 * d_out");
448        for seed in 0..6_u64 {
449            let z = random_vec(16, seed + 200);
450            let zp = ss.project(&z).expect("project should succeed");
451            let p = ss.predict(&zp).expect("predict should succeed");
452            let max_diff = zp
453                .iter()
454                .zip(p.iter())
455                .map(|(a, b)| (a - b).abs())
456                .fold(0.0_f32, f32::max);
457            assert!(
458                max_diff < 1e-5,
459                "identity predictor must reproduce input, max|p-zp|={max_diff} (seed={seed})"
460            );
461        }
462    }
463
464    #[test]
465    fn set_identity_predictor_requires_double_hidden() {
466        // d_predictor (8) != 2 * d_out (16) → must error, never panic.
467        let mut rng = LcgRng::new(28);
468        let mut ss = SimSiam::new(
469            SimSiamConfig {
470                d_encoder: 16,
471                d_projector: 32,
472                d_predictor: 8,
473                d_out: 8,
474            },
475            &mut rng,
476        )
477        .expect("value should be present");
478        assert!(
479            ss.set_identity_predictor().is_err(),
480            "identity predictor with d_predictor != 2*d_out must return Err"
481        );
482    }
483
484    #[test]
485    fn d_out_0_error() {
486        let mut rng = LcgRng::new(19);
487        let result = SimSiam::new(
488            SimSiamConfig {
489                d_encoder: 8,
490                d_projector: 16,
491                d_predictor: 8,
492                d_out: 0,
493            },
494            &mut rng,
495        );
496        assert!(result.is_err(), "d_out=0 must return Err");
497    }
498
499    #[test]
500    fn project_output_normalized() {
501        let ss = make_simsiam(20);
502        let z = random_vec(16, 21);
503        let out = ss.project(&z).expect("project should succeed");
504        let norm: f32 = out.iter().map(|&x| x * x).sum::<f32>().sqrt();
505        assert!(
506            (norm - 1.0).abs() < 1e-5,
507            "project output must be unit-norm, norm={norm}"
508        );
509    }
510
511    #[test]
512    fn loss_stop_grad_invariant() {
513        // Verify loss is finite for various arbitrary inputs — stop-grad is a
514        // no-op in pure Rust and must not cause numerical issues.
515        let ss = make_simsiam(22);
516        for seed in 0..8_u64 {
517            let z1 = random_vec(16, seed * 2 + 100);
518            let z2 = random_vec(16, seed * 2 + 101);
519            let l = ss.loss(&z1, &z2).expect("loss should succeed");
520            assert!(
521                l.is_finite(),
522                "loss must be finite for seed={seed}, got {l}"
523            );
524        }
525    }
526
527    #[test]
528    fn d_encoder_0_error() {
529        let mut rng = LcgRng::new(23);
530        assert!(
531            SimSiam::new(
532                SimSiamConfig {
533                    d_encoder: 0,
534                    d_projector: 16,
535                    d_predictor: 8,
536                    d_out: 8
537                },
538                &mut rng
539            )
540            .is_err()
541        );
542    }
543
544    #[test]
545    fn predict_output_normalized() {
546        let ss = make_simsiam(24);
547        let p = random_vec(8, 25);
548        let out = ss.predict(&p).expect("predict should succeed");
549        let norm: f32 = out.iter().map(|&x| x * x).sum::<f32>().sqrt();
550        assert!(
551            (norm - 1.0).abs() < 1e-5,
552            "predict output must be unit-norm, norm={norm}"
553        );
554    }
555
556    #[test]
557    fn d_out_accessor() {
558        let ss = make_simsiam(26);
559        assert_eq!(ss.d_out(), 8);
560    }
561}