Skip to main content

entrenar/transformer/
cuda_block.rs

1//! CUDA-accelerated Transformer Block (ENT-147 through ENT-152)
2//!
3//! This module provides a fully GPU-accelerated transformer block using trueno-gpu kernels.
4//! All operations run on CUDA to achieve >70% GPU utilization.
5//!
6//! # Phase 22 Implementation Status
7//!
8//! - ENT-147: CUDA RMSNorm integration ✅
9//! - ENT-148: CUDA Softmax integration ✅
10//! - ENT-149: CUDA SiLU activation ✅
11//! - ENT-150: Fused SwiGLU kernel ✅
12//! - ENT-151: CUDA backward pass ✅
13//! - ENT-152: CudaTransformer wrapper ✅
14
15#![allow(dead_code)]
16// SAFETY: This module performs GPU memory transfers via CUDA driver FFI.
17// The unsafe blocks are limited to copy_from_host_async / copy_to_host_async
18// where we guarantee the host buffer outlives the async operation by syncing
19// the stream before the buffer goes out of scope.
20#![allow(unsafe_code)]
21
22// PMAT-483/entrenar#328: Per-operation profiling indices for forward pass.
23// Must match StepProfiler OP_* constants in step_profiler.rs.
24#[cfg(feature = "cuda")]
25const OP_RMSNORM_ATTN: usize = 0;
26#[cfg(feature = "cuda")]
27const OP_QKV_GEMM: usize = 1;
28#[cfg(feature = "cuda")]
29const OP_ATTENTION: usize = 2;
30#[cfg(feature = "cuda")]
31const OP_O_PROJ: usize = 3;
32#[cfg(feature = "cuda")]
33const OP_RMSNORM_FFN: usize = 4;
34#[cfg(feature = "cuda")]
35const OP_GATE_UP_GEMM: usize = 5;
36#[cfg(feature = "cuda")]
37const OP_SILU: usize = 6;
38#[cfg(feature = "cuda")]
39const OP_DOWN_GEMM: usize = 7;
40
41// PMAT-483: Per-operation profiling indices for backward pass.
42#[cfg(feature = "cuda")]
43const OP_LORA_FWD: usize = 8;
44#[cfg(feature = "cuda")]
45const OP_DOWN_BWD: usize = 9;
46#[cfg(feature = "cuda")]
47const OP_SWIGLU_BWD: usize = 10;
48#[cfg(feature = "cuda")]
49const OP_GATE_UP_BWD: usize = 11;
50#[cfg(feature = "cuda")]
51const OP_ATTN_BWD: usize = 12;
52#[cfg(feature = "cuda")]
53const OP_QKV_BWD: usize = 13;
54#[cfg(feature = "cuda")]
55const OP_NORM_BWD: usize = 14;
56#[cfg(feature = "cuda")]
57const OP_LORA_BWD: usize = 15;
58
59#[cfg(feature = "cuda")]
60use std::sync::Arc;
61
62#[cfg(feature = "cuda")]
63#[inline]
64fn saturating_u32(v: usize) -> u32 {
65    v.min(u32::MAX as usize) as u32
66}
67
68/// Consume a value without running its destructor (prevents GPU double-free).
69#[cfg(feature = "cuda")]
70#[inline]
71fn leak<T>(val: T) {
72    let _ = std::mem::ManuallyDrop::new(val);
73}
74
75#[cfg(feature = "cuda")]
76use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
77
78#[cfg(feature = "cuda")]
79use crate::autograd::cuda_backward::{
80    batched_softmax_backward, gemm_backward_a, gemm_backward_a_fp16_dispatch,
81    gemm_backward_a_fp16_dispatch_accumulate, gemm_backward_b, rms_norm_backward, silu_backward,
82};
83#[cfg(feature = "cuda")]
84use crate::autograd::cuda_forward::{
85    batched_4d_gemm_forward, batched_rope_neox_backward, batched_rope_neox_forward,
86    batched_softmax_forward, batched_to_interleaved_forward, batched_transpose_forward,
87    cast_f32_to_f16_gpu, elementwise_mul_forward, expand_kv_heads, fused_residual_rmsnorm_forward,
88    fused_swiglu_forward, gemm_f16_to_f32_forward, gemm_forward, interleaved_to_batched_forward,
89    per_head_rmsnorm_forward, residual_add_forward, rms_norm_forward, scale_forward, silu_forward,
90};
91#[cfg(feature = "cuda")]
92use crate::autograd::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, squared_sum_cuda};
93#[cfg(feature = "cuda")]
94use crate::autograd::cuda_tensor::Result;
95
96#[cfg(feature = "cuda")]
97use super::config::TransformerConfig;
98
99/// CUDA-accelerated transformer block
100///
101/// All operations run on GPU with minimal CPU<->GPU transfers.
102#[cfg(feature = "cuda")]
103pub struct CudaTransformerBlock {
104    /// Configuration
105    config: TransformerConfig,
106    /// Layer index
107    layer_idx: usize,
108    /// Input RMSNorm weight (gamma)
109    input_norm_weight: GpuBuffer<f32>,
110    /// Post-attention RMSNorm weight (gamma)
111    post_attn_norm_weight: GpuBuffer<f32>,
112    /// Query projection weight (hidden_size x hidden_size)
113    w_q: GpuBuffer<f32>,
114    /// Key projection weight (hidden_size x kv_hidden_size)
115    w_k: GpuBuffer<f32>,
116    /// Value projection weight (hidden_size x kv_hidden_size)
117    w_v: GpuBuffer<f32>,
118    /// Output projection weight (hidden_size x hidden_size)
119    w_o: GpuBuffer<f32>,
120    /// FFN gate projection (hidden_size x intermediate_size)
121    w_gate: GpuBuffer<f32>,
122    /// FFN up projection (hidden_size x intermediate_size)
123    w_up: GpuBuffer<f32>,
124    /// FFN down projection (intermediate_size x hidden_size)
125    w_down: GpuBuffer<f32>,
126    /// CUDA context
127    ctx: Arc<CudaContext>,
128    /// Scratch buffers for intermediate results
129    scratch: CudaBlockScratch,
130    /// Pre-allocated host zero buffer for zeroing norm grad buffers [hidden_size]
131    norm_zero_buf: Vec<f32>,
132    /// ENT-270: QK-norm weights (per-head RMSNorm, shape=[head_dim])
133    q_norm_weight: Option<GpuBuffer<f32>>,
134    k_norm_weight: Option<GpuBuffer<f32>>,
135}
136
137/// Preallocated scratch buffers for transformer forward/backward pass.
138///
139/// For fp32 blocks: per-layer (backward reads forward activations).
140/// For NF4 blocks: shared across all layers (forward-only, no backward).
141///
142/// # Contract (C-SCRATCH-001)
143///
144/// - **Precondition**: Allocated with matching `config` and `max_seq_len`
145/// - **Postcondition**: All buffers sized for worst-case `max_seq_len`
146/// - **Invariant**: NF4 layers run sequentially — one shared scratch is safe
147#[cfg(feature = "cuda")]
148pub(crate) struct CudaBlockScratch {
149    /// After input RMSNorm (seq_len * hidden_size)
150    norm1_out: GpuBuffer<f32>,
151    /// Q projection output (seq_len * hidden_size)
152    q: GpuBuffer<f32>,
153    /// K projection output (seq_len * kv_hidden_size)
154    k: GpuBuffer<f32>,
155    /// V projection output (seq_len * kv_hidden_size)
156    v: GpuBuffer<f32>,
157    /// Attention scores (num_heads * seq_len * seq_len)
158    attn_scores: GpuBuffer<f32>,
159    /// Attention output (seq_len * q_dim)
160    attn_out: GpuBuffer<f32>,
161    /// Output projection result
162    o_proj_out: GpuBuffer<f32>,
163    /// Residual after attention
164    residual1: GpuBuffer<f32>,
165    /// After post-attention RMSNorm
166    norm2_out: GpuBuffer<f32>,
167    /// FFN gate output (seq_len * intermediate_size)
168    gate_out: GpuBuffer<f32>,
169    /// FFN up output (seq_len * intermediate_size)
170    up_out: GpuBuffer<f32>,
171    /// FFN fused SwiGLU output: SiLU(gate) * up (seq_len * intermediate_size)
172    swiglu_out: GpuBuffer<f32>,
173    /// FFN down projection output
174    ffn_out: GpuBuffer<f32>,
175    /// FP16 activation cast buffers for FP16 GEMM dispatch (PMAT-470)
176    /// Allocated lazily on first use when FP16_GEMM=1.
177    norm1_out_f16: Option<GpuBuffer<u16>>,
178    attn_out_f16: Option<GpuBuffer<u16>>,
179    norm2_out_f16: Option<GpuBuffer<u16>>,
180    swiglu_out_f16: Option<GpuBuffer<u16>>,
181    // === Seq-dependent backward scratch (per-layer for activation reuse) ===
182    /// Gradient accumulator for hidden states
183    grad_hidden: GpuBuffer<f32>,
184    /// Gradient for SwiGLU intermediate
185    grad_swiglu: GpuBuffer<f32>,
186    // === Attention layout scratch buffers (GPU-only attention pipeline) ===
187    /// Q in batched layout [num_heads, seq_len, head_dim]
188    attn_q_batched: GpuBuffer<f32>,
189    /// K/V layout temp buffer [num_heads, seq_len, head_dim]
190    attn_kv_temp: GpuBuffer<f32>,
191    /// K transposed / second temp [num_heads, head_dim, seq_len]
192    attn_kv_temp2: GpuBuffer<f32>,
193    // === Attention backward scratch (seq-dependent) ===
194    /// Gradient for attention scores [num_heads * seq_len * seq_len]
195    /// Kept separate from attn_scores because softmax backward reads y while writing grad_x
196    grad_attn_scores: GpuBuffer<f32>,
197    // === LoRA scratch buffers (ENT-153: QLoRA) ===
198    /// LoRA intermediate: x @ A, sized [max_seq_len * max_lora_rank]
199    lora_inter: GpuBuffer<f32>,
200    /// LoRA temp for scaled addition, sized [max_seq_len * max_proj_dim]
201    /// (reuses largest projection dimension for Q/V LoRA output)
202    lora_temp: GpuBuffer<f32>,
203    /// Sequential position indices [0, 1, ..., max_seq_len-1] for batched RoPE
204    rope_positions: GpuBuffer<u32>,
205    /// PMAT-420: Contiguous causal mask for current seq_len [seq_len * seq_len]
206    causal_mask_contiguous: GpuBuffer<f32>,
207    /// PMAT-420: seq_len that causal_mask_contiguous was generated for (cache key)
208    pub(crate) causal_mask_cached_seq_len: usize,
209    /// PMAT-483/entrenar#328: Per-operation timing accumulator (microseconds).
210    /// Index matches StepProfiler OP_* constants. Accumulated across layers per step.
211    /// Zero-overhead when profiling disabled (check op_profiling_enabled first).
212    pub(crate) op_us: [u64; 16],
213    /// Whether per-op profiling is active this step
214    pub(crate) op_profiling_enabled: bool,
215}
216
217#[cfg(feature = "cuda")]
218impl CudaBlockScratch {
219    /// PMAT-483: Start timing an operation (no-op if profiling disabled).
220    #[inline]
221    pub(crate) fn op_begin(&self) -> Option<std::time::Instant> {
222        if self.op_profiling_enabled {
223            Some(std::time::Instant::now())
224        } else {
225            None
226        }
227    }
228
229    /// PMAT-483: Record elapsed time for an operation.
230    #[inline]
231    pub(crate) fn op_end(&mut self, start: Option<std::time::Instant>, op: usize) {
232        if let Some(t) = start {
233            if op < 16 {
234                self.op_us[op] += t.elapsed().as_micros() as u64;
235            }
236        }
237    }
238
239    /// Zero all forward scratch buffers to prevent backward gradient contamination.
240    /// entrenar#318 Tier 1: GPU-side memset via cuMemsetD32Async (no PCIe transfer).
241    /// Max sequence length this scratch was allocated for.
242    pub(crate) fn max_seq_len(&self, hidden_size: usize) -> usize {
243        self.norm1_out.len() / hidden_size.max(1)
244    }
245
246    #[rustfmt::skip]
247    pub(crate) fn zero_forward_buffers(&mut self, stream: &CudaStream) {
248        let z = |b: &mut GpuBuffer<f32>| { b.zero_async(stream).ok(); };
249        z(&mut self.norm1_out); z(&mut self.q); z(&mut self.k); z(&mut self.v); z(&mut self.attn_scores); z(&mut self.attn_out);
250        z(&mut self.o_proj_out); z(&mut self.residual1); z(&mut self.norm2_out); z(&mut self.gate_out); z(&mut self.up_out);
251        z(&mut self.swiglu_out); z(&mut self.ffn_out); z(&mut self.attn_q_batched); z(&mut self.attn_kv_temp); z(&mut self.attn_kv_temp2);
252        z(&mut self.grad_hidden); z(&mut self.grad_swiglu); z(&mut self.grad_attn_scores); z(&mut self.lora_inter); z(&mut self.lora_temp);
253        self.causal_mask_cached_seq_len = 0;
254    }
255
256    /// Allocate scratch buffers for a given model config and max sequence length.
257    ///
258    /// # Contract (C-SCRATCH-001)
259    ///
260    /// All buffer sizes are deterministic from (config, max_seq_len).
261    pub(crate) fn new(
262        config: &TransformerConfig,
263        max_seq_len: usize,
264        ctx: &Arc<CudaContext>,
265        lora_rank: usize,
266    ) -> Result<Self> {
267        let hidden_size = config.hidden_size;
268        let q_dim = config.q_dim();
269        let kv_hidden_size = config.num_kv_heads * config.head_dim();
270        let intermediate_size = config.intermediate_size;
271        let num_heads = config.num_attention_heads;
272        let head_dim = config.head_dim();
273
274        // LoRA scratch: max(q_dim, kv_hidden) for the largest projection output
275        let max_proj_dim = q_dim.max(kv_hidden_size);
276        // Minimum 1 element to avoid zero-size GPU allocation
277        let lora_inter_size = (max_seq_len * lora_rank).max(1);
278        let lora_temp_size = (max_seq_len * max_proj_dim).max(1);
279
280        // C-CAUSAL-001: Precompute causal mask [seq × seq] — shared across all heads
281        // 4 MB for seq=1024. Applied per-head in compute_attention_cuda.
282        let causal_mask_data: Vec<f32> = (0..max_seq_len * max_seq_len)
283            .map(|idx| {
284                let row = idx / max_seq_len;
285                let col = idx % max_seq_len;
286                if col <= row {
287                    0.0f32
288                } else {
289                    f32::NEG_INFINITY
290                }
291            })
292            .collect();
293        Ok(Self {
294            norm1_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
295            q: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
296            k: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
297            v: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
298            attn_scores: GpuBuffer::new(ctx, num_heads * max_seq_len * max_seq_len)?,
299            attn_out: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
300            o_proj_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
301            residual1: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
302            norm2_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
303            gate_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
304            up_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
305            swiglu_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
306            ffn_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
307            norm1_out_f16: None,
308            attn_out_f16: None,
309            norm2_out_f16: None,
310            swiglu_out_f16: None,
311            grad_hidden: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
312            grad_swiglu: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
313            attn_q_batched: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
314            attn_kv_temp: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
315            attn_kv_temp2: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
316            grad_attn_scores: GpuBuffer::new(
317                ctx,
318                num_heads * max_seq_len * max_seq_len.max(head_dim),
319            )?,
320            lora_inter: GpuBuffer::new(ctx, lora_inter_size)?,
321            lora_temp: GpuBuffer::new(ctx, lora_temp_size)?,
322            rope_positions: {
323                let positions: Vec<u32> = (0..max_seq_len as u32).collect();
324                let mut buf = GpuBuffer::new(ctx, max_seq_len)?;
325                buf.copy_from_host(&positions)?;
326                buf
327            },
328            causal_mask_contiguous: GpuBuffer::from_host(ctx, &causal_mask_data)?,
329            causal_mask_cached_seq_len: max_seq_len,
330            op_us: [0u64; 16],
331            op_profiling_enabled: false,
332        })
333    }
334
335    /// PMAT-420: Prepare a contiguous [seq_len * seq_len] causal mask. Cached: only regenerates
336    /// when seq_len changes. Cost: O(seq_len^2) CPU + one H2D upload (~0.01ms for seq=256).
337    pub(crate) fn prepare_causal_mask(
338        &mut self,
339        seq_len: usize,
340        ctx: &Arc<CudaContext>,
341    ) -> crate::autograd::cuda_tensor::Result<()> {
342        if seq_len == self.causal_mask_cached_seq_len {
343            return Ok(());
344        }
345        let mask_data: Vec<f32> = (0..seq_len * seq_len)
346            .map(|idx| {
347                let row = idx / seq_len;
348                let col = idx % seq_len;
349                if col <= row {
350                    0.0f32
351                } else {
352                    f32::NEG_INFINITY
353                }
354            })
355            .collect();
356        self.causal_mask_contiguous = GpuBuffer::from_host(ctx, &mask_data)?;
357        self.causal_mask_cached_seq_len = seq_len;
358        Ok(())
359    }
360}
361
362/// Shared gradient workspace for weight gradients (one per model, NOT per layer).
363///
364/// # Contract (C-GRADWS-001)
365///
366/// Backward processes layers sequentially — only one layer's weight gradients
367/// are computed at a time. Sharing this workspace across layers saves
368/// `(L-1) * per_layer_grad_weight_elements * 4` bytes of VRAM.
369///
370/// For Qwen3-4B: saves 35 * 372 MB = 13.0 GB.
371///
372/// - **Precondition**: Allocated once before training loop starts
373/// - **Postcondition**: After backward() for layer i, contains layer i's weight gradients
374/// - **Invariant**: Buffer sizes match model config; never reallocated during training
375#[cfg(feature = "cuda")]
376pub struct CudaGradWorkspace {
377    /// Gradient for input norm weight [hidden_size]
378    pub(crate) grad_input_norm: GpuBuffer<f32>,
379    /// Gradient for post-attention norm weight [hidden_size]
380    pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
381    /// Gradient for FFN gate projection [hidden_size * intermediate_size]
382    pub(crate) grad_gate: GpuBuffer<f32>,
383    /// Gradient for FFN up projection [hidden_size * intermediate_size]
384    pub(crate) grad_up: GpuBuffer<f32>,
385    /// Gradient for FFN down projection [intermediate_size * hidden_size]
386    pub(crate) grad_down: GpuBuffer<f32>,
387    /// Gradient for Q projection weight [q_dim * hidden_size]
388    pub(crate) grad_w_q: GpuBuffer<f32>,
389    /// Gradient for K projection weight [hidden_size * kv_hidden_size]
390    pub(crate) grad_w_k: GpuBuffer<f32>,
391    /// Gradient for V projection weight [hidden_size * kv_hidden_size]
392    pub(crate) grad_w_v: GpuBuffer<f32>,
393    /// Gradient for output projection weight [hidden_size * q_dim]
394    pub(crate) grad_w_o: GpuBuffer<f32>,
395}
396
397#[cfg(feature = "cuda")]
398impl CudaGradWorkspace {
399    /// Allocate shared gradient workspace for the given model config.
400    ///
401    /// Called once per training run. GEMM weight gradients are fully overwritten
402    /// by each backward pass. Norm gradients use atomicAdd accumulation and MUST
403    /// be zeroed before each rms_norm_backward call (see `zero_norm_grads`).
404    pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> Result<Self> {
405        let h = config.hidden_size;
406        let q = config.q_dim();
407        let kv = config.num_kv_heads * config.head_dim();
408        let i = config.intermediate_size;
409
410        Ok(Self {
411            grad_input_norm: GpuBuffer::new(ctx, h)?,
412            grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
413            grad_gate: GpuBuffer::new(ctx, h * i)?,
414            grad_up: GpuBuffer::new(ctx, h * i)?,
415            grad_down: GpuBuffer::new(ctx, i * h)?,
416            grad_w_q: GpuBuffer::new(ctx, q * h)?,
417            grad_w_k: GpuBuffer::new(ctx, h * kv)?,
418            grad_w_v: GpuBuffer::new(ctx, h * kv)?,
419            grad_w_o: GpuBuffer::new(ctx, h * q)?,
420        })
421    }
422
423    /// Zero norm gradient buffers before rms_norm_backward calls.
424    ///
425    /// The BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
426    /// so these buffers MUST be zeroed before each backward pass. Without this,
427    /// grad_gamma accumulates across steps → exploding norm gradients.
428    pub fn zero_norm_grads(&mut self, zero_buf: &[f32]) -> Result<()> {
429        let n = self.grad_input_norm.len();
430        self.grad_input_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
431            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
432                "Failed to zero grad_input_norm: {e:?}"
433            ))
434        })?;
435        self.grad_post_attn_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
436            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
437                "Failed to zero grad_post_attn_norm: {e:?}"
438            ))
439        })?;
440        Ok(())
441    }
442}
443
444/// GPU-resident AdamW optimizer state for one transformer block.
445///
446/// Stores first (m) and second (v) moment estimates for all 9 weight tensors:
447/// 7 matmul weights + 2 RMSNorm weights. All buffers live on GPU to avoid
448/// CPU↔GPU transfers during training.
449///
450/// # Contract (C-OPTSTATE-001)
451///
452/// - **Precondition**: CUDA context valid, all buffers allocated to match weight dimensions
453/// - **Postcondition**: m and v buffers initialized to zero (unbiased start)
454/// - **Invariant**: Buffer sizes immutable after creation; m/v never reallocated
455#[cfg(feature = "cuda")]
456pub struct GpuBlockOptimizerState {
457    // Attention projection optimizer states
458    m_w_q: GpuBuffer<f32>,
459    v_w_q: GpuBuffer<f32>,
460    m_w_k: GpuBuffer<f32>,
461    v_w_k: GpuBuffer<f32>,
462    m_w_v: GpuBuffer<f32>,
463    v_w_v: GpuBuffer<f32>,
464    m_w_o: GpuBuffer<f32>,
465    v_w_o: GpuBuffer<f32>,
466    // FFN projection optimizer states
467    m_w_gate: GpuBuffer<f32>,
468    v_w_gate: GpuBuffer<f32>,
469    m_w_up: GpuBuffer<f32>,
470    v_w_up: GpuBuffer<f32>,
471    m_w_down: GpuBuffer<f32>,
472    v_w_down: GpuBuffer<f32>,
473    // RMSNorm weight optimizer states
474    m_input_norm: GpuBuffer<f32>,
475    v_input_norm: GpuBuffer<f32>,
476    m_post_attn_norm: GpuBuffer<f32>,
477    v_post_attn_norm: GpuBuffer<f32>,
478}
479
480/// ALB-118: Download GPU optimizer state to host for checkpointing.
481#[cfg(feature = "cuda")]
482impl GpuBlockOptimizerState {
483    /// Returns (suffix, data) pairs for all 18 m/v buffers.
484    /// Suffix is e.g. "m.w_q", "v.w_gate" — caller prefixes with layer index.
485    pub fn download_to_host(
486        &self,
487    ) -> crate::autograd::cuda_tensor::Result<Vec<(String, Vec<f32>)>> {
488        let dl = |name: &str,
489                  buf: &GpuBuffer<f32>|
490         -> crate::autograd::cuda_tensor::Result<(String, Vec<f32>)> {
491            let mut host = vec![0.0f32; buf.len()];
492            buf.copy_to_host(&mut host).map_err(|e| {
493                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
494                    "optimizer D2H {name}: {e}"
495                ))
496            })?;
497            Ok((name.to_string(), host))
498        };
499        Ok(vec![
500            dl("m.w_q", &self.m_w_q)?,
501            dl("v.w_q", &self.v_w_q)?,
502            dl("m.w_k", &self.m_w_k)?,
503            dl("v.w_k", &self.v_w_k)?,
504            dl("m.w_v", &self.m_w_v)?,
505            dl("v.w_v", &self.v_w_v)?,
506            dl("m.w_o", &self.m_w_o)?,
507            dl("v.w_o", &self.v_w_o)?,
508            dl("m.w_gate", &self.m_w_gate)?,
509            dl("v.w_gate", &self.v_w_gate)?,
510            dl("m.w_up", &self.m_w_up)?,
511            dl("v.w_up", &self.v_w_up)?,
512            dl("m.w_down", &self.m_w_down)?,
513            dl("v.w_down", &self.v_w_down)?,
514            dl("m.input_norm", &self.m_input_norm)?,
515            dl("v.input_norm", &self.v_input_norm)?,
516            dl("m.post_attn_norm", &self.m_post_attn_norm)?,
517            dl("v.post_attn_norm", &self.v_post_attn_norm)?,
518        ])
519    }
520
521    /// ALB-118: Upload host optimizer state to GPU (checkpoint resume).
522    /// Missing keys are silently skipped (buffer stays zero-initialized).
523    pub fn restore_from_host(
524        &mut self,
525        data: &std::collections::HashMap<String, Vec<f32>>,
526    ) -> crate::autograd::cuda_tensor::Result<()> {
527        let ul = |name: &str,
528                  buf: &mut GpuBuffer<f32>,
529                  data: &std::collections::HashMap<String, Vec<f32>>|
530         -> crate::autograd::cuda_tensor::Result<()> {
531            if let Some(host_data) = data.get(name) {
532                if host_data.len() == buf.len() {
533                    buf.copy_from_host(host_data).map_err(|e| {
534                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
535                            "optimizer H2D {name}: {e}"
536                        ))
537                    })?;
538                }
539            }
540            Ok(())
541        };
542        ul("m.w_q", &mut self.m_w_q, data)?;
543        ul("v.w_q", &mut self.v_w_q, data)?;
544        ul("m.w_k", &mut self.m_w_k, data)?;
545        ul("v.w_k", &mut self.v_w_k, data)?;
546        ul("m.w_v", &mut self.m_w_v, data)?;
547        ul("v.w_v", &mut self.v_w_v, data)?;
548        ul("m.w_o", &mut self.m_w_o, data)?;
549        ul("v.w_o", &mut self.v_w_o, data)?;
550        ul("m.w_gate", &mut self.m_w_gate, data)?;
551        ul("v.w_gate", &mut self.v_w_gate, data)?;
552        ul("m.w_up", &mut self.m_w_up, data)?;
553        ul("v.w_up", &mut self.v_w_up, data)?;
554        ul("m.w_down", &mut self.m_w_down, data)?;
555        ul("v.w_down", &mut self.v_w_down, data)?;
556        ul("m.input_norm", &mut self.m_input_norm, data)?;
557        ul("v.input_norm", &mut self.v_input_norm, data)?;
558        ul("m.post_attn_norm", &mut self.m_post_attn_norm, data)?;
559        ul("v.post_attn_norm", &mut self.v_post_attn_norm, data)?;
560        Ok(())
561    }
562}
563
564#[cfg(feature = "cuda")]
565impl CudaTransformerBlock {
566    /// Create a new CUDA transformer block from CPU tensors
567    ///
568    /// Uploads all weights to GPU memory.
569    pub fn new(
570        config: &TransformerConfig,
571        layer_idx: usize,
572        ctx: Arc<CudaContext>,
573        input_norm_weight: &[f32],
574        post_attn_norm_weight: &[f32],
575        w_q: &[f32],
576        w_k: &[f32],
577        w_v: &[f32],
578        w_o: &[f32],
579        w_gate: &[f32],
580        w_up: &[f32],
581        w_down: &[f32],
582        max_seq_len: usize,
583    ) -> Result<Self> {
584        let hidden_size = config.hidden_size;
585        let q_dim = config.q_dim(); // num_heads * head_dim (may differ from hidden_size)
586        let kv_hidden_size = config.num_kv_heads * config.head_dim();
587        let intermediate_size = config.intermediate_size;
588        let num_heads = config.num_attention_heads;
589
590        // Upload weights to GPU
591        let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
592        let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
593        let w_q = GpuBuffer::from_host(&ctx, w_q)?;
594        let w_k = GpuBuffer::from_host(&ctx, w_k)?;
595        let w_v = GpuBuffer::from_host(&ctx, w_v)?;
596        let w_o = GpuBuffer::from_host(&ctx, w_o)?;
597        let w_gate = GpuBuffer::from_host(&ctx, w_gate)?;
598        let w_up = GpuBuffer::from_host(&ctx, w_up)?;
599        let w_down = GpuBuffer::from_host(&ctx, w_down)?;
600
601        // C-CAUSAL-001: Precompute causal mask for NF4 path
602        let single_mask: Vec<f32> = (0..max_seq_len * max_seq_len)
603            .map(|idx| {
604                let row = idx / max_seq_len;
605                let col = idx % max_seq_len;
606                if col <= row {
607                    0.0f32
608                } else {
609                    f32::NEG_INFINITY
610                }
611            })
612            .collect();
613        // Allocate scratch buffers — Q and attn_out need q_dim, not hidden_size
614        let scratch = CudaBlockScratch {
615            norm1_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
616            q: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
617            k: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
618            v: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
619            attn_scores: GpuBuffer::new(&ctx, num_heads * max_seq_len * max_seq_len)?,
620            attn_out: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
621            o_proj_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
622            residual1: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
623            norm2_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
624            gate_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
625            up_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
626            swiglu_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
627            ffn_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
628            norm1_out_f16: None,
629            attn_out_f16: None,
630            norm2_out_f16: None,
631            swiglu_out_f16: None,
632            // Seq-dependent backward scratch
633            grad_hidden: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
634            grad_swiglu: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
635            // Attention layout scratch (all sized for num_heads, handles GQA expansion)
636            attn_q_batched: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
637            attn_kv_temp: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
638            attn_kv_temp2: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
639            // Attention backward gradient buffers (ENT-151b)
640            // grad_attn_scores needs max(H*S*S, H*S*hd) for buffer reuse safety
641            grad_attn_scores: GpuBuffer::new(
642                &ctx,
643                num_heads * max_seq_len * max_seq_len.max(config.head_dim()),
644            )?,
645            // LoRA scratch (unused for fp32 blocks, minimum allocation)
646            lora_inter: GpuBuffer::new(&ctx, 1)?,
647            lora_temp: GpuBuffer::new(&ctx, 1)?,
648            rope_positions: {
649                let positions: Vec<u32> = (0..max_seq_len as u32).collect();
650                let mut buf = GpuBuffer::new(&ctx, max_seq_len)?;
651                buf.copy_from_host(&positions)?;
652                buf
653            },
654            causal_mask_contiguous: GpuBuffer::from_host(&ctx, &single_mask)?,
655            causal_mask_cached_seq_len: max_seq_len,
656            op_us: [0u64; 16],
657            op_profiling_enabled: false,
658        };
659
660        Ok(Self {
661            config: config.clone(),
662            layer_idx,
663            input_norm_weight,
664            post_attn_norm_weight,
665            w_q,
666            w_k,
667            w_v,
668            w_o,
669            w_gate,
670            w_up,
671            w_down,
672            ctx,
673            scratch,
674            norm_zero_buf: vec![0.0f32; hidden_size],
675            q_norm_weight: None, // ENT-270: set via set_qk_norm() after construction
676            k_norm_weight: None,
677        })
678    }
679
680    /// Set QK-norm weights (ENT-270). Called after construction when loading Qwen3 models.
681    #[allow(dead_code)]
682    pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()> {
683        self.q_norm_weight = Some(GpuBuffer::from_host(&self.ctx, q_norm)?);
684        self.k_norm_weight = Some(GpuBuffer::from_host(&self.ctx, k_norm)?);
685        Ok(())
686    }
687
688    /// Forward pass - all operations on GPU
689    ///
690    /// # Arguments
691    /// * `input` - Input tensor on GPU (seq_len * hidden_size)
692    /// * `output` - Output tensor on GPU (seq_len * hidden_size)
693    /// * `seq_len` - Sequence length
694    /// * `stream` - CUDA stream for async execution
695    pub fn forward(
696        &mut self,
697        input: &GpuBuffer<f32>,
698        output: &mut GpuBuffer<f32>,
699        seq_len: usize,
700        stream: &CudaStream,
701    ) -> Result<()> {
702        let hidden_size = self.config.hidden_size;
703        let q_dim = self.config.q_dim();
704        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
705        let intermediate_size = self.config.intermediate_size;
706
707        // === Pre-attention RMSNorm (ENT-147) ===
708        rms_norm_forward(
709            input,
710            &self.input_norm_weight,
711            &mut self.scratch.norm1_out,
712            saturating_u32(seq_len),
713            saturating_u32(hidden_size),
714            stream,
715        )?;
716
717        // === Q, K, V Projections (CUDA GEMM) ===
718        // C[seq,q_dim] = A[seq,hidden] @ B[hidden,q_dim]
719        gemm_forward(
720            &self.scratch.norm1_out,
721            &self.w_q,
722            &mut self.scratch.q,
723            saturating_u32(seq_len),
724            saturating_u32(hidden_size),
725            saturating_u32(q_dim),
726            stream,
727        )?;
728
729        gemm_forward(
730            &self.scratch.norm1_out,
731            &self.w_k,
732            &mut self.scratch.k,
733            saturating_u32(seq_len),
734            saturating_u32(hidden_size),
735            saturating_u32(kv_hidden_size),
736            stream,
737        )?;
738
739        gemm_forward(
740            &self.scratch.norm1_out,
741            &self.w_v,
742            &mut self.scratch.v,
743            saturating_u32(seq_len),
744            saturating_u32(hidden_size),
745            saturating_u32(kv_hidden_size),
746            stream,
747        )?;
748
749        // === Multi-Head Attention (GPU-only, zero CPU transfers) ===
750        self.compute_attention_cuda(seq_len, stream)?;
751
752        // === Output Projection ===
753        // C[seq,hidden] = A[seq,q_dim] @ B[q_dim,hidden]
754        gemm_forward(
755            &self.scratch.attn_out,
756            &self.w_o,
757            &mut self.scratch.o_proj_out,
758            saturating_u32(seq_len),
759            saturating_u32(q_dim),
760            saturating_u32(hidden_size),
761            stream,
762        )?;
763
764        // === Residual Add (input + attention_output) ===
765        cuda_add(
766            input,
767            &self.scratch.o_proj_out,
768            &mut self.scratch.residual1,
769            seq_len * hidden_size,
770            stream,
771        )?;
772
773        // === Post-attention RMSNorm ===
774        rms_norm_forward(
775            &self.scratch.residual1,
776            &self.post_attn_norm_weight,
777            &mut self.scratch.norm2_out,
778            saturating_u32(seq_len),
779            saturating_u32(hidden_size),
780            stream,
781        )?;
782
783        // === FFN: Gate + Up Projections ===
784        gemm_forward(
785            &self.scratch.norm2_out,
786            &self.w_gate,
787            &mut self.scratch.gate_out,
788            saturating_u32(seq_len),
789            saturating_u32(hidden_size),
790            saturating_u32(intermediate_size),
791            stream,
792        )?;
793
794        gemm_forward(
795            &self.scratch.norm2_out,
796            &self.w_up,
797            &mut self.scratch.up_out,
798            saturating_u32(seq_len),
799            saturating_u32(hidden_size),
800            saturating_u32(intermediate_size),
801            stream,
802        )?;
803
804        // === FFN: Fused SwiGLU (ENT-150) - SiLU(gate) * up in single kernel ===
805        fused_swiglu_forward(
806            &self.scratch.gate_out,
807            &self.scratch.up_out,
808            &mut self.scratch.swiglu_out,
809            saturating_u32(seq_len * intermediate_size),
810            stream,
811        )?;
812
813        // === FFN: Down Projection ===
814        gemm_forward(
815            &self.scratch.swiglu_out,
816            &self.w_down,
817            &mut self.scratch.ffn_out,
818            saturating_u32(seq_len),
819            saturating_u32(intermediate_size),
820            saturating_u32(hidden_size),
821            stream,
822        )?;
823
824        // === Final Residual Add (residual1 + ffn_output) ===
825        cuda_add(
826            &self.scratch.residual1,
827            &self.scratch.ffn_out,
828            output,
829            seq_len * hidden_size,
830            stream,
831        )?;
832
833        Ok(())
834    }
835
836    /// Compute multi-head attention entirely on GPU (zero CPU transfers)
837    ///
838    /// # Contract (C-ATTN-001)
839    ///
840    /// - **Precondition**: Q [seq, hidden], K [seq, kv_hidden], V [seq, kv_hidden] on GPU
841    /// - **Postcondition**: attn_out [seq, hidden] = concat(head_0..head_H) where
842    ///   head_h = softmax(Q_h @ K_{kv(h)}^T / √d_k) @ V_{kv(h)}
843    /// - **Invariant**: Zero gpu_to_vec / vec_to_gpu calls; numerically equivalent to CPU
844    ///
845    /// Uses existing trueno-gpu kernels:
846    /// - `InterleavedToBatchedKernel` for Q/K/V layout conversion
847    /// - `BatchedTransposeKernel` for K^T
848    /// - `Batched4DGemmKernel` for Q@K^T and attn@V
849    /// - `ScaleKernel` for 1/√d_k scaling
850    /// - `BatchedSoftmaxKernel` for row-wise softmax
851    /// - `BatchedToInterleavedKernel` for output layout conversion
852    /// - D2D copies for GQA head expansion
853    fn compute_attention_cuda(&mut self, seq_len: usize, stream: &CudaStream) -> Result<()> {
854        let num_heads = self.config.num_attention_heads;
855        let num_kv_heads = self.config.num_kv_heads;
856        let head_dim = self.config.head_dim();
857        let heads_per_kv = num_heads / num_kv_heads;
858        let scale = 1.0 / (head_dim as f32).sqrt();
859
860        let seq = saturating_u32(seq_len);
861        let nh = saturating_u32(num_heads);
862        let nkv = saturating_u32(num_kv_heads);
863        let hd = saturating_u32(head_dim);
864
865        // PMAT-420: Ensure causal mask is contiguous for current seq_len.
866        self.scratch.prepare_causal_mask(seq_len, &self.ctx)?;
867
868        // ── ENT-270: Apply QK-norm (per-head RMSNorm) on Q and K ──────────
869        // SAFETY: In-place GPU operations — CUDA kernels read all input before writing output.
870        // Rust borrow checker cannot verify GPU kernel memory access patterns, so we use
871        // raw pointer reborrow to allow the same buffer as both input and output.
872        if let Some(ref q_norm) = self.q_norm_weight {
873            for pos in 0..seq_len {
874                let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
875                per_head_rmsnorm_forward(q_ref, q_norm, &mut self.scratch.q, nh, hd, pos, stream)?;
876            }
877        }
878        if let Some(ref k_norm) = self.k_norm_weight {
879            for pos in 0..seq_len {
880                let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
881                per_head_rmsnorm_forward(k_ref, k_norm, &mut self.scratch.k, nkv, hd, pos, stream)?;
882            }
883        }
884
885        // ── ENT-270: Apply RoPE (NeoX half-rotation) on Q and K ──────────
886        // ALB-119: Batched launch (2 kernels) replaces per-position loop (2*seq_len kernels)
887        let rope_theta = self.config.rope_theta;
888        {
889            let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
890            batched_rope_neox_forward(
891                q_ref,
892                &mut self.scratch.q,
893                &self.scratch.rope_positions,
894                nh,
895                hd,
896                seq,
897                rope_theta,
898                stream,
899            )?;
900            let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
901            batched_rope_neox_forward(
902                k_ref,
903                &mut self.scratch.k,
904                &self.scratch.rope_positions,
905                nkv,
906                hd,
907                seq,
908                rope_theta,
909                stream,
910            )?;
911        }
912
913        // Step 1: Q interleaved [seq, num_heads * head_dim] → batched [num_heads, seq, head_dim]
914        interleaved_to_batched_forward(
915            &self.scratch.q,
916            &mut self.scratch.attn_q_batched,
917            seq,
918            nh,
919            hd,
920            stream,
921        )?;
922
923        // Step 2: K interleaved [seq, num_kv_heads * head_dim] → batched [num_kv_heads, seq, head_dim]
924        interleaved_to_batched_forward(
925            &self.scratch.k,
926            &mut self.scratch.attn_kv_temp,
927            seq,
928            nkv,
929            hd,
930            stream,
931        )?;
932
933        // Step 3: GQA expansion + transpose for K
934        if heads_per_kv == 1 {
935            // MHA: transpose directly [num_heads, seq, head_dim] → [num_heads, head_dim, seq]
936            batched_transpose_forward(
937                &self.scratch.attn_kv_temp,
938                &mut self.scratch.attn_kv_temp2,
939                nh,
940                seq,
941                hd,
942                stream,
943            )?;
944        } else {
945            // GQA: expand [num_kv_heads, seq, hd] → [num_heads, seq, hd] in attn_kv_temp2
946            expand_kv_heads(
947                &self.scratch.attn_kv_temp,
948                &mut self.scratch.attn_kv_temp2,
949                num_kv_heads,
950                heads_per_kv,
951                seq_len * head_dim,
952                stream,
953            )?;
954            // Transpose expanded K: [num_heads, seq, hd] → [num_heads, hd, seq] in attn_kv_temp
955            batched_transpose_forward(
956                &self.scratch.attn_kv_temp2,
957                &mut self.scratch.attn_kv_temp,
958                nh,
959                seq,
960                hd,
961                stream,
962            )?;
963            // Move K^T to attn_kv_temp2 for consistent naming below
964            // (swap pointers via D2D copy — attn_kv_temp → attn_kv_temp2)
965            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
966            unsafe {
967                self.scratch
968                    .attn_kv_temp2
969                    .copy_from_buffer_async(&self.scratch.attn_kv_temp, stream)
970                    .map_err(|e| {
971                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
972                            "K^T buffer copy failed: {e}"
973                        ))
974                    })?;
975            }
976        }
977
978        // Step 4: Q @ K^T → attn_scores [num_heads, seq, seq]
979        // attn_q_batched: [1, num_heads, seq, head_dim]
980        // attn_kv_temp2:  [1, num_heads, head_dim, seq] (K transposed)
981        // attn_scores:    [1, num_heads, seq, seq]
982        batched_4d_gemm_forward(
983            &self.scratch.attn_q_batched,
984            &self.scratch.attn_kv_temp2,
985            &mut self.scratch.attn_scores,
986            1,
987            nh,
988            seq,
989            seq,
990            hd,
991            stream,
992        )?;
993
994        // Step 5: Scale scores by 1/√d_k (in-place)
995        let total_scores = nh * seq * seq;
996        {
997            // SAFETY: In-place aliasing is safe for element-wise operations where each
998            // element is read before being written. ScaleKernel processes elements
999            // independently. The view is forgotten to prevent double-free.
1000            let scores_view = unsafe {
1001                GpuBuffer::<f32>::from_raw_parts(
1002                    self.scratch.attn_scores.as_ptr(),
1003                    self.scratch.attn_scores.len(),
1004                )
1005            };
1006            scale_forward(
1007                &scores_view,
1008                &mut self.scratch.attn_scores,
1009                scale,
1010                total_scores,
1011                stream,
1012            )?;
1013            leak(scores_view);
1014        }
1015
1016        // Step 5.5 (C-CAUSAL-001): Apply causal mask — add -inf to future positions
1017        // Loop over heads, adding [seq, seq] mask to each head's scores slice.
1018        // PMAT-420: Use causal_mask_contiguous (correctly strided for seq_len)
1019        // instead of causal_mask (strided at max_seq_len, causes row misalignment
1020        // when seq_len < max_seq_len, leading to NaN after deep layers).
1021        {
1022            let seq_sq = (seq * seq) as usize;
1023            let mask_ptr = self.scratch.causal_mask_contiguous.as_ptr();
1024            let scores_base = self.scratch.attn_scores.as_ptr();
1025            for head in 0..nh as usize {
1026                let byte_offset = (head * seq_sq * 4) as u64; // f32 = 4 bytes
1027                let head_ptr = scores_base + byte_offset;
1028                // SAFETY: mask and scores_slice are non-overlapping GPU regions.
1029                // output aliases scores_slice — safe for element-wise add (read before write).
1030                // Views are leaked to prevent double-free of GPU memory.
1031                let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
1032                let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1033                let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1034                residual_add_forward(&mask_view, &scores_view, &mut out_view, seq * seq, stream)?;
1035                leak(mask_view);
1036                leak(scores_view);
1037                leak(out_view);
1038            }
1039        }
1040
1041        // Step 6: Row-wise softmax → attn_weights [num_heads * seq, seq] (in-place)
1042        let total_rows = nh * seq;
1043        {
1044            // SAFETY: In-place aliasing is safe for BatchedSoftmaxKernel which uses
1045            // shared memory for row-wise reduction. Each row is fully read into shared
1046            // memory before any output is written. The view is forgotten to prevent double-free.
1047            let scores_view = unsafe {
1048                GpuBuffer::<f32>::from_raw_parts(
1049                    self.scratch.attn_scores.as_ptr(),
1050                    self.scratch.attn_scores.len(),
1051                )
1052            };
1053            batched_softmax_forward(
1054                &scores_view,
1055                &mut self.scratch.attn_scores,
1056                total_rows,
1057                seq,
1058                stream,
1059            )?;
1060            leak(scores_view);
1061        }
1062
1063        // Step 7: V layout conversion + GQA expansion
1064        interleaved_to_batched_forward(
1065            &self.scratch.v,
1066            &mut self.scratch.attn_kv_temp,
1067            seq,
1068            nkv,
1069            hd,
1070            stream,
1071        )?;
1072
1073        if heads_per_kv == 1 {
1074            // MHA: V already in [num_heads, seq, head_dim] in attn_kv_temp
1075        } else {
1076            // GQA: expand V [num_kv_heads, seq, hd] → [num_heads, seq, hd]
1077            expand_kv_heads(
1078                &self.scratch.attn_kv_temp,
1079                &mut self.scratch.attn_kv_temp2,
1080                num_kv_heads,
1081                heads_per_kv,
1082                seq_len * head_dim,
1083                stream,
1084            )?;
1085            // Copy expanded V back to attn_kv_temp for the GEMM
1086            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1087            unsafe {
1088                self.scratch
1089                    .attn_kv_temp
1090                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1091                    .map_err(|e| {
1092                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1093                            "V expanded buffer copy failed: {e}"
1094                        ))
1095                    })?;
1096            }
1097        }
1098
1099        // Step 8: attn_weights @ V → attn_result [num_heads, seq, head_dim]
1100        // attn_scores:   [1, num_heads, seq, seq]
1101        // attn_kv_temp:  [1, num_heads, seq, head_dim]
1102        // → attn_q_batched: [1, num_heads, seq, head_dim] (reuse Q buffer)
1103        batched_4d_gemm_forward(
1104            &self.scratch.attn_scores,
1105            &self.scratch.attn_kv_temp,
1106            &mut self.scratch.attn_q_batched,
1107            1,
1108            nh,
1109            seq,
1110            hd,
1111            seq,
1112            stream,
1113        )?;
1114
1115        // Step 9: Convert back to interleaved [seq, num_heads * head_dim] → attn_out
1116        batched_to_interleaved_forward(
1117            &self.scratch.attn_q_batched,
1118            &mut self.scratch.attn_out,
1119            seq,
1120            nh,
1121            hd,
1122            stream,
1123        )?;
1124
1125        Ok(())
1126    }
1127
1128    /// Get layer index
1129    pub fn layer_idx(&self) -> usize {
1130        self.layer_idx
1131    }
1132
1133    /// Get configuration
1134    pub fn config(&self) -> &TransformerConfig {
1135        &self.config
1136    }
1137
1138    /// Backward pass - gradient computation on GPU (ENT-151)
1139    ///
1140    /// Computes gradients for all parameters given upstream gradient.
1141    ///
1142    /// # Arguments
1143    /// * `input` - Original input from forward pass (seq_len * hidden_size)
1144    /// * `grad_output` - Gradient from upstream layer (seq_len * hidden_size)
1145    /// * `grad_input` - Output: gradient w.r.t. input (seq_len * hidden_size)
1146    /// * `seq_len` - Sequence length
1147    /// * `stream` - CUDA stream for async execution
1148    ///
1149    /// # Returns
1150    /// Gradients are accumulated into the scratch buffers:
1151    /// - `scratch.grad_input_norm` - Gradient for input RMSNorm weight
1152    /// - `scratch.grad_post_attn_norm` - Gradient for post-attention RMSNorm weight
1153    /// - `scratch.grad_gate/up/down` - Gradients for FFN weights
1154    /// - `scratch.grad_w_q/w_k/w_v/w_o` - Gradients for attention projection weights
1155    #[provable_contracts_macros::contract("backward-pass-v1", equation = "backward")]
1156    pub fn backward(
1157        &mut self,
1158        input: &GpuBuffer<f32>,
1159        grad_output: &GpuBuffer<f32>,
1160        grad_input: &mut GpuBuffer<f32>,
1161        seq_len: usize,
1162        stream: &CudaStream,
1163        grad_ws: &mut CudaGradWorkspace,
1164    ) -> Result<()> {
1165        let hidden_size = self.config.hidden_size;
1166        let intermediate_size = self.config.intermediate_size;
1167        let eps = 1e-5_f32;
1168
1169        // Zero norm gradient buffers before backward pass.
1170        // BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
1171        // so buffers must be zeroed before each call to prevent cross-step accumulation.
1172        grad_ws.zero_norm_grads(&self.norm_zero_buf)?;
1173
1174        // Backward through final residual: output = residual1 + ffn_output
1175        // grad_output flows to BOTH residual1 (identity skip) and ffn_output path.
1176        self.backward_ffn(grad_output, seq_len, hidden_size, intermediate_size, stream, grad_ws)?;
1177
1178        // Backward through post-attention RMSNorm (FFN path gradient only)
1179        self.backward_post_attn_norm(grad_input, seq_len, hidden_size, eps, stream, grad_ws)?;
1180
1181        // C-RESIDUAL-001 / entrenar#313: Second residual skip gradient.
1182        // Forward: output = residual1 + ffn_output
1183        // The identity skip grad_residual1 = grad_output must bypass the
1184        // post-attention RMSNorm backward entirely — add AFTER norm backward.
1185        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
1186
1187        // Backward through attention: output projection, attention weights, Q/K/V projections
1188        // (ENT-151b: previously missing — attention params received no gradients)
1189        self.backward_attention(grad_input, seq_len, stream, grad_ws)?;
1190
1191        // Backward through first residual connection and input RMSNorm
1192        self.backward_residual_and_input_norm(
1193            input,
1194            grad_output,
1195            grad_input,
1196            seq_len,
1197            hidden_size,
1198            eps,
1199            stream,
1200            grad_ws,
1201        )?;
1202
1203        Ok(())
1204    }
1205
1206    /// Backward through FFN: down projection, SwiGLU, gate+up projections.
1207    ///
1208    /// SwiGLU(gate, up) = silu(gate) * up
1209    /// ∂L/∂gate = ∂L/∂swiglu * up * silu'(gate)
1210    /// ∂L/∂up   = ∂L/∂swiglu * silu(gate)
1211    ///
1212    /// Buffer reuse plan (all [S,I] unless noted):
1213    ///   grad_swiglu  = ∂L/∂swiglu          (computed step 1, read steps 2/4/6)
1214    ///   swiglu_out  → temp1 (step 2)       → silu(gate) (step 4)
1215    ///   up_out      → grad_gate (step 3)
1216    ///   gate_out    → grad_up (step 6)
1217    ///   ffn_out     → grad_norm2_gate [S,H] (step 8)
1218    ///   grad_hidden → grad_norm2_up [S,H]   (step 9)
1219    ///   norm2_out   → accumulated grad [S,H] (step 10)
1220    fn backward_ffn(
1221        &mut self,
1222        grad_output: &GpuBuffer<f32>,
1223        seq_len: usize,
1224        hidden_size: usize,
1225        intermediate_size: usize,
1226        stream: &CudaStream,
1227        grad_ws: &mut CudaGradWorkspace,
1228    ) -> Result<()> {
1229        let n_inter = saturating_u32(seq_len * intermediate_size);
1230        let n_hidden = saturating_u32(seq_len * hidden_size);
1231
1232        // Step 1: grad_swiglu = grad_ffn_out @ w_down^T  [S,I]
1233        gemm_backward_a(
1234            grad_output,
1235            &self.w_down,
1236            &mut self.scratch.grad_swiglu,
1237            saturating_u32(seq_len),
1238            saturating_u32(intermediate_size),
1239            saturating_u32(hidden_size),
1240            stream,
1241        )?;
1242
1243        // Step 2: grad_w_down = swiglu_out^T @ grad_ffn_out  [I,H]
1244        // (swiglu_out free after this)
1245        gemm_backward_b(
1246            &self.scratch.swiglu_out,
1247            grad_output,
1248            &mut grad_ws.grad_down,
1249            saturating_u32(seq_len),
1250            saturating_u32(intermediate_size),
1251            saturating_u32(hidden_size),
1252            stream,
1253        )?;
1254
1255        // === SwiGLU backward: swiglu = silu(gate) * up ===
1256
1257        // Step 3: temp1 = grad_swiglu * up_out → swiglu_out [S,I]
1258        elementwise_mul_forward(
1259            &self.scratch.grad_swiglu,
1260            &self.scratch.up_out,
1261            &mut self.scratch.swiglu_out,
1262            n_inter,
1263            stream,
1264        )?;
1265
1266        // Step 4: grad_gate = silu_backward(gate_out, temp1) → up_out [S,I]
1267        // Computes: (grad_swiglu * up_out) * silu'(gate_out) = correct ∂L/∂gate
1268        silu_backward(
1269            &self.scratch.gate_out,
1270            &self.scratch.swiglu_out,
1271            &mut self.scratch.up_out,
1272            stream,
1273        )?;
1274        // up_out now holds grad_gate [S,I]
1275
1276        // Step 5: silu_gate = silu(gate_out) → swiglu_out [S,I]
1277        silu_forward(&self.scratch.gate_out, &mut self.scratch.swiglu_out, n_inter, stream)?;
1278
1279        // Step 6: grad_up = grad_swiglu * silu_gate → gate_out [S,I]
1280        elementwise_mul_forward(
1281            &self.scratch.grad_swiglu,
1282            &self.scratch.swiglu_out,
1283            &mut self.scratch.gate_out,
1284            n_inter,
1285            stream,
1286        )?;
1287        // gate_out now holds grad_up [S,I]
1288
1289        // === Weight gradients ===
1290
1291        // Step 7a: grad_w_gate = norm2_out^T @ grad_gate (in up_out)  [H,I]
1292        gemm_backward_b(
1293            &self.scratch.norm2_out,
1294            &self.scratch.up_out,
1295            &mut grad_ws.grad_gate,
1296            saturating_u32(seq_len),
1297            saturating_u32(hidden_size),
1298            saturating_u32(intermediate_size),
1299            stream,
1300        )?;
1301
1302        // Step 7b: grad_w_up = norm2_out^T @ grad_up (in gate_out)  [H,I]
1303        gemm_backward_b(
1304            &self.scratch.norm2_out,
1305            &self.scratch.gate_out,
1306            &mut grad_ws.grad_up,
1307            saturating_u32(seq_len),
1308            saturating_u32(hidden_size),
1309            saturating_u32(intermediate_size),
1310            stream,
1311        )?;
1312
1313        // === Input gradient (accumulate gate + up paths) ===
1314
1315        // Step 8: grad_norm2_gate = grad_gate @ w_gate^T → ffn_out [S,H]
1316        gemm_backward_a(
1317            &self.scratch.up_out,
1318            &self.w_gate,
1319            &mut self.scratch.ffn_out,
1320            saturating_u32(seq_len),
1321            saturating_u32(hidden_size),
1322            saturating_u32(intermediate_size),
1323            stream,
1324        )?;
1325
1326        // Step 9: grad_norm2_up = grad_up @ w_up^T → grad_hidden [S,H]
1327        gemm_backward_a(
1328            &self.scratch.gate_out,
1329            &self.w_up,
1330            &mut self.scratch.grad_hidden,
1331            saturating_u32(seq_len),
1332            saturating_u32(hidden_size),
1333            saturating_u32(intermediate_size),
1334            stream,
1335        )?;
1336
1337        // Step 10: norm2_out = grad_norm2_gate + grad_norm2_up  [S,H]
1338        residual_add_forward(
1339            &self.scratch.ffn_out,
1340            &self.scratch.grad_hidden,
1341            &mut self.scratch.norm2_out,
1342            n_hidden,
1343            stream,
1344        )?;
1345
1346        Ok(())
1347    }
1348
1349    /// Backward through post-attention RMSNorm.
1350    fn backward_post_attn_norm(
1351        &mut self,
1352        grad_input: &mut GpuBuffer<f32>,
1353        seq_len: usize,
1354        hidden_size: usize,
1355        eps: f32,
1356        stream: &CudaStream,
1357        grad_ws: &mut CudaGradWorkspace,
1358    ) -> Result<()> {
1359        // D2D copy norm2_out → grad_hidden (avoids D2H + H2D round-trip)
1360        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1361        unsafe {
1362            self.scratch
1363                .grad_hidden
1364                .copy_from_buffer_async(&self.scratch.norm2_out, stream)
1365                .map_err(|e| {
1366                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1367                        "Backward norm D2D copy failed: {e}"
1368                    ))
1369                })?;
1370        }
1371
1372        rms_norm_backward(
1373            &self.scratch.residual1,
1374            &self.post_attn_norm_weight,
1375            &self.scratch.grad_hidden,
1376            grad_input,
1377            &mut grad_ws.grad_post_attn_norm,
1378            saturating_u32(seq_len),
1379            saturating_u32(hidden_size),
1380            eps,
1381            stream,
1382        )
1383    }
1384
1385    /// Backward through multi-head attention (ENT-151b)
1386    ///
1387    /// Reverses the forward attention pipeline:
1388    /// output_proj → layout → attn_weights@V → softmax → scale → Q@K^T → layout → Q/K/V proj
1389    ///
1390    /// # Contract (C-ATTN-BACK-001)
1391    ///
1392    /// - **Precondition**: grad_input contains gradient from post-attention norm backward,
1393    ///   scratch.{q, k, v, attn_scores, attn_out, norm1_out} contain forward pass values
1394    /// - **Postcondition**: grad_hidden contains gradient w.r.t. norm1_out (input to Q/K/V proj),
1395    ///   grad_w_{q,k,v,o} contain weight gradients for attention projections
1396    /// - **Invariant**: Zero CPU-side data transfers; all operations on GPU
1397    fn backward_attention(
1398        &mut self,
1399        grad_input: &mut GpuBuffer<f32>,
1400        seq_len: usize,
1401        stream: &CudaStream,
1402        grad_ws: &mut CudaGradWorkspace,
1403    ) -> Result<()> {
1404        let hidden_size = self.config.hidden_size;
1405        let q_dim = self.config.q_dim();
1406        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
1407        let num_heads = self.config.num_attention_heads;
1408        let num_kv_heads = self.config.num_kv_heads;
1409        let head_dim = self.config.head_dim();
1410        let heads_per_kv = num_heads / num_kv_heads;
1411        let scale = 1.0 / (head_dim as f32).sqrt();
1412
1413        let seq = saturating_u32(seq_len);
1414        let nh = saturating_u32(num_heads);
1415        let nkv = saturating_u32(num_kv_heads);
1416        let hd = saturating_u32(head_dim);
1417
1418        // === Step 4.1: Output projection backward ===
1419        // Forward: o_proj_out[seq,hidden] = attn_out[seq,q_dim] @ w_o[q_dim,hidden]
1420        //   m=seq, k=q_dim, n=hidden
1421        // grad_attn_out[seq,q_dim] = grad_o_proj[seq,hidden] @ w_o^T[hidden,q_dim]
1422        gemm_backward_a(
1423            grad_input,
1424            &self.w_o,
1425            &mut self.scratch.grad_hidden,
1426            seq,
1427            saturating_u32(q_dim),
1428            saturating_u32(hidden_size),
1429            stream,
1430        )?;
1431
1432        // grad_w_o[q_dim,hidden] = attn_out^T[q_dim,seq] @ grad_o_proj[seq,hidden]
1433        gemm_backward_b(
1434            &self.scratch.attn_out,
1435            grad_input,
1436            &mut grad_ws.grad_w_o,
1437            seq,
1438            saturating_u32(q_dim),
1439            saturating_u32(hidden_size),
1440            stream,
1441        )?;
1442
1443        // === Step 4.2: Layout conversion ===
1444        // grad_attn_out [seq, q_dim] → grad_attn_batched [num_heads, seq, head_dim]
1445        // Reuse attn_q_batched for grad_attn_batched
1446        interleaved_to_batched_forward(
1447            &self.scratch.grad_hidden,
1448            &mut self.scratch.attn_q_batched,
1449            seq,
1450            nh,
1451            hd,
1452            stream,
1453        )?;
1454
1455        // === Step 4.3: Backward through attn_weights @ V ===
1456        // Forward was: attn_result = attn_weights @ V_batched
1457        // Reconstruct V_batched from preserved v
1458        interleaved_to_batched_forward(
1459            &self.scratch.v,
1460            &mut self.scratch.attn_kv_temp,
1461            seq,
1462            nkv,
1463            hd,
1464            stream,
1465        )?;
1466
1467        // GQA expand V if needed
1468        if heads_per_kv > 1 {
1469            expand_kv_heads(
1470                &self.scratch.attn_kv_temp,
1471                &mut self.scratch.attn_kv_temp2,
1472                num_kv_heads,
1473                heads_per_kv,
1474                seq_len * head_dim,
1475                stream,
1476            )?;
1477            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1478            unsafe {
1479                self.scratch
1480                    .attn_kv_temp
1481                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1482                    .map_err(|e| {
1483                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1484                            "Attn backward V expand D2D copy failed: {e}"
1485                        ))
1486                    })?;
1487            }
1488        }
1489        // attn_kv_temp now has V_batched [num_heads, seq, head_dim]
1490
1491        // Transpose V: [num_heads, seq, head_dim] → [num_heads, head_dim, seq]
1492        batched_transpose_forward(
1493            &self.scratch.attn_kv_temp,
1494            &mut self.scratch.attn_kv_temp2,
1495            nh,
1496            seq,
1497            hd,
1498            stream,
1499        )?;
1500        // attn_kv_temp2 = V^T [num_heads, head_dim, seq]
1501
1502        // grad_attn_weights = grad_attn_batched @ V^T → grad_attn_scores [H, seq, seq]
1503        batched_4d_gemm_forward(
1504            &self.scratch.attn_q_batched,
1505            &self.scratch.attn_kv_temp2,
1506            &mut self.scratch.grad_attn_scores,
1507            1,
1508            nh,
1509            seq,
1510            seq,
1511            hd,
1512            stream,
1513        )?;
1514
1515        // grad_V = attn_weights^T @ grad_attn_batched → attn_kv_temp [H, seq, hd]
1516        //
1517        // BUG FIX: Cannot transpose attn_scores [H,S,S] into attn_kv_temp2 [H,S,hd]
1518        // because H*S*S >> H*S*hd when S > hd (e.g. 350M: 4.2M vs 524K = 8× overflow).
1519        //
1520        // Use identity: grad_V = (grad_attn_batched^T @ attn_scores)^T
1521        // All intermediates are [H, hd, S] = [H, S, hd] size — no H*S*S buffer needed.
1522
1523        // Step A: transpose grad_attn_batched [H,S,hd] → [H,hd,S]
1524        batched_transpose_forward(
1525            &self.scratch.attn_q_batched,   // grad_attn_batched [H, S, hd]
1526            &mut self.scratch.attn_kv_temp, // temp: grad_attn_batched^T [H, hd, S]
1527            nh,
1528            seq,
1529            hd,
1530            stream,
1531        )?;
1532
1533        // Step B: GEMM [H,hd,S] @ [H,S,S] → [H,hd,S] (= grad_V^T)
1534        batched_4d_gemm_forward(
1535            &self.scratch.attn_kv_temp,      // grad_attn_batched^T [H, hd, S]
1536            &self.scratch.attn_scores,       // attn_weights [H, S, S]
1537            &mut self.scratch.attn_kv_temp2, // grad_V^T [H, hd, S]
1538            1,
1539            nh,
1540            hd,  // m
1541            seq, // n
1542            seq, // k
1543            stream,
1544        )?;
1545
1546        // Step C: transpose grad_V^T [H,hd,S] → grad_V [H,S,hd]
1547        batched_transpose_forward(
1548            &self.scratch.attn_kv_temp2,    // grad_V^T [H, hd, S]
1549            &mut self.scratch.attn_kv_temp, // grad_V [H, S, hd]
1550            nh,
1551            hd,
1552            seq,
1553            stream,
1554        )?;
1555        // attn_kv_temp = grad_V [num_heads, seq, head_dim]
1556
1557        // === Step 4.4: Softmax backward ===
1558        // attn_scores contains softmax output from forward pass
1559        // In-place: grad_attn_scores is both input (grad_output) and output (grad_input)
1560        // This is safe because the kernel reads all elements in pass 1 before writing in pass 2.
1561        let total_rows = nh * seq;
1562        {
1563            // SAFETY: In-place aliasing is safe for BatchedSoftmaxBackwardKernel which uses
1564            // a two-pass approach: pass 1 reads all y[i]*gy[i] to compute dot product,
1565            // pass 2 writes grad_x[i]. The view is forgotten to prevent double-free.
1566            let grad_scores_view = unsafe {
1567                GpuBuffer::<f32>::from_raw_parts(
1568                    self.scratch.grad_attn_scores.as_ptr(),
1569                    self.scratch.grad_attn_scores.len(),
1570                )
1571            };
1572            batched_softmax_backward(
1573                &self.scratch.attn_scores,
1574                &grad_scores_view,
1575                &mut self.scratch.grad_attn_scores,
1576                total_rows,
1577                seq,
1578                stream,
1579            )?;
1580            leak(grad_scores_view);
1581        }
1582        // grad_attn_scores now contains gradient through softmax
1583
1584        // === Step 4.5: Scale backward ===
1585        // Forward scaled by 1/√d_k, backward is same scale (linear operation)
1586        let total_scores = nh * seq * seq;
1587        {
1588            // SAFETY: In-place aliasing safe for element-wise scale (independent elements).
1589            let scores_view = unsafe {
1590                GpuBuffer::<f32>::from_raw_parts(
1591                    self.scratch.grad_attn_scores.as_ptr(),
1592                    self.scratch.grad_attn_scores.len(),
1593                )
1594            };
1595            scale_forward(
1596                &scores_view,
1597                &mut self.scratch.grad_attn_scores,
1598                scale,
1599                total_scores,
1600                stream,
1601            )?;
1602            leak(scores_view);
1603        }
1604
1605        // === Step 4.6: Backward through Q @ K^T ===
1606        // Forward was: scores = Q_batched @ K^T
1607        // Reconstruct K_batched and expand for GQA
1608        interleaved_to_batched_forward(
1609            &self.scratch.k,
1610            &mut self.scratch.attn_kv_temp2,
1611            seq,
1612            nkv,
1613            hd,
1614            stream,
1615        )?;
1616
1617        if heads_per_kv > 1 {
1618            // SAFETY: Both buffers are valid GPU allocations; attn_q_batched is about to be
1619            // overwritten anyway.
1620            unsafe {
1621                self.scratch
1622                    .attn_q_batched
1623                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1624                    .map_err(|e| {
1625                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1626                            "Attn backward K copy for GQA expand failed: {e}"
1627                        ))
1628                    })?;
1629            }
1630            expand_kv_heads(
1631                &self.scratch.attn_q_batched,
1632                &mut self.scratch.attn_kv_temp2,
1633                num_kv_heads,
1634                heads_per_kv,
1635                seq_len * head_dim,
1636                stream,
1637            )?;
1638        }
1639        // attn_kv_temp2 = K_expanded [num_heads, seq, head_dim]
1640
1641        // grad_Q = grad_scores @ K_expanded → attn_q_batched [H, seq, hd]
1642        batched_4d_gemm_forward(
1643            &self.scratch.grad_attn_scores,
1644            &self.scratch.attn_kv_temp2,
1645            &mut self.scratch.attn_q_batched,
1646            1,
1647            nh,
1648            seq,
1649            hd,
1650            seq,
1651            stream,
1652        )?;
1653
1654        // grad_K^T = Q^T @ grad_scores
1655        // First reconstruct Q_batched from preserved q
1656        // Reconstruct Q_batched into o_proj_out (attn_q_batched already overwritten by grad_Q).
1657        interleaved_to_batched_forward(
1658            &self.scratch.q,
1659            &mut self.scratch.o_proj_out, // temp buffer for Q_batched
1660            seq,
1661            nh,
1662            hd,
1663            stream,
1664        )?;
1665
1666        // Transpose Q: [H, seq, hd] → [H, hd, seq]
1667        batched_transpose_forward(
1668            &self.scratch.o_proj_out,
1669            &mut self.scratch.attn_kv_temp2, // reuse for Q^T
1670            nh,
1671            seq,
1672            hd,
1673            stream,
1674        )?;
1675
1676        // grad_K^T = Q^T @ grad_scores → ffn_out as temp [H, hd, seq]
1677        batched_4d_gemm_forward(
1678            &self.scratch.attn_kv_temp2,
1679            &self.scratch.grad_attn_scores,
1680            &mut self.scratch.ffn_out, // reuse as temp for grad_K^T [H, hd, seq]
1681            1,
1682            nh,
1683            hd,
1684            seq,
1685            seq,
1686            stream,
1687        )?;
1688
1689        // Transpose grad_K^T → grad_K: [H, hd, seq] → [H, seq, hd]
1690        batched_transpose_forward(
1691            &self.scratch.ffn_out,
1692            &mut self.scratch.attn_kv_temp2, // grad_K [H, seq, hd]
1693            nh,
1694            hd,
1695            seq,
1696            stream,
1697        )?;
1698
1699        // === Step 4.7: GQA gradient reduction ===
1700        // grad_K and grad_V are in [num_heads, seq, hd], need to reduce to [num_kv_heads, seq, hd]
1701        if heads_per_kv > 1 {
1702            self.reduce_gqa_gradients(num_kv_heads, heads_per_kv, seq_len, head_dim, stream)?;
1703        }
1704
1705        // === Step 4.8: Convert gradients back to interleaved layout ===
1706        // grad_Q: attn_q_batched [H, seq, hd] → o_proj_out [seq, hidden] (interleaved)
1707        batched_to_interleaved_forward(
1708            &self.scratch.attn_q_batched,
1709            &mut self.scratch.o_proj_out,
1710            seq,
1711            nh,
1712            hd,
1713            stream,
1714        )?;
1715
1716        // grad_K: attn_kv_temp2 [nkv, seq, hd] → norm2_out [seq, kv_hidden] (interleaved)
1717        batched_to_interleaved_forward(
1718            &self.scratch.attn_kv_temp2,
1719            &mut self.scratch.norm2_out,
1720            seq,
1721            nkv,
1722            hd,
1723            stream,
1724        )?;
1725
1726        // grad_V: attn_kv_temp [nkv, seq, hd] → ffn_out [seq, kv_hidden] (interleaved)
1727        batched_to_interleaved_forward(
1728            &self.scratch.attn_kv_temp,
1729            &mut self.scratch.ffn_out,
1730            seq,
1731            nkv,
1732            hd,
1733            stream,
1734        )?;
1735
1736        // === Step 4.8b: RoPE backward (inverse rotation) ===
1737        // Forward applied RoPE to Q and K before attention. Backward must undo
1738        // the rotation so projection backward (step 4.9) gets unrotated gradients.
1739        // R^T(-θ): new_x0 = x0*cos + x1*sin, new_x1 = x1*cos - x0*sin
1740        let rope_theta = self.config.rope_theta;
1741        {
1742            // grad_Q in o_proj_out [seq, q_dim] — apply inverse rotation in-place
1743            let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.o_proj_out)) };
1744            batched_rope_neox_backward(
1745                q_ref,
1746                &mut self.scratch.o_proj_out,
1747                &self.scratch.rope_positions,
1748                nh,
1749                hd,
1750                seq,
1751                rope_theta,
1752                stream,
1753            )?;
1754            // grad_K in norm2_out [seq, kv_hidden] — apply inverse rotation in-place
1755            let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.norm2_out)) };
1756            batched_rope_neox_backward(
1757                k_ref,
1758                &mut self.scratch.norm2_out,
1759                &self.scratch.rope_positions,
1760                nkv,
1761                hd,
1762                seq,
1763                rope_theta,
1764                stream,
1765            )?;
1766        }
1767
1768        // === Step 4.9: Q/K/V projection backward ===
1769        // Forward: q[seq,q_dim] = norm1[seq,hidden] @ w_q[hidden,q_dim]
1770        //   m=seq, k=hidden, n=q_dim
1771        // grad_norm1[seq,hidden] = grad_q[seq,q_dim] @ w_q^T[q_dim,hidden]
1772        gemm_backward_a(
1773            &self.scratch.o_proj_out, // grad_q interleaved [seq, q_dim]
1774            &self.w_q,
1775            &mut self.scratch.grad_hidden,
1776            seq,
1777            saturating_u32(hidden_size),
1778            saturating_u32(q_dim),
1779            stream,
1780        )?;
1781
1782        // grad_norm1 += grad_k @ w_k^T
1783        // Forward: k[seq,kv_hidden] = norm1[seq,hidden] @ w_k[hidden,kv_hidden]
1784        //   m=seq, k=hidden, n=kv_hidden
1785        // KAIZEN-057: cuda_add_inplace replaces residual_add_forward + D2D copy
1786        gemm_backward_a(
1787            &self.scratch.norm2_out, // grad_k interleaved
1788            &self.w_k,
1789            &mut self.scratch.grad_attn_scores, // temp for grad_k @ w_k^T
1790            seq,
1791            saturating_u32(hidden_size),
1792            saturating_u32(kv_hidden_size),
1793            stream,
1794        )?;
1795        cuda_add_inplace(
1796            &mut self.scratch.grad_hidden,
1797            &self.scratch.grad_attn_scores,
1798            seq_len * hidden_size,
1799            stream,
1800        )?;
1801
1802        // grad_norm1 += grad_v @ w_v^T
1803        // Forward: v[seq,kv_hidden] = norm1[seq,hidden] @ w_v[hidden,kv_hidden]
1804        //   m=seq, k=hidden, n=kv_hidden
1805        gemm_backward_a(
1806            &self.scratch.ffn_out, // grad_v interleaved
1807            &self.w_v,
1808            &mut self.scratch.grad_attn_scores, // temp for grad_v @ w_v^T
1809            seq,
1810            saturating_u32(hidden_size),
1811            saturating_u32(kv_hidden_size),
1812            stream,
1813        )?;
1814        cuda_add_inplace(
1815            &mut self.scratch.grad_hidden,
1816            &self.scratch.grad_attn_scores,
1817            seq_len * hidden_size,
1818            stream,
1819        )?;
1820
1821        // Weight gradients: grad_w_q[hidden,q_dim] = norm1_out^T[hidden,seq] @ grad_q[seq,q_dim]
1822        gemm_backward_b(
1823            &self.scratch.norm1_out,
1824            &self.scratch.o_proj_out, // grad_q [seq, q_dim]
1825            &mut grad_ws.grad_w_q,
1826            seq,
1827            saturating_u32(hidden_size),
1828            saturating_u32(q_dim),
1829            stream,
1830        )?;
1831
1832        // grad_w_k = norm1_out^T @ grad_k
1833        gemm_backward_b(
1834            &self.scratch.norm1_out,
1835            &self.scratch.norm2_out, // grad_k
1836            &mut grad_ws.grad_w_k,
1837            seq,
1838            saturating_u32(hidden_size),
1839            saturating_u32(kv_hidden_size),
1840            stream,
1841        )?;
1842
1843        // grad_w_v = norm1_out^T @ grad_v
1844        gemm_backward_b(
1845            &self.scratch.norm1_out,
1846            &self.scratch.ffn_out, // grad_v
1847            &mut grad_ws.grad_w_v,
1848            seq,
1849            saturating_u32(hidden_size),
1850            saturating_u32(kv_hidden_size),
1851            stream,
1852        )?;
1853
1854        // Copy grad_hidden → grad_input for downstream (residual backward)
1855        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1856        unsafe {
1857            grad_input.copy_from_buffer_async(&self.scratch.grad_hidden, stream).map_err(|e| {
1858                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1859                    "Attn backward grad_hidden → grad_input D2D copy failed: {e}"
1860                ))
1861            })?;
1862        }
1863
1864        Ok(())
1865    }
1866
1867    /// Reduce GQA head gradients from [num_heads] to [num_kv_heads] by summing groups.
1868    ///
1869    /// Reads grad_K from `attn_kv_temp2` and grad_V from `attn_kv_temp` (both [H, seq, hd]).
1870    /// Writes reduced grad_K to `attn_kv_temp2` and reduced grad_V to `attn_kv_temp`
1871    /// (both [nkv, seq, hd]).
1872    ///
1873    /// Uses `grad_attn_scores`, `ffn_out`, `o_proj_out`, `grad_hidden` as scratch.
1874    fn reduce_gqa_gradients(
1875        &mut self,
1876        num_kv_heads: usize,
1877        heads_per_kv: usize,
1878        seq_len: usize,
1879        head_dim: usize,
1880        stream: &CudaStream,
1881    ) -> Result<()> {
1882        let elems_per_head = seq_len * head_dim;
1883
1884        // Reduce grad_K: attn_kv_temp2 [H] → grad_attn_scores [nkv]
1885        self.reduce_single_gqa_gradient(true, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1886
1887        // Reduce grad_V: attn_kv_temp [H] → ffn_out [nkv]
1888        self.reduce_single_gqa_gradient(false, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1889
1890        // Copy reduced results to known locations for step 4.8
1891        let kv_elems = num_kv_heads * elems_per_head;
1892        // SAFETY: Valid GPU allocations with sufficient size.
1893        unsafe {
1894            self.scratch
1895                .attn_kv_temp2
1896                .copy_from_buffer_at_async(&self.scratch.grad_attn_scores, 0, 0, kv_elems, stream)
1897                .map_err(|e| {
1898                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1899                        "GQA grad_K reduced final copy failed: {e}"
1900                    ))
1901                })?;
1902            self.scratch
1903                .attn_kv_temp
1904                .copy_from_buffer_at_async(&self.scratch.ffn_out, 0, 0, kv_elems, stream)
1905                .map_err(|e| {
1906                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1907                        "GQA grad_V reduced final copy failed: {e}"
1908                    ))
1909                })?;
1910        }
1911        Ok(())
1912    }
1913
1914    /// Reduce one gradient tensor from [num_heads] to [num_kv_heads] by summing groups.
1915    ///
1916    /// When `is_k=true`: reads from `attn_kv_temp2`, writes to `grad_attn_scores`.
1917    /// When `is_k=false`: reads from `attn_kv_temp`, writes to `ffn_out`.
1918    /// Uses `o_proj_out` and `grad_hidden` as scratch.
1919    fn reduce_single_gqa_gradient(
1920        &mut self,
1921        is_k: bool,
1922        num_kv_heads: usize,
1923        heads_per_kv: usize,
1924        elems_per_head: usize,
1925        stream: &CudaStream,
1926    ) -> Result<()> {
1927        let label = if is_k { "K" } else { "V" };
1928
1929        for kv_h in 0..num_kv_heads {
1930            let dst_offset = kv_h * elems_per_head;
1931            let first_h = kv_h * heads_per_kv;
1932            let src_offset = first_h * elems_per_head;
1933
1934            // Copy first head of group as base
1935            // SAFETY: All offsets are within buffer bounds.
1936            unsafe {
1937                let (dst, src) = if is_k {
1938                    (&mut self.scratch.grad_attn_scores, &self.scratch.attn_kv_temp2)
1939                } else {
1940                    (&mut self.scratch.ffn_out, &self.scratch.attn_kv_temp)
1941                };
1942                dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
1943                    .map_err(|e| {
1944                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1945                            "GQA grad_{label} reduce base copy failed: {e}"
1946                        ))
1947                    })?;
1948            }
1949
1950            // Add remaining heads in group
1951            for rep in 1..heads_per_kv {
1952                let h = kv_h * heads_per_kv + rep;
1953                let h_offset = h * elems_per_head;
1954
1955                // Head extraction into o_proj_out buffer
1956                // SAFETY: Valid GPU allocations with sufficient size.
1957                unsafe {
1958                    let src =
1959                        if is_k { &self.scratch.attn_kv_temp2 } else { &self.scratch.attn_kv_temp };
1960                    self.scratch
1961                        .o_proj_out
1962                        .copy_from_buffer_at_async(src, 0, h_offset, elems_per_head, stream)
1963                        .map_err(|e| {
1964                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1965                                "GQA grad_{label} reduce head copy failed: {e}"
1966                            ))
1967                        })?;
1968                }
1969
1970                // Add: dst[dst_offset..] += o_proj_out[0..elems_per_head]
1971                // SAFETY: Creating non-owning views for arithmetic; forgotten to prevent double-free.
1972                unsafe {
1973                    let dst_buf =
1974                        if is_k { &self.scratch.grad_attn_scores } else { &self.scratch.ffn_out };
1975                    let dst_view = GpuBuffer::<f32>::from_raw_parts(
1976                        dst_buf.as_ptr() + (dst_offset as u64 * 4),
1977                        elems_per_head,
1978                    );
1979                    let src_view = GpuBuffer::<f32>::from_raw_parts(
1980                        self.scratch.o_proj_out.as_ptr(),
1981                        elems_per_head,
1982                    );
1983                    let mut sum_view = GpuBuffer::<f32>::from_raw_parts(
1984                        self.scratch.grad_hidden.as_ptr(),
1985                        elems_per_head,
1986                    );
1987                    residual_add_forward(
1988                        &dst_view,
1989                        &src_view,
1990                        &mut sum_view,
1991                        saturating_u32(elems_per_head),
1992                        stream,
1993                    )?;
1994                    // Copy sum back to dst at dst_offset
1995                    let dst_buf = if is_k {
1996                        &mut self.scratch.grad_attn_scores
1997                    } else {
1998                        &mut self.scratch.ffn_out
1999                    };
2000                    dst_buf
2001                        .copy_from_buffer_at_async(
2002                            &self.scratch.grad_hidden,
2003                            dst_offset,
2004                            0,
2005                            elems_per_head,
2006                            stream,
2007                        )
2008                        .map_err(|e| {
2009                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2010                                "GQA grad_{label} reduce sum copy failed: {e}"
2011                            ))
2012                        })?;
2013                    leak(dst_view);
2014                    leak(src_view);
2015                    leak(sum_view);
2016                }
2017            }
2018        }
2019        Ok(())
2020    }
2021
2022    /// Backward through first residual connection and input RMSNorm.
2023    fn backward_residual_and_input_norm(
2024        &mut self,
2025        input: &GpuBuffer<f32>,
2026        grad_output: &GpuBuffer<f32>,
2027        grad_input: &mut GpuBuffer<f32>,
2028        seq_len: usize,
2029        hidden_size: usize,
2030        eps: f32,
2031        stream: &CudaStream,
2032        grad_ws: &mut CudaGradWorkspace,
2033    ) -> Result<()> {
2034        // Forward was: norm_out = RMSNorm(input); attn_out = Attn(norm_out); residual1 = input + attn_out
2035        // Backward: grad_input = RMSNorm_backward(grad_through_attention) + grad_output
2036        //
2037        // C-RESIDUAL-001 / entrenar#313: The residual skip (grad_output) must be added
2038        // AFTER RMSNorm backward, not before. RMSNorm backward should only transform
2039        // the gradient that flows through the norm/attention path. The identity skip
2040        // bypasses the norm entirely.
2041
2042        // D2D copy grad_input (attention path gradient) to grad_hidden
2043        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
2044        unsafe {
2045            self.scratch.grad_hidden.copy_from_buffer_async(grad_input, stream).map_err(|e| {
2046                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2047                    "Backward residual grad_hidden D2D copy failed: {e}"
2048                ))
2049            })?;
2050        }
2051
2052        // RMSNorm backward: only applied to the attention path gradient
2053        rms_norm_backward(
2054            input,
2055            &self.input_norm_weight,
2056            &self.scratch.grad_hidden,
2057            grad_input,
2058            &mut grad_ws.grad_input_norm,
2059            saturating_u32(seq_len),
2060            saturating_u32(hidden_size),
2061            eps,
2062            stream,
2063        )?;
2064
2065        // NOW add the residual skip: grad_input += grad_output (identity connection)
2066        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)
2067    }
2068
2069    /// Initialize GPU-resident AdamW optimizer state for all block weights.
2070    ///
2071    /// Allocates zero-initialized first and second moment buffers for each of the
2072    /// 9 weight tensors (4 attention projections + 3 FFN projections + 2 RMSNorm).
2073    ///
2074    /// # Contract (C-OPTINIT-001)
2075    ///
2076    /// - **Precondition**: CUDA context is valid, sufficient GPU memory available
2077    /// - **Postcondition**: All m/v buffers are zero-initialized with dimensions
2078    ///   matching the corresponding weight tensors
2079    /// - **Invariant**: Total GPU memory for optimizer state = 2 × sum(weight_sizes) × 4 bytes
2080    pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2081        let hidden = self.config.hidden_size;
2082        let q_dim = self.config.q_dim();
2083        let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
2084        let intermediate = self.config.intermediate_size;
2085
2086        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
2087        // zero memory (cuMemAlloc returns uninitialized VRAM). Uninitialized m/v
2088        // causes v_new = beta2 * GARBAGE which can be negative → sqrt(neg) → NaN.
2089        let z = |n: usize| -> Result<GpuBuffer<f32>> {
2090            Ok(GpuBuffer::from_host(&self.ctx, &vec![0.0f32; n])?)
2091        };
2092        Ok(GpuBlockOptimizerState {
2093            m_w_q: z(q_dim * hidden)?,
2094            v_w_q: z(q_dim * hidden)?,
2095            m_w_k: z(hidden * kv_hidden)?,
2096            v_w_k: z(hidden * kv_hidden)?,
2097            m_w_v: z(hidden * kv_hidden)?,
2098            v_w_v: z(hidden * kv_hidden)?,
2099            m_w_o: z(hidden * q_dim)?,
2100            v_w_o: z(hidden * q_dim)?,
2101            m_w_gate: z(hidden * intermediate)?,
2102            v_w_gate: z(hidden * intermediate)?,
2103            m_w_up: z(hidden * intermediate)?,
2104            v_w_up: z(hidden * intermediate)?,
2105            m_w_down: z(intermediate * hidden)?,
2106            v_w_down: z(intermediate * hidden)?,
2107            m_input_norm: z(hidden)?,
2108            v_input_norm: z(hidden)?,
2109            m_post_attn_norm: z(hidden)?,
2110            v_post_attn_norm: z(hidden)?,
2111        })
2112    }
2113
2114    /// Run GPU-resident AdamW optimizer step on all block weights.
2115    ///
2116    /// Updates weights in-place using gradients computed by `backward()`.
2117    /// All operations run on GPU — zero CPU↔GPU data transfers.
2118    ///
2119    /// # Contract (C-OPTSTEP-001)
2120    ///
2121    /// - **Precondition**: `backward()` completed for this block (scratch grad buffers valid),
2122    ///   `state` initialized via `init_optimizer_state()`, `step > 0`
2123    /// - **Postcondition**: All 9 weight tensors updated by AdamW rule,
2124    ///   m/v states updated with current gradient statistics
2125    /// - **Invariant**: Weight dimensions unchanged; no GPU memory allocated or freed
2126    pub fn optimizer_step(
2127        &mut self,
2128        state: &mut GpuBlockOptimizerState,
2129        step: u32,
2130        lr: f32,
2131        beta1: f32,
2132        beta2: f32,
2133        eps: f32,
2134        weight_decay: f32,
2135        stream: &CudaStream,
2136        grad_ws: &CudaGradWorkspace,
2137    ) -> Result<()> {
2138        debug_assert!(step > 0, "C-OPTSTEP-001: step must be > 0 for bias adjust");
2139
2140        // Pre-capture lengths to avoid borrow conflicts (len is immutable borrow,
2141        // adamw_step_cuda takes mutable borrow on same buffer)
2142        let n_wq = self.w_q.len() as u32;
2143        let n_wk = self.w_k.len() as u32;
2144        let n_wv = self.w_v.len() as u32;
2145        let n_wo = self.w_o.len() as u32;
2146        let n_gate = self.w_gate.len() as u32;
2147        let n_up = self.w_up.len() as u32;
2148        let n_down = self.w_down.len() as u32;
2149        let n_inorm = self.input_norm_weight.len() as u32;
2150        let n_panorm = self.post_attn_norm_weight.len() as u32;
2151
2152        // Attention projection weights
2153        adamw_step_cuda(
2154            &mut self.w_q,
2155            &grad_ws.grad_w_q,
2156            &mut state.m_w_q,
2157            &mut state.v_w_q,
2158            lr,
2159            beta1,
2160            beta2,
2161            eps,
2162            weight_decay,
2163            step,
2164            n_wq,
2165            stream,
2166        )?;
2167        adamw_step_cuda(
2168            &mut self.w_k,
2169            &grad_ws.grad_w_k,
2170            &mut state.m_w_k,
2171            &mut state.v_w_k,
2172            lr,
2173            beta1,
2174            beta2,
2175            eps,
2176            weight_decay,
2177            step,
2178            n_wk,
2179            stream,
2180        )?;
2181        adamw_step_cuda(
2182            &mut self.w_v,
2183            &grad_ws.grad_w_v,
2184            &mut state.m_w_v,
2185            &mut state.v_w_v,
2186            lr,
2187            beta1,
2188            beta2,
2189            eps,
2190            weight_decay,
2191            step,
2192            n_wv,
2193            stream,
2194        )?;
2195        adamw_step_cuda(
2196            &mut self.w_o,
2197            &grad_ws.grad_w_o,
2198            &mut state.m_w_o,
2199            &mut state.v_w_o,
2200            lr,
2201            beta1,
2202            beta2,
2203            eps,
2204            weight_decay,
2205            step,
2206            n_wo,
2207            stream,
2208        )?;
2209
2210        // FFN projection weights
2211        adamw_step_cuda(
2212            &mut self.w_gate,
2213            &grad_ws.grad_gate,
2214            &mut state.m_w_gate,
2215            &mut state.v_w_gate,
2216            lr,
2217            beta1,
2218            beta2,
2219            eps,
2220            weight_decay,
2221            step,
2222            n_gate,
2223            stream,
2224        )?;
2225        adamw_step_cuda(
2226            &mut self.w_up,
2227            &grad_ws.grad_up,
2228            &mut state.m_w_up,
2229            &mut state.v_w_up,
2230            lr,
2231            beta1,
2232            beta2,
2233            eps,
2234            weight_decay,
2235            step,
2236            n_up,
2237            stream,
2238        )?;
2239        adamw_step_cuda(
2240            &mut self.w_down,
2241            &grad_ws.grad_down,
2242            &mut state.m_w_down,
2243            &mut state.v_w_down,
2244            lr,
2245            beta1,
2246            beta2,
2247            eps,
2248            weight_decay,
2249            step,
2250            n_down,
2251            stream,
2252        )?;
2253
2254        // RMSNorm weights
2255        adamw_step_cuda(
2256            &mut self.input_norm_weight,
2257            &grad_ws.grad_input_norm,
2258            &mut state.m_input_norm,
2259            &mut state.v_input_norm,
2260            lr,
2261            beta1,
2262            beta2,
2263            eps,
2264            weight_decay,
2265            step,
2266            n_inorm,
2267            stream,
2268        )?;
2269        adamw_step_cuda(
2270            &mut self.post_attn_norm_weight,
2271            &grad_ws.grad_post_attn_norm,
2272            &mut state.m_post_attn_norm,
2273            &mut state.v_post_attn_norm,
2274            lr,
2275            beta1,
2276            beta2,
2277            eps,
2278            weight_decay,
2279            step,
2280            n_panorm,
2281            stream,
2282        )?;
2283
2284        Ok(())
2285    }
2286
2287    /// Download all weight data from GPU to host vectors.
2288    ///
2289    /// Used to synchronize GPU-updated weights back to CPU model for checkpointing.
2290    ///
2291    /// # Contract (C-DLWEIGHTS-001)
2292    ///
2293    /// - **Precondition**: Block weights are valid GPU allocations
2294    /// - **Postcondition**: Returned vectors have exact same length and content as GPU buffers
2295    /// - **Invariant**: GPU buffers are not modified
2296    pub fn download_weights(&self) -> Result<BlockWeights> {
2297        let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
2298            let mut host = vec![0.0f32; buf.len()];
2299            buf.copy_to_host(&mut host).map_err(|e| {
2300                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2301                    "Weight download failed: {e}"
2302                ))
2303            })?;
2304            Ok(host)
2305        };
2306
2307        Ok(BlockWeights {
2308            w_q: download(&self.w_q)?,
2309            w_k: download(&self.w_k)?,
2310            w_v: download(&self.w_v)?,
2311            w_o: download(&self.w_o)?,
2312            w_gate: download(&self.w_gate)?,
2313            w_up: download(&self.w_up)?,
2314            w_down: download(&self.w_down)?,
2315            input_norm_weight: download(&self.input_norm_weight)?,
2316            post_attn_norm_weight: download(&self.post_attn_norm_weight)?,
2317        })
2318    }
2319}
2320
2321/// Downloaded weight data from a CUDA transformer block.
2322///
2323/// # Contract (C-BLOCKWT-001)
2324///
2325/// - **Invariant**: Vector lengths match original weight dimensions
2326#[cfg(feature = "cuda")]
2327pub struct BlockWeights {
2328    pub w_q: Vec<f32>,
2329    pub w_k: Vec<f32>,
2330    pub w_v: Vec<f32>,
2331    pub w_o: Vec<f32>,
2332    pub w_gate: Vec<f32>,
2333    pub w_up: Vec<f32>,
2334    pub w_down: Vec<f32>,
2335    pub input_norm_weight: Vec<f32>,
2336    pub post_attn_norm_weight: Vec<f32>,
2337}
2338
2339/// CUDA element-wise addition on GPU (zero CPU transfers)
2340///
2341/// Uses `ResidualAddKernel` — single kernel launch, no D2H/H2D transfers.
2342#[cfg(feature = "cuda")]
2343fn cuda_add(
2344    a: &GpuBuffer<f32>,
2345    b: &GpuBuffer<f32>,
2346    output: &mut GpuBuffer<f32>,
2347    n: usize,
2348    stream: &CudaStream,
2349) -> Result<()> {
2350    residual_add_forward(a, b, output, saturating_u32(n), stream)
2351}
2352
2353/// In-place add: `target += source` using residual add with aliased output.
2354///
2355/// # Safety
2356///
2357/// The ResidualAdd kernel reads `a[i]` and `b[i]` then writes `output[i] = a[i] + b[i]`.
2358/// When `a` and `output` alias the same GPU buffer, each element is read before written
2359/// (no inter-element dependency), so this is safe for elementwise operations.
2360#[cfg(feature = "cuda")]
2361pub(crate) fn cuda_add_inplace(
2362    target: &mut GpuBuffer<f32>,
2363    source: &GpuBuffer<f32>,
2364    n: usize,
2365    stream: &CudaStream,
2366) -> Result<()> {
2367    // SAFETY: ResidualAdd kernel is elementwise (output[i] = a[i] + b[i]).
2368    // Aliasing target as both input and output is safe because each element is
2369    // independent — the GPU reads a[i] before writing output[i] at the same address.
2370    let target_ref: &GpuBuffer<f32> = unsafe { &*std::ptr::from_ref::<GpuBuffer<f32>>(target) };
2371    residual_add_forward(target_ref, source, target, saturating_u32(n), stream)
2372}
2373
2374/// CUDA element-wise multiplication on GPU (zero CPU transfers)
2375///
2376/// Uses `ElementwiseMulKernel` — single kernel launch, no D2H/H2D transfers.
2377#[cfg(feature = "cuda")]
2378fn cuda_mul(
2379    a: &GpuBuffer<f32>,
2380    b: &GpuBuffer<f32>,
2381    output: &mut GpuBuffer<f32>,
2382    n: usize,
2383    stream: &CudaStream,
2384) -> Result<()> {
2385    crate::autograd::cuda_forward::elementwise_mul_forward(a, b, output, saturating_u32(n), stream)
2386}
2387
2388// CPU fallback stub
2389#[cfg(not(feature = "cuda"))]
2390pub struct CudaTransformerBlock;
2391
2392#[cfg(not(feature = "cuda"))]
2393impl CudaTransformerBlock {
2394    pub fn layer_idx(&self) -> usize {
2395        0
2396    }
2397}
2398
2399// =============================================================================
2400// CudaBlock — enum dispatching fp32 or NF4 transformer blocks
2401// =============================================================================
2402
2403/// Unified enum for CUDA transformer blocks (fp32 or NF4-quantized).
2404///
2405/// The classify pipeline stores `Vec<CudaBlock>` and calls `forward()` without
2406/// caring which quantization format the frozen weights use.
2407#[cfg(feature = "cuda")]
2408pub enum CudaBlock {
2409    /// Standard fp32 weights (full precision, ~16 GB for Qwen3-4B)
2410    Fp32(CudaTransformerBlock),
2411    /// NF4 quantized weights (~2 GB for Qwen3-4B, ~8x compression)
2412    Nf4(CudaNf4TransformerBlock),
2413}
2414
2415#[cfg(feature = "cuda")]
2416impl CudaBlock {
2417    /// Forward pass through the transformer block.
2418    ///
2419    /// For NF4 blocks, `shared_scratch` must be `Some` — shared across all layers (C-SCRATCH-001).
2420    /// For fp32 blocks, `shared_scratch` is ignored (each block owns its scratch for backward).
2421    pub(crate) fn forward(
2422        &mut self,
2423        input: &GpuBuffer<f32>,
2424        output: &mut GpuBuffer<f32>,
2425        seq_len: usize,
2426        stream: &CudaStream,
2427        shared_scratch: Option<&mut CudaBlockScratch>,
2428    ) -> Result<()> {
2429        match self {
2430            CudaBlock::Fp32(b) => b.forward(input, output, seq_len, stream),
2431            CudaBlock::Nf4(b) => {
2432                let scratch =
2433                    shared_scratch.expect("C-SCRATCH-001: NF4 blocks require shared scratch");
2434                b.forward(input, output, seq_len, stream, scratch)
2435            }
2436        }
2437    }
2438
2439    /// Get the layer index of this block.
2440    pub fn layer_idx(&self) -> usize {
2441        match self {
2442            CudaBlock::Fp32(b) => b.layer_idx(),
2443            CudaBlock::Nf4(b) => b.layer_idx,
2444        }
2445    }
2446
2447    /// Backward pass (only supported for fp32 blocks).
2448    ///
2449    /// NF4 blocks are frozen — backward is never called when `quantize_nf4` is active
2450    /// because `gpu_training` is set to `None`.
2451    pub fn backward(
2452        &mut self,
2453        input: &GpuBuffer<f32>,
2454        grad_output: &GpuBuffer<f32>,
2455        grad_input: &mut GpuBuffer<f32>,
2456        seq_len: usize,
2457        stream: &CudaStream,
2458        grad_ws: &mut CudaGradWorkspace,
2459    ) -> Result<()> {
2460        match self {
2461            CudaBlock::Fp32(b) => {
2462                b.backward(input, grad_output, grad_input, seq_len, stream, grad_ws)
2463            }
2464            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2465                "backward not supported on NF4 blocks (frozen weights)".into(),
2466            )),
2467        }
2468    }
2469
2470    /// Initialize optimizer state (only supported for fp32 blocks).
2471    pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2472        match self {
2473            CudaBlock::Fp32(b) => b.init_optimizer_state(),
2474            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2475                "init_optimizer_state not supported on NF4 blocks".into(),
2476            )),
2477        }
2478    }
2479
2480    /// Download weights from GPU (only supported for fp32 blocks).
2481    pub fn download_weights(&self) -> Result<BlockWeights> {
2482        match self {
2483            CudaBlock::Fp32(b) => b.download_weights(),
2484            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2485                "download_weights not supported on NF4 blocks".into(),
2486            )),
2487        }
2488    }
2489
2490    /// Optimizer step (only supported for fp32 blocks).
2491    pub fn optimizer_step(
2492        &mut self,
2493        state: &mut GpuBlockOptimizerState,
2494        step: u32,
2495        lr: f32,
2496        beta1: f32,
2497        beta2: f32,
2498        eps: f32,
2499        weight_decay: f32,
2500        stream: &CudaStream,
2501        grad_ws: &CudaGradWorkspace,
2502    ) -> Result<()> {
2503        match self {
2504            CudaBlock::Fp32(b) => {
2505                b.optimizer_step(state, step, lr, beta1, beta2, eps, weight_decay, stream, grad_ws)
2506            }
2507            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2508                "optimizer_step not supported on NF4 blocks (frozen weights)".into(),
2509            )),
2510        }
2511    }
2512
2513    /// NF4 backward pass with LoRA gradient computation (ENT-153).
2514    ///
2515    /// Only callable on NF4 blocks. Returns error for fp32 blocks.
2516    #[allow(clippy::too_many_arguments)]
2517    pub(crate) fn backward_nf4(
2518        &self,
2519        layer_input: &GpuBuffer<f32>,
2520        grad_output: &GpuBuffer<f32>,
2521        grad_input: &mut GpuBuffer<f32>,
2522        output_scratch: &mut GpuBuffer<f32>,
2523        seq_len: usize,
2524        stream: &CudaStream,
2525        shared_scratch: &mut CudaBlockScratch,
2526        grad_lora: &mut CudaLoraGradWorkspace,
2527    ) -> Result<()> {
2528        match self {
2529            CudaBlock::Nf4(b) => b.backward(
2530                layer_input,
2531                grad_output,
2532                grad_input,
2533                output_scratch,
2534                seq_len,
2535                stream,
2536                shared_scratch,
2537                grad_lora,
2538            ),
2539            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2540                "backward_nf4 only supported on NF4 blocks".into(),
2541            )),
2542        }
2543    }
2544
2545    /// Initialize LoRA optimizer state for NF4 blocks.
2546    pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
2547        match self {
2548            CudaBlock::Nf4(b) => b.init_lora_optimizer_state(),
2549            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2550                "init_lora_optimizer_state only supported on NF4 blocks".into(),
2551            )),
2552        }
2553    }
2554
2555    /// LoRA optimizer step for NF4 blocks.
2556    #[allow(clippy::too_many_arguments)]
2557    pub(crate) fn lora_optimizer_step(
2558        &mut self,
2559        state: &mut GpuLoraOptimizerState,
2560        step: u32,
2561        lr: f32,
2562        beta1: f32,
2563        beta2: f32,
2564        eps: f32,
2565        weight_decay: f32,
2566        stream: &CudaStream,
2567        grad_lora: &CudaLoraGradWorkspace,
2568    ) -> Result<()> {
2569        match self {
2570            CudaBlock::Nf4(b) => b.lora_optimizer_step(
2571                state,
2572                step,
2573                lr,
2574                beta1,
2575                beta2,
2576                eps,
2577                weight_decay,
2578                stream,
2579                grad_lora,
2580            ),
2581            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2582                "lora_optimizer_step only supported on NF4 blocks".into(),
2583            )),
2584        }
2585    }
2586
2587    /// Download LoRA weights from NF4 blocks.
2588    pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2589        match self {
2590            CudaBlock::Nf4(b) => b.download_lora_weights(),
2591            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2592                "download_lora_weights only supported on NF4 blocks".into(),
2593            )),
2594        }
2595    }
2596
2597    /// Upload LoRA weights to NF4 blocks for checkpoint resume (ENT-276).
2598    pub fn upload_lora_weights(
2599        &mut self,
2600        a_q: &[f32],
2601        b_q: &[f32],
2602        a_v: &[f32],
2603        b_v: &[f32],
2604    ) -> Result<()> {
2605        match self {
2606            CudaBlock::Nf4(b) => b.upload_lora_weights(a_q, b_q, a_v, b_v),
2607            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2608                "upload_lora_weights only supported on NF4 blocks".into(),
2609            )),
2610        }
2611    }
2612}
2613
2614/// CPU fallback stub for CudaBlock.
2615#[cfg(not(feature = "cuda"))]
2616pub enum CudaBlock {
2617    Fp32(CudaTransformerBlock),
2618}
2619
2620// =============================================================================
2621// NF4 Quantized Transformer Block (trueno#108: QLoRA support)
2622// =============================================================================
2623
2624/// CUDA-accelerated transformer block with NF4-quantized frozen weights.
2625///
2626/// Stores the 7 projection weights as packed NF4 (4-bit) + per-block scales instead
2627/// of fp32, achieving ~8x compression. Norm weights remain fp32 (negligible size).
2628///
2629/// # VRAM Savings (Qwen3-4B example)
2630///
2631/// | Component | fp32 | NF4 |
2632/// |-----------|------|-----|
2633/// | Frozen weights (36L × 7 projections) | 16.0 GB | 2.1 GB |
2634///
2635/// # Forward Only
2636///
2637/// NF4 blocks are frozen — no backward pass needed. LoRA adapters (fp32) handle
2638/// the trainable parameters separately. The forward pass uses fused dequant+GEMM
2639/// kernels that read NF4 directly without materializing fp32 weights.
2640#[cfg(feature = "cuda")]
2641pub struct CudaNf4TransformerBlock {
2642    config: TransformerConfig,
2643    layer_idx: usize,
2644    // Norm weights stay fp32 (tiny: 2 × hidden_size floats)
2645    input_norm_weight: GpuBuffer<f32>,
2646    post_attn_norm_weight: GpuBuffer<f32>,
2647    // Projection weights: NF4 quantized (packed data + per-block scales)
2648    w_q_nf4: GpuBuffer<u8>,
2649    w_q_scales: GpuBuffer<f32>,
2650    w_k_nf4: GpuBuffer<u8>,
2651    w_k_scales: GpuBuffer<f32>,
2652    w_v_nf4: GpuBuffer<u8>,
2653    w_v_scales: GpuBuffer<f32>,
2654    w_o_nf4: GpuBuffer<u8>,
2655    w_o_scales: GpuBuffer<f32>,
2656    w_gate_nf4: GpuBuffer<u8>,
2657    w_gate_scales: GpuBuffer<f32>,
2658    w_up_nf4: GpuBuffer<u8>,
2659    w_up_scales: GpuBuffer<f32>,
2660    w_down_nf4: GpuBuffer<u8>,
2661    w_down_scales: GpuBuffer<f32>,
2662    // ENT-287: Pre-dequantized fp32 weights for cuBLAS GEMM (correct weight layout)
2663    w_q_fp32: GpuBuffer<f32>,
2664    w_k_fp32: GpuBuffer<f32>,
2665    w_v_fp32: GpuBuffer<f32>,
2666    w_o_fp32: GpuBuffer<f32>,
2667    w_gate_fp32: GpuBuffer<f32>,
2668    w_up_fp32: GpuBuffer<f32>,
2669    w_down_fp32: GpuBuffer<f32>,
2670    // LoRA adapters for Q and V projections (ENT-153: QLoRA backward)
2671    // None when LoRA is not active (inference-only or non-QLoRA training)
2672    lora_a_q: Option<GpuBuffer<f32>>, // [hidden_size, rank]
2673    lora_b_q: Option<GpuBuffer<f32>>, // [rank, q_dim]
2674    lora_a_v: Option<GpuBuffer<f32>>, // [hidden_size, rank]
2675    lora_b_v: Option<GpuBuffer<f32>>, // [rank, kv_hidden]
2676    lora_scale: f32,
2677    lora_rank: usize,
2678    // QK-norm weights (ENT-270: per-head RMSNorm on Q and K, shape=[head_dim])
2679    q_norm_weight: Option<GpuBuffer<f32>>,
2680    k_norm_weight: Option<GpuBuffer<f32>>,
2681    // FP16 weight buffers for Tier 2 parity (PMAT-470): halve memory BW
2682    // When set, forward uses gemm_f16_to_f32 (fp16 weights × fp16 activations → fp32 output)
2683    w_q_fp16: Option<GpuBuffer<u16>>,
2684    w_k_fp16: Option<GpuBuffer<u16>>,
2685    w_v_fp16: Option<GpuBuffer<u16>>,
2686    w_o_fp16: Option<GpuBuffer<u16>>,
2687    w_gate_fp16: Option<GpuBuffer<u16>>,
2688    w_up_fp16: Option<GpuBuffer<u16>>,
2689    w_down_fp16: Option<GpuBuffer<u16>>,
2690    ctx: Arc<CudaContext>,
2691    // NF4 blocks do NOT own scratch — shared across all layers (C-SCRATCH-001)
2692}
2693
2694#[cfg(feature = "cuda")]
2695impl CudaNf4TransformerBlock {
2696    /// Create a new NF4 transformer block from fp32 CPU tensors.
2697    ///
2698    /// Quantizes all 7 projection weights to NF4 on CPU, then uploads the packed
2699    /// data and scales to GPU. Norm weights are uploaded as fp32.
2700    #[allow(clippy::too_many_arguments)]
2701    pub fn new(
2702        config: &TransformerConfig,
2703        layer_idx: usize,
2704        ctx: Arc<CudaContext>,
2705        input_norm_weight: &[f32],
2706        post_attn_norm_weight: &[f32],
2707        w_q: &[f32],
2708        w_k: &[f32],
2709        w_v: &[f32],
2710        w_o: &[f32],
2711        w_gate: &[f32],
2712        w_up: &[f32],
2713        w_down: &[f32],
2714        _max_seq_len: usize, // NF4 blocks use shared scratch (C-SCRATCH-001)
2715        // ENT-153: Optional LoRA adapters for Q and V projections
2716        q_lora: Option<(&[f32], &[f32])>,
2717        v_lora: Option<(&[f32], &[f32])>,
2718        lora_scale: f32,
2719        lora_rank: usize,
2720        // ENT-270: Optional QK-norm weights (per-head RMSNorm, shape=[head_dim])
2721        q_norm: Option<&[f32]>,
2722        k_norm: Option<&[f32]>,
2723    ) -> Result<Self> {
2724        use trueno_gpu::kernels::{quantize_nf4, NF4_BLOCK_SIZE};
2725
2726        let hidden_size = config.hidden_size;
2727        let q_dim = config.q_dim(); // num_heads * head_dim (may differ from hidden_size)
2728        let kv_hidden_size = config.num_kv_heads * config.head_dim();
2729        let intermediate_size = config.intermediate_size;
2730
2731        // ── C-NF4SHAPE-001: Weight shape contracts ──────────────────────
2732        // Ground truth: PMAT-331 validation in attention.rs from_pretrained()
2733        //   Q: [q_dim, hidden], K: [kv_hidden, hidden], V: [kv_hidden, hidden], O: [hidden, q_dim]
2734        //   gate: [intermediate, hidden], up: [intermediate, hidden], down: [hidden, intermediate]
2735        assert_eq!(
2736            w_q.len(),
2737            q_dim * hidden_size,
2738            "C-NF4SHAPE-001: w_q expected {}, got {} (q_dim={q_dim}, hidden={hidden_size})",
2739            q_dim * hidden_size,
2740            w_q.len()
2741        );
2742        assert_eq!(
2743            w_k.len(),
2744            kv_hidden_size * hidden_size,
2745            "C-NF4SHAPE-001: w_k expected {}, got {}",
2746            kv_hidden_size * hidden_size,
2747            w_k.len()
2748        );
2749        assert_eq!(
2750            w_v.len(),
2751            kv_hidden_size * hidden_size,
2752            "C-NF4SHAPE-001: w_v expected {}, got {}",
2753            kv_hidden_size * hidden_size,
2754            w_v.len()
2755        );
2756        assert_eq!(
2757            w_o.len(),
2758            hidden_size * q_dim,
2759            "C-NF4SHAPE-001: w_o expected {}, got {}",
2760            hidden_size * q_dim,
2761            w_o.len()
2762        );
2763        assert_eq!(
2764            w_gate.len(),
2765            intermediate_size * hidden_size,
2766            "C-NF4SHAPE-001: w_gate expected {}, got {}",
2767            intermediate_size * hidden_size,
2768            w_gate.len()
2769        );
2770        assert_eq!(
2771            w_up.len(),
2772            intermediate_size * hidden_size,
2773            "C-NF4SHAPE-001: w_up expected {}, got {}",
2774            intermediate_size * hidden_size,
2775            w_up.len()
2776        );
2777        assert_eq!(
2778            w_down.len(),
2779            hidden_size * intermediate_size,
2780            "C-NF4SHAPE-001: w_down expected {}, got {}",
2781            hidden_size * intermediate_size,
2782            w_down.len()
2783        );
2784
2785        // Upload norm weights as fp32
2786        let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
2787        let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
2788
2789        // Helper: quantize fp32 weight to NF4, upload packed data + scales to GPU
2790        // Returns (gpu_nf4, gpu_scales, cpu_quantized) — the CPU struct is retained
2791        // for dequantization into the cuBLAS fp32 buffer.
2792        let quantize_and_upload = |weights: &[f32],
2793                                   total: usize|
2794         -> Result<(
2795            GpuBuffer<u8>,
2796            GpuBuffer<f32>,
2797            trueno_gpu::kernels::Nf4Quantized,
2798        )> {
2799            assert_eq!(weights.len(), total, "weight length mismatch");
2800            assert!(
2801                total.is_multiple_of(NF4_BLOCK_SIZE),
2802                "weight count {total} not divisible by NF4 block size {NF4_BLOCK_SIZE}"
2803            );
2804
2805            let q = quantize_nf4(weights, total / NF4_BLOCK_SIZE, NF4_BLOCK_SIZE);
2806            let nf4_buf = GpuBuffer::from_host(&ctx, &q.data)?;
2807            let scales_buf = GpuBuffer::from_host(&ctx, &q.scales)?;
2808            Ok((nf4_buf, scales_buf, q))
2809        };
2810
2811        // Quantize all 7 projection weights (shape contracts already verified above)
2812        let (w_q_nf4, w_q_scales, w_q_nf4_q) = quantize_and_upload(w_q, q_dim * hidden_size)?;
2813        let (w_k_nf4, w_k_scales, w_k_nf4_q) =
2814            quantize_and_upload(w_k, kv_hidden_size * hidden_size)?;
2815        let (w_v_nf4, w_v_scales, w_v_nf4_q) =
2816            quantize_and_upload(w_v, kv_hidden_size * hidden_size)?;
2817        let (w_o_nf4, w_o_scales, w_o_nf4_q) = quantize_and_upload(w_o, hidden_size * q_dim)?;
2818        let (w_gate_nf4, w_gate_scales, w_gate_nf4_q) =
2819            quantize_and_upload(w_gate, intermediate_size * hidden_size)?;
2820        let (w_up_nf4, w_up_scales, w_up_nf4_q) =
2821            quantize_and_upload(w_up, intermediate_size * hidden_size)?;
2822        let (w_down_nf4, w_down_scales, w_down_nf4_q) =
2823            quantize_and_upload(w_down, hidden_size * intermediate_size)?;
2824
2825        // ENT-287: Dequantize NF4 weights and TRANSPOSE to [K,N] for standard cuBLAS GEMM.
2826        //
2827        // bitsandbytes does: F.dequantize_4bit(B).to(dtype).t()
2828        // The .t() transposes [N,K] → [K,N] so torch.nn.functional.linear
2829        // can use standard C = A @ B^T internally.
2830        //
2831        // We replicate this exactly:
2832        // 1. dequantize_nf4(q) → flat [N*K] in [N,K] row-major order
2833        // 2. CPU transpose [N,K] → [K,N]
2834        // 3. Upload transposed buffer to GPU
2835        // 4. Use standard gemm_forward (NoTrans, NoTrans) — same as LoRA weights
2836        use trueno_gpu::kernels::dequantize_nf4;
2837        let dequant_transpose_upload = |q: &trueno_gpu::kernels::Nf4Quantized,
2838                                        n: usize,
2839                                        k: usize|
2840         -> std::result::Result<
2841            GpuBuffer<f32>,
2842            crate::autograd::cuda_tensor::CudaTensorError,
2843        > {
2844            let deq = dequantize_nf4(q); // [N*K] in [N,K] row-major
2845            let nonzero = deq.iter().filter(|&&x| x != 0.0).count();
2846            eprintln!(
2847                "[TRACE] dequant n={n} k={k} len={} nonzero={nonzero} first5={:?}",
2848                deq.len(),
2849                &deq[..5.min(deq.len())]
2850            );
2851            assert_eq!(deq.len(), n * k, "dequant size mismatch: {} vs {}x{}", deq.len(), n, k);
2852            // Transpose [N,K] → [K,N]
2853            let mut transposed = vec![0.0f32; n * k];
2854            for row in 0..n {
2855                for col in 0..k {
2856                    transposed[col * n + row] = deq[row * k + col];
2857                }
2858            }
2859            let buf = GpuBuffer::from_host(&ctx, &transposed).map_err(|e| {
2860                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2861                    "dequant transpose upload: {e:?}"
2862                ))
2863            })?;
2864            // Verify upload: must read FULL buffer then slice
2865            let mut verify_full = vec![0.0f32; buf.len()];
2866            let verify_ok = buf.copy_to_host(&mut verify_full).is_ok();
2867            let verify5: Vec<f32> = verify_full.iter().copied().take(5).collect();
2868            let nz = verify_full.iter().filter(|&&x| x != 0.0).count();
2869            eprintln!("[TRACE] uploaded ptr={:?} len={} copy_ok={verify_ok} nonzero={nz} verify[:5]={verify5:?}", buf.as_ptr(), buf.len());
2870            Ok(buf)
2871        };
2872        // Each weight is [out_features, in_features] = [N, K].
2873        // After transpose: [K, N] — standard cuBLAS B layout.
2874        let w_q_fp32 = dequant_transpose_upload(&w_q_nf4_q, q_dim, hidden_size)?;
2875        let w_k_fp32 = dequant_transpose_upload(&w_k_nf4_q, kv_hidden_size, hidden_size)?;
2876        let w_v_fp32 = dequant_transpose_upload(&w_v_nf4_q, kv_hidden_size, hidden_size)?;
2877        let w_o_fp32 = dequant_transpose_upload(&w_o_nf4_q, hidden_size, q_dim)?;
2878        let w_gate_fp32 = dequant_transpose_upload(&w_gate_nf4_q, intermediate_size, hidden_size)?;
2879        let w_up_fp32 = dequant_transpose_upload(&w_up_nf4_q, intermediate_size, hidden_size)?;
2880        let w_down_fp32 = dequant_transpose_upload(&w_down_nf4_q, hidden_size, intermediate_size)?;
2881
2882        // NF4 blocks do NOT allocate scratch — shared across all layers (C-SCRATCH-001).
2883        // Pipeline allocates one CudaBlockScratch and passes &mut to each forward() call.
2884        // Saves (L-1) * 214 MB = 7.5 GB for Qwen3-4B (36 layers).
2885
2886        // Upload LoRA adapters to GPU (ENT-153)
2887        // B matrices are pre-scaled by lora_scale to avoid a separate scale kernel in forward.
2888        let (lora_a_q, lora_b_q) = match q_lora {
2889            Some((a_data, b_data)) => {
2890                let a = GpuBuffer::from_host(&ctx, a_data)?;
2891                let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2892                let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2893                (Some(a), Some(b))
2894            }
2895            None => (None, None),
2896        };
2897        let (lora_a_v, lora_b_v) = match v_lora {
2898            Some((a_data, b_data)) => {
2899                let a = GpuBuffer::from_host(&ctx, a_data)?;
2900                let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2901                let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2902                (Some(a), Some(b))
2903            }
2904            None => (None, None),
2905        };
2906
2907        // ENT-270: Upload QK-norm weights if present
2908        let q_norm_weight = match q_norm {
2909            Some(w) => {
2910                assert_eq!(
2911                    w.len(),
2912                    config.head_dim(),
2913                    "ENT-270: q_norm weight expected [head_dim={}], got [{}]",
2914                    config.head_dim(),
2915                    w.len()
2916                );
2917                Some(GpuBuffer::from_host(&ctx, w)?)
2918            }
2919            None => None,
2920        };
2921        let k_norm_weight = match k_norm {
2922            Some(w) => {
2923                assert_eq!(
2924                    w.len(),
2925                    config.head_dim(),
2926                    "ENT-270: k_norm weight expected [head_dim={}], got [{}]",
2927                    config.head_dim(),
2928                    w.len()
2929                );
2930                Some(GpuBuffer::from_host(&ctx, w)?)
2931            }
2932            None => None,
2933        };
2934
2935        Ok(Self {
2936            config: config.clone(),
2937            layer_idx,
2938            input_norm_weight,
2939            post_attn_norm_weight,
2940            w_q_nf4,
2941            w_q_scales,
2942            w_k_nf4,
2943            w_k_scales,
2944            w_v_nf4,
2945            w_v_scales,
2946            w_o_nf4,
2947            w_o_scales,
2948            w_gate_nf4,
2949            w_gate_scales,
2950            w_up_nf4,
2951            w_up_scales,
2952            w_down_nf4,
2953            w_down_scales,
2954            w_q_fp32,
2955            w_k_fp32,
2956            w_v_fp32,
2957            w_o_fp32,
2958            w_gate_fp32,
2959            w_up_fp32,
2960            w_down_fp32,
2961            lora_a_q,
2962            lora_b_q,
2963            lora_a_v,
2964            lora_b_v,
2965            lora_scale,
2966            lora_rank,
2967            q_norm_weight,
2968            k_norm_weight,
2969            // FP16 weights: None by default, populated by set_fp16_weights() (PMAT-470)
2970            w_q_fp16: None,
2971            w_k_fp16: None,
2972            w_v_fp16: None,
2973            w_o_fp16: None,
2974            w_gate_fp16: None,
2975            w_up_fp16: None,
2976            w_down_fp16: None,
2977            ctx,
2978        })
2979    }
2980
2981    /// Cast fp32→fp16 weights + drop fp32 (PMAT-470/472). Frees ~2.6 GB VRAM.
2982    pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()> {
2983        let cast_weight = |w_fp32: &GpuBuffer<f32>, ctx: &CudaContext| -> Result<GpuBuffer<u16>> {
2984            let n = w_fp32.len();
2985            let mut w_fp16 = GpuBuffer::<u16>::new(ctx, n)?;
2986            cast_f32_to_f16_gpu(w_fp32, &mut w_fp16, n as u32, stream)?;
2987            Ok(w_fp16)
2988        };
2989
2990        self.w_q_fp16 = Some(cast_weight(&self.w_q_fp32, &self.ctx)?);
2991        self.w_k_fp16 = Some(cast_weight(&self.w_k_fp32, &self.ctx)?);
2992        self.w_v_fp16 = Some(cast_weight(&self.w_v_fp32, &self.ctx)?);
2993        self.w_o_fp16 = Some(cast_weight(&self.w_o_fp32, &self.ctx)?);
2994        self.w_gate_fp16 = Some(cast_weight(&self.w_gate_fp32, &self.ctx)?);
2995        self.w_up_fp16 = Some(cast_weight(&self.w_up_fp32, &self.ctx)?);
2996        self.w_down_fp16 = Some(cast_weight(&self.w_down_fp32, &self.ctx)?);
2997
2998        stream.synchronize().map_err(|e| {
2999            crate::autograd::cuda_tensor::CudaTensorError::KernelError(format!(
3000                "FP16 weight cast sync failed: {e:?}"
3001            ))
3002        })?;
3003        // PMAT-472: Drop fp32 weights — backward now uses fp16 via gemm_backward_a_fp16_dispatch.
3004        // Frees ~2.6 GB VRAM on yoga 8GB, allowing GPU embeddings to fit.
3005        let dummy = |ctx: &CudaContext| GpuBuffer::<f32>::new(ctx, 1).unwrap();
3006        self.w_q_fp32 = dummy(&self.ctx);
3007        self.w_k_fp32 = dummy(&self.ctx);
3008        self.w_v_fp32 = dummy(&self.ctx);
3009        self.w_o_fp32 = dummy(&self.ctx);
3010        self.w_gate_fp32 = dummy(&self.ctx);
3011        self.w_up_fp32 = dummy(&self.ctx);
3012        self.w_down_fp32 = dummy(&self.ctx);
3013        eprintln!("[FP16] Weights cast + fp32 dropped (~2.6 GB freed)");
3014
3015        Ok(())
3016    }
3017
3018    /// Forward pass: cuBLAS GEMM with pre-dequantized weights (ENT-287, C-SCRATCH-001).
3019    #[rustfmt::skip]
3020    pub(crate) fn forward(
3021        &self,
3022        input: &GpuBuffer<f32>,
3023        output: &mut GpuBuffer<f32>,
3024        seq_len: usize,
3025        stream: &CudaStream,
3026        scratch: &mut CudaBlockScratch,
3027    ) -> Result<()> {
3028        use crate::autograd::cuda_forward::{gemm_forward, gemm_nf4_forward, gemm_nf4_tc_forward};
3029
3030        let hidden_size = self.config.hidden_size;
3031        let q_dim = self.config.q_dim();
3032        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3033        let intermediate_size = self.config.intermediate_size;
3034
3035        // entrenar#318: scratch zeroing moved to forward_cuda_training (once per step, not per layer)
3036        scratch.prepare_causal_mask(seq_len, &self.ctx)?;
3037
3038        // === Pre-attention RMSNorm === (PMAT-483: per-op profiling)
3039        let _t = scratch.op_begin();
3040        rms_norm_forward(
3041            input,
3042            &self.input_norm_weight,
3043            &mut scratch.norm1_out,
3044            saturating_u32(seq_len),
3045            saturating_u32(hidden_size),
3046            stream,
3047        )?;
3048        scratch.op_end(_t, OP_RMSNORM_ATTN);
3049
3050        // === Q, K, V Projections ===
3051        // Backend selection:
3052        //   FP16_GEMM=1: fp16 tensor core GEMM (Tier 2 parity, 2x BW savings)
3053        //   NF4_FUSED_GEMM=1: fused dequant+GEMM (8x less BW, 100% GPU, but naive PTX)
3054        //   Default: cuBLAS fp32 (197 tok/s, 7% GPU, memory-BW bound)
3055        static USE_NF4_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3056        let nf4_gemm = *USE_NF4_GEMM.get_or_init(|| std::env::var("NF4_FUSED_GEMM").as_deref() == Ok("1"));
3057        static USE_NF4_TC_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3058        let nf4_tc_gemm = *USE_NF4_TC_GEMM.get_or_init(|| std::env::var("NF4_TC_GEMM").as_deref() == Ok("1"));
3059        static USE_FP16_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3060        let fp16_gemm = *USE_FP16_GEMM.get_or_init(|| std::env::var("FP16_GEMM").as_deref() == Ok("1"));
3061
3062        // FP16 path: cast activation once, reuse for all projections
3063        let act_n = (seq_len * hidden_size) as u32;
3064        if fp16_gemm && self.w_q_fp16.is_some() {
3065            // Lazy-allocate fp16 activation buffer
3066            if scratch.norm1_out_f16.is_none() {
3067                scratch.norm1_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3068            }
3069            let f16_buf = scratch.norm1_out_f16.as_mut().unwrap();
3070            cast_f32_to_f16_gpu(&scratch.norm1_out, f16_buf, act_n, stream)?;
3071        }
3072
3073        let _t = scratch.op_begin(); // QKV GEMM timing
3074        if fp16_gemm && self.w_q_fp16.is_some() {
3075            let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3076            gemm_f16_to_f32_forward(f16_act, self.w_q_fp16.as_ref().unwrap(), &mut scratch.q,
3077                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3078        } else if nf4_tc_gemm {
3079            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3080                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3081        } else if nf4_gemm {
3082            gemm_nf4_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3083                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3084        } else {
3085            gemm_forward(&scratch.norm1_out, &self.w_q_fp32, &mut scratch.q,
3086                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3087        }
3088
3089        // ENT-153: Q LoRA: q += (norm1_out @ A_q) @ B_q  (B_q pre-scaled by lora_scale)
3090        if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
3091            let s = saturating_u32(seq_len);
3092            let h = saturating_u32(hidden_size);
3093            let r = saturating_u32(self.lora_rank);
3094            let qd = saturating_u32(q_dim);
3095            // lora_inter[seq, rank] = norm1_out[seq, hidden] @ A_q[hidden, rank]
3096            gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
3097            // lora_temp[seq, q_dim] = lora_inter[seq, rank] @ B_q[rank, q_dim]
3098            gemm_forward(&scratch.lora_inter, b_q, &mut scratch.lora_temp, s, r, qd, stream)?;
3099            // q += lora_temp (in-place add)
3100            cuda_add_inplace(&mut scratch.q, &scratch.lora_temp, seq_len * q_dim, stream)?;
3101        }
3102
3103        if fp16_gemm && self.w_k_fp16.is_some() {
3104            let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3105            gemm_f16_to_f32_forward(f16_act, self.w_k_fp16.as_ref().unwrap(), &mut scratch.k,
3106                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3107            gemm_f16_to_f32_forward(f16_act, self.w_v_fp16.as_ref().unwrap(), &mut scratch.v,
3108                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3109        } else if nf4_tc_gemm {
3110            // PMAT-481: NF4 tensor core GEMM for K and V projections (separate)
3111            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_k_nf4, &self.w_k_scales, &mut scratch.k,
3112                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3113            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_v_nf4, &self.w_v_scales, &mut scratch.v,
3114                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3115        } else if nf4_gemm {
3116            // PMAT-478: Fused K+V — shared input load (same pattern as Gate+Up)
3117            crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3118                &scratch.norm1_out,
3119                &self.w_k_nf4, &self.w_k_scales,
3120                &self.w_v_nf4, &self.w_v_scales,
3121                &mut scratch.k, &mut scratch.v,
3122                saturating_u32(seq_len), saturating_u32(hidden_size),
3123                saturating_u32(kv_hidden_size), stream,
3124            )?;
3125        } else {
3126            gemm_forward(&scratch.norm1_out, &self.w_k_fp32, &mut scratch.k,
3127                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3128            gemm_forward(&scratch.norm1_out, &self.w_v_fp32, &mut scratch.v,
3129                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3130        }
3131
3132        scratch.op_end(_t, OP_QKV_GEMM); // End QKV timing (includes Q/K/V GEMMs + Q LoRA)
3133
3134        // ENT-153: V LoRA: v += (norm1_out @ A_v) @ B_v  (B_v pre-scaled by lora_scale)
3135        if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
3136            let s = saturating_u32(seq_len);
3137            let h = saturating_u32(hidden_size);
3138            let r = saturating_u32(self.lora_rank);
3139            let vd = saturating_u32(kv_hidden_size);
3140            // lora_inter[seq, rank] = norm1_out[seq, hidden] @ A_v[hidden, rank]
3141            gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
3142            // lora_temp[seq, kv_hidden] = lora_inter[seq, rank] @ B_v[rank, kv_hidden]
3143            gemm_forward(&scratch.lora_inter, b_v, &mut scratch.lora_temp, s, r, vd, stream)?;
3144            // v += lora_temp (in-place add)
3145            cuda_add_inplace(&mut scratch.v, &scratch.lora_temp, seq_len * kv_hidden_size, stream)?;
3146        }
3147
3148        // === Multi-Head Attention (GPU-only, zero CPU transfers) ===
3149        let _t = scratch.op_begin();
3150        self.compute_attention_cuda(seq_len, stream, scratch)?;
3151        scratch.op_end(_t, OP_ATTENTION);
3152
3153        // === Output Projection ===
3154        let _t = scratch.op_begin();
3155        if fp16_gemm && self.w_o_fp16.is_some() {
3156            if scratch.attn_out_f16.is_none() {
3157                scratch.attn_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * q_dim)?);
3158            }
3159            let f16_buf = scratch.attn_out_f16.as_mut().unwrap();
3160            cast_f32_to_f16_gpu(&scratch.attn_out, f16_buf, (seq_len * q_dim) as u32, stream)?;
3161            gemm_f16_to_f32_forward(f16_buf, self.w_o_fp16.as_ref().unwrap(), &mut scratch.o_proj_out,
3162                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3163        } else if nf4_tc_gemm {
3164            gemm_nf4_tc_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3165                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3166        } else if nf4_gemm {
3167            gemm_nf4_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3168                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3169        } else {
3170            gemm_forward(&scratch.attn_out, &self.w_o_fp32, &mut scratch.o_proj_out,
3171                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3172        }
3173
3174        scratch.op_end(_t, OP_O_PROJ);
3175
3176        // === Fused Residual Add + RMSNorm (entrenar#321: eliminates NaN cascade) ===
3177        let _t = scratch.op_begin();
3178        // The separate cuda_add + rms_norm_forward allows activation explosion between
3179        // the two operations. Fusing them prevents NaN in layers 24-27 because RMSNorm
3180        // normalizes the residual sum immediately, before it can propagate.
3181        fused_residual_rmsnorm_forward(
3182            input,
3183            &scratch.o_proj_out,
3184            &mut scratch.residual1,
3185            &mut scratch.norm2_out,
3186            &self.post_attn_norm_weight,
3187            saturating_u32(seq_len),
3188            saturating_u32(hidden_size),
3189            stream,
3190        )?;
3191
3192        scratch.op_end(_t, OP_RMSNORM_FFN); // Fused residual + RMSNorm
3193
3194        // === FFN: Gate + Up + SwiGLU + Down ===
3195        let _t = scratch.op_begin(); // Gate+Up GEMM timing
3196        if fp16_gemm && self.w_gate_fp16.is_some() {
3197            if scratch.norm2_out_f16.is_none() {
3198                scratch.norm2_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3199            }
3200            let f16_buf = scratch.norm2_out_f16.as_mut().unwrap();
3201            cast_f32_to_f16_gpu(&scratch.norm2_out, f16_buf, (seq_len * hidden_size) as u32, stream)?;
3202            gemm_f16_to_f32_forward(f16_buf, self.w_gate_fp16.as_ref().unwrap(), &mut scratch.gate_out,
3203                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3204            gemm_f16_to_f32_forward(f16_buf, self.w_up_fp16.as_ref().unwrap(), &mut scratch.up_out,
3205                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3206        } else if nf4_tc_gemm {
3207            // PMAT-481: NF4 tensor core GEMM for Gate and Up projections (separate)
3208            gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_gate_nf4, &self.w_gate_scales, &mut scratch.gate_out,
3209                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3210            gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_up_nf4, &self.w_up_scales, &mut scratch.up_out,
3211                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3212        } else if nf4_gemm {
3213            // PMAT-475: Fused gate+up — shared input load, saves M×K×4 bytes DRAM
3214            crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3215                &scratch.norm2_out,
3216                &self.w_gate_nf4, &self.w_gate_scales,
3217                &self.w_up_nf4, &self.w_up_scales,
3218                &mut scratch.gate_out, &mut scratch.up_out,
3219                saturating_u32(seq_len), saturating_u32(hidden_size),
3220                saturating_u32(intermediate_size), stream,
3221            )?;
3222        } else {
3223            gemm_forward(&scratch.norm2_out, &self.w_gate_fp32, &mut scratch.gate_out,
3224                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3225            gemm_forward(&scratch.norm2_out, &self.w_up_fp32, &mut scratch.up_out,
3226                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3227        }
3228
3229        scratch.op_end(_t, OP_GATE_UP_GEMM);
3230
3231        // === FFN: Fused SwiGLU ===
3232        let _t = scratch.op_begin();
3233        fused_swiglu_forward(&scratch.gate_out, &scratch.up_out, &mut scratch.swiglu_out,
3234            saturating_u32(seq_len * intermediate_size), stream)?;
3235        scratch.op_end(_t, OP_SILU);
3236
3237        // === FFN: Down Projection ===
3238        let _t = scratch.op_begin();
3239        if fp16_gemm && self.w_down_fp16.is_some() {
3240            if scratch.swiglu_out_f16.is_none() {
3241                scratch.swiglu_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * intermediate_size)?);
3242            }
3243            let f16_buf = scratch.swiglu_out_f16.as_mut().unwrap();
3244            cast_f32_to_f16_gpu(&scratch.swiglu_out, f16_buf, (seq_len * intermediate_size) as u32, stream)?;
3245            gemm_f16_to_f32_forward(f16_buf, self.w_down_fp16.as_ref().unwrap(), &mut scratch.ffn_out,
3246                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3247        } else if nf4_tc_gemm {
3248            gemm_nf4_tc_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3249                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3250        } else if nf4_gemm {
3251            gemm_nf4_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3252                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3253        } else {
3254            gemm_forward(&scratch.swiglu_out, &self.w_down_fp32, &mut scratch.ffn_out,
3255                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3256        }
3257
3258        scratch.op_end(_t, OP_DOWN_GEMM);
3259
3260        // === Final Residual Add ===
3261        cuda_add(&scratch.residual1, &scratch.ffn_out, output, seq_len * hidden_size, stream)?;
3262
3263        Ok(())
3264    }
3265
3266    /// Layer index accessor.
3267    pub fn layer_idx(&self) -> usize {
3268        self.layer_idx
3269    }
3270}
3271
3272/// Helper: delegate attention computation using shared scratch buffers.
3273///
3274/// `CudaNf4TransformerBlock` reuses the same attention pipeline as the fp32 block
3275/// since attention operates on fp32 activations (Q/K/V are already dequantized by GEMM).
3276#[cfg(feature = "cuda")]
3277impl CudaNf4TransformerBlock {
3278    fn compute_attention_cuda(
3279        &self,
3280        seq_len: usize,
3281        stream: &CudaStream,
3282        scratch: &mut CudaBlockScratch,
3283    ) -> Result<()> {
3284        let num_heads = self.config.num_attention_heads;
3285        let num_kv_heads = self.config.num_kv_heads;
3286        let head_dim = self.config.head_dim();
3287        let heads_per_kv = num_heads / num_kv_heads;
3288
3289        let s = saturating_u32(seq_len);
3290        let nh = saturating_u32(num_heads);
3291        let nkv = saturating_u32(num_kv_heads);
3292        let hd = saturating_u32(head_dim);
3293
3294        // ── ENT-270: Apply QK-norm (per-head RMSNorm) on Q and K ──────────
3295        // Must happen BEFORE RoPE, matching CPU path ordering:
3296        //   projection → QK-norm → RoPE → attention
3297        // SAFETY: In-place GPU operations — CUDA kernels read all input before writing output.
3298        if let Some(ref q_norm) = self.q_norm_weight {
3299            for pos in 0..seq_len {
3300                let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3301                per_head_rmsnorm_forward(q_ref, q_norm, &mut scratch.q, nh, hd, pos, stream)?;
3302            }
3303        }
3304        if let Some(ref k_norm) = self.k_norm_weight {
3305            for pos in 0..seq_len {
3306                let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3307                per_head_rmsnorm_forward(k_ref, k_norm, &mut scratch.k, nkv, hd, pos, stream)?;
3308            }
3309        }
3310
3311        // ── ENT-270: Apply RoPE (NeoX half-rotation) on Q and K ──────────
3312        // ALB-119: Batched launch (2 kernels) replaces per-position loop (2*seq_len kernels)
3313        let rope_theta = self.config.rope_theta;
3314        {
3315            let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3316            batched_rope_neox_forward(
3317                q_ref,
3318                &mut scratch.q,
3319                &scratch.rope_positions,
3320                nh,
3321                hd,
3322                s,
3323                rope_theta,
3324                stream,
3325            )?;
3326            let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3327            batched_rope_neox_forward(
3328                k_ref,
3329                &mut scratch.k,
3330                &scratch.rope_positions,
3331                nkv,
3332                hd,
3333                s,
3334                rope_theta,
3335                stream,
3336            )?;
3337        }
3338
3339        // Q: interleaved → batched layout
3340        interleaved_to_batched_forward(&scratch.q, &mut scratch.attn_q_batched, s, nh, hd, stream)?;
3341
3342        // K: interleaved → batched, then GQA expand if needed
3343        interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3344
3345        if heads_per_kv > 1 {
3346            expand_kv_heads(
3347                &scratch.attn_kv_temp,
3348                &mut scratch.attn_kv_temp2,
3349                num_kv_heads,
3350                heads_per_kv,
3351                seq_len * head_dim,
3352                stream,
3353            )?;
3354        } else {
3355            // SAFETY: D2D copy with matching buffer sizes
3356            unsafe {
3357                scratch
3358                    .attn_kv_temp2
3359                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3360                    .map_err(|e| {
3361                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3362                            "K copy failed: {e:?}"
3363                        ))
3364                    })?;
3365            }
3366        }
3367
3368        // K^T: transpose for attention scores
3369        batched_transpose_forward(
3370            &scratch.attn_kv_temp2,
3371            &mut scratch.attn_kv_temp,
3372            nh,
3373            s,
3374            hd,
3375            stream,
3376        )?;
3377
3378        // Q @ K^T → attention scores
3379        batched_4d_gemm_forward(
3380            &scratch.attn_q_batched,
3381            &scratch.attn_kv_temp,
3382            &mut scratch.attn_scores,
3383            1,
3384            nh,
3385            s,
3386            s,
3387            hd,
3388            stream,
3389        )?;
3390
3391        // Scale by 1/sqrt(head_dim)
3392        let scale_factor = 1.0 / (head_dim as f32).sqrt();
3393        let total_scores = num_heads * seq_len * seq_len;
3394        let scores_view = unsafe {
3395            GpuBuffer::<f32>::from_raw_parts(
3396                scratch.attn_scores.as_ptr(),
3397                scratch.attn_scores.len(),
3398            )
3399        };
3400        scale_forward(
3401            &scores_view,
3402            &mut scratch.attn_scores,
3403            scale_factor,
3404            saturating_u32(total_scores),
3405            stream,
3406        )?;
3407        leak(scores_view);
3408
3409        // Softmax (in-place: input aliased with output via unsafe view)
3410        // C-CAUSAL-001: Apply causal mask before softmax (NF4 path)
3411        // PMAT-420: Use causal_mask_contiguous (correctly strided for seq_len)
3412        // instead of causal_mask (strided at max_seq_len, causes row misalignment
3413        // when seq_len < max_seq_len, leading to NaN after deep layers).
3414        {
3415            let seq_sq = seq_len * seq_len;
3416            let mask_ptr = scratch.causal_mask_contiguous.as_ptr();
3417            let scores_base = scratch.attn_scores.as_ptr();
3418            for head in 0..num_heads {
3419                let byte_offset = (head * seq_sq * 4) as u64;
3420                let head_ptr = scores_base + byte_offset;
3421                let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
3422                let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3423                let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3424                residual_add_forward(
3425                    &mask_view,
3426                    &scores_view,
3427                    &mut out_view,
3428                    saturating_u32(seq_sq),
3429                    stream,
3430                )?;
3431                leak(mask_view);
3432                leak(scores_view);
3433                leak(out_view);
3434            }
3435        }
3436
3437        // SAFETY: The softmax kernel reads each row completely into shared memory / registers
3438        // before writing output. The view is forgotten to prevent double-free.
3439        let scores_view = unsafe {
3440            GpuBuffer::<f32>::from_raw_parts(
3441                scratch.attn_scores.as_ptr(),
3442                scratch.attn_scores.len(),
3443            )
3444        };
3445        batched_softmax_forward(
3446            &scores_view,
3447            &mut scratch.attn_scores,
3448            saturating_u32(num_heads * seq_len),
3449            s,
3450            stream,
3451        )?;
3452        leak(scores_view);
3453
3454        // V: interleaved → batched, then GQA expand
3455        interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3456
3457        if heads_per_kv > 1 {
3458            expand_kv_heads(
3459                &scratch.attn_kv_temp,
3460                &mut scratch.attn_kv_temp2,
3461                num_kv_heads,
3462                heads_per_kv,
3463                seq_len * head_dim,
3464                stream,
3465            )?;
3466        } else {
3467            // SAFETY: async GPU buffer copy within same CUDA stream; both buffers are
3468            // pre-allocated scratch with matching sizes, and stream ordering guarantees
3469            // the source is fully written before this copy executes.
3470            unsafe {
3471                scratch
3472                    .attn_kv_temp2
3473                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3474                    .map_err(|e| {
3475                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3476                            "V copy failed: {e:?}"
3477                        ))
3478                    })?;
3479            }
3480        }
3481
3482        // attn_scores @ V → attention output
3483        batched_4d_gemm_forward(
3484            &scratch.attn_scores,
3485            &scratch.attn_kv_temp2,
3486            &mut scratch.attn_q_batched,
3487            1,
3488            nh,
3489            s,
3490            hd,
3491            s,
3492            stream,
3493        )?;
3494
3495        // Batched → interleaved layout
3496        batched_to_interleaved_forward(
3497            &scratch.attn_q_batched,
3498            &mut scratch.attn_out,
3499            s,
3500            nh,
3501            hd,
3502            stream,
3503        )?;
3504
3505        Ok(())
3506    }
3507}
3508
3509// =============================================================================
3510// QLoRA Backward Pass Types (ENT-153)
3511// =============================================================================
3512
3513/// Shared gradient workspace for LoRA weight gradients (one per model, NOT per layer).
3514///
3515/// Backward processes layers sequentially — only one layer's LoRA gradients
3516/// are computed at a time. Sharing this workspace saves
3517/// `(L-1) * per_layer_lora_grad_elements * 4` bytes of VRAM.
3518///
3519/// # Contract (C-LORAGRADWS-001)
3520///
3521/// - **Precondition**: Allocated once before training loop starts
3522/// - **Postcondition**: After backward() for layer i, contains layer i's LoRA gradients
3523/// - **Invariant**: Buffer sizes match model config; never reallocated during training
3524#[cfg(feature = "cuda")]
3525pub(crate) struct CudaLoraGradWorkspace {
3526    /// Gradient for LoRA A_q [hidden_size, rank]
3527    pub(crate) grad_lora_a_q: GpuBuffer<f32>,
3528    /// Gradient for LoRA B_q [rank, q_dim]
3529    pub(crate) grad_lora_b_q: GpuBuffer<f32>,
3530    /// Gradient for LoRA A_v [hidden_size, rank]
3531    pub(crate) grad_lora_a_v: GpuBuffer<f32>,
3532    /// Gradient for LoRA B_v [rank, kv_hidden]
3533    pub(crate) grad_lora_b_v: GpuBuffer<f32>,
3534    /// Gradient for input norm weight [hidden_size]
3535    pub(crate) grad_input_norm: GpuBuffer<f32>,
3536    /// Gradient for post-attention norm weight [hidden_size]
3537    pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
3538}
3539
3540#[cfg(feature = "cuda")]
3541impl CudaLoraGradWorkspace {
3542    /// Allocate shared LoRA gradient workspace.
3543    pub(crate) fn new(
3544        ctx: &Arc<CudaContext>,
3545        config: &super::config::TransformerConfig,
3546        lora_rank: usize,
3547    ) -> Result<Self> {
3548        let h = config.hidden_size;
3549        let q_dim = config.q_dim();
3550        let kv = config.num_kv_heads * config.head_dim();
3551        let r = lora_rank;
3552
3553        Ok(Self {
3554            grad_lora_a_q: GpuBuffer::new(ctx, h * r)?,
3555            grad_lora_b_q: GpuBuffer::new(ctx, r * q_dim)?,
3556            grad_lora_a_v: GpuBuffer::new(ctx, h * r)?,
3557            grad_lora_b_v: GpuBuffer::new(ctx, r * kv)?,
3558            grad_input_norm: GpuBuffer::new(ctx, h)?,
3559            grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
3560        })
3561    }
3562
3563    /// ENT-265: Clip all 6 LoRA gradient buffers by global L2 norm.
3564    ///
3565    /// Computes the global L2 norm across A_q, B_q, A_v, B_v, input_norm,
3566    /// and post_attn_norm. If the norm exceeds `max_norm`, scales all buffers
3567    /// down by `max_norm / (total_norm + 1e-6)`.
3568    ///
3569    /// Two-phase design: phase 1 reads norms (immutable), phase 2 applies
3570    /// scale (mutable). This satisfies the borrow checker when the workspace
3571    /// is behind a mutable reference.
3572    pub(crate) fn clip_gradients(&mut self, max_norm: f32, stream: &CudaStream) {
3573        // Phase 1: compute global L2 norm
3574        let sq_a_q = squared_sum_cuda(&self.grad_lora_a_q, self.grad_lora_a_q.len() as u32, stream)
3575            .unwrap_or(0.0);
3576        let sq_b_q = squared_sum_cuda(&self.grad_lora_b_q, self.grad_lora_b_q.len() as u32, stream)
3577            .unwrap_or(0.0);
3578        let sq_a_v = squared_sum_cuda(&self.grad_lora_a_v, self.grad_lora_a_v.len() as u32, stream)
3579            .unwrap_or(0.0);
3580        let sq_b_v = squared_sum_cuda(&self.grad_lora_b_v, self.grad_lora_b_v.len() as u32, stream)
3581            .unwrap_or(0.0);
3582        let sq_in =
3583            squared_sum_cuda(&self.grad_input_norm, self.grad_input_norm.len() as u32, stream)
3584                .unwrap_or(0.0);
3585        let sq_pa = squared_sum_cuda(
3586            &self.grad_post_attn_norm,
3587            self.grad_post_attn_norm.len() as u32,
3588            stream,
3589        )
3590        .unwrap_or(0.0);
3591        let total_norm = (sq_a_q + sq_b_q + sq_a_v + sq_b_v + sq_in + sq_pa).sqrt();
3592
3593        if total_norm <= max_norm {
3594            return;
3595        }
3596
3597        // Phase 2: apply clip scale
3598        let clip_scale = max_norm / (total_norm + 1e-6);
3599        let n_aq = self.grad_lora_a_q.len() as u32;
3600        let n_bq = self.grad_lora_b_q.len() as u32;
3601        let n_av = self.grad_lora_a_v.len() as u32;
3602        let n_bv = self.grad_lora_b_v.len() as u32;
3603        let n_in = self.grad_input_norm.len() as u32;
3604        let n_pa = self.grad_post_attn_norm.len() as u32;
3605        let _ = gradient_clip_cuda(&mut self.grad_lora_a_q, clip_scale, n_aq, stream);
3606        let _ = gradient_clip_cuda(&mut self.grad_lora_b_q, clip_scale, n_bq, stream);
3607        let _ = gradient_clip_cuda(&mut self.grad_lora_a_v, clip_scale, n_av, stream);
3608        let _ = gradient_clip_cuda(&mut self.grad_lora_b_v, clip_scale, n_bv, stream);
3609        let _ = gradient_clip_cuda(&mut self.grad_input_norm, clip_scale, n_in, stream);
3610        let _ = gradient_clip_cuda(&mut self.grad_post_attn_norm, clip_scale, n_pa, stream);
3611    }
3612}
3613
3614/// GPU-resident AdamW optimizer state for LoRA adapters in one NF4 block.
3615///
3616/// Stores first (m) and second (v) moment estimates for:
3617/// - 4 LoRA weight tensors (A_q, B_q, A_v, B_v)
3618/// - 2 RMSNorm weights (input_norm, post_attn_norm)
3619///
3620/// # Contract (C-LORAOPT-001)
3621///
3622/// - **Precondition**: CUDA context valid, buffers match weight dimensions
3623/// - **Postcondition**: m and v initialized to zero
3624/// - **Invariant**: Buffer sizes immutable after creation
3625#[cfg(feature = "cuda")]
3626pub(crate) struct GpuLoraOptimizerState {
3627    m_lora_a_q: GpuBuffer<f32>,
3628    v_lora_a_q: GpuBuffer<f32>,
3629    m_lora_b_q: GpuBuffer<f32>,
3630    v_lora_b_q: GpuBuffer<f32>,
3631    m_lora_a_v: GpuBuffer<f32>,
3632    v_lora_a_v: GpuBuffer<f32>,
3633    m_lora_b_v: GpuBuffer<f32>,
3634    v_lora_b_v: GpuBuffer<f32>,
3635    m_input_norm: GpuBuffer<f32>,
3636    v_input_norm: GpuBuffer<f32>,
3637    m_post_attn_norm: GpuBuffer<f32>,
3638    v_post_attn_norm: GpuBuffer<f32>,
3639}
3640
3641#[cfg(feature = "cuda")]
3642impl GpuLoraOptimizerState {
3643    fn new(
3644        ctx: &Arc<CudaContext>,
3645        config: &super::config::TransformerConfig,
3646        lora_rank: usize,
3647    ) -> Result<Self> {
3648        let h = config.hidden_size;
3649        let q_dim = config.q_dim();
3650        let kv = config.num_kv_heads * config.head_dim();
3651        let r = lora_rank;
3652
3653        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
3654        // zero memory (cuMemAlloc returns uninitialized VRAM).
3655        let z = |n: usize| -> Result<GpuBuffer<f32>> {
3656            Ok(GpuBuffer::from_host(ctx, &vec![0.0f32; n])?)
3657        };
3658        Ok(Self {
3659            m_lora_a_q: z(h * r)?,
3660            v_lora_a_q: z(h * r)?,
3661            m_lora_b_q: z(r * q_dim)?,
3662            v_lora_b_q: z(r * q_dim)?,
3663            m_lora_a_v: z(h * r)?,
3664            v_lora_a_v: z(h * r)?,
3665            m_lora_b_v: z(r * kv)?,
3666            v_lora_b_v: z(r * kv)?,
3667            m_input_norm: z(h)?,
3668            v_input_norm: z(h)?,
3669            m_post_attn_norm: z(h)?,
3670            v_post_attn_norm: z(h)?,
3671        })
3672    }
3673}
3674
3675// =============================================================================
3676// NF4 Block Backward Pass (ENT-153)
3677// =============================================================================
3678
3679#[cfg(feature = "cuda")]
3680impl CudaNf4TransformerBlock {
3681    /// Backward pass with activation checkpointing and LoRA gradient computation.
3682    ///
3683    /// # Activation Checkpointing
3684    ///
3685    /// Re-runs forward to regenerate intermediate activations. Only `layer_input`
3686    /// is saved per-layer (47 MB for 36 layers at seq_len=128). This is the standard
3687    /// NF4 QLoRA backward (C-QLORA-BWD-001): recompute activations, propagate gradients.
3688    #[allow(clippy::too_many_arguments)]
3689    pub(crate) fn backward(
3690        &self,
3691        layer_input: &GpuBuffer<f32>,
3692        grad_output: &GpuBuffer<f32>,
3693        grad_input: &mut GpuBuffer<f32>,
3694        output_scratch: &mut GpuBuffer<f32>,
3695        seq_len: usize,
3696        stream: &CudaStream,
3697        scratch: &mut CudaBlockScratch,
3698        grad_lora: &mut CudaLoraGradWorkspace,
3699    ) -> Result<()> {
3700        let hidden_size = self.config.hidden_size;
3701        let _q_dim = self.config.q_dim();
3702        let _kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3703        let intermediate_size = self.config.intermediate_size;
3704        let eps = 1e-5_f32;
3705
3706        // === Step 0: Activation checkpointing — re-run forward ===
3707        // This repopulates scratch with all intermediates needed for backward.
3708        self.forward(layer_input, output_scratch, seq_len, stream, scratch).map_err(|e| {
3709            eprintln!(
3710                "[backward] Layer {} activation-checkpoint forward FAILED: {e:?}",
3711                self.layer_idx
3712            );
3713            e
3714        })?;
3715
3716        // === Step 1: FFN backward (NF4 transpose, no weight grads for frozen projections) ===
3717        self.backward_nf4_ffn(
3718            grad_output,
3719            seq_len,
3720            hidden_size,
3721            intermediate_size,
3722            stream,
3723            scratch,
3724        )?;
3725
3726        // === Step 2: Post-attn norm backward ===
3727        let _t = scratch.op_begin(); // OP_NORM_BWD timing (both norms)
3728        rms_norm_backward(
3729            &scratch.residual1,
3730            &self.post_attn_norm_weight,
3731            &scratch.grad_hidden, // grad_from_ffn is accumulated in grad_hidden by backward_nf4_ffn
3732            grad_input,           // temporarily store post-attn-norm grad here
3733            &mut grad_lora.grad_post_attn_norm,
3734            saturating_u32(seq_len),
3735            saturating_u32(hidden_size),
3736            eps,
3737            stream,
3738        )?;
3739
3740        // Add residual connection: grad flows through both ffn and skip path
3741        // grad_residual1 = grad_input (from norm backward) + grad_output (from residual skip)
3742        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
3743
3744        // === Step 3: Attention backward (NF4 + LoRA for Q/V) ===
3745        self.backward_nf4_attention(
3746            grad_input, // grad coming into attention (from residual1)
3747            seq_len, stream, scratch, grad_lora,
3748        )?;
3749
3750        // === Step 4: Input norm backward + first residual ===
3751        // At this point, scratch.grad_hidden contains grad from attention block
3752        // (accumulated by backward_nf4_attention into norm1_out reusing grad_hidden)
3753        rms_norm_backward(
3754            layer_input,
3755            &self.input_norm_weight,
3756            &scratch.grad_hidden, // grad flowing into norm1
3757            grad_input,           // final grad_input for this layer
3758            &mut grad_lora.grad_input_norm,
3759            saturating_u32(seq_len),
3760            saturating_u32(hidden_size),
3761            eps,
3762            stream,
3763        )?;
3764
3765        scratch.op_end(_t, OP_NORM_BWD);
3766
3767        Ok(())
3768    }
3769
3770    /// FFN backward for NF4 blocks (ENT-287: cuBLAS fp32 GEMM).
3771    ///
3772    /// Propagates gradient through: down_proj → SwiGLU → gate/up projections.
3773    /// Uses cuBLAS GEMM with pre-dequantized fp32 weights for correct layout.
3774    /// No weight gradients for frozen NF4 weights.
3775    fn backward_nf4_ffn(
3776        &self,
3777        grad_output: &GpuBuffer<f32>,
3778        seq_len: usize,
3779        hidden_size: usize,
3780        intermediate_size: usize,
3781        stream: &CudaStream,
3782        scratch: &mut CudaBlockScratch,
3783    ) -> Result<()> {
3784        let s = saturating_u32(seq_len);
3785        let h = saturating_u32(hidden_size);
3786        let i_size = saturating_u32(intermediate_size);
3787        let n_inter = saturating_u32(seq_len * intermediate_size);
3788
3789        // PMAT-481: NF4 tensor core backward dispatch
3790        static USE_NF4_TC_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3791        let nf4_tc_bwd =
3792            *USE_NF4_TC_BWD.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3793
3794        // Step 1: grad_swiglu[S,I] = grad_output[S,H] @ W_down^T (PMAT-472: fp16 dispatch)
3795        let _t = scratch.op_begin(); // OP_DOWN_BWD timing
3796        if nf4_tc_bwd {
3797            // NF4 TC backward: fused dequant+WMMA, no separate dequant kernel
3798            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3799                grad_output,
3800                &self.w_down_nf4,
3801                &self.w_down_scales,
3802                &mut scratch.grad_swiglu,
3803                s,
3804                h,      // N = hidden_size (grad_output cols)
3805                i_size, // K = intermediate_size (output cols = W_down rows)
3806                stream,
3807            )?;
3808        } else {
3809            gemm_backward_a_fp16_dispatch(
3810                grad_output,
3811                self.w_down_fp16.as_ref(),
3812                &self.w_down_fp32,
3813                &mut scratch.grad_swiglu,
3814                s,
3815                i_size,
3816                h,
3817                stream,
3818                &self.ctx,
3819            )?;
3820        }
3821
3822        scratch.op_end(_t, OP_DOWN_BWD);
3823
3824        // Step 2: SwiGLU backward: swiglu = silu(gate) * up
3825        // d_gate = d_swiglu * up * silu'(gate)
3826        // d_up   = d_swiglu * silu(gate)
3827        let _t = scratch.op_begin(); // OP_SWIGLU_BWD timing
3828
3829        // temp1 = d_swiglu * up_out → store in swiglu_out (reuse)
3830        elementwise_mul_forward(
3831            &scratch.grad_swiglu,
3832            &scratch.up_out,
3833            &mut scratch.swiglu_out,
3834            n_inter,
3835            stream,
3836        )?;
3837
3838        // silu_backward: d_gate_raw = temp1 * silu'(gate_out)
3839        // silu'(x) = silu(x) * (1 + x*(1-silu(x)))
3840        // Reuse up_out as storage for d_gate
3841        silu_backward(
3842            &scratch.gate_out,
3843            &scratch.swiglu_out,
3844            &mut scratch.up_out, // d_gate stored here
3845            stream,
3846        )?;
3847
3848        // d_up = d_swiglu * silu(gate) → store in gate_out (reuse)
3849        // Compute silu(gate) into swiglu_out (scratch) — NOT ffn_out which is [S,H] (too small)
3850        silu_forward(&scratch.gate_out, &mut scratch.swiglu_out, n_inter, stream)?;
3851        // d_up = d_swiglu * silu(gate)
3852        elementwise_mul_forward(
3853            &scratch.grad_swiglu,
3854            &scratch.swiglu_out,
3855            &mut scratch.gate_out, // d_up stored here
3856            n_inter,
3857            stream,
3858        )?;
3859
3860        scratch.op_end(_t, OP_SWIGLU_BWD);
3861
3862        // Step 3: gate/up backward (PMAT-472: fp16 dispatch, PMAT-481: TC dispatch)
3863        let _t = scratch.op_begin(); // OP_GATE_UP_BWD timing
3864                                     // PMAT-484: Fused backward — use cuBLAS beta=1.0 accumulate to eliminate
3865                                     // the separate cuda_add_inplace kernel launch (3 launches → 2).
3866        static USE_FUSED_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3867        let fused_bwd = *USE_FUSED_BWD
3868            .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
3869
3870        if nf4_tc_bwd {
3871            // PMAT-481: NF4 tensor core backward — fused dequant+WMMA per projection
3872            // Up backward: grad_up[S,H] = d_up[S,I] @ W_up^T
3873            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3874                &scratch.gate_out, // d_up stored in gate_out buffer
3875                &self.w_up_nf4,
3876                &self.w_up_scales,
3877                &mut scratch.grad_hidden,
3878                s,
3879                i_size, // N = intermediate_size (d_up cols)
3880                h,      // K = hidden_size (output cols)
3881                stream,
3882            )?;
3883            // Gate backward: grad_gate[S,H] = d_gate[S,I] @ W_gate^T
3884            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3885                &scratch.up_out, // d_gate stored in up_out buffer
3886                &self.w_gate_nf4,
3887                &self.w_gate_scales,
3888                &mut scratch.ffn_out,
3889                s,
3890                i_size, // N = intermediate_size
3891                h,      // K = hidden_size
3892                stream,
3893            )?;
3894            // Accumulate: grad_hidden += grad_gate
3895            cuda_add_inplace(
3896                &mut scratch.grad_hidden,
3897                &scratch.ffn_out,
3898                seq_len * hidden_size,
3899                stream,
3900            )?;
3901        } else if fused_bwd {
3902            // Fused path: compute up backward into grad_hidden, then accumulate gate
3903            gemm_backward_a_fp16_dispatch(
3904                &scratch.gate_out,
3905                self.w_up_fp16.as_ref(),
3906                &self.w_up_fp32,
3907                &mut scratch.grad_hidden,
3908                s,
3909                h,
3910                i_size,
3911                stream,
3912                &self.ctx,
3913            )?;
3914            // Accumulate gate backward into grad_hidden (beta=1.0)
3915            gemm_backward_a_fp16_dispatch_accumulate(
3916                &scratch.up_out,
3917                self.w_gate_fp16.as_ref(),
3918                &self.w_gate_fp32,
3919                &mut scratch.grad_hidden,
3920                s,
3921                h,
3922                i_size,
3923                stream,
3924                &self.ctx,
3925            )?;
3926        } else {
3927            // Unfused path: separate GEMMs + explicit add
3928            gemm_backward_a_fp16_dispatch(
3929                &scratch.up_out,
3930                self.w_gate_fp16.as_ref(),
3931                &self.w_gate_fp32,
3932                &mut scratch.ffn_out,
3933                s,
3934                h,
3935                i_size,
3936                stream,
3937                &self.ctx,
3938            )?;
3939            gemm_backward_a_fp16_dispatch(
3940                &scratch.gate_out,
3941                self.w_up_fp16.as_ref(),
3942                &self.w_up_fp32,
3943                &mut scratch.grad_hidden,
3944                s,
3945                h,
3946                i_size,
3947                stream,
3948                &self.ctx,
3949            )?;
3950
3951            // Accumulate: grad_hidden = grad_norm2_gate + grad_norm2_up
3952            cuda_add_inplace(
3953                &mut scratch.grad_hidden,
3954                &scratch.ffn_out,
3955                seq_len * hidden_size,
3956                stream,
3957            )?;
3958        }
3959        scratch.op_end(_t, OP_GATE_UP_BWD);
3960
3961        Ok(())
3962    }
3963
3964    /// Attention backward for NF4 blocks with LoRA gradient computation (ENT-287).
3965    ///
3966    /// Propagates gradient through O projection, attention mechanism, and Q/K/V projections.
3967    /// Computes LoRA weight gradients for Q and V projections.
3968    /// Uses cuBLAS GEMM with pre-dequantized fp32 weights.
3969    fn backward_nf4_attention(
3970        &self,
3971        grad_residual1: &GpuBuffer<f32>,
3972        seq_len: usize,
3973        stream: &CudaStream,
3974        scratch: &mut CudaBlockScratch,
3975        grad_lora: &mut CudaLoraGradWorkspace,
3976    ) -> Result<()> {
3977        use crate::autograd::cuda_forward::gemm_forward;
3978
3979        let hidden_size = self.config.hidden_size;
3980        let q_dim = self.config.q_dim();
3981        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3982        let num_heads = self.config.num_attention_heads;
3983        let head_dim = self.config.head_dim();
3984
3985        let s = saturating_u32(seq_len);
3986        let h = saturating_u32(hidden_size);
3987        let qd = saturating_u32(q_dim);
3988        let kvh = saturating_u32(kv_hidden_size);
3989
3990        // Step 1: O projection backward (PMAT-481: TC dispatch)
3991        static USE_NF4_TC_BWD_O: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3992        let nf4_tc_bwd_o = *USE_NF4_TC_BWD_O
3993            .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3994
3995        let _t = scratch.op_begin(); // OP_ATTN_BWD timing (O-proj + attention mechanism)
3996        if nf4_tc_bwd_o {
3997            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3998                grad_residual1,
3999                &self.w_o_nf4,
4000                &self.w_o_scales,
4001                &mut scratch.attn_out,
4002                s,
4003                h,  // N = hidden_size (grad_residual1 cols)
4004                qd, // K = q_dim (output cols = W_o rows)
4005                stream,
4006            )?;
4007        } else {
4008            gemm_backward_a_fp16_dispatch(
4009                grad_residual1,
4010                self.w_o_fp16.as_ref(),
4011                &self.w_o_fp32,
4012                &mut scratch.attn_out,
4013                s,
4014                qd,
4015                h,
4016                stream,
4017                &self.ctx,
4018            )?;
4019        }
4020
4021        // Step 2: Attention mechanism backward
4022        // This is complex (softmax backward, batched GEMMs) — reuse the fp32 attention backward
4023        // infrastructure since attention operates on fp32 activations.
4024        self.backward_nf4_attention_mechanism(seq_len, num_heads, head_dim, stream, scratch)?;
4025
4026        // After attention backward: scratch.norm1_out-related grads are accumulated.
4027        // grad_q is in scratch.q, grad_k in scratch.k, grad_v in scratch.v
4028
4029        // Step 2b: RoPE backward (inverse rotation) on grad_q and grad_k
4030        // Forward applied RoPE to Q and K. Undo rotation so projection backward
4031        // gets gradients in the unrotated coordinate frame.
4032        let rope_theta = self.config.rope_theta;
4033        let num_kv_heads = self.config.num_kv_heads;
4034        let nkv = saturating_u32(num_kv_heads);
4035        let nh = saturating_u32(num_heads);
4036        let hd = saturating_u32(head_dim);
4037        {
4038            let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
4039            batched_rope_neox_backward(
4040                q_ref,
4041                &mut scratch.q,
4042                &scratch.rope_positions,
4043                nh,
4044                hd,
4045                s,
4046                rope_theta,
4047                stream,
4048            )?;
4049            let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
4050            batched_rope_neox_backward(
4051                k_ref,
4052                &mut scratch.k,
4053                &scratch.rope_positions,
4054                nkv,
4055                hd,
4056                s,
4057                rope_theta,
4058                stream,
4059            )?;
4060        }
4061
4062        scratch.op_end(_t, OP_ATTN_BWD);
4063
4064        // Q/K/V backward (PMAT-472: fp16 dispatch, PMAT-481: TC dispatch)
4065        static USE_NF4_TC_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4066        let nf4_tc_bwd = *USE_NF4_TC_BWD_ATTN
4067            .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4068
4069        let _t = scratch.op_begin(); // OP_QKV_BWD timing
4070        if nf4_tc_bwd {
4071            // PMAT-481: NF4 tensor core backward for Q projection
4072            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4073                &scratch.q,
4074                &self.w_q_nf4,
4075                &self.w_q_scales,
4076                &mut scratch.o_proj_out,
4077                s,
4078                qd, // N = q_dim (grad_q cols)
4079                h,  // K = hidden_size (output cols)
4080                stream,
4081            )?;
4082        } else {
4083            gemm_backward_a_fp16_dispatch(
4084                &scratch.q,
4085                self.w_q_fp16.as_ref(),
4086                &self.w_q_fp32,
4087                &mut scratch.o_proj_out,
4088                s,
4089                h,
4090                qd,
4091                stream,
4092                &self.ctx,
4093            )?;
4094        }
4095
4096        // LoRA Q backward: compute grad_A_q, grad_B_q, and add to grad_norm1
4097        if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
4098            let r = saturating_u32(self.lora_rank);
4099
4100            // Recompute: lora_inter_q = norm1_out @ A_q  [S, rank]
4101            gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
4102
4103            // grad_B_q = lora_inter_q^T @ grad_q  [rank, q_dim]
4104            // (Note: B_q was pre-scaled, so grad_B_q includes the scale factor)
4105            gemm_backward_b(
4106                &scratch.lora_inter,
4107                &scratch.q,
4108                &mut grad_lora.grad_lora_b_q,
4109                s,
4110                r,
4111                qd,
4112                stream,
4113            )?;
4114
4115            // grad_lora_inter = grad_q @ B_q^T  [S, rank]
4116            gemm_backward_a(
4117                &scratch.q,
4118                b_q,
4119                &mut scratch.lora_inter, // reuse for grad_lora_inter
4120                s,
4121                qd,
4122                r,
4123                stream,
4124            )?;
4125
4126            // grad_A_q = norm1_out^T @ grad_lora_inter  [H, rank]
4127            gemm_backward_b(
4128                &scratch.norm1_out,
4129                &scratch.lora_inter,
4130                &mut grad_lora.grad_lora_a_q,
4131                s,
4132                h,
4133                r,
4134                stream,
4135            )?;
4136
4137            // Add LoRA's contribution to grad_norm1: += grad_lora_inter @ A_q^T  [S, H]
4138            gemm_backward_a(
4139                &scratch.lora_inter,
4140                a_q,
4141                &mut scratch.lora_temp, // [S, H]
4142                s,
4143                r,
4144                h,
4145                stream,
4146            )?;
4147            cuda_add_inplace(
4148                &mut scratch.o_proj_out,
4149                &scratch.lora_temp,
4150                seq_len * hidden_size,
4151                stream,
4152            )?;
4153        }
4154
4155        // K+V backward (PMAT-484: fused, PMAT-481: TC dispatch)
4156        static USE_FUSED_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4157        let fused_bwd = *USE_FUSED_BWD_ATTN
4158            .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
4159
4160        if nf4_tc_bwd {
4161            // PMAT-481: NF4 tensor core backward for K and V projections
4162            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4163                &scratch.k,
4164                &self.w_k_nf4,
4165                &self.w_k_scales,
4166                &mut scratch.ffn_out,
4167                s,
4168                kvh, // N = kv_hidden (grad_k cols)
4169                h,   // K = hidden_size (output cols)
4170                stream,
4171            )?;
4172            cuda_add_inplace(
4173                &mut scratch.o_proj_out,
4174                &scratch.ffn_out,
4175                seq_len * hidden_size,
4176                stream,
4177            )?;
4178            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4179                &scratch.v,
4180                &self.w_v_nf4,
4181                &self.w_v_scales,
4182                &mut scratch.ffn_out,
4183                s,
4184                kvh, // N = kv_hidden
4185                h,   // K = hidden_size
4186                stream,
4187            )?;
4188            cuda_add_inplace(
4189                &mut scratch.o_proj_out,
4190                &scratch.ffn_out,
4191                seq_len * hidden_size,
4192                stream,
4193            )?;
4194        } else if fused_bwd {
4195            // Fused: K and V backward accumulate directly into o_proj_out
4196            gemm_backward_a_fp16_dispatch_accumulate(
4197                &scratch.k,
4198                self.w_k_fp16.as_ref(),
4199                &self.w_k_fp32,
4200                &mut scratch.o_proj_out,
4201                s,
4202                h,
4203                kvh,
4204                stream,
4205                &self.ctx,
4206            )?;
4207            gemm_backward_a_fp16_dispatch_accumulate(
4208                &scratch.v,
4209                self.w_v_fp16.as_ref(),
4210                &self.w_v_fp32,
4211                &mut scratch.o_proj_out,
4212                s,
4213                h,
4214                kvh,
4215                stream,
4216                &self.ctx,
4217            )?;
4218        } else {
4219            // Unfused: K backward → temp, accumulate; V backward → temp, accumulate
4220            gemm_backward_a_fp16_dispatch(
4221                &scratch.k,
4222                self.w_k_fp16.as_ref(),
4223                &self.w_k_fp32,
4224                &mut scratch.ffn_out,
4225                s,
4226                h,
4227                kvh,
4228                stream,
4229                &self.ctx,
4230            )?;
4231            cuda_add_inplace(
4232                &mut scratch.o_proj_out,
4233                &scratch.ffn_out,
4234                seq_len * hidden_size,
4235                stream,
4236            )?;
4237
4238            gemm_backward_a_fp16_dispatch(
4239                &scratch.v,
4240                self.w_v_fp16.as_ref(),
4241                &self.w_v_fp32,
4242                &mut scratch.ffn_out,
4243                s,
4244                h,
4245                kvh,
4246                stream,
4247                &self.ctx,
4248            )?;
4249            cuda_add_inplace(
4250                &mut scratch.o_proj_out,
4251                &scratch.ffn_out,
4252                seq_len * hidden_size,
4253                stream,
4254            )?;
4255        }
4256
4257        // LoRA V backward
4258        if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
4259            let r = saturating_u32(self.lora_rank);
4260
4261            // Recompute: lora_inter_v = norm1_out @ A_v  [S, rank]
4262            gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
4263
4264            // grad_B_v = lora_inter_v^T @ grad_v  [rank, kv_hidden]
4265            gemm_backward_b(
4266                &scratch.lora_inter,
4267                &scratch.v,
4268                &mut grad_lora.grad_lora_b_v,
4269                s,
4270                r,
4271                kvh,
4272                stream,
4273            )?;
4274
4275            // grad_lora_inter = grad_v @ B_v^T  [S, rank]
4276            gemm_backward_a(&scratch.v, b_v, &mut scratch.lora_inter, s, kvh, r, stream)?;
4277
4278            // grad_A_v = norm1_out^T @ grad_lora_inter  [H, rank]
4279            gemm_backward_b(
4280                &scratch.norm1_out,
4281                &scratch.lora_inter,
4282                &mut grad_lora.grad_lora_a_v,
4283                s,
4284                h,
4285                r,
4286                stream,
4287            )?;
4288
4289            // Add LoRA V's contribution to grad_norm1
4290            gemm_backward_a(&scratch.lora_inter, a_v, &mut scratch.lora_temp, s, r, h, stream)?;
4291            cuda_add_inplace(
4292                &mut scratch.o_proj_out,
4293                &scratch.lora_temp,
4294                seq_len * hidden_size,
4295                stream,
4296            )?;
4297        }
4298
4299        scratch.op_end(_t, OP_QKV_BWD);
4300
4301        // Step 6: Accumulated grad_norm1 is in scratch.o_proj_out → move to scratch.grad_hidden
4302        unsafe {
4303            scratch.grad_hidden.copy_from_buffer_async(&scratch.o_proj_out, stream).map_err(
4304                |e| {
4305                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4306                        "grad_norm1 copy failed: {e}"
4307                    ))
4308                },
4309            )?;
4310        }
4311
4312        Ok(())
4313    }
4314
4315    /// Attention mechanism backward (softmax, Q@K^T backward) for NF4 blocks.
4316    ///
4317    /// After this call:
4318    /// - scratch.q contains grad_q [S, q_dim]
4319    /// - scratch.k contains grad_k [S, kv_hidden]
4320    /// - scratch.v contains grad_v [S, kv_hidden]
4321    ///
4322    /// PMAT-486: Full attention backward (replaces previous no-op).
4323    /// Mirrors the FP32 backward in `backward_attention` (lines 1470-1810).
4324    ///
4325    /// Forward: attn_out = softmax(Q @ K^T / √d) @ V
4326    /// Backward:
4327    ///   1. grad_scores = grad_attn_batched @ V^T
4328    ///   2. grad_V = (grad_attn_batched^T @ attn_weights)^T  (buffer-safe identity)
4329    ///   3. softmax_backward(grad_scores, attn_weights) → grad_raw
4330    ///   4. grad_raw *= 1/√d
4331    ///   5. grad_Q = grad_raw @ K_expanded
4332    ///   6. grad_K = (Q^T @ grad_raw)^T
4333    ///   7. GQA reduction + batched→interleaved conversion
4334    ///
4335    /// Contract: attention-backward-v1.yaml
4336    fn backward_nf4_attention_mechanism(
4337        &self,
4338        seq_len: usize,
4339        num_heads: usize,
4340        head_dim: usize,
4341        stream: &CudaStream,
4342        scratch: &mut CudaBlockScratch,
4343    ) -> Result<()> {
4344        let num_kv_heads = self.config.num_kv_heads;
4345        let heads_per_kv = num_heads / num_kv_heads;
4346        let s = saturating_u32(seq_len);
4347        let nh = saturating_u32(num_heads);
4348        let nkv = saturating_u32(num_kv_heads);
4349        let hd = saturating_u32(head_dim);
4350        let scale = 1.0 / (head_dim as f32).sqrt();
4351
4352        // grad_attn_out is in scratch.attn_out [S, q_dim]
4353        // Convert to batched layout [NH, S, HD]
4354        interleaved_to_batched_forward(
4355            &scratch.attn_out,
4356            &mut scratch.attn_q_batched, // grad_attn_batched [NH, S, HD]
4357            s,
4358            nh,
4359            hd,
4360            stream,
4361        )?;
4362
4363        // === Step 1: Expand V for GQA and transpose ===
4364        // V is in scratch.v [S, kv_hidden] from activation checkpointing forward.
4365        // Convert to batched [NKV, S, HD], then GQA-expand to [NH, S, HD].
4366        interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
4367
4368        if heads_per_kv > 1 {
4369            expand_kv_heads(
4370                &scratch.attn_kv_temp,
4371                &mut scratch.attn_kv_temp2,
4372                num_kv_heads,
4373                heads_per_kv,
4374                seq_len * head_dim,
4375                stream,
4376            )?;
4377        } else {
4378            unsafe {
4379                scratch
4380                    .attn_kv_temp2
4381                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
4382                    .map_err(|e| {
4383                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4384                            "V copy for attn backward: {e:?}"
4385                        ))
4386                    })?;
4387            }
4388        }
4389        // attn_kv_temp2 = V_expanded [NH, S, HD]
4390
4391        // Transpose V: [NH, S, HD] → [NH, HD, S]
4392        batched_transpose_forward(
4393            &scratch.attn_kv_temp2,
4394            &mut scratch.attn_kv_temp, // V^T [NH, HD, S]
4395            nh,
4396            s,
4397            hd,
4398            stream,
4399        )?;
4400
4401        // === Step 2: grad_attn_scores = grad_attn_batched @ V^T ===
4402        // [NH, S, HD] @ [NH, HD, S] → [NH, S, S]
4403        batched_4d_gemm_forward(
4404            &scratch.attn_q_batched,
4405            &scratch.attn_kv_temp,
4406            &mut scratch.grad_attn_scores,
4407            1,
4408            nh,
4409            s,
4410            s,
4411            hd,
4412            stream,
4413        )?;
4414
4415        // === Step 3: grad_V = (grad_attn_batched^T @ attn_weights)^T ===
4416        // Uses the identity to avoid needing an [NH, S, S] transpose buffer.
4417        // attn_weights in scratch.attn_scores [NH, S, S] from forward checkpoint.
4418
4419        // Step 3a: transpose grad_attn_batched [NH, S, HD] → [NH, HD, S]
4420        batched_transpose_forward(
4421            &scratch.attn_q_batched,
4422            &mut scratch.attn_kv_temp, // reuse: grad_attn_batched^T [NH, HD, S]
4423            nh,
4424            s,
4425            hd,
4426            stream,
4427        )?;
4428
4429        // Step 3b: [NH, HD, S] @ [NH, S, S] → [NH, HD, S] (= grad_V^T)
4430        batched_4d_gemm_forward(
4431            &scratch.attn_kv_temp,
4432            &scratch.attn_scores,       // attn_weights from forward
4433            &mut scratch.attn_kv_temp2, // grad_V^T [NH, HD, S]
4434            1,
4435            nh,
4436            hd,
4437            s,
4438            s,
4439            stream,
4440        )?;
4441
4442        // Step 3c: transpose grad_V^T [NH, HD, S] → grad_V [NH, S, HD]
4443        batched_transpose_forward(
4444            &scratch.attn_kv_temp2,
4445            &mut scratch.attn_kv_temp, // grad_V [NH, S, HD]
4446            nh,
4447            hd,
4448            s,
4449            stream,
4450        )?;
4451        // attn_kv_temp = grad_V [NH, S, HD]
4452
4453        // === Step 4: Softmax backward ===
4454        // In-place: grad_attn_scores is both input and output.
4455        let total_rows = nh * s;
4456        {
4457            let grad_scores_view = unsafe {
4458                GpuBuffer::<f32>::from_raw_parts(
4459                    scratch.grad_attn_scores.as_ptr(),
4460                    scratch.grad_attn_scores.len(),
4461                )
4462            };
4463            batched_softmax_backward(
4464                &scratch.attn_scores,
4465                &grad_scores_view,
4466                &mut scratch.grad_attn_scores,
4467                total_rows,
4468                s,
4469                stream,
4470            )?;
4471            leak(grad_scores_view);
4472        }
4473
4474        // === Step 5: Scale backward (1/√d) ===
4475        let total_scores = saturating_u32(num_heads * seq_len * seq_len);
4476        {
4477            let scores_view = unsafe {
4478                GpuBuffer::<f32>::from_raw_parts(
4479                    scratch.grad_attn_scores.as_ptr(),
4480                    scratch.grad_attn_scores.len(),
4481                )
4482            };
4483            scale_forward(
4484                &scores_view,
4485                &mut scratch.grad_attn_scores,
4486                scale,
4487                total_scores,
4488                stream,
4489            )?;
4490            leak(scores_view);
4491        }
4492
4493        // === Step 6: Reconstruct K, GQA expand, compute grad_Q ===
4494        interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp2, s, nkv, hd, stream)?;
4495
4496        if heads_per_kv > 1 {
4497            unsafe {
4498                scratch
4499                    .attn_q_batched
4500                    .copy_from_buffer_async(&scratch.attn_kv_temp2, stream)
4501                    .map_err(|e| {
4502                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4503                            "K copy for GQA expand: {e}"
4504                        ))
4505                    })?;
4506            }
4507            expand_kv_heads(
4508                &scratch.attn_q_batched,
4509                &mut scratch.attn_kv_temp2,
4510                num_kv_heads,
4511                heads_per_kv,
4512                seq_len * head_dim,
4513                stream,
4514            )?;
4515        }
4516        // attn_kv_temp2 = K_expanded [NH, S, HD]
4517
4518        // grad_Q = grad_raw_scores @ K_expanded → attn_q_batched [NH, S, HD]
4519        batched_4d_gemm_forward(
4520            &scratch.grad_attn_scores,
4521            &scratch.attn_kv_temp2,
4522            &mut scratch.attn_q_batched,
4523            1,
4524            nh,
4525            s,
4526            hd,
4527            s,
4528            stream,
4529        )?;
4530
4531        // === Step 7: Compute grad_K ===
4532        // grad_K^T = Q^T @ grad_raw_scores
4533        // Reconstruct Q_batched into o_proj_out (attn_q_batched has grad_Q now)
4534        interleaved_to_batched_forward(
4535            &scratch.q,
4536            &mut scratch.o_proj_out, // temp for Q_batched
4537            s,
4538            nh,
4539            hd,
4540            stream,
4541        )?;
4542
4543        // Transpose Q: [NH, S, HD] → [NH, HD, S]
4544        batched_transpose_forward(
4545            &scratch.o_proj_out,
4546            &mut scratch.attn_kv_temp2, // Q^T [NH, HD, S]
4547            nh,
4548            s,
4549            hd,
4550            stream,
4551        )?;
4552
4553        // grad_K^T = Q^T @ grad_raw_scores → ffn_out as temp [NH, HD, S]
4554        batched_4d_gemm_forward(
4555            &scratch.attn_kv_temp2,
4556            &scratch.grad_attn_scores,
4557            &mut scratch.ffn_out, // grad_K^T [NH, HD, S]
4558            1,
4559            nh,
4560            hd,
4561            s,
4562            s,
4563            stream,
4564        )?;
4565
4566        // Transpose grad_K^T → grad_K: [NH, HD, S] → [NH, S, HD]
4567        batched_transpose_forward(
4568            &scratch.ffn_out,
4569            &mut scratch.attn_kv_temp2, // grad_K [NH, S, HD]
4570            nh,
4571            hd,
4572            s,
4573            stream,
4574        )?;
4575
4576        // === Step 8: GQA gradient reduction ===
4577        if heads_per_kv > 1 {
4578            self.reduce_gqa_gradients_nf4(
4579                num_kv_heads,
4580                heads_per_kv,
4581                seq_len,
4582                head_dim,
4583                stream,
4584                scratch,
4585            )?;
4586        }
4587
4588        // === Step 9: Convert batched gradients → interleaved, store in scratch.q/k/v ===
4589        // grad_Q: attn_q_batched [NH, S, HD] → scratch.q [S, q_dim]
4590        batched_to_interleaved_forward(&scratch.attn_q_batched, &mut scratch.q, s, nh, hd, stream)?;
4591
4592        // grad_K: attn_kv_temp2 [NKV, S, HD] → scratch.k [S, kv_hidden]
4593        batched_to_interleaved_forward(&scratch.attn_kv_temp2, &mut scratch.k, s, nkv, hd, stream)?;
4594
4595        // grad_V: attn_kv_temp [NKV, S, HD] → scratch.v [S, kv_hidden]
4596        // (After GQA reduction, grad_V was reduced from [NH] to [NKV] heads)
4597        batched_to_interleaved_forward(&scratch.attn_kv_temp, &mut scratch.v, s, nkv, hd, stream)?;
4598
4599        Ok(())
4600    }
4601
4602    /// GQA gradient reduction for NF4 blocks.
4603    /// Reduces grad_K and grad_V from [num_heads, S, HD] to [num_kv_heads, S, HD]
4604    /// by summing across Q heads sharing each KV head.
4605    fn reduce_gqa_gradients_nf4(
4606        &self,
4607        num_kv_heads: usize,
4608        heads_per_kv: usize,
4609        seq_len: usize,
4610        head_dim: usize,
4611        stream: &CudaStream,
4612        scratch: &mut CudaBlockScratch,
4613    ) -> Result<()> {
4614        let chunk = seq_len * head_dim;
4615        for g in 0..num_kv_heads {
4616            let dst_off = g * chunk;
4617            // First Q head in group → initialize destination
4618            let src_off = g * heads_per_kv * chunk;
4619            // Copy first head
4620            {
4621                let src = unsafe {
4622                    GpuBuffer::<f32>::from_raw_parts(
4623                        scratch.attn_kv_temp2.as_ptr() + (src_off * 4) as u64,
4624                        chunk,
4625                    )
4626                };
4627                let mut dst = unsafe {
4628                    GpuBuffer::<f32>::from_raw_parts(
4629                        scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4630                        chunk,
4631                    )
4632                };
4633                if src_off != dst_off {
4634                    unsafe {
4635                        dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4636                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4637                                "GQA K reduce copy: {e}"
4638                            ))
4639                        })?;
4640                    }
4641                }
4642                leak(src);
4643                leak(dst);
4644            }
4645            // Accumulate remaining heads
4646            for h in 1..heads_per_kv {
4647                let add_off = (g * heads_per_kv + h) * chunk;
4648                let src = unsafe {
4649                    GpuBuffer::<f32>::from_raw_parts(
4650                        scratch.attn_kv_temp2.as_ptr() + (add_off * 4) as u64,
4651                        chunk,
4652                    )
4653                };
4654                let mut dst = unsafe {
4655                    GpuBuffer::<f32>::from_raw_parts(
4656                        scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4657                        chunk,
4658                    )
4659                };
4660                cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4661                leak(src);
4662                leak(dst);
4663            }
4664            // Same for grad_V (in attn_kv_temp)
4665            {
4666                let src = unsafe {
4667                    GpuBuffer::<f32>::from_raw_parts(
4668                        scratch.attn_kv_temp.as_ptr() + (src_off * 4) as u64,
4669                        chunk,
4670                    )
4671                };
4672                let mut dst = unsafe {
4673                    GpuBuffer::<f32>::from_raw_parts(
4674                        scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4675                        chunk,
4676                    )
4677                };
4678                if src_off != dst_off {
4679                    unsafe {
4680                        dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4681                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4682                                "GQA V reduce copy: {e}"
4683                            ))
4684                        })?;
4685                    }
4686                }
4687                leak(src);
4688                leak(dst);
4689            }
4690            for h in 1..heads_per_kv {
4691                let add_off = (g * heads_per_kv + h) * chunk;
4692                let src = unsafe {
4693                    GpuBuffer::<f32>::from_raw_parts(
4694                        scratch.attn_kv_temp.as_ptr() + (add_off * 4) as u64,
4695                        chunk,
4696                    )
4697                };
4698                let mut dst = unsafe {
4699                    GpuBuffer::<f32>::from_raw_parts(
4700                        scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4701                        chunk,
4702                    )
4703                };
4704                cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4705                leak(src);
4706                leak(dst);
4707            }
4708        }
4709        Ok(())
4710    }
4711
4712    /// Initialize LoRA optimizer state for this block.
4713    pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
4714        GpuLoraOptimizerState::new(&self.ctx, &self.config, self.lora_rank)
4715    }
4716
4717    /// LoRA optimizer step: update A_q, B_q, A_v, B_v and norm weights using AdamW.
4718    #[allow(clippy::too_many_arguments)]
4719    pub(crate) fn lora_optimizer_step(
4720        &mut self,
4721        state: &mut GpuLoraOptimizerState,
4722        step: u32,
4723        lr: f32,
4724        beta1: f32,
4725        beta2: f32,
4726        eps: f32,
4727        weight_decay: f32,
4728        stream: &CudaStream,
4729        grad_lora: &CudaLoraGradWorkspace,
4730    ) -> Result<()> {
4731        let h = self.config.hidden_size;
4732        let q_dim = self.config.q_dim();
4733        let kv = self.config.num_kv_heads * self.config.head_dim();
4734        let r = self.lora_rank;
4735
4736        // AdamW step for each LoRA weight
4737        if let Some(ref mut a_q) = self.lora_a_q {
4738            adamw_step_cuda(
4739                a_q,
4740                &grad_lora.grad_lora_a_q,
4741                &mut state.m_lora_a_q,
4742                &mut state.v_lora_a_q,
4743                lr,
4744                beta1,
4745                beta2,
4746                eps,
4747                weight_decay,
4748                step,
4749                saturating_u32(h * r),
4750                stream,
4751            )?;
4752        }
4753        if let Some(ref mut b_q) = self.lora_b_q {
4754            adamw_step_cuda(
4755                b_q,
4756                &grad_lora.grad_lora_b_q,
4757                &mut state.m_lora_b_q,
4758                &mut state.v_lora_b_q,
4759                lr,
4760                beta1,
4761                beta2,
4762                eps,
4763                weight_decay,
4764                step,
4765                saturating_u32(r * q_dim),
4766                stream,
4767            )?;
4768        }
4769        if let Some(ref mut a_v) = self.lora_a_v {
4770            adamw_step_cuda(
4771                a_v,
4772                &grad_lora.grad_lora_a_v,
4773                &mut state.m_lora_a_v,
4774                &mut state.v_lora_a_v,
4775                lr,
4776                beta1,
4777                beta2,
4778                eps,
4779                weight_decay,
4780                step,
4781                saturating_u32(h * r),
4782                stream,
4783            )?;
4784        }
4785        if let Some(ref mut b_v) = self.lora_b_v {
4786            adamw_step_cuda(
4787                b_v,
4788                &grad_lora.grad_lora_b_v,
4789                &mut state.m_lora_b_v,
4790                &mut state.v_lora_b_v,
4791                lr,
4792                beta1,
4793                beta2,
4794                eps,
4795                weight_decay,
4796                step,
4797                saturating_u32(r * kv),
4798                stream,
4799            )?;
4800        }
4801
4802        // AdamW step for norm weights
4803        adamw_step_cuda(
4804            &mut self.input_norm_weight,
4805            &grad_lora.grad_input_norm,
4806            &mut state.m_input_norm,
4807            &mut state.v_input_norm,
4808            lr,
4809            beta1,
4810            beta2,
4811            eps,
4812            weight_decay,
4813            step,
4814            saturating_u32(h),
4815            stream,
4816        )?;
4817        adamw_step_cuda(
4818            &mut self.post_attn_norm_weight,
4819            &grad_lora.grad_post_attn_norm,
4820            &mut state.m_post_attn_norm,
4821            &mut state.v_post_attn_norm,
4822            lr,
4823            beta1,
4824            beta2,
4825            eps,
4826            weight_decay,
4827            step,
4828            saturating_u32(h),
4829            stream,
4830        )?;
4831
4832        Ok(())
4833    }
4834
4835    /// Download LoRA weights from GPU to CPU for checkpoint saving.
4836    ///
4837    /// Returns (A_q, B_q, A_v, B_v) as flat f32 vectors.
4838    /// B matrices are returned WITH the baked-in scale (caller can divide by lora_scale
4839    /// if they need the unscaled version).
4840    pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
4841        let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
4842            let mut host = vec![0.0f32; buf.len()];
4843            buf.copy_to_host(&mut host).map_err(|e| {
4844                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4845                    "LoRA weight download failed: {e}"
4846                ))
4847            })?;
4848            Ok(host)
4849        };
4850        let a_q = self.lora_a_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4851        let b_q = self.lora_b_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4852        let a_v = self.lora_a_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4853        let b_v = self.lora_b_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4854        Ok((a_q, b_q, a_v, b_v))
4855    }
4856
4857    /// Upload LoRA weights from CPU to GPU for checkpoint resume (ENT-276).
4858    ///
4859    /// Overwrites the current LoRA adapter buffers with trained weights
4860    /// restored from a checkpoint. Call after `new()` to replace the fresh
4861    /// random init with previously trained adapters.
4862    pub fn upload_lora_weights(
4863        &mut self,
4864        a_q: &[f32],
4865        b_q: &[f32],
4866        a_v: &[f32],
4867        b_v: &[f32],
4868    ) -> Result<()> {
4869        let upload = |buf: &mut GpuBuffer<f32>, data: &[f32], name: &str| -> Result<()> {
4870            if data.len() != buf.len() {
4871                return Err(crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(
4872                    format!(
4873                        "LoRA {name} size mismatch: checkpoint has {} but GPU buffer expects {}",
4874                        data.len(),
4875                        buf.len()
4876                    ),
4877                ));
4878            }
4879            buf.copy_from_host(data).map_err(|e| {
4880                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4881                    "LoRA {name} upload failed: {e}"
4882                ))
4883            })
4884        };
4885        if let Some(ref mut buf) = self.lora_a_q {
4886            upload(buf, a_q, "a_q")?;
4887        }
4888        if let Some(ref mut buf) = self.lora_b_q {
4889            upload(buf, b_q, "b_q")?;
4890        }
4891        if let Some(ref mut buf) = self.lora_a_v {
4892            upload(buf, a_v, "a_v")?;
4893        }
4894        if let Some(ref mut buf) = self.lora_b_v {
4895            upload(buf, b_v, "b_v")?;
4896        }
4897        Ok(())
4898    }
4899}
4900
4901#[cfg(test)]
4902mod tests {
4903    #[test]
4904    fn test_cuda_block_compiles() {
4905        // Basic compilation test
4906        #[cfg(feature = "cuda")]
4907        {
4908            use super::*;
4909            let _ = std::mem::size_of::<CudaTransformerBlock>();
4910            let _ = std::mem::size_of::<CudaNf4TransformerBlock>();
4911        }
4912    }
4913}