oxicuda_ssl/head/
projector.rs1use crate::error::{SslError, SslResult};
7use crate::handle::LcgRng;
8
9#[derive(Debug, Clone)]
11pub struct MlpProjector {
12 pub in_dim: usize,
14 pub hidden_dim: usize,
16 pub out_dim: usize,
18 pub w1: Vec<f32>,
20 pub b1: Vec<f32>,
22 pub w2: Vec<f32>,
24 pub b2: Vec<f32>,
26}
27
28impl MlpProjector {
29 pub fn new(
34 in_dim: usize,
35 hidden_dim: usize,
36 out_dim: usize,
37 rng: &mut LcgRng,
38 ) -> SslResult<Self> {
39 if in_dim == 0 || hidden_dim == 0 || out_dim == 0 {
40 return Err(SslError::InvalidProjectorDim);
41 }
42 let scale1 = (2.0_f32 / in_dim as f32).sqrt();
43 let mut w1 = vec![0.0_f32; hidden_dim * in_dim];
44 rng.fill_normal(&mut w1);
45 for v in w1.iter_mut() {
46 *v *= scale1;
47 }
48 let scale2 = (2.0_f32 / hidden_dim as f32).sqrt();
49 let mut w2 = vec![0.0_f32; out_dim * hidden_dim];
50 rng.fill_normal(&mut w2);
51 for v in w2.iter_mut() {
52 *v *= scale2;
53 }
54 Ok(Self {
55 in_dim,
56 hidden_dim,
57 out_dim,
58 w1,
59 b1: vec![0.0_f32; hidden_dim],
60 w2,
61 b2: vec![0.0_f32; out_dim],
62 })
63 }
64
65 pub fn forward(&self, x: &[f32]) -> SslResult<Vec<f32>> {
70 if x.len() != self.in_dim {
71 return Err(SslError::DimensionMismatch {
72 expected: self.in_dim,
73 got: x.len(),
74 });
75 }
76 let mut h = vec![0.0_f32; self.hidden_dim];
77 for ((hj, b), row) in h
78 .iter_mut()
79 .zip(self.b1.iter())
80 .zip(self.w1.chunks(self.in_dim))
81 {
82 let mut acc = *b;
83 for (w, &xi) in row.iter().zip(x.iter()) {
84 acc += w * xi;
85 }
86 *hj = acc.max(0.0);
87 }
88 let mut out = vec![0.0_f32; self.out_dim];
89 for ((oj, b), row) in out
90 .iter_mut()
91 .zip(self.b2.iter())
92 .zip(self.w2.chunks(self.hidden_dim))
93 {
94 let mut acc = *b;
95 for (w, &hi) in row.iter().zip(h.iter()) {
96 acc += w * hi;
97 }
98 *oj = acc;
99 }
100 Ok(out)
101 }
102
103 pub fn forward_batch(&self, x: &[f32], n: usize) -> SslResult<Vec<f32>> {
108 if x.len() != n * self.in_dim {
109 return Err(SslError::DimensionMismatch {
110 expected: n * self.in_dim,
111 got: x.len(),
112 });
113 }
114 let mut out = Vec::with_capacity(n * self.out_dim);
115 for i in 0..n {
116 let row = &x[i * self.in_dim..(i + 1) * self.in_dim];
117 out.extend_from_slice(&self.forward(row)?);
118 }
119 Ok(out)
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn projector_construction_correct_shapes() {
129 let mut rng = LcgRng::new(0);
130 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
131 assert_eq!(p.in_dim, 8);
132 assert_eq!(p.hidden_dim, 16);
133 assert_eq!(p.out_dim, 4);
134 assert_eq!(p.w1.len(), 16 * 8);
135 assert_eq!(p.b1.len(), 16);
136 assert_eq!(p.w2.len(), 4 * 16);
137 assert_eq!(p.b2.len(), 4);
138 }
139
140 #[test]
141 fn projector_rejects_zero_dim() {
142 let mut rng = LcgRng::new(0);
143 assert!(MlpProjector::new(0, 4, 4, &mut rng).is_err());
144 assert!(MlpProjector::new(4, 0, 4, &mut rng).is_err());
145 assert!(MlpProjector::new(4, 4, 0, &mut rng).is_err());
146 }
147
148 #[test]
149 fn projector_forward_correct_shape() {
150 let mut rng = LcgRng::new(0);
151 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
152 let x = vec![0.0_f32; 8];
153 let y = p.forward(&x).expect("forward should succeed");
154 assert_eq!(y.len(), 4);
155 }
156
157 #[test]
158 fn projector_zero_input_returns_zero_when_no_bias() {
159 let mut rng = LcgRng::new(0);
160 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
161 let y = p.forward(&[0.0_f32; 8]).expect("forward should succeed");
162 for &v in &y {
163 assert!(v.abs() < 1e-6);
164 }
165 }
166
167 #[test]
168 fn projector_forward_rejects_dim_mismatch() {
169 let mut rng = LcgRng::new(0);
170 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
171 let r = p.forward(&[0.0_f32; 4]);
172 assert!(r.is_err());
173 }
174
175 #[test]
176 fn projector_forward_batch_correct_shape() {
177 let mut rng = LcgRng::new(0);
178 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
179 let x = vec![0.1_f32; 4 * 8];
180 let y = p
181 .forward_batch(&x, 4)
182 .expect("forward_batch should succeed");
183 assert_eq!(y.len(), 4 * 4);
184 }
185
186 #[test]
187 fn projector_forward_batch_rejects_dim_mismatch() {
188 let mut rng = LcgRng::new(0);
189 let p = MlpProjector::new(8, 16, 4, &mut rng).expect("new should succeed");
190 let r = p.forward_batch(&[0.0_f32; 16], 4);
191 assert!(r.is_err());
192 }
193}