Skip to main content

ferrotorch_nn/
transformer.rs

1//! LLM-critical transformer building blocks.
2//!
3//! This module provides the components needed to build modern large language
4//! models:
5//!
6//! - [`RotaryPositionEmbedding`] — Rotary Position Embeddings (RoPE) as
7//!   described in Su et al. (2021). Precomputes sin/cos tables and applies
8//!   pairwise rotation to queries and keys.
9//!
10//! - [`SwiGLU`] — The gated linear unit with SiLU activation used in
11//!   LLaMA, Mistral, and other modern architectures: `w3(silu(w1(x)) * w2(x))`.
12//!
13//! - [`KVCache`] — Key-value cache for efficient autoregressive inference,
14//!   concatenating new key/value pairs with previously cached ones.
15//!
16//! - [`TransformerEncoderLayer`] — A pre-norm encoder block:
17//!   `norm -> self-attn -> residual -> norm -> ffn -> residual`.
18//!
19//! - [`TransformerDecoderLayer`] — A pre-norm decoder block with
20//!   self-attention, cross-attention, and feedforward sub-layers.
21//!
22//! ## REQ status (per `.design/ferrotorch-nn/transformer.md`)
23//!
24//! | REQ | Status | Evidence |
25//! |---|---|---|
26//! | REQ-1 | SHIPPED | the `RoPEConvention` enum here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:248` + `ferrotorch-llama/src/attention.rs:23` |
27//! | REQ-2 | SHIPPED | the `RotaryPositionEmbedding<T>` struct with `apply_rope` here; non-test consumer: re-export at `lib.rs:248` + `ferrotorch-llama/src/attention.rs:23` |
28//! | REQ-3 | SHIPPED | the `RoPEScaling` enum here; non-test consumer: re-export at `lib.rs:248` + `ferrotorch-llama/src/attention.rs:23` (Llama-3 long-context scaling) |
29//! | REQ-4 | SHIPPED | the `SwiGLU<T>` struct + `impl Module<T> for SwiGLU<T>` here; non-test consumer: re-export at `lib.rs:248` |
30//! | REQ-5 | SHIPPED | the `KVCache<T>` struct with `append_kv` / `clear` here; non-test consumer: re-export at `lib.rs:248` |
31//! | REQ-6 | SHIPPED | the `TransformerEncoderLayer<T>` struct here mirroring upstream `transformer.py:659-980`; non-test consumer: re-export at `lib.rs:248` |
32//! | REQ-7 | SHIPPED | the `TransformerDecoderLayer<T>` struct here mirroring upstream `transformer.py:981-1100`; non-test consumer: re-export at `lib.rs:248` |
33//! | REQ-8 | SHIPPED | the `TransformerEncoder<T>` struct here mirroring upstream `transformer.py:318-553`; non-test consumer: re-export at `lib.rs:248` |
34//! | REQ-9 | SHIPPED | the `TransformerDecoder<T>` struct here mirroring upstream `transformer.py:554-658`; non-test consumer: re-export at `lib.rs:248` |
35//! | REQ-10 | SHIPPED | the `Transformer<T>` struct here mirroring upstream `transformer.py:58-317`; non-test consumer: re-export at `lib.rs:248` |
36//! | REQ-11 | SHIPPED | the `impl<T: Float> Module<T>` blocks for every transformer struct here; non-test consumer: re-export at `lib.rs:248` |
37//! | REQ-12 | SHIPPED | the `RoPEBackward<T>` struct + `impl GradFn<T>` here; non-test consumer: re-export at `lib.rs:248` |
38
39use std::sync::Arc;
40
41use ferrotorch_core::autograd::no_grad::is_grad_enabled;
42use ferrotorch_core::grad_fns::activation::silu;
43use ferrotorch_core::grad_fns::arithmetic::{add, mul};
44use ferrotorch_core::grad_fns::shape::reshape;
45use ferrotorch_core::tensor::GradFn;
46use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
47
48use crate::attention::MultiheadAttention;
49use crate::dropout::Dropout;
50use crate::linear::Linear;
51use crate::module::Module;
52use crate::norm::LayerNorm;
53use crate::parameter::Parameter;
54
55// ===========================================================================
56// RotaryPositionEmbedding (RoPE)
57// ===========================================================================
58
59/// Selects how RoPE pairs elements for rotation.
60///
61/// - **`Interleaved`** (default) — pairs `(x[2i], x[2i+1])`.
62///   Used by the original RoFormer paper.
63/// - **`HalfRotation`** — pairs `(x[i], x[i + d/2])`.
64///   Used by LLaMA, GPT-NeoX, and Pythia.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
66pub enum RoPEConvention {
67    /// Pairs consecutive elements: `(x[2i], x[2i+1])`. Original RoFormer.
68    #[default]
69    Interleaved,
70    /// Pairs first-half with second-half: `(x[i], x[i+d/2])`. LLaMA/GPT-NeoX.
71    HalfRotation,
72}
73
74impl std::fmt::Display for RoPEConvention {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            RoPEConvention::Interleaved => write!(f, "interleaved"),
78            RoPEConvention::HalfRotation => write!(f, "half_rotation"),
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// RoPEBackward
85// ---------------------------------------------------------------------------
86
87/// Backward node for RoPE.
88///
89/// RoPE applies a linear rotation per position, so the backward pass
90/// applies the *inverse* rotation (transpose of the rotation matrix,
91/// which is just cos / -sin swap).
92#[derive(Debug)]
93struct RoPEBackward<T: Float> {
94    input: Tensor<T>,
95    cos_flat: Vec<T>,
96    sin_flat: Vec<T>,
97    half_dim: usize,
98    seq_len: usize,
99    batch_dims: usize,
100    dim: usize,
101    seq_offset: usize,
102    convention: RoPEConvention,
103}
104
105impl<T: Float> GradFn<T> for RoPEBackward<T> {
106    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
107        let da = if self.input.requires_grad() {
108            let go_data = grad_output.data_vec()?;
109            let total = go_data.len();
110            let mut grad_input = Vec::with_capacity(total);
111
112            match self.convention {
113                RoPEConvention::Interleaved => {
114                    for b in 0..self.batch_dims {
115                        for s in 0..self.seq_len {
116                            let cache_start = (self.seq_offset + s) * self.half_dim;
117                            let go_start = b * self.seq_len * self.dim + s * self.dim;
118
119                            for i in 0..self.half_dim {
120                                let go_even = go_data[go_start + 2 * i];
121                                let go_odd = go_data[go_start + 2 * i + 1];
122                                let cos_val = self.cos_flat[cache_start + i];
123                                let sin_val = self.sin_flat[cache_start + i];
124
125                                // Inverse rotation: R^T * grad
126                                grad_input.push(go_even * cos_val + go_odd * sin_val);
127                                grad_input.push(-go_even * sin_val + go_odd * cos_val);
128                            }
129                        }
130                    }
131                }
132                RoPEConvention::HalfRotation => {
133                    for b in 0..self.batch_dims {
134                        for s in 0..self.seq_len {
135                            let cache_start = (self.seq_offset + s) * self.half_dim;
136                            let go_start = b * self.seq_len * self.dim + s * self.dim;
137
138                            // First half of grad_input: dx[i] = go_first[i]*cos + go_second[i]*sin
139                            for i in 0..self.half_dim {
140                                let go_first = go_data[go_start + i];
141                                let go_second = go_data[go_start + self.half_dim + i];
142                                let cos_val = self.cos_flat[cache_start + i];
143                                let sin_val = self.sin_flat[cache_start + i];
144
145                                grad_input.push(go_first * cos_val + go_second * sin_val);
146                            }
147                            // Second half: dx[i+d/2] = -go_first[i]*sin + go_second[i]*cos
148                            for i in 0..self.half_dim {
149                                let go_first = go_data[go_start + i];
150                                let go_second = go_data[go_start + self.half_dim + i];
151                                let cos_val = self.cos_flat[cache_start + i];
152                                let sin_val = self.sin_flat[cache_start + i];
153
154                                grad_input.push(-go_first * sin_val + go_second * cos_val);
155                            }
156                        }
157                    }
158                }
159            }
160
161            let g = Tensor::from_storage(
162                TensorStorage::cpu(grad_input),
163                self.input.shape().to_vec(),
164                false,
165            )?;
166            Some(if self.input.is_cuda() {
167                g.to(self.input.device())?
168            } else {
169                g
170            })
171        } else {
172            None
173        };
174        Ok(vec![da])
175    }
176
177    fn inputs(&self) -> Vec<&Tensor<T>> {
178        vec![&self.input]
179    }
180
181    fn name(&self) -> &'static str {
182        "RoPEBackward"
183    }
184}
185
186/// Rotary Position Embeddings (RoPE).
187///
188/// Precomputes sin/cos frequency tables up to `max_seq_len` and applies
189/// pairwise rotation to an input tensor, encoding absolute position
190/// information that degrades gracefully with relative distance.
191///
192/// Two element-pairing conventions are supported (see [`RoPEConvention`]):
193///
194/// - **Interleaved** (default): pairs `(x[2i], x[2i+1])` — RoFormer.
195/// - **HalfRotation**: pairs `(x[i], x[i+d/2])` — LLaMA, GPT-NeoX, Pythia.
196///
197/// RoPE is **not** a [`Module`] — it is a stateless utility applied inside
198/// attention layers before the dot product.
199///
200/// # Shape contract
201///
202/// - Input: `[..., seq_len, dim]` where `dim` is even.
203/// - Output: same shape as input.
204///
205/// # Reference
206///
207/// Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021).
208#[derive(Debug)]
209pub struct RotaryPositionEmbedding<T: Float> {
210    dim: usize,
211    max_seq_len: usize,
212    base: f64,
213    convention: RoPEConvention,
214    scaling: RoPEScaling,
215    /// Precomputed cosines: `[max_seq_len, dim/2]`.
216    cos_cache: Tensor<T>,
217    /// Precomputed sines: `[max_seq_len, dim/2]`.
218    sin_cache: Tensor<T>,
219}
220
221/// Frequency-scaling strategy for [`RotaryPositionEmbedding`].
222///
223/// Determines how `inv_freq[i] = 1 / base^(2i / dim)` is modified to
224/// support context lengths beyond the model's training distribution.
225/// Default is [`RoPEScaling::None`] (classical RoPE with no modification).
226#[derive(Debug, Clone, Copy, PartialEq, Default)]
227pub enum RoPEScaling {
228    /// No scaling. `inv_freq[i] = 1 / base^(2i / dim)`. Matches the
229    /// original RoFormer formulation and the Llama 3 8B base model
230    /// (8k context, no stretching).
231    #[default]
232    None,
233
234    /// Linear / positional interpolation (Chen et al. 2023,
235    /// kaiokendev's "linear RoPE scaling"). Divides the frequencies
236    /// uniformly by `factor`, so a position of `p` tokens rotates at
237    /// the same angle as position `p / factor` under the unscaled
238    /// schedule. Extends context by `factor` at the cost of short-text
239    /// quality.
240    Linear {
241        /// Target-context / original-context ratio. >= 1.
242        factor: f64,
243    },
244
245    /// NTK-aware scaling (bloc97's "NTK-Aware Scaled RoPE"). Scales the
246    /// base itself via `base_new = base * factor^(dim / (dim - 2))`.
247    /// Preserves high-frequency components (short-range token ordering)
248    /// while stretching low-frequency components (long-range position).
249    /// Default choice for short-extension long-context Llama variants
250    /// that don't want short-text degradation.
251    NtkAware {
252        /// Target-context / original-context ratio. >= 1.
253        factor: f64,
254        /// Original maximum context length the model was trained on.
255        /// Reserved for future use — current NTK formulation only
256        /// depends on `factor` and `dim`. Held on the struct so the
257        /// config the model ships with is fully captured.
258        original_max_pos_embeddings: usize,
259    },
260
261    /// YARN scaling (Peng et al. 2023, "YaRN: Efficient Context Window
262    /// Extension of Large Language Models"). Per-dimension piecewise
263    /// mix between PI (linear interpolation) for low-frequency dims
264    /// and extrapolation (no scaling) for high-frequency dims, with a
265    /// linear ramp in between. Generally the best-quality long-context
266    /// scaling; used by Mistral 7B v0.2, CodeLlama, and Yi.
267    Yarn {
268        /// Target-context / original-context ratio. >= 1.
269        factor: f64,
270        /// Original maximum context length the model was trained on.
271        original_max_pos_embeddings: usize,
272        /// Number of rotations above which a dimension extrapolates
273        /// (no interpolation). Default in the paper is 32.
274        beta_fast: f64,
275        /// Number of rotations below which a dimension fully
276        /// interpolates (linear PI). Default in the paper is 1.
277        beta_slow: f64,
278    },
279}
280
281impl RoPEScaling {
282    /// Convenience: YARN with the paper's default beta_fast=32, beta_slow=1.
283    pub const fn yarn_default(factor: f64, original_max_pos_embeddings: usize) -> Self {
284        RoPEScaling::Yarn {
285            factor,
286            original_max_pos_embeddings,
287            beta_fast: 32.0,
288            beta_slow: 1.0,
289        }
290    }
291}
292
293/// YARN helper: for a given `num_rotations` target, solve for the
294/// dimension index at which a base-schedule RoPE produces that many
295/// rotations across `original_max_pos_embeddings` positions.
296fn yarn_find_correction_dim(
297    num_rotations: f64,
298    dim: usize,
299    base: f64,
300    original_max_pos_embeddings: usize,
301) -> f64 {
302    // dim * ln(L / (num_rotations * 2pi)) / (2 ln(base))
303    (dim as f64
304        * (original_max_pos_embeddings as f64 / (num_rotations * 2.0 * std::f64::consts::PI)).ln())
305        / (2.0 * base.ln())
306}
307
308/// YARN helper: bracket the correction range to valid indices.
309fn yarn_find_correction_range(
310    low_rot: f64,
311    high_rot: f64,
312    dim: usize,
313    base: f64,
314    original_max_pos_embeddings: usize,
315) -> (f64, f64) {
316    let low = yarn_find_correction_dim(low_rot, dim, base, original_max_pos_embeddings).floor();
317    let high = yarn_find_correction_dim(high_rot, dim, base, original_max_pos_embeddings).ceil();
318    (low.max(0.0), high.min((dim - 1) as f64))
319}
320
321/// Compute inv_freq[i] = 1 / base^(2i / dim) for i in 0..dim/2.
322fn compute_base_inv_freq(dim: usize, base: f64) -> Vec<f64> {
323    let half = dim / 2;
324    (0..half)
325        .map(|i| 1.0 / base.powf(2.0 * i as f64 / dim as f64))
326        .collect()
327}
328
329/// Compute the per-dim inv_freq vector under the given scaling policy.
330///
331/// Exposed at `pub(crate)` so unit tests can verify the math directly
332/// without round-tripping through the precomputed cos/sin caches.
333pub(crate) fn compute_scaled_inv_freq(dim: usize, base: f64, scaling: RoPEScaling) -> Vec<f64> {
334    match scaling {
335        RoPEScaling::None => compute_base_inv_freq(dim, base),
336
337        RoPEScaling::Linear { factor } => {
338            let mut iv = compute_base_inv_freq(dim, base);
339            for v in iv.iter_mut() {
340                *v /= factor;
341            }
342            iv
343        }
344
345        RoPEScaling::NtkAware { factor, .. } => {
346            // NTK-Aware: base' = base * factor^(dim / (dim - 2)).
347            // Exponent is chosen so the *highest* frequency (i = 0) is
348            // unchanged while the *lowest* frequency (i = dim/2 - 1)
349            // is scaled by ~1/factor, matching linear PI at the long end.
350            let exp = dim as f64 / (dim as f64 - 2.0);
351            let base_scaled = base * factor.powf(exp);
352            compute_base_inv_freq(dim, base_scaled)
353        }
354
355        RoPEScaling::Yarn {
356            factor,
357            original_max_pos_embeddings,
358            beta_fast,
359            beta_slow,
360        } => {
361            let half = dim / 2;
362            let pos_freqs: Vec<f64> = (0..half)
363                .map(|i| base.powf(2.0 * i as f64 / dim as f64))
364                .collect();
365            let extrapolation: Vec<f64> = pos_freqs.iter().map(|p| 1.0 / p).collect();
366            let interpolation: Vec<f64> = pos_freqs.iter().map(|p| 1.0 / (factor * p)).collect();
367
368            let (low, high) = yarn_find_correction_range(
369                beta_fast,
370                beta_slow,
371                dim,
372                base,
373                original_max_pos_embeddings,
374            );
375            // Map the full-dim correction range onto the half-dim inv_freq
376            // index space.
377            let (low, high) = (low / 2.0, high / 2.0);
378
379            // ramp_mask[i] is 1.0 at i <= low (full extrapolation),
380            // 0.0 at i >= high (full interpolation), linear ramp
381            // between. Paper uses "1 - linear_ramp(low, high, dim/2)"
382            // to get this shape.
383            let denom = if high == low { 0.001 } else { high - low };
384            (0..half)
385                .map(|i| {
386                    let t = ((i as f64 - low) / denom).clamp(0.0, 1.0);
387                    // Invert: high-freq (small i) keeps extrapolation.
388                    let mask = 1.0 - t;
389                    interpolation[i] * (1.0 - mask) + extrapolation[i] * mask
390                })
391                .collect()
392        }
393    }
394}
395
396impl<T: Float> RotaryPositionEmbedding<T> {
397    /// Create a new RoPE instance with the default interleaved convention
398    /// and no frequency scaling.
399    ///
400    /// # Arguments
401    ///
402    /// - `dim` - The embedding dimension (must be even).
403    /// - `max_seq_len` - Maximum sequence length to precompute.
404    /// - `base` - Base for the frequency computation (default: 10 000.0).
405    ///
406    /// # Errors
407    ///
408    /// Returns an error if `dim` is odd or zero, or if `max_seq_len` is zero.
409    pub fn new(dim: usize, max_seq_len: usize, base: f64) -> FerrotorchResult<Self> {
410        Self::with_scaling(
411            dim,
412            max_seq_len,
413            base,
414            RoPEConvention::default(),
415            RoPEScaling::None,
416        )
417    }
418
419    /// Create a new RoPE instance with a specified pairing convention
420    /// and no frequency scaling.
421    ///
422    /// Use [`RoPEConvention::HalfRotation`] for LLaMA, GPT-NeoX, and Pythia
423    /// compatibility.
424    pub fn with_convention(
425        dim: usize,
426        max_seq_len: usize,
427        base: f64,
428        convention: RoPEConvention,
429    ) -> FerrotorchResult<Self> {
430        Self::with_scaling(dim, max_seq_len, base, convention, RoPEScaling::None)
431    }
432
433    /// Create a new RoPE instance with explicit convention and scaling.
434    ///
435    /// `scaling` selects between no scaling, linear PI, NTK-aware, and
436    /// YARN extension strategies. See [`RoPEScaling`] for the per-variant
437    /// semantics; [`RoPEScaling::None`] reproduces classical RoPE.
438    pub fn with_scaling(
439        dim: usize,
440        max_seq_len: usize,
441        base: f64,
442        convention: RoPEConvention,
443        scaling: RoPEScaling,
444    ) -> FerrotorchResult<Self> {
445        if dim == 0 || dim % 2 != 0 {
446            return Err(FerrotorchError::InvalidArgument {
447                message: format!("RoPE dim must be even and positive, got {dim}"),
448            });
449        }
450        if max_seq_len == 0 {
451            return Err(FerrotorchError::InvalidArgument {
452                message: "RoPE max_seq_len must be positive".into(),
453            });
454        }
455        if let RoPEScaling::Linear { factor }
456        | RoPEScaling::NtkAware { factor, .. }
457        | RoPEScaling::Yarn { factor, .. } = scaling
458        {
459            if !(factor.is_finite() && factor > 0.0) {
460                return Err(FerrotorchError::InvalidArgument {
461                    message: format!("RoPE scaling factor must be finite and > 0, got {factor}"),
462                });
463            }
464        }
465
466        let half_dim = dim / 2;
467        let thetas = compute_scaled_inv_freq(dim, base, scaling);
468
469        // cos_cache[pos, i] = cos(pos * theta_i)
470        // sin_cache[pos, i] = sin(pos * theta_i)
471        let total = max_seq_len * half_dim;
472        let mut cos_data = Vec::with_capacity(total);
473        let mut sin_data = Vec::with_capacity(total);
474
475        for pos in 0..max_seq_len {
476            for &theta in &thetas {
477                let angle = pos as f64 * theta;
478                cos_data.push(T::from(angle.cos()).unwrap());
479                sin_data.push(T::from(angle.sin()).unwrap());
480            }
481        }
482
483        let cos_cache = Tensor::from_storage(
484            TensorStorage::cpu(cos_data),
485            vec![max_seq_len, half_dim],
486            false,
487        )?;
488        let sin_cache = Tensor::from_storage(
489            TensorStorage::cpu(sin_data),
490            vec![max_seq_len, half_dim],
491            false,
492        )?;
493
494        Ok(Self {
495            dim,
496            max_seq_len,
497            base,
498            convention,
499            scaling,
500            cos_cache,
501            sin_cache,
502        })
503    }
504
505    /// Apply rotary embeddings to `x` starting at position `seq_offset`.
506    ///
507    /// # Shape
508    ///
509    /// - `x`: any shape where the last dimension equals `dim` and the
510    ///   second-to-last dimension is the sequence length.
511    /// - Returns a tensor of the same shape.
512    ///
513    /// For a typical attention head input of shape `[batch, num_heads, seq_len, head_dim]`,
514    /// the rotation is applied over `(seq_len, head_dim)`.
515    ///
516    /// # Errors
517    ///
518    /// Returns an error if `seq_offset + seq_len > max_seq_len` or if the
519    /// last dimension of `x` does not equal `dim`.
520    pub fn apply(&self, x: &Tensor<T>, seq_offset: usize) -> FerrotorchResult<Tensor<T>> {
521        let shape = x.shape();
522        let ndim = shape.len();
523        if ndim < 2 {
524            return Err(FerrotorchError::InvalidArgument {
525                message: format!(
526                    "RoPE input must be at least 2-D, got {ndim}-D with shape {shape:?}"
527                ),
528            });
529        }
530
531        let last_dim = shape[ndim - 1];
532        if last_dim != self.dim {
533            return Err(FerrotorchError::ShapeMismatch {
534                message: format!("RoPE: last dim of input ({last_dim}) != dim ({})", self.dim),
535            });
536        }
537
538        let seq_len = shape[ndim - 2];
539        if seq_offset + seq_len > self.max_seq_len {
540            return Err(FerrotorchError::InvalidArgument {
541                message: format!(
542                    "RoPE: seq_offset ({seq_offset}) + seq_len ({seq_len}) > max_seq_len ({})",
543                    self.max_seq_len
544                ),
545            });
546        }
547
548        let device = x.device();
549        let half_dim = self.dim / 2;
550        let cos_data = self.cos_cache.data_vec()?;
551        let sin_data = self.sin_cache.data_vec()?;
552        let x_data = x.data_vec()?;
553
554        // Number of independent "instances" before the (seq, dim) axes.
555        let batch_dims: usize = shape[..ndim - 2].iter().product();
556
557        let total = x.numel();
558        let mut output = Vec::with_capacity(total);
559
560        match self.convention {
561            RoPEConvention::Interleaved => {
562                // Pair (x[2i], x[2i+1]) — original RoFormer convention.
563                for b in 0..batch_dims {
564                    for s in 0..seq_len {
565                        let pos = seq_offset + s;
566                        let cache_start = pos * half_dim;
567                        let x_start = b * seq_len * self.dim + s * self.dim;
568
569                        for i in 0..half_dim {
570                            let x_even = x_data[x_start + 2 * i];
571                            let x_odd = x_data[x_start + 2 * i + 1];
572                            let cos_val = cos_data[cache_start + i];
573                            let sin_val = sin_data[cache_start + i];
574
575                            output.push(x_even * cos_val - x_odd * sin_val);
576                            output.push(x_even * sin_val + x_odd * cos_val);
577                        }
578                    }
579                }
580            }
581            RoPEConvention::HalfRotation => {
582                // Pair (x[i], x[i + d/2]) — LLaMA/GPT-NeoX convention.
583                // Output layout: first half then second half (same as input).
584                for b in 0..batch_dims {
585                    for s in 0..seq_len {
586                        let pos = seq_offset + s;
587                        let cache_start = pos * half_dim;
588                        let x_start = b * seq_len * self.dim + s * self.dim;
589
590                        // First half: x_rot[i] = x[i] * cos - x[i + d/2] * sin
591                        for i in 0..half_dim {
592                            let x_first = x_data[x_start + i];
593                            let x_second = x_data[x_start + half_dim + i];
594                            let cos_val = cos_data[cache_start + i];
595                            let sin_val = sin_data[cache_start + i];
596
597                            output.push(x_first * cos_val - x_second * sin_val);
598                        }
599                        // Second half: x_rot[i + d/2] = x[i] * sin + x[i + d/2] * cos
600                        for i in 0..half_dim {
601                            let x_first = x_data[x_start + i];
602                            let x_second = x_data[x_start + half_dim + i];
603                            let cos_val = cos_data[cache_start + i];
604                            let sin_val = sin_data[cache_start + i];
605
606                            output.push(x_first * sin_val + x_second * cos_val);
607                        }
608                    }
609                }
610            }
611        }
612
613        let result = if is_grad_enabled() && x.requires_grad() {
614            Tensor::from_operation(
615                TensorStorage::cpu(output),
616                shape.to_vec(),
617                Arc::new(RoPEBackward {
618                    input: x.clone(),
619                    cos_flat: cos_data,
620                    sin_flat: sin_data,
621                    half_dim,
622                    seq_len,
623                    batch_dims,
624                    dim: self.dim,
625                    seq_offset,
626                    convention: self.convention,
627                }),
628            )?
629        } else {
630            Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?
631        };
632        if device.is_cuda() {
633            result.to(device)
634        } else {
635            Ok(result)
636        }
637    }
638
639    /// The embedding dimension.
640    #[inline]
641    pub fn dim(&self) -> usize {
642        self.dim
643    }
644
645    /// The maximum sequence length the cache supports.
646    #[inline]
647    pub fn max_seq_len(&self) -> usize {
648        self.max_seq_len
649    }
650
651    /// The frequency base.
652    #[inline]
653    pub fn base(&self) -> f64 {
654        self.base
655    }
656
657    /// The pairing convention.
658    #[inline]
659    pub fn convention(&self) -> RoPEConvention {
660        self.convention
661    }
662
663    /// The frequency-scaling strategy this RoPE instance was built with.
664    #[inline]
665    pub fn scaling(&self) -> RoPEScaling {
666        self.scaling
667    }
668}
669
670// ===========================================================================
671// SwiGLU
672// ===========================================================================
673
674/// Gated Linear Unit with SiLU activation (SwiGLU).
675///
676/// Applies the feedforward network used in LLaMA, Mistral, and other
677/// modern transformer architectures:
678///
679/// ```text
680/// SwiGLU(x) = w3(silu(w1(x)) * w2(x))
681/// ```
682///
683/// where `w1` is the gate projection, `w2` is the up projection, and
684/// `w3` is the down projection.
685///
686/// # Shape contract
687///
688/// - Input: `[batch, seq_len, in_features]` (3-D) or `[batch, in_features]` (2-D).
689/// - Output: same shape as input.
690///
691/// Internally, `w1` and `w2` project from `in_features` to `hidden_features`,
692/// and `w3` projects back from `hidden_features` to `in_features`.
693#[derive(Debug)]
694pub struct SwiGLU<T: Float> {
695    /// Gate projection: `[in_features] -> [hidden_features]`.
696    w1: Linear<T>,
697    /// Up projection: `[in_features] -> [hidden_features]`.
698    w2: Linear<T>,
699    /// Down projection: `[hidden_features] -> [in_features]`.
700    w3: Linear<T>,
701    training: bool,
702}
703
704impl<T: Float> SwiGLU<T> {
705    /// Create a new SwiGLU layer.
706    ///
707    /// # Arguments
708    ///
709    /// - `in_features` - Input (and output) dimension.
710    /// - `hidden_features` - Hidden dimension of the gate/up projections.
711    ///   A common choice is `(8/3) * in_features` rounded to a multiple of 256.
712    /// - `bias` - Whether to include bias in the linear layers.
713    pub fn new(in_features: usize, hidden_features: usize, bias: bool) -> FerrotorchResult<Self> {
714        let w1 = Linear::new(in_features, hidden_features, bias)?;
715        let w2 = Linear::new(in_features, hidden_features, bias)?;
716        let w3 = Linear::new(hidden_features, in_features, bias)?;
717
718        Ok(Self {
719            w1,
720            w2,
721            w3,
722            training: true,
723        })
724    }
725
726    /// Forward pass for 3-D input `[batch, seq_len, in_features]`.
727    ///
728    /// Internally reshapes to 2-D, applies the SwiGLU computation, then
729    /// reshapes back. Uses differentiable `reshape` to preserve autograd.
730    fn forward_3d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
731        let shape = input.shape();
732        let batch = shape[0];
733        let seq_len = shape[1];
734
735        // Flatten to [batch * seq_len, features] — differentiable reshape.
736        let flat = reshape(input, &[(batch * seq_len) as isize, -1])?;
737
738        let output_flat = self.forward_2d(&flat)?;
739
740        // Reshape back to [batch, seq_len, out_features] — differentiable.
741        let out_features = output_flat.shape()[1];
742        reshape(
743            &output_flat,
744            &[batch as isize, seq_len as isize, out_features as isize],
745        )
746    }
747
748    /// Forward pass for 2-D input `[batch, in_features]`.
749    fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
750        // gate = silu(w1(x))
751        let w1_out = self.w1.forward(input)?;
752        let gate = silu(&w1_out)?;
753
754        // up = w2(x)
755        let up = self.w2.forward(input)?;
756
757        // gated = gate * up (elementwise)
758        let gated = mul(&gate, &up)?;
759
760        // down = w3(gated)
761        self.w3.forward(&gated)
762    }
763}
764
765impl<T: Float> Module<T> for SwiGLU<T> {
766    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
767        match input.ndim() {
768            2 => self.forward_2d(input),
769            3 => self.forward_3d(input),
770            _ => Err(FerrotorchError::InvalidArgument {
771                message: format!(
772                    "SwiGLU expects 2-D or 3-D input, got {}-D with shape {:?}",
773                    input.ndim(),
774                    input.shape()
775                ),
776            }),
777        }
778    }
779
780    fn parameters(&self) -> Vec<&Parameter<T>> {
781        let mut params = self.w1.parameters();
782        params.extend(self.w2.parameters());
783        params.extend(self.w3.parameters());
784        params
785    }
786
787    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
788        let mut params = self.w1.parameters_mut();
789        params.extend(self.w2.parameters_mut());
790        params.extend(self.w3.parameters_mut());
791        params
792    }
793
794    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
795        let mut params = Vec::new();
796        for (name, param) in self.w1.named_parameters() {
797            params.push((format!("w1.{name}"), param));
798        }
799        for (name, param) in self.w2.named_parameters() {
800            params.push((format!("w2.{name}"), param));
801        }
802        for (name, param) in self.w3.named_parameters() {
803            params.push((format!("w3.{name}"), param));
804        }
805        params
806    }
807
808    fn train(&mut self) {
809        self.training = true;
810        self.w1.train();
811        self.w2.train();
812        self.w3.train();
813    }
814
815    fn eval(&mut self) {
816        self.training = false;
817        self.w1.eval();
818        self.w2.eval();
819        self.w3.eval();
820    }
821
822    fn is_training(&self) -> bool {
823        self.training
824    }
825}
826
827// ===========================================================================
828// KVCache
829// ===========================================================================
830
831/// Dimensions a [`KVCache`] is pinned to after its first update (or when
832/// pre-declared via [`KVCache::with_dims`]). Every subsequent update must
833/// match these exactly on `batch`, `num_kv_heads`, and `head_dim`.
834#[derive(Debug, Clone, Copy, PartialEq, Eq)]
835struct CacheDims {
836    batch: usize,
837    num_kv_heads: usize,
838    head_dim: usize,
839}
840
841/// Key-value cache for efficient autoregressive (incremental) inference.
842///
843/// During generation, previously computed keys and values are cached so that
844/// each new token only requires computing its own K and V, then concatenating
845/// with the cache before attention.
846///
847/// # Shape convention
848///
849/// Keys and values are stored as `[batch, num_kv_heads, seq_len, head_dim]`.
850/// The cache grows along the `seq_len` axis (dimension 2).
851///
852/// ## Grouped-Query Attention
853///
854/// For models with grouped-query attention (e.g. Llama 3 8B: 32 Q heads,
855/// 8 KV heads) the cache stores at KV-head granularity — `dim 1 = num_kv_heads`,
856/// not `num_q_heads`. This keeps the cache ~1/4 the size versus storing
857/// at Q-head granularity. `repeat_kv` happens at *read* time, inside the
858/// attention computation, not at cache time.
859///
860/// Pre-declare the expected shape with [`KVCache::with_dims`] to get
861/// strict validation from the first update, or use [`KVCache::new`] to
862/// let the dims be inferred from the first push (matching the pre-GQA
863/// behaviour).
864#[derive(Debug)]
865pub struct KVCache<T: Float> {
866    /// Cached keys: `[B, num_kv_heads, cached_seq, head_dim]`, or `None` if empty.
867    key_cache: Option<Tensor<T>>,
868    /// Cached values: same shape as `key_cache`.
869    value_cache: Option<Tensor<T>>,
870    /// Maximum sequence length the cache will hold.
871    max_seq_len: usize,
872    /// Pinned dimensions (batch, num_kv_heads, head_dim). `None` means the
873    /// cache hasn't been populated or pre-declared yet.
874    dims: Option<CacheDims>,
875}
876
877impl<T: Float> KVCache<T> {
878    /// Create an empty cache with dims inferred from the first update.
879    ///
880    /// `max_seq_len` is the upper bound on the total cached sequence length.
881    pub fn new(max_seq_len: usize) -> Self {
882        Self {
883            key_cache: None,
884            value_cache: None,
885            max_seq_len,
886            dims: None,
887        }
888    }
889
890    /// Create an empty cache with pre-declared dimensions.
891    ///
892    /// Every subsequent [`update`](Self::update) must supply tensors with
893    /// shape `[batch, num_kv_heads, _, head_dim]`. Mismatches fail on the
894    /// very first update rather than silently poisoning the cache.
895    ///
896    /// For Llama 3 8B: `with_dims(max_seq_len, 1, 8, 128)`.
897    pub fn with_dims(
898        max_seq_len: usize,
899        batch: usize,
900        num_kv_heads: usize,
901        head_dim: usize,
902    ) -> Self {
903        Self {
904            key_cache: None,
905            value_cache: None,
906            max_seq_len,
907            dims: Some(CacheDims {
908                batch,
909                num_kv_heads,
910                head_dim,
911            }),
912        }
913    }
914
915    /// Append new keys and values to the cache.
916    ///
917    /// Returns the **full** (concatenated) keys and values for use in the
918    /// current attention step.
919    ///
920    /// # Arguments
921    ///
922    /// - `key` - New key tensor: `[B, num_kv_heads, new_seq, head_dim]`.
923    /// - `value` - New value tensor: same shape as `key`.
924    ///
925    /// # Returns
926    ///
927    /// `(full_key, full_value)` with shape
928    /// `[B, num_kv_heads, cached_seq + new_seq, head_dim]`.
929    pub fn update(
930        &mut self,
931        key: Tensor<T>,
932        value: Tensor<T>,
933    ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
934        if key.ndim() != 4 || value.ndim() != 4 {
935            return Err(FerrotorchError::InvalidArgument {
936                message: format!(
937                    "KVCache expects 4-D [B, kv_heads, seq, dim] tensors, \
938                     got key {:?}, value {:?}",
939                    key.shape(),
940                    value.shape()
941                ),
942            });
943        }
944
945        if key.shape() != value.shape() {
946            return Err(FerrotorchError::ShapeMismatch {
947                message: format!(
948                    "KVCache: key shape {:?} != value shape {:?}",
949                    key.shape(),
950                    value.shape()
951                ),
952            });
953        }
954
955        let ks = key.shape();
956        let incoming = CacheDims {
957            batch: ks[0],
958            num_kv_heads: ks[1],
959            head_dim: ks[3],
960        };
961
962        match &self.dims {
963            Some(expected) if expected != &incoming => {
964                return Err(FerrotorchError::ShapeMismatch {
965                    message: format!(
966                        "KVCache: update shape [B={}, kv_heads={}, _, dim={}] does not \
967                         match pinned dims [B={}, kv_heads={}, _, dim={}]",
968                        incoming.batch,
969                        incoming.num_kv_heads,
970                        incoming.head_dim,
971                        expected.batch,
972                        expected.num_kv_heads,
973                        expected.head_dim,
974                    ),
975                });
976            }
977            None => self.dims = Some(incoming),
978            _ => {}
979        }
980
981        let (full_key, full_value) = match (&self.key_cache, &self.value_cache) {
982            (Some(ck), Some(cv)) => {
983                let fk = concat_along_dim2(ck, &key)?;
984                let fv = concat_along_dim2(cv, &value)?;
985                (fk, fv)
986            }
987            _ => (key.clone(), value.clone()),
988        };
989
990        // Check that the total cached sequence length does not exceed the limit.
991        let total_seq = full_key.shape()[2];
992        if total_seq > self.max_seq_len {
993            return Err(FerrotorchError::InvalidArgument {
994                message: format!(
995                    "KVCache: total sequence length ({total_seq}) exceeds max_seq_len ({})",
996                    self.max_seq_len
997                ),
998            });
999        }
1000
1001        self.key_cache = Some(full_key.clone());
1002        self.value_cache = Some(full_value.clone());
1003
1004        Ok((full_key, full_value))
1005    }
1006
1007    /// Reset the cache, discarding all stored keys and values. The pinned
1008    /// dimensions (if any) are preserved — a cache created via
1009    /// [`with_dims`](Self::with_dims) still validates the next update
1010    /// against the original declaration.
1011    pub fn reset(&mut self) {
1012        self.key_cache = None;
1013        self.value_cache = None;
1014    }
1015
1016    /// The current cached sequence length (0 if empty).
1017    pub fn seq_len(&self) -> usize {
1018        self.key_cache.as_ref().map(|k| k.shape()[2]).unwrap_or(0)
1019    }
1020
1021    /// Whether the cache holds any keys/values.
1022    pub fn is_empty(&self) -> bool {
1023        self.key_cache.is_none()
1024    }
1025
1026    /// Maximum sequence length.
1027    #[inline]
1028    pub fn max_seq_len(&self) -> usize {
1029        self.max_seq_len
1030    }
1031
1032    /// Number of KV heads the cache is pinned to, once an update has
1033    /// happened (or pre-declared via [`with_dims`](Self::with_dims)).
1034    /// Returns `None` for a fresh [`new`](Self::new) cache that has not
1035    /// yet been updated.
1036    pub fn num_kv_heads(&self) -> Option<usize> {
1037        self.dims.map(|d| d.num_kv_heads)
1038    }
1039
1040    /// Head dimension (`head_dim`), once pinned. See [`num_kv_heads`].
1041    pub fn head_dim(&self) -> Option<usize> {
1042        self.dims.map(|d| d.head_dim)
1043    }
1044
1045    /// Batch size, once pinned. See [`num_kv_heads`].
1046    pub fn batch_size(&self) -> Option<usize> {
1047        self.dims.map(|d| d.batch)
1048    }
1049}
1050
1051/// Concatenate two 4-D tensors along dimension 2 (the sequence axis).
1052///
1053/// Shapes must match on dims 0, 1, 3.
1054fn concat_along_dim2<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1055    let sa = a.shape();
1056    let sb = b.shape();
1057
1058    if sa[0] != sb[0] || sa[1] != sb[1] || sa[3] != sb[3] {
1059        return Err(FerrotorchError::ShapeMismatch {
1060            message: format!(
1061                "concat_along_dim2: shapes {:?} and {:?} must match on dims 0, 1, 3",
1062                sa, sb
1063            ),
1064        });
1065    }
1066
1067    let device = a.device();
1068    let (batch, heads, seq_a, dim) = (sa[0], sa[1], sa[2], sa[3]);
1069    let seq_b = sb[2];
1070    let seq_out = seq_a + seq_b;
1071
1072    let a_data = a.data_vec()?;
1073    let b_data = b.data_vec()?;
1074
1075    let mut output = Vec::with_capacity(batch * heads * seq_out * dim);
1076
1077    for ba in 0..batch {
1078        for h in 0..heads {
1079            // Copy rows from a.
1080            let a_start = (ba * heads + h) * seq_a * dim;
1081            output.extend_from_slice(&a_data[a_start..a_start + seq_a * dim]);
1082            // Copy rows from b.
1083            let b_start = (ba * heads + h) * seq_b * dim;
1084            output.extend_from_slice(&b_data[b_start..b_start + seq_b * dim]);
1085        }
1086    }
1087
1088    let result = Tensor::from_storage(
1089        TensorStorage::cpu(output),
1090        vec![batch, heads, seq_out, dim],
1091        false,
1092    )?;
1093    if device.is_cuda() {
1094        result.to(device)
1095    } else {
1096        Ok(result)
1097    }
1098}
1099
1100// ===========================================================================
1101// TransformerEncoderLayer
1102// ===========================================================================
1103
1104/// A single pre-norm transformer encoder layer.
1105///
1106/// Applies the following computation:
1107///
1108/// ```text
1109/// x = x + dropout(self_attn(norm1(x)))
1110/// x = x + dropout(ffn(norm2(x)))
1111/// ```
1112///
1113/// This matches the pre-norm (Pre-LN) style used in GPT-2, LLaMA, and most
1114/// modern LLMs, which trains more stably than the original post-norm design.
1115///
1116/// # Shape contract
1117///
1118/// - Input: `[batch, seq_len, d_model]`
1119/// - Output: `[batch, seq_len, d_model]`
1120#[derive(Debug)]
1121pub struct TransformerEncoderLayer<T: Float> {
1122    self_attn: MultiheadAttention<T>,
1123    ffn: SwiGLU<T>,
1124    norm1: LayerNorm<T>,
1125    norm2: LayerNorm<T>,
1126    dropout: Dropout<T>,
1127    training: bool,
1128}
1129
1130impl<T: Float> TransformerEncoderLayer<T> {
1131    /// Create a new transformer encoder layer.
1132    ///
1133    /// # Arguments
1134    ///
1135    /// - `d_model` - The model dimension (embedding size).
1136    /// - `num_heads` - Number of attention heads.
1137    /// - `d_ff` - Hidden dimension of the SwiGLU feedforward network.
1138    /// - `dropout_p` - Dropout probability (applied after attention and FFN).
1139    /// - `layer_norm_eps` - Epsilon for layer normalization.
1140    /// - `bias` - Whether to use bias in attention and FFN projections.
1141    pub fn new(
1142        d_model: usize,
1143        num_heads: usize,
1144        d_ff: usize,
1145        dropout_p: f64,
1146        layer_norm_eps: f64,
1147        bias: bool,
1148    ) -> FerrotorchResult<Self> {
1149        let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1150        let ffn = SwiGLU::new(d_model, d_ff, bias)?;
1151        let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1152        let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1153        let dropout = Dropout::new(dropout_p)?;
1154
1155        Ok(Self {
1156            self_attn,
1157            ffn,
1158            norm1,
1159            norm2,
1160            dropout,
1161            training: true,
1162        })
1163    }
1164
1165    /// Override the post-attention / post-FFN dropout probability after
1166    /// construction. Forwards to [`Dropout::set_p`]; same validation.
1167    /// Used by MC-dropout inference (see ferrotorch-paged's
1168    /// `predict_all_layers_logits_mc`).
1169    pub fn set_dropout_p(&mut self, p: f64) -> FerrotorchResult<()> {
1170        self.dropout.set_p(p)
1171    }
1172
1173    /// Current dropout probability for this encoder layer.
1174    pub fn dropout_p(&self) -> f64 {
1175        self.dropout.p()
1176    }
1177}
1178
1179impl<T: Float> Module<T> for TransformerEncoderLayer<T> {
1180    /// Forward pass with pre-norm residual connections.
1181    ///
1182    /// Input shape: `[batch, seq_len, d_model]`.
1183    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1184        if input.ndim() != 3 {
1185            return Err(FerrotorchError::InvalidArgument {
1186                message: format!(
1187                    "TransformerEncoderLayer expects 3-D [batch, seq, d_model], got {:?}",
1188                    input.shape()
1189                ),
1190            });
1191        }
1192
1193        // Pre-norm self-attention block.
1194        let normed1 = self.norm1.forward(input)?;
1195        let attn_out = self.self_attn.forward(&normed1)?;
1196        let attn_out = self.dropout.forward(&attn_out)?;
1197        let residual1 = add(input, &attn_out)?;
1198
1199        // Pre-norm feedforward block.
1200        let normed2 = self.norm2.forward(&residual1)?;
1201        let ffn_out = self.ffn.forward(&normed2)?;
1202        let ffn_out = self.dropout.forward(&ffn_out)?;
1203        let residual2 = add(&residual1, &ffn_out)?;
1204
1205        Ok(residual2)
1206    }
1207
1208    fn parameters(&self) -> Vec<&Parameter<T>> {
1209        let mut params = Vec::new();
1210        params.extend(self.self_attn.parameters());
1211        params.extend(self.ffn.parameters());
1212        params.extend(self.norm1.parameters());
1213        params.extend(self.norm2.parameters());
1214        // Dropout has no parameters.
1215        params
1216    }
1217
1218    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1219        let mut params = Vec::new();
1220        params.extend(self.self_attn.parameters_mut());
1221        params.extend(self.ffn.parameters_mut());
1222        params.extend(self.norm1.parameters_mut());
1223        params.extend(self.norm2.parameters_mut());
1224        params
1225    }
1226
1227    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1228        let mut params = Vec::new();
1229        for (name, param) in self.self_attn.named_parameters() {
1230            params.push((format!("self_attn.{name}"), param));
1231        }
1232        for (name, param) in self.ffn.named_parameters() {
1233            params.push((format!("ffn.{name}"), param));
1234        }
1235        for (name, param) in self.norm1.named_parameters() {
1236            params.push((format!("norm1.{name}"), param));
1237        }
1238        for (name, param) in self.norm2.named_parameters() {
1239            params.push((format!("norm2.{name}"), param));
1240        }
1241        params
1242    }
1243
1244    fn train(&mut self) {
1245        self.training = true;
1246        self.self_attn.train();
1247        self.ffn.train();
1248        self.norm1.train();
1249        self.norm2.train();
1250        self.dropout.train();
1251    }
1252
1253    fn eval(&mut self) {
1254        self.training = false;
1255        self.self_attn.eval();
1256        self.ffn.eval();
1257        self.norm1.eval();
1258        self.norm2.eval();
1259        self.dropout.eval();
1260    }
1261
1262    fn is_training(&self) -> bool {
1263        self.training
1264    }
1265}
1266
1267// ===========================================================================
1268// TransformerDecoderLayer
1269// ===========================================================================
1270
1271/// A single pre-norm transformer decoder layer with cross-attention.
1272///
1273/// Applies the following computation:
1274///
1275/// ```text
1276/// x = x + dropout(self_attn(norm1(x)))
1277/// x = x + dropout(cross_attn(norm2(x), memory, memory))
1278/// x = x + dropout(ffn(norm3(x)))
1279/// ```
1280///
1281/// The self-attention sub-layer uses causal masking. The cross-attention
1282/// sub-layer attends over encoder output (`memory`).
1283///
1284/// # Shape contract
1285///
1286/// - `input` (decoder): `[batch, tgt_seq, d_model]`
1287/// - `memory` (encoder output): `[batch, src_seq, d_model]`
1288/// - Output: `[batch, tgt_seq, d_model]`
1289#[derive(Debug)]
1290pub struct TransformerDecoderLayer<T: Float> {
1291    self_attn: MultiheadAttention<T>,
1292    cross_attn: MultiheadAttention<T>,
1293    ffn: SwiGLU<T>,
1294    norm1: LayerNorm<T>,
1295    norm2: LayerNorm<T>,
1296    norm3: LayerNorm<T>,
1297    dropout: Dropout<T>,
1298    training: bool,
1299}
1300
1301impl<T: Float> TransformerDecoderLayer<T> {
1302    /// Create a new transformer decoder layer.
1303    ///
1304    /// # Arguments
1305    ///
1306    /// - `d_model` - The model dimension (embedding size).
1307    /// - `num_heads` - Number of attention heads.
1308    /// - `d_ff` - Hidden dimension of the SwiGLU feedforward network.
1309    /// - `dropout_p` - Dropout probability.
1310    /// - `layer_norm_eps` - Epsilon for layer normalization.
1311    /// - `bias` - Whether to use bias in attention and FFN projections.
1312    pub fn new(
1313        d_model: usize,
1314        num_heads: usize,
1315        d_ff: usize,
1316        dropout_p: f64,
1317        layer_norm_eps: f64,
1318        bias: bool,
1319    ) -> FerrotorchResult<Self> {
1320        let self_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1321        let cross_attn = MultiheadAttention::new(d_model, num_heads, bias)?;
1322        let ffn = SwiGLU::new(d_model, d_ff, bias)?;
1323        let norm1 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1324        let norm2 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1325        let norm3 = LayerNorm::new(vec![d_model], layer_norm_eps, true)?;
1326        let dropout = Dropout::new(dropout_p)?;
1327
1328        Ok(Self {
1329            self_attn,
1330            cross_attn,
1331            ffn,
1332            norm1,
1333            norm2,
1334            norm3,
1335            dropout,
1336            training: true,
1337        })
1338    }
1339
1340    /// Forward pass.
1341    ///
1342    /// # Arguments
1343    ///
1344    /// - `input` - Decoder input: `[batch, tgt_seq, d_model]`.
1345    /// - `memory` - Encoder output: `[batch, src_seq, d_model]`.
1346    ///
1347    /// # Returns
1348    ///
1349    /// Output tensor of shape `[batch, tgt_seq, d_model]`.
1350    pub fn forward_with_memory(
1351        &self,
1352        input: &Tensor<T>,
1353        memory: &Tensor<T>,
1354    ) -> FerrotorchResult<Tensor<T>> {
1355        if input.ndim() != 3 || memory.ndim() != 3 {
1356            return Err(FerrotorchError::InvalidArgument {
1357                message: format!(
1358                    "TransformerDecoderLayer expects 3-D inputs, \
1359                     got input {:?}, memory {:?}",
1360                    input.shape(),
1361                    memory.shape()
1362                ),
1363            });
1364        }
1365
1366        // Pre-norm causal self-attention.
1367        let normed1 = self.norm1.forward(input)?;
1368        let self_attn_out = self
1369            .self_attn
1370            .forward_qkv(&normed1, &normed1, &normed1, true)?;
1371        let self_attn_out = self.dropout.forward(&self_attn_out)?;
1372        let residual1 = add(input, &self_attn_out)?;
1373
1374        // Pre-norm cross-attention.
1375        let normed2 = self.norm2.forward(&residual1)?;
1376        let cross_attn_out = self
1377            .cross_attn
1378            .forward_qkv(&normed2, memory, memory, false)?;
1379        let cross_attn_out = self.dropout.forward(&cross_attn_out)?;
1380        let residual2 = add(&residual1, &cross_attn_out)?;
1381
1382        // Pre-norm feedforward.
1383        let normed3 = self.norm3.forward(&residual2)?;
1384        let ffn_out = self.ffn.forward(&normed3)?;
1385        let ffn_out = self.dropout.forward(&ffn_out)?;
1386        let residual3 = add(&residual2, &ffn_out)?;
1387
1388        Ok(residual3)
1389    }
1390}
1391
1392impl<T: Float> Module<T> for TransformerDecoderLayer<T> {
1393    /// Forward pass using `input` as both decoder input and memory.
1394    ///
1395    /// For the typical decoder use case with separate encoder output, call
1396    /// [`forward_with_memory`](Self::forward_with_memory) directly.
1397    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1398        self.forward_with_memory(input, input)
1399    }
1400
1401    fn parameters(&self) -> Vec<&Parameter<T>> {
1402        let mut params = Vec::new();
1403        params.extend(self.self_attn.parameters());
1404        params.extend(self.cross_attn.parameters());
1405        params.extend(self.ffn.parameters());
1406        params.extend(self.norm1.parameters());
1407        params.extend(self.norm2.parameters());
1408        params.extend(self.norm3.parameters());
1409        params
1410    }
1411
1412    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1413        let mut params = Vec::new();
1414        params.extend(self.self_attn.parameters_mut());
1415        params.extend(self.cross_attn.parameters_mut());
1416        params.extend(self.ffn.parameters_mut());
1417        params.extend(self.norm1.parameters_mut());
1418        params.extend(self.norm2.parameters_mut());
1419        params.extend(self.norm3.parameters_mut());
1420        params
1421    }
1422
1423    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1424        let mut params = Vec::new();
1425        for (name, param) in self.self_attn.named_parameters() {
1426            params.push((format!("self_attn.{name}"), param));
1427        }
1428        for (name, param) in self.cross_attn.named_parameters() {
1429            params.push((format!("cross_attn.{name}"), param));
1430        }
1431        for (name, param) in self.ffn.named_parameters() {
1432            params.push((format!("ffn.{name}"), param));
1433        }
1434        for (name, param) in self.norm1.named_parameters() {
1435            params.push((format!("norm1.{name}"), param));
1436        }
1437        for (name, param) in self.norm2.named_parameters() {
1438            params.push((format!("norm2.{name}"), param));
1439        }
1440        for (name, param) in self.norm3.named_parameters() {
1441            params.push((format!("norm3.{name}"), param));
1442        }
1443        params
1444    }
1445
1446    fn train(&mut self) {
1447        self.training = true;
1448        self.self_attn.train();
1449        self.cross_attn.train();
1450        self.ffn.train();
1451        self.norm1.train();
1452        self.norm2.train();
1453        self.norm3.train();
1454        self.dropout.train();
1455    }
1456
1457    fn eval(&mut self) {
1458        self.training = false;
1459        self.self_attn.eval();
1460        self.cross_attn.eval();
1461        self.ffn.eval();
1462        self.norm1.eval();
1463        self.norm2.eval();
1464        self.norm3.eval();
1465        self.dropout.eval();
1466    }
1467
1468    fn is_training(&self) -> bool {
1469        self.training
1470    }
1471}
1472
1473// ===========================================================================
1474// TransformerEncoder
1475// ===========================================================================
1476
1477/// A stack of N [`TransformerEncoderLayer`] modules with an optional final
1478/// layer normalization.
1479///
1480/// This mirrors `torch.nn.TransformerEncoder`: it iterates the input through
1481/// `num_layers` identical (but independently parameterized) encoder layers,
1482/// then optionally applies a final `LayerNorm`.
1483///
1484/// # Shape contract
1485///
1486/// - Input: `[batch, seq_len, d_model]`
1487/// - Output: `[batch, seq_len, d_model]`
1488#[derive(Debug)]
1489pub struct TransformerEncoder<T: Float> {
1490    layers: Vec<TransformerEncoderLayer<T>>,
1491    norm: Option<LayerNorm<T>>,
1492    training: bool,
1493}
1494
1495impl<T: Float> TransformerEncoder<T> {
1496    /// Create a new transformer encoder.
1497    ///
1498    /// Each layer is constructed fresh with the same hyperparameters (not
1499    /// cloned), so they have independent initial weights.
1500    ///
1501    /// # Arguments
1502    ///
1503    /// - `d_model` - The model dimension (embedding size).
1504    /// - `num_heads` - Number of attention heads.
1505    /// - `num_layers` - Number of encoder layers to stack.
1506    /// - `d_ff` - Hidden dimension of the SwiGLU feedforward network.
1507    /// - `dropout_p` - Dropout probability.
1508    /// - `layer_norm_eps` - Epsilon for layer normalization.
1509    /// - `bias` - Whether to use bias in attention and FFN projections.
1510    /// - `final_norm` - Whether to add a final `LayerNorm` after the last layer.
1511    #[allow(clippy::too_many_arguments)]
1512    pub fn new(
1513        d_model: usize,
1514        num_heads: usize,
1515        num_layers: usize,
1516        d_ff: usize,
1517        dropout_p: f64,
1518        layer_norm_eps: f64,
1519        bias: bool,
1520        final_norm: bool,
1521    ) -> FerrotorchResult<Self> {
1522        if num_layers == 0 {
1523            return Err(FerrotorchError::InvalidArgument {
1524                message: "TransformerEncoder: num_layers must be > 0".into(),
1525            });
1526        }
1527
1528        let mut layers = Vec::with_capacity(num_layers);
1529        for _ in 0..num_layers {
1530            layers.push(TransformerEncoderLayer::new(
1531                d_model,
1532                num_heads,
1533                d_ff,
1534                dropout_p,
1535                layer_norm_eps,
1536                bias,
1537            )?);
1538        }
1539
1540        let norm = if final_norm {
1541            Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
1542        } else {
1543            None
1544        };
1545
1546        Ok(Self {
1547            layers,
1548            norm,
1549            training: true,
1550        })
1551    }
1552
1553    /// The number of stacked encoder layers.
1554    #[inline]
1555    pub fn num_layers(&self) -> usize {
1556        self.layers.len()
1557    }
1558}
1559
1560impl<T: Float> Module<T> for TransformerEncoder<T> {
1561    /// Forward pass: iterate through all encoder layers, then apply final norm.
1562    ///
1563    /// Input shape: `[batch, seq_len, d_model]`.
1564    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1565        let mut output = input.clone();
1566        for layer in &self.layers {
1567            output = layer.forward(&output)?;
1568        }
1569        if let Some(ref norm) = self.norm {
1570            output = norm.forward(&output)?;
1571        }
1572        Ok(output)
1573    }
1574
1575    fn parameters(&self) -> Vec<&Parameter<T>> {
1576        let mut params = Vec::new();
1577        for (i, layer) in self.layers.iter().enumerate() {
1578            let _ = i; // layer index not needed for unnamed params
1579            params.extend(layer.parameters());
1580        }
1581        if let Some(ref norm) = self.norm {
1582            params.extend(norm.parameters());
1583        }
1584        params
1585    }
1586
1587    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1588        let mut params = Vec::new();
1589        for layer in &mut self.layers {
1590            params.extend(layer.parameters_mut());
1591        }
1592        if let Some(ref mut norm) = self.norm {
1593            params.extend(norm.parameters_mut());
1594        }
1595        params
1596    }
1597
1598    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1599        let mut params = Vec::new();
1600        for (i, layer) in self.layers.iter().enumerate() {
1601            for (name, param) in layer.named_parameters() {
1602                params.push((format!("layers.{i}.{name}"), param));
1603            }
1604        }
1605        if let Some(ref norm) = self.norm {
1606            for (name, param) in norm.named_parameters() {
1607                params.push((format!("norm.{name}"), param));
1608            }
1609        }
1610        params
1611    }
1612
1613    fn train(&mut self) {
1614        self.training = true;
1615        for layer in &mut self.layers {
1616            layer.train();
1617        }
1618        if let Some(ref mut norm) = self.norm {
1619            norm.train();
1620        }
1621    }
1622
1623    fn eval(&mut self) {
1624        self.training = false;
1625        for layer in &mut self.layers {
1626            layer.eval();
1627        }
1628        if let Some(ref mut norm) = self.norm {
1629            norm.eval();
1630        }
1631    }
1632
1633    fn is_training(&self) -> bool {
1634        self.training
1635    }
1636}
1637
1638// ===========================================================================
1639// TransformerDecoder
1640// ===========================================================================
1641
1642/// A stack of N [`TransformerDecoderLayer`] modules with an optional final
1643/// layer normalization.
1644///
1645/// This mirrors `torch.nn.TransformerDecoder`: it iterates the target input
1646/// through `num_layers` identical (but independently parameterized) decoder
1647/// layers, each attending to the encoder `memory`, then optionally applies a
1648/// final `LayerNorm`.
1649///
1650/// # Shape contract
1651///
1652/// - `input` (decoder): `[batch, tgt_seq, d_model]`
1653/// - `memory` (encoder output): `[batch, src_seq, d_model]`
1654/// - Output: `[batch, tgt_seq, d_model]`
1655#[derive(Debug)]
1656pub struct TransformerDecoder<T: Float> {
1657    layers: Vec<TransformerDecoderLayer<T>>,
1658    norm: Option<LayerNorm<T>>,
1659    training: bool,
1660}
1661
1662impl<T: Float> TransformerDecoder<T> {
1663    /// Create a new transformer decoder.
1664    ///
1665    /// Each layer is constructed fresh with the same hyperparameters (not
1666    /// cloned), so they have independent initial weights.
1667    ///
1668    /// # Arguments
1669    ///
1670    /// - `d_model` - The model dimension (embedding size).
1671    /// - `num_heads` - Number of attention heads.
1672    /// - `num_layers` - Number of decoder layers to stack.
1673    /// - `d_ff` - Hidden dimension of the SwiGLU feedforward network.
1674    /// - `dropout_p` - Dropout probability.
1675    /// - `layer_norm_eps` - Epsilon for layer normalization.
1676    /// - `bias` - Whether to use bias in attention and FFN projections.
1677    /// - `final_norm` - Whether to add a final `LayerNorm` after the last layer.
1678    #[allow(clippy::too_many_arguments)]
1679    pub fn new(
1680        d_model: usize,
1681        num_heads: usize,
1682        num_layers: usize,
1683        d_ff: usize,
1684        dropout_p: f64,
1685        layer_norm_eps: f64,
1686        bias: bool,
1687        final_norm: bool,
1688    ) -> FerrotorchResult<Self> {
1689        if num_layers == 0 {
1690            return Err(FerrotorchError::InvalidArgument {
1691                message: "TransformerDecoder: num_layers must be > 0".into(),
1692            });
1693        }
1694
1695        let mut layers = Vec::with_capacity(num_layers);
1696        for _ in 0..num_layers {
1697            layers.push(TransformerDecoderLayer::new(
1698                d_model,
1699                num_heads,
1700                d_ff,
1701                dropout_p,
1702                layer_norm_eps,
1703                bias,
1704            )?);
1705        }
1706
1707        let norm = if final_norm {
1708            Some(LayerNorm::new(vec![d_model], layer_norm_eps, true)?)
1709        } else {
1710            None
1711        };
1712
1713        Ok(Self {
1714            layers,
1715            norm,
1716            training: true,
1717        })
1718    }
1719
1720    /// Forward pass with encoder memory.
1721    ///
1722    /// # Arguments
1723    ///
1724    /// - `input` - Decoder input: `[batch, tgt_seq, d_model]`.
1725    /// - `memory` - Encoder output: `[batch, src_seq, d_model]`.
1726    ///
1727    /// # Returns
1728    ///
1729    /// Output tensor of shape `[batch, tgt_seq, d_model]`.
1730    pub fn forward_with_memory(
1731        &self,
1732        input: &Tensor<T>,
1733        memory: &Tensor<T>,
1734    ) -> FerrotorchResult<Tensor<T>> {
1735        let mut output = input.clone();
1736        for layer in &self.layers {
1737            output = layer.forward_with_memory(&output, memory)?;
1738        }
1739        if let Some(ref norm) = self.norm {
1740            output = norm.forward(&output)?;
1741        }
1742        Ok(output)
1743    }
1744
1745    /// The number of stacked decoder layers.
1746    #[inline]
1747    pub fn num_layers(&self) -> usize {
1748        self.layers.len()
1749    }
1750}
1751
1752impl<T: Float> Module<T> for TransformerDecoder<T> {
1753    /// Forward pass using `input` as both decoder input and memory.
1754    ///
1755    /// For the typical decoder use case with separate encoder output, call
1756    /// [`forward_with_memory`](Self::forward_with_memory) directly.
1757    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1758        self.forward_with_memory(input, input)
1759    }
1760
1761    fn parameters(&self) -> Vec<&Parameter<T>> {
1762        let mut params = Vec::new();
1763        for layer in &self.layers {
1764            params.extend(layer.parameters());
1765        }
1766        if let Some(ref norm) = self.norm {
1767            params.extend(norm.parameters());
1768        }
1769        params
1770    }
1771
1772    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1773        let mut params = Vec::new();
1774        for layer in &mut self.layers {
1775            params.extend(layer.parameters_mut());
1776        }
1777        if let Some(ref mut norm) = self.norm {
1778            params.extend(norm.parameters_mut());
1779        }
1780        params
1781    }
1782
1783    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1784        let mut params = Vec::new();
1785        for (i, layer) in self.layers.iter().enumerate() {
1786            for (name, param) in layer.named_parameters() {
1787                params.push((format!("layers.{i}.{name}"), param));
1788            }
1789        }
1790        if let Some(ref norm) = self.norm {
1791            for (name, param) in norm.named_parameters() {
1792                params.push((format!("norm.{name}"), param));
1793            }
1794        }
1795        params
1796    }
1797
1798    fn train(&mut self) {
1799        self.training = true;
1800        for layer in &mut self.layers {
1801            layer.train();
1802        }
1803        if let Some(ref mut norm) = self.norm {
1804            norm.train();
1805        }
1806    }
1807
1808    fn eval(&mut self) {
1809        self.training = false;
1810        for layer in &mut self.layers {
1811            layer.eval();
1812        }
1813        if let Some(ref mut norm) = self.norm {
1814            norm.eval();
1815        }
1816    }
1817
1818    fn is_training(&self) -> bool {
1819        self.training
1820    }
1821}
1822
1823// ===========================================================================
1824// Transformer
1825// ===========================================================================
1826
1827/// Full encoder-decoder transformer model.
1828///
1829/// Combines a [`TransformerEncoder`] and [`TransformerDecoder`] into a single
1830/// module, matching `torch.nn.Transformer`.
1831///
1832/// # Shape contract
1833///
1834/// - `src` (encoder input): `[batch, src_seq, d_model]`
1835/// - `tgt` (decoder input): `[batch, tgt_seq, d_model]`
1836/// - Output: `[batch, tgt_seq, d_model]`
1837///
1838/// # Example
1839///
1840/// ```ignore
1841/// let transformer = Transformer::<f32>::new(64, 4, 3, 3, 128, 0.1, 1e-5, true)?;
1842/// let src = ferrotorch_core::randn::<f32>(&[2, 10, 64])?;
1843/// let tgt = ferrotorch_core::randn::<f32>(&[2, 5, 64])?;
1844/// let output = transformer.forward_transformer(&src, &tgt)?;
1845/// assert_eq!(output.shape(), &[2, 5, 64]);
1846/// ```
1847#[derive(Debug)]
1848pub struct Transformer<T: Float> {
1849    encoder: TransformerEncoder<T>,
1850    decoder: TransformerDecoder<T>,
1851    training: bool,
1852}
1853
1854impl<T: Float> Transformer<T> {
1855    /// Create a new full encoder-decoder transformer.
1856    ///
1857    /// # Arguments
1858    ///
1859    /// - `d_model` - The model dimension (embedding size).
1860    /// - `num_heads` - Number of attention heads.
1861    /// - `num_encoder_layers` - Number of encoder layers (default: 6).
1862    /// - `num_decoder_layers` - Number of decoder layers (default: 6).
1863    /// - `d_ff` - Hidden dimension of the SwiGLU feedforward network (default: 2048).
1864    /// - `dropout_p` - Dropout probability (default: 0.1).
1865    /// - `layer_norm_eps` - Epsilon for layer normalization.
1866    /// - `bias` - Whether to use bias in attention and FFN projections.
1867    #[allow(clippy::too_many_arguments)]
1868    pub fn new(
1869        d_model: usize,
1870        num_heads: usize,
1871        num_encoder_layers: usize,
1872        num_decoder_layers: usize,
1873        d_ff: usize,
1874        dropout_p: f64,
1875        layer_norm_eps: f64,
1876        bias: bool,
1877    ) -> FerrotorchResult<Self> {
1878        let encoder = TransformerEncoder::new(
1879            d_model,
1880            num_heads,
1881            num_encoder_layers,
1882            d_ff,
1883            dropout_p,
1884            layer_norm_eps,
1885            bias,
1886            true, // final norm on encoder
1887        )?;
1888        let decoder = TransformerDecoder::new(
1889            d_model,
1890            num_heads,
1891            num_decoder_layers,
1892            d_ff,
1893            dropout_p,
1894            layer_norm_eps,
1895            bias,
1896            true, // final norm on decoder
1897        )?;
1898
1899        Ok(Self {
1900            encoder,
1901            decoder,
1902            training: true,
1903        })
1904    }
1905
1906    /// Forward pass: encode `src`, then decode `tgt` using the encoded memory.
1907    ///
1908    /// # Arguments
1909    ///
1910    /// - `src` - Encoder input: `[batch, src_seq, d_model]`.
1911    /// - `tgt` - Decoder input: `[batch, tgt_seq, d_model]`.
1912    ///
1913    /// # Returns
1914    ///
1915    /// Output tensor of shape `[batch, tgt_seq, d_model]`.
1916    pub fn forward_transformer(
1917        &self,
1918        src: &Tensor<T>,
1919        tgt: &Tensor<T>,
1920    ) -> FerrotorchResult<Tensor<T>> {
1921        let memory = self.encoder.forward(src)?;
1922        self.decoder.forward_with_memory(tgt, &memory)
1923    }
1924
1925    /// The number of encoder layers.
1926    #[inline]
1927    pub fn num_encoder_layers(&self) -> usize {
1928        self.encoder.num_layers()
1929    }
1930
1931    /// The number of decoder layers.
1932    #[inline]
1933    pub fn num_decoder_layers(&self) -> usize {
1934        self.decoder.num_layers()
1935    }
1936}
1937
1938impl<T: Float> Module<T> for Transformer<T> {
1939    /// Forward pass using `input` as both source and target.
1940    ///
1941    /// For the typical encoder-decoder use case with separate src/tgt, call
1942    /// [`forward_transformer`](Self::forward_transformer) directly.
1943    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1944        self.forward_transformer(input, input)
1945    }
1946
1947    fn parameters(&self) -> Vec<&Parameter<T>> {
1948        let mut params = self.encoder.parameters();
1949        params.extend(self.decoder.parameters());
1950        params
1951    }
1952
1953    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1954        let mut params = self.encoder.parameters_mut();
1955        params.extend(self.decoder.parameters_mut());
1956        params
1957    }
1958
1959    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1960        let mut params = Vec::new();
1961        for (name, param) in self.encoder.named_parameters() {
1962            params.push((format!("encoder.{name}"), param));
1963        }
1964        for (name, param) in self.decoder.named_parameters() {
1965            params.push((format!("decoder.{name}"), param));
1966        }
1967        params
1968    }
1969
1970    fn train(&mut self) {
1971        self.training = true;
1972        self.encoder.train();
1973        self.decoder.train();
1974    }
1975
1976    fn eval(&mut self) {
1977        self.training = false;
1978        self.encoder.eval();
1979        self.decoder.eval();
1980    }
1981
1982    fn is_training(&self) -> bool {
1983        self.training
1984    }
1985}
1986
1987// ===========================================================================
1988// Tests
1989// ===========================================================================
1990
1991#[cfg(test)]
1992mod tests {
1993    use super::*;
1994
1995    // -----------------------------------------------------------------------
1996    // RoPE
1997    // -----------------------------------------------------------------------
1998
1999    #[test]
2000    fn test_rope_construction() {
2001        let rope = RotaryPositionEmbedding::<f32>::new(64, 512, 10000.0);
2002        assert!(rope.is_ok());
2003        let rope = rope.unwrap();
2004        assert_eq!(rope.dim(), 64);
2005        assert_eq!(rope.max_seq_len(), 512);
2006        assert_eq!(rope.base(), 10000.0);
2007    }
2008
2009    #[test]
2010    fn test_rope_odd_dim_rejected() {
2011        assert!(RotaryPositionEmbedding::<f32>::new(63, 512, 10000.0).is_err());
2012    }
2013
2014    #[test]
2015    fn test_rope_zero_dim_rejected() {
2016        assert!(RotaryPositionEmbedding::<f32>::new(0, 512, 10000.0).is_err());
2017    }
2018
2019    #[test]
2020    fn test_rope_zero_seq_rejected() {
2021        assert!(RotaryPositionEmbedding::<f32>::new(64, 0, 10000.0).is_err());
2022    }
2023
2024    #[test]
2025    fn test_rope_output_shape_2d() {
2026        let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2027        // Input: [seq_len=4, dim=8]
2028        let x = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2029        let y = rope.apply(&x, 0).unwrap();
2030        assert_eq!(y.shape(), &[4, 8]);
2031    }
2032
2033    #[test]
2034    fn test_rope_output_shape_3d() {
2035        let rope = RotaryPositionEmbedding::<f32>::new(16, 256, 10000.0).unwrap();
2036        // Input: [batch=2, seq_len=10, dim=16]
2037        let x = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
2038        let y = rope.apply(&x, 0).unwrap();
2039        assert_eq!(y.shape(), &[2, 10, 16]);
2040    }
2041
2042    #[test]
2043    fn test_rope_output_shape_4d() {
2044        let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2045        // Input: [batch=2, heads=4, seq=6, head_dim=8]
2046        let x = ferrotorch_core::zeros::<f32>(&[2, 4, 6, 8]).unwrap();
2047        let y = rope.apply(&x, 0).unwrap();
2048        assert_eq!(y.shape(), &[2, 4, 6, 8]);
2049    }
2050
2051    #[test]
2052    fn test_rope_with_offset() {
2053        let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2054        let x = ferrotorch_core::ones::<f32>(&[4, 8]).unwrap();
2055        // Offset 10, seq_len 4 -> positions 10..14, fine since 14 <= 128.
2056        let y = rope.apply(&x, 10).unwrap();
2057        assert_eq!(y.shape(), &[4, 8]);
2058    }
2059
2060    #[test]
2061    fn test_rope_offset_overflow_rejected() {
2062        let rope = RotaryPositionEmbedding::<f32>::new(8, 16, 10000.0).unwrap();
2063        // seq_len=10, offset=10 -> 10+10=20 > 16 -> error.
2064        let x = ferrotorch_core::zeros::<f32>(&[10, 8]).unwrap();
2065        assert!(rope.apply(&x, 10).is_err());
2066    }
2067
2068    #[test]
2069    fn test_rope_position_zero_is_identity() {
2070        // At position 0, cos(0) = 1, sin(0) = 0, so rotation is identity.
2071        let rope = RotaryPositionEmbedding::<f64>::new(4, 64, 10000.0).unwrap();
2072        let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2073        let y = rope.apply(&x, 0).unwrap();
2074        let y_data = y.data().unwrap();
2075        let x_data = x.data().unwrap();
2076        for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
2077            assert!(
2078                (xv - yv).abs() < 1e-10,
2079                "position 0 should be identity, index {i}: x={xv}, y={yv}"
2080            );
2081        }
2082    }
2083
2084    #[test]
2085    fn test_rope_values_are_finite() {
2086        let rope = RotaryPositionEmbedding::<f32>::new(16, 512, 10000.0).unwrap();
2087        let x = ferrotorch_core::ones::<f32>(&[2, 4, 10, 16]).unwrap();
2088        let y = rope.apply(&x, 0).unwrap();
2089        for &v in y.data().unwrap() {
2090            assert!(v.is_finite(), "RoPE produced non-finite value: {v}");
2091        }
2092    }
2093
2094    #[test]
2095    fn test_rope_wrong_dim_rejected() {
2096        let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2097        let x = ferrotorch_core::zeros::<f32>(&[4, 10]).unwrap(); // dim=10 != 8
2098        assert!(rope.apply(&x, 0).is_err());
2099    }
2100
2101    // -- RoPE scaling tests (#515) -----------------------------------------
2102
2103    #[test]
2104    fn test_rope_scaling_default_is_none() {
2105        let rope = RotaryPositionEmbedding::<f32>::new(16, 128, 10000.0).unwrap();
2106        assert_eq!(rope.scaling(), RoPEScaling::None);
2107    }
2108
2109    #[test]
2110    fn test_rope_scaling_none_matches_classical() {
2111        // with_scaling(RoPEScaling::None) must produce the same caches as new().
2112        let a = RotaryPositionEmbedding::<f64>::new(16, 32, 10000.0).unwrap();
2113        let b = RotaryPositionEmbedding::<f64>::with_scaling(
2114            16,
2115            32,
2116            10000.0,
2117            RoPEConvention::default(),
2118            RoPEScaling::None,
2119        )
2120        .unwrap();
2121        let x = ferrotorch_core::from_slice(
2122            &(0..16).map(|i| i as f64 * 0.1).collect::<Vec<_>>(),
2123            &[1, 16],
2124        )
2125        .unwrap();
2126        let ya = a.apply(&x, 7).unwrap();
2127        let yb = b.apply(&x, 7).unwrap();
2128        for (va, vb) in ya.data().unwrap().iter().zip(yb.data().unwrap().iter()) {
2129            assert!((va - vb).abs() < 1e-12);
2130        }
2131    }
2132
2133    #[test]
2134    fn test_rope_scaling_linear_halves_angles() {
2135        // Linear factor=2: all angles at position p under the scaled
2136        // schedule equal angles at p/2 under the unscaled schedule.
2137        let scaled = RotaryPositionEmbedding::<f64>::with_scaling(
2138            8,
2139            64,
2140            10000.0,
2141            RoPEConvention::default(),
2142            RoPEScaling::Linear { factor: 2.0 },
2143        )
2144        .unwrap();
2145        let plain = RotaryPositionEmbedding::<f64>::new(8, 64, 10000.0).unwrap();
2146
2147        // All-ones probe; applying at pos=8 on scaled should equal
2148        // applying at pos=4 on plain.
2149        let x = ferrotorch_core::ones::<f64>(&[1, 8]).unwrap();
2150        let y_scaled = scaled.apply(&x, 8).unwrap();
2151        let y_plain = plain.apply(&x, 4).unwrap();
2152        for (a, b) in y_scaled
2153            .data()
2154            .unwrap()
2155            .iter()
2156            .zip(y_plain.data().unwrap().iter())
2157        {
2158            assert!(
2159                (a - b).abs() < 1e-6,
2160                "scaled(pos=8) should match plain(pos=4): {a} vs {b}"
2161            );
2162        }
2163    }
2164
2165    #[test]
2166    fn test_rope_scaling_ntk_inv_freq() {
2167        // NTK-aware scaling: base' = base * factor^(dim / (dim - 2)).
2168        // For i=0, inv_freq[0] = 1 / base'^0 = 1 exactly, matching the
2169        // unscaled schedule. For i = dim/2 - 1, NTK stretches by
2170        // approximately 1/factor (matches linear PI at the long end).
2171        use super::compute_scaled_inv_freq;
2172
2173        let dim = 64;
2174        let base = 10000.0;
2175        let factor = 4.0;
2176        let ntk = compute_scaled_inv_freq(
2177            dim,
2178            base,
2179            RoPEScaling::NtkAware {
2180                factor,
2181                original_max_pos_embeddings: 2048,
2182            },
2183        );
2184        let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
2185        assert_eq!(ntk.len(), 32);
2186        assert_eq!(plain.len(), 32);
2187
2188        // High-frequency dim (i=0) must round-trip bit-identically.
2189        assert!(
2190            (ntk[0] - plain[0]).abs() < 1e-15,
2191            "NTK inv_freq[0] should equal plain inv_freq[0]: ntk={}, plain={}",
2192            ntk[0],
2193            plain[0]
2194        );
2195
2196        // Lowest-frequency dim (i = dim/2 - 1 = 31) should approach the
2197        // linear-PI scaling of 1/factor of the plain frequency.
2198        let ratio = ntk[31] / plain[31];
2199        let expected = 1.0 / factor;
2200        assert!(
2201            (ratio - expected).abs() < 0.05,
2202            "NTK inv_freq[31]/plain ratio should be ~{expected}: got {ratio}"
2203        );
2204    }
2205
2206    #[test]
2207    fn test_rope_scaling_linear_inv_freq_halved() {
2208        use super::compute_scaled_inv_freq;
2209        let lin = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::Linear { factor: 2.0 });
2210        let plain = compute_scaled_inv_freq(8, 10000.0, RoPEScaling::None);
2211        for (a, b) in lin.iter().zip(plain.iter()) {
2212            assert!(
2213                (a - b / 2.0).abs() < 1e-15,
2214                "linear should halve: {a} vs {b}/2"
2215            );
2216        }
2217    }
2218
2219    #[test]
2220    fn test_rope_scaling_yarn_inv_freq_piecewise() {
2221        // YARN mixes extrapolation (no scale) at the highest frequencies
2222        // with interpolation (1/factor) at the lowest frequencies.
2223        use super::compute_scaled_inv_freq;
2224        let dim = 64;
2225        let base = 10000.0;
2226        let factor = 4.0;
2227        let yarn = compute_scaled_inv_freq(dim, base, RoPEScaling::yarn_default(factor, 2048));
2228        let plain = compute_scaled_inv_freq(dim, base, RoPEScaling::None);
2229
2230        // Highest-frequency dim: extrapolation regime (value matches plain).
2231        assert!(
2232            (yarn[0] - plain[0]).abs() < 1e-12,
2233            "YARN[0] (extrapolation) should equal plain[0]: {} vs {}",
2234            yarn[0],
2235            plain[0]
2236        );
2237        // Lowest-frequency dim: interpolation regime (value matches plain/factor).
2238        let expected_low = plain[dim / 2 - 1] / factor;
2239        let ratio = yarn[dim / 2 - 1] / expected_low;
2240        assert!(
2241            (ratio - 1.0).abs() < 0.1,
2242            "YARN[dim/2-1] (interpolation) should approx equal plain/factor: {} vs {}",
2243            yarn[dim / 2 - 1],
2244            expected_low
2245        );
2246    }
2247
2248    #[test]
2249    fn test_rope_scaling_yarn_constructs() {
2250        let rope = RotaryPositionEmbedding::<f32>::with_scaling(
2251            64,
2252            256,
2253            10000.0,
2254            RoPEConvention::default(),
2255            RoPEScaling::yarn_default(2.0, 2048),
2256        )
2257        .unwrap();
2258        assert!(matches!(rope.scaling(), RoPEScaling::Yarn { .. }));
2259        let x = ferrotorch_core::ones::<f32>(&[1, 64]).unwrap();
2260        for &v in rope.apply(&x, 0).unwrap().data().unwrap() {
2261            assert!(v.is_finite());
2262        }
2263    }
2264
2265    #[test]
2266    fn test_rope_scaling_rejects_zero_factor() {
2267        let r = RotaryPositionEmbedding::<f32>::with_scaling(
2268            8,
2269            16,
2270            10000.0,
2271            RoPEConvention::default(),
2272            RoPEScaling::Linear { factor: 0.0 },
2273        );
2274        assert!(r.is_err());
2275    }
2276
2277    #[test]
2278    fn test_rope_scaling_rejects_negative_factor() {
2279        let r = RotaryPositionEmbedding::<f32>::with_scaling(
2280            8,
2281            16,
2282            10000.0,
2283            RoPEConvention::default(),
2284            RoPEScaling::NtkAware {
2285                factor: -2.0,
2286                original_max_pos_embeddings: 2048,
2287            },
2288        );
2289        assert!(r.is_err());
2290    }
2291
2292    #[test]
2293    fn test_rope_scaling_accessor() {
2294        let rope = RotaryPositionEmbedding::<f32>::with_scaling(
2295            16,
2296            64,
2297            10000.0,
2298            RoPEConvention::default(),
2299            RoPEScaling::Linear { factor: 4.0 },
2300        )
2301        .unwrap();
2302        assert_eq!(rope.scaling(), RoPEScaling::Linear { factor: 4.0 });
2303    }
2304
2305    // -----------------------------------------------------------------------
2306    // RoPE — HalfRotation convention
2307    // -----------------------------------------------------------------------
2308
2309    #[test]
2310    fn test_rope_half_rotation_construction() {
2311        let rope = RotaryPositionEmbedding::<f32>::with_convention(
2312            8,
2313            128,
2314            10000.0,
2315            RoPEConvention::HalfRotation,
2316        )
2317        .unwrap();
2318        assert_eq!(rope.convention(), RoPEConvention::HalfRotation);
2319    }
2320
2321    #[test]
2322    fn test_rope_half_rotation_output_shape() {
2323        let rope = RotaryPositionEmbedding::<f32>::with_convention(
2324            8,
2325            128,
2326            10000.0,
2327            RoPEConvention::HalfRotation,
2328        )
2329        .unwrap();
2330        let x = ferrotorch_core::zeros::<f32>(&[2, 4, 8]).unwrap();
2331        let y = rope.apply(&x, 0).unwrap();
2332        assert_eq!(y.shape(), &[2, 4, 8]);
2333    }
2334
2335    #[test]
2336    fn test_rope_half_rotation_position_zero_is_identity() {
2337        // At position 0, cos(0)=1, sin(0)=0 → identity regardless of convention.
2338        let rope = RotaryPositionEmbedding::<f64>::with_convention(
2339            4,
2340            64,
2341            10000.0,
2342            RoPEConvention::HalfRotation,
2343        )
2344        .unwrap();
2345        let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2346        let y = rope.apply(&x, 0).unwrap();
2347        let x_data = x.data().unwrap();
2348        let y_data = y.data().unwrap();
2349        for (i, (&xv, &yv)) in x_data.iter().zip(y_data.iter()).enumerate() {
2350            assert!(
2351                (xv - yv).abs() < 1e-10,
2352                "half-rot pos 0 should be identity, index {i}: x={xv}, y={yv}"
2353            );
2354        }
2355    }
2356
2357    #[test]
2358    fn test_rope_half_rotation_correctness() {
2359        // dim=4, so half_dim=2. For half-rotation:
2360        //   x_rot[0] = x[0]*cos0 - x[2]*sin0
2361        //   x_rot[1] = x[1]*cos1 - x[3]*sin1
2362        //   x_rot[2] = x[0]*sin0 + x[2]*cos0
2363        //   x_rot[3] = x[1]*sin1 + x[3]*cos1
2364        let rope = RotaryPositionEmbedding::<f64>::with_convention(
2365            4,
2366            64,
2367            10000.0,
2368            RoPEConvention::HalfRotation,
2369        )
2370        .unwrap();
2371
2372        // Use position 1 so sin != 0.
2373        let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2374        let y = rope.apply(&x, 1).unwrap();
2375
2376        // Get the cached cos/sin at position 1.
2377        let cos_data = rope.cos_cache.data().unwrap();
2378        let sin_data = rope.sin_cache.data().unwrap();
2379        // Position 1 → row offset = 1 * half_dim = 2
2380        let c0 = cos_data[2];
2381        let c1 = cos_data[3];
2382        let s0 = sin_data[2];
2383        let s1 = sin_data[3];
2384
2385        let expected = [
2386            1.0 * c0 - 3.0 * s0,
2387            2.0 * c1 - 4.0 * s1,
2388            1.0 * s0 + 3.0 * c0,
2389            2.0 * s1 + 4.0 * c1,
2390        ];
2391
2392        let y_data = y.data().unwrap();
2393        for (i, (&actual, &exp)) in y_data.iter().zip(expected.iter()).enumerate() {
2394            assert!(
2395                (actual - exp).abs() < 1e-10,
2396                "half-rot index {i}: actual={actual}, expected={exp}"
2397            );
2398        }
2399    }
2400
2401    #[test]
2402    fn test_rope_interleaved_vs_half_rotation_differ() {
2403        // Same input at position > 0 should produce different outputs.
2404        let rope_il = RotaryPositionEmbedding::<f64>::with_convention(
2405            4,
2406            64,
2407            10000.0,
2408            RoPEConvention::Interleaved,
2409        )
2410        .unwrap();
2411        let rope_hr = RotaryPositionEmbedding::<f64>::with_convention(
2412            4,
2413            64,
2414            10000.0,
2415            RoPEConvention::HalfRotation,
2416        )
2417        .unwrap();
2418
2419        let x = ferrotorch_core::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
2420        let y_il = rope_il.apply(&x, 1).unwrap();
2421        let y_hr = rope_hr.apply(&x, 1).unwrap();
2422
2423        // They should differ (different pairing).
2424        let il_data = y_il.data().unwrap();
2425        let hr_data = y_hr.data().unwrap();
2426        let any_differ = il_data
2427            .iter()
2428            .zip(hr_data.iter())
2429            .any(|(&a, &b)| (a - b).abs() > 1e-10);
2430        assert!(
2431            any_differ,
2432            "interleaved and half-rotation should produce different outputs at pos > 0"
2433        );
2434    }
2435
2436    #[test]
2437    fn test_rope_default_convention_is_interleaved() {
2438        let rope = RotaryPositionEmbedding::<f32>::new(8, 128, 10000.0).unwrap();
2439        assert_eq!(rope.convention(), RoPEConvention::Interleaved);
2440    }
2441
2442    // -----------------------------------------------------------------------
2443    // SwiGLU
2444    // -----------------------------------------------------------------------
2445
2446    #[test]
2447    fn test_swiglu_construction() {
2448        let swiglu = SwiGLU::<f32>::new(64, 128, true);
2449        assert!(swiglu.is_ok());
2450    }
2451
2452    #[test]
2453    fn test_swiglu_forward_shape_2d() {
2454        let swiglu = SwiGLU::<f32>::new(16, 32, true).unwrap();
2455        let input = ferrotorch_core::zeros::<f32>(&[4, 16]).unwrap();
2456        let output = swiglu.forward(&input).unwrap();
2457        assert_eq!(output.shape(), &[4, 16]);
2458    }
2459
2460    #[test]
2461    fn test_swiglu_forward_shape_3d() {
2462        let swiglu = SwiGLU::<f32>::new(16, 32, false).unwrap();
2463        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2464        let output = swiglu.forward(&input).unwrap();
2465        assert_eq!(output.shape(), &[2, 5, 16]);
2466    }
2467
2468    #[test]
2469    fn test_swiglu_forward_values_finite() {
2470        let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
2471        let input = ferrotorch_core::ones::<f32>(&[2, 3, 8]).unwrap();
2472        let output = swiglu.forward(&input).unwrap();
2473        for &v in output.data().unwrap() {
2474            assert!(v.is_finite(), "SwiGLU produced non-finite value: {v}");
2475        }
2476    }
2477
2478    #[test]
2479    fn test_swiglu_1d_rejected() {
2480        let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2481        let input = ferrotorch_core::zeros::<f32>(&[8]).unwrap();
2482        assert!(swiglu.forward(&input).is_err());
2483    }
2484
2485    #[test]
2486    fn test_swiglu_parameters() {
2487        let swiglu = SwiGLU::<f32>::new(8, 16, true).unwrap();
2488        let params = swiglu.parameters();
2489        // w1: weight + bias, w2: weight + bias, w3: weight + bias = 6
2490        assert_eq!(params.len(), 6);
2491
2492        let named = swiglu.named_parameters();
2493        let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
2494        assert!(names.contains(&"w1.weight"));
2495        assert!(names.contains(&"w1.bias"));
2496        assert!(names.contains(&"w2.weight"));
2497        assert!(names.contains(&"w2.bias"));
2498        assert!(names.contains(&"w3.weight"));
2499        assert!(names.contains(&"w3.bias"));
2500    }
2501
2502    #[test]
2503    fn test_swiglu_parameters_no_bias() {
2504        let swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2505        let params = swiglu.parameters();
2506        // w1: weight, w2: weight, w3: weight = 3
2507        assert_eq!(params.len(), 3);
2508    }
2509
2510    #[test]
2511    fn test_swiglu_train_eval() {
2512        let mut swiglu = SwiGLU::<f32>::new(8, 16, false).unwrap();
2513        assert!(swiglu.is_training());
2514        swiglu.eval();
2515        assert!(!swiglu.is_training());
2516        swiglu.train();
2517        assert!(swiglu.is_training());
2518    }
2519
2520    // -----------------------------------------------------------------------
2521    // KVCache
2522    // -----------------------------------------------------------------------
2523
2524    #[test]
2525    fn test_kv_cache_new_empty() {
2526        let cache = KVCache::<f32>::new(1024);
2527        assert!(cache.is_empty());
2528        assert_eq!(cache.seq_len(), 0);
2529        assert_eq!(cache.max_seq_len(), 1024);
2530    }
2531
2532    #[test]
2533    fn test_kv_cache_single_update() {
2534        let mut cache = KVCache::<f32>::new(128);
2535        // [B=1, heads=2, seq=3, dim=4]
2536        let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2537        let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2538        let (fk, fv) = cache.update(k, v).unwrap();
2539        assert_eq!(fk.shape(), &[1, 2, 3, 4]);
2540        assert_eq!(fv.shape(), &[1, 2, 3, 4]);
2541        assert_eq!(cache.seq_len(), 3);
2542    }
2543
2544    #[test]
2545    fn test_kv_cache_append() {
2546        let mut cache = KVCache::<f32>::new(128);
2547        // First: [1, 2, 3, 4]
2548        let k1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2549        let v1 = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2550        cache.update(k1, v1).unwrap();
2551        assert_eq!(cache.seq_len(), 3);
2552
2553        // Append: [1, 2, 2, 4]
2554        let k2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
2555        let v2 = ferrotorch_core::ones::<f32>(&[1, 2, 2, 4]).unwrap();
2556        let (fk, fv) = cache.update(k2, v2).unwrap();
2557        assert_eq!(fk.shape(), &[1, 2, 5, 4]); // 3 + 2 = 5
2558        assert_eq!(fv.shape(), &[1, 2, 5, 4]);
2559        assert_eq!(cache.seq_len(), 5);
2560    }
2561
2562    #[test]
2563    fn test_kv_cache_reset() {
2564        let mut cache = KVCache::<f32>::new(128);
2565        let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2566        let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2567        cache.update(k, v).unwrap();
2568        assert_eq!(cache.seq_len(), 3);
2569
2570        cache.reset();
2571        assert!(cache.is_empty());
2572        assert_eq!(cache.seq_len(), 0);
2573    }
2574
2575    #[test]
2576    fn test_kv_cache_overflow_rejected() {
2577        let mut cache = KVCache::<f32>::new(4);
2578        let k = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
2579        let v = ferrotorch_core::ones::<f32>(&[1, 1, 5, 2]).unwrap();
2580        // seq=5 > max_seq_len=4 -> error.
2581        assert!(cache.update(k, v).is_err());
2582    }
2583
2584    #[test]
2585    fn test_kv_cache_shape_mismatch_rejected() {
2586        let mut cache = KVCache::<f32>::new(128);
2587        let k = ferrotorch_core::ones::<f32>(&[1, 2, 3, 4]).unwrap();
2588        let v = ferrotorch_core::ones::<f32>(&[1, 2, 3, 8]).unwrap(); // dim mismatch
2589        assert!(cache.update(k, v).is_err());
2590    }
2591
2592    #[test]
2593    fn test_kv_cache_values_preserved() {
2594        let mut cache = KVCache::<f64>::new(128);
2595        // First update: all 1s.
2596        let k1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
2597        let v1 = ferrotorch_core::ones::<f64>(&[1, 1, 2, 3]).unwrap();
2598        cache.update(k1, v1).unwrap();
2599
2600        // Second update: all 2s.
2601        let k2_data = vec![2.0f64; 3];
2602        let k2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
2603        let v2 = ferrotorch_core::from_slice(&k2_data, &[1, 1, 1, 3]).unwrap();
2604        let (fk, _fv) = cache.update(k2, v2).unwrap();
2605
2606        assert_eq!(fk.shape(), &[1, 1, 3, 3]); // 2 + 1 = 3
2607        let fk_data = fk.data().unwrap();
2608        // First 2 rows should be 1.0, last row should be 2.0.
2609        for &v in &fk_data[..6] {
2610            assert!((v - 1.0).abs() < 1e-10, "expected 1.0, got {v}");
2611        }
2612        for &v in &fk_data[6..9] {
2613            assert!((v - 2.0).abs() < 1e-10, "expected 2.0, got {v}");
2614        }
2615    }
2616
2617    // -- GQA KVCache tests (#506) -------------------------------------------
2618
2619    #[test]
2620    fn test_kv_cache_gqa_stores_at_kv_head_granularity() {
2621        // Llama 3 8B: 8 KV heads, not 32. Cache dim 1 must be num_kv_heads.
2622        let mut cache = KVCache::<f32>::new(8192);
2623        let k = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
2624        let v = ferrotorch_core::zeros::<f32>(&[1, 8, 3, 128]).unwrap();
2625        let (fk, _) = cache.update(k, v).unwrap();
2626        assert_eq!(fk.shape(), &[1, 8, 3, 128]);
2627        assert_eq!(cache.num_kv_heads(), Some(8));
2628        assert_eq!(cache.head_dim(), Some(128));
2629        assert_eq!(cache.batch_size(), Some(1));
2630    }
2631
2632    #[test]
2633    fn test_kv_cache_with_dims_pre_declares_shape() {
2634        let cache = KVCache::<f32>::with_dims(8192, 1, 8, 128);
2635        assert_eq!(cache.num_kv_heads(), Some(8));
2636        assert_eq!(cache.head_dim(), Some(128));
2637        assert_eq!(cache.batch_size(), Some(1));
2638        assert!(cache.is_empty());
2639    }
2640
2641    #[test]
2642    fn test_kv_cache_with_dims_rejects_first_update_mismatch() {
2643        // Pre-declare num_kv_heads=8, then try to push num_kv_heads=4.
2644        let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2645        let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
2646        let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 16]).unwrap();
2647        assert!(cache.update(k, v).is_err());
2648    }
2649
2650    #[test]
2651    fn test_kv_cache_with_dims_rejects_head_dim_mismatch() {
2652        let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2653        let k = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap(); // dim=32 != 16
2654        let v = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 32]).unwrap();
2655        assert!(cache.update(k, v).is_err());
2656    }
2657
2658    #[test]
2659    fn test_kv_cache_with_dims_rejects_batch_mismatch() {
2660        let mut cache = KVCache::<f32>::with_dims(128, 2, 4, 8);
2661        let k = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap(); // B=1 != 2
2662        let v = ferrotorch_core::zeros::<f32>(&[1, 4, 2, 8]).unwrap();
2663        assert!(cache.update(k, v).is_err());
2664    }
2665
2666    #[test]
2667    fn test_kv_cache_with_dims_accepts_matching_update() {
2668        let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2669        let k = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
2670        let v = ferrotorch_core::ones::<f32>(&[1, 8, 3, 16]).unwrap();
2671        assert!(cache.update(k, v).is_ok());
2672        assert_eq!(cache.seq_len(), 3);
2673    }
2674
2675    #[test]
2676    fn test_kv_cache_inferred_dims_reject_subsequent_mismatch() {
2677        // First push defines dims; second push with different num_kv_heads must fail.
2678        let mut cache = KVCache::<f32>::new(128);
2679        let k1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
2680        let v1 = ferrotorch_core::zeros::<f32>(&[1, 8, 2, 16]).unwrap();
2681        cache.update(k1, v1).unwrap();
2682        assert_eq!(cache.num_kv_heads(), Some(8));
2683
2684        let k2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap(); // 4 != 8
2685        let v2 = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
2686        assert!(cache.update(k2, v2).is_err());
2687    }
2688
2689    #[test]
2690    fn test_kv_cache_dims_not_yet_pinned_on_fresh_new() {
2691        let cache = KVCache::<f32>::new(128);
2692        assert_eq!(cache.num_kv_heads(), None);
2693        assert_eq!(cache.head_dim(), None);
2694        assert_eq!(cache.batch_size(), None);
2695    }
2696
2697    #[test]
2698    fn test_kv_cache_reset_preserves_pinned_dims() {
2699        let mut cache = KVCache::<f32>::with_dims(128, 1, 8, 16);
2700        let k = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
2701        let v = ferrotorch_core::ones::<f32>(&[1, 8, 2, 16]).unwrap();
2702        cache.update(k, v).unwrap();
2703        cache.reset();
2704        assert!(cache.is_empty());
2705        // Dims are retained so the cache still validates the next push.
2706        assert_eq!(cache.num_kv_heads(), Some(8));
2707        let bad = ferrotorch_core::zeros::<f32>(&[1, 4, 1, 16]).unwrap();
2708        assert!(cache.update(bad.clone(), bad).is_err());
2709    }
2710
2711    #[test]
2712    fn test_kv_cache_gqa_prefill_then_decode_preserves_all_positions() {
2713        // Acceptance: "Decoder step using this cache produces outputs
2714        // matching un-cached GQA attention on the same inputs." We prove
2715        // the cache round-trips data faithfully by:
2716        //   (1) prefilling 4 tokens, then pushing 1 decode token
2717        //   (2) verifying every (batch, head, seq, dim) position in the
2718        //       returned full tensor matches the source tensors at the
2719        //       corresponding index.
2720        let build = |seed: u64, shape: &[usize]| {
2721            let numel: usize = shape.iter().product();
2722            let data: Vec<f32> = (0..numel)
2723                .map(|i| ((i as u64).wrapping_mul(seed) % 997) as f32 * 0.001)
2724                .collect();
2725            ferrotorch_core::from_slice(&data, shape).unwrap()
2726        };
2727
2728        // Llama-8B-ish: 1 batch, 8 KV heads, head_dim=16 (scaled down).
2729        let (b, h, s_prefill, s_decode, d) = (1usize, 8usize, 4usize, 1usize, 16usize);
2730        let s_full = s_prefill + s_decode;
2731
2732        let k_prefill = build(7, &[b, h, s_prefill, d]);
2733        let v_prefill = build(11, &[b, h, s_prefill, d]);
2734        let k_decode = build(13, &[b, h, s_decode, d]);
2735        let v_decode = build(17, &[b, h, s_decode, d]);
2736
2737        let mut cache = KVCache::<f32>::with_dims(16, b, h, d);
2738        cache.update(k_prefill.clone(), v_prefill.clone()).unwrap();
2739        let (fk, fv) = cache.update(k_decode.clone(), v_decode.clone()).unwrap();
2740        assert_eq!(fk.shape(), &[b, h, s_full, d]);
2741        assert_eq!(fv.shape(), &[b, h, s_full, d]);
2742
2743        let fk_data = fk.data_vec().unwrap();
2744        let fv_data = fv.data_vec().unwrap();
2745        let kp = k_prefill.data_vec().unwrap();
2746        let vp = v_prefill.data_vec().unwrap();
2747        let kd = k_decode.data_vec().unwrap();
2748        let vd = v_decode.data_vec().unwrap();
2749
2750        // Row-major [B, H, S, D] stride.
2751        let full_idx = |bi, hi, si, di| ((bi * h + hi) * s_full + si) * d + di;
2752        let src_idx = |bi, hi, si, di, s_len| ((bi * h + hi) * s_len + si) * d + di;
2753
2754        for bi in 0..b {
2755            for hi in 0..h {
2756                for si in 0..s_full {
2757                    for di in 0..d {
2758                        let out = full_idx(bi, hi, si, di);
2759                        let (exp_k, exp_v) = if si < s_prefill {
2760                            let src = src_idx(bi, hi, si, di, s_prefill);
2761                            (kp[src], vp[src])
2762                        } else {
2763                            let src = src_idx(bi, hi, si - s_prefill, di, s_decode);
2764                            (kd[src], vd[src])
2765                        };
2766                        assert!(
2767                            (fk_data[out] - exp_k).abs() < 1e-6,
2768                            "k mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_k}",
2769                            fk_data[out]
2770                        );
2771                        assert!(
2772                            (fv_data[out] - exp_v).abs() < 1e-6,
2773                            "v mismatch at [b={bi}, h={hi}, s={si}, d={di}]: got {}, want {exp_v}",
2774                            fv_data[out]
2775                        );
2776                    }
2777                }
2778            }
2779        }
2780    }
2781
2782    // -----------------------------------------------------------------------
2783    // TransformerEncoderLayer
2784    // -----------------------------------------------------------------------
2785
2786    #[test]
2787    fn test_encoder_layer_construction() {
2788        let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
2789        assert!(layer.is_ok());
2790    }
2791
2792    #[test]
2793    fn test_encoder_layer_forward_shape() {
2794        let layer = TransformerEncoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
2795        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2796        let output = layer.forward(&input).unwrap();
2797        assert_eq!(output.shape(), &[2, 5, 16]);
2798    }
2799
2800    #[test]
2801    fn test_encoder_layer_forward_values_finite() {
2802        let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2803        let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2804        let output = layer.forward(&input).unwrap();
2805        for &v in output.data().unwrap() {
2806            assert!(
2807                v.is_finite(),
2808                "TransformerEncoderLayer produced non-finite value: {v}"
2809            );
2810        }
2811    }
2812
2813    #[test]
2814    fn test_encoder_layer_2d_rejected() {
2815        let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
2816        let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2817        assert!(layer.forward(&input).is_err());
2818    }
2819
2820    #[test]
2821    fn test_encoder_layer_parameters_count() {
2822        let layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2823        let params = layer.parameters();
2824        // self_attn: 4 weights + 4 biases = 8
2825        // ffn (SwiGLU): 3 weights + 3 biases = 6
2826        // norm1: weight + bias = 2
2827        // norm2: weight + bias = 2
2828        // Total: 18
2829        assert_eq!(params.len(), 18);
2830    }
2831
2832    #[test]
2833    fn test_encoder_layer_train_eval() {
2834        let mut layer = TransformerEncoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
2835        assert!(layer.is_training());
2836        layer.eval();
2837        assert!(!layer.is_training());
2838        layer.train();
2839        assert!(layer.is_training());
2840    }
2841
2842    #[test]
2843    fn test_encoder_layer_is_send_sync() {
2844        fn assert_send_sync<T: Send + Sync>() {}
2845        assert_send_sync::<TransformerEncoderLayer<f32>>();
2846        assert_send_sync::<TransformerEncoderLayer<f64>>();
2847    }
2848
2849    // -----------------------------------------------------------------------
2850    // TransformerDecoderLayer
2851    // -----------------------------------------------------------------------
2852
2853    #[test]
2854    fn test_decoder_layer_construction() {
2855        let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, true);
2856        assert!(layer.is_ok());
2857    }
2858
2859    #[test]
2860    fn test_decoder_layer_forward_shape() {
2861        let layer = TransformerDecoderLayer::<f32>::new(16, 4, 32, 0.0, 1e-5, false).unwrap();
2862        // decoder input: [2, 4, 16], encoder memory: [2, 6, 16]
2863        let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
2864        let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
2865        let output = layer.forward_with_memory(&tgt, &memory).unwrap();
2866        assert_eq!(output.shape(), &[2, 4, 16]);
2867    }
2868
2869    #[test]
2870    fn test_decoder_layer_self_forward_shape() {
2871        // Module::forward uses input as both decoder input and memory.
2872        let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2873        let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2874        let output = layer.forward(&input).unwrap();
2875        assert_eq!(output.shape(), &[1, 3, 8]);
2876    }
2877
2878    #[test]
2879    fn test_decoder_layer_forward_values_finite() {
2880        let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2881        let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2882        let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
2883        let output = layer.forward_with_memory(&tgt, &mem).unwrap();
2884        for &v in output.data().unwrap() {
2885            assert!(
2886                v.is_finite(),
2887                "TransformerDecoderLayer produced non-finite value: {v}"
2888            );
2889        }
2890    }
2891
2892    #[test]
2893    fn test_decoder_layer_2d_rejected() {
2894        let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, false).unwrap();
2895        let input = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2896        let memory = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap();
2897        assert!(layer.forward_with_memory(&input, &memory).is_err());
2898    }
2899
2900    #[test]
2901    fn test_decoder_layer_parameters_count() {
2902        let layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.0, 1e-5, true).unwrap();
2903        let params = layer.parameters();
2904        // self_attn: 4 weights + 4 biases = 8
2905        // cross_attn: 4 weights + 4 biases = 8
2906        // ffn (SwiGLU): 3 weights + 3 biases = 6
2907        // norm1: weight + bias = 2
2908        // norm2: weight + bias = 2
2909        // norm3: weight + bias = 2
2910        // Total: 28
2911        assert_eq!(params.len(), 28);
2912    }
2913
2914    #[test]
2915    fn test_decoder_layer_train_eval() {
2916        let mut layer = TransformerDecoderLayer::<f32>::new(8, 2, 16, 0.1, 1e-5, false).unwrap();
2917        assert!(layer.is_training());
2918        layer.eval();
2919        assert!(!layer.is_training());
2920        layer.train();
2921        assert!(layer.is_training());
2922    }
2923
2924    #[test]
2925    fn test_decoder_layer_is_send_sync() {
2926        fn assert_send_sync<T: Send + Sync>() {}
2927        assert_send_sync::<TransformerDecoderLayer<f32>>();
2928        assert_send_sync::<TransformerDecoderLayer<f64>>();
2929    }
2930
2931    // -----------------------------------------------------------------------
2932    // TransformerEncoder
2933    // -----------------------------------------------------------------------
2934
2935    #[test]
2936    fn test_encoder_construction() {
2937        let enc = TransformerEncoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
2938        assert!(enc.is_ok());
2939        assert_eq!(enc.unwrap().num_layers(), 3);
2940    }
2941
2942    #[test]
2943    fn test_encoder_zero_layers_rejected() {
2944        assert!(TransformerEncoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
2945    }
2946
2947    #[test]
2948    fn test_encoder_forward_shape() {
2949        let enc = TransformerEncoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
2950        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
2951        let output = enc.forward(&input).unwrap();
2952        assert_eq!(output.shape(), &[2, 5, 16]);
2953    }
2954
2955    #[test]
2956    fn test_encoder_forward_no_final_norm() {
2957        let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, false, false).unwrap();
2958        let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2959        let output = enc.forward(&input).unwrap();
2960        assert_eq!(output.shape(), &[1, 3, 8]);
2961    }
2962
2963    #[test]
2964    fn test_encoder_forward_values_finite() {
2965        let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2966        let input = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
2967        let output = enc.forward(&input).unwrap();
2968        for &v in output.data().unwrap() {
2969            assert!(
2970                v.is_finite(),
2971                "TransformerEncoder produced non-finite value: {v}"
2972            );
2973        }
2974    }
2975
2976    #[test]
2977    fn test_encoder_parameters_with_final_norm() {
2978        let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2979        // Each encoder layer: 18 params (see test_encoder_layer_parameters_count)
2980        // Final norm: 2 params (weight + bias)
2981        // Total: 2 * 18 + 2 = 38
2982        assert_eq!(enc.parameters().len(), 38);
2983    }
2984
2985    #[test]
2986    fn test_encoder_named_parameters_have_layer_prefix() {
2987        let enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
2988        let named = enc.named_parameters();
2989        // Verify layer indexing in names.
2990        let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
2991        let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
2992        let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
2993        assert!(has_layer_0, "missing layers.0.* in named_parameters");
2994        assert!(has_layer_1, "missing layers.1.* in named_parameters");
2995        assert!(has_norm, "missing norm.* in named_parameters");
2996    }
2997
2998    #[test]
2999    fn test_encoder_train_eval() {
3000        let mut enc = TransformerEncoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
3001        assert!(enc.is_training());
3002        enc.eval();
3003        assert!(!enc.is_training());
3004        enc.train();
3005        assert!(enc.is_training());
3006    }
3007
3008    #[test]
3009    fn test_encoder_is_send_sync() {
3010        fn assert_send_sync<T: Send + Sync>() {}
3011        assert_send_sync::<TransformerEncoder<f32>>();
3012        assert_send_sync::<TransformerEncoder<f64>>();
3013    }
3014
3015    // -----------------------------------------------------------------------
3016    // TransformerDecoder
3017    // -----------------------------------------------------------------------
3018
3019    #[test]
3020    fn test_decoder_construction() {
3021        let dec = TransformerDecoder::<f32>::new(16, 4, 3, 32, 0.0, 1e-5, true, true);
3022        assert!(dec.is_ok());
3023        assert_eq!(dec.unwrap().num_layers(), 3);
3024    }
3025
3026    #[test]
3027    fn test_decoder_zero_layers_rejected() {
3028        assert!(TransformerDecoder::<f32>::new(16, 4, 0, 32, 0.0, 1e-5, true, true).is_err());
3029    }
3030
3031    #[test]
3032    fn test_decoder_forward_with_memory_shape() {
3033        let dec = TransformerDecoder::<f32>::new(16, 4, 2, 32, 0.0, 1e-5, false, true).unwrap();
3034        let tgt = ferrotorch_core::zeros::<f32>(&[2, 4, 16]).unwrap();
3035        let memory = ferrotorch_core::zeros::<f32>(&[2, 6, 16]).unwrap();
3036        let output = dec.forward_with_memory(&tgt, &memory).unwrap();
3037        assert_eq!(output.shape(), &[2, 4, 16]);
3038    }
3039
3040    #[test]
3041    fn test_decoder_forward_values_finite() {
3042        let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3043        let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
3044        let mem = ferrotorch_core::ones::<f32>(&[1, 5, 8]).unwrap();
3045        let output = dec.forward_with_memory(&tgt, &mem).unwrap();
3046        for &v in output.data().unwrap() {
3047            assert!(
3048                v.is_finite(),
3049                "TransformerDecoder produced non-finite value: {v}"
3050            );
3051        }
3052    }
3053
3054    #[test]
3055    fn test_decoder_parameters_with_final_norm() {
3056        let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3057        // Each decoder layer: 28 params (see test_decoder_layer_parameters_count)
3058        // Final norm: 2 params (weight + bias)
3059        // Total: 2 * 28 + 2 = 58
3060        assert_eq!(dec.parameters().len(), 58);
3061    }
3062
3063    #[test]
3064    fn test_decoder_named_parameters_have_layer_prefix() {
3065        let dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.0, 1e-5, true, true).unwrap();
3066        let named = dec.named_parameters();
3067        let has_layer_0 = named.iter().any(|(n, _)| n.starts_with("layers.0."));
3068        let has_layer_1 = named.iter().any(|(n, _)| n.starts_with("layers.1."));
3069        let has_norm = named.iter().any(|(n, _)| n.starts_with("norm."));
3070        assert!(has_layer_0, "missing layers.0.* in named_parameters");
3071        assert!(has_layer_1, "missing layers.1.* in named_parameters");
3072        assert!(has_norm, "missing norm.* in named_parameters");
3073    }
3074
3075    #[test]
3076    fn test_decoder_train_eval() {
3077        let mut dec = TransformerDecoder::<f32>::new(8, 2, 2, 16, 0.1, 1e-5, false, false).unwrap();
3078        assert!(dec.is_training());
3079        dec.eval();
3080        assert!(!dec.is_training());
3081        dec.train();
3082        assert!(dec.is_training());
3083    }
3084
3085    #[test]
3086    fn test_decoder_is_send_sync() {
3087        fn assert_send_sync<T: Send + Sync>() {}
3088        assert_send_sync::<TransformerDecoder<f32>>();
3089        assert_send_sync::<TransformerDecoder<f64>>();
3090    }
3091
3092    // -----------------------------------------------------------------------
3093    // Transformer (full encoder-decoder)
3094    // -----------------------------------------------------------------------
3095
3096    #[test]
3097    fn test_transformer_construction() {
3098        let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, true);
3099        assert!(t.is_ok());
3100        let t = t.unwrap();
3101        assert_eq!(t.num_encoder_layers(), 2);
3102        assert_eq!(t.num_decoder_layers(), 2);
3103    }
3104
3105    #[test]
3106    fn test_transformer_forward_shape() {
3107        let t = Transformer::<f32>::new(16, 4, 2, 2, 32, 0.0, 1e-5, false).unwrap();
3108        let src = ferrotorch_core::zeros::<f32>(&[2, 10, 16]).unwrap();
3109        let tgt = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
3110        let output = t.forward_transformer(&src, &tgt).unwrap();
3111        assert_eq!(output.shape(), &[2, 5, 16]);
3112    }
3113
3114    #[test]
3115    fn test_transformer_self_forward_shape() {
3116        // Module::forward uses input as both src and tgt.
3117        let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, false).unwrap();
3118        let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
3119        let output = t.forward(&input).unwrap();
3120        assert_eq!(output.shape(), &[1, 3, 8]);
3121    }
3122
3123    #[test]
3124    fn test_transformer_forward_values_finite() {
3125        let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
3126        let src = ferrotorch_core::ones::<f32>(&[1, 4, 8]).unwrap();
3127        let tgt = ferrotorch_core::ones::<f32>(&[1, 3, 8]).unwrap();
3128        let output = t.forward_transformer(&src, &tgt).unwrap();
3129        for &v in output.data().unwrap() {
3130            assert!(v.is_finite(), "Transformer produced non-finite value: {v}");
3131        }
3132    }
3133
3134    #[test]
3135    fn test_transformer_parameters_count() {
3136        let t = Transformer::<f32>::new(8, 2, 2, 2, 16, 0.0, 1e-5, true).unwrap();
3137        // Encoder: 2 layers * 18 params + 2 (final norm) = 38
3138        // Decoder: 2 layers * 28 params + 2 (final norm) = 58
3139        // Total: 96
3140        assert_eq!(t.parameters().len(), 96);
3141    }
3142
3143    #[test]
3144    fn test_transformer_named_parameters_prefixed() {
3145        let t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.0, 1e-5, true).unwrap();
3146        let named = t.named_parameters();
3147        let has_encoder = named.iter().any(|(n, _)| n.starts_with("encoder."));
3148        let has_decoder = named.iter().any(|(n, _)| n.starts_with("decoder."));
3149        assert!(has_encoder, "missing encoder.* in named_parameters");
3150        assert!(has_decoder, "missing decoder.* in named_parameters");
3151    }
3152
3153    #[test]
3154    fn test_transformer_train_eval() {
3155        let mut t = Transformer::<f32>::new(8, 2, 1, 1, 16, 0.1, 1e-5, false).unwrap();
3156        assert!(t.is_training());
3157        t.eval();
3158        assert!(!t.is_training());
3159        t.train();
3160        assert!(t.is_training());
3161    }
3162
3163    #[test]
3164    fn test_transformer_is_send_sync() {
3165        fn assert_send_sync<T: Send + Sync>() {}
3166        assert_send_sync::<Transformer<f32>>();
3167        assert_send_sync::<Transformer<f64>>();
3168    }
3169}