Skip to main content

oxicuda_ssl/head/
predictor.rs

1//! BYOL / SimSiam predictor head — an additional 2-layer MLP placed only on
2//! the *online* branch so the loss is asymmetric.
3//!
4//! Architecturally identical to [`crate::head::projector::MlpProjector`] but
5//! semantically separate: the predictor parameters are *only* updated through
6//! the online stream's gradients (the target stream uses an EMA copy of the
7//! projector — never a copy of the predictor).
8
9use crate::error::{SslError, SslResult};
10use crate::handle::LcgRng;
11
12/// 2-layer MLP predictor head used by BYOL / SimSiam.
13#[derive(Debug, Clone)]
14pub struct PredictorHead {
15    /// Input dim.
16    pub in_dim: usize,
17    /// Hidden dim.
18    pub hidden_dim: usize,
19    /// Output dim (same as projector output for both BYOL and SimSiam).
20    pub out_dim: usize,
21    /// First layer weights `[hidden × in]`.
22    pub w1: Vec<f32>,
23    /// First layer bias.
24    pub b1: Vec<f32>,
25    /// Second layer weights `[out × hidden]`.
26    pub w2: Vec<f32>,
27    /// Second layer bias.
28    pub b2: Vec<f32>,
29}
30
31impl PredictorHead {
32    /// New predictor with Kaiming-init weights.
33    ///
34    /// # Errors
35    /// [`SslError::InvalidProjectorDim`] if any dim is zero.
36    pub fn new(
37        in_dim: usize,
38        hidden_dim: usize,
39        out_dim: usize,
40        rng: &mut LcgRng,
41    ) -> SslResult<Self> {
42        if in_dim == 0 || hidden_dim == 0 || out_dim == 0 {
43            return Err(SslError::InvalidProjectorDim);
44        }
45        let scale1 = (2.0_f32 / in_dim as f32).sqrt();
46        let mut w1 = vec![0.0_f32; hidden_dim * in_dim];
47        rng.fill_normal(&mut w1);
48        for v in w1.iter_mut() {
49            *v *= scale1;
50        }
51        let scale2 = (2.0_f32 / hidden_dim as f32).sqrt();
52        let mut w2 = vec![0.0_f32; out_dim * hidden_dim];
53        rng.fill_normal(&mut w2);
54        for v in w2.iter_mut() {
55            *v *= scale2;
56        }
57        Ok(Self {
58            in_dim,
59            hidden_dim,
60            out_dim,
61            w1,
62            b1: vec![0.0_f32; hidden_dim],
63            w2,
64            b2: vec![0.0_f32; out_dim],
65        })
66    }
67
68    /// Forward pass on a single feature vector.
69    ///
70    /// # Errors
71    /// [`SslError::DimensionMismatch`] if `x.len() != self.in_dim`.
72    pub fn forward(&self, x: &[f32]) -> SslResult<Vec<f32>> {
73        if x.len() != self.in_dim {
74            return Err(SslError::DimensionMismatch {
75                expected: self.in_dim,
76                got: x.len(),
77            });
78        }
79        let mut h = vec![0.0_f32; self.hidden_dim];
80        for ((hj, b), row) in h
81            .iter_mut()
82            .zip(self.b1.iter())
83            .zip(self.w1.chunks(self.in_dim))
84        {
85            let mut acc = *b;
86            for (w, &xi) in row.iter().zip(x.iter()) {
87                acc += w * xi;
88            }
89            *hj = acc.max(0.0);
90        }
91        let mut out = vec![0.0_f32; self.out_dim];
92        for ((oj, b), row) in out
93            .iter_mut()
94            .zip(self.b2.iter())
95            .zip(self.w2.chunks(self.hidden_dim))
96        {
97            let mut acc = *b;
98            for (w, &hi) in row.iter().zip(h.iter()) {
99                acc += w * hi;
100            }
101            *oj = acc;
102        }
103        Ok(out)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn predictor_shapes_correct() {
113        let mut rng = LcgRng::new(0);
114        let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
115        assert_eq!(p.w1.len(), 8 * 4);
116        assert_eq!(p.b1.len(), 8);
117        assert_eq!(p.w2.len(), 4 * 8);
118        assert_eq!(p.b2.len(), 4);
119    }
120
121    #[test]
122    fn predictor_forward_correct_dim() {
123        let mut rng = LcgRng::new(0);
124        let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
125        let x = vec![0.5_f32; 4];
126        let y = p.forward(&x).expect("forward should succeed");
127        assert_eq!(y.len(), 4);
128    }
129
130    #[test]
131    fn predictor_rejects_zero_dim() {
132        let mut rng = LcgRng::new(0);
133        assert!(PredictorHead::new(0, 4, 4, &mut rng).is_err());
134        assert!(PredictorHead::new(4, 0, 4, &mut rng).is_err());
135        assert!(PredictorHead::new(4, 4, 0, &mut rng).is_err());
136    }
137
138    #[test]
139    fn predictor_rejects_dim_mismatch() {
140        let mut rng = LcgRng::new(0);
141        let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
142        let r = p.forward(&[0.0_f32; 5]);
143        assert!(r.is_err());
144    }
145}