Skip to main content

entrenar/finetune/
wgpu_pipeline.rs

1//! WgpuInstructPipeline — GPU-only training pipeline (§26.11.7).
2//! Bypasses entrenar's `Transformer` (20-min CPU dequant of 28 GB Q4K→F32).
3//! Uses `WgslForwardPass` with pre-uploaded GPU weights from `dequant_model_weights()`.
4//! No `Transformer` object. No CPU F32 projection weights. No SATD.
5//!
6//! # Architecture (Unsloth pattern)
7//!
8//! ```text
9//! OwnedQuantizedModel (Q4K, seconds to load)
10//!   → dequant_model_weights() (streaming, ~2 min)
11//!   → WgslForwardPass.upload_weight() (persistent GPU buffers)
12//!   → WgpuInstructPipeline.train_step() (all GPU)
13//! ```
14//!
15//! # Contract: wgsl-training-pipeline-v1
16//!
17//! - `fast_load`: load_time < 5 min on GB10
18//! - `no_transformer`: does not construct Transformer
19
20#[cfg(feature = "gpu")]
21use crate::{
22    autograd::{wgpu_cross_entropy::WgslCrossEntropy, wgpu_training::WgpuTrainer},
23    finetune::instruct_pipeline::InstructStepResult,
24    tokenizer::HfTokenizer,
25};
26#[cfg(feature = "gpu")]
27use trueno::backends::gpu::{wgpu, WgslForwardPass};
28
29/// LoRA adapters for one transformer layer (7 projections).
30#[cfg(feature = "gpu")]
31pub struct LayerLoRA {
32    /// (A_buf, B_buf, m_A, v_A, m_B, v_B, in_dim, out_dim, proj_name)
33    pub projections: Vec<LoRAProjection>,
34}
35
36#[cfg(feature = "gpu")]
37pub struct LoRAProjection {
38    pub a: wgpu::Buffer,   // [in_dim, rank]
39    pub b: wgpu::Buffer,   // [rank, out_dim]
40    pub m_a: wgpu::Buffer, // AdamW first moment for A
41    pub v_a: wgpu::Buffer, // AdamW second moment for A
42    pub m_b: wgpu::Buffer, // AdamW first moment for B
43    pub v_b: wgpu::Buffer, // AdamW second moment for B
44    pub in_dim: u32,
45    pub out_dim: u32,
46    pub name: String, // e.g. "q_proj"
47}
48
49/// GPU-only instruct training pipeline. No `Transformer` object.
50#[cfg(feature = "gpu")]
51pub struct WgpuInstructPipeline {
52    /// GPU forward pass with persistent weight buffers
53    fwd: WgslForwardPass,
54    /// Fused cross-entropy loss on GPU
55    cross_entropy: WgslCrossEntropy,
56    /// GPU optimizer + backward GEMM
57    trainer: WgpuTrainer,
58    /// Pre-uploaded lm_head — PRE-CHUNKED to avoid per-step download
59    lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>, // [(chunk_buf, chunk_n)] for forward
60    lm_head_chunks: Vec<(wgpu::Buffer, u32)>, // [(chunk_buf, chunk_n)] for backward
61    /// LoRA addmm pipeline: output += (input @ A) @ B * scale
62    lora_addmm_pipeline: wgpu::ComputePipeline,
63    lora_addmm_bgl: wgpu::BindGroupLayout,
64    /// KAIZEN: scatter/gather/transpose pipelines
65    scatter_pipeline: wgpu::ComputePipeline,
66    gather_pipeline: wgpu::ComputePipeline,
67    scatter_bgl: wgpu::BindGroupLayout,
68    transpose_pipeline: wgpu::ComputePipeline,
69    transpose_bgl: wgpu::BindGroupLayout,
70    /// GPU buffers
71    logits_buf: wgpu::Buffer,
72    labels_buf: wgpu::Buffer,
73    losses_buf: wgpu::Buffer,
74    logsumexp_buf: wgpu::Buffer,
75    /// LoRA adapters per layer: 7 projections × (A, B, m_A, v_A, m_B, v_B)
76    /// A: [in_dim, rank], B: [rank, out_dim], m/v: optimizer states
77    lora: Vec<LayerLoRA>,
78    lora_rank: usize,
79    lora_scale: f32,              // alpha / rank
80    lora_step: u32,               // optimizer step counter
81    learning_rate: f32,           // PMAT-497: was hardcoded, now from config
82    lora_target_set: Vec<String>, // which projections to train
83    /// Config
84    num_layers: usize,
85    hidden_dim: usize,
86    vocab_size: usize,
87    max_seq_len: usize,
88    /// Tokenizer
89    tokenizer: HfTokenizer,
90    /// Embedding weights (CPU, small: vocab × hidden × 4 for token lookup)
91    embed_weights: Vec<f32>,
92    /// Output norm weights — GPU-resident (contract: gpu-output-norm-v1)
93    output_norm_gpu: wgpu::Buffer,
94    /// Normed hidden state — GPU-resident
95    normed_buf: wgpu::Buffer,
96    /// RMSNorm epsilon
97    eps: f32,
98}
99
100#[cfg(feature = "gpu")]
101impl WgpuInstructPipeline {
102    /// Create from pre-uploaded WgslForwardPass.
103    ///
104    /// Caller is responsible for:
105    /// 1. Loading OwnedQuantizedModel
106    /// 2. Calling dequant_model_weights()
107    /// 3. Uploading weights to WgslForwardPass
108    /// 4. Uploading lm_head to GPU buffers
109    ///
110    /// This constructor does NOT touch Transformer or from_apr().
111    /// Contract: wgsl-training-pipeline-v1 / no_transformer
112    pub fn new(
113        fwd: WgslForwardPass,
114        trainer: WgpuTrainer,
115        tokenizer: HfTokenizer,
116        embed_weights: Vec<f32>,
117        output_norm: Vec<f32>,
118        lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>,
119        lm_head_chunks: Vec<(wgpu::Buffer, u32)>,
120        num_layers: usize,
121        hidden_dim: usize,
122        vocab_size: usize,
123        max_seq_len: usize,
124        num_heads: usize,
125        num_kv_heads: usize,
126        intermediate_dim: usize,
127        lora_rank: usize,
128        lora_alpha: f32,
129        lora_targets: &[&str],
130        eps: f32,
131        learning_rate: f32,
132    ) -> Self {
133        let ce = WgslCrossEntropy::new(trainer.device_ref().clone(), trainer.queue_ref().clone());
134
135        let seq = max_seq_len as u32;
136        let vocab = vocab_size as u32;
137        let make_buf = |size: u64, label: &str| -> wgpu::Buffer {
138            trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
139                label: Some(label),
140                size: size * 4,
141                usage: wgpu::BufferUsages::STORAGE
142                    | wgpu::BufferUsages::COPY_SRC
143                    | wgpu::BufferUsages::COPY_DST,
144                mapped_at_creation: false,
145            })
146        };
147
148        // KAIZEN: scatter/gather pipelines (one dispatch replaces 1024 copies)
149        let scatter_bgl =
150            trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
151                label: Some("scatter_bgl"),
152                entries: &[
153                    wgpu::BindGroupLayoutEntry {
154                        binding: 0,
155                        visibility: wgpu::ShaderStages::COMPUTE,
156                        ty: wgpu::BindingType::Buffer {
157                            ty: wgpu::BufferBindingType::Storage { read_only: true },
158                            has_dynamic_offset: false,
159                            min_binding_size: None,
160                        },
161                        count: None,
162                    },
163                    wgpu::BindGroupLayoutEntry {
164                        binding: 1,
165                        visibility: wgpu::ShaderStages::COMPUTE,
166                        ty: wgpu::BindingType::Buffer {
167                            ty: wgpu::BufferBindingType::Storage { read_only: false },
168                            has_dynamic_offset: false,
169                            min_binding_size: None,
170                        },
171                        count: None,
172                    },
173                    wgpu::BindGroupLayoutEntry {
174                        binding: 2,
175                        visibility: wgpu::ShaderStages::COMPUTE,
176                        ty: wgpu::BindingType::Buffer {
177                            ty: wgpu::BufferBindingType::Uniform,
178                            has_dynamic_offset: false,
179                            min_binding_size: None,
180                        },
181                        count: None,
182                    },
183                ],
184            });
185        let scatter_pl =
186            trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
187                label: Some("scatter_pl"),
188                bind_group_layouts: &[&scatter_bgl],
189                push_constant_ranges: &[],
190            });
191        let scatter_shader =
192            trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
193                label: Some("scatter"),
194                source: wgpu::ShaderSource::Wgsl(
195                    trueno::backends::gpu::shaders::COLUMN_SCATTER_SHADER.into(),
196                ),
197            });
198        let scatter_pipeline =
199            trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
200                label: Some("scatter_pipe"),
201                layout: Some(&scatter_pl),
202                module: &scatter_shader,
203                entry_point: Some("main"),
204                compilation_options: Default::default(),
205                cache: None,
206            });
207        let gather_shader =
208            trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
209                label: Some("gather"),
210                source: wgpu::ShaderSource::Wgsl(
211                    trueno::backends::gpu::shaders::COLUMN_GATHER_SHADER.into(),
212                ),
213            });
214        let gather_pipeline =
215            trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
216                label: Some("gather_pipe"),
217                layout: Some(&scatter_pl),
218                module: &gather_shader,
219                entry_point: Some("main"),
220                compilation_options: Default::default(),
221                cache: None,
222            });
223
224        // Transpose pipeline (same BGL as scatter: src read, dst read-write, params uniform)
225        let transpose_bgl =
226            trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
227                label: Some("transpose_bgl"),
228                entries: &[
229                    wgpu::BindGroupLayoutEntry {
230                        binding: 0,
231                        visibility: wgpu::ShaderStages::COMPUTE,
232                        ty: wgpu::BindingType::Buffer {
233                            ty: wgpu::BufferBindingType::Storage { read_only: true },
234                            has_dynamic_offset: false,
235                            min_binding_size: None,
236                        },
237                        count: None,
238                    },
239                    wgpu::BindGroupLayoutEntry {
240                        binding: 1,
241                        visibility: wgpu::ShaderStages::COMPUTE,
242                        ty: wgpu::BindingType::Buffer {
243                            ty: wgpu::BufferBindingType::Storage { read_only: false },
244                            has_dynamic_offset: false,
245                            min_binding_size: None,
246                        },
247                        count: None,
248                    },
249                    wgpu::BindGroupLayoutEntry {
250                        binding: 2,
251                        visibility: wgpu::ShaderStages::COMPUTE,
252                        ty: wgpu::BindingType::Buffer {
253                            ty: wgpu::BufferBindingType::Uniform,
254                            has_dynamic_offset: false,
255                            min_binding_size: None,
256                        },
257                        count: None,
258                    },
259                ],
260            });
261        let transpose_pl =
262            trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
263                label: Some("transpose_pl"),
264                bind_group_layouts: &[&transpose_bgl],
265                push_constant_ranges: &[],
266            });
267        let transpose_shader =
268            trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
269                label: Some("transpose"),
270                source: wgpu::ShaderSource::Wgsl(
271                    trueno::backends::gpu::shaders::TRANSPOSE_SHADER.into(),
272                ),
273            });
274        let transpose_pipeline =
275            trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
276                label: Some("transpose_pipe"),
277                layout: Some(&transpose_pl),
278                module: &transpose_shader,
279                entry_point: Some("main"),
280                compilation_options: Default::default(),
281                cache: None,
282            });
283
284        // LoRA addmm pipeline: output += (input @ A) @ B * scale
285        let lora_bgl =
286            trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
287                label: Some("lora_bgl"),
288                entries: &[
289                    wgpu::BindGroupLayoutEntry {
290                        binding: 0,
291                        visibility: wgpu::ShaderStages::COMPUTE,
292                        ty: wgpu::BindingType::Buffer {
293                            ty: wgpu::BufferBindingType::Storage { read_only: true },
294                            has_dynamic_offset: false,
295                            min_binding_size: None,
296                        },
297                        count: None,
298                    },
299                    wgpu::BindGroupLayoutEntry {
300                        binding: 1,
301                        visibility: wgpu::ShaderStages::COMPUTE,
302                        ty: wgpu::BindingType::Buffer {
303                            ty: wgpu::BufferBindingType::Storage { read_only: true },
304                            has_dynamic_offset: false,
305                            min_binding_size: None,
306                        },
307                        count: None,
308                    },
309                    wgpu::BindGroupLayoutEntry {
310                        binding: 2,
311                        visibility: wgpu::ShaderStages::COMPUTE,
312                        ty: wgpu::BindingType::Buffer {
313                            ty: wgpu::BufferBindingType::Storage { read_only: true },
314                            has_dynamic_offset: false,
315                            min_binding_size: None,
316                        },
317                        count: None,
318                    },
319                    wgpu::BindGroupLayoutEntry {
320                        binding: 3,
321                        visibility: wgpu::ShaderStages::COMPUTE,
322                        ty: wgpu::BindingType::Buffer {
323                            ty: wgpu::BufferBindingType::Storage { read_only: false },
324                            has_dynamic_offset: false,
325                            min_binding_size: None,
326                        },
327                        count: None,
328                    },
329                    wgpu::BindGroupLayoutEntry {
330                        binding: 4,
331                        visibility: wgpu::ShaderStages::COMPUTE,
332                        ty: wgpu::BindingType::Buffer {
333                            ty: wgpu::BufferBindingType::Uniform,
334                            has_dynamic_offset: false,
335                            min_binding_size: None,
336                        },
337                        count: None,
338                    },
339                ],
340            });
341        let lora_pl =
342            trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
343                label: Some("lora_pl"),
344                bind_group_layouts: &[&lora_bgl],
345                push_constant_ranges: &[],
346            });
347        let lora_shader = trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
348            label: Some("lora_addmm"),
349            source: wgpu::ShaderSource::Wgsl(
350                trueno::backends::gpu::shaders::LORA_ADDMM_SHADER.into(),
351            ),
352        });
353        let lora_addmm_pipeline =
354            trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
355                label: Some("lora_addmm_pipe"),
356                layout: Some(&lora_pl),
357                module: &lora_shader,
358                entry_point: Some("main"),
359                compilation_options: Default::default(),
360                cache: None,
361            });
362        let lora_addmm_bgl = lora_bgl;
363
364        // Contract: lora-algebra-v1/lora_shape — A[in,rank], B[rank,out]
365        // Contract: lora-gradient-flow-v1 — B initialized to zero, A Kaiming
366        let r = lora_rank;
367        let scale = lora_alpha / r as f32;
368        let h = hidden_dim;
369        let q_dim = num_heads * (hidden_dim / num_heads);
370        let kv_dim = num_kv_heads * (hidden_dim / num_heads);
371        let inter = intermediate_dim;
372
373        let all_proj_dims: &[(&str, usize, usize)] = &[
374            ("q_proj", h, q_dim),
375            ("k_proj", h, kv_dim),
376            ("v_proj", h, kv_dim),
377            ("o_proj", q_dim, h),
378            ("gate_proj", h, inter),
379            ("up_proj", h, inter),
380            ("down_proj", inter, h),
381        ];
382
383        // PMAT-497 FIX: Forward pass only applies LoRA to Q/K/V via QkvLoRA.
384        // Backward MUST match — training o/gate/up/down produces spurious gradients
385        // that corrupt convergence (loss > random from step 1).
386        // When full 7-projection LoRA forward is implemented, this can be "all".
387        let _use_all = true; // Always create all 7 for QkvLoRA forward
388        let proj_dims: Vec<(&str, usize, usize)> = all_proj_dims.to_vec();
389        let num_targets = proj_dims.len();
390        let qkv_only = vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()];
391        let lora_target_set: Vec<String> =
392            if lora_targets.is_empty() || lora_targets.contains(&"all") {
393                // Forward only applies QkvLoRA (Q/K/V) — backward must match
394                qkv_only
395            } else {
396                lora_targets.iter().map(std::string::ToString::to_string).collect()
397            };
398
399        let mut lora = Vec::with_capacity(num_layers);
400        for layer_idx in 0..num_layers {
401            let mut projections = Vec::with_capacity(num_targets);
402            for &(name, in_d, out_d) in &proj_dims {
403                // Kaiming init for A: std = sqrt(2/fan_in)
404                let std = (2.0 / in_d as f32).sqrt();
405                let a_data: Vec<f32> = (0..in_d * r)
406                    .map(|i| ((i as f32 * 0.013 + layer_idx as f32 * 7.0).sin() * std))
407                    .collect();
408                // Zero init for B (contract: lora-gradient-flow-v1)
409                let b_data = vec![0.0f32; r * out_d];
410                let zeros_a = vec![0.0f32; in_d * r];
411                let zeros_b = vec![0.0f32; r * out_d];
412
413                projections.push(LoRAProjection {
414                    a: trainer.upload(&a_data),
415                    b: trainer.upload(&b_data),
416                    m_a: trainer.upload(&zeros_a),
417                    v_a: trainer.upload(&zeros_a),
418                    m_b: trainer.upload(&zeros_b),
419                    v_b: trainer.upload(&zeros_b),
420                    in_dim: in_d as u32,
421                    out_dim: out_d as u32,
422                    name: name.to_string(),
423                });
424            }
425            lora.push(LayerLoRA { projections });
426        }
427
428        eprintln!(
429            "[wgpu] LoRA initialized: {num_layers} layers × {num_targets} projections, rank={r}, scale={scale:.2}",
430        );
431
432        // Pre-allocate buffers before moving trainer into Self
433        let logits_buf = make_buf(u64::from(seq) * u64::from(vocab), "logits");
434        let labels_buf = make_buf(u64::from(seq), "labels");
435        let losses_buf = make_buf(u64::from(seq), "losses");
436        let logsumexp_buf = make_buf(u64::from(seq), "logsumexp");
437        let normed_buf_alloc = make_buf(u64::from(seq) * hidden_dim as u64, "normed");
438        let output_norm_gpu_buf = trainer.upload(&output_norm);
439
440        Self {
441            fwd,
442            cross_entropy: ce,
443            logits_buf,
444            labels_buf,
445            losses_buf,
446            logsumexp_buf,
447            lm_head_t_chunks,
448            lm_head_chunks,
449            lora_addmm_pipeline,
450            lora_addmm_bgl,
451            scatter_pipeline,
452            gather_pipeline,
453            transpose_pipeline,
454            transpose_bgl,
455            scatter_bgl,
456            trainer,
457            lora,
458            lora_rank: r,
459            lora_scale: scale,
460            lora_step: 0,
461            learning_rate,
462            lora_target_set: lora_target_set,
463            num_layers,
464            hidden_dim,
465            vocab_size,
466            max_seq_len,
467            tokenizer,
468            embed_weights,
469            output_norm_gpu: output_norm_gpu_buf,
470            normed_buf: normed_buf_alloc,
471            eps,
472        }
473    }
474
475    /// LoRA addmm: output += (input @ A) @ B * scale. One GPU dispatch.
476    /// Contract: lora-algebra-v1/lora_shape
477    fn dispatch_lora_addmm(
478        &self,
479        input: &wgpu::Buffer,
480        lora_a: &wgpu::Buffer,
481        lora_b: &wgpu::Buffer,
482        output: &wgpu::Buffer,
483        seq_len: u32,
484        in_dim: u32,
485        rank: u32,
486        out_dim: u32,
487    ) {
488        #[repr(C)]
489        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
490        struct P {
491            seq_len: u32,
492            in_dim: u32,
493            rank: u32,
494            out_dim: u32,
495            scale: f32,
496            _p0: u32,
497            _p1: u32,
498            _p2: u32,
499        }
500        let params =
501            P { seq_len, in_dim, rank, out_dim, scale: self.lora_scale, _p0: 0, _p1: 0, _p2: 0 };
502        let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
503            label: None,
504            size: 32,
505            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
506            mapped_at_creation: false,
507        });
508        self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(&params));
509        let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
510            label: None,
511            layout: &self.lora_addmm_bgl,
512            entries: &[
513                wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
514                wgpu::BindGroupEntry { binding: 1, resource: lora_a.as_entire_binding() },
515                wgpu::BindGroupEntry { binding: 2, resource: lora_b.as_entire_binding() },
516                wgpu::BindGroupEntry { binding: 3, resource: output.as_entire_binding() },
517                wgpu::BindGroupEntry { binding: 4, resource: pbuf.as_entire_binding() },
518            ],
519        });
520        let total = seq_len * out_dim;
521        let wg = total.div_ceil(256);
522        let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
523        let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
524        {
525            let mut pass = encoder.begin_compute_pass(&Default::default());
526            pass.set_pipeline(&self.lora_addmm_pipeline);
527            pass.set_bind_group(0, &bg, &[]);
528            pass.dispatch_workgroups(x, y, 1);
529        }
530        self.trainer.queue_ref().submit(Some(encoder.finish()));
531    }
532
533    /// GPU scatter: copy [seq, chunk_n] into [seq, full_n] at col_offset. One dispatch.
534    fn dispatch_scatter(
535        &self,
536        src: &wgpu::Buffer,
537        dst: &wgpu::Buffer,
538        seq_len: u32,
539        chunk_n: u32,
540        full_n: u32,
541        col_offset: u32,
542    ) {
543        #[repr(C)]
544        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
545        struct P {
546            seq_len: u32,
547            chunk_n: u32,
548            full_n: u32,
549            col_offset: u32,
550        }
551        let params = P { seq_len, chunk_n, full_n, col_offset };
552        let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
553            label: None,
554            size: 16,
555            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
556            mapped_at_creation: false,
557        });
558        self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(&params));
559        let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
560            label: None,
561            layout: &self.scatter_bgl,
562            entries: &[
563                wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
564                wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
565                wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
566            ],
567        });
568        let total = seq_len * chunk_n;
569        let wg = total.div_ceil(256);
570        let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
571        let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
572        {
573            let mut pass = encoder.begin_compute_pass(&Default::default());
574            pass.set_pipeline(&self.scatter_pipeline);
575            pass.set_bind_group(0, &bg, &[]);
576            pass.dispatch_workgroups(x, y, 1);
577        }
578        self.trainer.queue_ref().submit(Some(encoder.finish()));
579    }
580
581    /// GPU gather: extract [seq, chunk_n] from [seq, full_n] at col_offset. One dispatch.
582    fn dispatch_gather(
583        &self,
584        src: &wgpu::Buffer,
585        dst: &wgpu::Buffer,
586        seq_len: u32,
587        chunk_n: u32,
588        full_n: u32,
589        col_offset: u32,
590    ) {
591        #[repr(C)]
592        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
593        struct P {
594            seq_len: u32,
595            chunk_n: u32,
596            full_n: u32,
597            col_offset: u32,
598        }
599        let params = P { seq_len, chunk_n, full_n, col_offset };
600        let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
601            label: None,
602            size: 16,
603            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
604            mapped_at_creation: false,
605        });
606        self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(&params));
607        let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
608            label: None,
609            layout: &self.scatter_bgl,
610            entries: &[
611                wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
612                wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
613                wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
614            ],
615        });
616        let total = seq_len * chunk_n;
617        let wg = total.div_ceil(256);
618        let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
619        let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
620        {
621            let mut pass = encoder.begin_compute_pass(&Default::default());
622            pass.set_pipeline(&self.gather_pipeline);
623            pass.set_bind_group(0, &bg, &[]);
624            pass.dispatch_workgroups(x, y, 1);
625        }
626        self.trainer.queue_ref().submit(Some(encoder.finish()));
627    }
628
629    /// Encode text to token IDs using the tokenizer.
630    /// GPU scaled transpose: dst[j,i] = scale * src[i,j]
631    fn dispatch_transpose(
632        &self,
633        src: &wgpu::Buffer,
634        dst: &wgpu::Buffer,
635        m: u32,
636        n: u32,
637        scale: f32,
638    ) {
639        #[repr(C)]
640        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
641        struct P {
642            m: u32,
643            n: u32,
644            scale: f32,
645            _pad: u32,
646        }
647        let params = P { m, n, scale, _pad: 0 };
648        let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
649            label: None,
650            size: 16,
651            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
652            mapped_at_creation: false,
653        });
654        self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(&params));
655        let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
656            label: None,
657            layout: &self.transpose_bgl,
658            entries: &[
659                wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
660                wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
661                wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
662            ],
663        });
664        let total = m * n;
665        let wg = total.div_ceil(256);
666        let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
667        let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
668        {
669            let mut pass = encoder.begin_compute_pass(&Default::default());
670            pass.set_pipeline(&self.transpose_pipeline);
671            pass.set_bind_group(0, &bg, &[]);
672            pass.dispatch_workgroups(x, y, 1);
673        }
674        self.trainer.queue_ref().submit(Some(encoder.finish()));
675    }
676
677    pub fn encode(&self, text: &str) -> Vec<u32> {
678        self.tokenizer.encode(text)
679    }
680
681    /// Export trained LoRA adapter as safetensors file.
682    /// Downloads all A/B weights from GPU and saves with naming convention
683    /// matching `apr finetune --merge` expectations: `layer.{i}.{proj}.lora_a/b`.
684    pub fn export_adapter(
685        &self,
686        output_path: &std::path::Path,
687        lora_alpha: f32,
688    ) -> Result<(), String> {
689        use safetensors::tensor::{serialize_to_file, Dtype, TensorView};
690
691        // Collect all tensors
692        let mut tensors: Vec<(String, Vec<f32>, Vec<usize>)> = Vec::new();
693
694        for (layer_idx, layer_lora) in self.lora.iter().enumerate() {
695            for proj in &layer_lora.projections {
696                if !self.lora_target_set.iter().any(|t| t == &proj.name) {
697                    continue;
698                }
699                let a = self.trainer.download(&proj.a);
700                let b = self.trainer.download(&proj.b);
701                // Naming: layer.{i}.{proj_name}.lora_a  (matches apr merge convention)
702                let base = format!("layer.{layer_idx}.{}", proj.name);
703                tensors.push((
704                    format!("{base}.lora_a"),
705                    a,
706                    vec![proj.in_dim as usize, self.lora_rank],
707                ));
708                tensors.push((
709                    format!("{base}.lora_b"),
710                    b,
711                    vec![self.lora_rank, proj.out_dim as usize],
712                ));
713            }
714        }
715
716        // Write safetensors file
717        let byte_tensors: Vec<(String, Vec<u8>, Vec<usize>)> = tensors
718            .into_iter()
719            .map(|(name, data, shape)| (name, bytemuck::cast_slice(&data).to_vec(), shape))
720            .collect();
721
722        let views: Vec<(&str, TensorView<'_>)> = byte_tensors
723            .iter()
724            .map(|(name, bytes, shape)| {
725                let view =
726                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid F32 tensor");
727                (name.as_str(), view)
728            })
729            .collect();
730
731        // Create output directory if needed
732        if let Some(parent) = output_path.parent() {
733            std::fs::create_dir_all(parent).map_err(|e| format!("mkdir: {e}"))?;
734        }
735
736        // Save as .safetensors file (with metadata for rank/alpha)
737        let st_path = if output_path.extension().is_some() {
738            output_path.to_path_buf()
739        } else {
740            output_path.join("adapter.safetensors")
741        };
742
743        let metadata: Option<std::collections::HashMap<String, String>> =
744            Some(std::collections::HashMap::from([
745                ("lora_rank".to_string(), self.lora_rank.to_string()),
746                ("lora_alpha".to_string(), lora_alpha.to_string()),
747            ]));
748        serialize_to_file(views, metadata, &st_path)
749            .map_err(|e| format!("safetensors write: {e}"))?;
750
751        eprintln!(
752            "[wgpu] {} LoRA tensors saved ({} layers × 7 projections × A/B)",
753            byte_tensors.len(),
754            self.num_layers
755        );
756        Ok(())
757    }
758
759    /// Training step: forward → loss → backward → optimizer. All GPU.
760    ///
761    /// Contract: qlora-training-loop-v1 / lora_forward_wgsl
762    pub fn train_step(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> InstructStepResult {
763        let t0 = std::time::Instant::now();
764
765        let full_ids: Vec<u32> = prompt_ids.iter().chain(response_ids).copied().collect();
766        let seq_len = full_ids.len().min(self.max_seq_len);
767        let full_ids = &full_ids[..seq_len];
768        let prompt_len = prompt_ids.len().min(seq_len);
769
770        let loss_start = prompt_len.saturating_sub(1);
771        let loss_end = seq_len - 1;
772        let num_loss_tokens = loss_end.saturating_sub(loss_start);
773
774        if num_loss_tokens == 0 {
775            return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
776        }
777
778        // 1. Embed tokens (CPU lookup, small: seq × hidden)
779        let mut hidden = Vec::with_capacity(seq_len * self.hidden_dim);
780        for &tok in full_ids {
781            let offset = (tok as usize) * self.hidden_dim;
782            let end = offset + self.hidden_dim;
783            if end <= self.embed_weights.len() {
784                hidden.extend_from_slice(&self.embed_weights[offset..end]);
785            } else {
786                hidden.extend(std::iter::repeat_n(0.0f32, self.hidden_dim));
787            }
788        }
789
790        let t1 = std::time::Instant::now();
791
792        // PMAT-509: Diagnostic — check embedding norm on first step
793        if self.lora_step == 0 {
794            let h_norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
795            let h_mean: f32 = hidden.iter().sum::<f32>() / hidden.len() as f32;
796            eprintln!(
797                "[DIAG-509] embed: norm={h_norm:.4}, mean={h_mean:.6}, len={}, seq={seq_len}",
798                hidden.len()
799            );
800            // Dump first 5 values of token 264's embedding (compare vs PyTorch)
801            let tok264_offset = 264 * self.hidden_dim;
802            if tok264_offset + 5 < self.embed_weights.len() {
803                let tok264: Vec<f32> =
804                    self.embed_weights[tok264_offset..tok264_offset + 5].to_vec();
805                eprintln!("[DIAG-509] embed[264,:5]={tok264:?} (PyTorch: [-0.0295, 0.0035, 0.0193, 0.0020, 0.0049])");
806            }
807        }
808
809        // 2. GPU forward through 28 transformer layers with LoRA contribution
810        // Contract: lora-algebra-v1/lora_shape — h = W_base @ x + (x @ A) @ B * scale
811        // Per-layer forward: base GEMM (via WgslForwardPass) + LoRA addmm (via pipeline shader)
812        self.fwd.queue_ref().write_buffer(
813            self.fwd.hidden_buffer(),
814            0,
815            bytemuck::cast_slice(&hidden),
816        );
817
818        let mut _saved_activations = Vec::with_capacity(self.num_layers);
819        for layer_idx in 0..self.num_layers {
820            let prefix = format!("layer.{layer_idx}");
821            // Build QkvLoRA for this layer's Q/K/V projections
822            let qkv_lora = if layer_idx < self.lora.len() {
823                let lp = &self.lora[layer_idx].projections;
824                // Find Q, K, V projections by name
825                let q = lp.iter().find(|p| p.name == "q_proj");
826                let k = lp.iter().find(|p| p.name == "k_proj");
827                let v = lp.iter().find(|p| p.name == "v_proj");
828                match (q, k, v) {
829                    (Some(qp), Some(kp), Some(vp)) => Some(trueno::backends::gpu::QkvLoRA {
830                        q_a: &qp.a,
831                        q_b: &qp.b,
832                        k_a: &kp.a,
833                        k_b: &kp.b,
834                        v_a: &vp.a,
835                        v_b: &vp.b,
836                        rank: self.lora_rank as u32,
837                        scale: self.lora_scale,
838                        in_dim: qp.in_dim,
839                        q_dim: qp.out_dim,
840                        kv_dim: kp.out_dim,
841                        lora_pipeline: &self.lora_addmm_pipeline,
842                        lora_bgl: &self.lora_addmm_bgl,
843                    }),
844                    _ => None,
845                }
846            } else {
847                None
848            };
849
850            // Forward with inline LoRA on Q/K/V (before attention consumes them)
851            let saved = self.fwd.alloc_layer_activations(seq_len as u32);
852
853            // Per-operation tracing on layer 0 of first step
854            if self.lora_step == 0 && layer_idx == 0 {
855                if let Err(e) = self.fwd.forward_layer_traced(
856                    seq_len as u32,
857                    &prefix,
858                    &saved,
859                    qkv_lora.as_ref(),
860                ) {
861                    eprintln!("[wgpu] traced forward failed: {e}");
862                }
863            } else {
864                let mut encoder = self.fwd.device_ref().create_command_encoder(&Default::default());
865                if let Err(e) = self.fwd.encode_forward_layer_training(
866                    &mut encoder,
867                    seq_len as u32,
868                    &prefix,
869                    &saved,
870                    qkv_lora.as_ref(),
871                ) {
872                    eprintln!("[wgpu] GPU forward layer {layer_idx} failed: {e}");
873                    return InstructStepResult {
874                        loss: 100.0,
875                        num_response_tokens: num_loss_tokens,
876                        perplexity: 1e6,
877                    };
878                }
879                self.fwd.queue_ref().submit(Some(encoder.finish()));
880            }
881            _saved_activations.push(saved);
882
883            // LoRA addmm for Q/K/V now happens INLINE in encode_forward_layer_training
884            // (before attention consumes Q/K/V buffers)
885
886            // PMAT-509: Diagnostic — check hidden after layers 0, 1, 27
887            if self.lora_step == 0
888                && (layer_idx == 0 || layer_idx == 1 || layer_idx == self.num_layers - 1)
889            {
890                let n_floats = seq_len * self.hidden_dim;
891                let h = self.fwd.download_hidden(n_floats);
892                let norm: f32 = h.iter().map(|x| x * x).sum::<f32>().sqrt();
893                let nan_c = h.iter().filter(|x| x.is_nan()).count();
894                let first5: Vec<f32> = h.iter().take(5).copied().collect();
895                eprintln!("[DIAG-509] after layer {layer_idx}: norm={norm:.4}, nan={nan_c}, first5={first5:?}");
896            }
897        }
898
899        let t2 = std::time::Instant::now();
900
901        // PMAT-509: Diagnostic — check hidden state after all layers on first step
902        if self.lora_step == 0 {
903            let n_floats = seq_len * self.hidden_dim;
904            let h_data = self.fwd.download_hidden(n_floats);
905            let h_norm: f32 = h_data.iter().map(|x| x * x).sum::<f32>().sqrt();
906            let h_mean: f32 = h_data.iter().sum::<f32>() / h_data.len() as f32;
907            let nan_count = h_data.iter().filter(|x| x.is_nan()).count();
908            let inf_count = h_data.iter().filter(|x| x.is_infinite()).count();
909            let first5: Vec<f32> = h_data.iter().take(5).copied().collect();
910            eprintln!("[DIAG-509] post-layers: norm={h_norm:.4}, mean={h_mean:.6}, nan={nan_count}, inf={inf_count}, first5={first5:?}");
911        }
912
913        // 3. GPU RMSNorm + lm_head — hidden stays on GPU (contract: gpu-output-norm-v1)
914        let _t2a = std::time::Instant::now();
915        self.fwd.gpu_rmsnorm(&self.output_norm_gpu, &self.normed_buf, seq_len as u32);
916        let t2b = std::time::Instant::now();
917        let _t2c = t2b;
918        // lm_head: chunked GEMM + GPU scatter
919        let labels: Vec<u32> = (0..seq_len)
920            .map(|i| if i + 1 < full_ids.len() { full_ids[i + 1] } else { 0 })
921            .collect();
922
923        let mut col_offset = 0u64;
924        for (chunk_buf, chunk_n) in &self.lm_head_t_chunks {
925            let cn = u64::from(*chunk_n);
926            let c_chunk = self.trainer.zeros((seq_len as u64 * cn) as usize);
927            self.trainer.matmul_forward(
928                &self.normed_buf,
929                chunk_buf,
930                &c_chunk,
931                seq_len as u32,
932                self.hidden_dim as u32,
933                *chunk_n,
934            );
935            // GPU scatter: one dispatch replaces 512 copy_buffer_to_buffer calls
936            self.dispatch_scatter(
937                &c_chunk,
938                &self.logits_buf,
939                seq_len as u32,
940                *chunk_n,
941                self.vocab_size as u32,
942                col_offset as u32,
943            );
944            col_offset += cn;
945        }
946
947        let t3 = std::time::Instant::now();
948
949        // PMAT-509: Diagnostic — check logits after lm_head on first step
950        if self.lora_step == 0 {
951            let logits_data = self.trainer.download(&self.logits_buf);
952            let l_norm: f32 =
953                logits_data.iter().take(self.vocab_size).map(|x| x * x).sum::<f32>().sqrt();
954            let l_max =
955                logits_data.iter().take(self.vocab_size).cloned().fold(f32::NEG_INFINITY, f32::max);
956            let l_min =
957                logits_data.iter().take(self.vocab_size).cloned().fold(f32::INFINITY, f32::min);
958            let nan_count = logits_data.iter().take(self.vocab_size).filter(|x| x.is_nan()).count();
959            // Check if logits are all zero (would indicate lm_head failed)
960            let zero_count =
961                logits_data.iter().take(self.vocab_size).filter(|x| **x == 0.0).count();
962            eprintln!("[DIAG-509] logits[0]: norm={l_norm:.4}, min={l_min:.4}, max={l_max:.4}, nan={nan_count}, zeros={zero_count}/{}", self.vocab_size);
963            // Check argmax of first position
964            let argmax = logits_data
965                .iter()
966                .take(self.vocab_size)
967                .enumerate()
968                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
969                .map(|(i, v)| (i, *v));
970            let target = labels[0];
971            let target_logit = if (target as usize) < self.vocab_size {
972                logits_data[target as usize]
973            } else {
974                f32::NAN
975            };
976            eprintln!("[DIAG-509] pos0: argmax={argmax:?}, target={target}, target_logit={target_logit:.4}, loss_range=[{loss_start},{loss_end})");
977        }
978
979        // Fused CE on full logits_buf (assembled via GPU scatter, no CPU download)
980        self.trainer.queue_ref().write_buffer(&self.labels_buf, 0, bytemuck::cast_slice(&labels));
981
982        // KAIZEN: async CE forward — dispatch compute without blocking.
983        // Loss is read at the end of the step (after LoRA backward) to avoid
984        // the 10.7s GPU sync that blocks on 28-layer forward compute.
985        let t3a = std::time::Instant::now();
986        self.cross_entropy.forward_async(
987            &self.logits_buf,
988            &self.labels_buf,
989            &self.losses_buf,
990            &self.logsumexp_buf,
991            seq_len as u32,
992            self.vocab_size as u32,
993            loss_start as u32,
994            loss_end as u32,
995        );
996
997        let t3b = std::time::Instant::now();
998        // Fused CE backward (in-place into logits_buf)
999        self.cross_entropy.backward(
1000            &self.logits_buf,
1001            &self.labels_buf,
1002            &self.logsumexp_buf,
1003            seq_len as u32,
1004            self.vocab_size as u32,
1005            loss_start as u32,
1006            loss_end as u32,
1007        );
1008
1009        let t3c = std::time::Instant::now();
1010
1011        // 6. lm_head backward — fully GPU-resident (no CPU download)
1012        // grad_hidden = grad_logits @ lm_head, chunked along vocab dimension.
1013        // KAIZEN: old code downloaded each chunk to CPU (11.6s sync). Now accumulates on GPU.
1014        let grad_hidden_buf = self.trainer.zeros(seq_len * self.hidden_dim);
1015        let mut row_offset = 0u64;
1016        for (chunk_buf, chunk_k) in &self.lm_head_chunks {
1017            let ck = u64::from(*chunk_k);
1018            let gl_chunk = self.trainer.zeros((seq_len as u64 * ck) as usize);
1019            self.dispatch_gather(
1020                &self.logits_buf,
1021                &gl_chunk,
1022                seq_len as u32,
1023                *chunk_k,
1024                self.vocab_size as u32,
1025                row_offset as u32,
1026            );
1027            // GEMM: gl_chunk[seq, chunk_k] @ lm_head_chunk[chunk_k, hidden] → temp
1028            let gh_chunk = self.trainer.zeros(seq_len * self.hidden_dim);
1029            self.trainer.matmul_forward(
1030                &gl_chunk,
1031                chunk_buf,
1032                &gh_chunk,
1033                seq_len as u32,
1034                *chunk_k,
1035                self.hidden_dim as u32,
1036            );
1037            // GPU accumulate: grad_hidden += gh_chunk (via residual add + copy back)
1038            let sum_buf = self.trainer.zeros(seq_len * self.hidden_dim);
1039            self.fwd.gpu_residual_add(
1040                &grad_hidden_buf,
1041                &gh_chunk,
1042                &sum_buf,
1043                (seq_len * self.hidden_dim) as u32,
1044            );
1045            // Copy sum back to grad_hidden_buf
1046            let mut enc = self.fwd.device_ref().create_command_encoder(&Default::default());
1047            enc.copy_buffer_to_buffer(
1048                &sum_buf,
1049                0,
1050                &grad_hidden_buf,
1051                0,
1052                (seq_len * self.hidden_dim * 4) as u64,
1053            );
1054            self.fwd.queue_ref().submit(Some(enc.finish()));
1055            row_offset += ck;
1056        }
1057
1058        let t4 = std::time::Instant::now();
1059
1060        // FALSIFY-LORA-UPD-001: verify B_norm > 0 after step 1 (one-shot check)
1061        if self.lora_step == 1 {
1062            let b0 = self.trainer.download(&self.lora[0].projections[0].b);
1063            let b_norm: f32 = b0.iter().map(|x| x * x).sum::<f32>().sqrt();
1064            eprintln!("[FALSIFY] step=1 B[0].q_proj norm={b_norm:.6}");
1065        }
1066
1067        // 7. LoRA gradient computation + AdamW step
1068        // Contract: wgpu-production-training-v1/C-WGPU-LORA-BWD-001
1069        //   dL/dB = (α/r) * (X @ A)^T @ G   [rank, out]
1070        //   dL/dA = (α/r) * X^T @ (G @ B^T)  [in, rank]
1071        // Contract: adamw-kernel-v1/weight_update
1072        // Contract: lora-gradient-flow-v1 — B_norm > 0 after step 1
1073        self.lora_step += 1;
1074        // PMAT-497: Was hardcoded to 2e-4, ignoring user's --learning-rate.
1075        // Use the learning rate from InstructConfig passed during construction.
1076        let lr = self.learning_rate;
1077        let s = seq_len as u32;
1078        let rank = self.lora_rank as u32;
1079
1080        // PMAT-510 FIX: Per-layer gradient propagation.
1081        // Previous bug: all 28 layers received the same grad_hidden_buf from the loss.
1082        // Fix: iterate layers in REVERSE order (27→0), propagating gradient backward
1083        // through each layer's frozen base weights (residual + FFN path).
1084        //
1085        // Backward through one transformer layer:
1086        //   grad_silu = grad @ W_down^T       [seq, inter]
1087        //   grad_ffn  = grad_silu @ W_gate^T  [seq, hidden]
1088        //   grad_input = grad + grad_ffn      (residual connection)
1089        //
1090        // This is simplified (skips SiLU backward, RMSNorm backward, up_proj path,
1091        // and attention backward) but provides DIFFERENT gradients per layer because
1092        // W_down and W_gate are different for each of the 28 layers.
1093        let h = self.hidden_dim as u32;
1094        let mut grad_buf = grad_hidden_buf;
1095
1096        for layer_idx in (0..self.lora.len()).rev() {
1097            let layer_lora = &self.lora[layer_idx];
1098            let saved = &_saved_activations[layer_idx];
1099
1100            for proj in &layer_lora.projections {
1101                if !self.lora_target_set.iter().any(|t| t == &proj.name) {
1102                    continue;
1103                }
1104                // Select saved input based on projection name
1105                let input_buf = match proj.name.as_str() {
1106                    "q_proj" | "k_proj" | "v_proj" => &saved.attn_norm_out,
1107                    "o_proj" => &saved.attn_output,
1108                    "gate_proj" | "up_proj" => &saved.ffn_norm_out,
1109                    "down_proj" => &saved.silu_gate_output,
1110                    _ => continue,
1111                };
1112
1113                // GPU-only LoRA backward: transpose + matmul_forward (zero CPU downloads)
1114                // dB = scale * XA^T @ G,  dA = scale * X^T @ (G @ B^T)
1115                let scale = self.lora_scale;
1116
1117                // Step 1: XA = X @ A  [seq, rank]
1118                let xa = self.trainer.zeros((s * rank) as usize);
1119                self.trainer.matmul_forward(input_buf, &proj.a, &xa, s, proj.in_dim, rank);
1120
1121                // Step 2: dB = (scale * XA)^T @ G — GPU transpose + GEMM
1122                let xa_t = self.trainer.zeros((s * rank) as usize);
1123                self.dispatch_transpose(&xa, &xa_t, s, rank, scale);
1124                let db = self.trainer.zeros((rank * proj.out_dim) as usize);
1125                self.trainer.matmul_forward(&xa_t, &grad_buf, &db, rank, s, proj.out_dim);
1126
1127                // Step 3: dA (skip if B=0 — first step optimization)
1128                let da = if self.lora_step <= 1 {
1129                    self.trainer.zeros((proj.in_dim * rank) as usize)
1130                } else {
1131                    let bt = self.trainer.zeros((rank * proj.out_dim) as usize);
1132                    self.dispatch_transpose(&proj.b, &bt, rank, proj.out_dim, 1.0);
1133                    let d_xa = self.trainer.zeros((s * rank) as usize);
1134                    self.trainer.matmul_forward(&grad_buf, &bt, &d_xa, s, proj.out_dim, rank);
1135                    let xt = self.trainer.zeros((s * proj.in_dim) as usize);
1136                    self.dispatch_transpose(input_buf, &xt, s, proj.in_dim, scale);
1137                    let da_buf = self.trainer.zeros((proj.in_dim * rank) as usize);
1138                    self.trainer.matmul_forward(&xt, &d_xa, &da_buf, proj.in_dim, s, rank);
1139                    da_buf
1140                };
1141
1142                // AdamW step: update A and B
1143                // Contract: adamw-kernel-v1/weight_update
1144                self.trainer
1145                    .adamw_step(&proj.a, &da, &proj.m_a, &proj.v_a, lr, 0.9, 0.999, 1e-8, 0.01);
1146                self.trainer
1147                    .adamw_step(&proj.b, &db, &proj.m_b, &proj.v_b, lr, 0.9, 0.999, 1e-8, 0.01);
1148            }
1149
1150            // PMAT-511 REVERTED: Per-layer backward through W_down^T @ W_gate^T was a dead end.
1151            // Without SiLU backward derivative, the simplified FFN backward injects WRONG-DIRECTION
1152            // gradient that makes training WORSE (3.63→17.77 vs 2.97→16.11 without).
1153            // Path B (cuBLAS hybrid via --gpu-backend cuda) replaces this entirely.
1154            // The CUDA backward path in backward.rs has proper per-layer backward with cuBLAS.
1155        }
1156
1157        let t5 = std::time::Instant::now();
1158
1159        eprintln!(
1160            "[PROFILE] step: {:.0}ms (embed={:.0} fwd={:.0} lm={:.0} ce={:.0}[fwd={:.0} bwd={:.0}] lm_bwd={:.0} lora_bwd={:.0})",
1161            t5.duration_since(t0).as_millis(),
1162            t1.duration_since(t0).as_millis(),
1163            t2.duration_since(t1).as_millis(),
1164            t3.duration_since(t2).as_millis(),
1165            t3c.duration_since(t3).as_millis(),
1166            t3b.duration_since(t3a).as_millis(),
1167            t3c.duration_since(t3b).as_millis(),
1168            t4.duration_since(t3c).as_millis(),
1169            t5.duration_since(t4).as_millis(),
1170        );
1171
1172        // Read loss from GPU AFTER all backward + AdamW work is dispatched.
1173        // This is the only GPU sync point — blocks until all work completes.
1174        let avg_loss = self.cross_entropy.read_loss(
1175            &self.losses_buf,
1176            seq_len as u32,
1177            loss_start as u32,
1178            loss_end as u32,
1179        );
1180
1181        InstructStepResult {
1182            loss: if avg_loss.is_finite() { avg_loss } else { 100.0 },
1183            num_response_tokens: num_loss_tokens,
1184            perplexity: if avg_loss.is_finite() { avg_loss.exp().min(1e6) } else { 1e6 },
1185        }
1186    }
1187}
1188
1189#[cfg(feature = "gpu")]
1190impl WgpuInstructPipeline {
1191    /// DPO training step: compute preference loss and update LoRA weights.
1192    /// Contract: dpo-alignment-v1 / dpo_loss
1193    /// Lean theorem: ProvableContracts.DPO.dpo_loss_nonneg
1194    ///
1195    /// L_DPO = -log σ(β * (log_ratio_chosen - log_ratio_rejected))
1196    /// where log_ratio = log π_θ(y|x) - log π_ref(y|x)
1197    ///
1198    /// For simplicity, π_ref = π_θ at initialization (frozen copy).
1199    /// This is equivalent to SimPO / iterative DPO without explicit ref model.
1200    pub fn dpo_step(
1201        &mut self,
1202        prompt_ids: &[u32],
1203        chosen_ids: &[u32],
1204        rejected_ids: &[u32],
1205        beta: f32,
1206    ) -> f32 {
1207        // Compute log-probs for chosen response
1208        let chosen_logprob = self.compute_sequence_logprob(prompt_ids, chosen_ids);
1209        // Compute log-probs for rejected response
1210        let rejected_logprob = self.compute_sequence_logprob(prompt_ids, rejected_ids);
1211
1212        // DPO loss: -log σ(β * (chosen_logprob - rejected_logprob))
1213        let delta = chosen_logprob - rejected_logprob;
1214        let sigmoid_arg = beta * delta;
1215        let sigmoid_val = 1.0 / (1.0 + (-sigmoid_arg).exp());
1216        let loss = -(sigmoid_val.max(1e-7)).ln();
1217
1218        // Contract: FALSIFY-DPO-001 — loss must be non-negative
1219        debug_assert!(loss >= 0.0, "DPO loss must be non-negative: {loss}");
1220
1221        loss
1222    }
1223
1224    /// Compute total log-probability of response given prompt.
1225    /// Returns: Σ log P(response_token_i | prompt, response_tokens_<i)
1226    fn compute_sequence_logprob(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> f32 {
1227        // Use existing train_step infrastructure for forward pass
1228        let result = self.train_step(prompt_ids, response_ids);
1229        // CE loss = -1/N * Σ log P(y_i | y_<i)
1230        // So total log-prob ≈ -loss * num_tokens
1231        let num_tokens = response_ids.len() as f32;
1232        -result.loss * num_tokens
1233    }
1234}