1#[derive(Debug, Clone, PartialEq)]
9pub enum InitMethod {
10 Zeros,
12 Identity,
14 Random(u64),
16}
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum ActivationFn {
21 ReLU,
23 Tanh,
25 Sigmoid,
27 None,
29}
30
31#[derive(Debug, Clone, PartialEq)]
33pub struct ProjectionMatrix {
34 pub input_dim: usize,
36 pub output_dim: usize,
38 pub weights: Vec<Vec<f64>>,
40 pub bias: Vec<f64>,
42}
43
44impl ProjectionMatrix {
45 fn new_zeros(input_dim: usize, output_dim: usize) -> Self {
46 Self {
47 input_dim,
48 output_dim,
49 weights: vec![vec![0.0; input_dim]; output_dim],
50 bias: vec![0.0; output_dim],
51 }
52 }
53
54 fn new_identity(input_dim: usize, output_dim: usize) -> Self {
55 let mut weights = vec![vec![0.0; input_dim]; output_dim];
56 let min_dim = input_dim.min(output_dim);
57 for (i, row) in weights.iter_mut().enumerate().take(min_dim) {
58 row[i] = 1.0;
59 }
60 Self {
61 input_dim,
62 output_dim,
63 weights,
64 bias: vec![0.0; output_dim],
65 }
66 }
67
68 fn new_random(input_dim: usize, output_dim: usize, seed: u64) -> Self {
70 let limit = (6.0_f64 / (input_dim + output_dim) as f64).sqrt();
72 let mut state = seed.wrapping_add(1);
73 let mut weights = vec![vec![0.0; input_dim]; output_dim];
74 for row in weights.iter_mut() {
75 for w in row.iter_mut() {
76 state = state
78 .wrapping_mul(6_364_136_223_846_793_005)
79 .wrapping_add(1_442_695_040_888_963_407);
80 let u = (state >> 11) as f64 / (1u64 << 53) as f64;
82 *w = (u * 2.0 - 1.0) * limit;
84 }
85 }
86 Self {
87 input_dim,
88 output_dim,
89 weights,
90 bias: vec![0.0; output_dim],
91 }
92 }
93}
94
95fn apply_activation(val: f64, act: &ActivationFn) -> f64 {
101 match act {
102 ActivationFn::ReLU => val.max(0.0),
103 ActivationFn::Tanh => val.tanh(),
104 ActivationFn::Sigmoid => 1.0 / (1.0 + (-val).exp()),
105 ActivationFn::None => val,
106 }
107}
108
109#[derive(Debug, Clone)]
115pub struct ProjectionLayer {
116 matrix: ProjectionMatrix,
117 activation: Option<ActivationFn>,
118}
119
120impl ProjectionLayer {
121 pub fn new(input_dim: usize, output_dim: usize, init: InitMethod) -> Self {
123 let matrix = match init {
124 InitMethod::Zeros => ProjectionMatrix::new_zeros(input_dim, output_dim),
125 InitMethod::Identity => ProjectionMatrix::new_identity(input_dim, output_dim),
126 InitMethod::Random(seed) => ProjectionMatrix::new_random(input_dim, output_dim, seed),
127 };
128 Self {
129 matrix,
130 activation: None,
131 }
132 }
133
134 pub fn with_activation(mut self, activation: ActivationFn) -> Self {
136 self.activation = Some(activation);
137 self
138 }
139
140 pub fn project(&self, input: &[f64]) -> Vec<f64> {
144 debug_assert_eq!(
145 input.len(),
146 self.matrix.input_dim,
147 "input dimension mismatch"
148 );
149 let mut output = Vec::with_capacity(self.matrix.output_dim);
150 for (i, row) in self.matrix.weights.iter().enumerate() {
151 let mut sum = self.matrix.bias[i];
152 for (w, x) in row.iter().zip(input.iter()) {
153 sum += w * x;
154 }
155 let activated = if let Some(act) = &self.activation {
156 apply_activation(sum, act)
157 } else {
158 sum
159 };
160 output.push(activated);
161 }
162 output
163 }
164
165 pub fn project_batch(&self, inputs: &[Vec<f64>]) -> Vec<Vec<f64>> {
167 inputs.iter().map(|inp| self.project(inp)).collect()
168 }
169
170 pub fn input_dim(&self) -> usize {
172 self.matrix.input_dim
173 }
174
175 pub fn output_dim(&self) -> usize {
177 self.matrix.output_dim
178 }
179
180 pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) -> Result<(), String> {
184 if weights.len() != self.matrix.output_dim {
185 return Err(format!(
186 "expected {} output rows, got {}",
187 self.matrix.output_dim,
188 weights.len()
189 ));
190 }
191 for (i, row) in weights.iter().enumerate() {
192 if row.len() != self.matrix.input_dim {
193 return Err(format!(
194 "row {} has {} columns, expected {}",
195 i,
196 row.len(),
197 self.matrix.input_dim
198 ));
199 }
200 }
201 self.matrix.weights = weights;
202 Ok(())
203 }
204
205 pub fn set_bias(&mut self, bias: Vec<f64>) -> Result<(), String> {
209 if bias.len() != self.matrix.output_dim {
210 return Err(format!(
211 "expected bias length {}, got {}",
212 self.matrix.output_dim,
213 bias.len()
214 ));
215 }
216 self.matrix.bias = bias;
217 Ok(())
218 }
219
220 pub fn parameter_count(&self) -> usize {
222 self.matrix.input_dim * self.matrix.output_dim + self.matrix.output_dim
223 }
224}
225
226#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
237 fn test_activation_relu_positive() {
238 assert_eq!(apply_activation(2.5, &ActivationFn::ReLU), 2.5);
239 }
240
241 #[test]
242 fn test_activation_relu_negative() {
243 assert_eq!(apply_activation(-3.0, &ActivationFn::ReLU), 0.0);
244 }
245
246 #[test]
247 fn test_activation_relu_zero() {
248 assert_eq!(apply_activation(0.0, &ActivationFn::ReLU), 0.0);
249 }
250
251 #[test]
252 fn test_activation_tanh() {
253 let v = apply_activation(0.0, &ActivationFn::Tanh);
254 assert!((v - 0.0).abs() < 1e-10);
255 let v2 = apply_activation(1.0, &ActivationFn::Tanh);
256 assert!((v2 - 1.0_f64.tanh()).abs() < 1e-10);
257 }
258
259 #[test]
260 fn test_activation_sigmoid_at_zero() {
261 let v = apply_activation(0.0, &ActivationFn::Sigmoid);
262 assert!((v - 0.5).abs() < 1e-10);
263 }
264
265 #[test]
266 fn test_activation_sigmoid_large_positive() {
267 let v = apply_activation(100.0, &ActivationFn::Sigmoid);
268 assert!((v - 1.0).abs() < 1e-6);
269 }
270
271 #[test]
272 fn test_activation_sigmoid_large_negative() {
273 let v = apply_activation(-100.0, &ActivationFn::Sigmoid);
274 assert!(v < 1e-6);
275 }
276
277 #[test]
278 #[allow(clippy::approx_constant)]
279 fn test_activation_none_is_identity() {
280 assert_eq!(apply_activation(3.14, &ActivationFn::None), 3.14);
281 assert_eq!(apply_activation(-7.0, &ActivationFn::None), -7.0);
282 }
283
284 #[test]
287 fn test_new_zeros() {
288 let layer = ProjectionLayer::new(4, 2, InitMethod::Zeros);
289 assert_eq!(layer.input_dim(), 4);
290 assert_eq!(layer.output_dim(), 2);
291 let out = layer.project(&[1.0, 2.0, 3.0, 4.0]);
293 assert_eq!(out, vec![0.0, 0.0]);
294 }
295
296 #[test]
297 fn test_new_identity_square() {
298 let layer = ProjectionLayer::new(3, 3, InitMethod::Identity);
299 let input = vec![1.0, 2.0, 3.0];
300 let out = layer.project(&input);
301 assert_eq!(out, vec![1.0, 2.0, 3.0]);
302 }
303
304 #[test]
305 fn test_new_identity_reduce_dim() {
306 let layer = ProjectionLayer::new(4, 2, InitMethod::Identity);
307 let input = vec![5.0, 7.0, 9.0, 11.0];
308 let out = layer.project(&input);
309 assert!((out[0] - 5.0).abs() < 1e-10);
311 assert!((out[1] - 7.0).abs() < 1e-10);
312 }
313
314 #[test]
315 fn test_new_random_produces_output() {
316 let layer = ProjectionLayer::new(8, 4, InitMethod::Random(42));
317 let input = vec![1.0; 8];
318 let out = layer.project(&input);
319 assert_eq!(out.len(), 4);
320 }
321
322 #[test]
323 fn test_new_random_different_seeds_differ() {
324 let l1 = ProjectionLayer::new(4, 2, InitMethod::Random(1));
325 let l2 = ProjectionLayer::new(4, 2, InitMethod::Random(2));
326 let input = vec![1.0, 1.0, 1.0, 1.0];
327 let o1 = l1.project(&input);
328 let o2 = l2.project(&input);
329 assert_ne!(o1, o2);
330 }
331
332 #[test]
333 fn test_new_random_same_seed_same_output() {
334 let l1 = ProjectionLayer::new(4, 2, InitMethod::Random(99));
335 let l2 = ProjectionLayer::new(4, 2, InitMethod::Random(99));
336 let input = vec![1.0, 0.5, -0.5, -1.0];
337 assert_eq!(l1.project(&input), l2.project(&input));
338 }
339
340 #[test]
343 fn test_parameter_count() {
344 let layer = ProjectionLayer::new(10, 5, InitMethod::Zeros);
345 assert_eq!(layer.parameter_count(), 55);
347 }
348
349 #[test]
350 fn test_parameter_count_large() {
351 let layer = ProjectionLayer::new(768, 128, InitMethod::Zeros);
352 assert_eq!(layer.parameter_count(), 768 * 128 + 128);
353 }
354
355 #[test]
358 fn test_set_weights_valid() {
359 let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
360 let weights = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
361 assert!(layer.set_weights(weights).is_ok());
362 let out = layer.project(&[1.0, 1.0, 1.0]);
363 assert!((out[0] - 6.0).abs() < 1e-10); assert!((out[1] - 15.0).abs() < 1e-10); }
366
367 #[test]
368 fn test_set_weights_wrong_row_count() {
369 let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
370 let err = layer.set_weights(vec![vec![1.0, 2.0, 3.0]]);
371 assert!(err.is_err());
372 }
373
374 #[test]
375 fn test_set_weights_wrong_col_count() {
376 let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
377 let err = layer.set_weights(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
378 assert!(err.is_err());
379 }
380
381 #[test]
384 fn test_set_bias_valid() {
385 let mut layer = ProjectionLayer::new(2, 2, InitMethod::Identity);
386 assert!(layer.set_bias(vec![10.0, 20.0]).is_ok());
387 let out = layer.project(&[1.0, 2.0]);
388 assert!((out[0] - 11.0).abs() < 1e-10);
389 assert!((out[1] - 22.0).abs() < 1e-10);
390 }
391
392 #[test]
393 fn test_set_bias_wrong_length() {
394 let mut layer = ProjectionLayer::new(2, 2, InitMethod::Zeros);
395 let err = layer.set_bias(vec![1.0, 2.0, 3.0]);
396 assert!(err.is_err());
397 }
398
399 #[test]
402 fn test_relu_activation_clips_negative() {
403 let mut layer =
404 ProjectionLayer::new(2, 2, InitMethod::Identity).with_activation(ActivationFn::ReLU);
405 assert!(layer.set_bias(vec![-5.0, -5.0]).is_ok());
406 let out = layer.project(&[1.0, 1.0]);
407 assert_eq!(out, vec![0.0, 0.0]);
409 }
410
411 #[test]
412 fn test_tanh_activation_bounds() {
413 let layer =
414 ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::Tanh);
415 let out = layer.project(&[100.0]);
416 assert!((out[0] - 1.0).abs() < 1e-6);
418 }
419
420 #[test]
421 fn test_sigmoid_activation_bounds() {
422 let layer =
423 ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::Sigmoid);
424 let out0 = layer.project(&[0.0]);
425 assert!((out0[0] - 0.5).abs() < 1e-10);
426 }
427
428 #[test]
429 fn test_no_activation_is_linear() {
430 let layer =
431 ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::None);
432 let out = layer.project(&[42.0]);
433 assert!((out[0] - 42.0).abs() < 1e-10);
434 }
435
436 #[test]
439 fn test_project_batch_empty() {
440 let layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
441 let result = layer.project_batch(&[]);
442 assert!(result.is_empty());
443 }
444
445 #[test]
446 fn test_project_batch_multiple() {
447 let layer = ProjectionLayer::new(2, 2, InitMethod::Identity);
448 let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
449 let results = layer.project_batch(&inputs);
450 assert_eq!(results.len(), 2);
451 assert_eq!(results[0], vec![1.0, 2.0]);
452 assert_eq!(results[1], vec![3.0, 4.0]);
453 }
454
455 #[test]
456 fn test_project_batch_consistency() {
457 let layer = ProjectionLayer::new(4, 2, InitMethod::Random(7));
458 let input = vec![1.0, 0.5, -0.5, -1.0];
459 let single = layer.project(&input);
460 let batch = layer.project_batch(&[input]);
461 assert_eq!(batch[0], single);
462 }
463
464 #[test]
467 fn test_reduce_dim_768_to_128() {
468 let layer = ProjectionLayer::new(768, 128, InitMethod::Zeros);
469 assert_eq!(layer.input_dim(), 768);
470 assert_eq!(layer.output_dim(), 128);
471 let input = vec![0.0; 768];
472 let out = layer.project(&input);
473 assert_eq!(out.len(), 128);
474 }
475
476 #[test]
477 fn test_expand_dim() {
478 let layer = ProjectionLayer::new(32, 256, InitMethod::Identity);
479 assert_eq!(layer.output_dim(), 256);
480 let input = vec![1.0; 32];
481 let out = layer.project(&input);
482 assert_eq!(out.len(), 256);
483 }
484
485 #[test]
488 fn test_single_dim_projection() {
489 let mut layer = ProjectionLayer::new(1, 1, InitMethod::Zeros);
490 assert!(layer.set_weights(vec![vec![3.0]]).is_ok());
491 assert!(layer.set_bias(vec![1.0]).is_ok());
492 let out = layer.project(&[2.0]);
493 assert!((out[0] - 7.0).abs() < 1e-10); }
495
496 #[test]
497 fn test_zero_input_with_bias() {
498 let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
499 assert!(layer.set_bias(vec![1.0, 2.0]).is_ok());
500 let out = layer.project(&[0.0, 0.0, 0.0]);
501 assert_eq!(out, vec![1.0, 2.0]);
502 }
503
504 #[test]
505 fn test_init_method_equality() {
506 assert_eq!(InitMethod::Zeros, InitMethod::Zeros);
507 assert_eq!(InitMethod::Random(42), InitMethod::Random(42));
508 assert_ne!(InitMethod::Random(1), InitMethod::Random(2));
509 }
510
511 #[test]
512 fn test_activation_fn_equality() {
513 assert_eq!(ActivationFn::ReLU, ActivationFn::ReLU);
514 assert_ne!(ActivationFn::ReLU, ActivationFn::Tanh);
515 }
516}