Skip to main content

entrenar/transformer/
feedforward.rs

1//! Feed-forward network module
2//!
3//! This module provides position-wise feed-forward networks with SwiGLU activation.
4
5use crate::autograd::matmul_nt;
6use crate::Tensor;
7use std::collections::HashMap;
8
9use super::config::TransformerConfig;
10
11/// Position-wise Feed-Forward Network
12pub struct FeedForward {
13    /// Configuration
14    config: TransformerConfig,
15    /// Gate projection weight (hidden_size x intermediate_size)
16    pub w_gate: Tensor,
17    /// Up projection weight (hidden_size x intermediate_size)
18    pub w_up: Tensor,
19    /// Down projection weight (intermediate_size x hidden_size)
20    pub w_down: Tensor,
21}
22
23impl FeedForward {
24    /// Create new FFN layer with random normal initialization (C-INIT-001).
25    pub fn new(config: &TransformerConfig) -> Self {
26        use super::init::{get_init_seed, rand_normal_seeded};
27        let hidden_size = config.hidden_size;
28        let intermediate_size = config.intermediate_size;
29        let seed = get_init_seed();
30
31        Self {
32            config: config.clone(),
33            w_gate: Tensor::from_vec(
34                rand_normal_seeded(hidden_size * intermediate_size, seed, "w_gate"),
35                true,
36            ),
37            w_up: Tensor::from_vec(
38                rand_normal_seeded(hidden_size * intermediate_size, seed, "w_up"),
39                true,
40            ),
41            w_down: Tensor::from_vec(
42                rand_normal_seeded(intermediate_size * hidden_size, seed, "w_down"),
43                true,
44            ),
45        }
46    }
47
48    /// Create FFN layer from parameter map
49    ///
50    /// Expected parameter names (following HuggingFace convention):
51    /// - `{prefix}.gate_proj.weight`
52    /// - `{prefix}.up_proj.weight`
53    /// - `{prefix}.down_proj.weight`
54    /// # Contract (PMAT-333)
55    /// Validates gate/up/down projection shapes against config dimensions.
56    /// gate/up: hidden_size * intermediate_size, down: intermediate_size * hidden_size
57    /// Returns None if any key is missing or shape is wrong.
58    pub fn from_params(
59        config: &TransformerConfig,
60        params: &HashMap<String, Tensor>,
61        prefix: &str,
62    ) -> Option<Self> {
63        let w_gate = params.get(&format!("{prefix}.gate_proj.weight"))?.clone();
64        let w_up = params.get(&format!("{prefix}.up_proj.weight"))?.clone();
65        let w_down = params.get(&format!("{prefix}.down_proj.weight"))?.clone();
66
67        let expected_gate_up = config.hidden_size * config.intermediate_size;
68        let expected_down = config.intermediate_size * config.hidden_size;
69
70        // PMAT-333: Shape validation for FFN projections
71        let checks: &[(&str, &Tensor, usize)] = &[
72            ("gate_proj", &w_gate, expected_gate_up),
73            ("up_proj", &w_up, expected_gate_up),
74            ("down_proj", &w_down, expected_down),
75        ];
76        for &(name, tensor, expected) in checks {
77            if tensor.len() != expected {
78                eprintln!(
79                    "[PMAT-333] {prefix}.{name}: shape mismatch — got {} elements, expected {expected}",
80                    tensor.len()
81                );
82                return None;
83            }
84        }
85
86        Some(Self { config: config.clone(), w_gate, w_up, w_down })
87    }
88
89    /// Forward pass with SwiGLU activation
90    ///
91    /// FFN(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x))
92    ///
93    /// # Arguments
94    /// * `x` - Input tensor (seq_len * hidden_size, flattened)
95    /// * `seq_len` - Sequence length
96    ///
97    /// # Returns
98    /// Output tensor (seq_len * hidden_size, flattened)
99    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
100        let hidden_size = self.config.hidden_size;
101        let intermediate_size = self.config.intermediate_size;
102
103        // Gate projection — HF weights [intermediate, hidden] (ENT-269)
104        let gate = matmul_nt(x, &self.w_gate, seq_len, hidden_size, intermediate_size);
105
106        // Up projection — HF weights [intermediate, hidden] (ENT-269)
107        let up = matmul_nt(x, &self.w_up, seq_len, hidden_size, intermediate_size);
108
109        // SwiGLU: SiLU(gate) * up
110        let gate_activated = crate::autograd::swish(&gate);
111        let hidden = crate::autograd::mul(&gate_activated, &up);
112        contract_post_swiglu!(hidden.data().as_slice().unwrap_or(&[]));
113
114        // Down projection — HF weights [hidden, intermediate] (ENT-269)
115        matmul_nt(&hidden, &self.w_down, seq_len, intermediate_size, hidden_size)
116    }
117
118    /// Get all parameters as a vector
119    pub fn parameters(&self) -> Vec<&Tensor> {
120        vec![&self.w_gate, &self.w_up, &self.w_down]
121    }
122
123    /// Get all parameters as mutable references for optimizer
124    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
125        vec![&mut self.w_gate, &mut self.w_up, &mut self.w_down]
126    }
127}
128
129/// GELU activation function (used by BERT/RoBERTa encoders).
130///
131/// GELU(x) = x * Φ(x) where Φ is the standard normal CDF.
132/// Approximation: 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
133fn gelu(x: f32) -> f32 {
134    let c = (2.0_f32 / std::f32::consts::PI).sqrt();
135    0.5 * x * (1.0 + (c * (x + 0.044715 * x * x * x)).tanh())
136}
137
138/// Encoder feed-forward network with GELU activation (BERT/RoBERTa/CodeBERT).
139///
140/// Unlike decoder SwiGLU (3 projections), encoder FFN uses 2 projections:
141///   FFN(x) = W_down * GELU(W_up * x + b_up) + b_down
142///
143/// # Contract (ENC-004)
144/// - Uses GELU activation (not SiLU/SwiGLU)
145/// - 2 projections (up + down), not 3
146/// - Supports bias terms (BERT convention)
147pub struct EncoderFeedForward {
148    config: TransformerConfig,
149    /// Up projection (hidden → intermediate)
150    pub w_up: Tensor,
151    /// Up projection bias
152    pub b_up: Tensor,
153    /// Down projection (intermediate → hidden)
154    pub w_down: Tensor,
155    /// Down projection bias
156    pub b_down: Tensor,
157}
158
159impl EncoderFeedForward {
160    /// Create new encoder FFN with random normal initialization (C-INIT-001).
161    pub fn new(config: &TransformerConfig) -> Self {
162        use super::init::{get_init_seed, rand_normal_seeded};
163        let h = config.hidden_size;
164        let inter = config.intermediate_size;
165        let seed = get_init_seed();
166
167        Self {
168            config: config.clone(),
169            w_up: Tensor::from_vec(rand_normal_seeded(h * inter, seed, "enc_w_up"), true),
170            b_up: Tensor::from_vec(vec![0.0; inter], true),
171            w_down: Tensor::from_vec(rand_normal_seeded(inter * h, seed, "enc_w_down"), true),
172            b_down: Tensor::from_vec(vec![0.0; h], true),
173        }
174    }
175
176    /// Create from pre-trained parameters (BERT/RoBERTa weight names)
177    ///
178    /// Expected keys:
179    /// - `{prefix}.intermediate.dense.weight` (hidden → intermediate)
180    /// - `{prefix}.intermediate.dense.bias`
181    /// - `{prefix}.output.dense.weight` (intermediate → hidden)
182    /// - `{prefix}.output.dense.bias`
183    pub fn from_params(
184        config: &TransformerConfig,
185        params: &HashMap<String, Tensor>,
186        prefix: &str,
187    ) -> Option<Self> {
188        let w_up = params.get(&format!("{prefix}.intermediate.dense.weight"))?.clone();
189        let b_up = params.get(&format!("{prefix}.intermediate.dense.bias"))?.clone();
190        let w_down = params.get(&format!("{prefix}.output.dense.weight"))?.clone();
191        let b_down = params.get(&format!("{prefix}.output.dense.bias"))?.clone();
192
193        let expected_up = config.hidden_size * config.intermediate_size;
194        let expected_down = config.intermediate_size * config.hidden_size;
195
196        if w_up.len() != expected_up {
197            eprintln!(
198                "[ENC-004] {prefix}.intermediate.dense.weight: shape mismatch — \
199                 got {} elements, expected {expected_up}",
200                w_up.len()
201            );
202            return None;
203        }
204        if w_down.len() != expected_down {
205            eprintln!(
206                "[ENC-004] {prefix}.output.dense.weight: shape mismatch — \
207                 got {} elements, expected {expected_down}",
208                w_down.len()
209            );
210            return None;
211        }
212
213        Some(Self { config: config.clone(), w_up, b_up, w_down, b_down })
214    }
215
216    /// Forward pass: FFN(x) = W_down * GELU(W_up * x + b_up) + b_down
217    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
218        let h = self.config.hidden_size;
219        let inter = self.config.intermediate_size;
220
221        // Up projection — HF weights [inter, h] (ENT-269)
222        let up = matmul_nt(x, &self.w_up, seq_len, h, inter);
223        let up_data = up.data();
224        let up_slice = up_data.as_slice().expect("contiguous");
225        let b_up_slice = self.b_up.data().as_slice().expect("contiguous");
226
227        // GELU(W_up * x + b_up)
228        let activated: Vec<f32> =
229            (0..seq_len * inter).map(|i| gelu(up_slice[i] + b_up_slice[i % inter])).collect();
230        let activated_t = Tensor::from_vec(activated, true);
231
232        // Down projection — HF weights [h, inter] (ENT-269)
233        let down = matmul_nt(&activated_t, &self.w_down, seq_len, inter, h);
234        let down_data = down.data();
235        let down_slice = down_data.as_slice().expect("contiguous");
236        let b_down_slice = self.b_down.data().as_slice().expect("contiguous");
237
238        let output: Vec<f32> =
239            (0..seq_len * h).map(|i| down_slice[i] + b_down_slice[i % h]).collect();
240        Tensor::from_vec(output, true)
241    }
242
243    /// Get all parameters
244    pub fn parameters(&self) -> Vec<&Tensor> {
245        vec![&self.w_up, &self.b_up, &self.w_down, &self.b_down]
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_feed_forward_tiny() {
255        let config = TransformerConfig::tiny();
256        let ffn = FeedForward::new(&config);
257        let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
258        let output = ffn.forward(&x, 2);
259        assert_eq!(output.len(), 2 * config.hidden_size);
260    }
261
262    #[test]
263    fn test_feed_forward_parameters() {
264        let config = TransformerConfig::tiny();
265        let ffn = FeedForward::new(&config);
266        let params = ffn.parameters();
267        assert_eq!(params.len(), 3); // w_gate, w_up, w_down
268    }
269
270    #[test]
271    fn test_ffn_longer_sequence() {
272        let config = TransformerConfig::tiny();
273        let ffn = FeedForward::new(&config);
274        let x = Tensor::from_vec(vec![0.1; 8 * config.hidden_size], true);
275        let output = ffn.forward(&x, 8);
276        assert_eq!(output.len(), 8 * config.hidden_size);
277    }
278
279    #[test]
280    fn test_ffn_weight_sizes() {
281        let config = TransformerConfig::tiny();
282        let ffn = FeedForward::new(&config);
283        assert_eq!(ffn.w_gate.len(), config.hidden_size * config.intermediate_size);
284        assert_eq!(ffn.w_up.len(), config.hidden_size * config.intermediate_size);
285        assert_eq!(ffn.w_down.len(), config.intermediate_size * config.hidden_size);
286    }
287
288    #[test]
289    fn test_feed_forward_from_params_success() {
290        let config = TransformerConfig::tiny();
291        let hidden_size = config.hidden_size;
292        let intermediate_size = config.intermediate_size;
293
294        let mut params = HashMap::new();
295        params.insert(
296            "ffn.gate_proj.weight".to_string(),
297            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
298        );
299        params.insert(
300            "ffn.up_proj.weight".to_string(),
301            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
302        );
303        params.insert(
304            "ffn.down_proj.weight".to_string(),
305            Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
306        );
307
308        let ffn = FeedForward::from_params(&config, &params, "ffn");
309        assert!(ffn.is_some());
310        let ffn = ffn.expect("operation should succeed");
311        assert_eq!(ffn.w_gate.len(), hidden_size * intermediate_size);
312    }
313
314    #[test]
315    fn test_feed_forward_from_params_missing_key() {
316        let config = TransformerConfig::tiny();
317        let hidden_size = config.hidden_size;
318        let intermediate_size = config.intermediate_size;
319
320        let mut params = HashMap::new();
321        params.insert(
322            "ffn.gate_proj.weight".to_string(),
323            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
324        );
325        // Missing up_proj, down_proj
326
327        let ffn = FeedForward::from_params(&config, &params, "ffn");
328        assert!(ffn.is_none());
329    }
330
331    // =========================================================================
332    // ENC-004: EncoderFeedForward (GELU) tests
333    // =========================================================================
334
335    #[test]
336    fn enc_004_gelu_approximation() {
337        // GELU(0) = 0
338        assert!((gelu(0.0)).abs() < 1e-6);
339        // GELU is approximately identity for large positive x
340        assert!((gelu(3.0) - 3.0).abs() < 0.01);
341        // GELU is approximately 0 for large negative x
342        assert!(gelu(-3.0).abs() < 0.01);
343        // GELU(1) ≈ 0.8412
344        assert!((gelu(1.0) - 0.8412).abs() < 0.01);
345    }
346
347    #[test]
348    fn enc_004_encoder_ffn_output_shape() {
349        let config = TransformerConfig::tiny();
350        let ffn = EncoderFeedForward::new(&config);
351        let x = Tensor::from_vec(vec![0.1; 4 * config.hidden_size], true);
352        let output = ffn.forward(&x, 4);
353        assert_eq!(output.len(), 4 * config.hidden_size);
354    }
355
356    #[test]
357    fn enc_004_encoder_ffn_has_4_params() {
358        let config = TransformerConfig::tiny();
359        let ffn = EncoderFeedForward::new(&config);
360        assert_eq!(ffn.parameters().len(), 4); // w_up, b_up, w_down, b_down
361    }
362
363    #[test]
364    fn enc_004_encoder_ffn_output_finite() {
365        let config = TransformerConfig::tiny();
366        let ffn = EncoderFeedForward::new(&config);
367        let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
368        let output = ffn.forward(&x, 2);
369        assert!(output.data().iter().all(|v| v.is_finite()));
370    }
371
372    #[test]
373    fn enc_004_encoder_ffn_from_params() {
374        let config = TransformerConfig::tiny();
375        let h = config.hidden_size;
376        let inter = config.intermediate_size;
377
378        let mut params = HashMap::new();
379        params.insert(
380            "layer.intermediate.dense.weight".to_string(),
381            Tensor::from_vec(vec![0.1; h * inter], true),
382        );
383        params.insert(
384            "layer.intermediate.dense.bias".to_string(),
385            Tensor::from_vec(vec![0.0; inter], true),
386        );
387        params.insert(
388            "layer.output.dense.weight".to_string(),
389            Tensor::from_vec(vec![0.1; inter * h], true),
390        );
391        params.insert("layer.output.dense.bias".to_string(), Tensor::from_vec(vec![0.0; h], true));
392
393        let ffn = EncoderFeedForward::from_params(&config, &params, "layer");
394        assert!(ffn.is_some());
395    }
396
397    #[test]
398    fn enc_004_encoder_ffn_from_params_rejects_wrong_shape() {
399        let config = TransformerConfig::tiny();
400        let mut params = HashMap::new();
401        params.insert(
402            "layer.intermediate.dense.weight".to_string(),
403            Tensor::from_vec(vec![0.1; 42], true), // wrong size
404        );
405        params.insert(
406            "layer.intermediate.dense.bias".to_string(),
407            Tensor::from_vec(vec![0.0; config.intermediate_size], true),
408        );
409        params.insert(
410            "layer.output.dense.weight".to_string(),
411            Tensor::from_vec(vec![0.1; config.intermediate_size * config.hidden_size], true),
412        );
413        params.insert(
414            "layer.output.dense.bias".to_string(),
415            Tensor::from_vec(vec![0.0; config.hidden_size], true),
416        );
417
418        let ffn = EncoderFeedForward::from_params(&config, &params, "layer");
419        assert!(ffn.is_none());
420    }
421
422    // =========================================================================
423    // FALSIFY-F: §2.1.4 FFN Projections — Five-Whys Gap Analysis (Refs PMAT-333)
424    //
425    // Contract: tensor-layout-v1.yaml §tensors.gate_proj/up_proj/down_proj
426    //   gate_proj: [intermediate, hidden], up_proj: [intermediate, hidden]
427    //   down_proj: [hidden, intermediate]
428    //   SwiGLU: FFN(x) = down_proj(SiLU(gate_proj(x)) * up_proj(x))
429    //
430    // Five-Whys:
431    //   Why 1: from_params accepts ANY tensor shape without validation
432    //   Why 2: FeedForward stores raw Tensor, no ValidatedWeight wrapper
433    //   Why 3: entrenar uses flattened 1D Tensors — no shape metadata
434    //   Why 4: Shape errors only manifest at matmul time (runtime panic)
435    //   Why 5: No constructor-time shape check exists (PMAT-333 gap)
436    //
437    // Popper (1959): "These tests attempt to falsify the claim that
438    // entrenar's FFN construction prevents shape-related runtime panics."
439    // =========================================================================
440
441    /// FALSIFY-F1e: from_params rejects wrong-shape gate_proj (PMAT-333 fix)
442    ///
443    /// gate_proj should be [hidden_size * intermediate_size] elements.
444    /// from_params now validates shape against config dimensions.
445    #[test]
446    fn falsify_f1e_from_params_rejects_wrong_shape_gate() {
447        let config = TransformerConfig::tiny();
448        let hidden_size = config.hidden_size;
449        let intermediate_size = config.intermediate_size;
450
451        let mut params = HashMap::new();
452        // WRONG: gate_proj has 42 elements instead of hidden*intermediate
453        params.insert("ffn.gate_proj.weight".to_string(), Tensor::from_vec(vec![0.1; 42], true));
454        params.insert(
455            "ffn.up_proj.weight".to_string(),
456            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
457        );
458        params.insert(
459            "ffn.down_proj.weight".to_string(),
460            Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
461        );
462
463        // FIXED (PMAT-333): now rejected
464        let ffn = FeedForward::from_params(&config, &params, "ffn");
465        assert!(
466            ffn.is_none(),
467            "FALSIFY-F1e: PMAT-333 fix — from_params MUST reject wrong-shape gate_proj"
468        );
469    }
470
471    /// FALSIFY-F2e: SwiGLU forward produces correct output dimensions
472    ///
473    /// For correct weights, output.len() == seq_len * hidden_size.
474    #[test]
475    fn falsify_f2e_swiglu_forward_correct_dims() {
476        let config = TransformerConfig::tiny();
477        let ffn = FeedForward::new(&config);
478        let seq_len = 4;
479        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
480        let output = ffn.forward(&x, seq_len);
481        assert_eq!(
482            output.len(),
483            seq_len * config.hidden_size,
484            "FALSIFY-F2e: FFN output must be seq_len * hidden_size"
485        );
486    }
487
488    /// FALSIFY-F3e: FFN output is finite for valid inputs
489    ///
490    /// SwiGLU with bounded inputs must produce finite outputs.
491    /// If gate/up/down weights contain NaN/Inf, output would be NaN.
492    #[test]
493    fn falsify_f3e_ffn_output_finite() {
494        let config = TransformerConfig::tiny();
495        let ffn = FeedForward::new(&config);
496        let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
497        let output = ffn.forward(&x, 2);
498        assert!(
499            output.data().iter().all(|v| v.is_finite()),
500            "FALSIFY-F3e: FFN output must be finite for bounded inputs"
501        );
502    }
503
504    /// FALSIFY-F4e: gate_proj and up_proj share dimensions
505    ///
506    /// SwiGLU requires SiLU(gate(x)) * up(x) — element-wise multiply.
507    /// gate_proj and up_proj must produce identically-sized outputs.
508    #[test]
509    fn falsify_f4e_gate_up_shape_parity() {
510        let config = TransformerConfig::tiny();
511        let ffn = FeedForward::new(&config);
512        assert_eq!(
513            ffn.w_gate.len(),
514            ffn.w_up.len(),
515            "FALSIFY-F4e: gate_proj and up_proj must have identical size for SwiGLU multiply"
516        );
517    }
518
519    /// FALSIFY-F5e: down_proj dimensions reversed from gate/up
520    ///
521    /// gate/up: [hidden, intermediate] (transposed to row-major)
522    /// down: [intermediate, hidden] (reversed)
523    /// Total elements should still be hidden * intermediate.
524    #[test]
525    fn falsify_f5e_down_proj_reversed_same_total() {
526        let config = TransformerConfig::tiny();
527        let ffn = FeedForward::new(&config);
528        assert_eq!(
529            ffn.w_gate.len(),
530            ffn.w_down.len(),
531            "FALSIFY-F5e: gate and down must have same total elements (H*I)"
532        );
533        assert_eq!(
534            ffn.w_down.len(),
535            config.hidden_size * config.intermediate_size,
536            "FALSIFY-F5e: down_proj must have hidden*intermediate elements"
537        );
538    }
539
540    #[test]
541    fn test_ffn_backward_gradient_exists() {
542        let config = TransformerConfig::tiny();
543        let ffn = FeedForward::new(&config);
544        let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
545        let mut output = ffn.forward(&x, 2);
546
547        // Backward pass
548        let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
549        crate::autograd::backward(&mut output, Some(grad_out));
550
551        // All FFN weights should have gradients
552        assert!(ffn.w_gate.grad().is_some());
553        assert!(ffn.w_up.grad().is_some());
554        assert!(ffn.w_down.grad().is_some());
555    }
556
557    #[test]
558    fn test_ffn_backward_gradients_finite() {
559        let config = TransformerConfig::tiny();
560        let ffn = FeedForward::new(&config);
561        let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
562        let mut output = ffn.forward(&x, 2);
563
564        let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
565        crate::autograd::backward(&mut output, Some(grad_out));
566
567        // All gradients should be finite
568        let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
569        let grad_up = ffn.w_up.grad().expect("gradient should be available");
570        let grad_down = ffn.w_down.grad().expect("gradient should be available");
571
572        assert!(grad_gate.iter().all(|&v| v.is_finite()));
573        assert!(grad_up.iter().all(|&v| v.is_finite()));
574        assert!(grad_down.iter().all(|&v| v.is_finite()));
575    }
576
577    #[test]
578    fn test_ffn_backward_swiglu_activation() {
579        // Test that SwiGLU activation in FFN has proper gradients
580        let config = TransformerConfig::tiny();
581
582        // Test with various input magnitudes
583        for scale in [0.1, 1.0, 2.0] {
584            let ffn = FeedForward::new(&config);
585            let x = Tensor::from_vec(
586                (0..2 * config.hidden_size).map(|i| (i as f32 * 0.01).sin() * scale).collect(),
587                true,
588            );
589            let mut output = ffn.forward(&x, 2);
590
591            let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
592            crate::autograd::backward(&mut output, Some(grad_out));
593
594            let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
595            assert!(
596                grad_gate.iter().all(|&v| v.is_finite()),
597                "Gradients not finite for scale {scale}"
598            );
599        }
600    }
601
602    #[test]
603    fn test_ffn_backward_gradient_nonzero() {
604        let config = TransformerConfig::tiny();
605        let ffn = FeedForward::new(&config);
606        let x = Tensor::from_vec(vec![0.5; 2 * config.hidden_size], true);
607        let mut output = ffn.forward(&x, 2);
608
609        let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
610        crate::autograd::backward(&mut output, Some(grad_out));
611
612        // Gradients should not be all zero
613        let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
614        let sum: f32 = grad_gate.iter().map(|v| v.abs()).sum();
615        assert!(sum > 0.0, "FFN gate gradients should not be all zero");
616    }
617
618    #[test]
619    fn test_ffn_backward_different_seq_lengths() {
620        let config = TransformerConfig::tiny();
621
622        for seq_len in [1, 2, 4, 8] {
623            let ffn = FeedForward::new(&config);
624            let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
625            let mut output = ffn.forward(&x, seq_len);
626
627            let grad_out = ndarray::Array1::ones(seq_len * config.hidden_size);
628            crate::autograd::backward(&mut output, Some(grad_out));
629
630            let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
631            assert!(
632                grad_gate.iter().all(|&v| v.is_finite()),
633                "Non-finite gradient for seq_len {seq_len}"
634            );
635        }
636    }
637
638    #[test]
639    fn test_ffn_backward_gradient_accumulation() {
640        let config = TransformerConfig::tiny();
641        let ffn = FeedForward::new(&config);
642
643        // First forward-backward
644        let x1 = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
645        let mut output1 = ffn.forward(&x1, 2);
646        let grad_out1 = ndarray::Array1::ones(2 * config.hidden_size);
647        crate::autograd::backward(&mut output1, Some(grad_out1));
648        let grad1 = ffn.w_gate.grad().expect("gradient should be available").to_vec();
649
650        // Second forward-backward should accumulate
651        let x2 = Tensor::from_vec(vec![0.2; 2 * config.hidden_size], true);
652        let mut output2 = ffn.forward(&x2, 2);
653        let grad_out2 = ndarray::Array1::ones(2 * config.hidden_size);
654        crate::autograd::backward(&mut output2, Some(grad_out2));
655        let grad2 = ffn.w_gate.grad().expect("gradient should be available").to_vec();
656
657        // Gradients should have accumulated (different from first)
658        assert!(
659            grad2.iter().zip(grad1.iter()).any(|(g2, g1)| g2.abs() != g1.abs()),
660            "Gradients should accumulate across backward passes"
661        );
662    }
663
664    #[test]
665    fn test_ffn_backward_with_zero_input() {
666        let config = TransformerConfig::tiny();
667        let ffn = FeedForward::new(&config);
668        let x = Tensor::from_vec(vec![0.0; 2 * config.hidden_size], true);
669        let mut output = ffn.forward(&x, 2);
670
671        let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
672        crate::autograd::backward(&mut output, Some(grad_out));
673
674        // Should still produce finite gradients
675        let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
676        assert!(grad_gate.iter().all(|&v| v.is_finite()));
677    }
678
679    #[test]
680    fn test_ffn_backward_large_input() {
681        let config = TransformerConfig::tiny();
682        let ffn = FeedForward::new(&config);
683        let x = Tensor::from_vec(vec![10.0; 2 * config.hidden_size], true);
684        let mut output = ffn.forward(&x, 2);
685
686        let grad_out = ndarray::Array1::ones(2 * config.hidden_size);
687        crate::autograd::backward(&mut output, Some(grad_out));
688
689        // Should still produce finite gradients
690        let grad_gate = ffn.w_gate.grad().expect("gradient should be available");
691        assert!(grad_gate.iter().all(|&v| v.is_finite()));
692    }
693}