Skip to main content

entrenar/transformer/
wgpu_block.rs

1//! wgpu-accelerated transformer forward pass
2//!
3//! Provides [`WgpuForwardPass`] that batches all transformer layer matmuls
4//! and activations into a single GPU execution, eliminating per-operation
5//! CPU↔GPU round-trips.
6//!
7//! # Architecture
8//!
9//! - Weights are uploaded once at construction and kept GPU-resident
10//! - Each forward call: upload hidden → batch all ops → download result
11//! - Uses `GpuCommandBatch` for deferred execution with persistent buffers
12//!
13//! # Contract (C-WGPU-FWD-001)
14//!
15//! - **Precondition**: Transformer model loaded with valid weights
16//! - **Postcondition**: Output hidden states numerically match CPU forward pass (within fp32 tolerance)
17//! - **Invariant**: GPU buffers remain valid across forward calls
18
19use crate::autograd::Tensor;
20use crate::lora::LoRALayer;
21use crate::transformer::config::TransformerConfig;
22use crate::transformer::model::Transformer;
23use std::cell::RefCell;
24use std::sync::Arc;
25use trueno::backends::gpu::{wgpu, GpuCommandBatch, GpuDevice, PipelineCache};
26
27/// Pre-uploaded FFN weight buffers for a single transformer layer (KAIZEN-015).
28///
29/// Each buffer is wrapped in `Arc` so it can be shared across multiple
30/// `GpuCommandBatch` executions via `import_buffer()` without re-uploading.
31/// The buffers remain GPU-resident for the lifetime of `WgpuForwardPass`.
32struct GpuResidentFfnWeights {
33    /// Gate projection: (hidden_size, intermediate_size)
34    w_gate: Arc<wgpu::Buffer>,
35    /// Up projection: (hidden_size, intermediate_size)
36    w_up: Arc<wgpu::Buffer>,
37    /// Down projection: (intermediate_size, hidden_size)
38    w_down: Arc<wgpu::Buffer>,
39    /// Size of gate/up buffers in f32 elements (hidden_size * intermediate_size)
40    gate_up_elements: usize,
41    /// Size of down buffer in f32 elements (intermediate_size * hidden_size)
42    down_elements: usize,
43}
44
45/// wgpu-accelerated transformer forward pass
46///
47/// Batches all matmul and activation operations across transformer layers
48/// into a single `GpuCommandBatch::execute()` call.
49///
50/// Current implementation accelerates FFN matmuls (gate/up/down projections)
51/// which dominate compute time (~60-70% of forward pass). Attention remains
52/// on CPU due to softmax/RoPE complexity (Phase 2 will add GPU attention).
53pub struct WgpuForwardPass {
54    device: GpuDevice,
55    config: TransformerConfig,
56    /// Number of transformer layers
57    num_layers: usize,
58    /// KAIZEN-015: GPU-resident FFN weights — uploaded once, reused every forward pass.
59    /// Each layer has (w_gate, w_up, w_down) as persistent `wgpu::Buffer`s.
60    /// Empty when constructed via `new()`/`new_default()`; populated by `with_resident_weights()`.
61    ffn_weights: Vec<GpuResidentFfnWeights>,
62    /// KAIZEN-023: Persistent pipeline cache across batch executions.
63    /// Shaders compiled in layer 1's batch are reused for layers 2-36.
64    /// Reduces 108 shader compilations per forward pass to just 3.
65    /// Uses `RefCell` because `forward_ffn_gpu()` is called via `&self`.
66    pipeline_cache: RefCell<PipelineCache>,
67}
68
69impl WgpuForwardPass {
70    /// Create a new wgpu forward pass from a transformer model
71    ///
72    /// # Arguments
73    /// * `model` - Transformer model with loaded weights
74    /// * `adapter_index` - wgpu adapter index to use
75    ///
76    /// # Errors
77    /// Returns error if GPU device creation fails
78    pub fn new(config: &TransformerConfig, adapter_index: u32) -> Result<Self, String> {
79        let device = GpuDevice::new_with_adapter_index(adapter_index)?;
80
81        Ok(Self {
82            device,
83            config: config.clone(),
84            num_layers: config.num_hidden_layers,
85            ffn_weights: Vec::new(),
86            pipeline_cache: RefCell::new(PipelineCache::new()),
87        })
88    }
89
90    /// Create from default GPU adapter
91    pub fn new_default(config: &TransformerConfig) -> Result<Self, String> {
92        let device = GpuDevice::new()?;
93
94        Ok(Self {
95            device,
96            config: config.clone(),
97            num_layers: config.num_hidden_layers,
98            ffn_weights: Vec::new(),
99            pipeline_cache: RefCell::new(PipelineCache::new()),
100        })
101    }
102
103    /// Create from model with GPU-resident FFN weights (KAIZEN-015).
104    ///
105    /// Uploads all FFN weights (gate, up, down projections for every layer) to
106    /// GPU memory at construction time. Subsequent forward passes use
107    /// `import_buffer()` to reference these persistent buffers, eliminating
108    /// the per-call H2D transfer overhead.
109    ///
110    /// For Qwen3-4B (36 layers, hidden=3584, intermediate=18944):
111    /// - Per-layer: 3 * 3584 * 18944 * 4 bytes = ~775 MB
112    /// - Total: 36 layers * 775 MB = ~27.2 GB GPU memory (once)
113    /// - Savings: 14.6 GB/step H2D transfer eliminated
114    ///
115    /// # Contract (C-WGPU-RESIDENT-001)
116    ///
117    /// - **Precondition**: Model has valid FFN weights for all layers
118    /// - **Postcondition**: All FFN weights GPU-resident, zero H2D transfers during forward
119    /// - **Invariant**: Weights are frozen (read-only) -- LoRA adapters remain CPU-resident
120    pub fn with_resident_weights(model: &Transformer) -> Result<Self, String> {
121        contract_pre_with_resident_weights!();
122        let device = GpuDevice::new()?;
123        let config = model.config.clone();
124        let num_layers = config.num_hidden_layers;
125        let hidden_size = config.hidden_size;
126        let intermediate_size = config.intermediate_size;
127        let gate_up_elements = hidden_size * intermediate_size;
128        let down_elements = intermediate_size * hidden_size;
129
130        let mut ffn_weights = Vec::with_capacity(num_layers);
131        let mut total_bytes: usize = 0;
132
133        for (i, layer) in model.layers.iter().enumerate() {
134            let gate_data = layer.ffn.w_gate.data();
135            let gate_slice = gate_data
136                .as_slice()
137                .ok_or_else(|| format!("Layer {i}: gate weight not contiguous"))?;
138            let up_data = layer.ffn.w_up.data();
139            let up_slice =
140                up_data.as_slice().ok_or_else(|| format!("Layer {i}: up weight not contiguous"))?;
141            let down_data = layer.ffn.w_down.data();
142            let down_slice = down_data
143                .as_slice()
144                .ok_or_else(|| format!("Layer {i}: down weight not contiguous"))?;
145
146            let w_gate = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
147                label: Some(&format!("ffn_gate_L{i}")),
148                size: (gate_slice.len() * 4) as u64,
149                usage: wgpu::BufferUsages::STORAGE
150                    | wgpu::BufferUsages::COPY_SRC
151                    | wgpu::BufferUsages::COPY_DST,
152                mapped_at_creation: false,
153            }));
154            device.queue.write_buffer(&w_gate, 0, bytemuck::cast_slice(gate_slice));
155
156            let w_up = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
157                label: Some(&format!("ffn_up_L{i}")),
158                size: (up_slice.len() * 4) as u64,
159                usage: wgpu::BufferUsages::STORAGE
160                    | wgpu::BufferUsages::COPY_SRC
161                    | wgpu::BufferUsages::COPY_DST,
162                mapped_at_creation: false,
163            }));
164            device.queue.write_buffer(&w_up, 0, bytemuck::cast_slice(up_slice));
165
166            let w_down = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
167                label: Some(&format!("ffn_down_L{i}")),
168                size: (down_slice.len() * 4) as u64,
169                usage: wgpu::BufferUsages::STORAGE
170                    | wgpu::BufferUsages::COPY_SRC
171                    | wgpu::BufferUsages::COPY_DST,
172                mapped_at_creation: false,
173            }));
174            device.queue.write_buffer(&w_down, 0, bytemuck::cast_slice(down_slice));
175
176            let layer_bytes = (gate_slice.len() + up_slice.len() + down_slice.len()) * 4;
177            total_bytes += layer_bytes;
178
179            ffn_weights.push(GpuResidentFfnWeights {
180                w_gate,
181                w_up,
182                w_down,
183                gate_up_elements,
184                down_elements,
185            });
186        }
187
188        let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
189        eprintln!("[wgpu] GPU-resident FFN weights: {num_layers} layers, {total_mb:.1} MB");
190
191        Ok(Self {
192            device,
193            config,
194            num_layers,
195            ffn_weights,
196            pipeline_cache: RefCell::new(PipelineCache::new()),
197        })
198    }
199
200    /// Execute forward pass through all transformer layers on GPU
201    ///
202    /// Batches FFN matmuls (gate/up/down projections) per layer. Attention and
203    /// normalization remain on CPU for this phase.
204    ///
205    /// # Arguments
206    /// * `model` - Transformer model with weights
207    /// * `token_ids` - Input token IDs
208    ///
209    /// # Returns
210    /// Hidden states tensor (seq_len * hidden_size)
211    pub fn forward_hidden(&self, model: &Transformer, token_ids: &[u32]) -> Result<Tensor, String> {
212        let seq_len = token_ids.len();
213        let hidden_size = self.config.hidden_size;
214        let intermediate_size = self.config.intermediate_size;
215
216        // Step 1: Embed tokens on CPU (small operation)
217        let mut hidden = model.embed_tokens.forward(token_ids);
218
219        // Step 2: Process each transformer layer
220        // Attention stays on CPU; FFN matmuls go to GPU via batched execution
221        //
222        // KAIZEN-004: Suppress per-op wgpu during attention. Without this,
223        // each Q/K/V/O projection triggers a full buffer upload/compute/download
224        // cycle (144 per-op matmuls × ~3-5ms = 430-720ms overhead per sample).
225        // CPU SIMD is equally fast and doesn't contend for GPU bandwidth.
226        crate::autograd::suppress_per_op_wgpu();
227        for (layer_idx, layer) in model.layers.iter().enumerate() {
228            // --- Attention on CPU SIMD (includes RoPE, softmax, masking) ---
229            let norm1 = layer.input_norm.forward_batched(&hidden, seq_len, hidden_size);
230            let attn_out = layer.self_attn.forward(&norm1, seq_len);
231            let residual1 = crate::autograd::add(&hidden, &attn_out);
232
233            // --- FFN on GPU (3 large matmuls + SwiGLU) ---
234            let norm2 = layer.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
235
236            // KAIZEN-015: Use GPU-resident weights if available
237            let resident = self.ffn_weights.get(layer_idx);
238
239            let ffn_out = self.forward_ffn_gpu(
240                &norm2,
241                &layer.ffn.w_gate,
242                &layer.ffn.w_up,
243                &layer.ffn.w_down,
244                seq_len,
245                hidden_size,
246                intermediate_size,
247                resident,
248            )?;
249
250            // Residual connection
251            hidden = crate::autograd::add(&residual1, &ffn_out);
252        }
253        // KAIZEN-004: Re-enable per-op wgpu for backward pass / other operations
254        crate::autograd::unsuppress_per_op_wgpu();
255
256        // Step 3: Final normalization on CPU
257        let normalized = model.norm.forward_batched(&hidden, seq_len, hidden_size);
258
259        Ok(normalized)
260    }
261
262    /// Execute FFN forward pass on GPU using batched operations.
263    ///
264    /// Batches gate/up/down matmuls + SwiGLU into single GPU execution:
265    /// 1. Upload input (small: seq_len * hidden_size)
266    /// 2. Import or upload weights (gate, up, down)
267    /// 3. Compute: gate = input @ w_gate, up = input @ w_up, silu(gate) * up, down
268    /// 4. Download: ffn_output
269    ///
270    /// When `resident_weights` is `Some`, weight buffers are imported from
271    /// persistent GPU memory (KAIZEN-015: zero H2D transfer for weights).
272    /// When `None`, falls back to uploading from CPU tensors each call.
273    fn forward_ffn_gpu(
274        &self,
275        input: &Tensor,
276        w_gate: &Tensor,
277        w_up: &Tensor,
278        w_down: &Tensor,
279        seq_len: usize,
280        hidden_size: usize,
281        intermediate_size: usize,
282        resident_weights: Option<&GpuResidentFfnWeights>,
283    ) -> Result<Tensor, String> {
284        use trueno::backends::gpu::runtime;
285
286        runtime::block_on(async {
287            let mut batch = GpuCommandBatch::new(self.device.clone());
288
289            // Upload input (always from CPU — small: seq_len * hidden_size)
290            let input_data = input.data();
291            let input_slice = input_data.as_slice().ok_or("Input tensor not contiguous")?;
292            let buf_input = batch.upload(input_slice);
293
294            // KAIZEN-015: Import GPU-resident weights or fall back to CPU upload
295            let (buf_gate, buf_up, buf_down) = if let Some(rw) = resident_weights {
296                // Zero H2D transfer — Arc::clone is cheap (~1 atomic increment)
297                let g = batch.import_buffer(Arc::clone(&rw.w_gate), rw.gate_up_elements);
298                let u = batch.import_buffer(Arc::clone(&rw.w_up), rw.gate_up_elements);
299                let d = batch.import_buffer(Arc::clone(&rw.w_down), rw.down_elements);
300                (g, u, d)
301            } else {
302                // Fallback: upload weights from CPU tensors (original path)
303                let gate_data = w_gate.data();
304                let gate_slice = gate_data.as_slice().ok_or("Gate weight not contiguous")?;
305                let up_data = w_up.data();
306                let up_slice = up_data.as_slice().ok_or("Up weight not contiguous")?;
307                let down_data = w_down.data();
308                let down_slice = down_data.as_slice().ok_or("Down weight not contiguous")?;
309                let g = batch.upload(gate_slice);
310                let u = batch.upload(up_slice);
311                let d = batch.upload(down_slice);
312                (g, u, d)
313            };
314
315            // Gate projection: (seq_len, hidden) @ (hidden, intermediate)
316            let gate_out = batch.matmul(
317                buf_input,
318                buf_gate,
319                seq_len as u32,
320                hidden_size as u32,
321                intermediate_size as u32,
322            );
323
324            // Up projection: (seq_len, hidden) @ (hidden, intermediate)
325            let up_out = batch.matmul(
326                buf_input,
327                buf_up,
328                seq_len as u32,
329                hidden_size as u32,
330                intermediate_size as u32,
331            );
332
333            // SwiGLU: swish(gate) * up
334            let gate_activated = batch.swish(gate_out);
335            let swiglu_out = batch.mul(gate_activated, up_out);
336
337            // Down projection: (seq_len, intermediate) @ (intermediate, hidden)
338            let ffn_out = batch.matmul(
339                swiglu_out,
340                buf_down,
341                seq_len as u32,
342                intermediate_size as u32,
343                hidden_size as u32,
344            );
345
346            // Execute all ops in single batch — KAIZEN-023: persistent pipeline cache
347            batch.execute_with_cache(&mut self.pipeline_cache.borrow_mut()).await?;
348
349            // Download result
350            let result_data = batch.read(ffn_out).await?;
351
352            Ok(Tensor::from_vec(result_data, false))
353        })
354    }
355
356    /// Execute batched forward pass for multiple samples (KAIZEN-008).
357    ///
358    /// Processes all samples through each layer together, uploading FFN weights
359    /// ONCE per layer instead of once per sample. With batch_size=20 × 36 layers,
360    /// this reduces weight uploads from 720 to 36 (20× reduction, ~146 GB saved).
361    ///
362    /// Attention remains per-sample on CPU SIMD. FFN inputs are concatenated
363    /// across samples for a single large matmul per layer.
364    ///
365    /// # KAIZEN-010: LoRA integration
366    ///
367    /// When `lora_layers` is provided, attention uses `forward_with_lora()` to
368    /// apply LoRA adjusts to Q and V projections. Layout: `[Q_0, V_0, Q_1, V_1, ...]`
369    /// (2 LoRA layers per transformer layer). Without this, only the classifier
370    /// head trains on the wgpu path (5,122 params vs 5.9M with LoRA).
371    pub fn forward_hidden_batch(
372        &self,
373        model: &Transformer,
374        batch_token_ids: &[Vec<u32>],
375        lora_layers: Option<&[LoRALayer]>,
376    ) -> Result<Vec<Tensor>, String> {
377        let hidden_size = self.config.hidden_size;
378        let intermediate_size = self.config.intermediate_size;
379        let n = batch_token_ids.len();
380
381        // Step 1: Embed all samples on CPU
382        let mut hiddens: Vec<Tensor> =
383            batch_token_ids.iter().map(|ids| model.embed_tokens.forward(ids)).collect();
384
385        // Step 2: Layer-at-a-time processing
386        // Attention on CPU SIMD (per-sample), FFN on GPU (all samples concatenated)
387        //
388        // KAIZEN-017: Pre-compute total_tokens once (constant across layers).
389        let total_tokens: usize = batch_token_ids.iter().map(std::vec::Vec::len).sum();
390
391        crate::autograd::suppress_per_op_wgpu();
392        for (layer_idx, layer) in model.layers.iter().enumerate() {
393            // Attention on CPU (per-sample, independent)
394            // KAIZEN-017: Keep Tensor references instead of .to_vec() copies.
395            // Eliminates n per-sample Vec allocations per layer (360 MB total for Qwen3-4B).
396            let mut ffn_input_tensors: Vec<Tensor> = Vec::with_capacity(n);
397            let mut residuals: Vec<Tensor> = Vec::with_capacity(n);
398            for (i, hidden) in hiddens.iter().enumerate() {
399                let seq_len = batch_token_ids[i].len();
400                let norm1 = layer.input_norm.forward_batched(hidden, seq_len, hidden_size);
401
402                // KAIZEN-010: Use LoRA-enabled attention when adapters are available
403                let attn_out = match lora_layers {
404                    Some(loras) => {
405                        let q_idx = layer_idx * 2;
406                        let v_idx = layer_idx * 2 + 1;
407                        if v_idx < loras.len() {
408                            layer.self_attn.forward_with_lora(
409                                &norm1,
410                                seq_len,
411                                loras[q_idx].lora_a(),
412                                loras[q_idx].lora_b(),
413                                loras[v_idx].lora_a(),
414                                loras[v_idx].lora_b(),
415                                loras[q_idx].rank(),
416                                loras[q_idx].scale(),
417                            )
418                        } else {
419                            layer.self_attn.forward(&norm1, seq_len)
420                        }
421                    }
422                    None => layer.self_attn.forward(&norm1, seq_len),
423                };
424
425                let residual1 = crate::autograd::add(hidden, &attn_out);
426                let norm2 = layer.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
427                ffn_input_tensors.push(norm2);
428                residuals.push(residual1);
429            }
430
431            // Concatenate all samples' FFN inputs for single GPU batch
432            // KAIZEN-017: Borrow directly from Tensor instead of intermediate Vecs
433            let mut concat_input = Vec::with_capacity(total_tokens * hidden_size);
434            for norm2 in &ffn_input_tensors {
435                let data = norm2.data();
436                concat_input.extend_from_slice(data.as_slice().expect("norm2 contiguous"));
437            }
438            let concat_tensor = Tensor::from_vec(concat_input, false);
439
440            // FFN on GPU — KAIZEN-015: import GPU-resident weights (zero H2D)
441            let resident = self.ffn_weights.get(layer_idx);
442
443            let ffn_out = self.forward_ffn_gpu(
444                &concat_tensor,
445                &layer.ffn.w_gate,
446                &layer.ffn.w_up,
447                &layer.ffn.w_down,
448                total_tokens,
449                hidden_size,
450                intermediate_size,
451                resident,
452            )?;
453
454            // Split FFN output back into per-sample tensors + residual
455            let ffn_data = ffn_out.data();
456            let ffn_slice = ffn_data.as_slice().expect("ffn contiguous");
457            let mut offset = 0;
458            hiddens = residuals
459                .into_iter()
460                .enumerate()
461                .map(|(i, r)| {
462                    let len = batch_token_ids[i].len() * hidden_size;
463                    let sample_ffn =
464                        Tensor::from_vec(ffn_slice[offset..offset + len].to_vec(), false);
465                    offset += len;
466                    crate::autograd::add(&r, &sample_ffn)
467                })
468                .collect();
469        }
470        crate::autograd::unsuppress_per_op_wgpu();
471
472        // Step 3: Final normalization (per-sample)
473        let results: Vec<Tensor> = hiddens
474            .into_iter()
475            .enumerate()
476            .map(|(i, h)| {
477                let seq_len = batch_token_ids[i].len();
478                model.norm.forward_batched(&h, seq_len, hidden_size)
479            })
480            .collect();
481
482        Ok(results)
483    }
484
485    /// Get the adapter info for display
486    pub fn adapter_info(&self) -> String {
487        format!(
488            "wgpu device ({}x{} model, {} layers)",
489            self.config.hidden_size, self.config.intermediate_size, self.num_layers
490        )
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_wgpu_forward_pass_creation() {
500        if !GpuDevice::is_available() {
501            eprintln!("GPU not available, skipping test");
502            return;
503        }
504
505        let mut config = TransformerConfig::llama2_7b();
506        config.hidden_size = 64;
507        config.num_hidden_layers = 2;
508        config.num_attention_heads = 4;
509        config.num_kv_heads = 4;
510        config.intermediate_size = 128;
511        config.vocab_size = 100;
512
513        let pass = WgpuForwardPass::new_default(&config);
514        assert!(pass.is_ok(), "WgpuForwardPass creation failed: {:?}", pass.err());
515    }
516
517    #[test]
518    fn test_wgpu_ffn_numerical_correctness() {
519        if !GpuDevice::is_available() {
520            eprintln!("GPU not available, skipping test");
521            return;
522        }
523
524        let mut config = TransformerConfig::llama2_7b();
525        config.hidden_size = 8;
526        config.num_hidden_layers = 1;
527        config.num_attention_heads = 2;
528        config.num_kv_heads = 2;
529        config.intermediate_size = 16;
530        config.vocab_size = 32;
531
532        let pass =
533            WgpuForwardPass::new_default(&config).expect("GPU available but creation failed");
534
535        // Create small test tensors
536        let input = Tensor::from_vec(vec![1.0; 8], false); // 1 token × 8 hidden
537        let w_gate = Tensor::from_vec(vec![0.1; 8 * 16], false);
538        let w_up = Tensor::from_vec(vec![0.1; 8 * 16], false);
539        let w_down = Tensor::from_vec(vec![0.1; 16 * 8], false);
540
541        let gpu_result = pass.forward_ffn_gpu(
542            &input, &w_gate, &w_up, &w_down, 1, 8, 16,
543            None, // No GPU-resident weights for this test
544        );
545
546        assert!(gpu_result.is_ok(), "GPU FFN failed: {:?}", gpu_result.err());
547
548        let gpu_data = gpu_result.expect("checked above");
549        assert_eq!(gpu_data.len(), 8, "Output should be 1 × 8");
550
551        // Verify no NaN/Inf
552        for (i, &val) in gpu_data.data().iter().enumerate() {
553            assert!(val.is_finite(), "NaN/Inf at index {i}: {val}");
554        }
555    }
556}