entrenar/train/transformer_trainer/
wgpu_attention.rs1#[cfg(feature = "gpu")]
9use trueno::backends::gpu::GpuDevice;
10
11#[cfg(feature = "gpu")]
13fn head_rms_norm(
14 buf: &mut [f32],
15 seq_len: usize,
16 n_heads: usize,
17 total_dim: usize,
18 head_dim: usize,
19) {
20 let eps = 1e-6f32;
21 for si in 0..seq_len {
22 for head in 0..n_heads {
23 let off = si * total_dim + head * head_dim;
24 let rms = (buf[off..off + head_dim].iter().map(|x| x * x).sum::<f32>()
25 / head_dim as f32
26 + eps)
27 .sqrt();
28 for d in 0..head_dim {
29 buf[off + d] /= rms;
30 }
31 }
32 }
33}
34
35#[cfg(feature = "gpu")]
37fn norm_guard(output: &mut [f32], reference: &[f32], max_ratio: f32) {
38 let out_n = output.iter().map(|v| v * v).sum::<f32>().sqrt();
39 let ref_n = reference.iter().map(|v| v * v).sum::<f32>().sqrt();
40 if out_n > ref_n * max_ratio && ref_n > 1e-6 {
41 let scale = ref_n / out_n;
42 for v in output {
43 *v *= scale;
44 }
45 }
46}
47
48#[cfg(feature = "gpu")]
50pub struct AttentionCache {
51 pub q: Vec<f32>,
52 pub k: Vec<f32>,
53 pub v: Vec<f32>,
54 pub attn_weights: Vec<f32>, pub context: Vec<f32>,
56 pub lora_q_h: Vec<f32>, pub lora_v_h: Vec<f32>, }
59
60#[cfg(feature = "gpu")]
62#[allow(clippy::too_many_arguments)]
63pub fn attention_forward(
64 device: &GpuDevice,
65 hidden: &[f32], q_weight: &[f32], k_weight: &[f32], v_weight: &[f32], o_weight: &[f32], lora_q: &super::wgpu_nf4::LoraAdapter,
71 lora_v: &super::wgpu_nf4::LoraAdapter,
72 lora_alpha: f32,
73 seq_len: u32,
74 hidden_size: u32,
75 num_heads: u32,
76 num_kv_heads: u32,
77 head_dim: u32,
78) -> Result<(Vec<f32>, AttentionCache), String> {
79 let s = seq_len as usize;
80 let h = hidden_size as usize;
81 let q_dim = (num_heads * head_dim) as usize;
82 let kv_dim = (num_kv_heads * head_dim) as usize;
83 let hd = head_dim as usize;
84 let nh = num_heads as usize;
85 let nkv = num_kv_heads as usize;
86
87 let mut q = vec![0.0f32; s * q_dim];
89 device.matmul(hidden, q_weight, &mut q, s, h, q_dim)?;
90 let mut k = vec![0.0f32; s * kv_dim];
91 device.matmul(hidden, k_weight, &mut k, s, h, kv_dim)?;
92 let mut v = vec![0.0f32; s * kv_dim];
93 device.matmul(hidden, v_weight, &mut v, s, h, kv_dim)?;
94
95 let rank = lora_q.rank as usize;
97 let mut h_a_saved: Option<Vec<f32>> = None;
98 let mut h_av_saved: Option<Vec<f32>> = None;
99 if rank > 0 {
100 let scaling_q = lora_alpha / lora_q.rank as f32;
101 let mut h_a = vec![0.0f32; s * rank]; for si in 0..s {
104 for ri in 0..rank {
105 let mut sum = 0.0f32;
106 for hi in 0..h {
107 sum += hidden[si * h + hi] * lora_q.a[ri * h + hi];
108 }
109 h_a[si * rank + ri] = sum;
110 }
111 }
112 for si in 0..s {
113 for qi in 0..q_dim {
114 let mut sum = 0.0f32;
115 for ri in 0..rank {
116 sum += h_a[si * rank + ri] * lora_q.b[qi * rank + ri];
117 }
118 q[si * q_dim + qi] += scaling_q * sum;
119 }
120 }
121 h_a_saved = Some(h_a);
122
123 let v_rank = lora_v.rank as usize;
125 let scaling_v = lora_alpha / lora_v.rank as f32;
126 let mut h_av = vec![0.0f32; s * v_rank];
127 for si in 0..s {
128 for ri in 0..v_rank {
129 let mut sum = 0.0f32;
130 for hi in 0..h {
131 sum += hidden[si * h + hi] * lora_v.a[ri * h + hi];
132 }
133 h_av[si * v_rank + ri] = sum;
134 }
135 }
136 for si in 0..s {
137 for vi in 0..kv_dim {
138 let mut sum = 0.0f32;
139 for ri in 0..v_rank {
140 sum += h_av[si * v_rank + ri] * lora_v.b[vi * v_rank + ri];
141 }
142 v[si * kv_dim + vi] += scaling_v * sum;
143 }
144 }
145 h_av_saved = Some(h_av);
146 }
147
148 head_rms_norm(&mut q, s, nh, q_dim, hd);
150 head_rms_norm(&mut k, s, nkv, kv_dim, hd);
151
152 for si in 0..s {
154 for head in 0..nh {
155 for d in (0..hd).step_by(2) {
156 let pos = si as f32;
157 let freq = 1.0 / (10000.0f32).powf(d as f32 / hd as f32);
158 let (sin_val, cos_val) = (pos * freq).sin_cos();
159 let idx0 = si * q_dim + head * hd + d;
160 let idx1 = idx0 + 1;
161 if idx1 < q.len() {
162 let q0 = q[idx0];
163 let q1 = q[idx1];
164 q[idx0] = q0 * cos_val - q1 * sin_val;
165 q[idx1] = q0 * sin_val + q1 * cos_val;
166 }
167 }
168 }
169 for head in 0..nkv {
170 for d in (0..hd).step_by(2) {
171 let pos = si as f32;
172 let freq = 1.0 / (10000.0f32).powf(d as f32 / hd as f32);
173 let (sin_val, cos_val) = (pos * freq).sin_cos();
174 let idx0 = si * kv_dim + head * hd + d;
175 let idx1 = idx0 + 1;
176 if idx1 < k.len() {
177 let k0 = k[idx0];
178 let k1 = k[idx1];
179 k[idx0] = k0 * cos_val - k1 * sin_val;
180 k[idx1] = k0 * sin_val + k1 * cos_val;
181 }
182 }
183 }
184 }
185
186 let heads_per_kv = nh / nkv;
188 let mut context = vec![0.0f32; s * q_dim];
189 let mut attn_weights = vec![0.0f32; nh * s * s]; let scale = 1.0 / (hd as f32).sqrt();
191
192 for head in 0..nh {
193 let kv_head = head / heads_per_kv;
194 for qi in 0..s {
195 let mut max_score = f32::NEG_INFINITY;
196 let aw_off = head * s * s + qi * s;
197 for ki in 0..s {
198 if ki > qi {
199 attn_weights[aw_off + ki] = 0.0;
200 continue;
201 }
202 let mut dot = 0.0f32;
203 for d in 0..hd {
204 dot += q[qi * q_dim + head * hd + d] * k[ki * kv_dim + kv_head * hd + d];
205 }
206 attn_weights[aw_off + ki] = dot * scale;
207 if attn_weights[aw_off + ki] > max_score {
208 max_score = attn_weights[aw_off + ki];
209 }
210 }
211 let mut sum_exp = 0.0f32;
212 for ki in 0..s {
213 attn_weights[aw_off + ki] =
214 if ki > qi { 0.0 } else { (attn_weights[aw_off + ki] - max_score).exp() };
215 sum_exp += attn_weights[aw_off + ki];
216 }
217 if sum_exp > 0.0 {
218 for ki in 0..s {
219 attn_weights[aw_off + ki] /= sum_exp;
220 }
221 }
222 for d in 0..hd {
223 let mut val = 0.0f32;
224 for ki in 0..s {
225 val += attn_weights[aw_off + ki] * v[ki * kv_dim + kv_head * hd + d];
226 }
227 context[qi * q_dim + head * hd + d] = val;
228 }
229 }
230 }
231
232 let mut output = vec![0.0f32; s * h];
234 device.matmul(&context, o_weight, &mut output, s, q_dim, h)?;
235
236 norm_guard(&mut output, hidden, 10.0); let cache = AttentionCache {
238 q: q.clone(),
239 k: k.clone(),
240 v,
241 attn_weights,
242 context,
243 lora_q_h: if rank > 0 { h_a_saved.unwrap_or_default() } else { vec![] },
244 lora_v_h: if rank > 0 { h_av_saved.unwrap_or_default() } else { vec![] },
245 };
246 Ok((output, cache))
247}
248
249#[cfg(all(test, feature = "gpu"))]
250mod tests {
251 use super::*;
252 use crate::train::transformer_trainer::wgpu_nf4::LoraAdapter;
253
254 #[test]
256 fn test_attention_forward_basic() {
257 let device = GpuDevice::new().expect("GPU");
258 let (s, h, nh, nkv, hd) = (4u32, 16u32, 4u32, 2u32, 4u32);
259 let q_dim = (nh * hd) as usize;
260 let kv_dim = (nkv * hd) as usize;
261
262 let hidden: Vec<f32> = (0..(s * h) as usize).map(|i| (i as f32 - 32.0) * 0.01).collect();
263 let q_w: Vec<f32> = (0..q_dim * h as usize).map(|i| (i as f32 - 64.0) * 0.005).collect();
264 let k_w: Vec<f32> = (0..kv_dim * h as usize).map(|i| (i as f32 - 32.0) * 0.005).collect();
265 let v_w: Vec<f32> = (0..kv_dim * h as usize).map(|i| (i as f32 - 32.0) * 0.005).collect();
266 let o_w: Vec<f32> = (0..h as usize * q_dim).map(|i| (i as f32 - 64.0) * 0.005).collect();
267
268 let lora_q = LoraAdapter::new(4, h, q_dim as u32);
269 let lora_v = LoraAdapter::new(4, h, kv_dim as u32);
270
271 let (out_base, _cache) = attention_forward(
273 &device, &hidden, &q_w, &k_w, &v_w, &o_w, &lora_q, &lora_v, 32.0, s, h, nh, nkv, hd,
274 )
275 .expect("attention_forward");
276
277 assert_eq!(out_base.len(), (s * h) as usize);
278 assert!(out_base.iter().all(|v| v.is_finite()), "All outputs finite");
279
280 let mut lora_q2 = LoraAdapter::new(4, h, q_dim as u32);
282 for b in &mut lora_q2.b {
283 *b = 0.01;
284 }
285 let (out_lora, _) = attention_forward(
286 &device, &hidden, &q_w, &k_w, &v_w, &o_w, &lora_q2, &lora_v, 32.0, s, h, nh, nkv, hd,
287 )
288 .expect("attention_forward lora");
289
290 let diff: f32 = out_base.iter().zip(out_lora.iter()).map(|(a, b)| (a - b).abs()).sum();
291 assert!(diff > 1e-6, "LoRA Q should change attention output, diff={diff}");
292
293 eprintln!(
294 "Attention forward: output_norm={:.4}, lora_diff={diff:.6}",
295 out_base.iter().map(|v| v * v).sum::<f32>().sqrt()
296 );
297 }
298}