Skip to main content

ferrotorch_nn/
attention.rs

1//! Multi-head attention layer.
2//!
3//! Implements scaled dot-product attention with multiple heads, following the
4//! "Attention Is All You Need" paper (Vaswani et al., 2017). All operations
5//! use differentiable primitives from `ferrotorch_core`, so autograd handles
6//! the backward pass automatically.
7//!
8//! # Grouped-Query Attention (GQA)
9//!
10//! When `num_kv_heads < num_heads`, keys and values share a smaller head
11//! count than queries (Llama 3 uses 32 Q heads : 8 KV heads). The K and V
12//! projections are sized `[num_kv_heads * head_dim, embed_dim]` and each
13//! KV head serves `group_size = num_heads / num_kv_heads` consecutive
14//! Q heads via `repeat_kv` before the attention matmul. Construct a GQA
15//! attention with [`MultiheadAttention::with_gqa`]; the default [`MultiheadAttention::new`]
16//! preserves classical MHA (`num_kv_heads = num_heads`).
17//!
18//! ## REQ status (per `.design/ferrotorch-nn/attention.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 | SHIPPED | the `MultiheadAttention<T>` struct here mirrors upstream `activation.py:1089-1200`; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:194` + `ferrotorch-vision/src/models/vit.rs:20` |
23//! | REQ-2 | SHIPPED | the `MultiheadAttention::new` constructor here (delegates to `with_gqa`) mirroring upstream `activation.py:1153-1188`; non-test consumer: re-export at `lib.rs:194` + `vit.rs:20` |
24//! | REQ-3 | SHIPPED | the `MultiheadAttention::with_gqa` constructor here with `num_heads % num_kv_heads` validation; non-test consumer: re-export at `lib.rs:194` and `ferrotorch-llama/src/attention.rs:23` |
25//! | REQ-4 | SHIPPED | the `forward_qkv` method here with 3-D / batch / seq shape validation; non-test consumer: re-export at `lib.rs:194` + `vit.rs:20` |
26//! | REQ-5 | SHIPPED | the general-path attention body inside `forward_qkv` using `mm_differentiable`, `bmm_differentiable`, `softmax`, `mul`, `add` from `ferrotorch_core::grad_fns`; non-test consumer: `vit.rs:20`, `ferrotorch-llama/src/attention.rs:23` |
27//! | REQ-6 | SHIPPED | the causal-mask construction inside `forward_qkv` (additive `-1e9` matrix `[1, seq_q, seq_k]` moved to device); non-test consumer: `vit.rs:20`, `ferrotorch-llama/src/attention.rs:23` |
28//! | REQ-7 | SHIPPED | the `group_size > 1` branch inside `forward_qkv` using `expand` from `ferrotorch_core::grad_fns::shape`; non-test consumer: `ferrotorch-llama/src/attention.rs:23` |
29//! | REQ-8 | SHIPPED | the `impl<T: Float> Module<T> for MultiheadAttention<T>` block here; non-test consumer: re-export at `lib.rs:194` |
30//! | REQ-9 | SHIPPED | the `forward_2d` method here (GQA-rejection + fused-linear short circuit); non-test consumer: re-export at `lib.rs:194` |
31//! | REQ-10 | SHIPPED | the `reshape_to_heads`, `transpose_heads_to_2d`, `repeat_kv` helpers here; non-test consumer: `ferrotorch-llama/src/attention.rs:23` imports `repeat_kv` and `reshape_to_heads` |
32//! | REQ-11 | NOT-STARTED | parity-sweep runner arm for `nn.functional.scaled_dot_product_attention` not wired — blocker #1455 |
33//! | REQ-12 | NOT-STARTED | parity-sweep runner arm for `nn.functional.multi_head_attention_forward` not wired — blocker #1455 |
34
35use ferrotorch_core::grad_fns::activation::softmax;
36use ferrotorch_core::grad_fns::arithmetic::{add, mul};
37use ferrotorch_core::grad_fns::linalg::{bmm_differentiable, mm_differentiable};
38use ferrotorch_core::grad_fns::shape::{expand, transpose_2d};
39use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
40
41use crate::init::{xavier_uniform, zeros};
42use crate::module::Module;
43use crate::parameter::Parameter;
44
45/// Multi-head attention mechanism.
46///
47/// Computes scaled dot-product attention across `num_heads` parallel heads,
48/// projecting queries, keys, and values through learned linear transformations.
49///
50/// # Shape contract
51///
52/// - Input: `[batch, seq_len, embed_dim]`
53/// - Output: `[batch, seq_len, embed_dim]`
54///
55/// # Example
56///
57/// ```ignore
58/// let mha = MultiheadAttention::<f32>::new(64, 8, true)?;
59/// let input = ferrotorch_core::randn::<f32>(&[2, 10, 64])?;
60/// let output = mha.forward(&input)?;
61/// assert_eq!(output.shape(), &[2, 10, 64]);
62/// ```
63#[derive(Debug)]
64pub struct MultiheadAttention<T: Float> {
65    pub embed_dim: usize,
66    pub num_heads: usize,
67    /// Number of key/value heads. Equals `num_heads` for classical MHA;
68    /// less than `num_heads` (and a divisor thereof) for grouped-query
69    /// attention.
70    pub num_kv_heads: usize,
71    pub head_dim: usize,
72
73    /// Query projection weight: `[embed_dim, embed_dim]`.
74    pub q_proj: Parameter<T>,
75    /// Key projection weight: `[num_kv_heads * head_dim, embed_dim]`.
76    pub k_proj: Parameter<T>,
77    /// Value projection weight: `[num_kv_heads * head_dim, embed_dim]`.
78    pub v_proj: Parameter<T>,
79    /// Output projection weight: `[embed_dim, embed_dim]`.
80    pub out_proj: Parameter<T>,
81
82    /// Optional Q/O biases: `[embed_dim]`. K/V biases: `[num_kv_heads * head_dim]`.
83    pub q_bias: Option<Parameter<T>>,
84    pub k_bias: Option<Parameter<T>>,
85    pub v_bias: Option<Parameter<T>>,
86    pub out_bias: Option<Parameter<T>>,
87
88    pub training: bool,
89}
90
91impl<T: Float> MultiheadAttention<T> {
92    /// Create a new classical multi-head attention layer
93    /// (`num_kv_heads == num_heads`).
94    ///
95    /// # Arguments
96    ///
97    /// - `embed_dim` - Total embedding dimension (must be divisible by `num_heads`).
98    /// - `num_heads` - Number of parallel attention heads.
99    /// - `bias` - Whether to include additive bias in projections.
100    ///
101    /// # Errors
102    ///
103    /// Returns `FerrotorchError::InvalidArgument` if `embed_dim % num_heads != 0`.
104    pub fn new(embed_dim: usize, num_heads: usize, bias: bool) -> FerrotorchResult<Self> {
105        Self::with_gqa(embed_dim, num_heads, num_heads, bias)
106    }
107
108    /// Create a grouped-query (or classical) attention layer.
109    ///
110    /// When `num_kv_heads < num_heads`, each KV head serves
111    /// `group_size = num_heads / num_kv_heads` consecutive query heads.
112    /// `num_kv_heads == num_heads` reproduces classical MHA.
113    ///
114    /// # Arguments
115    ///
116    /// - `embed_dim` - Total embedding dimension (must be divisible by `num_heads`).
117    /// - `num_heads` - Number of parallel query heads.
118    /// - `num_kv_heads` - Number of key/value heads. Must divide `num_heads` evenly.
119    /// - `bias` - Whether to include additive bias in projections.
120    ///
121    /// # Errors
122    ///
123    /// - `embed_dim == 0` or either head count is zero.
124    /// - `embed_dim % num_heads != 0`.
125    /// - `num_heads % num_kv_heads != 0`.
126    pub fn with_gqa(
127        embed_dim: usize,
128        num_heads: usize,
129        num_kv_heads: usize,
130        bias: bool,
131    ) -> FerrotorchResult<Self> {
132        if embed_dim == 0 || num_heads == 0 || num_kv_heads == 0 {
133            return Err(FerrotorchError::InvalidArgument {
134                message: "embed_dim, num_heads, num_kv_heads must be positive".into(),
135            });
136        }
137        if embed_dim % num_heads != 0 {
138            return Err(FerrotorchError::InvalidArgument {
139                message: format!(
140                    "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
141                ),
142            });
143        }
144        if num_heads % num_kv_heads != 0 {
145            return Err(FerrotorchError::InvalidArgument {
146                message: format!(
147                    "num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
148                ),
149            });
150        }
151
152        let head_dim = embed_dim / num_heads;
153        let kv_dim = num_kv_heads * head_dim;
154
155        let mut q_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
156        let mut k_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
157        let mut v_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
158        let mut out_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
159
160        xavier_uniform(&mut q_proj)?;
161        xavier_uniform(&mut k_proj)?;
162        xavier_uniform(&mut v_proj)?;
163        xavier_uniform(&mut out_proj)?;
164
165        let (q_bias, k_bias, v_bias, out_bias) = if bias {
166            let mut qb = Parameter::zeros(&[embed_dim])?;
167            let mut kb = Parameter::zeros(&[kv_dim])?;
168            let mut vb = Parameter::zeros(&[kv_dim])?;
169            let mut ob = Parameter::zeros(&[embed_dim])?;
170            zeros(&mut qb)?;
171            zeros(&mut kb)?;
172            zeros(&mut vb)?;
173            zeros(&mut ob)?;
174            (Some(qb), Some(kb), Some(vb), Some(ob))
175        } else {
176            (None, None, None, None)
177        };
178
179        Ok(Self {
180            embed_dim,
181            num_heads,
182            num_kv_heads,
183            head_dim,
184            q_proj,
185            k_proj,
186            v_proj,
187            out_proj,
188            q_bias,
189            k_bias,
190            v_bias,
191            out_bias,
192            training: true,
193        })
194    }
195
196    /// Forward pass with separate query, key, and value tensors (cross-attention).
197    ///
198    /// # Arguments
199    ///
200    /// - `query` - `[batch, seq_q, embed_dim]`
201    /// - `key` - `[batch, seq_k, embed_dim]`
202    /// - `value` - `[batch, seq_k, embed_dim]`
203    /// - `causal_mask` - If `true`, apply a causal (lower-triangular) mask so that
204    ///   position `i` cannot attend to positions `j > i`. Only valid when
205    ///   `seq_q == seq_k`.
206    ///
207    /// # Returns
208    ///
209    /// Output tensor of shape `[batch, seq_q, embed_dim]`.
210    pub fn forward_qkv(
211        &self,
212        query: &Tensor<T>,
213        key: &Tensor<T>,
214        value: &Tensor<T>,
215        causal_mask: bool,
216    ) -> FerrotorchResult<Tensor<T>> {
217        // --- Validate input shapes ---
218        if query.ndim() != 3 || key.ndim() != 3 || value.ndim() != 3 {
219            return Err(FerrotorchError::InvalidArgument {
220                message: format!(
221                    "MultiheadAttention expects 3-D inputs [batch, seq, embed_dim], \
222                     got query {:?}, key {:?}, value {:?}",
223                    query.shape(),
224                    key.shape(),
225                    value.shape()
226                ),
227            });
228        }
229
230        let batch = query.shape()[0];
231        let seq_q = query.shape()[1];
232        let seq_k = key.shape()[1];
233
234        if query.shape()[2] != self.embed_dim
235            || key.shape()[2] != self.embed_dim
236            || value.shape()[2] != self.embed_dim
237        {
238            return Err(FerrotorchError::ShapeMismatch {
239                message: format!(
240                    "embed_dim mismatch: expected {}, got query={}, key={}, value={}",
241                    self.embed_dim,
242                    query.shape()[2],
243                    key.shape()[2],
244                    value.shape()[2]
245                ),
246            });
247        }
248
249        if key.shape()[0] != batch || value.shape()[0] != batch {
250            return Err(FerrotorchError::ShapeMismatch {
251                message: format!(
252                    "batch size mismatch: query batch={}, key batch={}, value batch={}",
253                    batch,
254                    key.shape()[0],
255                    value.shape()[0]
256                ),
257            });
258        }
259
260        if key.shape()[1] != value.shape()[1] {
261            return Err(FerrotorchError::ShapeMismatch {
262                message: format!(
263                    "key and value seq_len must match: key={}, value={}",
264                    key.shape()[1],
265                    value.shape()[1]
266                ),
267            });
268        }
269
270        if causal_mask && seq_q != seq_k {
271            return Err(FerrotorchError::InvalidArgument {
272                message: format!(
273                    "causal mask requires seq_q == seq_k, got seq_q={seq_q}, seq_k={seq_k}"
274                ),
275            });
276        }
277
278        // ─── Fast path: self-attention with seq_len=1 ───────────────
279        // When seq_len=1, attention scores are [1,1] -> softmax = 1.0,
280        // so context = V identically. The whole MHA reduces to:
281        //   output = V_proj(input) @ W_O^T + bias
282        // We skip Q/K projections (saves 2 matmuls) and the per-head
283        // attention loop (saves reshape, transpose, softmax per head).
284        //
285        // Gated on `num_kv_heads == num_heads` — the GQA case needs a
286        // repeat_kv step between V_proj and O_proj (V_proj outputs kv_dim,
287        // O_proj expects embed_dim), so we fall through to the general path.
288        if seq_q == 1 && seq_k == 1 && !causal_mask && self.num_kv_heads == self.num_heads {
289            use ferrotorch_core::grad_fns::linalg::linear_fused;
290
291            // Squeeze [batch, 1, embed_dim] -> [batch, embed_dim]
292            let v_2d = value.reshape_t(&[batch as isize, self.embed_dim as isize])?;
293
294            // V_proj = v_2d @ W_V^T + v_bias -> [batch, embed_dim]
295            let v_proj = linear_fused(
296                &v_2d,
297                self.v_proj.tensor(),
298                self.v_bias.as_ref().map(|b| b.tensor()),
299            )?;
300
301            // O_proj = v_proj @ W_O^T + o_bias -> [batch, embed_dim]
302            let output = linear_fused(
303                &v_proj,
304                self.out_proj.tensor(),
305                self.out_bias.as_ref().map(|b| b.tensor()),
306            )?;
307
308            // Unsqueeze [batch, embed_dim] -> [batch, 1, embed_dim]
309            return output.reshape_t(&[batch as isize, 1, self.embed_dim as isize]);
310        }
311
312        // ─── General path: batched multi-head attention ────────────
313        //
314        // Fully differentiable, GPU-compatible. Uses reshape/permute
315        // (zero-copy metadata ops) instead of data shuffling, and
316        // bmm_differentiable for the full batch at once.
317
318        let nh = self.num_heads;
319        let nkv = self.num_kv_heads;
320        let hd = self.head_dim;
321        let group_size = nh / nkv;
322
323        // 1. Project Q/K/V. Flatten to 2-D for the matmul.
324        let wq_t = transpose_2d(self.q_proj.tensor())?;
325        let wk_t = transpose_2d(self.k_proj.tensor())?;
326        let wv_t = transpose_2d(self.v_proj.tensor())?;
327        let wo_t = transpose_2d(self.out_proj.tensor())?;
328
329        let flat_q = query.reshape_t(&[-1, self.embed_dim as isize])?;
330        let flat_k = key.reshape_t(&[-1, self.embed_dim as isize])?;
331        let flat_v = value.reshape_t(&[-1, self.embed_dim as isize])?;
332
333        let mut q_proj = mm_differentiable(&flat_q, &wq_t)?;
334        let mut k_proj = mm_differentiable(&flat_k, &wk_t)?;
335        let mut v_proj = mm_differentiable(&flat_v, &wv_t)?;
336
337        if let Some(ref qb) = self.q_bias {
338            let b = expand_bias_to_2d(qb.tensor(), batch * seq_q)?;
339            q_proj = add(&q_proj, &b)?;
340        }
341        if let Some(ref kb) = self.k_bias {
342            let b = expand_bias_to_2d(kb.tensor(), batch * seq_k)?;
343            k_proj = add(&k_proj, &b)?;
344        }
345        if let Some(ref vb) = self.v_bias {
346            let b = expand_bias_to_2d(vb.tensor(), batch * seq_k)?;
347            v_proj = add(&v_proj, &b)?;
348        }
349
350        // 2. Reshape to per-head layout via permute (zero-copy + contiguous).
351        //    Q: [B*Sq, D] → [B, Sq, H, Hd] → [B, H, Sq, Hd] → [B*H, Sq, Hd]
352        //    K/V: same but with Hkv instead of H.
353        let q = q_proj
354            .reshape_t(&[batch as isize, seq_q as isize, nh as isize, hd as isize])?
355            .permute(&[0, 2, 1, 3])?
356            .contiguous()?
357            .reshape_t(&[(batch * nh) as isize, seq_q as isize, hd as isize])?;
358
359        let mut k = k_proj
360            .reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
361            .permute(&[0, 2, 1, 3])?
362            .contiguous()?;
363        let mut v = v_proj
364            .reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
365            .permute(&[0, 2, 1, 3])?
366            .contiguous()?;
367
368        // 3. GQA repeat: expand each KV head to serve `group_size` Q heads.
369        if group_size > 1 {
370            // [B, Hkv, S, Hd] → [B, Hkv, 1, S, Hd] → expand [B, Hkv, G, S, Hd]
371            // → reshape [B, H, S, Hd]
372            k = k.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
373            k = expand(&k, &[batch, nkv, group_size, seq_k, hd])?;
374            k = k.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
375
376            v = v.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
377            v = expand(&v, &[batch, nkv, group_size, seq_k, hd])?;
378            v = v.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
379        }
380
381        let k = k.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
382        let v = v.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
383
384        // 4. Scaled dot-product attention.
385        //    scores = Q @ K^T → [B*H, Sq, Sk]
386        let k_t = k.permute(&[0, 2, 1])?.contiguous()?;
387        let scores = bmm_differentiable(&q, &k_t)?;
388
389        let scale_val = T::from(1.0 / (hd as f64).sqrt()).unwrap();
390        let scale_tensor = Tensor::from_storage(
391            TensorStorage::on_device(vec![scale_val], scores.device())?,
392            vec![1],
393            false,
394        )?;
395        let scaled = mul(&scores, &scale_tensor)?;
396
397        // 5. Causal mask: additive -1e9 for future positions.
398        let masked = if causal_mask {
399            let neg_inf = T::from(-1e9).unwrap();
400            let zero = <T as num_traits::Zero>::zero();
401            let mut mask_data = vec![zero; seq_q * seq_k];
402            for i in 0..seq_q {
403                for j in (i + 1)..seq_k {
404                    mask_data[i * seq_k + j] = neg_inf;
405                }
406            }
407            let mask =
408                Tensor::from_storage(TensorStorage::cpu(mask_data), vec![1, seq_q, seq_k], false)?;
409            let mask = if scaled.is_cuda() {
410                mask.to(scaled.device())?
411            } else {
412                mask
413            };
414            add(&scaled, &mask)?
415        } else {
416            scaled
417        };
418
419        // 6. Softmax + context.
420        let weights = softmax(&masked)?;
421        let context = bmm_differentiable(&weights, &v)?;
422
423        // 7. Reshape back: [B*H, Sq, Hd] → [B, H, Sq, Hd] → [B, Sq, H, Hd] → [B*Sq, D]
424        let context = context
425            .reshape_t(&[batch as isize, nh as isize, seq_q as isize, hd as isize])?
426            .permute(&[0, 2, 1, 3])?
427            .contiguous()?
428            .reshape_t(&[(batch * seq_q) as isize, self.embed_dim as isize])?;
429
430        // 8. Output projection.
431        let mut output = mm_differentiable(&context, &wo_t)?;
432        if let Some(ref ob) = self.out_bias {
433            let b = expand_bias_to_2d(ob.tensor(), batch * seq_q)?;
434            output = add(&output, &b)?;
435        }
436
437        output.reshape_t(&[batch as isize, seq_q as isize, self.embed_dim as isize])
438    }
439
440    /// The embedding dimension.
441    #[inline]
442    pub fn embed_dim(&self) -> usize {
443        self.embed_dim
444    }
445
446    /// The number of query attention heads.
447    #[inline]
448    pub fn num_heads(&self) -> usize {
449        self.num_heads
450    }
451
452    /// The number of key/value heads (equal to `num_heads` for classical MHA,
453    /// less for grouped-query attention).
454    #[inline]
455    pub fn num_kv_heads(&self) -> usize {
456        self.num_kv_heads
457    }
458
459    /// The dimension of each attention head.
460    #[inline]
461    pub fn head_dim(&self) -> usize {
462        self.head_dim
463    }
464
465    /// Whether this layer is configured for grouped-query attention
466    /// (`num_kv_heads < num_heads`).
467    #[inline]
468    pub fn is_gqa(&self) -> bool {
469        self.num_kv_heads != self.num_heads
470    }
471
472    /// Fast 2D self-attention for seq_len=1: [batch, embed_dim] -> [batch, embed_dim].
473    /// Avoids unsqueeze/squeeze overhead. For seq_len=1, attention is identity on V,
474    /// so this is just V_proj + O_proj (two fused linear ops).
475    ///
476    /// # Errors
477    ///
478    /// Returns `InvalidArgument` when called on a GQA layer
479    /// (`num_kv_heads != num_heads`): the V/O shapes no longer match and a
480    /// `repeat_kv` step is required. Use [`forward_qkv`] for GQA.
481    pub fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
482        use ferrotorch_core::grad_fns::linalg::linear_fused;
483
484        if self.is_gqa() {
485            return Err(FerrotorchError::InvalidArgument {
486                message:
487                    "forward_2d is MHA-only; use forward_qkv for GQA (num_kv_heads != num_heads)"
488                        .into(),
489            });
490        }
491
492        let v_proj = linear_fused(
493            input,
494            self.v_proj.tensor(),
495            self.v_bias.as_ref().map(|b| b.tensor()),
496        )?;
497        linear_fused(
498            &v_proj,
499            self.out_proj.tensor(),
500            self.out_bias.as_ref().map(|b| b.tensor()),
501        )
502    }
503}
504
505impl<T: Float> Module<T> for MultiheadAttention<T> {
506    /// Self-attention forward: query = key = value = input.
507    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
508        self.forward_qkv(input, input, input, false)
509    }
510
511    fn parameters(&self) -> Vec<&Parameter<T>> {
512        let mut params = vec![&self.q_proj, &self.k_proj, &self.v_proj, &self.out_proj];
513        if let Some(ref b) = self.q_bias {
514            params.push(b);
515        }
516        if let Some(ref b) = self.k_bias {
517            params.push(b);
518        }
519        if let Some(ref b) = self.v_bias {
520            params.push(b);
521        }
522        if let Some(ref b) = self.out_bias {
523            params.push(b);
524        }
525        params
526    }
527
528    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
529        let mut params: Vec<&mut Parameter<T>> = vec![
530            &mut self.q_proj,
531            &mut self.k_proj,
532            &mut self.v_proj,
533            &mut self.out_proj,
534        ];
535        if let Some(ref mut b) = self.q_bias {
536            params.push(b);
537        }
538        if let Some(ref mut b) = self.k_bias {
539            params.push(b);
540        }
541        if let Some(ref mut b) = self.v_bias {
542            params.push(b);
543        }
544        if let Some(ref mut b) = self.out_bias {
545            params.push(b);
546        }
547        params
548    }
549
550    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
551        let mut params = vec![
552            ("q_proj.weight".to_string(), &self.q_proj),
553            ("k_proj.weight".to_string(), &self.k_proj),
554            ("v_proj.weight".to_string(), &self.v_proj),
555            ("out_proj.weight".to_string(), &self.out_proj),
556        ];
557        if let Some(ref b) = self.q_bias {
558            params.push(("q_proj.bias".to_string(), b));
559        }
560        if let Some(ref b) = self.k_bias {
561            params.push(("k_proj.bias".to_string(), b));
562        }
563        if let Some(ref b) = self.v_bias {
564            params.push(("v_proj.bias".to_string(), b));
565        }
566        if let Some(ref b) = self.out_bias {
567            params.push(("out_proj.bias".to_string(), b));
568        }
569        params
570    }
571
572    fn train(&mut self) {
573        self.training = true;
574    }
575
576    fn eval(&mut self) {
577        self.training = false;
578    }
579
580    fn is_training(&self) -> bool {
581        self.training
582    }
583}
584
585// ---------------------------------------------------------------------------
586// Internal helpers
587// ---------------------------------------------------------------------------
588
589/// Expand a 1-D bias `[dim]` to `[rows, dim]` by repeating it along rows.
590///
591/// Uses the differentiable `expand` primitive so that gradients flow back
592/// to the original bias parameter through `ExpandBackward`.
593fn expand_bias_to_2d<T: Float>(bias: &Tensor<T>, rows: usize) -> FerrotorchResult<Tensor<T>> {
594    let dim = bias.shape()[0];
595    // Reshape [dim] -> [1, dim], then expand to [rows, dim].
596    let bias_2d = bias.reshape_t(&[1, dim as isize])?;
597    expand(&bias_2d, &[rows, dim])
598}
599
600/// Reshape `[seq, embed_dim]` to `[num_heads, seq, head_dim]`.
601///
602/// Conceptually: `[seq, num_heads * head_dim]` -> `[seq, num_heads, head_dim]`
603/// -> transpose(0,1) -> `[num_heads, seq, head_dim]`.
604///
605/// Since we lack a general N-D transpose, we do this with explicit data shuffling.
606pub fn reshape_to_heads<T: Float>(
607    tensor: &Tensor<T>,
608    num_heads: usize,
609    seq_len: usize,
610    head_dim: usize,
611) -> FerrotorchResult<Tensor<T>> {
612    let data = tensor.data()?;
613    // data layout: [seq_len, embed_dim] where embed_dim = num_heads * head_dim
614    // Interpret as [seq_len, num_heads, head_dim], then transpose to [num_heads, seq_len, head_dim]
615    let mut result = vec![<T as num_traits::Zero>::zero(); num_heads * seq_len * head_dim];
616
617    for s in 0..seq_len {
618        for h in 0..num_heads {
619            for d in 0..head_dim {
620                let src_idx = s * (num_heads * head_dim) + h * head_dim + d;
621                let dst_idx = h * (seq_len * head_dim) + s * head_dim + d;
622                result[dst_idx] = data[src_idx];
623            }
624        }
625    }
626
627    Tensor::from_storage(
628        TensorStorage::cpu(result),
629        vec![num_heads, seq_len, head_dim],
630        tensor.requires_grad(),
631    )
632}
633
634/// Transpose [num_heads, seq, head_dim] → [seq, num_heads * head_dim] = [seq, embed_dim].
635///
636/// Inverse of `reshape_to_heads` for the batched attention output.
637pub fn transpose_heads_to_2d<T: Float>(
638    tensor: &Tensor<T>,
639    num_heads: usize,
640    seq_len: usize,
641    head_dim: usize,
642) -> FerrotorchResult<Tensor<T>> {
643    let embed_dim = num_heads * head_dim;
644    let data = tensor.data_vec()?;
645    let mut result = vec![<T as num_traits::Zero>::zero(); seq_len * embed_dim];
646
647    for h in 0..num_heads {
648        for s in 0..seq_len {
649            for d in 0..head_dim {
650                let src_idx = h * (seq_len * head_dim) + s * head_dim + d;
651                let dst_idx = s * embed_dim + h * head_dim + d;
652                result[dst_idx] = data[src_idx];
653            }
654        }
655    }
656
657    let device = tensor.device();
658    Tensor::from_storage(
659        TensorStorage::on_device(result, device)?,
660        vec![seq_len, embed_dim],
661        false,
662    )
663}
664
665/// Repeat each KV head `group_size` times along the head axis to match
666/// the query-head count for grouped-query attention.
667///
668/// Input:  `[num_kv_heads, seq, head_dim]`
669/// Output: `[num_kv_heads * group_size, seq, head_dim]`
670///
671/// For output head `h`, the slice is copied from input head `h / group_size`.
672/// This matches the standard GQA convention where each KV head serves
673/// `group_size` consecutive query heads.
674///
675/// `group_size == 1` is a fast no-op clone (classical MHA path pays
676/// nothing).
677///
678/// Note: like the other reshape helpers in this module, this breaks the
679/// autograd graph — it is correct for inference but training-broken. A
680/// fully-differentiable variant would require a `RepeatKvBackward` op
681/// that sums gradients across replicated groups.
682pub fn repeat_kv<T: Float>(kv: &Tensor<T>, group_size: usize) -> FerrotorchResult<Tensor<T>> {
683    if group_size == 1 {
684        return Ok(kv.clone());
685    }
686    let shape = kv.shape();
687    if shape.len() != 3 {
688        return Err(FerrotorchError::ShapeMismatch {
689            message: format!(
690                "repeat_kv expects 3-D [num_kv_heads, seq, head_dim], got {:?}",
691                shape
692            ),
693        });
694    }
695    let num_kv_heads = shape[0];
696    let seq = shape[1];
697    let head_dim = shape[2];
698    let num_q_heads = num_kv_heads * group_size;
699    let data = kv.data_vec()?;
700    let head_stride = seq * head_dim;
701    let mut out = vec![<T as num_traits::Zero>::zero(); num_q_heads * head_stride];
702    for h in 0..num_q_heads {
703        let kv_h = h / group_size;
704        let src_start = kv_h * head_stride;
705        let dst_start = h * head_stride;
706        out[dst_start..dst_start + head_stride]
707            .copy_from_slice(&data[src_start..src_start + head_stride]);
708    }
709    let device = kv.device();
710    Tensor::from_storage(
711        TensorStorage::on_device(out, device)?,
712        vec![num_q_heads, seq, head_dim],
713        kv.requires_grad(),
714    )
715}
716
717// ===========================================================================
718// Tests
719// ===========================================================================
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724
725    #[test]
726    fn test_new_valid() {
727        let mha = MultiheadAttention::<f32>::new(64, 8, true);
728        assert!(mha.is_ok());
729        let mha = mha.unwrap();
730        assert_eq!(mha.embed_dim(), 64);
731        assert_eq!(mha.num_heads(), 8);
732        assert_eq!(mha.head_dim(), 8);
733    }
734
735    #[test]
736    fn test_new_invalid_divisibility() {
737        let result = MultiheadAttention::<f32>::new(65, 8, true);
738        assert!(result.is_err());
739    }
740
741    #[test]
742    fn test_new_zero_dims() {
743        assert!(MultiheadAttention::<f32>::new(0, 4, false).is_err());
744        assert!(MultiheadAttention::<f32>::new(64, 0, false).is_err());
745    }
746
747    #[test]
748    fn test_parameter_count_with_bias() {
749        let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
750        let params = mha.parameters();
751        // 4 weight matrices: 4 * 16 * 16 = 1024
752        // 4 bias vectors: 4 * 16 = 64
753        // Total params: 1088
754        let total: usize = params.iter().map(|p| p.numel()).sum();
755        let embed_dim = 16usize;
756        let expected = 4 * embed_dim * embed_dim + 4 * embed_dim;
757        assert_eq!(total, expected);
758        assert_eq!(params.len(), 8); // 4 weights + 4 biases
759    }
760
761    #[test]
762    fn test_parameter_count_without_bias() {
763        let mha = MultiheadAttention::<f32>::new(16, 4, false).unwrap();
764        let params = mha.parameters();
765        let total: usize = params.iter().map(|p| p.numel()).sum();
766        let embed_dim = 16usize;
767        let expected = 4 * embed_dim * embed_dim;
768        assert_eq!(total, expected);
769        assert_eq!(params.len(), 4); // 4 weights only
770    }
771
772    #[test]
773    fn test_named_parameters() {
774        let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
775        let named = mha.named_parameters();
776        let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
777        assert!(names.contains(&"q_proj.weight"));
778        assert!(names.contains(&"k_proj.weight"));
779        assert!(names.contains(&"v_proj.weight"));
780        assert!(names.contains(&"out_proj.weight"));
781        assert!(names.contains(&"q_proj.bias"));
782        assert!(names.contains(&"k_proj.bias"));
783        assert!(names.contains(&"v_proj.bias"));
784        assert!(names.contains(&"out_proj.bias"));
785    }
786
787    #[test]
788    fn test_output_shape() {
789        let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
790        // Input: [batch=2, seq_len=5, embed_dim=16]
791        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
792        let output = mha.forward(&input).unwrap();
793        assert_eq!(output.shape(), &[2, 5, 16]);
794    }
795
796    #[test]
797    fn test_output_shape_no_bias() {
798        let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
799        let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
800        let output = mha.forward(&input).unwrap();
801        assert_eq!(output.shape(), &[1, 3, 8]);
802    }
803
804    #[test]
805    fn test_self_attention_basic_forward() {
806        // Use a small model to verify forward pass produces finite values.
807        let mha = MultiheadAttention::<f64>::new(4, 2, true).unwrap();
808        let input = ferrotorch_core::ones::<f64>(&[1, 2, 4]).unwrap();
809        let output = mha.forward(&input).unwrap();
810
811        assert_eq!(output.shape(), &[1, 2, 4]);
812        let data = output.data().unwrap();
813        // All values should be finite (not NaN, not Inf).
814        for &v in data {
815            assert!(v.is_finite(), "output contains non-finite value: {v}");
816        }
817    }
818
819    #[test]
820    fn test_cross_attention_shape() {
821        let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
822        // query: [1, 3, 8], key/value: [1, 5, 8]
823        let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
824        let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
825        let output = mha.forward_qkv(&query, &kv, &kv, false).unwrap();
826        assert_eq!(output.shape(), &[1, 3, 8]);
827    }
828
829    #[test]
830    fn test_causal_mask_different_seq_lens_error() {
831        let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
832        let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
833        let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
834        // Causal mask requires seq_q == seq_k.
835        let result = mha.forward_qkv(&query, &kv, &kv, true);
836        assert!(result.is_err());
837    }
838
839    #[test]
840    fn test_train_eval_toggle() {
841        let mut mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
842        assert!(mha.is_training());
843        mha.eval();
844        assert!(!mha.is_training());
845        mha.train();
846        assert!(mha.is_training());
847    }
848
849    #[test]
850    fn test_wrong_embed_dim_input() {
851        let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
852        // Wrong embed_dim: 4 instead of 8.
853        let input = ferrotorch_core::zeros::<f32>(&[1, 3, 4]).unwrap();
854        let result = mha.forward(&input);
855        assert!(result.is_err());
856    }
857
858    #[test]
859    fn test_2d_input_rejected() {
860        let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
861        let input = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap();
862        let result = mha.forward(&input);
863        assert!(result.is_err());
864    }
865
866    #[test]
867    fn test_is_send_sync() {
868        fn assert_send_sync<T: Send + Sync>() {}
869        assert_send_sync::<MultiheadAttention<f32>>();
870        assert_send_sync::<MultiheadAttention<f64>>();
871    }
872
873    // -- Grouped-Query Attention tests (#505) ---------------------------
874
875    #[test]
876    fn test_with_gqa_valid_construction() {
877        // Llama 3 8B layout: 32 query heads, 8 KV heads, head_dim=128.
878        let mha = MultiheadAttention::<f32>::with_gqa(4096, 32, 8, false).unwrap();
879        assert_eq!(mha.embed_dim(), 4096);
880        assert_eq!(mha.num_heads(), 32);
881        assert_eq!(mha.num_kv_heads(), 8);
882        assert_eq!(mha.head_dim(), 128);
883        assert!(mha.is_gqa());
884    }
885
886    #[test]
887    fn test_with_gqa_kv_proj_shapes() {
888        // K/V projections must output `num_kv_heads * head_dim`, not `embed_dim`.
889        let mha = MultiheadAttention::<f32>::with_gqa(64, 8, 2, true).unwrap();
890        let kv_dim = 2 * (64 / 8); // num_kv_heads * head_dim = 16
891        assert_eq!(mha.q_proj.shape(), &[64, 64]);
892        assert_eq!(mha.k_proj.shape(), &[kv_dim, 64]);
893        assert_eq!(mha.v_proj.shape(), &[kv_dim, 64]);
894        assert_eq!(mha.out_proj.shape(), &[64, 64]);
895        // Biases follow the same split.
896        assert_eq!(mha.q_bias.as_ref().unwrap().shape(), &[64]);
897        assert_eq!(mha.k_bias.as_ref().unwrap().shape(), &[kv_dim]);
898        assert_eq!(mha.v_bias.as_ref().unwrap().shape(), &[kv_dim]);
899        assert_eq!(mha.out_bias.as_ref().unwrap().shape(), &[64]);
900    }
901
902    #[test]
903    fn test_with_gqa_rejects_non_divisible_kv_heads() {
904        // num_heads=8, num_kv_heads=3 → 8 % 3 != 0.
905        let result = MultiheadAttention::<f32>::with_gqa(64, 8, 3, false);
906        assert!(result.is_err());
907    }
908
909    #[test]
910    fn test_with_gqa_rejects_zero_kv_heads() {
911        let result = MultiheadAttention::<f32>::with_gqa(64, 8, 0, false);
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn test_with_gqa_equivalent_to_new_when_kv_equals_q() {
917        // Passing num_kv_heads == num_heads must reproduce classical MHA shapes.
918        let gqa = MultiheadAttention::<f32>::with_gqa(32, 4, 4, true).unwrap();
919        let mha = MultiheadAttention::<f32>::new(32, 4, true).unwrap();
920        assert_eq!(gqa.num_kv_heads(), mha.num_kv_heads());
921        assert_eq!(gqa.k_proj.shape(), mha.k_proj.shape());
922        assert_eq!(gqa.v_proj.shape(), mha.v_proj.shape());
923        assert!(!gqa.is_gqa());
924    }
925
926    #[test]
927    fn test_repeat_kv_noop_on_group_size_1() {
928        // group_size=1 must be a cheap clone (the MHA hot path).
929        let kv = ferrotorch_core::from_slice::<f32>(
930            &(0..24).map(|i| i as f32).collect::<Vec<_>>(),
931            &[2, 3, 4], // [num_kv_heads=2, seq=3, head_dim=4]
932        )
933        .unwrap();
934        let out = repeat_kv(&kv, 1).unwrap();
935        assert_eq!(out.shape(), kv.shape());
936        assert_eq!(out.data_vec().unwrap(), kv.data_vec().unwrap());
937    }
938
939    #[test]
940    fn test_repeat_kv_copies_correct_heads() {
941        // Input: 2 KV heads, each a 1x3 row with distinct values per head.
942        // group_size=3 → 6 output heads. Heads 0,1,2 should equal input head 0;
943        // heads 3,4,5 should equal input head 1.
944        let data: Vec<f32> = vec![
945            10.0, 11.0, 12.0, // head 0, seq 0
946            13.0, 14.0, 15.0, // head 0, seq 1
947            20.0, 21.0, 22.0, // head 1, seq 0
948            23.0, 24.0, 25.0, // head 1, seq 1
949        ];
950        let kv = ferrotorch_core::from_slice::<f32>(&data, &[2, 2, 3]).unwrap();
951        let out = repeat_kv(&kv, 3).unwrap();
952        assert_eq!(out.shape(), &[6, 2, 3]);
953        let out_data = out.data_vec().unwrap();
954        let head_stride = 2 * 3; // seq * head_dim
955        // Heads 0, 1, 2 come from input head 0.
956        for h in 0..3 {
957            let start = h * head_stride;
958            assert_eq!(&out_data[start..start + head_stride], &data[0..head_stride]);
959        }
960        // Heads 3, 4, 5 come from input head 1.
961        for h in 3..6 {
962            let start = h * head_stride;
963            assert_eq!(
964                &out_data[start..start + head_stride],
965                &data[head_stride..2 * head_stride]
966            );
967        }
968    }
969
970    #[test]
971    fn test_repeat_kv_rejects_wrong_rank() {
972        let kv = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap(); // 2-D
973        assert!(repeat_kv(&kv, 2).is_err());
974    }
975
976    #[test]
977    fn test_gqa_forward_output_shape_preserved() {
978        // GQA must return the same [batch, seq, embed_dim] shape as MHA.
979        let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, true).unwrap();
980        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
981        let out = mha.forward(&input).unwrap();
982        assert_eq!(out.shape(), &[2, 5, 16]);
983    }
984
985    #[test]
986    fn test_gqa_forward_produces_finite_values() {
987        let mha = MultiheadAttention::<f64>::with_gqa(8, 4, 2, true).unwrap();
988        let input = ferrotorch_core::ones::<f64>(&[1, 3, 8]).unwrap();
989        let out = mha.forward(&input).unwrap();
990        let data = out.data().unwrap();
991        for &v in data {
992            assert!(v.is_finite(), "GQA output non-finite: {v}");
993        }
994    }
995
996    #[test]
997    fn test_gqa_forward_decoder_style_single_token() {
998        // Single-token forward (seq_q == seq_k == 1) must stay numerically
999        // stable on the GQA path — this is the autoregressive generation
1000        // hot case for Llama inference.
1001        let mha = MultiheadAttention::<f32>::with_gqa(32, 8, 2, false).unwrap();
1002        let input = ferrotorch_core::ones::<f32>(&[1, 1, 32]).unwrap();
1003        let out = mha.forward(&input).unwrap();
1004        assert_eq!(out.shape(), &[1, 1, 32]);
1005        for &v in out.data().unwrap() {
1006            assert!(v.is_finite());
1007        }
1008    }
1009
1010    #[test]
1011    fn test_gqa_forward_with_causal_mask() {
1012        // Causal masking must still work when K/V have fewer heads than Q.
1013        let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, false).unwrap();
1014        let x = ferrotorch_core::ones::<f32>(&[1, 4, 16]).unwrap();
1015        let out = mha.forward_qkv(&x, &x, &x, true).unwrap();
1016        assert_eq!(out.shape(), &[1, 4, 16]);
1017        for &v in out.data().unwrap() {
1018            assert!(v.is_finite());
1019        }
1020    }
1021}