1use crate::autograd::matmul_nt;
6use crate::Tensor;
7use std::collections::HashMap;
8
9use super::config::TransformerConfig;
10
11pub struct FeedForward {
13 config: TransformerConfig,
15 pub w_gate: Tensor,
17 pub w_up: Tensor,
19 pub w_down: Tensor,
21}
22
23impl FeedForward {
24 pub fn new(config: &TransformerConfig) -> Self {
26 use super::init::{get_init_seed, rand_normal_seeded};
27 let hidden_size = config.hidden_size;
28 let intermediate_size = config.intermediate_size;
29 let seed = get_init_seed();
30
31 Self {
32 config: config.clone(),
33 w_gate: Tensor::from_vec(
34 rand_normal_seeded(hidden_size * intermediate_size, seed, "w_gate"),
35 true,
36 ),
37 w_up: Tensor::from_vec(
38 rand_normal_seeded(hidden_size * intermediate_size, seed, "w_up"),
39 true,
40 ),
41 w_down: Tensor::from_vec(
42 rand_normal_seeded(intermediate_size * hidden_size, seed, "w_down"),
43 true,
44 ),
45 }
46 }
47
48 pub fn from_params(
59 config: &TransformerConfig,
60 params: &HashMap<String, Tensor>,
61 prefix: &str,
62 ) -> Option<Self> {
63 let w_gate = params.get(&format!("{prefix}.gate_proj.weight"))?.clone();
64 let w_up = params.get(&format!("{prefix}.up_proj.weight"))?.clone();
65 let w_down = params.get(&format!("{prefix}.down_proj.weight"))?.clone();
66
67 let expected_gate_up = config.hidden_size * config.intermediate_size;
68 let expected_down = config.intermediate_size * config.hidden_size;
69
70 let checks: &[(&str, &Tensor, usize)] = &[
72 ("gate_proj", &w_gate, expected_gate_up),
73 ("up_proj", &w_up, expected_gate_up),
74 ("down_proj", &w_down, expected_down),
75 ];
76 for &(name, tensor, expected) in checks {
77 if tensor.len() != expected {
78 eprintln!(
79 "[PMAT-333] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
80 tensor.len()
81 );
82 return None;
83 }
84 }
85
86 Some(Self { config: config.clone(), w_gate, w_up, w_down })
87 }
88
89 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
100 let hidden_size = self.config.hidden_size;
101 let intermediate_size = self.config.intermediate_size;
102
103 let gate = matmul_nt(x, &self.w_gate, seq_len, hidden_size, intermediate_size);
105
106 let up = matmul_nt(x, &self.w_up, seq_len, hidden_size, intermediate_size);
108
109 let gate_activated = crate::autograd::swish(&gate);
111 let hidden = crate::autograd::mul(&gate_activated, &up);
112 contract_post_swiglu!(hidden.data().as_slice().unwrap_or(&[]));
113
114 matmul_nt(&hidden, &self.w_down, seq_len, intermediate_size, hidden_size)
116 }
117
118 pub fn parameters(&self) -> Vec<&Tensor> {
120 vec![&self.w_gate, &self.w_up, &self.w_down]
121 }
122
123 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
125 vec![&mut self.w_gate, &mut self.w_up, &mut self.w_down]
126 }
127}
128
129fn gelu(x: f32) -> f32 {
134 let c = (2.0_f32 / std::f32::consts::PI).sqrt();
135 0.5 * x * (1.0 + (c * (x + 0.044715 * x * x * x)).tanh())
136}
137
138pub struct EncoderFeedForward {
148 config: TransformerConfig,
149 pub w_up: Tensor,
151 pub b_up: Tensor,
153 pub w_down: Tensor,
155 pub b_down: Tensor,
157}
158
159impl EncoderFeedForward {
160 pub fn new(config: &TransformerConfig) -> Self {
162 use super::init::{get_init_seed, rand_normal_seeded};
163 let h = config.hidden_size;
164 let inter = config.intermediate_size;
165 let seed = get_init_seed();
166
167 Self {
168 config: config.clone(),
169 w_up: Tensor::from_vec(rand_normal_seeded(h * inter, seed, "enc_w_up"), true),
170 b_up: Tensor::from_vec(vec![0.0; inter], true),
171 w_down: Tensor::from_vec(rand_normal_seeded(inter * h, seed, "enc_w_down"), true),
172 b_down: Tensor::from_vec(vec![0.0; h], true),
173 }
174 }
175
176 pub fn from_params(
184 config: &TransformerConfig,
185 params: &HashMap<String, Tensor>,
186 prefix: &str,
187 ) -> Option<Self> {
188 let w_up = params.get(&format!("{prefix}.intermediate.dense.weight"))?.clone();
189 let b_up = params.get(&format!("{prefix}.intermediate.dense.bias"))?.clone();
190 let w_down = params.get(&format!("{prefix}.output.dense.weight"))?.clone();
191 let b_down = params.get(&format!("{prefix}.output.dense.bias"))?.clone();
192
193 let expected_up = config.hidden_size * config.intermediate_size;
194 let expected_down = config.intermediate_size * config.hidden_size;
195
196 if w_up.len() != expected_up {
197 eprintln!(
198 "[ENC-004] {prefix}.intermediate.dense.weight: shape mismatch — \
199 got {} elements, expected {expected_up}",
200 w_up.len()
201 );
202 return None;
203 }
204 if w_down.len() != expected_down {
205 eprintln!(
206 "[ENC-004] {prefix}.output.dense.weight: shape mismatch — \
207 got {} elements, expected {expected_down}",
208 w_down.len()
209 );
210 return None;
211 }
212
213 Some(Self { config: config.clone(), w_up, b_up, w_down, b_down })
214 }
215
216 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
218 let h = self.config.hidden_size;
219 let inter = self.config.intermediate_size;
220
221 let up = matmul_nt(x, &self.w_up, seq_len, h, inter);
223 let up_data = up.data();
224 let up_slice = up_data.as_slice().expect("contiguous");
225 let b_up_slice = self.b_up.data().as_slice().expect("contiguous");
226
227 let activated: Vec<f32> =
229 (0..seq_len * inter).map(|i| gelu(up_slice[i] + b_up_slice[i % inter])).collect();
230 let activated_t = Tensor::from_vec(activated, true);
231
232 let down = matmul_nt(&activated_t, &self.w_down, seq_len, inter, h);
234 let down_data = down.data();
235 let down_slice = down_data.as_slice().expect("contiguous");
236 let b_down_slice = self.b_down.data().as_slice().expect("contiguous");
237
238 let output: Vec<f32> =
239 (0..seq_len * h).map(|i| down_slice[i] + b_down_slice[i % h]).collect();
240 Tensor::from_vec(output, true)
241 }
242
243 pub fn parameters(&self) -> Vec<&Tensor> {
245 vec![&self.w_up, &self.b_up, &self.w_down, &self.b_down]
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_feed_forward_tiny() {
255 let config = TransformerConfig::tiny();
256 let ffn = FeedForward::new(&config);
257 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
258 let output = ffn.forward(&x, 2);
259 assert_eq!(output.len(), 2 * config.hidden_size);
260 }
261
262 #[test]
263 fn test_feed_forward_parameters() {
264 let config = TransformerConfig::tiny();
265 let ffn = FeedForward::new(&config);
266 let params = ffn.parameters();
267 assert_eq!(params.len(), 3); }
269
270 #[test]
271 fn test_ffn_longer_sequence() {
272 let config = TransformerConfig::tiny();
273 let ffn = FeedForward::new(&config);
274 let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
275 let output = ffn.forward(&x, 8);
276 assert_eq!(output.len(), 8 * config.hidden_size);
277 }
278
279 #[test]
280 fn test_ffn_weight_sizes() {
281 let config = TransformerConfig::tiny();
282 let ffn = FeedForward::new(&config);
283 assert_eq!(ffn.w_gate.len(), config.hidden_size * config.intermediate_size);
284 assert_eq!(ffn.w_up.len(), config.hidden_size * config.intermediate_size);
285 assert_eq!(ffn.w_down.len(), config.intermediate_size * config.hidden_size);
286 }
287
288 #[test]
289 fn test_feed_forward_from_params_success() {
290 let config = TransformerConfig::tiny();
291 let hidden_size = config.hidden_size;
292 let intermediate_size = config.intermediate_size;
293
294 let mut params = HashMap::new();
295 params.insert(
296 "ffn.gate_proj.weight".to_string(),
297 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
298 );
299 params.insert(
300 "ffn.up_proj.weight".to_string(),
301 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
302 );
303 params.insert(
304 "ffn.down_proj.weight".to_string(),
305 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
306 );
307
308 let ffn = FeedForward::from_params(&config, ¶ms, "ffn");
309 assert!(ffn.is_some());
310 let ffn = ffn.expect("operation should succeed");
311 assert_eq!(ffn.w_gate.len(), hidden_size * intermediate_size);
312 }
313
314 #[test]
315 fn test_feed_forward_from_params_missing_key() {
316 let config = TransformerConfig::tiny();
317 let hidden_size = config.hidden_size;
318 let intermediate_size = config.intermediate_size;
319
320 let mut params = HashMap::new();
321 params.insert(
322 "ffn.gate_proj.weight".to_string(),
323 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
324 );
325 let ffn = FeedForward::from_params(&config, ¶ms, "ffn");
328 assert!(ffn.is_none());
329 }
330
331 #[test]
336 fn enc_004_gelu_approximation() {
337 assert!((gelu(0.0)).abs() < 1e-6);
339 assert!((gelu(3.0) - 3.0).abs() < 0.01);
341 assert!(gelu(-3.0).abs() < 0.01);
343 assert!((gelu(1.0) - 0.8412).abs() < 0.01);
345 }
346
347 #[test]
348 fn enc_004_encoder_ffn_output_shape() {
349 let config = TransformerConfig::tiny();
350 let ffn = EncoderFeedForward::new(&config);
351 let x = Tensor::from_vec(vec![0.1; 4 * config.hidden_size], true);
352 let output = ffn.forward(&x, 4);
353 assert_eq!(output.len(), 4 * config.hidden_size);
354 }
355
356 #[test]
357 fn enc_004_encoder_ffn_has_4_params() {
358 let config = TransformerConfig::tiny();
359 let ffn = EncoderFeedForward::new(&config);
360 assert_eq!(ffn.parameters().len(), 4); }
362
363 #[test]
364 fn enc_004_encoder_ffn_output_finite() {
365 let config = TransformerConfig::tiny();
366 let ffn = EncoderFeedForward::new(&config);
367 let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
368 let output = ffn.forward(&x, 2);
369 assert!(output.data().iter().all(|v| v.is_finite()));
370 }
371
372 #[test]
373 fn enc_004_encoder_ffn_from_params() {
374 let config = TransformerConfig::tiny();
375 let h = config.hidden_size;
376 let inter = config.intermediate_size;
377
378 let mut params = HashMap::new();
379 params.insert(
380 "layer.intermediate.dense.weight".to_string(),
381 Tensor::from_vec(vec![0.1; h * inter], true),
382 );
383 params.insert(
384 "layer.intermediate.dense.bias".to_string(),
385 Tensor::from_vec(vec![0.0; inter], true),
386 );
387 params.insert(
388 "layer.output.dense.weight".to_string(),
389 Tensor::from_vec(vec![0.1; inter * h], true),
390 );
391 params.insert("layer.output.dense.bias".to_string(), Tensor::from_vec(vec![0.0; h], true));
392
393 let ffn = EncoderFeedForward::from_params(&config, ¶ms, "layer");
394 assert!(ffn.is_some());
395 }
396
397 #[test]
398 fn enc_004_encoder_ffn_from_params_rejects_wrong_shape() {
399 let config = TransformerConfig::tiny();
400 let mut params = HashMap::new();
401 params.insert(
402 "layer.intermediate.dense.weight".to_string(),
403 Tensor::from_vec(vec![0.1; 42], true), );
405 params.insert(
406 "layer.intermediate.dense.bias".to_string(),
407 Tensor::from_vec(vec![0.0; config.intermediate_size], true),
408 );
409 params.insert(
410 "layer.output.dense.weight".to_string(),
411 Tensor::from_vec(vec![0.1; config.intermediate_size * config.hidden_size], true),
412 );
413 params.insert(
414 "layer.output.dense.bias".to_string(),
415 Tensor::from_vec(vec![0.0; config.hidden_size], true),
416 );
417
418 let ffn = EncoderFeedForward::from_params(&config, ¶ms, "layer");
419 assert!(ffn.is_none());
420 }
421
422 #[test]
446 fn falsify_f1e_from_params_rejects_wrong_shape_gate() {
447 let config = TransformerConfig::tiny();
448 let hidden_size = config.hidden_size;
449 let intermediate_size = config.intermediate_size;
450
451 let mut params = HashMap::new();
452 params.insert("ffn.gate_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 42], true));
454 params.insert(
455 "ffn.up_proj.weight".to_string(),
456 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
457 );
458 params.insert(
459 "ffn.down_proj.weight".to_string(),
460 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
461 );
462
463 let ffn = FeedForward::from_params(&config, ¶ms, "ffn");
465 assert!(
466 ffn.is_none(),
467 "FALSIFY-F1e: PMAT-333 fix — from_params MUST reject wrong-shape gate_proj"
468 );
469 }
470
471 #[test]
475 fn falsify_f2e_swiglu_forward_correct_dims() {
476 let config = TransformerConfig::tiny();
477 let ffn = FeedForward::new(&config);
478 let seq_len = 4;
479 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
480 let output = ffn.forward(&x, seq_len);
481 assert_eq!(
482 output.len(),
483 seq_len * config.hidden_size,
484 "FALSIFY-F2e: FFN output must be seq_len * hidden_size"
485 );
486 }
487
488 #[test]
493 fn falsify_f3e_ffn_output_finite() {
494 let config = TransformerConfig::tiny();
495 let ffn = FeedForward::new(&config);
496 let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
497 let output = ffn.forward(&x, 2);
498 assert!(
499 output.data().iter().all(|v| v.is_finite()),
500 "FALSIFY-F3e: FFN output must be finite for bounded inputs"
501 );
502 }
503
504 #[test]
509 fn falsify_f4e_gate_up_shape_parity() {
510 let config = TransformerConfig::tiny();
511 let ffn = FeedForward::new(&config);
512 assert_eq!(
513 ffn.w_gate.len(),
514 ffn.w_up.len(),
515 "FALSIFY-F4e: gate_proj and up_proj must have identical size for SwiGLU multiply"
516 );
517 }
518
519 #[test]
525 fn falsify_f5e_down_proj_reversed_same_total() {
526 let config = TransformerConfig::tiny();
527 let ffn = FeedForward::new(&config);
528 assert_eq!(
529 ffn.w_gate.len(),
530 ffn.w_down.len(),
531 "FALSIFY-F5e: gate and down must have same total elements (H*I)"
532 );
533 assert_eq!(
534 ffn.w_down.len(),
535 config.hidden_size * config.intermediate_size,
536 "FALSIFY-F5e: down_proj must have hidden*intermediate elements"
537 );
538 }
539
540 #[test]
541 fn test_ffn_backward_gradient_exists() {
542 let config = TransformerConfig::tiny();
543 let ffn = FeedForward::new(&config);
544 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
545 let mut output = ffn.forward(&x, 2);
546
547 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
549 crate::autograd::backward(&mut output, Some(grad_out));
550
551 assert!(ffn.w_gate.grad().is_some());
553 assert!(ffn.w_up.grad().is_some());
554 assert!(ffn.w_down.grad().is_some());
555 }
556
557 #[test]
558 fn test_ffn_backward_gradients_finite() {
559 let config = TransformerConfig::tiny();
560 let ffn = FeedForward::new(&config);
561 let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
562 let mut output = ffn.forward(&x, 2);
563
564 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
565 crate::autograd::backward(&mut output, Some(grad_out));
566
567 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
569 let grad_up = ffn.w_up.grad().expect("gradient should be available");
570 let grad_down = ffn.w_down.grad().expect("gradient should be available");
571
572 assert!(grad_gate.iter().all(|&v| v.is_finite()));
573 assert!(grad_up.iter().all(|&v| v.is_finite()));
574 assert!(grad_down.iter().all(|&v| v.is_finite()));
575 }
576
577 #[test]
578 fn test_ffn_backward_swiglu_activation() {
579 let config = TransformerConfig::tiny();
581
582 for scale in [0.1, 1.0, 2.0] {
584 let ffn = FeedForward::new(&config);
585 let x = Tensor::from_vec(
586 (0..2 * config.hidden_size).map(|i| (i as f32 * 0.01).sin() * scale).collect(),
587 true,
588 );
589 let mut output = ffn.forward(&x, 2);
590
591 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
592 crate::autograd::backward(&mut output, Some(grad_out));
593
594 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
595 assert!(
596 grad_gate.iter().all(|&v| v.is_finite()),
597 "Gradients not finite for scale {scale}"
598 );
599 }
600 }
601
602 #[test]
603 fn test_ffn_backward_gradient_nonzero() {
604 let config = TransformerConfig::tiny();
605 let ffn = FeedForward::new(&config);
606 let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
607 let mut output = ffn.forward(&x, 2);
608
609 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
610 crate::autograd::backward(&mut output, Some(grad_out));
611
612 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
614 let sum: f32 = grad_gate.iter().map(|v| v.abs()).sum();
615 assert!(sum > 0.0, "FFN gate gradients should not be all zero");
616 }
617
618 #[test]
619 fn test_ffn_backward_different_seq_lengths() {
620 let config = TransformerConfig::tiny();
621
622 for seq_len in [1, 2, 4, 8] {
623 let ffn = FeedForward::new(&config);
624 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
625 let mut output = ffn.forward(&x, seq_len);
626
627 let grad_out = ndarray::Array1::ones(seq_len * config.hidden_size);
628 crate::autograd::backward(&mut output, Some(grad_out));
629
630 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
631 assert!(
632 grad_gate.iter().all(|&v| v.is_finite()),
633 "Non-finite gradient for seq_len {seq_len}"
634 );
635 }
636 }
637
638 #[test]
639 fn test_ffn_backward_gradient_accumulation() {
640 let config = TransformerConfig::tiny();
641 let ffn = FeedForward::new(&config);
642
643 let x1 = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
645 let mut output1 = ffn.forward(&x1, 2);
646 let grad_out1 = ndarray::Array1::ones(2 * config.hidden_size);
647 crate::autograd::backward(&mut output1, Some(grad_out1));
648 let grad1 = ffn.w_gate.grad().expect("gradient should be available").to_vec();
649
650 let x2 = Tensor::from_vec(vec![0.2; 2 * config.hidden_size], true);
652 let mut output2 = ffn.forward(&x2, 2);
653 let grad_out2 = ndarray::Array1::ones(2 * config.hidden_size);
654 crate::autograd::backward(&mut output2, Some(grad_out2));
655 let grad2 = ffn.w_gate.grad().expect("gradient should be available").to_vec();
656
657 assert!(
659 grad2.iter().zip(grad1.iter()).any(|(g2, g1)| g2.abs() != g1.abs()),
660 "Gradients should accumulate across backward passes"
661 );
662 }
663
664 #[test]
665 fn test_ffn_backward_with_zero_input() {
666 let config = TransformerConfig::tiny();
667 let ffn = FeedForward::new(&config);
668 let x = Tensor::from_vec(vec![0.0; 2 * config.hidden_size], true);
669 let mut output = ffn.forward(&x, 2);
670
671 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
672 crate::autograd::backward(&mut output, Some(grad_out));
673
674 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
676 assert!(grad_gate.iter().all(|&v| v.is_finite()));
677 }
678
679 #[test]
680 fn test_ffn_backward_large_input() {
681 let config = TransformerConfig::tiny();
682 let ffn = FeedForward::new(&config);
683 let x = Tensor::from_vec(vec![10.0; 2 * config.hidden_size], true);
684 let mut output = ffn.forward(&x, 2);
685
686 let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
687 crate::autograd::backward(&mut output, Some(grad_out));
688
689 let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
691 assert!(grad_gate.iter().all(|&v| v.is_finite()));
692 }
693}