Skip to main content

mlx_nn/
attention.rs

1//! Multi-head attention module.
2
3use mlx_core::{Result, Tensor};
4
5use crate::{Linear, Module};
6
7/// Multi-head attention.
8///
9/// Projects input through Q, K, V linear layers, splits into `n_heads` heads,
10/// runs per-head scaled dot-product attention (via the fused `Attention` op),
11/// concatenates heads, and applies an output projection.
12pub struct MultiHeadAttention {
13    n_heads: usize,
14    head_dim: usize,
15    wq: Linear,
16    wk: Linear,
17    wv: Linear,
18    wo: Linear,
19}
20
21impl MultiHeadAttention {
22    /// Create a new MultiHeadAttention from pre-built linear layers.
23    ///
24    /// - `wq`, `wk`, `wv`: projection layers with weight `[n_heads * head_dim, model_dim]`
25    /// - `wo`: output projection `[model_dim, n_heads * head_dim]`
26    /// - `n_heads`: number of attention heads
27    pub fn new(wq: Linear, wk: Linear, wv: Linear, wo: Linear, n_heads: usize) -> Self {
28        // Infer head_dim from wq weight shape: [n_heads * head_dim, model_dim]
29        let total_dim = wq.weight().shape().0[0] as usize;
30        let head_dim = total_dim / n_heads;
31        Self {
32            n_heads,
33            head_dim,
34            wq,
35            wk,
36            wv,
37            wo,
38        }
39    }
40
41    // TODO: add cross-attention variant that accepts separate key/value inputs.
42
43    /// Forward pass with causal masking (self-attention, auto-regressive).
44    ///
45    /// `x` has shape `[seq_len, model_dim]`.
46    /// Returns `[seq_len, model_dim]`.
47    pub fn forward_causal(&self, x: &Tensor) -> Result<Tensor> {
48        let seq_len = x.shape().0[0] as usize;
49        let scale = 1.0 / (self.head_dim as f32).sqrt();
50
51        // Project Q, K, V: [seq, model_dim] -> [seq, n_heads * head_dim]
52        let q = self.wq.forward(x)?;
53        let k = self.wk.forward(x)?;
54        let v = self.wv.forward(x)?;
55
56        // Reshape to [n_heads, seq, head_dim]
57        let q = q.reshape(&mlx_core::Shape::new(vec![
58            seq_len as i64,
59            self.n_heads as i64,
60            self.head_dim as i64,
61        ]))?;
62        let q = q.transpose(Some(&[1, 0, 2]))?; // [n_heads, seq, head_dim]
63
64        let k = k.reshape(&mlx_core::Shape::new(vec![
65            seq_len as i64,
66            self.n_heads as i64,
67            self.head_dim as i64,
68        ]))?;
69        let k = k.transpose(Some(&[1, 0, 2]))?;
70
71        let v = v.reshape(&mlx_core::Shape::new(vec![
72            seq_len as i64,
73            self.n_heads as i64,
74            self.head_dim as i64,
75        ]))?;
76        let v = v.transpose(Some(&[1, 0, 2]))?;
77
78        // Per-head attention using narrow to extract each head, then the fused Attention op.
79        // TODO: replace with a batched attention op to avoid O(n_heads) graph nodes.
80        let mut head_outputs = Vec::with_capacity(self.n_heads);
81        for h in 0..self.n_heads {
82            let q_h = q.narrow(0, h as i64, 1)?; // [1, seq, head_dim]
83            let q_h = q_h.reshape(&mlx_core::Shape::new(vec![
84                seq_len as i64,
85                self.head_dim as i64,
86            ]))?; // [seq, head_dim]
87
88            let k_h = k.narrow(0, h as i64, 1)?;
89            let k_h = k_h.reshape(&mlx_core::Shape::new(vec![
90                seq_len as i64,
91                self.head_dim as i64,
92            ]))?;
93
94            let v_h = v.narrow(0, h as i64, 1)?;
95            let v_h = v_h.reshape(&mlx_core::Shape::new(vec![
96                seq_len as i64,
97                self.head_dim as i64,
98            ]))?;
99
100            let attn_h = q_h.attention(&k_h, &v_h, scale, true)?; // [seq, head_dim]
101            head_outputs.push(attn_h);
102        }
103
104        // Concatenate heads: collect [seq, head_dim] tensors -> [seq, n_heads * head_dim]
105        let head_refs: Vec<&Tensor> = head_outputs.iter().collect();
106        let concat = Tensor::cat(&head_refs, 1)?; // [seq, n_heads * head_dim]
107
108        // Output projection
109        // wo expects [seq, n_heads * head_dim] -> [seq, model_dim]
110        self.wo.forward(&concat)
111    }
112}