1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
impl MultiHeadAttention {
/// Create a new Multi-Head Attention layer with configurable `KV` heads
///
/// # Arguments
///
/// * `hidden_dim` - Total hidden dimension (must be divisible by `num_heads`)
/// * `num_heads` - Number of query heads
/// * `num_kv_heads` - Number of key/value heads (must divide `num_heads`)
///
/// # Modes
///
/// - MHA: `num_kv_heads = num_heads` (standard multi-head)
/// - MQA: `num_kv_heads = 1` (all heads share K/V)
/// - GQA: `1 < num_kv_heads < num_heads` (grouped heads)
///
/// # Errors
///
/// Returns error if:
/// - `hidden_dim` is zero or not divisible by `num_heads`
/// - `num_heads` is zero or not divisible by `num_kv_heads`
/// - `num_kv_heads` is zero or greater than `num_heads`
///
/// # Examples
///
/// ```rust,ignore
/// // Standard Multi-Head Attention (MHA)
/// let mha = MultiHeadAttention::new(512, 8, 8)?;
///
/// // Multi-Query Attention (MQA)
/// let mqa = MultiHeadAttention::new(512, 8, 1)?;
///
/// // Grouped-Query Attention (GQA) - 4 heads per group
/// let gqa = MultiHeadAttention::new(512, 8, 2)?;
/// ```
pub fn new(hidden_dim: usize, num_heads: usize, num_kv_heads: usize) -> Result<Self> {
if hidden_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "hidden_dim must be > 0".to_string(),
});
}
if num_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_heads must be > 0".to_string(),
});
}
if num_kv_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_kv_heads must be > 0".to_string(),
});
}
if num_kv_heads > num_heads {
return Err(RealizarError::InvalidShape {
reason: format!(
"num_kv_heads {num_kv_heads} cannot be greater than num_heads {num_heads}"
),
});
}
if !hidden_dim.is_multiple_of(num_heads) {
return Err(RealizarError::InvalidShape {
reason: format!(
"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
),
});
}
if !num_heads.is_multiple_of(num_kv_heads) {
return Err(RealizarError::InvalidShape {
reason: format!(
"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}"
),
});
}
let head_dim = hidden_dim / num_heads;
// Q projection: always hidden_dim -> hidden_dim (all query heads)
let q_proj = Linear::new(hidden_dim, hidden_dim)?;
// K/V projections: hidden_dim -> num_kv_heads * head_dim
let kv_dim = num_kv_heads * head_dim;
let k_proj = Linear::new(hidden_dim, kv_dim)?;
let v_proj = Linear::new(hidden_dim, kv_dim)?;
// Output projection: hidden_dim -> hidden_dim
let o_proj = Linear::new(hidden_dim, hidden_dim)?;
// Per-head attention mechanism
let attention = Attention::new(head_dim)?;
Ok(Self {
num_heads,
num_kv_heads,
head_dim,
hidden_dim,
q_proj,
k_proj,
v_proj,
o_proj,
attention,
})
}
/// Create standard Multi-Head Attention (MHA) - each head has separate K/V
///
/// # Errors
///
/// Returns `RealizarError::InvalidShape` if:
/// - `hidden_dim` is 0
/// - `num_heads` is 0
/// - `hidden_dim` is not divisible by `num_heads`
pub fn mha(hidden_dim: usize, num_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, num_heads)
}
/// Create Multi-Query Attention (MQA) - all heads share K/V
///
/// # Errors
///
/// Returns `RealizarError::InvalidShape` if:
/// - `hidden_dim` is 0
/// - `num_heads` is 0
/// - `hidden_dim` is not divisible by `num_heads`
pub fn mqa(hidden_dim: usize, num_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, 1)
}
/// Create Grouped-Query Attention (GQA) - heads grouped to share K/V
///
/// # Errors
///
/// Returns `RealizarError::InvalidShape` if:
/// - `hidden_dim` is 0
/// - `num_heads` is 0
/// - `num_kv_heads` is 0
/// - `num_kv_heads` is greater than `num_heads`
/// - `hidden_dim` is not divisible by `num_heads`
/// - `num_heads` is not divisible by `num_kv_heads`
pub fn gqa(hidden_dim: usize, num_heads: usize, num_kv_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, num_kv_heads)
}
/// Forward pass through multi-head attention
///
/// # Arguments
///
/// * `input` - Input tensor `[seq_len, hidden_dim]`
///
/// # Returns
///
/// Output tensor `[seq_len, hidden_dim]`
///
/// # Errors
///
/// Returns error if input shape is invalid
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.len() != 2 {
return Err(RealizarError::InvalidShape {
reason: format!("Expected 2D tensor [seq_len, hidden_dim], got shape {shape:?}"),
});
}
let seq_len = shape[0];
let input_dim = shape[1];
if input_dim != self.hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected hidden_dim={}, got {}", self.hidden_dim, input_dim),
});
}
// Project Q, K, V
let q = self.q_proj.forward(input)?; // [seq_len, hidden_dim]
let k = self.k_proj.forward(input)?; // [seq_len, kv_dim]
let v = self.v_proj.forward(input)?; // [seq_len, kv_dim]
// Reshape Q into heads: [seq_len, num_heads, head_dim]
let q_data = q.data();
let k_data = k.data();
let v_data = v.data();
// Calculate heads per group for GQA
let heads_per_group = self.num_heads / self.num_kv_heads;
// Process each query head
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head_idx in 0..self.num_heads {
// Extract Q for this head
let mut q_head_data = Vec::with_capacity(seq_len * self.head_dim);
for seq_idx in 0..seq_len {
let q_row_start = seq_idx * self.hidden_dim;
let head_start = q_row_start + head_idx * self.head_dim;
for offset in 0..self.head_dim {
q_head_data.push(q_data[head_start + offset]);
}
}
let q_head = Tensor::from_vec(vec![seq_len, self.head_dim], q_head_data)?;
// Determine which KV head this Q head uses (for GQA/MQA/MHA)
let kv_head_idx = head_idx / heads_per_group;
let kv_dim = self.num_kv_heads * self.head_dim;
// Extract K, V for the corresponding KV head
let mut k_head_data = Vec::with_capacity(seq_len * self.head_dim);
let mut v_head_data = Vec::with_capacity(seq_len * self.head_dim);
for seq_idx in 0..seq_len {
let kv_row_start = seq_idx * kv_dim;
let kv_head_start = kv_row_start + kv_head_idx * self.head_dim;
for offset in 0..self.head_dim {
k_head_data.push(k_data[kv_head_start + offset]);
v_head_data.push(v_data[kv_head_start + offset]);
}
}
let k_head = Tensor::from_vec(vec![seq_len, self.head_dim], k_head_data)?;
let v_head = Tensor::from_vec(vec![seq_len, self.head_dim], v_head_data)?;
// Compute attention for this head
let head_output = self.attention.forward(&q_head, &k_head, &v_head)?;
head_outputs.push(head_output);
}
// Concatenate all head outputs: [seq_len, hidden_dim]
let mut concat_data = Vec::with_capacity(seq_len * self.hidden_dim);
for seq_idx in 0..seq_len {
for head_output in &head_outputs {
let head_output_data = head_output.data();
let head_row_start = seq_idx * self.head_dim;
for offset in 0..self.head_dim {
concat_data.push(head_output_data[head_row_start + offset]);
}
}
}
let concat = Tensor::from_vec(vec![seq_len, self.hidden_dim], concat_data)?;
// Output projection
self.o_proj.forward(&concat)
}
/// Get number of query heads
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
/// Get number of key/value heads
#[must_use]
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
/// Get head dimension
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
/// Get hidden dimension
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
/// Check if using Multi-Query Attention (MQA)
#[must_use]
pub fn is_mqa(&self) -> bool {
self.num_kv_heads == 1
}
/// Check if using Grouped-Query Attention (GQA)
#[must_use]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads > 1 && self.num_kv_heads < self.num_heads
}
/// Check if using standard Multi-Head Attention (MHA)
#[must_use]
pub fn is_mha(&self) -> bool {
self.num_kv_heads == self.num_heads
}
}