Skip to main content

trueno/backends/gpu/device/linalg/
wgsl_forward.rs

1#![allow(dead_code, clippy::many_single_char_names)]
2//! PMAT-324: WGSL transformer forward pass — multi-pass single submission.
3//!
4//! Instead of one matmul per CPU call (2ms roundtrip each), this encodes
5//! ALL operations for one transformer layer into a single command encoder.
6//! Only one submit + one readback per layer (or per full forward pass).
7//!
8//! Architecture: separate WGSL kernels per operation type, dispatched
9//! sequentially within one command encoder. All intermediate data stays
10//! GPU-resident in persistent buffers.
11
12use std::collections::HashMap;
13
14/// Saved activations for one transformer layer's backward pass.
15///
16/// Contains the 7 tensors needed for LoRA gradient computation
17/// without replaying the forward pass (§26.11.5, falsification-verified).
18pub struct LayerActivations {
19    /// Input to Q/K/V projections (RMSNorm output). [seq, hidden]
20    pub attn_norm_out: wgpu::Buffer,
21    /// Input to O projection (attention output). [seq, q_dim]
22    pub attn_output: wgpu::Buffer,
23    /// Input to gate/up/down projections (FFN RMSNorm output). [seq, hidden]
24    pub ffn_norm_out: wgpu::Buffer,
25    /// Input to down projection (SiLU(gate)×up). [seq, intermediate]
26    pub silu_gate_output: wgpu::Buffer,
27    /// RMSNorm reciprocal std for attention norm. [seq]
28    pub rstd_attn: wgpu::Buffer,
29    /// RMSNorm reciprocal std for FFN norm. [seq]
30    pub rstd_ffn: wgpu::Buffer,
31    /// Softmax logsumexp for attention backward. [num_heads, seq]
32    pub softmax_logsumexp: wgpu::Buffer,
33}
34
35/// Optional LoRA buffers for Q/K/V projections in a layer's forward pass.
36pub struct QkvLoRA<'a> {
37    pub q_a: &'a wgpu::Buffer,
38    pub q_b: &'a wgpu::Buffer,
39    pub k_a: &'a wgpu::Buffer,
40    pub k_b: &'a wgpu::Buffer,
41    pub v_a: &'a wgpu::Buffer,
42    pub v_b: &'a wgpu::Buffer,
43    pub rank: u32,
44    pub scale: f32,
45    pub in_dim: u32,
46    pub q_dim: u32,
47    pub kv_dim: u32,
48    pub lora_pipeline: &'a wgpu::ComputePipeline,
49    pub lora_bgl: &'a wgpu::BindGroupLayout,
50}
51
52/// GPU-resident transformer layer state.
53/// All buffers persist across tokens — only input/output change per step.
54pub struct WgslForwardPass {
55    device: wgpu::Device,
56    queue: wgpu::Queue,
57
58    // Kernels (compiled once)
59    matmul_pipeline: wgpu::ComputePipeline,
60    /// CUTLASS-style tiled GEMM for M>=4 (training batch, prefill)
61    tiled_matmul_pipeline: wgpu::ComputePipeline,
62    /// PMAT-327: GEMV pipeline for M=1 decode (cooperative K-reduction)
63    gemv_pipeline: wgpu::ComputePipeline,
64    /// C-WGPU-Q4K-001: Q4K GEMV pipeline — dequantize-on-the-fly, no F32 weights
65    q4k_gemv_pipeline: wgpu::ComputePipeline,
66    /// Causal attention pipeline for training (full sequence, no KV cache)
67    attention_pipeline: wgpu::ComputePipeline,
68    attention_bgl: wgpu::BindGroupLayout,
69    rmsnorm_pipeline: wgpu::ComputePipeline,
70    silu_mul_pipeline: wgpu::ComputePipeline,
71    rope_pipeline: wgpu::ComputePipeline,
72    batch_rope_pipeline: wgpu::ComputePipeline,
73    batch_rope_bgl: wgpu::BindGroupLayout,
74    residual_pipeline: wgpu::ComputePipeline,
75
76    // Bind group layouts
77    matmul_bgl: wgpu::BindGroupLayout,
78    elementwise_bgl: wgpu::BindGroupLayout,
79
80    // Weight buffers (persistent, uploaded once)
81    weight_buffers: HashMap<String, wgpu::Buffer>,
82    /// GH-560: Raw Q4K weight buffers for fused dequant+GEMV.
83    q4k_weights: HashMap<String, wgpu::Buffer>,
84    /// PMAT-342: CPU-side bias data (small, not worth GPU dispatch)
85    cpu_biases: HashMap<String, Vec<f32>>,
86    /// GH-560: Per-layer GPU KV cache buffers.
87    kv_cache_k: Vec<wgpu::Buffer>,
88    /// GH-560: Per-layer GPU KV cache buffers (values).
89    kv_cache_v: Vec<wgpu::Buffer>,
90
91    // Intermediate buffers (persistent, reused across calls)
92    // For 1.5B: hidden=1536, kv=256, intermediate=8960
93    hidden_buf: wgpu::Buffer,   // [hidden_dim] working state
94    q_buf: wgpu::Buffer,        // [q_dim]
95    k_buf: wgpu::Buffer,        // [kv_dim]
96    v_buf: wgpu::Buffer,        // [kv_dim]
97    attn_out_buf: wgpu::Buffer, // [hidden_dim]
98    ffn_gate_buf: wgpu::Buffer, // [intermediate_dim]
99    ffn_up_buf: wgpu::Buffer,   // [intermediate_dim]
100    ffn_silu_buf: wgpu::Buffer, // [intermediate_dim] — SiLU(gate)×up output (can't alias inputs)
101    ffn_out_buf: wgpu::Buffer,  // [hidden_dim]
102    norm_buf: wgpu::Buffer,     // [hidden_dim] for RMSNorm output
103    staging_buf: wgpu::Buffer,  // readback
104
105    // Config
106    hidden_dim: u32,
107    num_heads: u32,
108    num_kv_heads: u32,
109    head_dim: u32,
110    intermediate_dim: u32,
111}
112
113// WGSL shader source for RMSNorm (multi-row via workgroup_id.y)
114// Dispatch: (1, seq_len, 1) — one workgroup per row.
115const RMSNORM_SHADER: &str = r#"
116@group(0) @binding(0) var<storage, read> input: array<f32>;
117@group(0) @binding(1) var<storage, read> weight: array<f32>;
118@group(0) @binding(2) var<storage, read_write> output: array<f32>;
119@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
120
121var<workgroup> shared_sum: array<f32, 256>;
122
123@compute @workgroup_size(256)
124fn main(@builtin(local_invocation_id) lid: vec3<u32>,
125        @builtin(workgroup_id) wg_id: vec3<u32>) {
126    let dim = params.x;
127    let row = wg_id.y;
128    let base = row * dim;
129    let tid = lid.x;
130
131    // Compute sum of squares (reduction) for this row
132    var local_sum: f32 = 0.0;
133    var i = tid;
134    while (i < dim) {
135        let val = input[base + i];
136        local_sum += val * val;
137        i += 256u;
138    }
139    shared_sum[tid] = local_sum;
140    workgroupBarrier();
141
142    // Tree reduction
143    var stride = 128u;
144    while (stride > 0u) {
145        if (tid < stride) {
146            shared_sum[tid] += shared_sum[tid + stride];
147        }
148        workgroupBarrier();
149        stride >>= 1u;
150    }
151
152    let rms = sqrt(shared_sum[0] / f32(dim) + 1e-6);
153
154    // Normalize and scale
155    i = tid;
156    while (i < dim) {
157        output[base + i] = (input[base + i] / rms) * weight[i];
158        i += 256u;
159    }
160}
161"#;
162
163// WGSL shader for SiLU(gate) * up
164const SILU_MUL_SHADER: &str = r#"
165@group(0) @binding(0) var<storage, read> gate: array<f32>;
166@group(0) @binding(1) var<storage, read> up: array<f32>;
167@group(0) @binding(2) var<storage, read_write> output: array<f32>;
168@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
169
170@compute @workgroup_size(256)
171fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
172    let idx = gid.x;
173    if (idx >= params.x) { return; }
174    let g = gate[idx];
175    let silu_g = g / (1.0 + exp(-g));
176    output[idx] = silu_g * up[idx];
177}
178"#;
179
180// WGSL shader for residual add: output = a + b
181const RESIDUAL_SHADER: &str = r#"
182@group(0) @binding(0) var<storage, read> a: array<f32>;
183@group(0) @binding(1) var<storage, read> b: array<f32>;
184@group(0) @binding(2) var<storage, read_write> output: array<f32>;
185@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
186
187@compute @workgroup_size(256)
188fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
189    let idx = gid.x;
190    if (idx >= params.x) { return; }
191    output[idx] = a[idx] + b[idx];
192}
193"#;
194
195// Batch RoPE shader — applies RoPE to all positions in a sequence at once.
196// PMAT-509: Training forward path was missing RoPE entirely, causing loss > random.
197// Input: qk[seq_len * num_heads * head_dim], applies position-dependent rotation.
198const BATCH_ROPE_SHADER: &str = r#"
199@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
200
201struct RopeParams {
202    seq_len: u32,
203    num_heads: u32,
204    head_dim: u32,
205    _pad: u32,
206}
207
208@group(0) @binding(1) var<uniform> params: RopeParams;
209
210@compute @workgroup_size(256)
211fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
212    let idx = gid.x;
213    let total = params.seq_len * params.num_heads * params.head_dim;
214    if (idx >= total) { return; }
215
216    let head_dim = params.head_dim;
217    let half_hd = head_dim / 2u;
218
219    // Decompose idx into (position, head, pos_in_head)
220    let elements_per_pos = params.num_heads * head_dim;
221    let position = idx / elements_per_pos;
222    let within_pos = idx % elements_per_pos;
223    let head_idx = within_pos / head_dim;
224    let pos_in_head = within_pos % head_dim;
225
226    // Only process the first half of each head (pairs with second half)
227    if (pos_in_head >= half_hd) { return; }
228
229    let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
230    let angle = f32(position) * theta;
231    let cos_a = cos(angle);
232    let sin_a = sin(angle);
233
234    let base = position * elements_per_pos + head_idx * head_dim;
235    let i0 = base + pos_in_head;
236    let i1 = i0 + half_hd;
237
238    let x0 = qk[i0];
239    let x1 = qk[i1];
240    qk[i0] = x0 * cos_a - x1 * sin_a;
241    qk[i1] = x0 * sin_a + x1 * cos_a;
242}
243"#;
244
245// RoPE shader (NeoX-style interleaved) — single position (inference)
246const ROPE_SHADER: &str = r#"
247@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
248@group(0) @binding(1) var<uniform> params: vec4<u32>; // (dim, position, num_heads, head_dim)
249
250@compute @workgroup_size(256)
251fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
252    let idx = gid.x;
253    let dim = params.x;
254    let position = params.y;
255    let head_dim = params.w;
256
257    if (idx >= dim) { return; }
258
259    let half_hd = head_dim / 2u;
260    let head_idx = idx / head_dim;
261    let pos_in_head = idx % head_dim;
262
263    if (pos_in_head >= half_hd) { return; }
264
265    let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
266    let angle = f32(position) * theta;
267    let cos_a = cos(angle);
268    let sin_a = sin(angle);
269
270    let i0 = head_idx * head_dim + pos_in_head;
271    let i1 = i0 + half_hd;
272
273    let x0 = qk[i0];
274    let x1 = qk[i1];
275    qk[i0] = x0 * cos_a - x1 * sin_a;
276    qk[i1] = x0 * sin_a + x1 * cos_a;
277}
278"#;
279
280impl WgslForwardPass {
281    /// Get the shader sources for external inspection/testing
282    pub fn rmsnorm_shader() -> &'static str {
283        RMSNORM_SHADER
284    }
285    pub fn silu_mul_shader() -> &'static str {
286        SILU_MUL_SHADER
287    }
288    pub fn residual_shader() -> &'static str {
289        RESIDUAL_SHADER
290    }
291    pub fn rope_shader() -> &'static str {
292        ROPE_SHADER
293    }
294
295    /// PMAT-325: Create a new WGSL forward pass context.
296    ///
297    /// Compiles all shader pipelines and allocates persistent intermediate buffers.
298    /// Call once at model init. All GPU resources persist until dropped.
299    pub fn new(
300        device: wgpu::Device,
301        queue: wgpu::Queue,
302        hidden_dim: usize,
303        num_heads: usize,
304        num_kv_heads: usize,
305        head_dim: usize,
306        intermediate_dim: usize,
307    ) -> Self {
308        let q_dim = num_heads * head_dim;
309        let kv_dim = num_kv_heads * head_dim;
310
311        // Compile shaders
312        let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
313            label: Some("matmul"),
314            source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
315        });
316        let rmsnorm_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
317            label: Some("rmsnorm"),
318            source: wgpu::ShaderSource::Wgsl(RMSNORM_SHADER.into()),
319        });
320        let silu_mul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
321            label: Some("silu_mul"),
322            source: wgpu::ShaderSource::Wgsl(SILU_MUL_SHADER.into()),
323        });
324        let rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
325            label: Some("rope"),
326            source: wgpu::ShaderSource::Wgsl(ROPE_SHADER.into()),
327        });
328        let residual_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
329            label: Some("residual"),
330            source: wgpu::ShaderSource::Wgsl(RESIDUAL_SHADER.into()),
331        });
332
333        // Bind group layouts
334        let matmul_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
335            label: Some("matmul_bgl"),
336            entries: &[
337                bgl_storage(0, true),
338                bgl_storage(1, true),
339                bgl_storage(2, false),
340                bgl_uniform(3),
341            ],
342        });
343        let elementwise_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
344            label: Some("ew_bgl"),
345            entries: &[
346                bgl_storage(0, true),
347                bgl_storage(1, true),
348                bgl_storage(2, false),
349                bgl_uniform(3),
350            ],
351        });
352
353        // Pipelines
354        let make_pipeline =
355            |shader: &wgpu::ShaderModule, bgl: &wgpu::BindGroupLayout, label: &str| {
356                let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
357                    label: Some(label),
358                    bind_group_layouts: &[bgl],
359                    push_constant_ranges: &[],
360                });
361                device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
362                    label: Some(label),
363                    layout: Some(&pl),
364                    module: shader,
365                    entry_point: Some("main"),
366                    compilation_options: Default::default(),
367                    cache: None,
368                })
369            };
370
371        let matmul_pipeline = make_pipeline(&matmul_shader, &matmul_bgl, "matmul_pipe");
372
373        // CUTLASS-style tiled GEMM for M>=4 (training batch, prefill)
374        let tiled_matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
375            label: Some("tiled_matmul"),
376            source: wgpu::ShaderSource::Wgsl(
377                crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
378            ),
379        });
380        let tiled_matmul_pipeline =
381            make_pipeline(&tiled_matmul_shader, &matmul_bgl, "tiled_matmul_pipe");
382
383        // Causal attention pipeline for training
384        let attention_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
385            label: Some("causal_attention"),
386            source: wgpu::ShaderSource::Wgsl(
387                crate::backends::gpu::shaders::CAUSAL_ATTENTION_SHADER.into(),
388            ),
389        });
390        let attention_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
391            label: Some("attn_bgl"),
392            entries: &[
393                bgl_storage(0, true),  // Q
394                bgl_storage(1, true),  // K
395                bgl_storage(2, true),  // V
396                bgl_storage(3, false), // output
397                bgl_uniform(4),        // params
398            ],
399        });
400        let attention_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
401            label: Some("attn_pl"),
402            bind_group_layouts: &[&attention_bgl],
403            push_constant_ranges: &[],
404        });
405        let attention_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
406            label: Some("attn_pipe"),
407            layout: Some(&attention_pl),
408            module: &attention_shader,
409            entry_point: Some("main"),
410            compilation_options: Default::default(),
411            cache: None,
412        });
413
414        // PMAT-327: GEMV pipeline — same bind group layout as matmul but cooperative reduction
415        let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
416            label: Some("gemv"),
417            source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
418        });
419        let gemv_pipeline = make_pipeline(&gemv_shader, &matmul_bgl, "gemv_pipe");
420
421        // C-WGPU-Q4K-001: Q4K GEMV — dequantize on-the-fly, no F32 weight buffer
422        let q4k_gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
423            label: Some("q4k_gemv"),
424            source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::Q4K_GEMV_SHADER.into()),
425        });
426        let q4k_gemv_pipeline = make_pipeline(&q4k_gemv_shader, &matmul_bgl, "q4k_gemv_pipe");
427
428        let rmsnorm_pipeline = make_pipeline(&rmsnorm_shader, &elementwise_bgl, "rmsnorm_pipe");
429        let silu_mul_pipeline = make_pipeline(&silu_mul_shader, &elementwise_bgl, "silu_pipe");
430        let residual_pipeline = make_pipeline(&residual_shader_mod, &elementwise_bgl, "res_pipe");
431
432        // RoPE has a 2-binding layout (in-place + uniform)
433        let rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
434            label: Some("rope_bgl"),
435            entries: &[bgl_storage(0, false), bgl_uniform(1)],
436        });
437        let rope_pipeline = {
438            let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
439                label: Some("rope_pl"),
440                bind_group_layouts: &[&rope_bgl],
441                push_constant_ranges: &[],
442            });
443            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
444                label: Some("rope_pipe"),
445                layout: Some(&pl),
446                module: &rope_shader_mod,
447                entry_point: Some("main"),
448                compilation_options: Default::default(),
449                cache: None,
450            })
451        };
452
453        // PMAT-509: Batch RoPE for training (all positions at once)
454        let batch_rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
455            label: Some("batch_rope"),
456            source: wgpu::ShaderSource::Wgsl(BATCH_ROPE_SHADER.into()),
457        });
458        let batch_rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
459            label: Some("batch_rope_bgl"),
460            entries: &[bgl_storage(0, false), bgl_uniform(1)],
461        });
462        let batch_rope_pipeline = {
463            let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
464                label: Some("batch_rope_pl"),
465                bind_group_layouts: &[&batch_rope_bgl],
466                push_constant_ranges: &[],
467            });
468            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
469                label: Some("batch_rope_pipe"),
470                layout: Some(&pl),
471                module: &batch_rope_shader_mod,
472                entry_point: Some("main"),
473                compilation_options: Default::default(),
474                cache: None,
475            })
476        };
477
478        // Allocate persistent intermediate buffers
479        let buf = |size: usize, label: &str| -> wgpu::Buffer {
480            device.create_buffer(&wgpu::BufferDescriptor {
481                label: Some(label),
482                size: (size * 4) as u64,
483                usage: wgpu::BufferUsages::STORAGE
484                    | wgpu::BufferUsages::COPY_SRC
485                    | wgpu::BufferUsages::COPY_DST,
486                mapped_at_creation: false,
487            })
488        };
489
490        // Buffer sizes: max_seq × dim for training, or 1 × dim for inference.
491        // Training calls forward_layer_training with seq_len > 1.
492        // Allocate for max_seq=2048 to support both.
493        let max_seq = 2048;
494        let hidden_buf = buf(max_seq * hidden_dim, "hidden");
495        let q_buf = buf(max_seq * q_dim, "q");
496        let k_buf = buf(max_seq * kv_dim, "k");
497        let v_buf = buf(max_seq * kv_dim, "v");
498        let attn_out_buf = buf(max_seq * hidden_dim, "attn_out");
499        let ffn_gate_buf = buf(max_seq * intermediate_dim, "ffn_gate");
500        let ffn_up_buf = buf(max_seq * intermediate_dim, "ffn_up");
501        let ffn_silu_buf = buf(max_seq * intermediate_dim, "ffn_silu");
502        let ffn_out_buf = buf(max_seq * hidden_dim, "ffn_out");
503        let norm_buf = buf(max_seq * hidden_dim, "norm");
504
505        let max_out = max_seq * hidden_dim.max(intermediate_dim);
506        let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
507            label: Some("staging"),
508            size: (max_out * 4) as u64,
509            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
510            mapped_at_creation: false,
511        });
512
513        Self {
514            device,
515            queue,
516            matmul_pipeline,
517            tiled_matmul_pipeline,
518            attention_pipeline,
519            attention_bgl,
520            gemv_pipeline,
521            q4k_gemv_pipeline,
522            rmsnorm_pipeline,
523            silu_mul_pipeline,
524            rope_pipeline,
525            batch_rope_pipeline,
526            batch_rope_bgl,
527            residual_pipeline,
528            matmul_bgl,
529            elementwise_bgl,
530            weight_buffers: HashMap::new(),
531            q4k_weights: HashMap::new(),
532            kv_cache_k: Vec::new(),
533            kv_cache_v: Vec::new(),
534            cpu_biases: HashMap::new(),
535            hidden_buf,
536            q_buf,
537            k_buf,
538            v_buf,
539            attn_out_buf,
540            ffn_gate_buf,
541            ffn_up_buf,
542            ffn_silu_buf,
543            ffn_out_buf,
544            norm_buf,
545            staging_buf,
546            hidden_dim: hidden_dim as u32,
547            num_heads: num_heads as u32,
548            num_kv_heads: num_kv_heads as u32,
549            head_dim: head_dim as u32,
550            intermediate_dim: intermediate_dim as u32,
551        }
552    }
553
554    /// Upload a weight matrix (call once per layer at init).
555    /// PMAT-342: Bias weights (name contains "bias") are stored CPU-side.
556    pub fn upload_weight(&mut self, name: &str, data: &[f32]) {
557        if name.contains("bias") {
558            // Biases are small, keep on CPU for easy access in attention
559            self.cpu_biases.insert(name.to_string(), data.to_vec());
560            return;
561        }
562        // Skip weights that exceed the device's max buffer binding size (e.g., lm_head > 2 GB)
563        let size_bytes = (data.len() * 4) as u64;
564        let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
565        if size_bytes > max_binding {
566            eprintln!(
567                "[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB limit) — CPU fallback",
568                name,
569                size_bytes as f64 / 1e6,
570                max_binding as f64 / 1e6
571            );
572            return;
573        }
574        use wgpu::util::DeviceExt;
575        let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
576            label: Some(name),
577            contents: bytemuck::cast_slice(data),
578            usage: wgpu::BufferUsages::STORAGE,
579        });
580        self.weight_buffers.insert(name.to_string(), buffer);
581    }
582
583    /// GH-560: Upload raw Q4K weight bytes for fused dequant+GEMV on GPU.
584    pub fn upload_q4k_weight(&mut self, name: &str, data: &[u8]) {
585        use wgpu::util::DeviceExt;
586        let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
587            label: Some(name),
588            contents: data,
589            usage: wgpu::BufferUsages::STORAGE,
590        });
591        self.q4k_weights.insert(name.to_string(), buffer);
592    }
593
594    /// GH-560: Initialize per-layer KV cache buffers on GPU.
595    pub fn init_kv_cache(&mut self, num_layers: usize) {
596        let kv_dim = (self.num_kv_heads * self.head_dim) as u64;
597        let max_seq = 2048u64;
598        for _ in 0..num_layers {
599            let k = self.device.create_buffer(&wgpu::BufferDescriptor {
600                label: Some("kv_cache_k"),
601                size: max_seq * kv_dim * 4,
602                usage: wgpu::BufferUsages::STORAGE
603                    | wgpu::BufferUsages::COPY_DST
604                    | wgpu::BufferUsages::COPY_SRC,
605                mapped_at_creation: false,
606            });
607            let v = self.device.create_buffer(&wgpu::BufferDescriptor {
608                label: Some("kv_cache_v"),
609                size: max_seq * kv_dim * 4,
610                usage: wgpu::BufferUsages::STORAGE
611                    | wgpu::BufferUsages::COPY_DST
612                    | wgpu::BufferUsages::COPY_SRC,
613                mapped_at_creation: false,
614            });
615            self.kv_cache_k.push(k);
616            self.kv_cache_v.push(v);
617        }
618    }
619
620    /// Number of uploaded weight buffers.
621    pub fn weight_count(&self) -> usize {
622        self.weight_buffers.len()
623    }
624
625    /// Access a dequantized weight buffer by name (e.g. "layer.0.down_proj").
626    /// Used by backward pass for gradient propagation through frozen base weights.
627    pub fn weight_buffer(&self, name: &str) -> Option<&wgpu::Buffer> {
628        self.weight_buffers.get(name)
629    }
630
631    /// Reference to the wgpu device.
632    pub fn device_ref(&self) -> &wgpu::Device {
633        &self.device
634    }
635
636    /// Reference to the wgpu queue.
637    pub fn queue_ref(&self) -> &wgpu::Queue {
638        &self.queue
639    }
640
641    /// Reference to the hidden state buffer (for writing input).
642    pub fn hidden_buffer(&self) -> &wgpu::Buffer {
643        &self.hidden_buf
644    }
645
646    /// Reference to Q buffer (for LoRA addmm after Q projection).
647    pub fn q_buffer(&self) -> &wgpu::Buffer {
648        &self.q_buf
649    }
650
651    /// Reference to K buffer.
652    pub fn k_buffer(&self) -> &wgpu::Buffer {
653        &self.k_buf
654    }
655
656    /// Reference to V buffer.
657    pub fn v_buffer(&self) -> &wgpu::Buffer {
658        &self.v_buf
659    }
660
661    /// Elementwise add: output = a + b. Dispatches residual add shader.
662    pub fn gpu_residual_add(
663        &self,
664        a: &wgpu::Buffer,
665        b: &wgpu::Buffer,
666        output: &wgpu::Buffer,
667        len: u32,
668    ) {
669        let mut encoder = self.device.create_command_encoder(&Default::default());
670        self.encode_residual(&mut encoder, a, b, output, len);
671        self.queue.submit(Some(encoder.finish()));
672    }
673
674    /// Apply RMSNorm on GPU: normed = rmsnorm(hidden_buf, weight) → output_buf.
675    /// Contract: gpu-output-norm-v1 / gpu_resident — hidden state never leaves GPU.
676    pub fn gpu_rmsnorm(&self, weight: &wgpu::Buffer, output: &wgpu::Buffer, _seq_len: u32) {
677        let mut encoder = self.device.create_command_encoder(&Default::default());
678        self.encode_rmsnorm(&mut encoder, &self.hidden_buf, weight, output, self.hidden_dim);
679        self.queue.submit(Some(encoder.finish()));
680    }
681
682    /// Download hidden state from GPU.
683    pub fn download_hidden(&self, len: usize) -> Vec<f32> {
684        let size = (len * 4) as u64;
685        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
686            label: Some("hidden_download"),
687            size,
688            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
689            mapped_at_creation: false,
690        });
691        let mut encoder = self.device.create_command_encoder(&Default::default());
692        encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &staging, 0, size);
693        self.queue.submit(Some(encoder.finish()));
694
695        let slice = staging.slice(..size);
696        let (tx, rx) = std::sync::mpsc::channel();
697        slice.map_async(wgpu::MapMode::Read, move |r| {
698            tx.send(r).ok();
699        });
700        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
701        rx.recv()
702            .expect("GPU map_async callback channel disconnected")
703            .expect("GPU buffer mapping failed");
704
705        let data = slice.get_mapped_range();
706        let result: Vec<f32> = bytemuck::cast_slice(&data)[..len].to_vec();
707        drop(data);
708        staging.unmap();
709        result
710    }
711
712    /// Total VRAM used by all buffers (bytes).
713    pub fn total_vram_bytes(&self) -> usize {
714        let weight_bytes: usize = self.weight_buffers.values().map(|b| b.size() as usize).sum();
715        let intermediate_bytes = (self.hidden_dim as usize * 4) * 4  // hidden, attn_out, ffn_out, norm
716            + (self.num_heads as usize * self.head_dim as usize * 4) // q
717            + (self.num_kv_heads as usize * self.head_dim as usize * 4) * 2 // k, v
718            + (self.intermediate_dim as usize * 4) * 2; // gate, up
719        weight_bytes + intermediate_bytes
720    }
721
722    /// PMAT-336: Full model forward — embedding + all layers + output norm + LM head.
723    ///
724    /// Returns logits [vocab_size] for the given token at the given position.
725    /// Embedding lookup and final LM head are CPU-side (not yet GPU-accelerated).
726    /// PMAT-344: Added kv_caches for multi-token context
727    #[provable_contracts_macros::contract("wgpu-forward-pass-v1", equation = "rmsnorm_correctness")]
728    pub fn forward_model(
729        &self,
730        token_id: u32,
731        position: usize,
732        num_layers: usize,
733        token_embedding: &[f32],
734        output_norm_weight: &[f32],
735        lm_head_weight: &[f32],
736        vocab_size: usize,
737        eps: f32,
738        kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>,
739    ) -> Result<Vec<f32>, String> {
740        let hd = self.hidden_dim as usize;
741
742        // 1. Embedding lookup (CPU)
743        let embed_start = token_id as usize * hd;
744        if embed_start + hd > token_embedding.len() {
745            return Err(format!(
746                "Token {} out of range (embedding size {})",
747                token_id,
748                token_embedding.len() / hd
749            ));
750        }
751        let mut hidden: Vec<f32> = token_embedding[embed_start..embed_start + hd].to_vec();
752
753        // 2. Transformer layers (GPU via forward_layer with KV cache)
754        // Initialize KV caches if empty
755        while kv_caches.len() < num_layers {
756            kv_caches.push((Vec::new(), Vec::new()));
757        }
758        for layer_idx in 0..num_layers {
759            let prefix = format!("layer.{layer_idx}");
760            let (ref mut k_cache, ref mut v_cache) = kv_caches[layer_idx];
761            self.forward_layer(&mut hidden, &prefix, position, k_cache, v_cache)?;
762        }
763
764        // 3. Output RMSNorm (CPU — small, not worth GPU dispatch)
765        let rms = (hidden.iter().map(|x| x * x).sum::<f32>() / hd as f32 + eps).sqrt();
766        for i in 0..hd {
767            hidden[i] = (hidden[i] / rms) * output_norm_weight[i];
768        }
769
770        // 4. LM head — CPU matmul
771        // PMAT-346: GPU tiled GEMM expects weight in [K,N] layout but lm_head is [N,K].
772        // CPU path reads weight[v * hd + j] which matches the [vocab, hidden] layout.
773        // GPU LM head via GEMV is blocked by vocab > 65535 dispatch limit.
774        // TODO: add upload_weight_transposed() for GPU-accelerated LM head.
775        let mut logits = vec![0.0f32; vocab_size];
776        for v in 0..vocab_size {
777            let mut sum = 0.0f32;
778            let row_start = v * hd;
779            for j in 0..hd {
780                sum += lm_head_weight[row_start + j] * hidden[j];
781            }
782            logits[v] = sum;
783        }
784        Ok(logits)
785    }
786
787    /// PMAT-325: Execute one transformer layer — 14 passes, 1 submit, 1 readback.
788    ///
789    /// Input: hidden state [hidden_dim] on CPU.
790    /// Output: updated hidden state [hidden_dim] on CPU.
791    /// All intermediate computation stays GPU-resident.
792    /// PMAT-344: KV cache parameters for multi-token context
793    pub fn forward_layer(
794        &self,
795        hidden: &mut [f32],
796        layer_prefix: &str,
797        _position: usize,
798        kv_cache_k: &mut Vec<f32>, // accumulated K: [seq_len * kv_dim]
799        kv_cache_v: &mut Vec<f32>, // accumulated V: [seq_len * kv_dim]
800    ) -> Result<(), String> {
801        let hd = self.hidden_dim;
802
803        // Upload hidden state
804        self.queue.write_buffer(&self.hidden_buf, 0, bytemuck::cast_slice(hidden));
805
806        let mut encoder = self.device.create_command_encoder(&Default::default());
807
808        // Pass 1: RMSNorm(hidden → norm_buf)
809        let norm_w = self
810            .weight_buffers
811            .get(&format!("{layer_prefix}.attn_norm"))
812            .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
813        self.encode_rmsnorm(&mut encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
814
815        // Passes 2-4: Q/K/V projections (norm_buf × W → q/k/v_buf)
816        let q_dim = self.num_heads * self.head_dim;
817        let kv_dim = self.num_kv_heads * self.head_dim;
818
819        self.encode_matmul(
820            &mut encoder,
821            &self.norm_buf,
822            layer_prefix,
823            "q_proj",
824            &self.q_buf,
825            1,
826            hd,
827            q_dim,
828        );
829        self.encode_matmul(
830            &mut encoder,
831            &self.norm_buf,
832            layer_prefix,
833            "k_proj",
834            &self.k_buf,
835            1,
836            hd,
837            kv_dim,
838        );
839        self.encode_matmul(
840            &mut encoder,
841            &self.norm_buf,
842            layer_prefix,
843            "v_proj",
844            &self.v_buf,
845            1,
846            hd,
847            kv_dim,
848        );
849
850        // PMAT-342: Submit Q/K/V projections, readback, do attention on CPU
851        // GPU handles the heavy matmuls; CPU handles attention (small at M=1)
852        let q_bytes = (q_dim * 4) as u64;
853        let kv_bytes = (kv_dim * 4) as u64;
854
855        // Readback Q/K/V from GPU
856        let q_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
857            label: Some("q_stg"),
858            size: q_bytes,
859            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
860            mapped_at_creation: false,
861        });
862        let k_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
863            label: Some("k_stg"),
864            size: kv_bytes,
865            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
866            mapped_at_creation: false,
867        });
868        let v_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
869            label: Some("v_stg"),
870            size: kv_bytes,
871            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
872            mapped_at_creation: false,
873        });
874        encoder.copy_buffer_to_buffer(&self.q_buf, 0, &q_staging, 0, q_bytes);
875        encoder.copy_buffer_to_buffer(&self.k_buf, 0, &k_staging, 0, kv_bytes);
876        encoder.copy_buffer_to_buffer(&self.v_buf, 0, &v_staging, 0, kv_bytes);
877        self.queue.submit(Some(encoder.finish()));
878
879        // Readback Q
880        let mut q_data = vec![0.0f32; q_dim as usize];
881        {
882            let slice = q_staging.slice(..q_bytes);
883            let (tx, rx) = std::sync::mpsc::channel();
884            slice.map_async(wgpu::MapMode::Read, move |r| {
885                tx.send(r).ok();
886            });
887            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
888            rx.recv().map_err(|e| format!("q recv: {e}"))?.map_err(|e| format!("q map: {e:?}"))?;
889            let data = slice.get_mapped_range();
890            q_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..q_dim as usize]);
891        }
892        q_staging.unmap();
893
894        // Readback K
895        let mut k_data = vec![0.0f32; kv_dim as usize];
896        {
897            let slice = k_staging.slice(..kv_bytes);
898            let (tx, rx) = std::sync::mpsc::channel();
899            slice.map_async(wgpu::MapMode::Read, move |r| {
900                tx.send(r).ok();
901            });
902            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
903            rx.recv().map_err(|e| format!("k recv: {e}"))?.map_err(|e| format!("k map: {e:?}"))?;
904            let data = slice.get_mapped_range();
905            k_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
906        }
907        k_staging.unmap();
908
909        // Readback V
910        let mut v_data = vec![0.0f32; kv_dim as usize];
911        {
912            let slice = v_staging.slice(..kv_bytes);
913            let (tx, rx) = std::sync::mpsc::channel();
914            slice.map_async(wgpu::MapMode::Read, move |r| {
915                tx.send(r).ok();
916            });
917            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
918            rx.recv().map_err(|e| format!("v recv: {e}"))?.map_err(|e| format!("v map: {e:?}"))?;
919            let data = slice.get_mapped_range();
920            v_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
921        }
922        v_staging.unmap();
923
924        // PMAT-342: Add QKV biases (required for Qwen2)
925        if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
926            for (q, b) in q_data.iter_mut().zip(q_bias.iter()) {
927                *q += *b;
928            }
929        }
930        if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
931            for (k, b) in k_data.iter_mut().zip(k_bias.iter()) {
932                *k += *b;
933            }
934        }
935        if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
936            for (v, b) in v_data.iter_mut().zip(v_bias.iter()) {
937                *v += *b;
938            }
939        }
940
941        // PMAT-343: Apply RoPE (NeoX-style interleaved) to Q and K
942        let head_dim = self.head_dim as usize;
943        let position = _position; // Use the position parameter
944        let rope_theta = 1_000_000.0f64; // Qwen2 rope_theta
945
946        // RoPE on Q (num_heads × head_dim)
947        for h in 0..(self.num_heads as usize) {
948            let offset = h * head_dim;
949            let half = head_dim / 2;
950            for i in 0..half {
951                let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
952                let angle = position as f64 * theta;
953                let cos_a = angle.cos() as f32;
954                let sin_a = angle.sin() as f32;
955                let x0 = q_data[offset + i];
956                let x1 = q_data[offset + i + half];
957                q_data[offset + i] = x0 * cos_a - x1 * sin_a;
958                q_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
959            }
960        }
961
962        // RoPE on K (num_kv_heads × head_dim)
963        for h in 0..(self.num_kv_heads as usize) {
964            let offset = h * head_dim;
965            let half = head_dim / 2;
966            for i in 0..half {
967                let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
968                let angle = position as f64 * theta;
969                let cos_a = angle.cos() as f32;
970                let sin_a = angle.sin() as f32;
971                let x0 = k_data[offset + i];
972                let x1 = k_data[offset + i + half];
973                k_data[offset + i] = x0 * cos_a - x1 * sin_a;
974                k_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
975            }
976        }
977
978        // PMAT-344: Append K,V to cache and compute full attention
979        let head_dim = self.head_dim as usize;
980        let num_heads = self.num_heads as usize;
981        let num_kv_heads = self.num_kv_heads as usize;
982        let kv_dim_usize = kv_dim as usize;
983
984        kv_cache_k.extend_from_slice(&k_data);
985        kv_cache_v.extend_from_slice(&v_data);
986        let seq_len = kv_cache_k.len() / kv_dim_usize;
987
988        // Scaled dot-product attention with GQA
989        let kv_group = num_heads / num_kv_heads;
990        let scale = 1.0 / (head_dim as f32).sqrt();
991        let mut attn_out = vec![0.0f32; q_dim as usize];
992
993        for h in 0..num_heads {
994            let kv_h = h / kv_group;
995            let q_offset = h * head_dim;
996
997            // Compute attention scores: Q[h] · K[kv_h, :seq_len]^T / sqrt(d)
998            let mut scores = vec![0.0f32; seq_len];
999            for s in 0..seq_len {
1000                let k_offset = s * kv_dim_usize + kv_h * head_dim;
1001                let mut dot = 0.0f32;
1002                for d in 0..head_dim {
1003                    dot += q_data[q_offset + d] * kv_cache_k[k_offset + d];
1004                }
1005                scores[s] = dot * scale;
1006            }
1007
1008            // Softmax
1009            let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1010            let mut sum = 0.0f32;
1011            for s in scores.iter_mut() {
1012                *s = (*s - max_score).exp();
1013                sum += *s;
1014            }
1015            if sum > 0.0 {
1016                for s in scores.iter_mut() {
1017                    *s /= sum;
1018                }
1019            }
1020
1021            // Weighted sum of V
1022            let out_offset = h * head_dim;
1023            for d in 0..head_dim {
1024                let mut val = 0.0f32;
1025                for s in 0..seq_len {
1026                    let v_offset = s * kv_dim_usize + kv_h * head_dim;
1027                    val += scores[s] * kv_cache_v[v_offset + d];
1028                }
1029                attn_out[out_offset + d] = val;
1030            }
1031        }
1032
1033        // Upload attention output back to GPU for O projection
1034        self.queue.write_buffer(&self.q_buf, 0, bytemuck::cast_slice(&attn_out));
1035
1036        // New encoder for remaining passes
1037        let mut encoder = self.device.create_command_encoder(&Default::default());
1038
1039        // Pass 7: O projection (attn_out × W_o → attn_out_buf)
1040        self.encode_matmul(
1041            &mut encoder,
1042            &self.q_buf,
1043            layer_prefix,
1044            "o_proj",
1045            &self.attn_out_buf,
1046            1,
1047            q_dim,
1048            hd,
1049        );
1050
1051        // Pass 8: Residual(hidden + attn_out → hidden)
1052        self.encode_residual(
1053            &mut encoder,
1054            &self.hidden_buf,
1055            &self.attn_out_buf,
1056            &self.ffn_out_buf,
1057            hd,
1058        );
1059
1060        // Pass 9: FFN RMSNorm(ffn_out → norm_buf)
1061        let ffn_norm_w = self
1062            .weight_buffers
1063            .get(&format!("{layer_prefix}.ffn_norm"))
1064            .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1065        self.encode_rmsnorm(&mut encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
1066
1067        // Passes 10-11: Gate + Up projections
1068        let inter = self.intermediate_dim;
1069        self.encode_matmul(
1070            &mut encoder,
1071            &self.norm_buf,
1072            layer_prefix,
1073            "gate_proj",
1074            &self.ffn_gate_buf,
1075            1,
1076            hd,
1077            inter,
1078        );
1079        self.encode_matmul(
1080            &mut encoder,
1081            &self.norm_buf,
1082            layer_prefix,
1083            "up_proj",
1084            &self.ffn_up_buf,
1085            1,
1086            hd,
1087            inter,
1088        );
1089
1090        // Pass 12: SiLU(gate) × up → ffn_silu_buf [intermediate_dim]
1091        // BUG FIX: was writing to attn_out_buf (hidden_dim=3584) but needs intermediate_dim=18944.
1092        // attn_out_buf is only hidden_dim — wgpu robustness silently drops OOB writes,
1093        // then down_proj reads zeros past hidden_dim → 81% of FFN truncated → garbage output.
1094        // Cannot alias gate/up buffers (WGSL read/write aliasing UB), so use dedicated buffer.
1095        self.encode_silu_mul(
1096            &mut encoder,
1097            &self.ffn_gate_buf,
1098            &self.ffn_up_buf,
1099            &self.ffn_silu_buf,
1100            inter,
1101        );
1102
1103        // Pass 13: Down projection (reads ffn_silu_buf [intermediate_dim] → norm_buf [hidden_dim])
1104        self.encode_matmul(
1105            &mut encoder,
1106            &self.ffn_silu_buf,
1107            layer_prefix,
1108            "down_proj",
1109            &self.norm_buf,
1110            1,
1111            inter,
1112            hd,
1113        );
1114
1115        // Pass 14: Residual(ffn_out + down → hidden)
1116        self.encode_residual(&mut encoder, &self.ffn_out_buf, &self.norm_buf, &self.hidden_buf, hd);
1117
1118        // Single readback
1119        encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &self.staging_buf, 0, (hd * 4) as u64);
1120        self.queue.submit(Some(encoder.finish()));
1121
1122        // Readback
1123        let slice = self.staging_buf.slice(..(hd as u64 * 4));
1124        let (tx, rx) = std::sync::mpsc::channel();
1125        slice.map_async(wgpu::MapMode::Read, move |r| {
1126            tx.send(r).ok();
1127        });
1128        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1129        rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
1130        {
1131            let data = slice.get_mapped_range();
1132            hidden.copy_from_slice(
1133                &bytemuck::cast_slice::<u8, f32>(&data)[..self.hidden_dim as usize],
1134            );
1135        }
1136        self.staging_buf.unmap();
1137
1138        Ok(())
1139    }
1140
1141    /// Training forward pass for a single transformer layer.
1142    ///
1143    /// Unlike `forward_layer` (M=1 decode), this processes the full sequence
1144    /// at once (M=seq_len) and keeps everything on GPU. No CPU readback.
1145    ///
1146    /// Saves `norm_output` (pre-projection activations) for backward pass.
1147    ///
1148    /// # Arguments
1149    /// - `seq_len`: number of tokens in the sequence
1150    /// - `layer_prefix`: e.g. "model.layers.0"
1151    /// - `saved_norm_attn`: OUTPUT — saved pre-attention norm for backward (wgpu::Buffer, [seq×hidden])
1152    /// - `saved_norm_ffn`: OUTPUT — saved pre-FFN norm for backward (wgpu::Buffer, [seq×hidden])
1153    ///
1154    /// Forward one layer into an EXISTING encoder (no submit).
1155    /// Caller batches multiple layers into one encoder, submits once.
1156    pub fn encode_forward_layer_training(
1157        &self,
1158        encoder: &mut wgpu::CommandEncoder,
1159        seq_len: u32,
1160        layer_prefix: &str,
1161        saved: &LayerActivations,
1162        lora: Option<&QkvLoRA<'_>>,
1163    ) -> Result<(), String> {
1164        let hd = self.hidden_dim;
1165        let q_dim = self.num_heads * self.head_dim;
1166        let kv_dim = self.num_kv_heads * self.head_dim;
1167        let inter = self.intermediate_dim;
1168        let s = seq_len as usize;
1169
1170        // Pass 1: RMSNorm
1171        let norm_w = self
1172            .weight_buffers
1173            .get(&format!("{layer_prefix}.attn_norm"))
1174            .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
1175        self.encode_rmsnorm(encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
1176
1177        // SAVE attn_norm_out
1178        encoder.copy_buffer_to_buffer(
1179            &self.norm_buf,
1180            0,
1181            &saved.attn_norm_out,
1182            0,
1183            (s * hd as usize * 4) as u64,
1184        );
1185
1186        // Q/K/V projections
1187        self.encode_matmul(
1188            encoder,
1189            &self.norm_buf,
1190            layer_prefix,
1191            "q_proj",
1192            &self.q_buf,
1193            seq_len,
1194            hd,
1195            q_dim,
1196        );
1197        self.encode_matmul(
1198            encoder,
1199            &self.norm_buf,
1200            layer_prefix,
1201            "k_proj",
1202            &self.k_buf,
1203            seq_len,
1204            hd,
1205            kv_dim,
1206        );
1207        self.encode_matmul(
1208            encoder,
1209            &self.norm_buf,
1210            layer_prefix,
1211            "v_proj",
1212            &self.v_buf,
1213            seq_len,
1214            hd,
1215            kv_dim,
1216        );
1217
1218        // LoRA addmm on Q/K/V: output += (saved_input @ A) @ B * scale
1219        // Must happen BEFORE attention consumes Q/K/V buffers.
1220        if let Some(lora) = lora {
1221            self.encode_lora_addmm(
1222                encoder,
1223                &saved.attn_norm_out,
1224                lora.q_a,
1225                lora.q_b,
1226                &self.q_buf,
1227                seq_len,
1228                lora.in_dim,
1229                lora.rank,
1230                lora.q_dim,
1231                lora.scale,
1232                lora.lora_pipeline,
1233                lora.lora_bgl,
1234            );
1235            self.encode_lora_addmm(
1236                encoder,
1237                &saved.attn_norm_out,
1238                lora.k_a,
1239                lora.k_b,
1240                &self.k_buf,
1241                seq_len,
1242                lora.in_dim,
1243                lora.rank,
1244                lora.kv_dim,
1245                lora.scale,
1246                lora.lora_pipeline,
1247                lora.lora_bgl,
1248            );
1249            self.encode_lora_addmm(
1250                encoder,
1251                &saved.attn_norm_out,
1252                lora.v_a,
1253                lora.v_b,
1254                &self.v_buf,
1255                seq_len,
1256                lora.in_dim,
1257                lora.rank,
1258                lora.kv_dim,
1259                lora.scale,
1260                lora.lora_pipeline,
1261                lora.lora_bgl,
1262            );
1263        }
1264
1265        // PMAT-509: Apply QKV biases (required for Qwen2)
1266        if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
1267            self.encode_broadcast_bias(encoder, &self.q_buf, q_bias, seq_len);
1268        }
1269        if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
1270            self.encode_broadcast_bias(encoder, &self.k_buf, k_bias, seq_len);
1271        }
1272        if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
1273            self.encode_broadcast_bias(encoder, &self.v_buf, v_bias, seq_len);
1274        }
1275
1276        // PMAT-509: Apply RoPE to Q and K before attention.
1277        self.encode_batch_rope(encoder, &self.q_buf, seq_len, self.num_heads, self.head_dim);
1278        self.encode_batch_rope(encoder, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim);
1279
1280        // Attention — wgpu handles execution ordering within the encoder.
1281        self.encode_attention(encoder, seq_len);
1282
1283        // SAVE attn_output
1284        encoder.copy_buffer_to_buffer(
1285            &self.attn_out_buf,
1286            0,
1287            &saved.attn_output,
1288            0,
1289            (s * q_dim as usize * 4) as u64,
1290        );
1291
1292        // O projection
1293        self.encode_matmul(
1294            encoder,
1295            &self.attn_out_buf,
1296            layer_prefix,
1297            "o_proj",
1298            &self.q_buf,
1299            seq_len,
1300            q_dim,
1301            hd,
1302        );
1303
1304        // Residual
1305        self.encode_residual(
1306            encoder,
1307            &self.hidden_buf,
1308            &self.q_buf,
1309            &self.ffn_out_buf,
1310            hd * seq_len,
1311        );
1312
1313        // FFN RMSNorm
1314        let ffn_norm_w = self
1315            .weight_buffers
1316            .get(&format!("{layer_prefix}.ffn_norm"))
1317            .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1318        self.encode_rmsnorm(encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
1319
1320        // SAVE ffn_norm_out
1321        encoder.copy_buffer_to_buffer(
1322            &self.norm_buf,
1323            0,
1324            &saved.ffn_norm_out,
1325            0,
1326            (s * hd as usize * 4) as u64,
1327        );
1328
1329        // Gate + Up
1330        self.encode_matmul(
1331            encoder,
1332            &self.norm_buf,
1333            layer_prefix,
1334            "gate_proj",
1335            &self.ffn_gate_buf,
1336            seq_len,
1337            hd,
1338            inter,
1339        );
1340        self.encode_matmul(
1341            encoder,
1342            &self.norm_buf,
1343            layer_prefix,
1344            "up_proj",
1345            &self.ffn_up_buf,
1346            seq_len,
1347            hd,
1348            inter,
1349        );
1350
1351        // SiLU
1352        self.encode_silu_mul(
1353            encoder,
1354            &self.ffn_gate_buf,
1355            &self.ffn_up_buf,
1356            &self.ffn_silu_buf,
1357            inter * seq_len,
1358        );
1359
1360        // SAVE silu_gate_output
1361        encoder.copy_buffer_to_buffer(
1362            &self.ffn_silu_buf,
1363            0,
1364            &saved.silu_gate_output,
1365            0,
1366            (s * inter as usize * 4) as u64,
1367        );
1368
1369        // Down projection
1370        self.encode_matmul(
1371            encoder,
1372            &self.ffn_silu_buf,
1373            layer_prefix,
1374            "down_proj",
1375            &self.norm_buf,
1376            seq_len,
1377            inter,
1378            hd,
1379        );
1380
1381        // Residual
1382        self.encode_residual(
1383            encoder,
1384            &self.ffn_out_buf,
1385            &self.norm_buf,
1386            &self.hidden_buf,
1387            hd * seq_len,
1388        );
1389
1390        Ok(())
1391    }
1392
1393    /// Run one layer with per-operation GPU timing (submit+poll between each op group).
1394    /// Contract: forward-pass-perf-v1 / bottleneck_identified
1395    pub fn forward_layer_traced(
1396        &self,
1397        seq_len: u32,
1398        layer_prefix: &str,
1399        saved: &LayerActivations,
1400        lora: Option<&QkvLoRA<'_>>,
1401    ) -> Result<(), String> {
1402        let hd = self.hidden_dim;
1403        let q_dim = self.num_heads * self.head_dim;
1404        let kv_dim = self.num_kv_heads * self.head_dim;
1405        let inter = self.intermediate_dim;
1406        let s = seq_len as usize;
1407
1408        let norm_w = self
1409            .weight_buffers
1410            .get(&format!("{layer_prefix}.attn_norm"))
1411            .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
1412
1413        let mut trace = Vec::new();
1414        let mut run = |name: &str, f: &dyn Fn(&mut wgpu::CommandEncoder)| {
1415            let mut enc = self.device.create_command_encoder(&Default::default());
1416            f(&mut enc);
1417            self.queue.submit(Some(enc.finish()));
1418            let t = std::time::Instant::now();
1419            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1420            trace.push((name.to_string(), t.elapsed().as_millis() as u64));
1421        };
1422
1423        run("rmsnorm1", &|e| self.encode_rmsnorm(e, &self.hidden_buf, norm_w, &self.norm_buf, hd));
1424        {
1425            let mut e = self.device.create_command_encoder(&Default::default());
1426            e.copy_buffer_to_buffer(
1427                &self.norm_buf,
1428                0,
1429                &saved.attn_norm_out,
1430                0,
1431                (s * hd as usize * 4) as u64,
1432            );
1433            self.queue.submit(Some(e.finish()));
1434        }
1435        run("q_proj", &|e| {
1436            self.encode_matmul(
1437                e,
1438                &self.norm_buf,
1439                layer_prefix,
1440                "q_proj",
1441                &self.q_buf,
1442                seq_len,
1443                hd,
1444                q_dim,
1445            )
1446        });
1447        run("k_proj", &|e| {
1448            self.encode_matmul(
1449                e,
1450                &self.norm_buf,
1451                layer_prefix,
1452                "k_proj",
1453                &self.k_buf,
1454                seq_len,
1455                hd,
1456                kv_dim,
1457            )
1458        });
1459        run("v_proj", &|e| {
1460            self.encode_matmul(
1461                e,
1462                &self.norm_buf,
1463                layer_prefix,
1464                "v_proj",
1465                &self.v_buf,
1466                seq_len,
1467                hd,
1468                kv_dim,
1469            )
1470        });
1471        if let Some(lr) = lora {
1472            run("lora_qkv", &|e| {
1473                self.encode_lora_addmm(
1474                    e,
1475                    &saved.attn_norm_out,
1476                    lr.q_a,
1477                    lr.q_b,
1478                    &self.q_buf,
1479                    seq_len,
1480                    lr.in_dim,
1481                    lr.rank,
1482                    lr.q_dim,
1483                    lr.scale,
1484                    lr.lora_pipeline,
1485                    lr.lora_bgl,
1486                );
1487                self.encode_lora_addmm(
1488                    e,
1489                    &saved.attn_norm_out,
1490                    lr.k_a,
1491                    lr.k_b,
1492                    &self.k_buf,
1493                    seq_len,
1494                    lr.in_dim,
1495                    lr.rank,
1496                    lr.kv_dim,
1497                    lr.scale,
1498                    lr.lora_pipeline,
1499                    lr.lora_bgl,
1500                );
1501                self.encode_lora_addmm(
1502                    e,
1503                    &saved.attn_norm_out,
1504                    lr.v_a,
1505                    lr.v_b,
1506                    &self.v_buf,
1507                    seq_len,
1508                    lr.in_dim,
1509                    lr.rank,
1510                    lr.kv_dim,
1511                    lr.scale,
1512                    lr.lora_pipeline,
1513                    lr.lora_bgl,
1514                );
1515            });
1516        }
1517        // PMAT-509: QKV biases + RoPE before attention
1518        if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
1519            run("q_bias", &|e| self.encode_broadcast_bias(e, &self.q_buf, q_bias, seq_len));
1520        }
1521        if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
1522            run("k_bias", &|e| self.encode_broadcast_bias(e, &self.k_buf, k_bias, seq_len));
1523        }
1524        if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
1525            run("v_bias", &|e| self.encode_broadcast_bias(e, &self.v_buf, v_bias, seq_len));
1526        }
1527        run("rope_q", &|e| {
1528            self.encode_batch_rope(e, &self.q_buf, seq_len, self.num_heads, self.head_dim)
1529        });
1530        run("rope_k", &|e| {
1531            self.encode_batch_rope(e, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim)
1532        });
1533        run("attention", &|e| self.encode_attention(e, seq_len));
1534        {
1535            let mut e = self.device.create_command_encoder(&Default::default());
1536            e.copy_buffer_to_buffer(
1537                &self.attn_out_buf,
1538                0,
1539                &saved.attn_output,
1540                0,
1541                (s * q_dim as usize * 4) as u64,
1542            );
1543            self.queue.submit(Some(e.finish()));
1544        }
1545        run("o_proj", &|e| {
1546            self.encode_matmul(
1547                e,
1548                &self.attn_out_buf,
1549                layer_prefix,
1550                "o_proj",
1551                &self.q_buf,
1552                seq_len,
1553                q_dim,
1554                hd,
1555            )
1556        });
1557        run("residual1", &|e| {
1558            self.encode_residual(e, &self.hidden_buf, &self.q_buf, &self.ffn_out_buf, hd * seq_len)
1559        });
1560        let ffn_norm_w = self
1561            .weight_buffers
1562            .get(&format!("{layer_prefix}.ffn_norm"))
1563            .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1564        run("rmsnorm2", &|e| {
1565            self.encode_rmsnorm(e, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd)
1566        });
1567        {
1568            let mut e = self.device.create_command_encoder(&Default::default());
1569            e.copy_buffer_to_buffer(
1570                &self.norm_buf,
1571                0,
1572                &saved.ffn_norm_out,
1573                0,
1574                (s * hd as usize * 4) as u64,
1575            );
1576            self.queue.submit(Some(e.finish()));
1577        }
1578        run("gate_proj", &|e| {
1579            self.encode_matmul(
1580                e,
1581                &self.norm_buf,
1582                layer_prefix,
1583                "gate_proj",
1584                &self.ffn_gate_buf,
1585                seq_len,
1586                hd,
1587                inter,
1588            )
1589        });
1590        run("up_proj", &|e| {
1591            self.encode_matmul(
1592                e,
1593                &self.norm_buf,
1594                layer_prefix,
1595                "up_proj",
1596                &self.ffn_up_buf,
1597                seq_len,
1598                hd,
1599                inter,
1600            )
1601        });
1602        run("silu", &|e| {
1603            self.encode_silu_mul(
1604                e,
1605                &self.ffn_gate_buf,
1606                &self.ffn_up_buf,
1607                &self.ffn_silu_buf,
1608                inter * seq_len,
1609            )
1610        });
1611        {
1612            let mut e = self.device.create_command_encoder(&Default::default());
1613            e.copy_buffer_to_buffer(
1614                &self.ffn_silu_buf,
1615                0,
1616                &saved.silu_gate_output,
1617                0,
1618                (s * inter as usize * 4) as u64,
1619            );
1620            self.queue.submit(Some(e.finish()));
1621        }
1622        run("down_proj", &|e| {
1623            self.encode_matmul(
1624                e,
1625                &self.ffn_silu_buf,
1626                layer_prefix,
1627                "down_proj",
1628                &self.norm_buf,
1629                seq_len,
1630                inter,
1631                hd,
1632            )
1633        });
1634        run("residual2", &|e| {
1635            self.encode_residual(
1636                e,
1637                &self.ffn_out_buf,
1638                &self.norm_buf,
1639                &self.hidden_buf,
1640                hd * seq_len,
1641            )
1642        });
1643
1644        let total: u64 = trace.iter().map(|(_, ms)| ms).sum();
1645        let parts: Vec<String> = trace.iter().map(|(n, ms)| format!("{n}={ms}")).collect();
1646        eprintln!("[OP-TRACE] layer {} total={}ms: {}", layer_prefix, total, parts.join(" "));
1647        Ok(())
1648    }
1649
1650    /// Allocate saved activations for one layer.
1651    pub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations {
1652        let s = seq_len as usize;
1653        let buf = |size: usize, label: &str| -> wgpu::Buffer {
1654            self.device.create_buffer(&wgpu::BufferDescriptor {
1655                label: Some(label),
1656                size: (size * 4) as u64,
1657                usage: wgpu::BufferUsages::STORAGE
1658                    | wgpu::BufferUsages::COPY_SRC
1659                    | wgpu::BufferUsages::COPY_DST,
1660                mapped_at_creation: false,
1661            })
1662        };
1663        LayerActivations {
1664            attn_norm_out: buf(s * self.hidden_dim as usize, "saved_attn_norm"),
1665            attn_output: buf(s * (self.num_heads * self.head_dim) as usize, "saved_attn_out"),
1666            ffn_norm_out: buf(s * self.hidden_dim as usize, "saved_ffn_norm"),
1667            silu_gate_output: buf(s * self.intermediate_dim as usize, "saved_silu"),
1668            rstd_attn: buf(s, "saved_rstd_attn"),
1669            rstd_ffn: buf(s, "saved_rstd_ffn"),
1670            softmax_logsumexp: buf(self.num_heads as usize * s, "saved_logsumexp"),
1671        }
1672    }
1673
1674    /// Forward one layer with its own encoder + submit (original API, kept for compat).
1675    pub fn forward_layer_training(
1676        &self,
1677        seq_len: u32,
1678        layer_prefix: &str,
1679    ) -> Result<LayerActivations, String> {
1680        let saved = self.alloc_layer_activations(seq_len);
1681        let mut encoder = self.device.create_command_encoder(&Default::default());
1682        self.encode_forward_layer_training(&mut encoder, seq_len, layer_prefix, &saved, None)?;
1683        self.queue.submit(Some(encoder.finish()));
1684        Ok(saved)
1685    }
1686
1687    /// Forward ALL layers in one encoder submit. 28 layers → 1 GPU sync.
1688    pub fn forward_all_layers_training(
1689        &self,
1690        seq_len: u32,
1691        num_layers: usize,
1692    ) -> Result<Vec<LayerActivations>, String> {
1693        let mut encoder = self.device.create_command_encoder(&Default::default());
1694        let mut all_saved = Vec::with_capacity(num_layers);
1695
1696        for layer_idx in 0..num_layers {
1697            let prefix = format!("layer.{layer_idx}");
1698            let saved = self.alloc_layer_activations(seq_len);
1699            self.encode_forward_layer_training(&mut encoder, seq_len, &prefix, &saved, None)?;
1700            all_saved.push(saved);
1701        }
1702
1703        // ONE submit for all 28 layers — eliminates 27 GPU sync barriers
1704        self.queue.submit(Some(encoder.finish()));
1705        Ok(all_saved)
1706    }
1707    // --- Encode helpers (add compute passes to an existing encoder) ---
1708
1709    /// Encode causal multi-head attention on GPU.
1710    /// Q: [seq_len, num_heads * head_dim], K/V: [seq_len, num_kv_heads * head_dim]
1711    /// Output written to q_buf (reused as attn output).
1712    /// PMAT-509: Add broadcast bias to a [seq_len, dim] buffer.
1713    /// bias has shape [dim], applied to each of seq_len rows.
1714    pub fn encode_broadcast_bias(
1715        &self,
1716        encoder: &mut wgpu::CommandEncoder,
1717        buf: &wgpu::Buffer,
1718        bias: &[f32],
1719        seq_len: u32,
1720    ) {
1721        let dim = bias.len();
1722        // Create a full-size bias buffer by repeating the bias per position
1723        let mut full_bias = Vec::with_capacity(seq_len as usize * dim);
1724        for _ in 0..seq_len {
1725            full_bias.extend_from_slice(bias);
1726        }
1727        let bias_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1728            label: Some("broadcast_bias"),
1729            size: (full_bias.len() * 4) as u64,
1730            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1731            mapped_at_creation: false,
1732        });
1733        self.queue.write_buffer(&bias_buf, 0, bytemuck::cast_slice(&full_bias));
1734
1735        // Use existing residual: out = buf + bias_buf (into a temp, then copy back)
1736        let total = seq_len * dim as u32;
1737        let tmp = self.device.create_buffer(&wgpu::BufferDescriptor {
1738            label: Some("bias_tmp"),
1739            size: (total as usize * 4) as u64,
1740            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1741            mapped_at_creation: false,
1742        });
1743        self.encode_residual(encoder, buf, &bias_buf, &tmp, total);
1744        encoder.copy_buffer_to_buffer(&tmp, 0, buf, 0, (total as u64) * 4);
1745    }
1746
1747    /// PMAT-509: Encode batch RoPE for all positions in a sequence.
1748    /// Applies position-dependent rotation to Q or K buffer in-place.
1749    fn encode_batch_rope(
1750        &self,
1751        encoder: &mut wgpu::CommandEncoder,
1752        qk_buf: &wgpu::Buffer,
1753        seq_len: u32,
1754        num_heads: u32,
1755        head_dim: u32,
1756    ) {
1757        #[repr(C)]
1758        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1759        struct RopeParams {
1760            seq_len: u32,
1761            num_heads: u32,
1762            head_dim: u32,
1763            _pad: u32,
1764        }
1765        let params = RopeParams { seq_len, num_heads, head_dim, _pad: 0 };
1766        let params_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1767            label: Some("batch_rope_params"),
1768            size: 16,
1769            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1770            mapped_at_creation: false,
1771        });
1772        self.queue.write_buffer(&params_buf, 0, bytemuck::bytes_of(&params));
1773
1774        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1775            label: Some("batch_rope_bg"),
1776            layout: &self.batch_rope_bgl,
1777            entries: &[
1778                wgpu::BindGroupEntry { binding: 0, resource: qk_buf.as_entire_binding() },
1779                wgpu::BindGroupEntry { binding: 1, resource: params_buf.as_entire_binding() },
1780            ],
1781        });
1782        let total = seq_len * num_heads * head_dim;
1783        let wg = total.div_ceil(256);
1784        let mut pass = encoder.begin_compute_pass(&Default::default());
1785        pass.set_pipeline(&self.batch_rope_pipeline);
1786        pass.set_bind_group(0, &bg, &[]);
1787        pass.dispatch_workgroups(wg, 1, 1);
1788    }
1789
1790    fn encode_attention(&self, encoder: &mut wgpu::CommandEncoder, seq_len: u32) {
1791        let params = [seq_len, self.num_heads, self.num_kv_heads, self.head_dim];
1792        let params_buf = self.make_uniform(&params);
1793        let _q_dim = self.num_heads * self.head_dim;
1794
1795        // Attention reads Q and writes to attn_out_buf.
1796        // Then O projection reads attn_out_buf → writes to another buffer.
1797        // We can safely write to norm_buf here since it's not read during attention.
1798        // After attention, we'll copy norm_buf → q_buf for the O projection to read.
1799        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1800            label: None,
1801            layout: &self.attention_bgl,
1802            entries: &[
1803                wgpu::BindGroupEntry { binding: 0, resource: self.q_buf.as_entire_binding() },
1804                wgpu::BindGroupEntry { binding: 1, resource: self.k_buf.as_entire_binding() },
1805                wgpu::BindGroupEntry { binding: 2, resource: self.v_buf.as_entire_binding() },
1806                wgpu::BindGroupEntry {
1807                    binding: 3,
1808                    resource: self.attn_out_buf.as_entire_binding(),
1809                },
1810                wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
1811            ],
1812        });
1813        let mut pass = encoder.begin_compute_pass(&Default::default());
1814        pass.set_pipeline(&self.attention_pipeline);
1815        pass.set_bind_group(0, &bg, &[]);
1816        // One workgroup per (head, position)
1817        pass.dispatch_workgroups(self.num_heads, seq_len, 1);
1818    }
1819
1820    /// Encode LoRA addmm: output += (input @ A) @ B * scale
1821    ///
1822    /// KAIZEN: replaced fused shader (0.11 GFLOPS) with two tiled GEMM dispatches (1000+ GFLOPS).
1823    /// Step 1: temp = input @ A  [seq, rank] via tiled GEMM
1824    /// Step 2: output += scale * (temp @ B) [seq, out_dim] via tiled GEMM with alpha=scale
1825    ///
1826    /// The second GEMM uses alpha=scale in the tiled GEMM shader (C = alpha * A @ B).
1827    /// But we need ADD (+=), not overwrite (=). We use a temp buffer for the delta,
1828    /// then add to output via an elementwise shader.
1829    #[allow(clippy::too_many_arguments)]
1830    fn encode_lora_addmm(
1831        &self,
1832        encoder: &mut wgpu::CommandEncoder,
1833        input: &wgpu::Buffer,
1834        lora_a: &wgpu::Buffer,
1835        lora_b: &wgpu::Buffer,
1836        output: &wgpu::Buffer,
1837        seq_len: u32,
1838        in_dim: u32,
1839        rank: u32,
1840        out_dim: u32,
1841        scale: f32,
1842        _pipeline: &wgpu::ComputePipeline,
1843        _bgl: &wgpu::BindGroupLayout,
1844    ) {
1845        // Step 1: temp[seq, rank] = input[seq, in_dim] @ A[in_dim, rank]
1846        let temp_size = (seq_len * rank) as u64 * 4;
1847        let temp = self.device.create_buffer(&wgpu::BufferDescriptor {
1848            label: Some("lora_temp"),
1849            size: temp_size,
1850            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1851            mapped_at_creation: false,
1852        });
1853        self.encode_tiled_gemm(encoder, input, lora_a, &temp, seq_len, in_dim, rank, 1.0);
1854
1855        // Step 2: delta[seq, out_dim] = scale * temp[seq, rank] @ B[rank, out_dim]
1856        let delta_size = (seq_len * out_dim) as u64 * 4;
1857        let delta = self.device.create_buffer(&wgpu::BufferDescriptor {
1858            label: Some("lora_delta"),
1859            size: delta_size,
1860            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1861            mapped_at_creation: false,
1862        });
1863        self.encode_tiled_gemm(encoder, &temp, lora_b, &delta, seq_len, rank, out_dim, scale);
1864
1865        // Step 3: output += delta (elementwise add, via temp to avoid aliasing)
1866        let sum_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1867            label: Some("lora_sum"),
1868            size: delta_size,
1869            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1870            mapped_at_creation: false,
1871        });
1872        self.encode_residual(encoder, output, &delta, &sum_buf, seq_len * out_dim);
1873        encoder.copy_buffer_to_buffer(&sum_buf, 0, output, 0, delta_size);
1874    }
1875
1876    /// Encode tiled GEMM: C = alpha * A[M,K] @ B[K,N]. Uses CUTLASS-style 64×64 tiles.
1877    fn encode_tiled_gemm(
1878        &self,
1879        encoder: &mut wgpu::CommandEncoder,
1880        a: &wgpu::Buffer,
1881        b: &wgpu::Buffer,
1882        c: &wgpu::Buffer,
1883        m: u32,
1884        k: u32,
1885        n: u32,
1886        alpha: f32,
1887    ) {
1888        let params = [m, k, n, alpha.to_bits()];
1889        let params_buf = self.make_uniform(&params);
1890        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1891            label: None,
1892            layout: &self.matmul_bgl,
1893            entries: &[
1894                wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
1895                wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
1896                wgpu::BindGroupEntry { binding: 2, resource: c.as_entire_binding() },
1897                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1898            ],
1899        });
1900        let mut pass = encoder.begin_compute_pass(&Default::default());
1901        pass.set_pipeline(&self.tiled_matmul_pipeline);
1902        pass.set_bind_group(0, &bg, &[]);
1903        pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
1904    }
1905
1906    fn encode_rmsnorm(
1907        &self,
1908        encoder: &mut wgpu::CommandEncoder,
1909        input: &wgpu::Buffer,
1910        weight: &wgpu::Buffer,
1911        output: &wgpu::Buffer,
1912        dim: u32,
1913    ) {
1914        let params = [dim, 0u32, 0, 0];
1915        let params_buf = self.make_uniform(&params);
1916        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1917            label: None,
1918            layout: &self.elementwise_bgl,
1919            entries: &[
1920                wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
1921                wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
1922                wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
1923                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1924            ],
1925        });
1926        // Dispatch: (1, num_rows, 1). Each workgroup processes one row via wg_id.y.
1927        // For inference (M=1): dispatch (1,1,1). For training (M=seq_len): dispatch (1,seq_len,1).
1928        let num_rows = (input.size() / (dim as u64 * 4)).max(1) as u32;
1929        let mut pass = encoder.begin_compute_pass(&Default::default());
1930        pass.set_pipeline(&self.rmsnorm_pipeline);
1931        pass.set_bind_group(0, &bg, &[]);
1932        pass.dispatch_workgroups(1, num_rows, 1);
1933    }
1934
1935    fn encode_matmul(
1936        &self,
1937        encoder: &mut wgpu::CommandEncoder,
1938        input: &wgpu::Buffer,
1939        layer_prefix: &str,
1940        proj_name: &str,
1941        output: &wgpu::Buffer,
1942        m: u32,
1943        k: u32,
1944        n: u32,
1945    ) {
1946        // C-WGPU-Q4K-001: Try Q4K GEMV first for M=1 decode (7x less VRAM)
1947        if m == 1 && self.encode_q4k_gemv(encoder, input, output, layer_prefix, proj_name, n, k) {
1948            return;
1949        }
1950        let weight_key = format!("{layer_prefix}.{proj_name}");
1951        let weight = match self.weight_buffers.get(&weight_key) {
1952            Some(w) => w,
1953            None => return, // Skip missing weights silently
1954        };
1955        // PMAT-346: GEMV and matmul have different uniform struct layouts.
1956        // GEMV: Params { n (output dim), k (input dim), _, _ }
1957        // Matmul: Dimensions { M, K, N, _ }
1958        // Tiled GEMM: Dimensions { M, K, N, alpha_bits }
1959        let params = if m == 1 { [n, k, 0u32, 0u32] } else { [m, k, n, 1.0_f32.to_bits()] };
1960        let params_buf = self.make_uniform(&params);
1961        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1962            label: None,
1963            layout: &self.matmul_bgl,
1964            entries: &[
1965                wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
1966                wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
1967                wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
1968                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1969            ],
1970        });
1971        let mut pass = encoder.begin_compute_pass(&Default::default());
1972        if m == 1 {
1973            // PMAT-327: GEMV for M=1 — cooperative K-reduction, N workgroups
1974            pass.set_pipeline(&self.gemv_pipeline);
1975            pass.set_bind_group(0, &bg, &[]);
1976            pass.dispatch_workgroups(n, 1, 1);
1977        } else if m >= 4 {
1978            // CUTLASS-style tiled GEMM for M>=4 (training batch, prefill)
1979            // 64×64 tiles, 4×4 thread micro-tiles, 10-30x faster than naive
1980            pass.set_pipeline(&self.tiled_matmul_pipeline);
1981            pass.set_bind_group(0, &bg, &[]);
1982            pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
1983        } else {
1984            // Naive 16×16 GEMM for small M (2-3)
1985            pass.set_pipeline(&self.matmul_pipeline);
1986            pass.set_bind_group(0, &bg, &[]);
1987            pass.dispatch_workgroups(m.div_ceil(16), n.div_ceil(16), 1);
1988        }
1989    }
1990
1991    /// C-WGPU-Q4K-001: Encode Q4K GEMV — reads raw Q4K weight bytes, dequantizes on-the-fly.
1992    /// Falls back to F32 GEMV if no Q4K weight found for this layer.
1993    /// Returns true if Q4K path was used.
1994    fn encode_q4k_gemv(
1995        &self,
1996        encoder: &mut wgpu::CommandEncoder,
1997        input: &wgpu::Buffer,
1998        output: &wgpu::Buffer,
1999        layer_prefix: &str,
2000        proj_name: &str,
2001        n: u32,
2002        k: u32,
2003    ) -> bool {
2004        let weight_key = format!("{layer_prefix}.{proj_name}");
2005        let weight = match self.q4k_weights.get(&weight_key) {
2006            Some(w) => w,
2007            None => return false,
2008        };
2009        let num_superblocks = (k + 255) / 256;
2010        let params = [n, k, num_superblocks, 0u32];
2011        let params_buf = self.make_uniform(&params);
2012        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2013            label: None,
2014            layout: &self.matmul_bgl,
2015            entries: &[
2016                wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
2017                wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
2018                wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2019                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2020            ],
2021        });
2022        let mut pass = encoder.begin_compute_pass(&Default::default());
2023        pass.set_pipeline(&self.q4k_gemv_pipeline);
2024        pass.set_bind_group(0, &bg, &[]);
2025        pass.dispatch_workgroups(n, 1, 1);
2026        true
2027    }
2028
2029    fn encode_silu_mul(
2030        &self,
2031        encoder: &mut wgpu::CommandEncoder,
2032        gate: &wgpu::Buffer,
2033        up: &wgpu::Buffer,
2034        output: &wgpu::Buffer,
2035        dim: u32,
2036    ) {
2037        let params = [dim, 0u32, 0, 0];
2038        let params_buf = self.make_uniform(&params);
2039        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2040            label: None,
2041            layout: &self.elementwise_bgl,
2042            entries: &[
2043                wgpu::BindGroupEntry { binding: 0, resource: gate.as_entire_binding() },
2044                wgpu::BindGroupEntry { binding: 1, resource: up.as_entire_binding() },
2045                wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2046                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2047            ],
2048        });
2049        let mut pass = encoder.begin_compute_pass(&Default::default());
2050        pass.set_pipeline(&self.silu_mul_pipeline);
2051        pass.set_bind_group(0, &bg, &[]);
2052        pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
2053    }
2054
2055    fn encode_residual(
2056        &self,
2057        encoder: &mut wgpu::CommandEncoder,
2058        a: &wgpu::Buffer,
2059        b: &wgpu::Buffer,
2060        output: &wgpu::Buffer,
2061        dim: u32,
2062    ) {
2063        let params = [dim, 0u32, 0, 0];
2064        let params_buf = self.make_uniform(&params);
2065        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2066            label: None,
2067            layout: &self.elementwise_bgl,
2068            entries: &[
2069                wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
2070                wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
2071                wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2072                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2073            ],
2074        });
2075        let mut pass = encoder.begin_compute_pass(&Default::default());
2076        pass.set_pipeline(&self.residual_pipeline);
2077        pass.set_bind_group(0, &bg, &[]);
2078        pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
2079    }
2080
2081    fn make_uniform(&self, data: &[u32; 4]) -> wgpu::Buffer {
2082        use wgpu::util::DeviceExt;
2083        self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
2084            label: None,
2085            contents: bytemuck::cast_slice(data),
2086            usage: wgpu::BufferUsages::UNIFORM,
2087        })
2088    }
2089}
2090
2091fn bgl_storage(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
2092    wgpu::BindGroupLayoutEntry {
2093        binding,
2094        visibility: wgpu::ShaderStages::COMPUTE,
2095        ty: wgpu::BindingType::Buffer {
2096            ty: wgpu::BufferBindingType::Storage { read_only },
2097            has_dynamic_offset: false,
2098            min_binding_size: None,
2099        },
2100        count: None,
2101    }
2102}
2103
2104fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
2105    wgpu::BindGroupLayoutEntry {
2106        binding,
2107        visibility: wgpu::ShaderStages::COMPUTE,
2108        ty: wgpu::BindingType::Buffer {
2109            ty: wgpu::BufferBindingType::Uniform,
2110            has_dynamic_offset: false,
2111            min_binding_size: None,
2112        },
2113        count: None,
2114    }
2115}