Skip to main content

entrenar/transformer/
cuda_block.rs

1//! CUDA-accelerated Transformer Block (ENT-147 through ENT-152)
2//!
3//! This module provides a fully GPU-accelerated transformer block using trueno-gpu kernels.
4//! All operations run on CUDA to achieve >70% GPU utilization.
5//!
6//! # Phase 22 Implementation Status
7//!
8//! - ENT-147: CUDA RMSNorm integration ✅
9//! - ENT-148: CUDA Softmax integration ✅
10//! - ENT-149: CUDA SiLU activation ✅
11//! - ENT-150: Fused SwiGLU kernel ✅
12//! - ENT-151: CUDA backward pass ✅
13//! - ENT-152: CudaTransformer wrapper ✅
14
15#![allow(dead_code)]
16// SAFETY: This module performs GPU memory transfers via CUDA driver FFI.
17// The unsafe blocks are limited to copy_from_host_async / copy_to_host_async
18// where we guarantee the host buffer outlives the async operation by syncing
19// the stream before the buffer goes out of scope.
20#![allow(unsafe_code)]
21
22// PMAT-483/entrenar#328: Per-operation profiling indices for forward pass.
23// Must match StepProfiler OP_* constants in step_profiler.rs.
24#[cfg(feature = "cuda")]
25const OP_RMSNORM_ATTN: usize = 0;
26#[cfg(feature = "cuda")]
27const OP_QKV_GEMM: usize = 1;
28#[cfg(feature = "cuda")]
29const OP_ATTENTION: usize = 2;
30#[cfg(feature = "cuda")]
31const OP_O_PROJ: usize = 3;
32#[cfg(feature = "cuda")]
33const OP_RMSNORM_FFN: usize = 4;
34#[cfg(feature = "cuda")]
35const OP_GATE_UP_GEMM: usize = 5;
36#[cfg(feature = "cuda")]
37const OP_SILU: usize = 6;
38#[cfg(feature = "cuda")]
39const OP_DOWN_GEMM: usize = 7;
40
41// PMAT-483: Per-operation profiling indices for backward pass.
42#[cfg(feature = "cuda")]
43const OP_LORA_FWD: usize = 8;
44#[cfg(feature = "cuda")]
45const OP_DOWN_BWD: usize = 9;
46#[cfg(feature = "cuda")]
47const OP_SWIGLU_BWD: usize = 10;
48#[cfg(feature = "cuda")]
49const OP_GATE_UP_BWD: usize = 11;
50#[cfg(feature = "cuda")]
51const OP_ATTN_BWD: usize = 12;
52#[cfg(feature = "cuda")]
53const OP_QKV_BWD: usize = 13;
54#[cfg(feature = "cuda")]
55const OP_NORM_BWD: usize = 14;
56#[cfg(feature = "cuda")]
57const OP_LORA_BWD: usize = 15;
58
59#[cfg(feature = "cuda")]
60use std::sync::Arc;
61
62#[cfg(feature = "cuda")]
63#[inline]
64fn saturating_u32(v: usize) -> u32 {
65    v.min(u32::MAX as usize) as u32
66}
67
68/// Consume a value without running its destructor (prevents GPU double-free).
69#[cfg(feature = "cuda")]
70#[inline]
71fn leak<T>(val: T) {
72    let _ = std::mem::ManuallyDrop::new(val);
73}
74
75#[cfg(feature = "cuda")]
76use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
77
78#[cfg(feature = "cuda")]
79use crate::autograd::cuda_backward::{
80    batched_softmax_backward, gemm_backward_a, gemm_backward_a_fp16_dispatch,
81    gemm_backward_a_fp16_dispatch_accumulate, gemm_backward_b, rms_norm_backward, silu_backward,
82};
83#[cfg(feature = "cuda")]
84use crate::autograd::cuda_forward::{
85    batched_4d_gemm_forward, batched_rope_neox_backward, batched_rope_neox_forward,
86    batched_softmax_forward, batched_to_interleaved_forward, batched_transpose_forward,
87    cast_f32_to_f16_gpu, elementwise_mul_forward, expand_kv_heads, fused_residual_rmsnorm_forward,
88    fused_swiglu_forward, gemm_f16_to_f32_forward, gemm_forward, interleaved_to_batched_forward,
89    per_head_rmsnorm_forward, residual_add_forward, rms_norm_forward, rms_norm_forward_with_eps,
90    scale_forward, silu_forward,
91};
92#[cfg(feature = "cuda")]
93use crate::autograd::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, squared_sum_cuda};
94#[cfg(feature = "cuda")]
95use crate::autograd::cuda_tensor::Result;
96
97#[cfg(feature = "cuda")]
98use super::config::TransformerConfig;
99
100/// CUDA-accelerated transformer block
101///
102/// All operations run on GPU with minimal CPU<->GPU transfers.
103#[cfg(feature = "cuda")]
104pub struct CudaTransformerBlock {
105    /// Configuration
106    config: TransformerConfig,
107    /// Layer index
108    layer_idx: usize,
109    /// Input RMSNorm weight (gamma)
110    input_norm_weight: GpuBuffer<f32>,
111    /// Post-attention RMSNorm weight (gamma)
112    post_attn_norm_weight: GpuBuffer<f32>,
113    /// Query projection weight (hidden_size x hidden_size)
114    w_q: GpuBuffer<f32>,
115    /// Key projection weight (hidden_size x kv_hidden_size)
116    w_k: GpuBuffer<f32>,
117    /// Value projection weight (hidden_size x kv_hidden_size)
118    w_v: GpuBuffer<f32>,
119    /// Output projection weight (hidden_size x hidden_size)
120    w_o: GpuBuffer<f32>,
121    /// FFN gate projection (hidden_size x intermediate_size)
122    w_gate: GpuBuffer<f32>,
123    /// FFN up projection (hidden_size x intermediate_size)
124    w_up: GpuBuffer<f32>,
125    /// FFN down projection (intermediate_size x hidden_size)
126    w_down: GpuBuffer<f32>,
127    /// CUDA context
128    ctx: Arc<CudaContext>,
129    /// Scratch buffers for intermediate results
130    scratch: CudaBlockScratch,
131    /// Pre-allocated host zero buffer for zeroing norm grad buffers [hidden_size]
132    norm_zero_buf: Vec<f32>,
133    /// ENT-270: QK-norm weights (per-head RMSNorm, shape=[head_dim])
134    q_norm_weight: Option<GpuBuffer<f32>>,
135    k_norm_weight: Option<GpuBuffer<f32>>,
136    /// FALSIFY-CUDA-FORWARD-PARITY-002: Q projection bias replicated
137    /// across max_seq_len rows. Allocated when `config.use_bias == true`
138    /// (Qwen2 family). Pre-replicated so `cuda_add_inplace` can apply
139    /// the bias broadcast in a single kernel call. Memory: max_seq_len
140    /// × q_dim × 4 bytes per layer (e.g., 512 × 896 × 4 = 1.75 MB on
141    /// Qwen 0.5B). Pre-fix this field did not exist; Qwen Q/K/V biases
142    /// silently dropped → val_loss > ln(vocab) per
143    /// `apr-pretrain-cuda-forward-parity-v1.yaml`.
144    b_q_replicated: Option<GpuBuffer<f32>>,
145    b_k_replicated: Option<GpuBuffer<f32>>,
146    b_v_replicated: Option<GpuBuffer<f32>>,
147}
148
149/// Preallocated scratch buffers for transformer forward/backward pass.
150///
151/// For fp32 blocks: per-layer (backward reads forward activations).
152/// For NF4 blocks: shared across all layers (forward-only, no backward).
153///
154/// # Contract (C-SCRATCH-001)
155///
156/// - **Precondition**: Allocated with matching `config` and `max_seq_len`
157/// - **Postcondition**: All buffers sized for worst-case `max_seq_len`
158/// - **Invariant**: NF4 layers run sequentially — one shared scratch is safe
159#[cfg(feature = "cuda")]
160pub(crate) struct CudaBlockScratch {
161    /// After input RMSNorm (seq_len * hidden_size)
162    norm1_out: GpuBuffer<f32>,
163    /// Q projection output (seq_len * hidden_size)
164    q: GpuBuffer<f32>,
165    /// K projection output (seq_len * kv_hidden_size)
166    k: GpuBuffer<f32>,
167    /// V projection output (seq_len * kv_hidden_size)
168    v: GpuBuffer<f32>,
169    /// Attention scores (num_heads * seq_len * seq_len)
170    attn_scores: GpuBuffer<f32>,
171    /// Attention output (seq_len * q_dim)
172    attn_out: GpuBuffer<f32>,
173    /// Output projection result
174    o_proj_out: GpuBuffer<f32>,
175    /// Residual after attention
176    residual1: GpuBuffer<f32>,
177    /// After post-attention RMSNorm
178    norm2_out: GpuBuffer<f32>,
179    /// FFN gate output (seq_len * intermediate_size)
180    gate_out: GpuBuffer<f32>,
181    /// FFN up output (seq_len * intermediate_size)
182    up_out: GpuBuffer<f32>,
183    /// FFN fused SwiGLU output: SiLU(gate) * up (seq_len * intermediate_size)
184    swiglu_out: GpuBuffer<f32>,
185    /// FFN down projection output
186    ffn_out: GpuBuffer<f32>,
187    /// FP16 activation cast buffers for FP16 GEMM dispatch (PMAT-470)
188    /// Allocated lazily on first use when FP16_GEMM=1.
189    norm1_out_f16: Option<GpuBuffer<u16>>,
190    attn_out_f16: Option<GpuBuffer<u16>>,
191    norm2_out_f16: Option<GpuBuffer<u16>>,
192    swiglu_out_f16: Option<GpuBuffer<u16>>,
193    // === Seq-dependent backward scratch (per-layer for activation reuse) ===
194    /// Gradient accumulator for hidden states
195    grad_hidden: GpuBuffer<f32>,
196    /// Gradient for SwiGLU intermediate
197    grad_swiglu: GpuBuffer<f32>,
198    // === Attention layout scratch buffers (GPU-only attention pipeline) ===
199    /// Q in batched layout [num_heads, seq_len, head_dim]
200    attn_q_batched: GpuBuffer<f32>,
201    /// K/V layout temp buffer [num_heads, seq_len, head_dim]
202    attn_kv_temp: GpuBuffer<f32>,
203    /// K transposed / second temp [num_heads, head_dim, seq_len]
204    attn_kv_temp2: GpuBuffer<f32>,
205    // === Attention backward scratch (seq-dependent) ===
206    /// Gradient for attention scores [num_heads * seq_len * seq_len]
207    /// Kept separate from attn_scores because softmax backward reads y while writing grad_x
208    grad_attn_scores: GpuBuffer<f32>,
209    // === LoRA scratch buffers (ENT-153: QLoRA) ===
210    /// LoRA intermediate: x @ A, sized [max_seq_len * max_lora_rank]
211    lora_inter: GpuBuffer<f32>,
212    /// LoRA temp for scaled addition, sized [max_seq_len * max_proj_dim]
213    /// (reuses largest projection dimension for Q/V LoRA output)
214    lora_temp: GpuBuffer<f32>,
215    /// Sequential position indices [0, 1, ..., max_seq_len-1] for batched RoPE
216    rope_positions: GpuBuffer<u32>,
217    /// PMAT-420: Contiguous causal mask for current seq_len [seq_len * seq_len]
218    causal_mask_contiguous: GpuBuffer<f32>,
219    /// PMAT-420: seq_len that causal_mask_contiguous was generated for (cache key)
220    pub(crate) causal_mask_cached_seq_len: usize,
221    /// PMAT-483/entrenar#328: Per-operation timing accumulator (microseconds).
222    /// Index matches StepProfiler OP_* constants. Accumulated across layers per step.
223    /// Zero-overhead when profiling disabled (check op_profiling_enabled first).
224    pub(crate) op_us: [u64; 16],
225    /// Whether per-op profiling is active this step
226    pub(crate) op_profiling_enabled: bool,
227}
228
229#[cfg(feature = "cuda")]
230impl CudaBlockScratch {
231    /// PMAT-483: Start timing an operation (no-op if profiling disabled).
232    #[inline]
233    pub(crate) fn op_begin(&self) -> Option<std::time::Instant> {
234        if self.op_profiling_enabled {
235            Some(std::time::Instant::now())
236        } else {
237            None
238        }
239    }
240
241    /// PMAT-483: Record elapsed time for an operation.
242    #[inline]
243    pub(crate) fn op_end(&mut self, start: Option<std::time::Instant>, op: usize) {
244        if let Some(t) = start {
245            if op < 16 {
246                self.op_us[op] += t.elapsed().as_micros() as u64;
247            }
248        }
249    }
250
251    /// Zero all forward scratch buffers to prevent backward gradient contamination.
252    /// entrenar#318 Tier 1: GPU-side memset via cuMemsetD32Async (no PCIe transfer).
253    /// Max sequence length this scratch was allocated for.
254    pub(crate) fn max_seq_len(&self, hidden_size: usize) -> usize {
255        self.norm1_out.len() / hidden_size.max(1)
256    }
257
258    #[rustfmt::skip]
259    pub(crate) fn zero_forward_buffers(&mut self, stream: &CudaStream) {
260        let z = |b: &mut GpuBuffer<f32>| { b.zero_async(stream).ok(); };
261        z(&mut self.norm1_out); z(&mut self.q); z(&mut self.k); z(&mut self.v); z(&mut self.attn_scores); z(&mut self.attn_out);
262        z(&mut self.o_proj_out); z(&mut self.residual1); z(&mut self.norm2_out); z(&mut self.gate_out); z(&mut self.up_out);
263        z(&mut self.swiglu_out); z(&mut self.ffn_out); z(&mut self.attn_q_batched); z(&mut self.attn_kv_temp); z(&mut self.attn_kv_temp2);
264        z(&mut self.grad_hidden); z(&mut self.grad_swiglu); z(&mut self.grad_attn_scores); z(&mut self.lora_inter); z(&mut self.lora_temp);
265        self.causal_mask_cached_seq_len = 0;
266    }
267
268    /// Allocate scratch buffers for a given model config and max sequence length.
269    ///
270    /// # Contract (C-SCRATCH-001)
271    ///
272    /// All buffer sizes are deterministic from (config, max_seq_len).
273    pub(crate) fn new(
274        config: &TransformerConfig,
275        max_seq_len: usize,
276        ctx: &Arc<CudaContext>,
277        lora_rank: usize,
278    ) -> Result<Self> {
279        let hidden_size = config.hidden_size;
280        let q_dim = config.q_dim();
281        let kv_hidden_size = config.num_kv_heads * config.head_dim();
282        let intermediate_size = config.intermediate_size;
283        let num_heads = config.num_attention_heads;
284        let head_dim = config.head_dim();
285
286        // LoRA scratch: max(q_dim, kv_hidden) for the largest projection output
287        let max_proj_dim = q_dim.max(kv_hidden_size);
288        // Minimum 1 element to avoid zero-size GPU allocation
289        let lora_inter_size = (max_seq_len * lora_rank).max(1);
290        let lora_temp_size = (max_seq_len * max_proj_dim).max(1);
291
292        // C-CAUSAL-001: Precompute causal mask [seq × seq] — shared across all heads
293        // 4 MB for seq=1024. Applied per-head in compute_attention_cuda.
294        let causal_mask_data: Vec<f32> = (0..max_seq_len * max_seq_len)
295            .map(|idx| {
296                let row = idx / max_seq_len;
297                let col = idx % max_seq_len;
298                if col <= row {
299                    0.0f32
300                } else {
301                    f32::NEG_INFINITY
302                }
303            })
304            .collect();
305        Ok(Self {
306            norm1_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
307            q: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
308            k: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
309            v: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
310            attn_scores: GpuBuffer::new(ctx, num_heads * max_seq_len * max_seq_len)?,
311            attn_out: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
312            o_proj_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
313            residual1: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
314            norm2_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
315            gate_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
316            up_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
317            swiglu_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
318            ffn_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
319            norm1_out_f16: None,
320            attn_out_f16: None,
321            norm2_out_f16: None,
322            swiglu_out_f16: None,
323            grad_hidden: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
324            grad_swiglu: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
325            attn_q_batched: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
326            attn_kv_temp: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
327            attn_kv_temp2: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
328            grad_attn_scores: GpuBuffer::new(
329                ctx,
330                num_heads * max_seq_len * max_seq_len.max(head_dim),
331            )?,
332            lora_inter: GpuBuffer::new(ctx, lora_inter_size)?,
333            lora_temp: GpuBuffer::new(ctx, lora_temp_size)?,
334            rope_positions: {
335                let positions: Vec<u32> = (0..max_seq_len as u32).collect();
336                let mut buf = GpuBuffer::new(ctx, max_seq_len)?;
337                buf.copy_from_host(&positions)?;
338                buf
339            },
340            causal_mask_contiguous: GpuBuffer::from_host(ctx, &causal_mask_data)?,
341            causal_mask_cached_seq_len: max_seq_len,
342            op_us: [0u64; 16],
343            op_profiling_enabled: false,
344        })
345    }
346
347    /// PMAT-420: Prepare a contiguous [seq_len * seq_len] causal mask. Cached: only regenerates
348    /// when seq_len changes. Cost: O(seq_len^2) CPU + one H2D upload (~0.01ms for seq=256).
349    pub(crate) fn prepare_causal_mask(
350        &mut self,
351        seq_len: usize,
352        ctx: &Arc<CudaContext>,
353    ) -> crate::autograd::cuda_tensor::Result<()> {
354        if seq_len == self.causal_mask_cached_seq_len {
355            return Ok(());
356        }
357        let mask_data: Vec<f32> = (0..seq_len * seq_len)
358            .map(|idx| {
359                let row = idx / seq_len;
360                let col = idx % seq_len;
361                if col <= row {
362                    0.0f32
363                } else {
364                    f32::NEG_INFINITY
365                }
366            })
367            .collect();
368        self.causal_mask_contiguous = GpuBuffer::from_host(ctx, &mask_data)?;
369        self.causal_mask_cached_seq_len = seq_len;
370        Ok(())
371    }
372}
373
374/// Shared gradient workspace for weight gradients (one per model, NOT per layer).
375///
376/// # Contract (C-GRADWS-001)
377///
378/// Backward processes layers sequentially — only one layer's weight gradients
379/// are computed at a time. Sharing this workspace across layers saves
380/// `(L-1) * per_layer_grad_weight_elements * 4` bytes of VRAM.
381///
382/// For Qwen3-4B: saves 35 * 372 MB = 13.0 GB.
383///
384/// - **Precondition**: Allocated once before training loop starts
385/// - **Postcondition**: After backward() for layer i, contains layer i's weight gradients
386/// - **Invariant**: Buffer sizes match model config; never reallocated during training
387#[cfg(feature = "cuda")]
388pub struct CudaGradWorkspace {
389    /// Gradient for input norm weight [hidden_size]
390    pub(crate) grad_input_norm: GpuBuffer<f32>,
391    /// Gradient for post-attention norm weight [hidden_size]
392    pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
393    /// Gradient for FFN gate projection [hidden_size * intermediate_size]
394    pub(crate) grad_gate: GpuBuffer<f32>,
395    /// Gradient for FFN up projection [hidden_size * intermediate_size]
396    pub(crate) grad_up: GpuBuffer<f32>,
397    /// Gradient for FFN down projection [intermediate_size * hidden_size]
398    pub(crate) grad_down: GpuBuffer<f32>,
399    /// Gradient for Q projection weight [q_dim * hidden_size]
400    pub(crate) grad_w_q: GpuBuffer<f32>,
401    /// Gradient for K projection weight [hidden_size * kv_hidden_size]
402    pub(crate) grad_w_k: GpuBuffer<f32>,
403    /// Gradient for V projection weight [hidden_size * kv_hidden_size]
404    pub(crate) grad_w_v: GpuBuffer<f32>,
405    /// Gradient for output projection weight [hidden_size * q_dim]
406    pub(crate) grad_w_o: GpuBuffer<f32>,
407}
408
409#[cfg(feature = "cuda")]
410impl CudaGradWorkspace {
411    /// Allocate shared gradient workspace for the given model config.
412    ///
413    /// Called once per training run. GEMM weight gradients are fully overwritten
414    /// by each backward pass. Norm gradients use atomicAdd accumulation and MUST
415    /// be zeroed before each rms_norm_backward call (see `zero_norm_grads`).
416    pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> Result<Self> {
417        let h = config.hidden_size;
418        let q = config.q_dim();
419        let kv = config.num_kv_heads * config.head_dim();
420        let i = config.intermediate_size;
421
422        Ok(Self {
423            grad_input_norm: GpuBuffer::new(ctx, h)?,
424            grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
425            grad_gate: GpuBuffer::new(ctx, h * i)?,
426            grad_up: GpuBuffer::new(ctx, h * i)?,
427            grad_down: GpuBuffer::new(ctx, i * h)?,
428            grad_w_q: GpuBuffer::new(ctx, q * h)?,
429            grad_w_k: GpuBuffer::new(ctx, h * kv)?,
430            grad_w_v: GpuBuffer::new(ctx, h * kv)?,
431            grad_w_o: GpuBuffer::new(ctx, h * q)?,
432        })
433    }
434
435    /// Zero norm gradient buffers before rms_norm_backward calls.
436    ///
437    /// The BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
438    /// so these buffers MUST be zeroed before each backward pass. Without this,
439    /// grad_gamma accumulates across steps → exploding norm gradients.
440    pub fn zero_norm_grads(&mut self, zero_buf: &[f32]) -> Result<()> {
441        let n = self.grad_input_norm.len();
442        self.grad_input_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
443            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
444                "Failed to zero grad_input_norm: {e:?}"
445            ))
446        })?;
447        self.grad_post_attn_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
448            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
449                "Failed to zero grad_post_attn_norm: {e:?}"
450            ))
451        })?;
452        Ok(())
453    }
454}
455
456/// GPU-resident AdamW optimizer state for one transformer block.
457///
458/// Stores first (m) and second (v) moment estimates for all 9 weight tensors:
459/// 7 matmul weights + 2 RMSNorm weights. All buffers live on GPU to avoid
460/// CPU↔GPU transfers during training.
461///
462/// # Contract (C-OPTSTATE-001)
463///
464/// - **Precondition**: CUDA context valid, all buffers allocated to match weight dimensions
465/// - **Postcondition**: m and v buffers initialized to zero (unbiased start)
466/// - **Invariant**: Buffer sizes immutable after creation; m/v never reallocated
467#[cfg(feature = "cuda")]
468pub struct GpuBlockOptimizerState {
469    // Attention projection optimizer states
470    m_w_q: GpuBuffer<f32>,
471    v_w_q: GpuBuffer<f32>,
472    m_w_k: GpuBuffer<f32>,
473    v_w_k: GpuBuffer<f32>,
474    m_w_v: GpuBuffer<f32>,
475    v_w_v: GpuBuffer<f32>,
476    m_w_o: GpuBuffer<f32>,
477    v_w_o: GpuBuffer<f32>,
478    // FFN projection optimizer states
479    m_w_gate: GpuBuffer<f32>,
480    v_w_gate: GpuBuffer<f32>,
481    m_w_up: GpuBuffer<f32>,
482    v_w_up: GpuBuffer<f32>,
483    m_w_down: GpuBuffer<f32>,
484    v_w_down: GpuBuffer<f32>,
485    // RMSNorm weight optimizer states
486    m_input_norm: GpuBuffer<f32>,
487    v_input_norm: GpuBuffer<f32>,
488    m_post_attn_norm: GpuBuffer<f32>,
489    v_post_attn_norm: GpuBuffer<f32>,
490}
491
492/// ALB-118: Download GPU optimizer state to host for checkpointing.
493#[cfg(feature = "cuda")]
494impl GpuBlockOptimizerState {
495    /// Returns (suffix, data) pairs for all 18 m/v buffers.
496    /// Suffix is e.g. "m.w_q", "v.w_gate" — caller prefixes with layer index.
497    pub fn download_to_host(
498        &self,
499    ) -> crate::autograd::cuda_tensor::Result<Vec<(String, Vec<f32>)>> {
500        let dl = |name: &str,
501                  buf: &GpuBuffer<f32>|
502         -> crate::autograd::cuda_tensor::Result<(String, Vec<f32>)> {
503            let mut host = vec![0.0f32; buf.len()];
504            buf.copy_to_host(&mut host).map_err(|e| {
505                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
506                    "optimizer D2H {name}: {e}"
507                ))
508            })?;
509            Ok((name.to_string(), host))
510        };
511        Ok(vec![
512            dl("m.w_q", &self.m_w_q)?,
513            dl("v.w_q", &self.v_w_q)?,
514            dl("m.w_k", &self.m_w_k)?,
515            dl("v.w_k", &self.v_w_k)?,
516            dl("m.w_v", &self.m_w_v)?,
517            dl("v.w_v", &self.v_w_v)?,
518            dl("m.w_o", &self.m_w_o)?,
519            dl("v.w_o", &self.v_w_o)?,
520            dl("m.w_gate", &self.m_w_gate)?,
521            dl("v.w_gate", &self.v_w_gate)?,
522            dl("m.w_up", &self.m_w_up)?,
523            dl("v.w_up", &self.v_w_up)?,
524            dl("m.w_down", &self.m_w_down)?,
525            dl("v.w_down", &self.v_w_down)?,
526            dl("m.input_norm", &self.m_input_norm)?,
527            dl("v.input_norm", &self.v_input_norm)?,
528            dl("m.post_attn_norm", &self.m_post_attn_norm)?,
529            dl("v.post_attn_norm", &self.v_post_attn_norm)?,
530        ])
531    }
532
533    /// ALB-118: Upload host optimizer state to GPU (checkpoint resume).
534    /// Missing keys are silently skipped (buffer stays zero-initialized).
535    pub fn restore_from_host(
536        &mut self,
537        data: &std::collections::HashMap<String, Vec<f32>>,
538    ) -> crate::autograd::cuda_tensor::Result<()> {
539        let ul = |name: &str,
540                  buf: &mut GpuBuffer<f32>,
541                  data: &std::collections::HashMap<String, Vec<f32>>|
542         -> crate::autograd::cuda_tensor::Result<()> {
543            if let Some(host_data) = data.get(name) {
544                if host_data.len() == buf.len() {
545                    buf.copy_from_host(host_data).map_err(|e| {
546                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
547                            "optimizer H2D {name}: {e}"
548                        ))
549                    })?;
550                }
551            }
552            Ok(())
553        };
554        ul("m.w_q", &mut self.m_w_q, data)?;
555        ul("v.w_q", &mut self.v_w_q, data)?;
556        ul("m.w_k", &mut self.m_w_k, data)?;
557        ul("v.w_k", &mut self.v_w_k, data)?;
558        ul("m.w_v", &mut self.m_w_v, data)?;
559        ul("v.w_v", &mut self.v_w_v, data)?;
560        ul("m.w_o", &mut self.m_w_o, data)?;
561        ul("v.w_o", &mut self.v_w_o, data)?;
562        ul("m.w_gate", &mut self.m_w_gate, data)?;
563        ul("v.w_gate", &mut self.v_w_gate, data)?;
564        ul("m.w_up", &mut self.m_w_up, data)?;
565        ul("v.w_up", &mut self.v_w_up, data)?;
566        ul("m.w_down", &mut self.m_w_down, data)?;
567        ul("v.w_down", &mut self.v_w_down, data)?;
568        ul("m.input_norm", &mut self.m_input_norm, data)?;
569        ul("v.input_norm", &mut self.v_input_norm, data)?;
570        ul("m.post_attn_norm", &mut self.m_post_attn_norm, data)?;
571        ul("v.post_attn_norm", &mut self.v_post_attn_norm, data)?;
572        Ok(())
573    }
574}
575
576#[cfg(feature = "cuda")]
577impl CudaTransformerBlock {
578    /// Create a new CUDA transformer block from CPU tensors
579    ///
580    /// Uploads all weights to GPU memory.
581    #[allow(clippy::too_many_arguments)]
582    pub fn new(
583        config: &TransformerConfig,
584        layer_idx: usize,
585        ctx: Arc<CudaContext>,
586        input_norm_weight: &[f32],
587        post_attn_norm_weight: &[f32],
588        w_q: &[f32],
589        w_k: &[f32],
590        w_v: &[f32],
591        w_o: &[f32],
592        w_gate: &[f32],
593        w_up: &[f32],
594        w_down: &[f32],
595        max_seq_len: usize,
596        // FALSIFY-CUDA-FORWARD-PARITY-002: Optional Q/K/V projection
597        // biases (Qwen2 family has them; Llama doesn't). When present,
598        // each is replicated across max_seq_len rows so the existing
599        // `cuda_add_inplace` (residual_add) can apply the broadcast
600        // in one kernel call.
601        b_q: Option<&[f32]>,
602        b_k: Option<&[f32]>,
603        b_v: Option<&[f32]>,
604    ) -> Result<Self> {
605        let hidden_size = config.hidden_size;
606        let q_dim = config.q_dim(); // num_heads * head_dim (may differ from hidden_size)
607        let kv_hidden_size = config.num_kv_heads * config.head_dim();
608        let intermediate_size = config.intermediate_size;
609        let num_heads = config.num_attention_heads;
610
611        // Upload weights to GPU
612        let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
613        let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
614        let w_q = GpuBuffer::from_host(&ctx, w_q)?;
615        let w_k = GpuBuffer::from_host(&ctx, w_k)?;
616        let w_v = GpuBuffer::from_host(&ctx, w_v)?;
617        let w_o = GpuBuffer::from_host(&ctx, w_o)?;
618        let w_gate = GpuBuffer::from_host(&ctx, w_gate)?;
619        let w_up = GpuBuffer::from_host(&ctx, w_up)?;
620        let w_down = GpuBuffer::from_host(&ctx, w_down)?;
621
622        // C-CAUSAL-001: Precompute causal mask for NF4 path
623        let single_mask: Vec<f32> = (0..max_seq_len * max_seq_len)
624            .map(|idx| {
625                let row = idx / max_seq_len;
626                let col = idx % max_seq_len;
627                if col <= row {
628                    0.0f32
629                } else {
630                    f32::NEG_INFINITY
631                }
632            })
633            .collect();
634        // Allocate scratch buffers — Q and attn_out need q_dim, not hidden_size
635        let scratch = CudaBlockScratch {
636            norm1_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
637            q: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
638            k: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
639            v: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
640            attn_scores: GpuBuffer::new(&ctx, num_heads * max_seq_len * max_seq_len)?,
641            attn_out: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
642            o_proj_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
643            residual1: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
644            norm2_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
645            gate_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
646            up_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
647            swiglu_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
648            ffn_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
649            norm1_out_f16: None,
650            attn_out_f16: None,
651            norm2_out_f16: None,
652            swiglu_out_f16: None,
653            // Seq-dependent backward scratch
654            grad_hidden: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
655            grad_swiglu: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
656            // Attention layout scratch (all sized for num_heads, handles GQA expansion)
657            attn_q_batched: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
658            attn_kv_temp: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
659            attn_kv_temp2: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
660            // Attention backward gradient buffers (ENT-151b)
661            // grad_attn_scores needs max(H*S*S, H*S*hd) for buffer reuse safety
662            grad_attn_scores: GpuBuffer::new(
663                &ctx,
664                num_heads * max_seq_len * max_seq_len.max(config.head_dim()),
665            )?,
666            // LoRA scratch (unused for fp32 blocks, minimum allocation)
667            lora_inter: GpuBuffer::new(&ctx, 1)?,
668            lora_temp: GpuBuffer::new(&ctx, 1)?,
669            rope_positions: {
670                let positions: Vec<u32> = (0..max_seq_len as u32).collect();
671                let mut buf = GpuBuffer::new(&ctx, max_seq_len)?;
672                buf.copy_from_host(&positions)?;
673                buf
674            },
675            causal_mask_contiguous: GpuBuffer::from_host(&ctx, &single_mask)?,
676            causal_mask_cached_seq_len: max_seq_len,
677            op_us: [0u64; 16],
678            op_profiling_enabled: false,
679        };
680
681        // FALSIFY-CUDA-FORWARD-PARITY-002: replicate Q/K/V biases across
682        // max_seq_len rows when use_bias=true. Each replicated buffer
683        // is used by `cuda_add_inplace` after the corresponding gemm to
684        // apply the broadcast bias. Pre-fix this allocation didn't
685        // exist; biases dropped silently on Qwen-init paths.
686        let replicate = |bias: Option<&[f32]>, dim: usize| -> Result<Option<GpuBuffer<f32>>> {
687            match bias {
688                Some(slice) => {
689                    debug_assert_eq!(
690                        slice.len(),
691                        dim,
692                        "bias slice len {} != expected dim {dim}",
693                        slice.len()
694                    );
695                    let mut repl: Vec<f32> = Vec::with_capacity(max_seq_len * dim);
696                    for _ in 0..max_seq_len {
697                        repl.extend_from_slice(slice);
698                    }
699                    Ok(Some(GpuBuffer::from_host(&ctx, &repl)?))
700                }
701                None => Ok(None),
702            }
703        };
704        let b_q_replicated = replicate(b_q, q_dim)?;
705        let b_k_replicated = replicate(b_k, kv_hidden_size)?;
706        let b_v_replicated = replicate(b_v, kv_hidden_size)?;
707
708        Ok(Self {
709            config: config.clone(),
710            layer_idx,
711            input_norm_weight,
712            post_attn_norm_weight,
713            w_q,
714            w_k,
715            w_v,
716            w_o,
717            w_gate,
718            w_up,
719            w_down,
720            ctx,
721            scratch,
722            norm_zero_buf: vec![0.0f32; hidden_size],
723            q_norm_weight: None, // ENT-270: set via set_qk_norm() after construction
724            k_norm_weight: None,
725            b_q_replicated,
726            b_k_replicated,
727            b_v_replicated,
728        })
729    }
730
731    /// Set QK-norm weights (ENT-270). Called after construction when loading Qwen3 models.
732    #[allow(dead_code)]
733    pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()> {
734        self.q_norm_weight = Some(GpuBuffer::from_host(&self.ctx, q_norm)?);
735        self.k_norm_weight = Some(GpuBuffer::from_host(&self.ctx, k_norm)?);
736        Ok(())
737    }
738
739    /// Forward pass - all operations on GPU
740    ///
741    /// # Arguments
742    /// * `input` - Input tensor on GPU (seq_len * hidden_size)
743    /// * `output` - Output tensor on GPU (seq_len * hidden_size)
744    /// * `seq_len` - Sequence length
745    /// * `stream` - CUDA stream for async execution
746    pub fn forward(
747        &mut self,
748        input: &GpuBuffer<f32>,
749        output: &mut GpuBuffer<f32>,
750        seq_len: usize,
751        stream: &CudaStream,
752    ) -> Result<()> {
753        let hidden_size = self.config.hidden_size;
754        let q_dim = self.config.q_dim();
755        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
756        let intermediate_size = self.config.intermediate_size;
757
758        // === Pre-attention RMSNorm (ENT-147) ===
759        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: thread `config.rms_norm_eps`
760        // through to the kernel so Qwen2 (1e-6) and Llama (1e-5) get the
761        // correct epsilon. Pre-fix this hardcoded 1e-5 for everyone.
762        rms_norm_forward_with_eps(
763            input,
764            &self.input_norm_weight,
765            &mut self.scratch.norm1_out,
766            saturating_u32(seq_len),
767            saturating_u32(hidden_size),
768            self.config.rms_norm_eps,
769            stream,
770        )?;
771
772        // === Q, K, V Projections (CUDA GEMM) ===
773        // C[seq,q_dim] = A[seq,hidden] @ B[hidden,q_dim]
774        gemm_forward(
775            &self.scratch.norm1_out,
776            &self.w_q,
777            &mut self.scratch.q,
778            saturating_u32(seq_len),
779            saturating_u32(hidden_size),
780            saturating_u32(q_dim),
781            stream,
782        )?;
783        // FALSIFY-CUDA-FORWARD-PARITY-003: apply Q bias broadcast when
784        // use_bias=true (Qwen2 family). The replicated bias buffer is
785        // allocated at block-construction time; here we add the first
786        // `seq_len * q_dim` elements element-wise.
787        if let Some(b_q_repl) = self.b_q_replicated.as_ref() {
788            cuda_add_inplace(&mut self.scratch.q, b_q_repl, seq_len * q_dim, stream)?;
789        }
790
791        gemm_forward(
792            &self.scratch.norm1_out,
793            &self.w_k,
794            &mut self.scratch.k,
795            saturating_u32(seq_len),
796            saturating_u32(hidden_size),
797            saturating_u32(kv_hidden_size),
798            stream,
799        )?;
800        if let Some(b_k_repl) = self.b_k_replicated.as_ref() {
801            cuda_add_inplace(&mut self.scratch.k, b_k_repl, seq_len * kv_hidden_size, stream)?;
802        }
803
804        gemm_forward(
805            &self.scratch.norm1_out,
806            &self.w_v,
807            &mut self.scratch.v,
808            saturating_u32(seq_len),
809            saturating_u32(hidden_size),
810            saturating_u32(kv_hidden_size),
811            stream,
812        )?;
813        if let Some(b_v_repl) = self.b_v_replicated.as_ref() {
814            cuda_add_inplace(&mut self.scratch.v, b_v_repl, seq_len * kv_hidden_size, stream)?;
815        }
816
817        // === Multi-Head Attention (GPU-only, zero CPU transfers) ===
818        self.compute_attention_cuda(seq_len, stream)?;
819
820        // === Output Projection ===
821        // C[seq,hidden] = A[seq,q_dim] @ B[q_dim,hidden]
822        gemm_forward(
823            &self.scratch.attn_out,
824            &self.w_o,
825            &mut self.scratch.o_proj_out,
826            saturating_u32(seq_len),
827            saturating_u32(q_dim),
828            saturating_u32(hidden_size),
829            stream,
830        )?;
831
832        // === Residual Add (input + attention_output) ===
833        cuda_add(
834            input,
835            &self.scratch.o_proj_out,
836            &mut self.scratch.residual1,
837            seq_len * hidden_size,
838            stream,
839        )?;
840
841        // === Post-attention RMSNorm ===
842        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: see pre-attn note above.
843        rms_norm_forward_with_eps(
844            &self.scratch.residual1,
845            &self.post_attn_norm_weight,
846            &mut self.scratch.norm2_out,
847            saturating_u32(seq_len),
848            saturating_u32(hidden_size),
849            self.config.rms_norm_eps,
850            stream,
851        )?;
852
853        // === FFN: Gate + Up Projections ===
854        gemm_forward(
855            &self.scratch.norm2_out,
856            &self.w_gate,
857            &mut self.scratch.gate_out,
858            saturating_u32(seq_len),
859            saturating_u32(hidden_size),
860            saturating_u32(intermediate_size),
861            stream,
862        )?;
863
864        gemm_forward(
865            &self.scratch.norm2_out,
866            &self.w_up,
867            &mut self.scratch.up_out,
868            saturating_u32(seq_len),
869            saturating_u32(hidden_size),
870            saturating_u32(intermediate_size),
871            stream,
872        )?;
873
874        // === FFN: Fused SwiGLU (ENT-150) - SiLU(gate) * up in single kernel ===
875        fused_swiglu_forward(
876            &self.scratch.gate_out,
877            &self.scratch.up_out,
878            &mut self.scratch.swiglu_out,
879            saturating_u32(seq_len * intermediate_size),
880            stream,
881        )?;
882
883        // === FFN: Down Projection ===
884        gemm_forward(
885            &self.scratch.swiglu_out,
886            &self.w_down,
887            &mut self.scratch.ffn_out,
888            saturating_u32(seq_len),
889            saturating_u32(intermediate_size),
890            saturating_u32(hidden_size),
891            stream,
892        )?;
893
894        // === Final Residual Add (residual1 + ffn_output) ===
895        cuda_add(
896            &self.scratch.residual1,
897            &self.scratch.ffn_out,
898            output,
899            seq_len * hidden_size,
900            stream,
901        )?;
902
903        Ok(())
904    }
905
906    /// Compute multi-head attention entirely on GPU (zero CPU transfers)
907    ///
908    /// # Contract (C-ATTN-001)
909    ///
910    /// - **Precondition**: Q [seq, hidden], K [seq, kv_hidden], V [seq, kv_hidden] on GPU
911    /// - **Postcondition**: attn_out [seq, hidden] = concat(head_0..head_H) where
912    ///   head_h = softmax(Q_h @ K_{kv(h)}^T / √d_k) @ V_{kv(h)}
913    /// - **Invariant**: Zero gpu_to_vec / vec_to_gpu calls; numerically equivalent to CPU
914    ///
915    /// Uses existing trueno-gpu kernels:
916    /// - `InterleavedToBatchedKernel` for Q/K/V layout conversion
917    /// - `BatchedTransposeKernel` for K^T
918    /// - `Batched4DGemmKernel` for Q@K^T and attn@V
919    /// - `ScaleKernel` for 1/√d_k scaling
920    /// - `BatchedSoftmaxKernel` for row-wise softmax
921    /// - `BatchedToInterleavedKernel` for output layout conversion
922    /// - D2D copies for GQA head expansion
923    fn compute_attention_cuda(&mut self, seq_len: usize, stream: &CudaStream) -> Result<()> {
924        let num_heads = self.config.num_attention_heads;
925        let num_kv_heads = self.config.num_kv_heads;
926        let head_dim = self.config.head_dim();
927        let heads_per_kv = num_heads / num_kv_heads;
928        let scale = 1.0 / (head_dim as f32).sqrt();
929
930        let seq = saturating_u32(seq_len);
931        let nh = saturating_u32(num_heads);
932        let nkv = saturating_u32(num_kv_heads);
933        let hd = saturating_u32(head_dim);
934
935        // PMAT-420: Ensure causal mask is contiguous for current seq_len.
936        self.scratch.prepare_causal_mask(seq_len, &self.ctx)?;
937
938        // ── ENT-270: Apply QK-norm (per-head RMSNorm) on Q and K ──────────
939        // SAFETY: In-place GPU operations — CUDA kernels read all input before writing output.
940        // Rust borrow checker cannot verify GPU kernel memory access patterns, so we use
941        // raw pointer reborrow to allow the same buffer as both input and output.
942        if let Some(ref q_norm) = self.q_norm_weight {
943            for pos in 0..seq_len {
944                let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
945                per_head_rmsnorm_forward(q_ref, q_norm, &mut self.scratch.q, nh, hd, pos, stream)?;
946            }
947        }
948        if let Some(ref k_norm) = self.k_norm_weight {
949            for pos in 0..seq_len {
950                let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
951                per_head_rmsnorm_forward(k_ref, k_norm, &mut self.scratch.k, nkv, hd, pos, stream)?;
952            }
953        }
954
955        // ── ENT-270: Apply RoPE (NeoX half-rotation) on Q and K ──────────
956        // ALB-119: Batched launch (2 kernels) replaces per-position loop (2*seq_len kernels)
957        let rope_theta = self.config.rope_theta;
958        {
959            let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
960            batched_rope_neox_forward(
961                q_ref,
962                &mut self.scratch.q,
963                &self.scratch.rope_positions,
964                nh,
965                hd,
966                seq,
967                rope_theta,
968                stream,
969            )?;
970            let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
971            batched_rope_neox_forward(
972                k_ref,
973                &mut self.scratch.k,
974                &self.scratch.rope_positions,
975                nkv,
976                hd,
977                seq,
978                rope_theta,
979                stream,
980            )?;
981        }
982
983        // Step 1: Q interleaved [seq, num_heads * head_dim] → batched [num_heads, seq, head_dim]
984        interleaved_to_batched_forward(
985            &self.scratch.q,
986            &mut self.scratch.attn_q_batched,
987            seq,
988            nh,
989            hd,
990            stream,
991        )?;
992
993        // Step 2: K interleaved [seq, num_kv_heads * head_dim] → batched [num_kv_heads, seq, head_dim]
994        interleaved_to_batched_forward(
995            &self.scratch.k,
996            &mut self.scratch.attn_kv_temp,
997            seq,
998            nkv,
999            hd,
1000            stream,
1001        )?;
1002
1003        // Step 3: GQA expansion + transpose for K
1004        if heads_per_kv == 1 {
1005            // MHA: transpose directly [num_heads, seq, head_dim] → [num_heads, head_dim, seq]
1006            batched_transpose_forward(
1007                &self.scratch.attn_kv_temp,
1008                &mut self.scratch.attn_kv_temp2,
1009                nh,
1010                seq,
1011                hd,
1012                stream,
1013            )?;
1014        } else {
1015            // GQA: expand [num_kv_heads, seq, hd] → [num_heads, seq, hd] in attn_kv_temp2
1016            expand_kv_heads(
1017                &self.scratch.attn_kv_temp,
1018                &mut self.scratch.attn_kv_temp2,
1019                num_kv_heads,
1020                heads_per_kv,
1021                seq_len * head_dim,
1022                stream,
1023            )?;
1024            // Transpose expanded K: [num_heads, seq, hd] → [num_heads, hd, seq] in attn_kv_temp
1025            batched_transpose_forward(
1026                &self.scratch.attn_kv_temp2,
1027                &mut self.scratch.attn_kv_temp,
1028                nh,
1029                seq,
1030                hd,
1031                stream,
1032            )?;
1033            // Move K^T to attn_kv_temp2 for consistent naming below
1034            // (swap pointers via D2D copy — attn_kv_temp → attn_kv_temp2)
1035            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1036            unsafe {
1037                self.scratch
1038                    .attn_kv_temp2
1039                    .copy_from_buffer_async(&self.scratch.attn_kv_temp, stream)
1040                    .map_err(|e| {
1041                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1042                            "K^T buffer copy failed: {e}"
1043                        ))
1044                    })?;
1045            }
1046        }
1047
1048        // Step 4: Q @ K^T → attn_scores [num_heads, seq, seq]
1049        // attn_q_batched: [1, num_heads, seq, head_dim]
1050        // attn_kv_temp2:  [1, num_heads, head_dim, seq] (K transposed)
1051        // attn_scores:    [1, num_heads, seq, seq]
1052        batched_4d_gemm_forward(
1053            &self.scratch.attn_q_batched,
1054            &self.scratch.attn_kv_temp2,
1055            &mut self.scratch.attn_scores,
1056            1,
1057            nh,
1058            seq,
1059            seq,
1060            hd,
1061            stream,
1062        )?;
1063
1064        // Step 5: Scale scores by 1/√d_k (in-place)
1065        let total_scores = nh * seq * seq;
1066        {
1067            // SAFETY: In-place aliasing is safe for element-wise operations where each
1068            // element is read before being written. ScaleKernel processes elements
1069            // independently. The view is forgotten to prevent double-free.
1070            let scores_view = unsafe {
1071                GpuBuffer::<f32>::from_raw_parts(
1072                    self.scratch.attn_scores.as_ptr(),
1073                    self.scratch.attn_scores.len(),
1074                )
1075            };
1076            scale_forward(
1077                &scores_view,
1078                &mut self.scratch.attn_scores,
1079                scale,
1080                total_scores,
1081                stream,
1082            )?;
1083            leak(scores_view);
1084        }
1085
1086        // Step 5.5 (C-CAUSAL-001): Apply causal mask — add -inf to future positions
1087        // Loop over heads, adding [seq, seq] mask to each head's scores slice.
1088        // PMAT-420: Use causal_mask_contiguous (correctly strided for seq_len)
1089        // instead of causal_mask (strided at max_seq_len, causes row misalignment
1090        // when seq_len < max_seq_len, leading to NaN after deep layers).
1091        {
1092            let seq_sq = (seq * seq) as usize;
1093            let mask_ptr = self.scratch.causal_mask_contiguous.as_ptr();
1094            let scores_base = self.scratch.attn_scores.as_ptr();
1095            for head in 0..nh as usize {
1096                let byte_offset = (head * seq_sq * 4) as u64; // f32 = 4 bytes
1097                let head_ptr = scores_base + byte_offset;
1098                // SAFETY: mask and scores_slice are non-overlapping GPU regions.
1099                // output aliases scores_slice — safe for element-wise add (read before write).
1100                // Views are leaked to prevent double-free of GPU memory.
1101                let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
1102                let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1103                let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1104                residual_add_forward(&mask_view, &scores_view, &mut out_view, seq * seq, stream)?;
1105                leak(mask_view);
1106                leak(scores_view);
1107                leak(out_view);
1108            }
1109        }
1110
1111        // Step 6: Row-wise softmax → attn_weights [num_heads * seq, seq] (in-place)
1112        let total_rows = nh * seq;
1113        {
1114            // SAFETY: In-place aliasing is safe for BatchedSoftmaxKernel which uses
1115            // shared memory for row-wise reduction. Each row is fully read into shared
1116            // memory before any output is written. The view is forgotten to prevent double-free.
1117            let scores_view = unsafe {
1118                GpuBuffer::<f32>::from_raw_parts(
1119                    self.scratch.attn_scores.as_ptr(),
1120                    self.scratch.attn_scores.len(),
1121                )
1122            };
1123            batched_softmax_forward(
1124                &scores_view,
1125                &mut self.scratch.attn_scores,
1126                total_rows,
1127                seq,
1128                stream,
1129            )?;
1130            leak(scores_view);
1131        }
1132
1133        // Step 7: V layout conversion + GQA expansion
1134        interleaved_to_batched_forward(
1135            &self.scratch.v,
1136            &mut self.scratch.attn_kv_temp,
1137            seq,
1138            nkv,
1139            hd,
1140            stream,
1141        )?;
1142
1143        if heads_per_kv == 1 {
1144            // MHA: V already in [num_heads, seq, head_dim] in attn_kv_temp
1145        } else {
1146            // GQA: expand V [num_kv_heads, seq, hd] → [num_heads, seq, hd]
1147            expand_kv_heads(
1148                &self.scratch.attn_kv_temp,
1149                &mut self.scratch.attn_kv_temp2,
1150                num_kv_heads,
1151                heads_per_kv,
1152                seq_len * head_dim,
1153                stream,
1154            )?;
1155            // Copy expanded V back to attn_kv_temp for the GEMM
1156            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1157            unsafe {
1158                self.scratch
1159                    .attn_kv_temp
1160                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1161                    .map_err(|e| {
1162                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1163                            "V expanded buffer copy failed: {e}"
1164                        ))
1165                    })?;
1166            }
1167        }
1168
1169        // Step 8: attn_weights @ V → attn_result [num_heads, seq, head_dim]
1170        // attn_scores:   [1, num_heads, seq, seq]
1171        // attn_kv_temp:  [1, num_heads, seq, head_dim]
1172        // → attn_q_batched: [1, num_heads, seq, head_dim] (reuse Q buffer)
1173        batched_4d_gemm_forward(
1174            &self.scratch.attn_scores,
1175            &self.scratch.attn_kv_temp,
1176            &mut self.scratch.attn_q_batched,
1177            1,
1178            nh,
1179            seq,
1180            hd,
1181            seq,
1182            stream,
1183        )?;
1184
1185        // Step 9: Convert back to interleaved [seq, num_heads * head_dim] → attn_out
1186        batched_to_interleaved_forward(
1187            &self.scratch.attn_q_batched,
1188            &mut self.scratch.attn_out,
1189            seq,
1190            nh,
1191            hd,
1192            stream,
1193        )?;
1194
1195        Ok(())
1196    }
1197
1198    /// Get layer index
1199    pub fn layer_idx(&self) -> usize {
1200        self.layer_idx
1201    }
1202
1203    /// Get configuration
1204    pub fn config(&self) -> &TransformerConfig {
1205        &self.config
1206    }
1207
1208    /// Backward pass - gradient computation on GPU (ENT-151)
1209    ///
1210    /// Computes gradients for all parameters given upstream gradient.
1211    ///
1212    /// # Arguments
1213    /// * `input` - Original input from forward pass (seq_len * hidden_size)
1214    /// * `grad_output` - Gradient from upstream layer (seq_len * hidden_size)
1215    /// * `grad_input` - Output: gradient w.r.t. input (seq_len * hidden_size)
1216    /// * `seq_len` - Sequence length
1217    /// * `stream` - CUDA stream for async execution
1218    ///
1219    /// # Returns
1220    /// Gradients are accumulated into the scratch buffers:
1221    /// - `scratch.grad_input_norm` - Gradient for input RMSNorm weight
1222    /// - `scratch.grad_post_attn_norm` - Gradient for post-attention RMSNorm weight
1223    /// - `scratch.grad_gate/up/down` - Gradients for FFN weights
1224    /// - `scratch.grad_w_q/w_k/w_v/w_o` - Gradients for attention projection weights
1225    #[provable_contracts_macros::contract("backward-pass-v1", equation = "backward")]
1226    pub fn backward(
1227        &mut self,
1228        input: &GpuBuffer<f32>,
1229        grad_output: &GpuBuffer<f32>,
1230        grad_input: &mut GpuBuffer<f32>,
1231        seq_len: usize,
1232        stream: &CudaStream,
1233        grad_ws: &mut CudaGradWorkspace,
1234    ) -> Result<()> {
1235        let hidden_size = self.config.hidden_size;
1236        let intermediate_size = self.config.intermediate_size;
1237        let eps = 1e-5_f32;
1238
1239        // Zero norm gradient buffers before backward pass.
1240        // BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
1241        // so buffers must be zeroed before each call to prevent cross-step accumulation.
1242        grad_ws.zero_norm_grads(&self.norm_zero_buf)?;
1243
1244        // Backward through final residual: output = residual1 + ffn_output
1245        // grad_output flows to BOTH residual1 (identity skip) and ffn_output path.
1246        self.backward_ffn(grad_output, seq_len, hidden_size, intermediate_size, stream, grad_ws)?;
1247
1248        // Backward through post-attention RMSNorm (FFN path gradient only)
1249        self.backward_post_attn_norm(grad_input, seq_len, hidden_size, eps, stream, grad_ws)?;
1250
1251        // C-RESIDUAL-001 / entrenar#313: Second residual skip gradient.
1252        // Forward: output = residual1 + ffn_output
1253        // The identity skip grad_residual1 = grad_output must bypass the
1254        // post-attention RMSNorm backward entirely — add AFTER norm backward.
1255        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
1256
1257        // Backward through attention: output projection, attention weights, Q/K/V projections
1258        // (ENT-151b: previously missing — attention params received no gradients)
1259        self.backward_attention(grad_input, seq_len, stream, grad_ws)?;
1260
1261        // Backward through first residual connection and input RMSNorm
1262        self.backward_residual_and_input_norm(
1263            input,
1264            grad_output,
1265            grad_input,
1266            seq_len,
1267            hidden_size,
1268            eps,
1269            stream,
1270            grad_ws,
1271        )?;
1272
1273        Ok(())
1274    }
1275
1276    /// Backward through FFN: down projection, SwiGLU, gate+up projections.
1277    ///
1278    /// SwiGLU(gate, up) = silu(gate) * up
1279    /// ∂L/∂gate = ∂L/∂swiglu * up * silu'(gate)
1280    /// ∂L/∂up   = ∂L/∂swiglu * silu(gate)
1281    ///
1282    /// Buffer reuse plan (all [S,I] unless noted):
1283    ///   grad_swiglu  = ∂L/∂swiglu          (computed step 1, read steps 2/4/6)
1284    ///   swiglu_out  → temp1 (step 2)       → silu(gate) (step 4)
1285    ///   up_out      → grad_gate (step 3)
1286    ///   gate_out    → grad_up (step 6)
1287    ///   ffn_out     → grad_norm2_gate [S,H] (step 8)
1288    ///   grad_hidden → grad_norm2_up [S,H]   (step 9)
1289    ///   norm2_out   → accumulated grad [S,H] (step 10)
1290    fn backward_ffn(
1291        &mut self,
1292        grad_output: &GpuBuffer<f32>,
1293        seq_len: usize,
1294        hidden_size: usize,
1295        intermediate_size: usize,
1296        stream: &CudaStream,
1297        grad_ws: &mut CudaGradWorkspace,
1298    ) -> Result<()> {
1299        let n_inter = saturating_u32(seq_len * intermediate_size);
1300        let n_hidden = saturating_u32(seq_len * hidden_size);
1301
1302        // Step 1: grad_swiglu = grad_ffn_out @ w_down^T  [S,I]
1303        gemm_backward_a(
1304            grad_output,
1305            &self.w_down,
1306            &mut self.scratch.grad_swiglu,
1307            saturating_u32(seq_len),
1308            saturating_u32(intermediate_size),
1309            saturating_u32(hidden_size),
1310            stream,
1311        )?;
1312
1313        // Step 2: grad_w_down = swiglu_out^T @ grad_ffn_out  [I,H]
1314        // (swiglu_out free after this)
1315        gemm_backward_b(
1316            &self.scratch.swiglu_out,
1317            grad_output,
1318            &mut grad_ws.grad_down,
1319            saturating_u32(seq_len),
1320            saturating_u32(intermediate_size),
1321            saturating_u32(hidden_size),
1322            stream,
1323        )?;
1324
1325        // === SwiGLU backward: swiglu = silu(gate) * up ===
1326
1327        // Step 3: temp1 = grad_swiglu * up_out → swiglu_out [S,I]
1328        elementwise_mul_forward(
1329            &self.scratch.grad_swiglu,
1330            &self.scratch.up_out,
1331            &mut self.scratch.swiglu_out,
1332            n_inter,
1333            stream,
1334        )?;
1335
1336        // Step 4: grad_gate = silu_backward(gate_out, temp1) → up_out [S,I]
1337        // Computes: (grad_swiglu * up_out) * silu'(gate_out) = correct ∂L/∂gate
1338        silu_backward(
1339            &self.scratch.gate_out,
1340            &self.scratch.swiglu_out,
1341            &mut self.scratch.up_out,
1342            stream,
1343        )?;
1344        // up_out now holds grad_gate [S,I]
1345
1346        // Step 5: silu_gate = silu(gate_out) → swiglu_out [S,I]
1347        silu_forward(&self.scratch.gate_out, &mut self.scratch.swiglu_out, n_inter, stream)?;
1348
1349        // Step 6: grad_up = grad_swiglu * silu_gate → gate_out [S,I]
1350        elementwise_mul_forward(
1351            &self.scratch.grad_swiglu,
1352            &self.scratch.swiglu_out,
1353            &mut self.scratch.gate_out,
1354            n_inter,
1355            stream,
1356        )?;
1357        // gate_out now holds grad_up [S,I]
1358
1359        // === Weight gradients ===
1360
1361        // Step 7a: grad_w_gate = norm2_out^T @ grad_gate (in up_out)  [H,I]
1362        gemm_backward_b(
1363            &self.scratch.norm2_out,
1364            &self.scratch.up_out,
1365            &mut grad_ws.grad_gate,
1366            saturating_u32(seq_len),
1367            saturating_u32(hidden_size),
1368            saturating_u32(intermediate_size),
1369            stream,
1370        )?;
1371
1372        // Step 7b: grad_w_up = norm2_out^T @ grad_up (in gate_out)  [H,I]
1373        gemm_backward_b(
1374            &self.scratch.norm2_out,
1375            &self.scratch.gate_out,
1376            &mut grad_ws.grad_up,
1377            saturating_u32(seq_len),
1378            saturating_u32(hidden_size),
1379            saturating_u32(intermediate_size),
1380            stream,
1381        )?;
1382
1383        // === Input gradient (accumulate gate + up paths) ===
1384
1385        // Step 8: grad_norm2_gate = grad_gate @ w_gate^T → ffn_out [S,H]
1386        gemm_backward_a(
1387            &self.scratch.up_out,
1388            &self.w_gate,
1389            &mut self.scratch.ffn_out,
1390            saturating_u32(seq_len),
1391            saturating_u32(hidden_size),
1392            saturating_u32(intermediate_size),
1393            stream,
1394        )?;
1395
1396        // Step 9: grad_norm2_up = grad_up @ w_up^T → grad_hidden [S,H]
1397        gemm_backward_a(
1398            &self.scratch.gate_out,
1399            &self.w_up,
1400            &mut self.scratch.grad_hidden,
1401            saturating_u32(seq_len),
1402            saturating_u32(hidden_size),
1403            saturating_u32(intermediate_size),
1404            stream,
1405        )?;
1406
1407        // Step 10: norm2_out = grad_norm2_gate + grad_norm2_up  [S,H]
1408        residual_add_forward(
1409            &self.scratch.ffn_out,
1410            &self.scratch.grad_hidden,
1411            &mut self.scratch.norm2_out,
1412            n_hidden,
1413            stream,
1414        )?;
1415
1416        Ok(())
1417    }
1418
1419    /// Backward through post-attention RMSNorm.
1420    fn backward_post_attn_norm(
1421        &mut self,
1422        grad_input: &mut GpuBuffer<f32>,
1423        seq_len: usize,
1424        hidden_size: usize,
1425        eps: f32,
1426        stream: &CudaStream,
1427        grad_ws: &mut CudaGradWorkspace,
1428    ) -> Result<()> {
1429        // D2D copy norm2_out → grad_hidden (avoids D2H + H2D round-trip)
1430        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1431        unsafe {
1432            self.scratch
1433                .grad_hidden
1434                .copy_from_buffer_async(&self.scratch.norm2_out, stream)
1435                .map_err(|e| {
1436                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1437                        "Backward norm D2D copy failed: {e}"
1438                    ))
1439                })?;
1440        }
1441
1442        rms_norm_backward(
1443            &self.scratch.residual1,
1444            &self.post_attn_norm_weight,
1445            &self.scratch.grad_hidden,
1446            grad_input,
1447            &mut grad_ws.grad_post_attn_norm,
1448            saturating_u32(seq_len),
1449            saturating_u32(hidden_size),
1450            eps,
1451            stream,
1452        )
1453    }
1454
1455    /// Backward through multi-head attention (ENT-151b)
1456    ///
1457    /// Reverses the forward attention pipeline:
1458    /// output_proj → layout → attn_weights@V → softmax → scale → Q@K^T → layout → Q/K/V proj
1459    ///
1460    /// # Contract (C-ATTN-BACK-001)
1461    ///
1462    /// - **Precondition**: grad_input contains gradient from post-attention norm backward,
1463    ///   scratch.{q, k, v, attn_scores, attn_out, norm1_out} contain forward pass values
1464    /// - **Postcondition**: grad_hidden contains gradient w.r.t. norm1_out (input to Q/K/V proj),
1465    ///   grad_w_{q,k,v,o} contain weight gradients for attention projections
1466    /// - **Invariant**: Zero CPU-side data transfers; all operations on GPU
1467    fn backward_attention(
1468        &mut self,
1469        grad_input: &mut GpuBuffer<f32>,
1470        seq_len: usize,
1471        stream: &CudaStream,
1472        grad_ws: &mut CudaGradWorkspace,
1473    ) -> Result<()> {
1474        let hidden_size = self.config.hidden_size;
1475        let q_dim = self.config.q_dim();
1476        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
1477        let num_heads = self.config.num_attention_heads;
1478        let num_kv_heads = self.config.num_kv_heads;
1479        let head_dim = self.config.head_dim();
1480        let heads_per_kv = num_heads / num_kv_heads;
1481        let scale = 1.0 / (head_dim as f32).sqrt();
1482
1483        let seq = saturating_u32(seq_len);
1484        let nh = saturating_u32(num_heads);
1485        let nkv = saturating_u32(num_kv_heads);
1486        let hd = saturating_u32(head_dim);
1487
1488        // === Step 4.1: Output projection backward ===
1489        // Forward: o_proj_out[seq,hidden] = attn_out[seq,q_dim] @ w_o[q_dim,hidden]
1490        //   m=seq, k=q_dim, n=hidden
1491        // grad_attn_out[seq,q_dim] = grad_o_proj[seq,hidden] @ w_o^T[hidden,q_dim]
1492        gemm_backward_a(
1493            grad_input,
1494            &self.w_o,
1495            &mut self.scratch.grad_hidden,
1496            seq,
1497            saturating_u32(q_dim),
1498            saturating_u32(hidden_size),
1499            stream,
1500        )?;
1501
1502        // grad_w_o[q_dim,hidden] = attn_out^T[q_dim,seq] @ grad_o_proj[seq,hidden]
1503        gemm_backward_b(
1504            &self.scratch.attn_out,
1505            grad_input,
1506            &mut grad_ws.grad_w_o,
1507            seq,
1508            saturating_u32(q_dim),
1509            saturating_u32(hidden_size),
1510            stream,
1511        )?;
1512
1513        // === Step 4.2: Layout conversion ===
1514        // grad_attn_out [seq, q_dim] → grad_attn_batched [num_heads, seq, head_dim]
1515        // Reuse attn_q_batched for grad_attn_batched
1516        interleaved_to_batched_forward(
1517            &self.scratch.grad_hidden,
1518            &mut self.scratch.attn_q_batched,
1519            seq,
1520            nh,
1521            hd,
1522            stream,
1523        )?;
1524
1525        // === Step 4.3: Backward through attn_weights @ V ===
1526        // Forward was: attn_result = attn_weights @ V_batched
1527        // Reconstruct V_batched from preserved v
1528        interleaved_to_batched_forward(
1529            &self.scratch.v,
1530            &mut self.scratch.attn_kv_temp,
1531            seq,
1532            nkv,
1533            hd,
1534            stream,
1535        )?;
1536
1537        // GQA expand V if needed
1538        if heads_per_kv > 1 {
1539            expand_kv_heads(
1540                &self.scratch.attn_kv_temp,
1541                &mut self.scratch.attn_kv_temp2,
1542                num_kv_heads,
1543                heads_per_kv,
1544                seq_len * head_dim,
1545                stream,
1546            )?;
1547            // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1548            unsafe {
1549                self.scratch
1550                    .attn_kv_temp
1551                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1552                    .map_err(|e| {
1553                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1554                            "Attn backward V expand D2D copy failed: {e}"
1555                        ))
1556                    })?;
1557            }
1558        }
1559        // attn_kv_temp now has V_batched [num_heads, seq, head_dim]
1560
1561        // Transpose V: [num_heads, seq, head_dim] → [num_heads, head_dim, seq]
1562        batched_transpose_forward(
1563            &self.scratch.attn_kv_temp,
1564            &mut self.scratch.attn_kv_temp2,
1565            nh,
1566            seq,
1567            hd,
1568            stream,
1569        )?;
1570        // attn_kv_temp2 = V^T [num_heads, head_dim, seq]
1571
1572        // grad_attn_weights = grad_attn_batched @ V^T → grad_attn_scores [H, seq, seq]
1573        batched_4d_gemm_forward(
1574            &self.scratch.attn_q_batched,
1575            &self.scratch.attn_kv_temp2,
1576            &mut self.scratch.grad_attn_scores,
1577            1,
1578            nh,
1579            seq,
1580            seq,
1581            hd,
1582            stream,
1583        )?;
1584
1585        // grad_V = attn_weights^T @ grad_attn_batched → attn_kv_temp [H, seq, hd]
1586        //
1587        // BUG FIX: Cannot transpose attn_scores [H,S,S] into attn_kv_temp2 [H,S,hd]
1588        // because H*S*S >> H*S*hd when S > hd (e.g. 350M: 4.2M vs 524K = 8× overflow).
1589        //
1590        // Use identity: grad_V = (grad_attn_batched^T @ attn_scores)^T
1591        // All intermediates are [H, hd, S] = [H, S, hd] size — no H*S*S buffer needed.
1592
1593        // Step A: transpose grad_attn_batched [H,S,hd] → [H,hd,S]
1594        batched_transpose_forward(
1595            &self.scratch.attn_q_batched,   // grad_attn_batched [H, S, hd]
1596            &mut self.scratch.attn_kv_temp, // temp: grad_attn_batched^T [H, hd, S]
1597            nh,
1598            seq,
1599            hd,
1600            stream,
1601        )?;
1602
1603        // Step B: GEMM [H,hd,S] @ [H,S,S] → [H,hd,S] (= grad_V^T)
1604        batched_4d_gemm_forward(
1605            &self.scratch.attn_kv_temp,      // grad_attn_batched^T [H, hd, S]
1606            &self.scratch.attn_scores,       // attn_weights [H, S, S]
1607            &mut self.scratch.attn_kv_temp2, // grad_V^T [H, hd, S]
1608            1,
1609            nh,
1610            hd,  // m
1611            seq, // n
1612            seq, // k
1613            stream,
1614        )?;
1615
1616        // Step C: transpose grad_V^T [H,hd,S] → grad_V [H,S,hd]
1617        batched_transpose_forward(
1618            &self.scratch.attn_kv_temp2,    // grad_V^T [H, hd, S]
1619            &mut self.scratch.attn_kv_temp, // grad_V [H, S, hd]
1620            nh,
1621            hd,
1622            seq,
1623            stream,
1624        )?;
1625        // attn_kv_temp = grad_V [num_heads, seq, head_dim]
1626
1627        // === Step 4.4: Softmax backward ===
1628        // attn_scores contains softmax output from forward pass
1629        // In-place: grad_attn_scores is both input (grad_output) and output (grad_input)
1630        // This is safe because the kernel reads all elements in pass 1 before writing in pass 2.
1631        let total_rows = nh * seq;
1632        {
1633            // SAFETY: In-place aliasing is safe for BatchedSoftmaxBackwardKernel which uses
1634            // a two-pass approach: pass 1 reads all y[i]*gy[i] to compute dot product,
1635            // pass 2 writes grad_x[i]. The view is forgotten to prevent double-free.
1636            let grad_scores_view = unsafe {
1637                GpuBuffer::<f32>::from_raw_parts(
1638                    self.scratch.grad_attn_scores.as_ptr(),
1639                    self.scratch.grad_attn_scores.len(),
1640                )
1641            };
1642            batched_softmax_backward(
1643                &self.scratch.attn_scores,
1644                &grad_scores_view,
1645                &mut self.scratch.grad_attn_scores,
1646                total_rows,
1647                seq,
1648                stream,
1649            )?;
1650            leak(grad_scores_view);
1651        }
1652        // grad_attn_scores now contains gradient through softmax
1653
1654        // === Step 4.5: Scale backward ===
1655        // Forward scaled by 1/√d_k, backward is same scale (linear operation)
1656        let total_scores = nh * seq * seq;
1657        {
1658            // SAFETY: In-place aliasing safe for element-wise scale (independent elements).
1659            let scores_view = unsafe {
1660                GpuBuffer::<f32>::from_raw_parts(
1661                    self.scratch.grad_attn_scores.as_ptr(),
1662                    self.scratch.grad_attn_scores.len(),
1663                )
1664            };
1665            scale_forward(
1666                &scores_view,
1667                &mut self.scratch.grad_attn_scores,
1668                scale,
1669                total_scores,
1670                stream,
1671            )?;
1672            leak(scores_view);
1673        }
1674
1675        // === Step 4.6: Backward through Q @ K^T ===
1676        // Forward was: scores = Q_batched @ K^T
1677        // Reconstruct K_batched and expand for GQA
1678        interleaved_to_batched_forward(
1679            &self.scratch.k,
1680            &mut self.scratch.attn_kv_temp2,
1681            seq,
1682            nkv,
1683            hd,
1684            stream,
1685        )?;
1686
1687        if heads_per_kv > 1 {
1688            // SAFETY: Both buffers are valid GPU allocations; attn_q_batched is about to be
1689            // overwritten anyway.
1690            unsafe {
1691                self.scratch
1692                    .attn_q_batched
1693                    .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1694                    .map_err(|e| {
1695                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1696                            "Attn backward K copy for GQA expand failed: {e}"
1697                        ))
1698                    })?;
1699            }
1700            expand_kv_heads(
1701                &self.scratch.attn_q_batched,
1702                &mut self.scratch.attn_kv_temp2,
1703                num_kv_heads,
1704                heads_per_kv,
1705                seq_len * head_dim,
1706                stream,
1707            )?;
1708        }
1709        // attn_kv_temp2 = K_expanded [num_heads, seq, head_dim]
1710
1711        // grad_Q = grad_scores @ K_expanded → attn_q_batched [H, seq, hd]
1712        batched_4d_gemm_forward(
1713            &self.scratch.grad_attn_scores,
1714            &self.scratch.attn_kv_temp2,
1715            &mut self.scratch.attn_q_batched,
1716            1,
1717            nh,
1718            seq,
1719            hd,
1720            seq,
1721            stream,
1722        )?;
1723
1724        // grad_K^T = Q^T @ grad_scores
1725        // First reconstruct Q_batched from preserved q
1726        // Reconstruct Q_batched into o_proj_out (attn_q_batched already overwritten by grad_Q).
1727        interleaved_to_batched_forward(
1728            &self.scratch.q,
1729            &mut self.scratch.o_proj_out, // temp buffer for Q_batched
1730            seq,
1731            nh,
1732            hd,
1733            stream,
1734        )?;
1735
1736        // Transpose Q: [H, seq, hd] → [H, hd, seq]
1737        batched_transpose_forward(
1738            &self.scratch.o_proj_out,
1739            &mut self.scratch.attn_kv_temp2, // reuse for Q^T
1740            nh,
1741            seq,
1742            hd,
1743            stream,
1744        )?;
1745
1746        // grad_K^T = Q^T @ grad_scores → ffn_out as temp [H, hd, seq]
1747        batched_4d_gemm_forward(
1748            &self.scratch.attn_kv_temp2,
1749            &self.scratch.grad_attn_scores,
1750            &mut self.scratch.ffn_out, // reuse as temp for grad_K^T [H, hd, seq]
1751            1,
1752            nh,
1753            hd,
1754            seq,
1755            seq,
1756            stream,
1757        )?;
1758
1759        // Transpose grad_K^T → grad_K: [H, hd, seq] → [H, seq, hd]
1760        batched_transpose_forward(
1761            &self.scratch.ffn_out,
1762            &mut self.scratch.attn_kv_temp2, // grad_K [H, seq, hd]
1763            nh,
1764            hd,
1765            seq,
1766            stream,
1767        )?;
1768
1769        // === Step 4.7: GQA gradient reduction ===
1770        // grad_K and grad_V are in [num_heads, seq, hd], need to reduce to [num_kv_heads, seq, hd]
1771        if heads_per_kv > 1 {
1772            self.reduce_gqa_gradients(num_kv_heads, heads_per_kv, seq_len, head_dim, stream)?;
1773        }
1774
1775        // === Step 4.8: Convert gradients back to interleaved layout ===
1776        // grad_Q: attn_q_batched [H, seq, hd] → o_proj_out [seq, hidden] (interleaved)
1777        batched_to_interleaved_forward(
1778            &self.scratch.attn_q_batched,
1779            &mut self.scratch.o_proj_out,
1780            seq,
1781            nh,
1782            hd,
1783            stream,
1784        )?;
1785
1786        // grad_K: attn_kv_temp2 [nkv, seq, hd] → norm2_out [seq, kv_hidden] (interleaved)
1787        batched_to_interleaved_forward(
1788            &self.scratch.attn_kv_temp2,
1789            &mut self.scratch.norm2_out,
1790            seq,
1791            nkv,
1792            hd,
1793            stream,
1794        )?;
1795
1796        // grad_V: attn_kv_temp [nkv, seq, hd] → ffn_out [seq, kv_hidden] (interleaved)
1797        batched_to_interleaved_forward(
1798            &self.scratch.attn_kv_temp,
1799            &mut self.scratch.ffn_out,
1800            seq,
1801            nkv,
1802            hd,
1803            stream,
1804        )?;
1805
1806        // === Step 4.8b: RoPE backward (inverse rotation) ===
1807        // Forward applied RoPE to Q and K before attention. Backward must undo
1808        // the rotation so projection backward (step 4.9) gets unrotated gradients.
1809        // R^T(-θ): new_x0 = x0*cos + x1*sin, new_x1 = x1*cos - x0*sin
1810        let rope_theta = self.config.rope_theta;
1811        {
1812            // grad_Q in o_proj_out [seq, q_dim] — apply inverse rotation in-place
1813            let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.o_proj_out)) };
1814            batched_rope_neox_backward(
1815                q_ref,
1816                &mut self.scratch.o_proj_out,
1817                &self.scratch.rope_positions,
1818                nh,
1819                hd,
1820                seq,
1821                rope_theta,
1822                stream,
1823            )?;
1824            // grad_K in norm2_out [seq, kv_hidden] — apply inverse rotation in-place
1825            let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.norm2_out)) };
1826            batched_rope_neox_backward(
1827                k_ref,
1828                &mut self.scratch.norm2_out,
1829                &self.scratch.rope_positions,
1830                nkv,
1831                hd,
1832                seq,
1833                rope_theta,
1834                stream,
1835            )?;
1836        }
1837
1838        // === Step 4.9: Q/K/V projection backward ===
1839        // Forward: q[seq,q_dim] = norm1[seq,hidden] @ w_q[hidden,q_dim]
1840        //   m=seq, k=hidden, n=q_dim
1841        // grad_norm1[seq,hidden] = grad_q[seq,q_dim] @ w_q^T[q_dim,hidden]
1842        gemm_backward_a(
1843            &self.scratch.o_proj_out, // grad_q interleaved [seq, q_dim]
1844            &self.w_q,
1845            &mut self.scratch.grad_hidden,
1846            seq,
1847            saturating_u32(hidden_size),
1848            saturating_u32(q_dim),
1849            stream,
1850        )?;
1851
1852        // grad_norm1 += grad_k @ w_k^T
1853        // Forward: k[seq,kv_hidden] = norm1[seq,hidden] @ w_k[hidden,kv_hidden]
1854        //   m=seq, k=hidden, n=kv_hidden
1855        // KAIZEN-057: cuda_add_inplace replaces residual_add_forward + D2D copy
1856        gemm_backward_a(
1857            &self.scratch.norm2_out, // grad_k interleaved
1858            &self.w_k,
1859            &mut self.scratch.grad_attn_scores, // temp for grad_k @ w_k^T
1860            seq,
1861            saturating_u32(hidden_size),
1862            saturating_u32(kv_hidden_size),
1863            stream,
1864        )?;
1865        cuda_add_inplace(
1866            &mut self.scratch.grad_hidden,
1867            &self.scratch.grad_attn_scores,
1868            seq_len * hidden_size,
1869            stream,
1870        )?;
1871
1872        // grad_norm1 += grad_v @ w_v^T
1873        // Forward: v[seq,kv_hidden] = norm1[seq,hidden] @ w_v[hidden,kv_hidden]
1874        //   m=seq, k=hidden, n=kv_hidden
1875        gemm_backward_a(
1876            &self.scratch.ffn_out, // grad_v interleaved
1877            &self.w_v,
1878            &mut self.scratch.grad_attn_scores, // temp for grad_v @ w_v^T
1879            seq,
1880            saturating_u32(hidden_size),
1881            saturating_u32(kv_hidden_size),
1882            stream,
1883        )?;
1884        cuda_add_inplace(
1885            &mut self.scratch.grad_hidden,
1886            &self.scratch.grad_attn_scores,
1887            seq_len * hidden_size,
1888            stream,
1889        )?;
1890
1891        // Weight gradients: grad_w_q[hidden,q_dim] = norm1_out^T[hidden,seq] @ grad_q[seq,q_dim]
1892        gemm_backward_b(
1893            &self.scratch.norm1_out,
1894            &self.scratch.o_proj_out, // grad_q [seq, q_dim]
1895            &mut grad_ws.grad_w_q,
1896            seq,
1897            saturating_u32(hidden_size),
1898            saturating_u32(q_dim),
1899            stream,
1900        )?;
1901
1902        // grad_w_k = norm1_out^T @ grad_k
1903        gemm_backward_b(
1904            &self.scratch.norm1_out,
1905            &self.scratch.norm2_out, // grad_k
1906            &mut grad_ws.grad_w_k,
1907            seq,
1908            saturating_u32(hidden_size),
1909            saturating_u32(kv_hidden_size),
1910            stream,
1911        )?;
1912
1913        // grad_w_v = norm1_out^T @ grad_v
1914        gemm_backward_b(
1915            &self.scratch.norm1_out,
1916            &self.scratch.ffn_out, // grad_v
1917            &mut grad_ws.grad_w_v,
1918            seq,
1919            saturating_u32(hidden_size),
1920            saturating_u32(kv_hidden_size),
1921            stream,
1922        )?;
1923
1924        // Copy grad_hidden → grad_input for downstream (residual backward)
1925        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
1926        unsafe {
1927            grad_input.copy_from_buffer_async(&self.scratch.grad_hidden, stream).map_err(|e| {
1928                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1929                    "Attn backward grad_hidden → grad_input D2D copy failed: {e}"
1930                ))
1931            })?;
1932        }
1933
1934        Ok(())
1935    }
1936
1937    /// Reduce GQA head gradients from [num_heads] to [num_kv_heads] by summing groups.
1938    ///
1939    /// Reads grad_K from `attn_kv_temp2` and grad_V from `attn_kv_temp` (both [H, seq, hd]).
1940    /// Writes reduced grad_K to `attn_kv_temp2` and reduced grad_V to `attn_kv_temp`
1941    /// (both [nkv, seq, hd]).
1942    ///
1943    /// Uses `grad_attn_scores`, `ffn_out`, `o_proj_out`, `grad_hidden` as scratch.
1944    fn reduce_gqa_gradients(
1945        &mut self,
1946        num_kv_heads: usize,
1947        heads_per_kv: usize,
1948        seq_len: usize,
1949        head_dim: usize,
1950        stream: &CudaStream,
1951    ) -> Result<()> {
1952        let elems_per_head = seq_len * head_dim;
1953
1954        // Reduce grad_K: attn_kv_temp2 [H] → grad_attn_scores [nkv]
1955        self.reduce_single_gqa_gradient(true, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1956
1957        // Reduce grad_V: attn_kv_temp [H] → ffn_out [nkv]
1958        self.reduce_single_gqa_gradient(false, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1959
1960        // Copy reduced results to known locations for step 4.8
1961        let kv_elems = num_kv_heads * elems_per_head;
1962        // SAFETY: Valid GPU allocations with sufficient size.
1963        unsafe {
1964            self.scratch
1965                .attn_kv_temp2
1966                .copy_from_buffer_at_async(&self.scratch.grad_attn_scores, 0, 0, kv_elems, stream)
1967                .map_err(|e| {
1968                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1969                        "GQA grad_K reduced final copy failed: {e}"
1970                    ))
1971                })?;
1972            self.scratch
1973                .attn_kv_temp
1974                .copy_from_buffer_at_async(&self.scratch.ffn_out, 0, 0, kv_elems, stream)
1975                .map_err(|e| {
1976                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1977                        "GQA grad_V reduced final copy failed: {e}"
1978                    ))
1979                })?;
1980        }
1981        Ok(())
1982    }
1983
1984    /// Reduce one gradient tensor from [num_heads] to [num_kv_heads] by summing groups.
1985    ///
1986    /// When `is_k=true`: reads from `attn_kv_temp2`, writes to `grad_attn_scores`.
1987    /// When `is_k=false`: reads from `attn_kv_temp`, writes to `ffn_out`.
1988    /// Uses `o_proj_out` and `grad_hidden` as scratch.
1989    fn reduce_single_gqa_gradient(
1990        &mut self,
1991        is_k: bool,
1992        num_kv_heads: usize,
1993        heads_per_kv: usize,
1994        elems_per_head: usize,
1995        stream: &CudaStream,
1996    ) -> Result<()> {
1997        let label = if is_k { "K" } else { "V" };
1998
1999        for kv_h in 0..num_kv_heads {
2000            let dst_offset = kv_h * elems_per_head;
2001            let first_h = kv_h * heads_per_kv;
2002            let src_offset = first_h * elems_per_head;
2003
2004            // Copy first head of group as base
2005            // SAFETY: All offsets are within buffer bounds.
2006            unsafe {
2007                let (dst, src) = if is_k {
2008                    (&mut self.scratch.grad_attn_scores, &self.scratch.attn_kv_temp2)
2009                } else {
2010                    (&mut self.scratch.ffn_out, &self.scratch.attn_kv_temp)
2011                };
2012                dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
2013                    .map_err(|e| {
2014                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2015                            "GQA grad_{label} reduce base copy failed: {e}"
2016                        ))
2017                    })?;
2018            }
2019
2020            // Add remaining heads in group
2021            for rep in 1..heads_per_kv {
2022                let h = kv_h * heads_per_kv + rep;
2023                let h_offset = h * elems_per_head;
2024
2025                // Head extraction into o_proj_out buffer
2026                // SAFETY: Valid GPU allocations with sufficient size.
2027                unsafe {
2028                    let src =
2029                        if is_k { &self.scratch.attn_kv_temp2 } else { &self.scratch.attn_kv_temp };
2030                    self.scratch
2031                        .o_proj_out
2032                        .copy_from_buffer_at_async(src, 0, h_offset, elems_per_head, stream)
2033                        .map_err(|e| {
2034                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2035                                "GQA grad_{label} reduce head copy failed: {e}"
2036                            ))
2037                        })?;
2038                }
2039
2040                // Add: dst[dst_offset..] += o_proj_out[0..elems_per_head]
2041                // SAFETY: Creating non-owning views for arithmetic; forgotten to prevent double-free.
2042                unsafe {
2043                    let dst_buf =
2044                        if is_k { &self.scratch.grad_attn_scores } else { &self.scratch.ffn_out };
2045                    let dst_view = GpuBuffer::<f32>::from_raw_parts(
2046                        dst_buf.as_ptr() + (dst_offset as u64 * 4),
2047                        elems_per_head,
2048                    );
2049                    let src_view = GpuBuffer::<f32>::from_raw_parts(
2050                        self.scratch.o_proj_out.as_ptr(),
2051                        elems_per_head,
2052                    );
2053                    let mut sum_view = GpuBuffer::<f32>::from_raw_parts(
2054                        self.scratch.grad_hidden.as_ptr(),
2055                        elems_per_head,
2056                    );
2057                    residual_add_forward(
2058                        &dst_view,
2059                        &src_view,
2060                        &mut sum_view,
2061                        saturating_u32(elems_per_head),
2062                        stream,
2063                    )?;
2064                    // Copy sum back to dst at dst_offset
2065                    let dst_buf = if is_k {
2066                        &mut self.scratch.grad_attn_scores
2067                    } else {
2068                        &mut self.scratch.ffn_out
2069                    };
2070                    dst_buf
2071                        .copy_from_buffer_at_async(
2072                            &self.scratch.grad_hidden,
2073                            dst_offset,
2074                            0,
2075                            elems_per_head,
2076                            stream,
2077                        )
2078                        .map_err(|e| {
2079                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2080                                "GQA grad_{label} reduce sum copy failed: {e}"
2081                            ))
2082                        })?;
2083                    leak(dst_view);
2084                    leak(src_view);
2085                    leak(sum_view);
2086                }
2087            }
2088        }
2089        Ok(())
2090    }
2091
2092    /// Backward through first residual connection and input RMSNorm.
2093    fn backward_residual_and_input_norm(
2094        &mut self,
2095        input: &GpuBuffer<f32>,
2096        grad_output: &GpuBuffer<f32>,
2097        grad_input: &mut GpuBuffer<f32>,
2098        seq_len: usize,
2099        hidden_size: usize,
2100        eps: f32,
2101        stream: &CudaStream,
2102        grad_ws: &mut CudaGradWorkspace,
2103    ) -> Result<()> {
2104        // Forward was: norm_out = RMSNorm(input); attn_out = Attn(norm_out); residual1 = input + attn_out
2105        // Backward: grad_input = RMSNorm_backward(grad_through_attention) + grad_output
2106        //
2107        // C-RESIDUAL-001 / entrenar#313: The residual skip (grad_output) must be added
2108        // AFTER RMSNorm backward, not before. RMSNorm backward should only transform
2109        // the gradient that flows through the norm/attention path. The identity skip
2110        // bypasses the norm entirely.
2111
2112        // D2D copy grad_input (attention path gradient) to grad_hidden
2113        // SAFETY: Both buffers are valid GPU allocations with matching sizes.
2114        unsafe {
2115            self.scratch.grad_hidden.copy_from_buffer_async(grad_input, stream).map_err(|e| {
2116                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2117                    "Backward residual grad_hidden D2D copy failed: {e}"
2118                ))
2119            })?;
2120        }
2121
2122        // RMSNorm backward: only applied to the attention path gradient
2123        rms_norm_backward(
2124            input,
2125            &self.input_norm_weight,
2126            &self.scratch.grad_hidden,
2127            grad_input,
2128            &mut grad_ws.grad_input_norm,
2129            saturating_u32(seq_len),
2130            saturating_u32(hidden_size),
2131            eps,
2132            stream,
2133        )?;
2134
2135        // NOW add the residual skip: grad_input += grad_output (identity connection)
2136        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)
2137    }
2138
2139    /// Initialize GPU-resident AdamW optimizer state for all block weights.
2140    ///
2141    /// Allocates zero-initialized first and second moment buffers for each of the
2142    /// 9 weight tensors (4 attention projections + 3 FFN projections + 2 RMSNorm).
2143    ///
2144    /// # Contract (C-OPTINIT-001)
2145    ///
2146    /// - **Precondition**: CUDA context is valid, sufficient GPU memory available
2147    /// - **Postcondition**: All m/v buffers are zero-initialized with dimensions
2148    ///   matching the corresponding weight tensors
2149    /// - **Invariant**: Total GPU memory for optimizer state = 2 × sum(weight_sizes) × 4 bytes
2150    pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2151        let hidden = self.config.hidden_size;
2152        let q_dim = self.config.q_dim();
2153        let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
2154        let intermediate = self.config.intermediate_size;
2155
2156        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
2157        // zero memory (cuMemAlloc returns uninitialized VRAM). Uninitialized m/v
2158        // causes v_new = beta2 * GARBAGE which can be negative → sqrt(neg) → NaN.
2159        let z = |n: usize| -> Result<GpuBuffer<f32>> {
2160            Ok(GpuBuffer::from_host(&self.ctx, &vec![0.0f32; n])?)
2161        };
2162        Ok(GpuBlockOptimizerState {
2163            m_w_q: z(q_dim * hidden)?,
2164            v_w_q: z(q_dim * hidden)?,
2165            m_w_k: z(hidden * kv_hidden)?,
2166            v_w_k: z(hidden * kv_hidden)?,
2167            m_w_v: z(hidden * kv_hidden)?,
2168            v_w_v: z(hidden * kv_hidden)?,
2169            m_w_o: z(hidden * q_dim)?,
2170            v_w_o: z(hidden * q_dim)?,
2171            m_w_gate: z(hidden * intermediate)?,
2172            v_w_gate: z(hidden * intermediate)?,
2173            m_w_up: z(hidden * intermediate)?,
2174            v_w_up: z(hidden * intermediate)?,
2175            m_w_down: z(intermediate * hidden)?,
2176            v_w_down: z(intermediate * hidden)?,
2177            m_input_norm: z(hidden)?,
2178            v_input_norm: z(hidden)?,
2179            m_post_attn_norm: z(hidden)?,
2180            v_post_attn_norm: z(hidden)?,
2181        })
2182    }
2183
2184    /// Run GPU-resident AdamW optimizer step on all block weights.
2185    ///
2186    /// Updates weights in-place using gradients computed by `backward()`.
2187    /// All operations run on GPU — zero CPU↔GPU data transfers.
2188    ///
2189    /// # Contract (C-OPTSTEP-001)
2190    ///
2191    /// - **Precondition**: `backward()` completed for this block (scratch grad buffers valid),
2192    ///   `state` initialized via `init_optimizer_state()`, `step > 0`
2193    /// - **Postcondition**: All 9 weight tensors updated by AdamW rule,
2194    ///   m/v states updated with current gradient statistics
2195    /// - **Invariant**: Weight dimensions unchanged; no GPU memory allocated or freed
2196    pub fn optimizer_step(
2197        &mut self,
2198        state: &mut GpuBlockOptimizerState,
2199        step: u32,
2200        lr: f32,
2201        beta1: f32,
2202        beta2: f32,
2203        eps: f32,
2204        weight_decay: f32,
2205        stream: &CudaStream,
2206        grad_ws: &CudaGradWorkspace,
2207    ) -> Result<()> {
2208        debug_assert!(step > 0, "C-OPTSTEP-001: step must be > 0 for bias adjust");
2209
2210        // Pre-capture lengths to avoid borrow conflicts (len is immutable borrow,
2211        // adamw_step_cuda takes mutable borrow on same buffer)
2212        let n_wq = self.w_q.len() as u32;
2213        let n_wk = self.w_k.len() as u32;
2214        let n_wv = self.w_v.len() as u32;
2215        let n_wo = self.w_o.len() as u32;
2216        let n_gate = self.w_gate.len() as u32;
2217        let n_up = self.w_up.len() as u32;
2218        let n_down = self.w_down.len() as u32;
2219        let n_inorm = self.input_norm_weight.len() as u32;
2220        let n_panorm = self.post_attn_norm_weight.len() as u32;
2221
2222        // Attention projection weights
2223        adamw_step_cuda(
2224            &mut self.w_q,
2225            &grad_ws.grad_w_q,
2226            &mut state.m_w_q,
2227            &mut state.v_w_q,
2228            lr,
2229            beta1,
2230            beta2,
2231            eps,
2232            weight_decay,
2233            step,
2234            n_wq,
2235            stream,
2236        )?;
2237        adamw_step_cuda(
2238            &mut self.w_k,
2239            &grad_ws.grad_w_k,
2240            &mut state.m_w_k,
2241            &mut state.v_w_k,
2242            lr,
2243            beta1,
2244            beta2,
2245            eps,
2246            weight_decay,
2247            step,
2248            n_wk,
2249            stream,
2250        )?;
2251        adamw_step_cuda(
2252            &mut self.w_v,
2253            &grad_ws.grad_w_v,
2254            &mut state.m_w_v,
2255            &mut state.v_w_v,
2256            lr,
2257            beta1,
2258            beta2,
2259            eps,
2260            weight_decay,
2261            step,
2262            n_wv,
2263            stream,
2264        )?;
2265        adamw_step_cuda(
2266            &mut self.w_o,
2267            &grad_ws.grad_w_o,
2268            &mut state.m_w_o,
2269            &mut state.v_w_o,
2270            lr,
2271            beta1,
2272            beta2,
2273            eps,
2274            weight_decay,
2275            step,
2276            n_wo,
2277            stream,
2278        )?;
2279
2280        // FFN projection weights
2281        adamw_step_cuda(
2282            &mut self.w_gate,
2283            &grad_ws.grad_gate,
2284            &mut state.m_w_gate,
2285            &mut state.v_w_gate,
2286            lr,
2287            beta1,
2288            beta2,
2289            eps,
2290            weight_decay,
2291            step,
2292            n_gate,
2293            stream,
2294        )?;
2295        adamw_step_cuda(
2296            &mut self.w_up,
2297            &grad_ws.grad_up,
2298            &mut state.m_w_up,
2299            &mut state.v_w_up,
2300            lr,
2301            beta1,
2302            beta2,
2303            eps,
2304            weight_decay,
2305            step,
2306            n_up,
2307            stream,
2308        )?;
2309        adamw_step_cuda(
2310            &mut self.w_down,
2311            &grad_ws.grad_down,
2312            &mut state.m_w_down,
2313            &mut state.v_w_down,
2314            lr,
2315            beta1,
2316            beta2,
2317            eps,
2318            weight_decay,
2319            step,
2320            n_down,
2321            stream,
2322        )?;
2323
2324        // RMSNorm weights
2325        adamw_step_cuda(
2326            &mut self.input_norm_weight,
2327            &grad_ws.grad_input_norm,
2328            &mut state.m_input_norm,
2329            &mut state.v_input_norm,
2330            lr,
2331            beta1,
2332            beta2,
2333            eps,
2334            weight_decay,
2335            step,
2336            n_inorm,
2337            stream,
2338        )?;
2339        adamw_step_cuda(
2340            &mut self.post_attn_norm_weight,
2341            &grad_ws.grad_post_attn_norm,
2342            &mut state.m_post_attn_norm,
2343            &mut state.v_post_attn_norm,
2344            lr,
2345            beta1,
2346            beta2,
2347            eps,
2348            weight_decay,
2349            step,
2350            n_panorm,
2351            stream,
2352        )?;
2353
2354        Ok(())
2355    }
2356
2357    /// Download all weight data from GPU to host vectors.
2358    ///
2359    /// Used to synchronize GPU-updated weights back to CPU model for checkpointing.
2360    ///
2361    /// # Contract (C-DLWEIGHTS-001)
2362    ///
2363    /// - **Precondition**: Block weights are valid GPU allocations
2364    /// - **Postcondition**: Returned vectors have exact same length and content as GPU buffers
2365    /// - **Invariant**: GPU buffers are not modified
2366    pub fn download_weights(&self) -> Result<BlockWeights> {
2367        let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
2368            let mut host = vec![0.0f32; buf.len()];
2369            buf.copy_to_host(&mut host).map_err(|e| {
2370                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2371                    "Weight download failed: {e}"
2372                ))
2373            })?;
2374            Ok(host)
2375        };
2376
2377        Ok(BlockWeights {
2378            w_q: download(&self.w_q)?,
2379            w_k: download(&self.w_k)?,
2380            w_v: download(&self.w_v)?,
2381            w_o: download(&self.w_o)?,
2382            w_gate: download(&self.w_gate)?,
2383            w_up: download(&self.w_up)?,
2384            w_down: download(&self.w_down)?,
2385            input_norm_weight: download(&self.input_norm_weight)?,
2386            post_attn_norm_weight: download(&self.post_attn_norm_weight)?,
2387        })
2388    }
2389}
2390
2391/// Downloaded weight data from a CUDA transformer block.
2392///
2393/// # Contract (C-BLOCKWT-001)
2394///
2395/// - **Invariant**: Vector lengths match original weight dimensions
2396#[cfg(feature = "cuda")]
2397pub struct BlockWeights {
2398    pub w_q: Vec<f32>,
2399    pub w_k: Vec<f32>,
2400    pub w_v: Vec<f32>,
2401    pub w_o: Vec<f32>,
2402    pub w_gate: Vec<f32>,
2403    pub w_up: Vec<f32>,
2404    pub w_down: Vec<f32>,
2405    pub input_norm_weight: Vec<f32>,
2406    pub post_attn_norm_weight: Vec<f32>,
2407}
2408
2409/// CUDA element-wise addition on GPU (zero CPU transfers)
2410///
2411/// Uses `ResidualAddKernel` — single kernel launch, no D2H/H2D transfers.
2412#[cfg(feature = "cuda")]
2413fn cuda_add(
2414    a: &GpuBuffer<f32>,
2415    b: &GpuBuffer<f32>,
2416    output: &mut GpuBuffer<f32>,
2417    n: usize,
2418    stream: &CudaStream,
2419) -> Result<()> {
2420    residual_add_forward(a, b, output, saturating_u32(n), stream)
2421}
2422
2423/// In-place add: `target += source` using residual add with aliased output.
2424///
2425/// # Safety
2426///
2427/// The ResidualAdd kernel reads `a[i]` and `b[i]` then writes `output[i] = a[i] + b[i]`.
2428/// When `a` and `output` alias the same GPU buffer, each element is read before written
2429/// (no inter-element dependency), so this is safe for elementwise operations.
2430#[cfg(feature = "cuda")]
2431pub(crate) fn cuda_add_inplace(
2432    target: &mut GpuBuffer<f32>,
2433    source: &GpuBuffer<f32>,
2434    n: usize,
2435    stream: &CudaStream,
2436) -> Result<()> {
2437    // SAFETY: ResidualAdd kernel is elementwise (output[i] = a[i] + b[i]).
2438    // Aliasing target as both input and output is safe because each element is
2439    // independent — the GPU reads a[i] before writing output[i] at the same address.
2440    let target_ref: &GpuBuffer<f32> = unsafe { &*std::ptr::from_ref::<GpuBuffer<f32>>(target) };
2441    residual_add_forward(target_ref, source, target, saturating_u32(n), stream)
2442}
2443
2444/// CUDA element-wise multiplication on GPU (zero CPU transfers)
2445///
2446/// Uses `ElementwiseMulKernel` — single kernel launch, no D2H/H2D transfers.
2447#[cfg(feature = "cuda")]
2448fn cuda_mul(
2449    a: &GpuBuffer<f32>,
2450    b: &GpuBuffer<f32>,
2451    output: &mut GpuBuffer<f32>,
2452    n: usize,
2453    stream: &CudaStream,
2454) -> Result<()> {
2455    crate::autograd::cuda_forward::elementwise_mul_forward(a, b, output, saturating_u32(n), stream)
2456}
2457
2458// CPU fallback stub
2459#[cfg(not(feature = "cuda"))]
2460pub struct CudaTransformerBlock;
2461
2462#[cfg(not(feature = "cuda"))]
2463impl CudaTransformerBlock {
2464    pub fn layer_idx(&self) -> usize {
2465        0
2466    }
2467}
2468
2469// =============================================================================
2470// CudaBlock — enum dispatching fp32 or NF4 transformer blocks
2471// =============================================================================
2472
2473/// Unified enum for CUDA transformer blocks (fp32 or NF4-quantized).
2474///
2475/// The classify pipeline stores `Vec<CudaBlock>` and calls `forward()` without
2476/// caring which quantization format the frozen weights use.
2477#[cfg(feature = "cuda")]
2478pub enum CudaBlock {
2479    /// Standard fp32 weights (full precision, ~16 GB for Qwen3-4B)
2480    Fp32(CudaTransformerBlock),
2481    /// NF4 quantized weights (~2 GB for Qwen3-4B, ~8x compression)
2482    Nf4(CudaNf4TransformerBlock),
2483}
2484
2485#[cfg(feature = "cuda")]
2486impl CudaBlock {
2487    /// Forward pass through the transformer block.
2488    ///
2489    /// For NF4 blocks, `shared_scratch` must be `Some` — shared across all layers (C-SCRATCH-001).
2490    /// For fp32 blocks, `shared_scratch` is ignored (each block owns its scratch for backward).
2491    pub(crate) fn forward(
2492        &mut self,
2493        input: &GpuBuffer<f32>,
2494        output: &mut GpuBuffer<f32>,
2495        seq_len: usize,
2496        stream: &CudaStream,
2497        shared_scratch: Option<&mut CudaBlockScratch>,
2498    ) -> Result<()> {
2499        match self {
2500            CudaBlock::Fp32(b) => b.forward(input, output, seq_len, stream),
2501            CudaBlock::Nf4(b) => {
2502                let scratch =
2503                    shared_scratch.expect("C-SCRATCH-001: NF4 blocks require shared scratch");
2504                b.forward(input, output, seq_len, stream, scratch)
2505            }
2506        }
2507    }
2508
2509    /// Get the layer index of this block.
2510    pub fn layer_idx(&self) -> usize {
2511        match self {
2512            CudaBlock::Fp32(b) => b.layer_idx(),
2513            CudaBlock::Nf4(b) => b.layer_idx,
2514        }
2515    }
2516
2517    /// Backward pass (only supported for fp32 blocks).
2518    ///
2519    /// NF4 blocks are frozen — backward is never called when `quantize_nf4` is active
2520    /// because `gpu_training` is set to `None`.
2521    pub fn backward(
2522        &mut self,
2523        input: &GpuBuffer<f32>,
2524        grad_output: &GpuBuffer<f32>,
2525        grad_input: &mut GpuBuffer<f32>,
2526        seq_len: usize,
2527        stream: &CudaStream,
2528        grad_ws: &mut CudaGradWorkspace,
2529    ) -> Result<()> {
2530        match self {
2531            CudaBlock::Fp32(b) => {
2532                b.backward(input, grad_output, grad_input, seq_len, stream, grad_ws)
2533            }
2534            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2535                "backward not supported on NF4 blocks (frozen weights)".into(),
2536            )),
2537        }
2538    }
2539
2540    /// Initialize optimizer state (only supported for fp32 blocks).
2541    pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2542        match self {
2543            CudaBlock::Fp32(b) => b.init_optimizer_state(),
2544            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2545                "init_optimizer_state not supported on NF4 blocks".into(),
2546            )),
2547        }
2548    }
2549
2550    /// Download weights from GPU (only supported for fp32 blocks).
2551    pub fn download_weights(&self) -> Result<BlockWeights> {
2552        match self {
2553            CudaBlock::Fp32(b) => b.download_weights(),
2554            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2555                "download_weights not supported on NF4 blocks".into(),
2556            )),
2557        }
2558    }
2559
2560    /// Optimizer step (only supported for fp32 blocks).
2561    pub fn optimizer_step(
2562        &mut self,
2563        state: &mut GpuBlockOptimizerState,
2564        step: u32,
2565        lr: f32,
2566        beta1: f32,
2567        beta2: f32,
2568        eps: f32,
2569        weight_decay: f32,
2570        stream: &CudaStream,
2571        grad_ws: &CudaGradWorkspace,
2572    ) -> Result<()> {
2573        match self {
2574            CudaBlock::Fp32(b) => {
2575                b.optimizer_step(state, step, lr, beta1, beta2, eps, weight_decay, stream, grad_ws)
2576            }
2577            CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2578                "optimizer_step not supported on NF4 blocks (frozen weights)".into(),
2579            )),
2580        }
2581    }
2582
2583    /// NF4 backward pass with LoRA gradient computation (ENT-153).
2584    ///
2585    /// Only callable on NF4 blocks. Returns error for fp32 blocks.
2586    #[allow(clippy::too_many_arguments)]
2587    pub(crate) fn backward_nf4(
2588        &self,
2589        layer_input: &GpuBuffer<f32>,
2590        grad_output: &GpuBuffer<f32>,
2591        grad_input: &mut GpuBuffer<f32>,
2592        output_scratch: &mut GpuBuffer<f32>,
2593        seq_len: usize,
2594        stream: &CudaStream,
2595        shared_scratch: &mut CudaBlockScratch,
2596        grad_lora: &mut CudaLoraGradWorkspace,
2597    ) -> Result<()> {
2598        match self {
2599            CudaBlock::Nf4(b) => b.backward(
2600                layer_input,
2601                grad_output,
2602                grad_input,
2603                output_scratch,
2604                seq_len,
2605                stream,
2606                shared_scratch,
2607                grad_lora,
2608            ),
2609            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2610                "backward_nf4 only supported on NF4 blocks".into(),
2611            )),
2612        }
2613    }
2614
2615    /// Initialize LoRA optimizer state for NF4 blocks.
2616    pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
2617        match self {
2618            CudaBlock::Nf4(b) => b.init_lora_optimizer_state(),
2619            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2620                "init_lora_optimizer_state only supported on NF4 blocks".into(),
2621            )),
2622        }
2623    }
2624
2625    /// LoRA optimizer step for NF4 blocks.
2626    #[allow(clippy::too_many_arguments)]
2627    pub(crate) fn lora_optimizer_step(
2628        &mut self,
2629        state: &mut GpuLoraOptimizerState,
2630        step: u32,
2631        lr: f32,
2632        beta1: f32,
2633        beta2: f32,
2634        eps: f32,
2635        weight_decay: f32,
2636        stream: &CudaStream,
2637        grad_lora: &CudaLoraGradWorkspace,
2638    ) -> Result<()> {
2639        match self {
2640            CudaBlock::Nf4(b) => b.lora_optimizer_step(
2641                state,
2642                step,
2643                lr,
2644                beta1,
2645                beta2,
2646                eps,
2647                weight_decay,
2648                stream,
2649                grad_lora,
2650            ),
2651            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2652                "lora_optimizer_step only supported on NF4 blocks".into(),
2653            )),
2654        }
2655    }
2656
2657    /// Download LoRA weights from NF4 blocks.
2658    pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2659        match self {
2660            CudaBlock::Nf4(b) => b.download_lora_weights(),
2661            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2662                "download_lora_weights only supported on NF4 blocks".into(),
2663            )),
2664        }
2665    }
2666
2667    /// Upload LoRA weights to NF4 blocks for checkpoint resume (ENT-276).
2668    pub fn upload_lora_weights(
2669        &mut self,
2670        a_q: &[f32],
2671        b_q: &[f32],
2672        a_v: &[f32],
2673        b_v: &[f32],
2674    ) -> Result<()> {
2675        match self {
2676            CudaBlock::Nf4(b) => b.upload_lora_weights(a_q, b_q, a_v, b_v),
2677            CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2678                "upload_lora_weights only supported on NF4 blocks".into(),
2679            )),
2680        }
2681    }
2682}
2683
2684/// CPU fallback stub for CudaBlock.
2685#[cfg(not(feature = "cuda"))]
2686pub enum CudaBlock {
2687    Fp32(CudaTransformerBlock),
2688}
2689
2690// =============================================================================
2691// NF4 Quantized Transformer Block (trueno#108: QLoRA support)
2692// =============================================================================
2693
2694/// CUDA-accelerated transformer block with NF4-quantized frozen weights.
2695///
2696/// Stores the 7 projection weights as packed NF4 (4-bit) + per-block scales instead
2697/// of fp32, achieving ~8x compression. Norm weights remain fp32 (negligible size).
2698///
2699/// # VRAM Savings (Qwen3-4B example)
2700///
2701/// | Component | fp32 | NF4 |
2702/// |-----------|------|-----|
2703/// | Frozen weights (36L × 7 projections) | 16.0 GB | 2.1 GB |
2704///
2705/// # Forward Only
2706///
2707/// NF4 blocks are frozen — no backward pass needed. LoRA adapters (fp32) handle
2708/// the trainable parameters separately. The forward pass uses fused dequant+GEMM
2709/// kernels that read NF4 directly without materializing fp32 weights.
2710#[cfg(feature = "cuda")]
2711pub struct CudaNf4TransformerBlock {
2712    config: TransformerConfig,
2713    layer_idx: usize,
2714    // Norm weights stay fp32 (tiny: 2 × hidden_size floats)
2715    input_norm_weight: GpuBuffer<f32>,
2716    post_attn_norm_weight: GpuBuffer<f32>,
2717    // Projection weights: NF4 quantized (packed data + per-block scales)
2718    w_q_nf4: GpuBuffer<u8>,
2719    w_q_scales: GpuBuffer<f32>,
2720    w_k_nf4: GpuBuffer<u8>,
2721    w_k_scales: GpuBuffer<f32>,
2722    w_v_nf4: GpuBuffer<u8>,
2723    w_v_scales: GpuBuffer<f32>,
2724    w_o_nf4: GpuBuffer<u8>,
2725    w_o_scales: GpuBuffer<f32>,
2726    w_gate_nf4: GpuBuffer<u8>,
2727    w_gate_scales: GpuBuffer<f32>,
2728    w_up_nf4: GpuBuffer<u8>,
2729    w_up_scales: GpuBuffer<f32>,
2730    w_down_nf4: GpuBuffer<u8>,
2731    w_down_scales: GpuBuffer<f32>,
2732    // ENT-287: Pre-dequantized fp32 weights for cuBLAS GEMM (correct weight layout)
2733    w_q_fp32: GpuBuffer<f32>,
2734    w_k_fp32: GpuBuffer<f32>,
2735    w_v_fp32: GpuBuffer<f32>,
2736    w_o_fp32: GpuBuffer<f32>,
2737    w_gate_fp32: GpuBuffer<f32>,
2738    w_up_fp32: GpuBuffer<f32>,
2739    w_down_fp32: GpuBuffer<f32>,
2740    // LoRA adapters for Q and V projections (ENT-153: QLoRA backward)
2741    // None when LoRA is not active (inference-only or non-QLoRA training)
2742    lora_a_q: Option<GpuBuffer<f32>>, // [hidden_size, rank]
2743    lora_b_q: Option<GpuBuffer<f32>>, // [rank, q_dim]
2744    lora_a_v: Option<GpuBuffer<f32>>, // [hidden_size, rank]
2745    lora_b_v: Option<GpuBuffer<f32>>, // [rank, kv_hidden]
2746    lora_scale: f32,
2747    lora_rank: usize,
2748    // QK-norm weights (ENT-270: per-head RMSNorm on Q and K, shape=[head_dim])
2749    q_norm_weight: Option<GpuBuffer<f32>>,
2750    k_norm_weight: Option<GpuBuffer<f32>>,
2751    // FP16 weight buffers for Tier 2 parity (PMAT-470): halve memory BW
2752    // When set, forward uses gemm_f16_to_f32 (fp16 weights × fp16 activations → fp32 output)
2753    w_q_fp16: Option<GpuBuffer<u16>>,
2754    w_k_fp16: Option<GpuBuffer<u16>>,
2755    w_v_fp16: Option<GpuBuffer<u16>>,
2756    w_o_fp16: Option<GpuBuffer<u16>>,
2757    w_gate_fp16: Option<GpuBuffer<u16>>,
2758    w_up_fp16: Option<GpuBuffer<u16>>,
2759    w_down_fp16: Option<GpuBuffer<u16>>,
2760    ctx: Arc<CudaContext>,
2761    // NF4 blocks do NOT own scratch — shared across all layers (C-SCRATCH-001)
2762}
2763
2764#[cfg(feature = "cuda")]
2765impl CudaNf4TransformerBlock {
2766    /// Create a new NF4 transformer block from fp32 CPU tensors.
2767    ///
2768    /// Quantizes all 7 projection weights to NF4 on CPU, then uploads the packed
2769    /// data and scales to GPU. Norm weights are uploaded as fp32.
2770    #[allow(clippy::too_many_arguments)]
2771    pub fn new(
2772        config: &TransformerConfig,
2773        layer_idx: usize,
2774        ctx: Arc<CudaContext>,
2775        input_norm_weight: &[f32],
2776        post_attn_norm_weight: &[f32],
2777        w_q: &[f32],
2778        w_k: &[f32],
2779        w_v: &[f32],
2780        w_o: &[f32],
2781        w_gate: &[f32],
2782        w_up: &[f32],
2783        w_down: &[f32],
2784        _max_seq_len: usize, // NF4 blocks use shared scratch (C-SCRATCH-001)
2785        // ENT-153: Optional LoRA adapters for Q and V projections
2786        q_lora: Option<(&[f32], &[f32])>,
2787        v_lora: Option<(&[f32], &[f32])>,
2788        lora_scale: f32,
2789        lora_rank: usize,
2790        // ENT-270: Optional QK-norm weights (per-head RMSNorm, shape=[head_dim])
2791        q_norm: Option<&[f32]>,
2792        k_norm: Option<&[f32]>,
2793    ) -> Result<Self> {
2794        use trueno_gpu::kernels::{quantize_nf4, NF4_BLOCK_SIZE};
2795
2796        let hidden_size = config.hidden_size;
2797        let q_dim = config.q_dim(); // num_heads * head_dim (may differ from hidden_size)
2798        let kv_hidden_size = config.num_kv_heads * config.head_dim();
2799        let intermediate_size = config.intermediate_size;
2800
2801        // ── C-NF4SHAPE-001: Weight shape contracts ──────────────────────
2802        // Ground truth: PMAT-331 validation in attention.rs from_pretrained()
2803        //   Q: [q_dim, hidden], K: [kv_hidden, hidden], V: [kv_hidden, hidden], O: [hidden, q_dim]
2804        //   gate: [intermediate, hidden], up: [intermediate, hidden], down: [hidden, intermediate]
2805        assert_eq!(
2806            w_q.len(),
2807            q_dim * hidden_size,
2808            "C-NF4SHAPE-001: w_q expected {}, got {} (q_dim={q_dim}, hidden={hidden_size})",
2809            q_dim * hidden_size,
2810            w_q.len()
2811        );
2812        assert_eq!(
2813            w_k.len(),
2814            kv_hidden_size * hidden_size,
2815            "C-NF4SHAPE-001: w_k expected {}, got {}",
2816            kv_hidden_size * hidden_size,
2817            w_k.len()
2818        );
2819        assert_eq!(
2820            w_v.len(),
2821            kv_hidden_size * hidden_size,
2822            "C-NF4SHAPE-001: w_v expected {}, got {}",
2823            kv_hidden_size * hidden_size,
2824            w_v.len()
2825        );
2826        assert_eq!(
2827            w_o.len(),
2828            hidden_size * q_dim,
2829            "C-NF4SHAPE-001: w_o expected {}, got {}",
2830            hidden_size * q_dim,
2831            w_o.len()
2832        );
2833        assert_eq!(
2834            w_gate.len(),
2835            intermediate_size * hidden_size,
2836            "C-NF4SHAPE-001: w_gate expected {}, got {}",
2837            intermediate_size * hidden_size,
2838            w_gate.len()
2839        );
2840        assert_eq!(
2841            w_up.len(),
2842            intermediate_size * hidden_size,
2843            "C-NF4SHAPE-001: w_up expected {}, got {}",
2844            intermediate_size * hidden_size,
2845            w_up.len()
2846        );
2847        assert_eq!(
2848            w_down.len(),
2849            hidden_size * intermediate_size,
2850            "C-NF4SHAPE-001: w_down expected {}, got {}",
2851            hidden_size * intermediate_size,
2852            w_down.len()
2853        );
2854
2855        // Upload norm weights as fp32
2856        let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
2857        let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
2858
2859        // Helper: quantize fp32 weight to NF4, upload packed data + scales to GPU
2860        // Returns (gpu_nf4, gpu_scales, cpu_quantized) — the CPU struct is retained
2861        // for dequantization into the cuBLAS fp32 buffer.
2862        let quantize_and_upload = |weights: &[f32],
2863                                   total: usize|
2864         -> Result<(
2865            GpuBuffer<u8>,
2866            GpuBuffer<f32>,
2867            trueno_gpu::kernels::Nf4Quantized,
2868        )> {
2869            assert_eq!(weights.len(), total, "weight length mismatch");
2870            assert!(
2871                total.is_multiple_of(NF4_BLOCK_SIZE),
2872                "weight count {total} not divisible by NF4 block size {NF4_BLOCK_SIZE}"
2873            );
2874
2875            let q = quantize_nf4(weights, total / NF4_BLOCK_SIZE, NF4_BLOCK_SIZE);
2876            let nf4_buf = GpuBuffer::from_host(&ctx, &q.data)?;
2877            let scales_buf = GpuBuffer::from_host(&ctx, &q.scales)?;
2878            Ok((nf4_buf, scales_buf, q))
2879        };
2880
2881        // Quantize all 7 projection weights (shape contracts already verified above)
2882        let (w_q_nf4, w_q_scales, w_q_nf4_q) = quantize_and_upload(w_q, q_dim * hidden_size)?;
2883        let (w_k_nf4, w_k_scales, w_k_nf4_q) =
2884            quantize_and_upload(w_k, kv_hidden_size * hidden_size)?;
2885        let (w_v_nf4, w_v_scales, w_v_nf4_q) =
2886            quantize_and_upload(w_v, kv_hidden_size * hidden_size)?;
2887        let (w_o_nf4, w_o_scales, w_o_nf4_q) = quantize_and_upload(w_o, hidden_size * q_dim)?;
2888        let (w_gate_nf4, w_gate_scales, w_gate_nf4_q) =
2889            quantize_and_upload(w_gate, intermediate_size * hidden_size)?;
2890        let (w_up_nf4, w_up_scales, w_up_nf4_q) =
2891            quantize_and_upload(w_up, intermediate_size * hidden_size)?;
2892        let (w_down_nf4, w_down_scales, w_down_nf4_q) =
2893            quantize_and_upload(w_down, hidden_size * intermediate_size)?;
2894
2895        // ENT-287: Dequantize NF4 weights and TRANSPOSE to [K,N] for standard cuBLAS GEMM.
2896        //
2897        // bitsandbytes does: F.dequantize_4bit(B).to(dtype).t()
2898        // The .t() transposes [N,K] → [K,N] so torch.nn.functional.linear
2899        // can use standard C = A @ B^T internally.
2900        //
2901        // We replicate this exactly:
2902        // 1. dequantize_nf4(q) → flat [N*K] in [N,K] row-major order
2903        // 2. CPU transpose [N,K] → [K,N]
2904        // 3. Upload transposed buffer to GPU
2905        // 4. Use standard gemm_forward (NoTrans, NoTrans) — same as LoRA weights
2906        use trueno_gpu::kernels::dequantize_nf4;
2907        let dequant_transpose_upload = |q: &trueno_gpu::kernels::Nf4Quantized,
2908                                        n: usize,
2909                                        k: usize|
2910         -> std::result::Result<
2911            GpuBuffer<f32>,
2912            crate::autograd::cuda_tensor::CudaTensorError,
2913        > {
2914            let deq = dequantize_nf4(q); // [N*K] in [N,K] row-major
2915            let nonzero = deq.iter().filter(|&&x| x != 0.0).count();
2916            eprintln!(
2917                "[TRACE] dequant n={n} k={k} len={} nonzero={nonzero} first5={:?}",
2918                deq.len(),
2919                &deq[..5.min(deq.len())]
2920            );
2921            assert_eq!(deq.len(), n * k, "dequant size mismatch: {} vs {}x{}", deq.len(), n, k);
2922            // Transpose [N,K] → [K,N]
2923            let mut transposed = vec![0.0f32; n * k];
2924            for row in 0..n {
2925                for col in 0..k {
2926                    transposed[col * n + row] = deq[row * k + col];
2927                }
2928            }
2929            let buf = GpuBuffer::from_host(&ctx, &transposed).map_err(|e| {
2930                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2931                    "dequant transpose upload: {e:?}"
2932                ))
2933            })?;
2934            // Verify upload: must read FULL buffer then slice
2935            let mut verify_full = vec![0.0f32; buf.len()];
2936            let verify_ok = buf.copy_to_host(&mut verify_full).is_ok();
2937            let verify5: Vec<f32> = verify_full.iter().copied().take(5).collect();
2938            let nz = verify_full.iter().filter(|&&x| x != 0.0).count();
2939            eprintln!("[TRACE] uploaded ptr={:?} len={} copy_ok={verify_ok} nonzero={nz} verify[:5]={verify5:?}", buf.as_ptr(), buf.len());
2940            Ok(buf)
2941        };
2942        // Each weight is [out_features, in_features] = [N, K].
2943        // After transpose: [K, N] — standard cuBLAS B layout.
2944        let w_q_fp32 = dequant_transpose_upload(&w_q_nf4_q, q_dim, hidden_size)?;
2945        let w_k_fp32 = dequant_transpose_upload(&w_k_nf4_q, kv_hidden_size, hidden_size)?;
2946        let w_v_fp32 = dequant_transpose_upload(&w_v_nf4_q, kv_hidden_size, hidden_size)?;
2947        let w_o_fp32 = dequant_transpose_upload(&w_o_nf4_q, hidden_size, q_dim)?;
2948        let w_gate_fp32 = dequant_transpose_upload(&w_gate_nf4_q, intermediate_size, hidden_size)?;
2949        let w_up_fp32 = dequant_transpose_upload(&w_up_nf4_q, intermediate_size, hidden_size)?;
2950        let w_down_fp32 = dequant_transpose_upload(&w_down_nf4_q, hidden_size, intermediate_size)?;
2951
2952        // NF4 blocks do NOT allocate scratch — shared across all layers (C-SCRATCH-001).
2953        // Pipeline allocates one CudaBlockScratch and passes &mut to each forward() call.
2954        // Saves (L-1) * 214 MB = 7.5 GB for Qwen3-4B (36 layers).
2955
2956        // Upload LoRA adapters to GPU (ENT-153)
2957        // B matrices are pre-scaled by lora_scale to avoid a separate scale kernel in forward.
2958        let (lora_a_q, lora_b_q) = match q_lora {
2959            Some((a_data, b_data)) => {
2960                let a = GpuBuffer::from_host(&ctx, a_data)?;
2961                let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2962                let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2963                (Some(a), Some(b))
2964            }
2965            None => (None, None),
2966        };
2967        let (lora_a_v, lora_b_v) = match v_lora {
2968            Some((a_data, b_data)) => {
2969                let a = GpuBuffer::from_host(&ctx, a_data)?;
2970                let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2971                let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2972                (Some(a), Some(b))
2973            }
2974            None => (None, None),
2975        };
2976
2977        // ENT-270: Upload QK-norm weights if present
2978        let q_norm_weight = match q_norm {
2979            Some(w) => {
2980                assert_eq!(
2981                    w.len(),
2982                    config.head_dim(),
2983                    "ENT-270: q_norm weight expected [head_dim={}], got [{}]",
2984                    config.head_dim(),
2985                    w.len()
2986                );
2987                Some(GpuBuffer::from_host(&ctx, w)?)
2988            }
2989            None => None,
2990        };
2991        let k_norm_weight = match k_norm {
2992            Some(w) => {
2993                assert_eq!(
2994                    w.len(),
2995                    config.head_dim(),
2996                    "ENT-270: k_norm weight expected [head_dim={}], got [{}]",
2997                    config.head_dim(),
2998                    w.len()
2999                );
3000                Some(GpuBuffer::from_host(&ctx, w)?)
3001            }
3002            None => None,
3003        };
3004
3005        Ok(Self {
3006            config: config.clone(),
3007            layer_idx,
3008            input_norm_weight,
3009            post_attn_norm_weight,
3010            w_q_nf4,
3011            w_q_scales,
3012            w_k_nf4,
3013            w_k_scales,
3014            w_v_nf4,
3015            w_v_scales,
3016            w_o_nf4,
3017            w_o_scales,
3018            w_gate_nf4,
3019            w_gate_scales,
3020            w_up_nf4,
3021            w_up_scales,
3022            w_down_nf4,
3023            w_down_scales,
3024            w_q_fp32,
3025            w_k_fp32,
3026            w_v_fp32,
3027            w_o_fp32,
3028            w_gate_fp32,
3029            w_up_fp32,
3030            w_down_fp32,
3031            lora_a_q,
3032            lora_b_q,
3033            lora_a_v,
3034            lora_b_v,
3035            lora_scale,
3036            lora_rank,
3037            q_norm_weight,
3038            k_norm_weight,
3039            // FP16 weights: None by default, populated by set_fp16_weights() (PMAT-470)
3040            w_q_fp16: None,
3041            w_k_fp16: None,
3042            w_v_fp16: None,
3043            w_o_fp16: None,
3044            w_gate_fp16: None,
3045            w_up_fp16: None,
3046            w_down_fp16: None,
3047            ctx,
3048        })
3049    }
3050
3051    /// Cast fp32→fp16 weights + drop fp32 (PMAT-470/472). Frees ~2.6 GB VRAM.
3052    pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()> {
3053        let cast_weight = |w_fp32: &GpuBuffer<f32>, ctx: &CudaContext| -> Result<GpuBuffer<u16>> {
3054            let n = w_fp32.len();
3055            let mut w_fp16 = GpuBuffer::<u16>::new(ctx, n)?;
3056            cast_f32_to_f16_gpu(w_fp32, &mut w_fp16, n as u32, stream)?;
3057            Ok(w_fp16)
3058        };
3059
3060        self.w_q_fp16 = Some(cast_weight(&self.w_q_fp32, &self.ctx)?);
3061        self.w_k_fp16 = Some(cast_weight(&self.w_k_fp32, &self.ctx)?);
3062        self.w_v_fp16 = Some(cast_weight(&self.w_v_fp32, &self.ctx)?);
3063        self.w_o_fp16 = Some(cast_weight(&self.w_o_fp32, &self.ctx)?);
3064        self.w_gate_fp16 = Some(cast_weight(&self.w_gate_fp32, &self.ctx)?);
3065        self.w_up_fp16 = Some(cast_weight(&self.w_up_fp32, &self.ctx)?);
3066        self.w_down_fp16 = Some(cast_weight(&self.w_down_fp32, &self.ctx)?);
3067
3068        stream.synchronize().map_err(|e| {
3069            crate::autograd::cuda_tensor::CudaTensorError::KernelError(format!(
3070                "FP16 weight cast sync failed: {e:?}"
3071            ))
3072        })?;
3073        // PMAT-472: Drop fp32 weights — backward now uses fp16 via gemm_backward_a_fp16_dispatch.
3074        // Frees ~2.6 GB VRAM on yoga 8GB, allowing GPU embeddings to fit.
3075        let dummy = |ctx: &CudaContext| GpuBuffer::<f32>::new(ctx, 1).unwrap();
3076        self.w_q_fp32 = dummy(&self.ctx);
3077        self.w_k_fp32 = dummy(&self.ctx);
3078        self.w_v_fp32 = dummy(&self.ctx);
3079        self.w_o_fp32 = dummy(&self.ctx);
3080        self.w_gate_fp32 = dummy(&self.ctx);
3081        self.w_up_fp32 = dummy(&self.ctx);
3082        self.w_down_fp32 = dummy(&self.ctx);
3083        eprintln!("[FP16] Weights cast + fp32 dropped (~2.6 GB freed)");
3084
3085        Ok(())
3086    }
3087
3088    /// Forward pass: cuBLAS GEMM with pre-dequantized weights (ENT-287, C-SCRATCH-001).
3089    #[rustfmt::skip]
3090    pub(crate) fn forward(
3091        &self,
3092        input: &GpuBuffer<f32>,
3093        output: &mut GpuBuffer<f32>,
3094        seq_len: usize,
3095        stream: &CudaStream,
3096        scratch: &mut CudaBlockScratch,
3097    ) -> Result<()> {
3098        use crate::autograd::cuda_forward::{gemm_forward, gemm_nf4_forward, gemm_nf4_tc_forward};
3099
3100        let hidden_size = self.config.hidden_size;
3101        let q_dim = self.config.q_dim();
3102        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3103        let intermediate_size = self.config.intermediate_size;
3104
3105        // entrenar#318: scratch zeroing moved to forward_cuda_training (once per step, not per layer)
3106        scratch.prepare_causal_mask(seq_len, &self.ctx)?;
3107
3108        // === Pre-attention RMSNorm === (PMAT-483: per-op profiling)
3109        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: thread config eps so Qwen2
3110        // (1e-6) and Llama (1e-5) get the right epsilon (was hardcoded 1e-5).
3111        let _t = scratch.op_begin();
3112        rms_norm_forward_with_eps(
3113            input,
3114            &self.input_norm_weight,
3115            &mut scratch.norm1_out,
3116            saturating_u32(seq_len),
3117            saturating_u32(hidden_size),
3118            self.config.rms_norm_eps,
3119            stream,
3120        )?;
3121        scratch.op_end(_t, OP_RMSNORM_ATTN);
3122
3123        // === Q, K, V Projections ===
3124        // Backend selection:
3125        //   FP16_GEMM=1: fp16 tensor core GEMM (Tier 2 parity, 2x BW savings)
3126        //   NF4_FUSED_GEMM=1: fused dequant+GEMM (8x less BW, 100% GPU, but naive PTX)
3127        //   Default: cuBLAS fp32 (197 tok/s, 7% GPU, memory-BW bound)
3128        static USE_NF4_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3129        let nf4_gemm = *USE_NF4_GEMM.get_or_init(|| std::env::var("NF4_FUSED_GEMM").as_deref() == Ok("1"));
3130        static USE_NF4_TC_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3131        let nf4_tc_gemm = *USE_NF4_TC_GEMM.get_or_init(|| std::env::var("NF4_TC_GEMM").as_deref() == Ok("1"));
3132        static USE_FP16_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3133        let fp16_gemm = *USE_FP16_GEMM.get_or_init(|| std::env::var("FP16_GEMM").as_deref() == Ok("1"));
3134
3135        // FP16 path: cast activation once, reuse for all projections
3136        let act_n = (seq_len * hidden_size) as u32;
3137        if fp16_gemm && self.w_q_fp16.is_some() {
3138            // Lazy-allocate fp16 activation buffer
3139            if scratch.norm1_out_f16.is_none() {
3140                scratch.norm1_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3141            }
3142            let f16_buf = scratch.norm1_out_f16.as_mut().unwrap();
3143            cast_f32_to_f16_gpu(&scratch.norm1_out, f16_buf, act_n, stream)?;
3144        }
3145
3146        let _t = scratch.op_begin(); // QKV GEMM timing
3147        if fp16_gemm && self.w_q_fp16.is_some() {
3148            let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3149            gemm_f16_to_f32_forward(f16_act, self.w_q_fp16.as_ref().unwrap(), &mut scratch.q,
3150                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3151        } else if nf4_tc_gemm {
3152            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3153                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3154        } else if nf4_gemm {
3155            gemm_nf4_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3156                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3157        } else {
3158            gemm_forward(&scratch.norm1_out, &self.w_q_fp32, &mut scratch.q,
3159                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3160        }
3161
3162        // ENT-153: Q LoRA: q += (norm1_out @ A_q) @ B_q  (B_q pre-scaled by lora_scale)
3163        if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
3164            let s = saturating_u32(seq_len);
3165            let h = saturating_u32(hidden_size);
3166            let r = saturating_u32(self.lora_rank);
3167            let qd = saturating_u32(q_dim);
3168            // lora_inter[seq, rank] = norm1_out[seq, hidden] @ A_q[hidden, rank]
3169            gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
3170            // lora_temp[seq, q_dim] = lora_inter[seq, rank] @ B_q[rank, q_dim]
3171            gemm_forward(&scratch.lora_inter, b_q, &mut scratch.lora_temp, s, r, qd, stream)?;
3172            // q += lora_temp (in-place add)
3173            cuda_add_inplace(&mut scratch.q, &scratch.lora_temp, seq_len * q_dim, stream)?;
3174        }
3175
3176        if fp16_gemm && self.w_k_fp16.is_some() {
3177            let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3178            gemm_f16_to_f32_forward(f16_act, self.w_k_fp16.as_ref().unwrap(), &mut scratch.k,
3179                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3180            gemm_f16_to_f32_forward(f16_act, self.w_v_fp16.as_ref().unwrap(), &mut scratch.v,
3181                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3182        } else if nf4_tc_gemm {
3183            // PMAT-481: NF4 tensor core GEMM for K and V projections (separate)
3184            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_k_nf4, &self.w_k_scales, &mut scratch.k,
3185                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3186            gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_v_nf4, &self.w_v_scales, &mut scratch.v,
3187                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3188        } else if nf4_gemm {
3189            // PMAT-478: Fused K+V — shared input load (same pattern as Gate+Up)
3190            crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3191                &scratch.norm1_out,
3192                &self.w_k_nf4, &self.w_k_scales,
3193                &self.w_v_nf4, &self.w_v_scales,
3194                &mut scratch.k, &mut scratch.v,
3195                saturating_u32(seq_len), saturating_u32(hidden_size),
3196                saturating_u32(kv_hidden_size), stream,
3197            )?;
3198        } else {
3199            gemm_forward(&scratch.norm1_out, &self.w_k_fp32, &mut scratch.k,
3200                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3201            gemm_forward(&scratch.norm1_out, &self.w_v_fp32, &mut scratch.v,
3202                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3203        }
3204
3205        scratch.op_end(_t, OP_QKV_GEMM); // End QKV timing (includes Q/K/V GEMMs + Q LoRA)
3206
3207        // ENT-153: V LoRA: v += (norm1_out @ A_v) @ B_v  (B_v pre-scaled by lora_scale)
3208        if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
3209            let s = saturating_u32(seq_len);
3210            let h = saturating_u32(hidden_size);
3211            let r = saturating_u32(self.lora_rank);
3212            let vd = saturating_u32(kv_hidden_size);
3213            // lora_inter[seq, rank] = norm1_out[seq, hidden] @ A_v[hidden, rank]
3214            gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
3215            // lora_temp[seq, kv_hidden] = lora_inter[seq, rank] @ B_v[rank, kv_hidden]
3216            gemm_forward(&scratch.lora_inter, b_v, &mut scratch.lora_temp, s, r, vd, stream)?;
3217            // v += lora_temp (in-place add)
3218            cuda_add_inplace(&mut scratch.v, &scratch.lora_temp, seq_len * kv_hidden_size, stream)?;
3219        }
3220
3221        // === Multi-Head Attention (GPU-only, zero CPU transfers) ===
3222        let _t = scratch.op_begin();
3223        self.compute_attention_cuda(seq_len, stream, scratch)?;
3224        scratch.op_end(_t, OP_ATTENTION);
3225
3226        // === Output Projection ===
3227        let _t = scratch.op_begin();
3228        if fp16_gemm && self.w_o_fp16.is_some() {
3229            if scratch.attn_out_f16.is_none() {
3230                scratch.attn_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * q_dim)?);
3231            }
3232            let f16_buf = scratch.attn_out_f16.as_mut().unwrap();
3233            cast_f32_to_f16_gpu(&scratch.attn_out, f16_buf, (seq_len * q_dim) as u32, stream)?;
3234            gemm_f16_to_f32_forward(f16_buf, self.w_o_fp16.as_ref().unwrap(), &mut scratch.o_proj_out,
3235                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3236        } else if nf4_tc_gemm {
3237            gemm_nf4_tc_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3238                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3239        } else if nf4_gemm {
3240            gemm_nf4_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3241                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3242        } else {
3243            gemm_forward(&scratch.attn_out, &self.w_o_fp32, &mut scratch.o_proj_out,
3244                saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3245        }
3246
3247        scratch.op_end(_t, OP_O_PROJ);
3248
3249        // === Fused Residual Add + RMSNorm (entrenar#321: eliminates NaN cascade) ===
3250        let _t = scratch.op_begin();
3251        // The separate cuda_add + rms_norm_forward allows activation explosion between
3252        // the two operations. Fusing them prevents NaN in layers 24-27 because RMSNorm
3253        // normalizes the residual sum immediately, before it can propagate.
3254        fused_residual_rmsnorm_forward(
3255            input,
3256            &scratch.o_proj_out,
3257            &mut scratch.residual1,
3258            &mut scratch.norm2_out,
3259            &self.post_attn_norm_weight,
3260            saturating_u32(seq_len),
3261            saturating_u32(hidden_size),
3262            stream,
3263        )?;
3264
3265        scratch.op_end(_t, OP_RMSNORM_FFN); // Fused residual + RMSNorm
3266
3267        // === FFN: Gate + Up + SwiGLU + Down ===
3268        let _t = scratch.op_begin(); // Gate+Up GEMM timing
3269        if fp16_gemm && self.w_gate_fp16.is_some() {
3270            if scratch.norm2_out_f16.is_none() {
3271                scratch.norm2_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3272            }
3273            let f16_buf = scratch.norm2_out_f16.as_mut().unwrap();
3274            cast_f32_to_f16_gpu(&scratch.norm2_out, f16_buf, (seq_len * hidden_size) as u32, stream)?;
3275            gemm_f16_to_f32_forward(f16_buf, self.w_gate_fp16.as_ref().unwrap(), &mut scratch.gate_out,
3276                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3277            gemm_f16_to_f32_forward(f16_buf, self.w_up_fp16.as_ref().unwrap(), &mut scratch.up_out,
3278                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3279        } else if nf4_tc_gemm {
3280            // PMAT-481: NF4 tensor core GEMM for Gate and Up projections (separate)
3281            gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_gate_nf4, &self.w_gate_scales, &mut scratch.gate_out,
3282                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3283            gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_up_nf4, &self.w_up_scales, &mut scratch.up_out,
3284                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3285        } else if nf4_gemm {
3286            // PMAT-475: Fused gate+up — shared input load, saves M×K×4 bytes DRAM
3287            crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3288                &scratch.norm2_out,
3289                &self.w_gate_nf4, &self.w_gate_scales,
3290                &self.w_up_nf4, &self.w_up_scales,
3291                &mut scratch.gate_out, &mut scratch.up_out,
3292                saturating_u32(seq_len), saturating_u32(hidden_size),
3293                saturating_u32(intermediate_size), stream,
3294            )?;
3295        } else {
3296            gemm_forward(&scratch.norm2_out, &self.w_gate_fp32, &mut scratch.gate_out,
3297                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3298            gemm_forward(&scratch.norm2_out, &self.w_up_fp32, &mut scratch.up_out,
3299                saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3300        }
3301
3302        scratch.op_end(_t, OP_GATE_UP_GEMM);
3303
3304        // === FFN: Fused SwiGLU ===
3305        let _t = scratch.op_begin();
3306        fused_swiglu_forward(&scratch.gate_out, &scratch.up_out, &mut scratch.swiglu_out,
3307            saturating_u32(seq_len * intermediate_size), stream)?;
3308        scratch.op_end(_t, OP_SILU);
3309
3310        // === FFN: Down Projection ===
3311        let _t = scratch.op_begin();
3312        if fp16_gemm && self.w_down_fp16.is_some() {
3313            if scratch.swiglu_out_f16.is_none() {
3314                scratch.swiglu_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * intermediate_size)?);
3315            }
3316            let f16_buf = scratch.swiglu_out_f16.as_mut().unwrap();
3317            cast_f32_to_f16_gpu(&scratch.swiglu_out, f16_buf, (seq_len * intermediate_size) as u32, stream)?;
3318            gemm_f16_to_f32_forward(f16_buf, self.w_down_fp16.as_ref().unwrap(), &mut scratch.ffn_out,
3319                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3320        } else if nf4_tc_gemm {
3321            gemm_nf4_tc_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3322                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3323        } else if nf4_gemm {
3324            gemm_nf4_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3325                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3326        } else {
3327            gemm_forward(&scratch.swiglu_out, &self.w_down_fp32, &mut scratch.ffn_out,
3328                saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3329        }
3330
3331        scratch.op_end(_t, OP_DOWN_GEMM);
3332
3333        // === Final Residual Add ===
3334        cuda_add(&scratch.residual1, &scratch.ffn_out, output, seq_len * hidden_size, stream)?;
3335
3336        Ok(())
3337    }
3338
3339    /// Layer index accessor.
3340    pub fn layer_idx(&self) -> usize {
3341        self.layer_idx
3342    }
3343}
3344
3345/// Helper: delegate attention computation using shared scratch buffers.
3346///
3347/// `CudaNf4TransformerBlock` reuses the same attention pipeline as the fp32 block
3348/// since attention operates on fp32 activations (Q/K/V are already dequantized by GEMM).
3349#[cfg(feature = "cuda")]
3350impl CudaNf4TransformerBlock {
3351    fn compute_attention_cuda(
3352        &self,
3353        seq_len: usize,
3354        stream: &CudaStream,
3355        scratch: &mut CudaBlockScratch,
3356    ) -> Result<()> {
3357        let num_heads = self.config.num_attention_heads;
3358        let num_kv_heads = self.config.num_kv_heads;
3359        let head_dim = self.config.head_dim();
3360        let heads_per_kv = num_heads / num_kv_heads;
3361
3362        let s = saturating_u32(seq_len);
3363        let nh = saturating_u32(num_heads);
3364        let nkv = saturating_u32(num_kv_heads);
3365        let hd = saturating_u32(head_dim);
3366
3367        // ── ENT-270: Apply QK-norm (per-head RMSNorm) on Q and K ──────────
3368        // Must happen BEFORE RoPE, matching CPU path ordering:
3369        //   projection → QK-norm → RoPE → attention
3370        // SAFETY: In-place GPU operations — CUDA kernels read all input before writing output.
3371        if let Some(ref q_norm) = self.q_norm_weight {
3372            for pos in 0..seq_len {
3373                let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3374                per_head_rmsnorm_forward(q_ref, q_norm, &mut scratch.q, nh, hd, pos, stream)?;
3375            }
3376        }
3377        if let Some(ref k_norm) = self.k_norm_weight {
3378            for pos in 0..seq_len {
3379                let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3380                per_head_rmsnorm_forward(k_ref, k_norm, &mut scratch.k, nkv, hd, pos, stream)?;
3381            }
3382        }
3383
3384        // ── ENT-270: Apply RoPE (NeoX half-rotation) on Q and K ──────────
3385        // ALB-119: Batched launch (2 kernels) replaces per-position loop (2*seq_len kernels)
3386        let rope_theta = self.config.rope_theta;
3387        {
3388            let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3389            batched_rope_neox_forward(
3390                q_ref,
3391                &mut scratch.q,
3392                &scratch.rope_positions,
3393                nh,
3394                hd,
3395                s,
3396                rope_theta,
3397                stream,
3398            )?;
3399            let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3400            batched_rope_neox_forward(
3401                k_ref,
3402                &mut scratch.k,
3403                &scratch.rope_positions,
3404                nkv,
3405                hd,
3406                s,
3407                rope_theta,
3408                stream,
3409            )?;
3410        }
3411
3412        // Q: interleaved → batched layout
3413        interleaved_to_batched_forward(&scratch.q, &mut scratch.attn_q_batched, s, nh, hd, stream)?;
3414
3415        // K: interleaved → batched, then GQA expand if needed
3416        interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3417
3418        if heads_per_kv > 1 {
3419            expand_kv_heads(
3420                &scratch.attn_kv_temp,
3421                &mut scratch.attn_kv_temp2,
3422                num_kv_heads,
3423                heads_per_kv,
3424                seq_len * head_dim,
3425                stream,
3426            )?;
3427        } else {
3428            // SAFETY: D2D copy with matching buffer sizes
3429            unsafe {
3430                scratch
3431                    .attn_kv_temp2
3432                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3433                    .map_err(|e| {
3434                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3435                            "K copy failed: {e:?}"
3436                        ))
3437                    })?;
3438            }
3439        }
3440
3441        // K^T: transpose for attention scores
3442        batched_transpose_forward(
3443            &scratch.attn_kv_temp2,
3444            &mut scratch.attn_kv_temp,
3445            nh,
3446            s,
3447            hd,
3448            stream,
3449        )?;
3450
3451        // Q @ K^T → attention scores
3452        batched_4d_gemm_forward(
3453            &scratch.attn_q_batched,
3454            &scratch.attn_kv_temp,
3455            &mut scratch.attn_scores,
3456            1,
3457            nh,
3458            s,
3459            s,
3460            hd,
3461            stream,
3462        )?;
3463
3464        // Scale by 1/sqrt(head_dim)
3465        let scale_factor = 1.0 / (head_dim as f32).sqrt();
3466        let total_scores = num_heads * seq_len * seq_len;
3467        let scores_view = unsafe {
3468            GpuBuffer::<f32>::from_raw_parts(
3469                scratch.attn_scores.as_ptr(),
3470                scratch.attn_scores.len(),
3471            )
3472        };
3473        scale_forward(
3474            &scores_view,
3475            &mut scratch.attn_scores,
3476            scale_factor,
3477            saturating_u32(total_scores),
3478            stream,
3479        )?;
3480        leak(scores_view);
3481
3482        // Softmax (in-place: input aliased with output via unsafe view)
3483        // C-CAUSAL-001: Apply causal mask before softmax (NF4 path)
3484        // PMAT-420: Use causal_mask_contiguous (correctly strided for seq_len)
3485        // instead of causal_mask (strided at max_seq_len, causes row misalignment
3486        // when seq_len < max_seq_len, leading to NaN after deep layers).
3487        {
3488            let seq_sq = seq_len * seq_len;
3489            let mask_ptr = scratch.causal_mask_contiguous.as_ptr();
3490            let scores_base = scratch.attn_scores.as_ptr();
3491            for head in 0..num_heads {
3492                let byte_offset = (head * seq_sq * 4) as u64;
3493                let head_ptr = scores_base + byte_offset;
3494                let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
3495                let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3496                let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3497                residual_add_forward(
3498                    &mask_view,
3499                    &scores_view,
3500                    &mut out_view,
3501                    saturating_u32(seq_sq),
3502                    stream,
3503                )?;
3504                leak(mask_view);
3505                leak(scores_view);
3506                leak(out_view);
3507            }
3508        }
3509
3510        // SAFETY: The softmax kernel reads each row completely into shared memory / registers
3511        // before writing output. The view is forgotten to prevent double-free.
3512        let scores_view = unsafe {
3513            GpuBuffer::<f32>::from_raw_parts(
3514                scratch.attn_scores.as_ptr(),
3515                scratch.attn_scores.len(),
3516            )
3517        };
3518        batched_softmax_forward(
3519            &scores_view,
3520            &mut scratch.attn_scores,
3521            saturating_u32(num_heads * seq_len),
3522            s,
3523            stream,
3524        )?;
3525        leak(scores_view);
3526
3527        // V: interleaved → batched, then GQA expand
3528        interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3529
3530        if heads_per_kv > 1 {
3531            expand_kv_heads(
3532                &scratch.attn_kv_temp,
3533                &mut scratch.attn_kv_temp2,
3534                num_kv_heads,
3535                heads_per_kv,
3536                seq_len * head_dim,
3537                stream,
3538            )?;
3539        } else {
3540            // SAFETY: async GPU buffer copy within same CUDA stream; both buffers are
3541            // pre-allocated scratch with matching sizes, and stream ordering guarantees
3542            // the source is fully written before this copy executes.
3543            unsafe {
3544                scratch
3545                    .attn_kv_temp2
3546                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3547                    .map_err(|e| {
3548                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3549                            "V copy failed: {e:?}"
3550                        ))
3551                    })?;
3552            }
3553        }
3554
3555        // attn_scores @ V → attention output
3556        batched_4d_gemm_forward(
3557            &scratch.attn_scores,
3558            &scratch.attn_kv_temp2,
3559            &mut scratch.attn_q_batched,
3560            1,
3561            nh,
3562            s,
3563            hd,
3564            s,
3565            stream,
3566        )?;
3567
3568        // Batched → interleaved layout
3569        batched_to_interleaved_forward(
3570            &scratch.attn_q_batched,
3571            &mut scratch.attn_out,
3572            s,
3573            nh,
3574            hd,
3575            stream,
3576        )?;
3577
3578        Ok(())
3579    }
3580}
3581
3582// =============================================================================
3583// QLoRA Backward Pass Types (ENT-153)
3584// =============================================================================
3585
3586/// Shared gradient workspace for LoRA weight gradients (one per model, NOT per layer).
3587///
3588/// Backward processes layers sequentially — only one layer's LoRA gradients
3589/// are computed at a time. Sharing this workspace saves
3590/// `(L-1) * per_layer_lora_grad_elements * 4` bytes of VRAM.
3591///
3592/// # Contract (C-LORAGRADWS-001)
3593///
3594/// - **Precondition**: Allocated once before training loop starts
3595/// - **Postcondition**: After backward() for layer i, contains layer i's LoRA gradients
3596/// - **Invariant**: Buffer sizes match model config; never reallocated during training
3597#[cfg(feature = "cuda")]
3598pub(crate) struct CudaLoraGradWorkspace {
3599    /// Gradient for LoRA A_q [hidden_size, rank]
3600    pub(crate) grad_lora_a_q: GpuBuffer<f32>,
3601    /// Gradient for LoRA B_q [rank, q_dim]
3602    pub(crate) grad_lora_b_q: GpuBuffer<f32>,
3603    /// Gradient for LoRA A_v [hidden_size, rank]
3604    pub(crate) grad_lora_a_v: GpuBuffer<f32>,
3605    /// Gradient for LoRA B_v [rank, kv_hidden]
3606    pub(crate) grad_lora_b_v: GpuBuffer<f32>,
3607    /// Gradient for input norm weight [hidden_size]
3608    pub(crate) grad_input_norm: GpuBuffer<f32>,
3609    /// Gradient for post-attention norm weight [hidden_size]
3610    pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
3611}
3612
3613#[cfg(feature = "cuda")]
3614impl CudaLoraGradWorkspace {
3615    /// Allocate shared LoRA gradient workspace.
3616    pub(crate) fn new(
3617        ctx: &Arc<CudaContext>,
3618        config: &super::config::TransformerConfig,
3619        lora_rank: usize,
3620    ) -> Result<Self> {
3621        let h = config.hidden_size;
3622        let q_dim = config.q_dim();
3623        let kv = config.num_kv_heads * config.head_dim();
3624        let r = lora_rank;
3625
3626        Ok(Self {
3627            grad_lora_a_q: GpuBuffer::new(ctx, h * r)?,
3628            grad_lora_b_q: GpuBuffer::new(ctx, r * q_dim)?,
3629            grad_lora_a_v: GpuBuffer::new(ctx, h * r)?,
3630            grad_lora_b_v: GpuBuffer::new(ctx, r * kv)?,
3631            grad_input_norm: GpuBuffer::new(ctx, h)?,
3632            grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
3633        })
3634    }
3635
3636    /// ENT-265: Clip all 6 LoRA gradient buffers by global L2 norm.
3637    ///
3638    /// Computes the global L2 norm across A_q, B_q, A_v, B_v, input_norm,
3639    /// and post_attn_norm. If the norm exceeds `max_norm`, scales all buffers
3640    /// down by `max_norm / (total_norm + 1e-6)`.
3641    ///
3642    /// Two-phase design: phase 1 reads norms (immutable), phase 2 applies
3643    /// scale (mutable). This satisfies the borrow checker when the workspace
3644    /// is behind a mutable reference.
3645    pub(crate) fn clip_gradients(&mut self, max_norm: f32, stream: &CudaStream) {
3646        // Phase 1: compute global L2 norm
3647        let sq_a_q = squared_sum_cuda(&self.grad_lora_a_q, self.grad_lora_a_q.len() as u32, stream)
3648            .unwrap_or(0.0);
3649        let sq_b_q = squared_sum_cuda(&self.grad_lora_b_q, self.grad_lora_b_q.len() as u32, stream)
3650            .unwrap_or(0.0);
3651        let sq_a_v = squared_sum_cuda(&self.grad_lora_a_v, self.grad_lora_a_v.len() as u32, stream)
3652            .unwrap_or(0.0);
3653        let sq_b_v = squared_sum_cuda(&self.grad_lora_b_v, self.grad_lora_b_v.len() as u32, stream)
3654            .unwrap_or(0.0);
3655        let sq_in =
3656            squared_sum_cuda(&self.grad_input_norm, self.grad_input_norm.len() as u32, stream)
3657                .unwrap_or(0.0);
3658        let sq_pa = squared_sum_cuda(
3659            &self.grad_post_attn_norm,
3660            self.grad_post_attn_norm.len() as u32,
3661            stream,
3662        )
3663        .unwrap_or(0.0);
3664        let total_norm = (sq_a_q + sq_b_q + sq_a_v + sq_b_v + sq_in + sq_pa).sqrt();
3665
3666        if total_norm <= max_norm {
3667            return;
3668        }
3669
3670        // Phase 2: apply clip scale
3671        let clip_scale = max_norm / (total_norm + 1e-6);
3672        let n_aq = self.grad_lora_a_q.len() as u32;
3673        let n_bq = self.grad_lora_b_q.len() as u32;
3674        let n_av = self.grad_lora_a_v.len() as u32;
3675        let n_bv = self.grad_lora_b_v.len() as u32;
3676        let n_in = self.grad_input_norm.len() as u32;
3677        let n_pa = self.grad_post_attn_norm.len() as u32;
3678        let _ = gradient_clip_cuda(&mut self.grad_lora_a_q, clip_scale, n_aq, stream);
3679        let _ = gradient_clip_cuda(&mut self.grad_lora_b_q, clip_scale, n_bq, stream);
3680        let _ = gradient_clip_cuda(&mut self.grad_lora_a_v, clip_scale, n_av, stream);
3681        let _ = gradient_clip_cuda(&mut self.grad_lora_b_v, clip_scale, n_bv, stream);
3682        let _ = gradient_clip_cuda(&mut self.grad_input_norm, clip_scale, n_in, stream);
3683        let _ = gradient_clip_cuda(&mut self.grad_post_attn_norm, clip_scale, n_pa, stream);
3684    }
3685}
3686
3687/// GPU-resident AdamW optimizer state for LoRA adapters in one NF4 block.
3688///
3689/// Stores first (m) and second (v) moment estimates for:
3690/// - 4 LoRA weight tensors (A_q, B_q, A_v, B_v)
3691/// - 2 RMSNorm weights (input_norm, post_attn_norm)
3692///
3693/// # Contract (C-LORAOPT-001)
3694///
3695/// - **Precondition**: CUDA context valid, buffers match weight dimensions
3696/// - **Postcondition**: m and v initialized to zero
3697/// - **Invariant**: Buffer sizes immutable after creation
3698#[cfg(feature = "cuda")]
3699pub(crate) struct GpuLoraOptimizerState {
3700    m_lora_a_q: GpuBuffer<f32>,
3701    v_lora_a_q: GpuBuffer<f32>,
3702    m_lora_b_q: GpuBuffer<f32>,
3703    v_lora_b_q: GpuBuffer<f32>,
3704    m_lora_a_v: GpuBuffer<f32>,
3705    v_lora_a_v: GpuBuffer<f32>,
3706    m_lora_b_v: GpuBuffer<f32>,
3707    v_lora_b_v: GpuBuffer<f32>,
3708    m_input_norm: GpuBuffer<f32>,
3709    v_input_norm: GpuBuffer<f32>,
3710    m_post_attn_norm: GpuBuffer<f32>,
3711    v_post_attn_norm: GpuBuffer<f32>,
3712}
3713
3714#[cfg(feature = "cuda")]
3715impl GpuLoraOptimizerState {
3716    fn new(
3717        ctx: &Arc<CudaContext>,
3718        config: &super::config::TransformerConfig,
3719        lora_rank: usize,
3720    ) -> Result<Self> {
3721        let h = config.hidden_size;
3722        let q_dim = config.q_dim();
3723        let kv = config.num_kv_heads * config.head_dim();
3724        let r = lora_rank;
3725
3726        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
3727        // zero memory (cuMemAlloc returns uninitialized VRAM).
3728        let z = |n: usize| -> Result<GpuBuffer<f32>> {
3729            Ok(GpuBuffer::from_host(ctx, &vec![0.0f32; n])?)
3730        };
3731        Ok(Self {
3732            m_lora_a_q: z(h * r)?,
3733            v_lora_a_q: z(h * r)?,
3734            m_lora_b_q: z(r * q_dim)?,
3735            v_lora_b_q: z(r * q_dim)?,
3736            m_lora_a_v: z(h * r)?,
3737            v_lora_a_v: z(h * r)?,
3738            m_lora_b_v: z(r * kv)?,
3739            v_lora_b_v: z(r * kv)?,
3740            m_input_norm: z(h)?,
3741            v_input_norm: z(h)?,
3742            m_post_attn_norm: z(h)?,
3743            v_post_attn_norm: z(h)?,
3744        })
3745    }
3746}
3747
3748// =============================================================================
3749// NF4 Block Backward Pass (ENT-153)
3750// =============================================================================
3751
3752#[cfg(feature = "cuda")]
3753impl CudaNf4TransformerBlock {
3754    /// Backward pass with activation checkpointing and LoRA gradient computation.
3755    ///
3756    /// # Activation Checkpointing
3757    ///
3758    /// Re-runs forward to regenerate intermediate activations. Only `layer_input`
3759    /// is saved per-layer (47 MB for 36 layers at seq_len=128). This is the standard
3760    /// NF4 QLoRA backward (C-QLORA-BWD-001): recompute activations, propagate gradients.
3761    #[allow(clippy::too_many_arguments)]
3762    pub(crate) fn backward(
3763        &self,
3764        layer_input: &GpuBuffer<f32>,
3765        grad_output: &GpuBuffer<f32>,
3766        grad_input: &mut GpuBuffer<f32>,
3767        output_scratch: &mut GpuBuffer<f32>,
3768        seq_len: usize,
3769        stream: &CudaStream,
3770        scratch: &mut CudaBlockScratch,
3771        grad_lora: &mut CudaLoraGradWorkspace,
3772    ) -> Result<()> {
3773        let hidden_size = self.config.hidden_size;
3774        let _q_dim = self.config.q_dim();
3775        let _kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3776        let intermediate_size = self.config.intermediate_size;
3777        let eps = 1e-5_f32;
3778
3779        // === Step 0: Activation checkpointing — re-run forward ===
3780        // This repopulates scratch with all intermediates needed for backward.
3781        self.forward(layer_input, output_scratch, seq_len, stream, scratch).map_err(|e| {
3782            eprintln!(
3783                "[backward] Layer {} activation-checkpoint forward FAILED: {e:?}",
3784                self.layer_idx
3785            );
3786            e
3787        })?;
3788
3789        // === Step 1: FFN backward (NF4 transpose, no weight grads for frozen projections) ===
3790        self.backward_nf4_ffn(
3791            grad_output,
3792            seq_len,
3793            hidden_size,
3794            intermediate_size,
3795            stream,
3796            scratch,
3797        )?;
3798
3799        // === Step 2: Post-attn norm backward ===
3800        let _t = scratch.op_begin(); // OP_NORM_BWD timing (both norms)
3801        rms_norm_backward(
3802            &scratch.residual1,
3803            &self.post_attn_norm_weight,
3804            &scratch.grad_hidden, // grad_from_ffn is accumulated in grad_hidden by backward_nf4_ffn
3805            grad_input,           // temporarily store post-attn-norm grad here
3806            &mut grad_lora.grad_post_attn_norm,
3807            saturating_u32(seq_len),
3808            saturating_u32(hidden_size),
3809            eps,
3810            stream,
3811        )?;
3812
3813        // Add residual connection: grad flows through both ffn and skip path
3814        // grad_residual1 = grad_input (from norm backward) + grad_output (from residual skip)
3815        cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
3816
3817        // === Step 3: Attention backward (NF4 + LoRA for Q/V) ===
3818        self.backward_nf4_attention(
3819            grad_input, // grad coming into attention (from residual1)
3820            seq_len, stream, scratch, grad_lora,
3821        )?;
3822
3823        // === Step 4: Input norm backward + first residual ===
3824        // At this point, scratch.grad_hidden contains grad from attention block
3825        // (accumulated by backward_nf4_attention into norm1_out reusing grad_hidden)
3826        rms_norm_backward(
3827            layer_input,
3828            &self.input_norm_weight,
3829            &scratch.grad_hidden, // grad flowing into norm1
3830            grad_input,           // final grad_input for this layer
3831            &mut grad_lora.grad_input_norm,
3832            saturating_u32(seq_len),
3833            saturating_u32(hidden_size),
3834            eps,
3835            stream,
3836        )?;
3837
3838        scratch.op_end(_t, OP_NORM_BWD);
3839
3840        Ok(())
3841    }
3842
3843    /// FFN backward for NF4 blocks (ENT-287: cuBLAS fp32 GEMM).
3844    ///
3845    /// Propagates gradient through: down_proj → SwiGLU → gate/up projections.
3846    /// Uses cuBLAS GEMM with pre-dequantized fp32 weights for correct layout.
3847    /// No weight gradients for frozen NF4 weights.
3848    fn backward_nf4_ffn(
3849        &self,
3850        grad_output: &GpuBuffer<f32>,
3851        seq_len: usize,
3852        hidden_size: usize,
3853        intermediate_size: usize,
3854        stream: &CudaStream,
3855        scratch: &mut CudaBlockScratch,
3856    ) -> Result<()> {
3857        let s = saturating_u32(seq_len);
3858        let h = saturating_u32(hidden_size);
3859        let i_size = saturating_u32(intermediate_size);
3860        let n_inter = saturating_u32(seq_len * intermediate_size);
3861
3862        // PMAT-481: NF4 tensor core backward dispatch
3863        static USE_NF4_TC_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3864        let nf4_tc_bwd =
3865            *USE_NF4_TC_BWD.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3866
3867        // Step 1: grad_swiglu[S,I] = grad_output[S,H] @ W_down^T (PMAT-472: fp16 dispatch)
3868        let _t = scratch.op_begin(); // OP_DOWN_BWD timing
3869        if nf4_tc_bwd {
3870            // NF4 TC backward: fused dequant+WMMA, no separate dequant kernel
3871            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3872                grad_output,
3873                &self.w_down_nf4,
3874                &self.w_down_scales,
3875                &mut scratch.grad_swiglu,
3876                s,
3877                h,      // N = hidden_size (grad_output cols)
3878                i_size, // K = intermediate_size (output cols = W_down rows)
3879                stream,
3880            )?;
3881        } else {
3882            gemm_backward_a_fp16_dispatch(
3883                grad_output,
3884                self.w_down_fp16.as_ref(),
3885                &self.w_down_fp32,
3886                &mut scratch.grad_swiglu,
3887                s,
3888                i_size,
3889                h,
3890                stream,
3891                &self.ctx,
3892            )?;
3893        }
3894
3895        scratch.op_end(_t, OP_DOWN_BWD);
3896
3897        // Step 2: SwiGLU backward: swiglu = silu(gate) * up
3898        // d_gate = d_swiglu * up * silu'(gate)
3899        // d_up   = d_swiglu * silu(gate)
3900        let _t = scratch.op_begin(); // OP_SWIGLU_BWD timing
3901
3902        // temp1 = d_swiglu * up_out → store in swiglu_out (reuse)
3903        elementwise_mul_forward(
3904            &scratch.grad_swiglu,
3905            &scratch.up_out,
3906            &mut scratch.swiglu_out,
3907            n_inter,
3908            stream,
3909        )?;
3910
3911        // silu_backward: d_gate_raw = temp1 * silu'(gate_out)
3912        // silu'(x) = silu(x) * (1 + x*(1-silu(x)))
3913        // Reuse up_out as storage for d_gate
3914        silu_backward(
3915            &scratch.gate_out,
3916            &scratch.swiglu_out,
3917            &mut scratch.up_out, // d_gate stored here
3918            stream,
3919        )?;
3920
3921        // d_up = d_swiglu * silu(gate) → store in gate_out (reuse)
3922        // Compute silu(gate) into swiglu_out (scratch) — NOT ffn_out which is [S,H] (too small)
3923        silu_forward(&scratch.gate_out, &mut scratch.swiglu_out, n_inter, stream)?;
3924        // d_up = d_swiglu * silu(gate)
3925        elementwise_mul_forward(
3926            &scratch.grad_swiglu,
3927            &scratch.swiglu_out,
3928            &mut scratch.gate_out, // d_up stored here
3929            n_inter,
3930            stream,
3931        )?;
3932
3933        scratch.op_end(_t, OP_SWIGLU_BWD);
3934
3935        // Step 3: gate/up backward (PMAT-472: fp16 dispatch, PMAT-481: TC dispatch)
3936        let _t = scratch.op_begin(); // OP_GATE_UP_BWD timing
3937                                     // PMAT-484: Fused backward — use cuBLAS beta=1.0 accumulate to eliminate
3938                                     // the separate cuda_add_inplace kernel launch (3 launches → 2).
3939        static USE_FUSED_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3940        let fused_bwd = *USE_FUSED_BWD
3941            .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
3942
3943        if nf4_tc_bwd {
3944            // PMAT-481: NF4 tensor core backward — fused dequant+WMMA per projection
3945            // Up backward: grad_up[S,H] = d_up[S,I] @ W_up^T
3946            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3947                &scratch.gate_out, // d_up stored in gate_out buffer
3948                &self.w_up_nf4,
3949                &self.w_up_scales,
3950                &mut scratch.grad_hidden,
3951                s,
3952                i_size, // N = intermediate_size (d_up cols)
3953                h,      // K = hidden_size (output cols)
3954                stream,
3955            )?;
3956            // Gate backward: grad_gate[S,H] = d_gate[S,I] @ W_gate^T
3957            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3958                &scratch.up_out, // d_gate stored in up_out buffer
3959                &self.w_gate_nf4,
3960                &self.w_gate_scales,
3961                &mut scratch.ffn_out,
3962                s,
3963                i_size, // N = intermediate_size
3964                h,      // K = hidden_size
3965                stream,
3966            )?;
3967            // Accumulate: grad_hidden += grad_gate
3968            cuda_add_inplace(
3969                &mut scratch.grad_hidden,
3970                &scratch.ffn_out,
3971                seq_len * hidden_size,
3972                stream,
3973            )?;
3974        } else if fused_bwd {
3975            // Fused path: compute up backward into grad_hidden, then accumulate gate
3976            gemm_backward_a_fp16_dispatch(
3977                &scratch.gate_out,
3978                self.w_up_fp16.as_ref(),
3979                &self.w_up_fp32,
3980                &mut scratch.grad_hidden,
3981                s,
3982                h,
3983                i_size,
3984                stream,
3985                &self.ctx,
3986            )?;
3987            // Accumulate gate backward into grad_hidden (beta=1.0)
3988            gemm_backward_a_fp16_dispatch_accumulate(
3989                &scratch.up_out,
3990                self.w_gate_fp16.as_ref(),
3991                &self.w_gate_fp32,
3992                &mut scratch.grad_hidden,
3993                s,
3994                h,
3995                i_size,
3996                stream,
3997                &self.ctx,
3998            )?;
3999        } else {
4000            // Unfused path: separate GEMMs + explicit add
4001            gemm_backward_a_fp16_dispatch(
4002                &scratch.up_out,
4003                self.w_gate_fp16.as_ref(),
4004                &self.w_gate_fp32,
4005                &mut scratch.ffn_out,
4006                s,
4007                h,
4008                i_size,
4009                stream,
4010                &self.ctx,
4011            )?;
4012            gemm_backward_a_fp16_dispatch(
4013                &scratch.gate_out,
4014                self.w_up_fp16.as_ref(),
4015                &self.w_up_fp32,
4016                &mut scratch.grad_hidden,
4017                s,
4018                h,
4019                i_size,
4020                stream,
4021                &self.ctx,
4022            )?;
4023
4024            // Accumulate: grad_hidden = grad_norm2_gate + grad_norm2_up
4025            cuda_add_inplace(
4026                &mut scratch.grad_hidden,
4027                &scratch.ffn_out,
4028                seq_len * hidden_size,
4029                stream,
4030            )?;
4031        }
4032        scratch.op_end(_t, OP_GATE_UP_BWD);
4033
4034        Ok(())
4035    }
4036
4037    /// Attention backward for NF4 blocks with LoRA gradient computation (ENT-287).
4038    ///
4039    /// Propagates gradient through O projection, attention mechanism, and Q/K/V projections.
4040    /// Computes LoRA weight gradients for Q and V projections.
4041    /// Uses cuBLAS GEMM with pre-dequantized fp32 weights.
4042    fn backward_nf4_attention(
4043        &self,
4044        grad_residual1: &GpuBuffer<f32>,
4045        seq_len: usize,
4046        stream: &CudaStream,
4047        scratch: &mut CudaBlockScratch,
4048        grad_lora: &mut CudaLoraGradWorkspace,
4049    ) -> Result<()> {
4050        use crate::autograd::cuda_forward::gemm_forward;
4051
4052        let hidden_size = self.config.hidden_size;
4053        let q_dim = self.config.q_dim();
4054        let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
4055        let num_heads = self.config.num_attention_heads;
4056        let head_dim = self.config.head_dim();
4057
4058        let s = saturating_u32(seq_len);
4059        let h = saturating_u32(hidden_size);
4060        let qd = saturating_u32(q_dim);
4061        let kvh = saturating_u32(kv_hidden_size);
4062
4063        // Step 1: O projection backward (PMAT-481: TC dispatch)
4064        static USE_NF4_TC_BWD_O: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4065        let nf4_tc_bwd_o = *USE_NF4_TC_BWD_O
4066            .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4067
4068        let _t = scratch.op_begin(); // OP_ATTN_BWD timing (O-proj + attention mechanism)
4069        if nf4_tc_bwd_o {
4070            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4071                grad_residual1,
4072                &self.w_o_nf4,
4073                &self.w_o_scales,
4074                &mut scratch.attn_out,
4075                s,
4076                h,  // N = hidden_size (grad_residual1 cols)
4077                qd, // K = q_dim (output cols = W_o rows)
4078                stream,
4079            )?;
4080        } else {
4081            gemm_backward_a_fp16_dispatch(
4082                grad_residual1,
4083                self.w_o_fp16.as_ref(),
4084                &self.w_o_fp32,
4085                &mut scratch.attn_out,
4086                s,
4087                qd,
4088                h,
4089                stream,
4090                &self.ctx,
4091            )?;
4092        }
4093
4094        // Step 2: Attention mechanism backward
4095        // This is complex (softmax backward, batched GEMMs) — reuse the fp32 attention backward
4096        // infrastructure since attention operates on fp32 activations.
4097        self.backward_nf4_attention_mechanism(seq_len, num_heads, head_dim, stream, scratch)?;
4098
4099        // After attention backward: scratch.norm1_out-related grads are accumulated.
4100        // grad_q is in scratch.q, grad_k in scratch.k, grad_v in scratch.v
4101
4102        // Step 2b: RoPE backward (inverse rotation) on grad_q and grad_k
4103        // Forward applied RoPE to Q and K. Undo rotation so projection backward
4104        // gets gradients in the unrotated coordinate frame.
4105        let rope_theta = self.config.rope_theta;
4106        let num_kv_heads = self.config.num_kv_heads;
4107        let nkv = saturating_u32(num_kv_heads);
4108        let nh = saturating_u32(num_heads);
4109        let hd = saturating_u32(head_dim);
4110        {
4111            let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
4112            batched_rope_neox_backward(
4113                q_ref,
4114                &mut scratch.q,
4115                &scratch.rope_positions,
4116                nh,
4117                hd,
4118                s,
4119                rope_theta,
4120                stream,
4121            )?;
4122            let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
4123            batched_rope_neox_backward(
4124                k_ref,
4125                &mut scratch.k,
4126                &scratch.rope_positions,
4127                nkv,
4128                hd,
4129                s,
4130                rope_theta,
4131                stream,
4132            )?;
4133        }
4134
4135        scratch.op_end(_t, OP_ATTN_BWD);
4136
4137        // Q/K/V backward (PMAT-472: fp16 dispatch, PMAT-481: TC dispatch)
4138        static USE_NF4_TC_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4139        let nf4_tc_bwd = *USE_NF4_TC_BWD_ATTN
4140            .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4141
4142        let _t = scratch.op_begin(); // OP_QKV_BWD timing
4143        if nf4_tc_bwd {
4144            // PMAT-481: NF4 tensor core backward for Q projection
4145            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4146                &scratch.q,
4147                &self.w_q_nf4,
4148                &self.w_q_scales,
4149                &mut scratch.o_proj_out,
4150                s,
4151                qd, // N = q_dim (grad_q cols)
4152                h,  // K = hidden_size (output cols)
4153                stream,
4154            )?;
4155        } else {
4156            gemm_backward_a_fp16_dispatch(
4157                &scratch.q,
4158                self.w_q_fp16.as_ref(),
4159                &self.w_q_fp32,
4160                &mut scratch.o_proj_out,
4161                s,
4162                h,
4163                qd,
4164                stream,
4165                &self.ctx,
4166            )?;
4167        }
4168
4169        // LoRA Q backward: compute grad_A_q, grad_B_q, and add to grad_norm1
4170        if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
4171            let r = saturating_u32(self.lora_rank);
4172
4173            // Recompute: lora_inter_q = norm1_out @ A_q  [S, rank]
4174            gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
4175
4176            // grad_B_q = lora_inter_q^T @ grad_q  [rank, q_dim]
4177            // (Note: B_q was pre-scaled, so grad_B_q includes the scale factor)
4178            gemm_backward_b(
4179                &scratch.lora_inter,
4180                &scratch.q,
4181                &mut grad_lora.grad_lora_b_q,
4182                s,
4183                r,
4184                qd,
4185                stream,
4186            )?;
4187
4188            // grad_lora_inter = grad_q @ B_q^T  [S, rank]
4189            gemm_backward_a(
4190                &scratch.q,
4191                b_q,
4192                &mut scratch.lora_inter, // reuse for grad_lora_inter
4193                s,
4194                qd,
4195                r,
4196                stream,
4197            )?;
4198
4199            // grad_A_q = norm1_out^T @ grad_lora_inter  [H, rank]
4200            gemm_backward_b(
4201                &scratch.norm1_out,
4202                &scratch.lora_inter,
4203                &mut grad_lora.grad_lora_a_q,
4204                s,
4205                h,
4206                r,
4207                stream,
4208            )?;
4209
4210            // Add LoRA's contribution to grad_norm1: += grad_lora_inter @ A_q^T  [S, H]
4211            gemm_backward_a(
4212                &scratch.lora_inter,
4213                a_q,
4214                &mut scratch.lora_temp, // [S, H]
4215                s,
4216                r,
4217                h,
4218                stream,
4219            )?;
4220            cuda_add_inplace(
4221                &mut scratch.o_proj_out,
4222                &scratch.lora_temp,
4223                seq_len * hidden_size,
4224                stream,
4225            )?;
4226        }
4227
4228        // K+V backward (PMAT-484: fused, PMAT-481: TC dispatch)
4229        static USE_FUSED_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4230        let fused_bwd = *USE_FUSED_BWD_ATTN
4231            .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
4232
4233        if nf4_tc_bwd {
4234            // PMAT-481: NF4 tensor core backward for K and V projections
4235            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4236                &scratch.k,
4237                &self.w_k_nf4,
4238                &self.w_k_scales,
4239                &mut scratch.ffn_out,
4240                s,
4241                kvh, // N = kv_hidden (grad_k cols)
4242                h,   // K = hidden_size (output cols)
4243                stream,
4244            )?;
4245            cuda_add_inplace(
4246                &mut scratch.o_proj_out,
4247                &scratch.ffn_out,
4248                seq_len * hidden_size,
4249                stream,
4250            )?;
4251            crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4252                &scratch.v,
4253                &self.w_v_nf4,
4254                &self.w_v_scales,
4255                &mut scratch.ffn_out,
4256                s,
4257                kvh, // N = kv_hidden
4258                h,   // K = hidden_size
4259                stream,
4260            )?;
4261            cuda_add_inplace(
4262                &mut scratch.o_proj_out,
4263                &scratch.ffn_out,
4264                seq_len * hidden_size,
4265                stream,
4266            )?;
4267        } else if fused_bwd {
4268            // Fused: K and V backward accumulate directly into o_proj_out
4269            gemm_backward_a_fp16_dispatch_accumulate(
4270                &scratch.k,
4271                self.w_k_fp16.as_ref(),
4272                &self.w_k_fp32,
4273                &mut scratch.o_proj_out,
4274                s,
4275                h,
4276                kvh,
4277                stream,
4278                &self.ctx,
4279            )?;
4280            gemm_backward_a_fp16_dispatch_accumulate(
4281                &scratch.v,
4282                self.w_v_fp16.as_ref(),
4283                &self.w_v_fp32,
4284                &mut scratch.o_proj_out,
4285                s,
4286                h,
4287                kvh,
4288                stream,
4289                &self.ctx,
4290            )?;
4291        } else {
4292            // Unfused: K backward → temp, accumulate; V backward → temp, accumulate
4293            gemm_backward_a_fp16_dispatch(
4294                &scratch.k,
4295                self.w_k_fp16.as_ref(),
4296                &self.w_k_fp32,
4297                &mut scratch.ffn_out,
4298                s,
4299                h,
4300                kvh,
4301                stream,
4302                &self.ctx,
4303            )?;
4304            cuda_add_inplace(
4305                &mut scratch.o_proj_out,
4306                &scratch.ffn_out,
4307                seq_len * hidden_size,
4308                stream,
4309            )?;
4310
4311            gemm_backward_a_fp16_dispatch(
4312                &scratch.v,
4313                self.w_v_fp16.as_ref(),
4314                &self.w_v_fp32,
4315                &mut scratch.ffn_out,
4316                s,
4317                h,
4318                kvh,
4319                stream,
4320                &self.ctx,
4321            )?;
4322            cuda_add_inplace(
4323                &mut scratch.o_proj_out,
4324                &scratch.ffn_out,
4325                seq_len * hidden_size,
4326                stream,
4327            )?;
4328        }
4329
4330        // LoRA V backward
4331        if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
4332            let r = saturating_u32(self.lora_rank);
4333
4334            // Recompute: lora_inter_v = norm1_out @ A_v  [S, rank]
4335            gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
4336
4337            // grad_B_v = lora_inter_v^T @ grad_v  [rank, kv_hidden]
4338            gemm_backward_b(
4339                &scratch.lora_inter,
4340                &scratch.v,
4341                &mut grad_lora.grad_lora_b_v,
4342                s,
4343                r,
4344                kvh,
4345                stream,
4346            )?;
4347
4348            // grad_lora_inter = grad_v @ B_v^T  [S, rank]
4349            gemm_backward_a(&scratch.v, b_v, &mut scratch.lora_inter, s, kvh, r, stream)?;
4350
4351            // grad_A_v = norm1_out^T @ grad_lora_inter  [H, rank]
4352            gemm_backward_b(
4353                &scratch.norm1_out,
4354                &scratch.lora_inter,
4355                &mut grad_lora.grad_lora_a_v,
4356                s,
4357                h,
4358                r,
4359                stream,
4360            )?;
4361
4362            // Add LoRA V's contribution to grad_norm1
4363            gemm_backward_a(&scratch.lora_inter, a_v, &mut scratch.lora_temp, s, r, h, stream)?;
4364            cuda_add_inplace(
4365                &mut scratch.o_proj_out,
4366                &scratch.lora_temp,
4367                seq_len * hidden_size,
4368                stream,
4369            )?;
4370        }
4371
4372        scratch.op_end(_t, OP_QKV_BWD);
4373
4374        // Step 6: Accumulated grad_norm1 is in scratch.o_proj_out → move to scratch.grad_hidden
4375        unsafe {
4376            scratch.grad_hidden.copy_from_buffer_async(&scratch.o_proj_out, stream).map_err(
4377                |e| {
4378                    crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4379                        "grad_norm1 copy failed: {e}"
4380                    ))
4381                },
4382            )?;
4383        }
4384
4385        Ok(())
4386    }
4387
4388    /// Attention mechanism backward (softmax, Q@K^T backward) for NF4 blocks.
4389    ///
4390    /// After this call:
4391    /// - scratch.q contains grad_q [S, q_dim]
4392    /// - scratch.k contains grad_k [S, kv_hidden]
4393    /// - scratch.v contains grad_v [S, kv_hidden]
4394    ///
4395    /// PMAT-486: Full attention backward (replaces previous no-op).
4396    /// Mirrors the FP32 backward in `backward_attention` (lines 1470-1810).
4397    ///
4398    /// Forward: attn_out = softmax(Q @ K^T / √d) @ V
4399    /// Backward:
4400    ///   1. grad_scores = grad_attn_batched @ V^T
4401    ///   2. grad_V = (grad_attn_batched^T @ attn_weights)^T  (buffer-safe identity)
4402    ///   3. softmax_backward(grad_scores, attn_weights) → grad_raw
4403    ///   4. grad_raw *= 1/√d
4404    ///   5. grad_Q = grad_raw @ K_expanded
4405    ///   6. grad_K = (Q^T @ grad_raw)^T
4406    ///   7. GQA reduction + batched→interleaved conversion
4407    ///
4408    /// Contract: attention-backward-v1.yaml
4409    fn backward_nf4_attention_mechanism(
4410        &self,
4411        seq_len: usize,
4412        num_heads: usize,
4413        head_dim: usize,
4414        stream: &CudaStream,
4415        scratch: &mut CudaBlockScratch,
4416    ) -> Result<()> {
4417        let num_kv_heads = self.config.num_kv_heads;
4418        let heads_per_kv = num_heads / num_kv_heads;
4419        let s = saturating_u32(seq_len);
4420        let nh = saturating_u32(num_heads);
4421        let nkv = saturating_u32(num_kv_heads);
4422        let hd = saturating_u32(head_dim);
4423        let scale = 1.0 / (head_dim as f32).sqrt();
4424
4425        // grad_attn_out is in scratch.attn_out [S, q_dim]
4426        // Convert to batched layout [NH, S, HD]
4427        interleaved_to_batched_forward(
4428            &scratch.attn_out,
4429            &mut scratch.attn_q_batched, // grad_attn_batched [NH, S, HD]
4430            s,
4431            nh,
4432            hd,
4433            stream,
4434        )?;
4435
4436        // === Step 1: Expand V for GQA and transpose ===
4437        // V is in scratch.v [S, kv_hidden] from activation checkpointing forward.
4438        // Convert to batched [NKV, S, HD], then GQA-expand to [NH, S, HD].
4439        interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
4440
4441        if heads_per_kv > 1 {
4442            expand_kv_heads(
4443                &scratch.attn_kv_temp,
4444                &mut scratch.attn_kv_temp2,
4445                num_kv_heads,
4446                heads_per_kv,
4447                seq_len * head_dim,
4448                stream,
4449            )?;
4450        } else {
4451            unsafe {
4452                scratch
4453                    .attn_kv_temp2
4454                    .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
4455                    .map_err(|e| {
4456                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4457                            "V copy for attn backward: {e:?}"
4458                        ))
4459                    })?;
4460            }
4461        }
4462        // attn_kv_temp2 = V_expanded [NH, S, HD]
4463
4464        // Transpose V: [NH, S, HD] → [NH, HD, S]
4465        batched_transpose_forward(
4466            &scratch.attn_kv_temp2,
4467            &mut scratch.attn_kv_temp, // V^T [NH, HD, S]
4468            nh,
4469            s,
4470            hd,
4471            stream,
4472        )?;
4473
4474        // === Step 2: grad_attn_scores = grad_attn_batched @ V^T ===
4475        // [NH, S, HD] @ [NH, HD, S] → [NH, S, S]
4476        batched_4d_gemm_forward(
4477            &scratch.attn_q_batched,
4478            &scratch.attn_kv_temp,
4479            &mut scratch.grad_attn_scores,
4480            1,
4481            nh,
4482            s,
4483            s,
4484            hd,
4485            stream,
4486        )?;
4487
4488        // === Step 3: grad_V = (grad_attn_batched^T @ attn_weights)^T ===
4489        // Uses the identity to avoid needing an [NH, S, S] transpose buffer.
4490        // attn_weights in scratch.attn_scores [NH, S, S] from forward checkpoint.
4491
4492        // Step 3a: transpose grad_attn_batched [NH, S, HD] → [NH, HD, S]
4493        batched_transpose_forward(
4494            &scratch.attn_q_batched,
4495            &mut scratch.attn_kv_temp, // reuse: grad_attn_batched^T [NH, HD, S]
4496            nh,
4497            s,
4498            hd,
4499            stream,
4500        )?;
4501
4502        // Step 3b: [NH, HD, S] @ [NH, S, S] → [NH, HD, S] (= grad_V^T)
4503        batched_4d_gemm_forward(
4504            &scratch.attn_kv_temp,
4505            &scratch.attn_scores,       // attn_weights from forward
4506            &mut scratch.attn_kv_temp2, // grad_V^T [NH, HD, S]
4507            1,
4508            nh,
4509            hd,
4510            s,
4511            s,
4512            stream,
4513        )?;
4514
4515        // Step 3c: transpose grad_V^T [NH, HD, S] → grad_V [NH, S, HD]
4516        batched_transpose_forward(
4517            &scratch.attn_kv_temp2,
4518            &mut scratch.attn_kv_temp, // grad_V [NH, S, HD]
4519            nh,
4520            hd,
4521            s,
4522            stream,
4523        )?;
4524        // attn_kv_temp = grad_V [NH, S, HD]
4525
4526        // === Step 4: Softmax backward ===
4527        // In-place: grad_attn_scores is both input and output.
4528        let total_rows = nh * s;
4529        {
4530            let grad_scores_view = unsafe {
4531                GpuBuffer::<f32>::from_raw_parts(
4532                    scratch.grad_attn_scores.as_ptr(),
4533                    scratch.grad_attn_scores.len(),
4534                )
4535            };
4536            batched_softmax_backward(
4537                &scratch.attn_scores,
4538                &grad_scores_view,
4539                &mut scratch.grad_attn_scores,
4540                total_rows,
4541                s,
4542                stream,
4543            )?;
4544            leak(grad_scores_view);
4545        }
4546
4547        // === Step 5: Scale backward (1/√d) ===
4548        let total_scores = saturating_u32(num_heads * seq_len * seq_len);
4549        {
4550            let scores_view = unsafe {
4551                GpuBuffer::<f32>::from_raw_parts(
4552                    scratch.grad_attn_scores.as_ptr(),
4553                    scratch.grad_attn_scores.len(),
4554                )
4555            };
4556            scale_forward(
4557                &scores_view,
4558                &mut scratch.grad_attn_scores,
4559                scale,
4560                total_scores,
4561                stream,
4562            )?;
4563            leak(scores_view);
4564        }
4565
4566        // === Step 6: Reconstruct K, GQA expand, compute grad_Q ===
4567        interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp2, s, nkv, hd, stream)?;
4568
4569        if heads_per_kv > 1 {
4570            unsafe {
4571                scratch
4572                    .attn_q_batched
4573                    .copy_from_buffer_async(&scratch.attn_kv_temp2, stream)
4574                    .map_err(|e| {
4575                        crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4576                            "K copy for GQA expand: {e}"
4577                        ))
4578                    })?;
4579            }
4580            expand_kv_heads(
4581                &scratch.attn_q_batched,
4582                &mut scratch.attn_kv_temp2,
4583                num_kv_heads,
4584                heads_per_kv,
4585                seq_len * head_dim,
4586                stream,
4587            )?;
4588        }
4589        // attn_kv_temp2 = K_expanded [NH, S, HD]
4590
4591        // grad_Q = grad_raw_scores @ K_expanded → attn_q_batched [NH, S, HD]
4592        batched_4d_gemm_forward(
4593            &scratch.grad_attn_scores,
4594            &scratch.attn_kv_temp2,
4595            &mut scratch.attn_q_batched,
4596            1,
4597            nh,
4598            s,
4599            hd,
4600            s,
4601            stream,
4602        )?;
4603
4604        // === Step 7: Compute grad_K ===
4605        // grad_K^T = Q^T @ grad_raw_scores
4606        // Reconstruct Q_batched into o_proj_out (attn_q_batched has grad_Q now)
4607        interleaved_to_batched_forward(
4608            &scratch.q,
4609            &mut scratch.o_proj_out, // temp for Q_batched
4610            s,
4611            nh,
4612            hd,
4613            stream,
4614        )?;
4615
4616        // Transpose Q: [NH, S, HD] → [NH, HD, S]
4617        batched_transpose_forward(
4618            &scratch.o_proj_out,
4619            &mut scratch.attn_kv_temp2, // Q^T [NH, HD, S]
4620            nh,
4621            s,
4622            hd,
4623            stream,
4624        )?;
4625
4626        // grad_K^T = Q^T @ grad_raw_scores → ffn_out as temp [NH, HD, S]
4627        batched_4d_gemm_forward(
4628            &scratch.attn_kv_temp2,
4629            &scratch.grad_attn_scores,
4630            &mut scratch.ffn_out, // grad_K^T [NH, HD, S]
4631            1,
4632            nh,
4633            hd,
4634            s,
4635            s,
4636            stream,
4637        )?;
4638
4639        // Transpose grad_K^T → grad_K: [NH, HD, S] → [NH, S, HD]
4640        batched_transpose_forward(
4641            &scratch.ffn_out,
4642            &mut scratch.attn_kv_temp2, // grad_K [NH, S, HD]
4643            nh,
4644            hd,
4645            s,
4646            stream,
4647        )?;
4648
4649        // === Step 8: GQA gradient reduction ===
4650        if heads_per_kv > 1 {
4651            self.reduce_gqa_gradients_nf4(
4652                num_kv_heads,
4653                heads_per_kv,
4654                seq_len,
4655                head_dim,
4656                stream,
4657                scratch,
4658            )?;
4659        }
4660
4661        // === Step 9: Convert batched gradients → interleaved, store in scratch.q/k/v ===
4662        // grad_Q: attn_q_batched [NH, S, HD] → scratch.q [S, q_dim]
4663        batched_to_interleaved_forward(&scratch.attn_q_batched, &mut scratch.q, s, nh, hd, stream)?;
4664
4665        // grad_K: attn_kv_temp2 [NKV, S, HD] → scratch.k [S, kv_hidden]
4666        batched_to_interleaved_forward(&scratch.attn_kv_temp2, &mut scratch.k, s, nkv, hd, stream)?;
4667
4668        // grad_V: attn_kv_temp [NKV, S, HD] → scratch.v [S, kv_hidden]
4669        // (After GQA reduction, grad_V was reduced from [NH] to [NKV] heads)
4670        batched_to_interleaved_forward(&scratch.attn_kv_temp, &mut scratch.v, s, nkv, hd, stream)?;
4671
4672        Ok(())
4673    }
4674
4675    /// GQA gradient reduction for NF4 blocks.
4676    /// Reduces grad_K and grad_V from [num_heads, S, HD] to [num_kv_heads, S, HD]
4677    /// by summing across Q heads sharing each KV head.
4678    fn reduce_gqa_gradients_nf4(
4679        &self,
4680        num_kv_heads: usize,
4681        heads_per_kv: usize,
4682        seq_len: usize,
4683        head_dim: usize,
4684        stream: &CudaStream,
4685        scratch: &mut CudaBlockScratch,
4686    ) -> Result<()> {
4687        let chunk = seq_len * head_dim;
4688        for g in 0..num_kv_heads {
4689            let dst_off = g * chunk;
4690            // First Q head in group → initialize destination
4691            let src_off = g * heads_per_kv * chunk;
4692            // Copy first head
4693            {
4694                let src = unsafe {
4695                    GpuBuffer::<f32>::from_raw_parts(
4696                        scratch.attn_kv_temp2.as_ptr() + (src_off * 4) as u64,
4697                        chunk,
4698                    )
4699                };
4700                let mut dst = unsafe {
4701                    GpuBuffer::<f32>::from_raw_parts(
4702                        scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4703                        chunk,
4704                    )
4705                };
4706                if src_off != dst_off {
4707                    unsafe {
4708                        dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4709                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4710                                "GQA K reduce copy: {e}"
4711                            ))
4712                        })?;
4713                    }
4714                }
4715                leak(src);
4716                leak(dst);
4717            }
4718            // Accumulate remaining heads
4719            for h in 1..heads_per_kv {
4720                let add_off = (g * heads_per_kv + h) * chunk;
4721                let src = unsafe {
4722                    GpuBuffer::<f32>::from_raw_parts(
4723                        scratch.attn_kv_temp2.as_ptr() + (add_off * 4) as u64,
4724                        chunk,
4725                    )
4726                };
4727                let mut dst = unsafe {
4728                    GpuBuffer::<f32>::from_raw_parts(
4729                        scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4730                        chunk,
4731                    )
4732                };
4733                cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4734                leak(src);
4735                leak(dst);
4736            }
4737            // Same for grad_V (in attn_kv_temp)
4738            {
4739                let src = unsafe {
4740                    GpuBuffer::<f32>::from_raw_parts(
4741                        scratch.attn_kv_temp.as_ptr() + (src_off * 4) as u64,
4742                        chunk,
4743                    )
4744                };
4745                let mut dst = unsafe {
4746                    GpuBuffer::<f32>::from_raw_parts(
4747                        scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4748                        chunk,
4749                    )
4750                };
4751                if src_off != dst_off {
4752                    unsafe {
4753                        dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4754                            crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4755                                "GQA V reduce copy: {e}"
4756                            ))
4757                        })?;
4758                    }
4759                }
4760                leak(src);
4761                leak(dst);
4762            }
4763            for h in 1..heads_per_kv {
4764                let add_off = (g * heads_per_kv + h) * chunk;
4765                let src = unsafe {
4766                    GpuBuffer::<f32>::from_raw_parts(
4767                        scratch.attn_kv_temp.as_ptr() + (add_off * 4) as u64,
4768                        chunk,
4769                    )
4770                };
4771                let mut dst = unsafe {
4772                    GpuBuffer::<f32>::from_raw_parts(
4773                        scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4774                        chunk,
4775                    )
4776                };
4777                cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4778                leak(src);
4779                leak(dst);
4780            }
4781        }
4782        Ok(())
4783    }
4784
4785    /// Initialize LoRA optimizer state for this block.
4786    pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
4787        GpuLoraOptimizerState::new(&self.ctx, &self.config, self.lora_rank)
4788    }
4789
4790    /// LoRA optimizer step: update A_q, B_q, A_v, B_v and norm weights using AdamW.
4791    #[allow(clippy::too_many_arguments)]
4792    pub(crate) fn lora_optimizer_step(
4793        &mut self,
4794        state: &mut GpuLoraOptimizerState,
4795        step: u32,
4796        lr: f32,
4797        beta1: f32,
4798        beta2: f32,
4799        eps: f32,
4800        weight_decay: f32,
4801        stream: &CudaStream,
4802        grad_lora: &CudaLoraGradWorkspace,
4803    ) -> Result<()> {
4804        let h = self.config.hidden_size;
4805        let q_dim = self.config.q_dim();
4806        let kv = self.config.num_kv_heads * self.config.head_dim();
4807        let r = self.lora_rank;
4808
4809        // AdamW step for each LoRA weight
4810        if let Some(ref mut a_q) = self.lora_a_q {
4811            adamw_step_cuda(
4812                a_q,
4813                &grad_lora.grad_lora_a_q,
4814                &mut state.m_lora_a_q,
4815                &mut state.v_lora_a_q,
4816                lr,
4817                beta1,
4818                beta2,
4819                eps,
4820                weight_decay,
4821                step,
4822                saturating_u32(h * r),
4823                stream,
4824            )?;
4825        }
4826        if let Some(ref mut b_q) = self.lora_b_q {
4827            adamw_step_cuda(
4828                b_q,
4829                &grad_lora.grad_lora_b_q,
4830                &mut state.m_lora_b_q,
4831                &mut state.v_lora_b_q,
4832                lr,
4833                beta1,
4834                beta2,
4835                eps,
4836                weight_decay,
4837                step,
4838                saturating_u32(r * q_dim),
4839                stream,
4840            )?;
4841        }
4842        if let Some(ref mut a_v) = self.lora_a_v {
4843            adamw_step_cuda(
4844                a_v,
4845                &grad_lora.grad_lora_a_v,
4846                &mut state.m_lora_a_v,
4847                &mut state.v_lora_a_v,
4848                lr,
4849                beta1,
4850                beta2,
4851                eps,
4852                weight_decay,
4853                step,
4854                saturating_u32(h * r),
4855                stream,
4856            )?;
4857        }
4858        if let Some(ref mut b_v) = self.lora_b_v {
4859            adamw_step_cuda(
4860                b_v,
4861                &grad_lora.grad_lora_b_v,
4862                &mut state.m_lora_b_v,
4863                &mut state.v_lora_b_v,
4864                lr,
4865                beta1,
4866                beta2,
4867                eps,
4868                weight_decay,
4869                step,
4870                saturating_u32(r * kv),
4871                stream,
4872            )?;
4873        }
4874
4875        // AdamW step for norm weights
4876        adamw_step_cuda(
4877            &mut self.input_norm_weight,
4878            &grad_lora.grad_input_norm,
4879            &mut state.m_input_norm,
4880            &mut state.v_input_norm,
4881            lr,
4882            beta1,
4883            beta2,
4884            eps,
4885            weight_decay,
4886            step,
4887            saturating_u32(h),
4888            stream,
4889        )?;
4890        adamw_step_cuda(
4891            &mut self.post_attn_norm_weight,
4892            &grad_lora.grad_post_attn_norm,
4893            &mut state.m_post_attn_norm,
4894            &mut state.v_post_attn_norm,
4895            lr,
4896            beta1,
4897            beta2,
4898            eps,
4899            weight_decay,
4900            step,
4901            saturating_u32(h),
4902            stream,
4903        )?;
4904
4905        Ok(())
4906    }
4907
4908    /// Download LoRA weights from GPU to CPU for checkpoint saving.
4909    ///
4910    /// Returns (A_q, B_q, A_v, B_v) as flat f32 vectors.
4911    /// B matrices are returned WITH the baked-in scale (caller can divide by lora_scale
4912    /// if they need the unscaled version).
4913    pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
4914        let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
4915            let mut host = vec![0.0f32; buf.len()];
4916            buf.copy_to_host(&mut host).map_err(|e| {
4917                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4918                    "LoRA weight download failed: {e}"
4919                ))
4920            })?;
4921            Ok(host)
4922        };
4923        let a_q = self.lora_a_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4924        let b_q = self.lora_b_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4925        let a_v = self.lora_a_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4926        let b_v = self.lora_b_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4927        Ok((a_q, b_q, a_v, b_v))
4928    }
4929
4930    /// Upload LoRA weights from CPU to GPU for checkpoint resume (ENT-276).
4931    ///
4932    /// Overwrites the current LoRA adapter buffers with trained weights
4933    /// restored from a checkpoint. Call after `new()` to replace the fresh
4934    /// random init with previously trained adapters.
4935    pub fn upload_lora_weights(
4936        &mut self,
4937        a_q: &[f32],
4938        b_q: &[f32],
4939        a_v: &[f32],
4940        b_v: &[f32],
4941    ) -> Result<()> {
4942        let upload = |buf: &mut GpuBuffer<f32>, data: &[f32], name: &str| -> Result<()> {
4943            if data.len() != buf.len() {
4944                return Err(crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(
4945                    format!(
4946                        "LoRA {name} size mismatch: checkpoint has {} but GPU buffer expects {}",
4947                        data.len(),
4948                        buf.len()
4949                    ),
4950                ));
4951            }
4952            buf.copy_from_host(data).map_err(|e| {
4953                crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4954                    "LoRA {name} upload failed: {e}"
4955                ))
4956            })
4957        };
4958        if let Some(ref mut buf) = self.lora_a_q {
4959            upload(buf, a_q, "a_q")?;
4960        }
4961        if let Some(ref mut buf) = self.lora_b_q {
4962            upload(buf, b_q, "b_q")?;
4963        }
4964        if let Some(ref mut buf) = self.lora_a_v {
4965            upload(buf, a_v, "a_v")?;
4966        }
4967        if let Some(ref mut buf) = self.lora_b_v {
4968            upload(buf, b_v, "b_v")?;
4969        }
4970        Ok(())
4971    }
4972}
4973
4974#[cfg(test)]
4975mod tests {
4976    #[test]
4977    fn test_cuda_block_compiles() {
4978        // Basic compilation test
4979        #[cfg(feature = "cuda")]
4980        {
4981            use super::*;
4982            let _ = std::mem::size_of::<CudaTransformerBlock>();
4983            let _ = std::mem::size_of::<CudaNf4TransformerBlock>();
4984        }
4985    }
4986}