oxicuda_ssl/head/
predictor.rs1use crate::error::{SslError, SslResult};
10use crate::handle::LcgRng;
11
12#[derive(Debug, Clone)]
14pub struct PredictorHead {
15 pub in_dim: usize,
17 pub hidden_dim: usize,
19 pub out_dim: usize,
21 pub w1: Vec<f32>,
23 pub b1: Vec<f32>,
25 pub w2: Vec<f32>,
27 pub b2: Vec<f32>,
29}
30
31impl PredictorHead {
32 pub fn new(
37 in_dim: usize,
38 hidden_dim: usize,
39 out_dim: usize,
40 rng: &mut LcgRng,
41 ) -> SslResult<Self> {
42 if in_dim == 0 || hidden_dim == 0 || out_dim == 0 {
43 return Err(SslError::InvalidProjectorDim);
44 }
45 let scale1 = (2.0_f32 / in_dim as f32).sqrt();
46 let mut w1 = vec![0.0_f32; hidden_dim * in_dim];
47 rng.fill_normal(&mut w1);
48 for v in w1.iter_mut() {
49 *v *= scale1;
50 }
51 let scale2 = (2.0_f32 / hidden_dim as f32).sqrt();
52 let mut w2 = vec![0.0_f32; out_dim * hidden_dim];
53 rng.fill_normal(&mut w2);
54 for v in w2.iter_mut() {
55 *v *= scale2;
56 }
57 Ok(Self {
58 in_dim,
59 hidden_dim,
60 out_dim,
61 w1,
62 b1: vec![0.0_f32; hidden_dim],
63 w2,
64 b2: vec![0.0_f32; out_dim],
65 })
66 }
67
68 pub fn forward(&self, x: &[f32]) -> SslResult<Vec<f32>> {
73 if x.len() != self.in_dim {
74 return Err(SslError::DimensionMismatch {
75 expected: self.in_dim,
76 got: x.len(),
77 });
78 }
79 let mut h = vec![0.0_f32; self.hidden_dim];
80 for ((hj, b), row) in h
81 .iter_mut()
82 .zip(self.b1.iter())
83 .zip(self.w1.chunks(self.in_dim))
84 {
85 let mut acc = *b;
86 for (w, &xi) in row.iter().zip(x.iter()) {
87 acc += w * xi;
88 }
89 *hj = acc.max(0.0);
90 }
91 let mut out = vec![0.0_f32; self.out_dim];
92 for ((oj, b), row) in out
93 .iter_mut()
94 .zip(self.b2.iter())
95 .zip(self.w2.chunks(self.hidden_dim))
96 {
97 let mut acc = *b;
98 for (w, &hi) in row.iter().zip(h.iter()) {
99 acc += w * hi;
100 }
101 *oj = acc;
102 }
103 Ok(out)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn predictor_shapes_correct() {
113 let mut rng = LcgRng::new(0);
114 let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
115 assert_eq!(p.w1.len(), 8 * 4);
116 assert_eq!(p.b1.len(), 8);
117 assert_eq!(p.w2.len(), 4 * 8);
118 assert_eq!(p.b2.len(), 4);
119 }
120
121 #[test]
122 fn predictor_forward_correct_dim() {
123 let mut rng = LcgRng::new(0);
124 let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
125 let x = vec![0.5_f32; 4];
126 let y = p.forward(&x).expect("forward should succeed");
127 assert_eq!(y.len(), 4);
128 }
129
130 #[test]
131 fn predictor_rejects_zero_dim() {
132 let mut rng = LcgRng::new(0);
133 assert!(PredictorHead::new(0, 4, 4, &mut rng).is_err());
134 assert!(PredictorHead::new(4, 0, 4, &mut rng).is_err());
135 assert!(PredictorHead::new(4, 4, 0, &mut rng).is_err());
136 }
137
138 #[test]
139 fn predictor_rejects_dim_mismatch() {
140 let mut rng = LcgRng::new(0);
141 let p = PredictorHead::new(4, 8, 4, &mut rng).expect("new should succeed");
142 let r = p.forward(&[0.0_f32; 5]);
143 assert!(r.is_err());
144 }
145}