Skip to main content

oxicuda_ssl/head/
projector.rs

1//! 2-layer MLP projection head used by SimCLR/BYOL/MoCo v2/v3.
2//!
3//! Architecture: `Linear(in → hidden) → ReLU → Linear(hidden → out)`.
4//! Initialised with Kaiming uniform on the inner Linear, zero bias.
5
6use crate::error::{SslError, SslResult};
7use crate::handle::LcgRng;
8
9/// 2-layer MLP projection head.
10#[derive(Debug, Clone)]
11pub struct MlpProjector {
12    /// Input dim.
13    pub in_dim: usize,
14    /// Hidden dim.
15    pub hidden_dim: usize,
16    /// Output (projection) dim.
17    pub out_dim: usize,
18    /// First layer weights `[hidden × in]` (row-major).
19    pub w1: Vec<f32>,
20    /// First layer bias `[hidden]`.
21    pub b1: Vec<f32>,
22    /// Second layer weights `[out × hidden]`.
23    pub w2: Vec<f32>,
24    /// Second layer bias `[out]`.
25    pub b2: Vec<f32>,
26}
27
28impl MlpProjector {
29    /// New projector with Kaiming-init weights.
30    ///
31    /// # Errors
32    /// [`SslError::InvalidProjectorDim`] if any dim is zero.
33    pub fn new(
34        in_dim: usize,
35        hidden_dim: usize,
36        out_dim: usize,
37        rng: &mut LcgRng,
38    ) -> SslResult<Self> {
39        if in_dim == 0 || hidden_dim == 0 || out_dim == 0 {
40            return Err(SslError::InvalidProjectorDim);
41        }
42        let scale1 = (2.0_f32 / in_dim as f32).sqrt();
43        let mut w1 = vec![0.0_f32; hidden_dim * in_dim];
44        rng.fill_normal(&mut w1);
45        for v in w1.iter_mut() {
46            *v *= scale1;
47        }
48        let scale2 = (2.0_f32 / hidden_dim as f32).sqrt();
49        let mut w2 = vec![0.0_f32; out_dim * hidden_dim];
50        rng.fill_normal(&mut w2);
51        for v in w2.iter_mut() {
52            *v *= scale2;
53        }
54        Ok(Self {
55            in_dim,
56            hidden_dim,
57            out_dim,
58            w1,
59            b1: vec![0.0_f32; hidden_dim],
60            w2,
61            b2: vec![0.0_f32; out_dim],
62        })
63    }
64
65    /// Forward pass on a single feature vector `[in_dim]` → `[out_dim]`.
66    ///
67    /// # Errors
68    /// [`SslError::DimensionMismatch`] if `x.len() != self.in_dim`.
69    pub fn forward(&self, x: &[f32]) -> SslResult<Vec<f32>> {
70        if x.len() != self.in_dim {
71            return Err(SslError::DimensionMismatch {
72                expected: self.in_dim,
73                got: x.len(),
74            });
75        }
76        let mut h = vec![0.0_f32; self.hidden_dim];
77        for ((hj, b), row) in h
78            .iter_mut()
79            .zip(self.b1.iter())
80            .zip(self.w1.chunks(self.in_dim))
81        {
82            let mut acc = *b;
83            for (w, &xi) in row.iter().zip(x.iter()) {
84                acc += w * xi;
85            }
86            *hj = acc.max(0.0);
87        }
88        let mut out = vec![0.0_f32; self.out_dim];
89        for ((oj, b), row) in out
90            .iter_mut()
91            .zip(self.b2.iter())
92            .zip(self.w2.chunks(self.hidden_dim))
93        {
94            let mut acc = *b;
95            for (w, &hi) in row.iter().zip(h.iter()) {
96                acc += w * hi;
97            }
98            *oj = acc;
99        }
100        Ok(out)
101    }
102
103    /// Forward pass on a batch `[N × in_dim]` → `[N × out_dim]`.
104    ///
105    /// # Errors
106    /// [`SslError::DimensionMismatch`] if `x.len() != n*in_dim`.
107    pub fn forward_batch(&self, x: &[f32], n: usize) -> SslResult<Vec<f32>> {
108        if x.len() != n * self.in_dim {
109            return Err(SslError::DimensionMismatch {
110                expected: n * self.in_dim,
111                got: x.len(),
112            });
113        }
114        let mut out = Vec::with_capacity(n * self.out_dim);
115        for i in 0..n {
116            let row = &x[i * self.in_dim..(i + 1) * self.in_dim];
117            out.extend_from_slice(&self.forward(row)?);
118        }
119        Ok(out)
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn projector_construction_correct_shapes() {
129        let mut rng = LcgRng::new(0);
130        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
131        assert_eq!(p.in_dim, 8);
132        assert_eq!(p.hidden_dim, 16);
133        assert_eq!(p.out_dim, 4);
134        assert_eq!(p.w1.len(), 16 * 8);
135        assert_eq!(p.b1.len(), 16);
136        assert_eq!(p.w2.len(), 4 * 16);
137        assert_eq!(p.b2.len(), 4);
138    }
139
140    #[test]
141    fn projector_rejects_zero_dim() {
142        let mut rng = LcgRng::new(0);
143        assert!(MlpProjector::new(0, 4, 4, &mut rng).is_err());
144        assert!(MlpProjector::new(4, 0, 4, &mut rng).is_err());
145        assert!(MlpProjector::new(4, 4, 0, &mut rng).is_err());
146    }
147
148    #[test]
149    fn projector_forward_correct_shape() {
150        let mut rng = LcgRng::new(0);
151        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
152        let x = vec![0.0_f32; 8];
153        let y = p.forward(&x).expect("forward should succeed");
154        assert_eq!(y.len(), 4);
155    }
156
157    #[test]
158    fn projector_zero_input_returns_zero_when_no_bias() {
159        let mut rng = LcgRng::new(0);
160        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
161        let y = p.forward(&[0.0_f32; 8]).expect("forward should succeed");
162        for &v in &y {
163            assert!(v.abs() < 1e-6);
164        }
165    }
166
167    #[test]
168    fn projector_forward_rejects_dim_mismatch() {
169        let mut rng = LcgRng::new(0);
170        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
171        let r = p.forward(&[0.0_f32; 4]);
172        assert!(r.is_err());
173    }
174
175    #[test]
176    fn projector_forward_batch_correct_shape() {
177        let mut rng = LcgRng::new(0);
178        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
179        let x = vec![0.1_f32; 4 * 8];
180        let y = p
181            .forward_batch(&x, 4)
182            .expect("forward_batch should succeed");
183        assert_eq!(y.len(), 4 * 4);
184    }
185
186    #[test]
187    fn projector_forward_batch_rejects_dim_mismatch() {
188        let mut rng = LcgRng::new(0);
189        let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
190        let r = p.forward_batch(&[0.0_f32; 16], 4);
191        assert!(r.is_err());
192    }
193}