mlx_nn/attention.rs
1//! Multi-head attention module.
2
3use mlx_core::{Result, Tensor};
4
5use crate::{Linear, Module};
6
7/// Multi-head attention.
8///
9/// Projects input through Q, K, V linear layers, splits into `n_heads` heads,
10/// runs per-head scaled dot-product attention (via the fused `Attention` op),
11/// concatenates heads, and applies an output projection.
12pub struct MultiHeadAttention {
13 n_heads: usize,
14 head_dim: usize,
15 wq: Linear,
16 wk: Linear,
17 wv: Linear,
18 wo: Linear,
19}
20
21impl MultiHeadAttention {
22 /// Create a new MultiHeadAttention from pre-built linear layers.
23 ///
24 /// - `wq`, `wk`, `wv`: projection layers with weight `[n_heads * head_dim, model_dim]`
25 /// - `wo`: output projection `[model_dim, n_heads * head_dim]`
26 /// - `n_heads`: number of attention heads
27 pub fn new(wq: Linear, wk: Linear, wv: Linear, wo: Linear, n_heads: usize) -> Self {
28 // Infer head_dim from wq weight shape: [n_heads * head_dim, model_dim]
29 let total_dim = wq.weight().shape().0[0] as usize;
30 let head_dim = total_dim / n_heads;
31 Self {
32 n_heads,
33 head_dim,
34 wq,
35 wk,
36 wv,
37 wo,
38 }
39 }
40
41 // TODO: add cross-attention variant that accepts separate key/value inputs.
42
43 /// Forward pass with causal masking (self-attention, auto-regressive).
44 ///
45 /// `x` has shape `[seq_len, model_dim]`.
46 /// Returns `[seq_len, model_dim]`.
47 pub fn forward_causal(&self, x: &Tensor) -> Result<Tensor> {
48 let seq_len = x.shape().0[0] as usize;
49 let scale = 1.0 / (self.head_dim as f32).sqrt();
50
51 // Project Q, K, V: [seq, model_dim] -> [seq, n_heads * head_dim]
52 let q = self.wq.forward(x)?;
53 let k = self.wk.forward(x)?;
54 let v = self.wv.forward(x)?;
55
56 // Reshape to [n_heads, seq, head_dim]
57 let q = q.reshape(&mlx_core::Shape::new(vec![
58 seq_len as i64,
59 self.n_heads as i64,
60 self.head_dim as i64,
61 ]))?;
62 let q = q.transpose(Some(&[1, 0, 2]))?; // [n_heads, seq, head_dim]
63
64 let k = k.reshape(&mlx_core::Shape::new(vec![
65 seq_len as i64,
66 self.n_heads as i64,
67 self.head_dim as i64,
68 ]))?;
69 let k = k.transpose(Some(&[1, 0, 2]))?;
70
71 let v = v.reshape(&mlx_core::Shape::new(vec![
72 seq_len as i64,
73 self.n_heads as i64,
74 self.head_dim as i64,
75 ]))?;
76 let v = v.transpose(Some(&[1, 0, 2]))?;
77
78 // Per-head attention using narrow to extract each head, then the fused Attention op.
79 // TODO: replace with a batched attention op to avoid O(n_heads) graph nodes.
80 let mut head_outputs = Vec::with_capacity(self.n_heads);
81 for h in 0..self.n_heads {
82 let q_h = q.narrow(0, h as i64, 1)?; // [1, seq, head_dim]
83 let q_h = q_h.reshape(&mlx_core::Shape::new(vec![
84 seq_len as i64,
85 self.head_dim as i64,
86 ]))?; // [seq, head_dim]
87
88 let k_h = k.narrow(0, h as i64, 1)?;
89 let k_h = k_h.reshape(&mlx_core::Shape::new(vec![
90 seq_len as i64,
91 self.head_dim as i64,
92 ]))?;
93
94 let v_h = v.narrow(0, h as i64, 1)?;
95 let v_h = v_h.reshape(&mlx_core::Shape::new(vec![
96 seq_len as i64,
97 self.head_dim as i64,
98 ]))?;
99
100 let attn_h = q_h.attention(&k_h, &v_h, scale, true)?; // [seq, head_dim]
101 head_outputs.push(attn_h);
102 }
103
104 // Concatenate heads: collect [seq, head_dim] tensors -> [seq, n_heads * head_dim]
105 let head_refs: Vec<&Tensor> = head_outputs.iter().collect();
106 let concat = Tensor::cat(&head_refs, 1)?; // [seq, n_heads * head_dim]
107
108 // Output projection
109 // wo expects [seq, n_heads * head_dim] -> [seq, model_dim]
110 self.wo.forward(&concat)
111 }
112}