Skip to main content

entrenar/train/transformer_trainer/
wgpu_attention.rs

1//! WGPU attention forward pass with LoRA for training
2//!
3//! QKV projection (with LoRA on Q, V) → RoPE → scaled dot-product attention
4//! → output projection. Causal mask for autoregressive training.
5//!
6//! # Contract: C-WGPU-TRAIN-008 (attention forward with LoRA)
7
8#[cfg(feature = "gpu")]
9use trueno::backends::gpu::GpuDevice;
10
11/// Per-head RMS normalization (QK-norm for Qwen3)
12#[cfg(feature = "gpu")]
13fn head_rms_norm(
14    buf: &mut [f32],
15    seq_len: usize,
16    n_heads: usize,
17    total_dim: usize,
18    head_dim: usize,
19) {
20    let eps = 1e-6f32;
21    for si in 0..seq_len {
22        for head in 0..n_heads {
23            let off = si * total_dim + head * head_dim;
24            let rms = (buf[off..off + head_dim].iter().map(|x| x * x).sum::<f32>()
25                / head_dim as f32
26                + eps)
27                .sqrt();
28            for d in 0..head_dim {
29                buf[off + d] /= rms;
30            }
31        }
32    }
33}
34
35/// Scale buffer to match target norm (prevents residual explosion)
36#[cfg(feature = "gpu")]
37fn norm_guard(output: &mut [f32], reference: &[f32], max_ratio: f32) {
38    let out_n = output.iter().map(|v| v * v).sum::<f32>().sqrt();
39    let ref_n = reference.iter().map(|v| v * v).sum::<f32>().sqrt();
40    if out_n > ref_n * max_ratio && ref_n > 1e-6 {
41        let scale = ref_n / out_n;
42        for v in output {
43            *v *= scale;
44        }
45    }
46}
47
48/// Attention cache returned from forward for use in backward
49#[cfg(feature = "gpu")]
50pub struct AttentionCache {
51    pub q: Vec<f32>,
52    pub k: Vec<f32>,
53    pub v: Vec<f32>,
54    pub attn_weights: Vec<f32>, // [num_heads, seq_len, seq_len]
55    pub context: Vec<f32>,
56    pub lora_q_h: Vec<f32>, // hidden @ A_q^T [s, rank]
57    pub lora_v_h: Vec<f32>, // hidden @ A_v^T [s, rank]
58}
59
60/// Attention forward. Returns (output, cache) for backward pass.
61#[cfg(feature = "gpu")]
62#[allow(clippy::too_many_arguments)]
63pub fn attention_forward(
64    device: &GpuDevice,
65    hidden: &[f32],   // [seq_len, hidden_size]
66    q_weight: &[f32], // [q_dim, hidden_size]
67    k_weight: &[f32], // [kv_dim, hidden_size]
68    v_weight: &[f32], // [kv_dim, hidden_size]
69    o_weight: &[f32], // [hidden_size, q_dim]
70    lora_q: &super::wgpu_nf4::LoraAdapter,
71    lora_v: &super::wgpu_nf4::LoraAdapter,
72    lora_alpha: f32,
73    seq_len: u32,
74    hidden_size: u32,
75    num_heads: u32,
76    num_kv_heads: u32,
77    head_dim: u32,
78) -> Result<(Vec<f32>, AttentionCache), String> {
79    let s = seq_len as usize;
80    let h = hidden_size as usize;
81    let q_dim = (num_heads * head_dim) as usize;
82    let kv_dim = (num_kv_heads * head_dim) as usize;
83    let hd = head_dim as usize;
84    let nh = num_heads as usize;
85    let nkv = num_kv_heads as usize;
86
87    // --- QKV projections (GPU matmul with pre-transposed weights) ---
88    let mut q = vec![0.0f32; s * q_dim];
89    device.matmul(hidden, q_weight, &mut q, s, h, q_dim)?;
90    let mut k = vec![0.0f32; s * kv_dim];
91    device.matmul(hidden, k_weight, &mut k, s, h, kv_dim)?;
92    let mut v = vec![0.0f32; s * kv_dim];
93    device.matmul(hidden, v_weight, &mut v, s, h, kv_dim)?;
94
95    // --- LoRA contributions (CPU, small rank) ---
96    let rank = lora_q.rank as usize;
97    let mut h_a_saved: Option<Vec<f32>> = None;
98    let mut h_av_saved: Option<Vec<f32>> = None;
99    if rank > 0 {
100        let scaling_q = lora_alpha / lora_q.rank as f32;
101        // LoRA Q: q += scaling * hidden @ A_q^T @ B_q^T
102        let mut h_a = vec![0.0f32; s * rank]; // hidden @ A^T
103        for si in 0..s {
104            for ri in 0..rank {
105                let mut sum = 0.0f32;
106                for hi in 0..h {
107                    sum += hidden[si * h + hi] * lora_q.a[ri * h + hi];
108                }
109                h_a[si * rank + ri] = sum;
110            }
111        }
112        for si in 0..s {
113            for qi in 0..q_dim {
114                let mut sum = 0.0f32;
115                for ri in 0..rank {
116                    sum += h_a[si * rank + ri] * lora_q.b[qi * rank + ri];
117                }
118                q[si * q_dim + qi] += scaling_q * sum;
119            }
120        }
121        h_a_saved = Some(h_a);
122
123        // LoRA V: v += scaling * hidden @ A_v^T @ B_v^T
124        let v_rank = lora_v.rank as usize;
125        let scaling_v = lora_alpha / lora_v.rank as f32;
126        let mut h_av = vec![0.0f32; s * v_rank];
127        for si in 0..s {
128            for ri in 0..v_rank {
129                let mut sum = 0.0f32;
130                for hi in 0..h {
131                    sum += hidden[si * h + hi] * lora_v.a[ri * h + hi];
132                }
133                h_av[si * v_rank + ri] = sum;
134            }
135        }
136        for si in 0..s {
137            for vi in 0..kv_dim {
138                let mut sum = 0.0f32;
139                for ri in 0..v_rank {
140                    sum += h_av[si * v_rank + ri] * lora_v.b[vi * v_rank + ri];
141                }
142                v[si * kv_dim + vi] += scaling_v * sum;
143            }
144        }
145        h_av_saved = Some(h_av);
146    }
147
148    // QK-norm: per-head RMS normalization (prevents attention score explosion)
149    head_rms_norm(&mut q, s, nh, q_dim, hd);
150    head_rms_norm(&mut k, s, nkv, kv_dim, hd);
151
152    // --- RoPE (sin/cos positional encoding) ---
153    for si in 0..s {
154        for head in 0..nh {
155            for d in (0..hd).step_by(2) {
156                let pos = si as f32;
157                let freq = 1.0 / (10000.0f32).powf(d as f32 / hd as f32);
158                let (sin_val, cos_val) = (pos * freq).sin_cos();
159                let idx0 = si * q_dim + head * hd + d;
160                let idx1 = idx0 + 1;
161                if idx1 < q.len() {
162                    let q0 = q[idx0];
163                    let q1 = q[idx1];
164                    q[idx0] = q0 * cos_val - q1 * sin_val;
165                    q[idx1] = q0 * sin_val + q1 * cos_val;
166                }
167            }
168        }
169        for head in 0..nkv {
170            for d in (0..hd).step_by(2) {
171                let pos = si as f32;
172                let freq = 1.0 / (10000.0f32).powf(d as f32 / hd as f32);
173                let (sin_val, cos_val) = (pos * freq).sin_cos();
174                let idx0 = si * kv_dim + head * hd + d;
175                let idx1 = idx0 + 1;
176                if idx1 < k.len() {
177                    let k0 = k[idx0];
178                    let k1 = k[idx1];
179                    k[idx0] = k0 * cos_val - k1 * sin_val;
180                    k[idx1] = k0 * sin_val + k1 * cos_val;
181                }
182            }
183        }
184    }
185
186    // GQA: 32 Q heads, 8 KV heads → 4 Q heads per KV head
187    let heads_per_kv = nh / nkv;
188    let mut context = vec![0.0f32; s * q_dim];
189    let mut attn_weights = vec![0.0f32; nh * s * s]; // cache for backward
190    let scale = 1.0 / (hd as f32).sqrt();
191
192    for head in 0..nh {
193        let kv_head = head / heads_per_kv;
194        for qi in 0..s {
195            let mut max_score = f32::NEG_INFINITY;
196            let aw_off = head * s * s + qi * s;
197            for ki in 0..s {
198                if ki > qi {
199                    attn_weights[aw_off + ki] = 0.0;
200                    continue;
201                }
202                let mut dot = 0.0f32;
203                for d in 0..hd {
204                    dot += q[qi * q_dim + head * hd + d] * k[ki * kv_dim + kv_head * hd + d];
205                }
206                attn_weights[aw_off + ki] = dot * scale;
207                if attn_weights[aw_off + ki] > max_score {
208                    max_score = attn_weights[aw_off + ki];
209                }
210            }
211            let mut sum_exp = 0.0f32;
212            for ki in 0..s {
213                attn_weights[aw_off + ki] =
214                    if ki > qi { 0.0 } else { (attn_weights[aw_off + ki] - max_score).exp() };
215                sum_exp += attn_weights[aw_off + ki];
216            }
217            if sum_exp > 0.0 {
218                for ki in 0..s {
219                    attn_weights[aw_off + ki] /= sum_exp;
220                }
221            }
222            for d in 0..hd {
223                let mut val = 0.0f32;
224                for ki in 0..s {
225                    val += attn_weights[aw_off + ki] * v[ki * kv_dim + kv_head * hd + d];
226                }
227                context[qi * q_dim + head * hd + d] = val;
228            }
229        }
230    }
231
232    // Output projection: context[s,q_dim] @ O[q_dim,h] → [s,h] (pre-transposed)
233    let mut output = vec![0.0f32; s * h];
234    device.matmul(&context, o_weight, &mut output, s, q_dim, h)?;
235
236    norm_guard(&mut output, hidden, 10.0); // relaxed from 2.0 — lets LoRA contribute more
237    let cache = AttentionCache {
238        q: q.clone(),
239        k: k.clone(),
240        v,
241        attn_weights,
242        context,
243        lora_q_h: if rank > 0 { h_a_saved.unwrap_or_default() } else { vec![] },
244        lora_v_h: if rank > 0 { h_av_saved.unwrap_or_default() } else { vec![] },
245    };
246    Ok((output, cache))
247}
248
249#[cfg(all(test, feature = "gpu"))]
250mod tests {
251    use super::*;
252    use crate::train::transformer_trainer::wgpu_nf4::LoraAdapter;
253
254    /// FALSIFY: Attention forward produces finite output and LoRA contributes
255    #[test]
256    fn test_attention_forward_basic() {
257        let device = GpuDevice::new().expect("GPU");
258        let (s, h, nh, nkv, hd) = (4u32, 16u32, 4u32, 2u32, 4u32);
259        let q_dim = (nh * hd) as usize;
260        let kv_dim = (nkv * hd) as usize;
261
262        let hidden: Vec<f32> = (0..(s * h) as usize).map(|i| (i as f32 - 32.0) * 0.01).collect();
263        let q_w: Vec<f32> = (0..q_dim * h as usize).map(|i| (i as f32 - 64.0) * 0.005).collect();
264        let k_w: Vec<f32> = (0..kv_dim * h as usize).map(|i| (i as f32 - 32.0) * 0.005).collect();
265        let v_w: Vec<f32> = (0..kv_dim * h as usize).map(|i| (i as f32 - 32.0) * 0.005).collect();
266        let o_w: Vec<f32> = (0..h as usize * q_dim).map(|i| (i as f32 - 64.0) * 0.005).collect();
267
268        let lora_q = LoraAdapter::new(4, h, q_dim as u32);
269        let lora_v = LoraAdapter::new(4, h, kv_dim as u32);
270
271        // Without LoRA (B=0 → no contribution)
272        let (out_base, _cache) = attention_forward(
273            &device, &hidden, &q_w, &k_w, &v_w, &o_w, &lora_q, &lora_v, 32.0, s, h, nh, nkv, hd,
274        )
275        .expect("attention_forward");
276
277        assert_eq!(out_base.len(), (s * h) as usize);
278        assert!(out_base.iter().all(|v| v.is_finite()), "All outputs finite");
279
280        // With non-zero LoRA B → output should differ
281        let mut lora_q2 = LoraAdapter::new(4, h, q_dim as u32);
282        for b in &mut lora_q2.b {
283            *b = 0.01;
284        }
285        let (out_lora, _) = attention_forward(
286            &device, &hidden, &q_w, &k_w, &v_w, &o_w, &lora_q2, &lora_v, 32.0, s, h, nh, nkv, hd,
287        )
288        .expect("attention_forward lora");
289
290        let diff: f32 = out_base.iter().zip(out_lora.iter()).map(|(a, b)| (a - b).abs()).sum();
291        assert!(diff > 1e-6, "LoRA Q should change attention output, diff={diff}");
292
293        eprintln!(
294            "Attention forward: output_norm={:.4}, lora_diff={diff:.6}",
295            out_base.iter().map(|v| v * v).sum::<f32>().sqrt()
296        );
297    }
298}