Skip to main content

entrenar/train/transformer_trainer/
cuda_trainer.rs

1//! GPU-resident transformer trainer (ALB-040)
2//!
3//! Wires the existing `CudaTransformerBlock` forward/backward/optimizer_step
4//! into the pretraining path. Follows the proven `classify_pipeline.rs` pattern.
5//!
6//! # Architecture
7//!
8//! ```text
9//! CudaTransformerTrainer
10//! ├── model: Transformer                 (CPU — embed + save)
11//! ├── cuda_trainer: CudaTrainer          (GPU device context)
12//! ├── cuda_blocks: Vec<CudaBlock>            (fp32 or NF4)
13//! ├── cuda_grad_workspace: CudaGradWorkspace
14//! ├── gpu_training: GpuPretrainState     (layer_inputs, grad bufs, opt states)
15//! ├── lm_head_weight_gpu: GpuBuffer      (V × H on GPU)
16//! ├── lm_head_grad_gpu: GpuBuffer        (V × H gradient scratch)
17//! ├── lm_head_m/v: GpuBuffer             (AdamW moment states)
18//! └── config: TransformerTrainConfig
19//! ```
20//!
21//! # Transfer budget (C-GPUTRAIN-002, updated KAIZEN-050/052)
22//!
23//! 1 PCIe transfer per training step (+ tiny control transfers):
24//! 1. H2D: hidden states after embedding (seq×H×4 bytes)
25//! 2. H2D: target_ids for fused cross-entropy (seq×4 bytes — ~512B)
26//! 3. D2H: loss_partials from fused cross-entropy (seq×4 bytes — ~512B)
27//!
28//! Eliminated by KAIZEN-050:
29//! - D2H logits (was seq×V×4 = 77.8MB for Qwen3-4B)
30//! - H2D grad_logits (was seq×V×4 = 77.8MB)
31//!
32//! Eliminated by KAIZEN-052:
33//! - grad_gpu buffer allocation (was seq×V×4 = 77.8MB per step)
34
35#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42    gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46    adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47    gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48    squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64    CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65    GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75/// Compute gradient L2 norm of the shared workspace via GPU reduction (KAIZEN-054).
76///
77/// Uses `squared_sum_cuda` per buffer (~1KB D2H each) instead of downloading entire
78/// gradient buffers to CPU (was 58 MB+ per block, disabled in ALB-067).
79///
80/// Free function to avoid borrow conflicts with `&mut self`.
81#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83    ws: &CudaGradWorkspace,
84    max_norm: f32,
85    stream: &CudaStream,
86) -> (f32, f32) {
87    use crate::autograd::cuda_optim::PendingSquaredSum;
88
89    let all_bufs: [&GpuBuffer<f32>; 9] = [
90        &ws.grad_w_q,
91        &ws.grad_w_k,
92        &ws.grad_w_v,
93        &ws.grad_w_o,
94        &ws.grad_gate,
95        &ws.grad_up,
96        &ws.grad_down,
97        &ws.grad_input_norm,
98        &ws.grad_post_attn_norm,
99    ];
100
101    // KAIZEN-055: Launch all 9 squared_sum kernels back-to-back without syncing.
102    // Single sync after all launches — reduces 9 pipeline flushes to 1 per block.
103    let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104    for buf in &all_bufs {
105        let n = buf.len() as u32;
106        if n == 0 {
107            continue;
108        }
109        if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110            pending.push(p);
111        }
112    }
113
114    // Single sync point for all 9 kernel launches.
115    if stream.synchronize().is_err() {
116        return (1.0, 0.0);
117    }
118
119    // Collect results: download partial sums (~1KB each) and combine.
120    // C-CLIP-001: squared_sum_collect returns sum(x²) = ||g||².
121    // Accumulate directly — do NOT re-square (entrenar#311 fix).
122    let mut total_sq = 0.0f64;
123    for p in &pending {
124        if let Ok(sq_norm) = squared_sum_collect(p) {
125            total_sq += f64::from(sq_norm); // sq_norm is already ||g||²
126        }
127    }
128
129    let grad_norm = total_sq.sqrt() as f32; // L2 norm = sqrt(sum of squared norms)
130    let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131    (scale, grad_norm)
132}
133
134/// Clip all gradient buffers in the shared workspace using GPU-computed L2 norm (KAIZEN-054).
135///
136/// R-004: Returns pre-clip gradient L2 norm for observability logging.
137#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139    let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140    if (scale - 1.0).abs() < 1e-7 {
141        return grad_norm;
142    }
143
144    let n_wq = ws.grad_w_q.len() as u32;
145    let n_wk = ws.grad_w_k.len() as u32;
146    let n_wv = ws.grad_w_v.len() as u32;
147    let n_wo = ws.grad_w_o.len() as u32;
148    let n_gate = ws.grad_gate.len() as u32;
149    let n_up = ws.grad_up.len() as u32;
150    let n_down = ws.grad_down.len() as u32;
151    let n_inorm = ws.grad_input_norm.len() as u32;
152    let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154    let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155    let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156    let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157    let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158    let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159    let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160    let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163    grad_norm
164}
165
166/// ALB-078: Fused gradient clipping — entire pipeline stays on GPU.
167///
168/// Replaces `clip_workspace_gradients` by eliminating the stream.synchronize()
169/// and D2H partial-sum download. All computation happens on GPU:
170///
171/// 1. 9× SquaredSumKernel → write partials to pre-allocated contiguous buffer
172/// 2. 1× ClipScaleReduceKernel → reduce partials, compute scale on GPU
173/// 3. 9× GradientClipGpuScaleKernel → read scale from GPU, apply to gradients
174///
175/// Zero sync points, zero D2H transfers per block.
176#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178    ws: &mut CudaGradWorkspace,
179    max_norm: f32,
180    state: &FusedClipState,
181    stream: &CudaStream,
182) {
183    let all_bufs: [&GpuBuffer<f32>; 9] = [
184        &ws.grad_w_q,
185        &ws.grad_w_k,
186        &ws.grad_w_v,
187        &ws.grad_w_o,
188        &ws.grad_gate,
189        &ws.grad_up,
190        &ws.grad_down,
191        &ws.grad_input_norm,
192        &ws.grad_post_attn_norm,
193    ];
194
195    // Phase 1: Launch 9 squared_sum kernels into contiguous partials buffer.
196    // Each writes to state.partials_buf at its pre-computed offset.
197    for (i, buf) in all_bufs.iter().enumerate() {
198        let n = buf.len() as u32;
199        if n == 0 {
200            continue;
201        }
202        let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203        let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204    }
205
206    // Phase 2: Reduce all partials and compute clip_scale on GPU.
207    // Stream ordering guarantees all squared_sum kernels complete before this runs.
208    let _ = clip_scale_reduce_cuda(
209        &state.partials_buf,
210        state.total_partials,
211        max_norm,
212        &state.scale_buf,
213        stream,
214    );
215
216    // Phase 3: Apply clip scale to all 9 gradient buffers.
217    // Scale is read from GPU memory — no D2H needed.
218    let scale_ptr = state.scale_buf.as_ptr(); // output[0] = clip_scale
219    let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220        &mut ws.grad_w_q,
221        &mut ws.grad_w_k,
222        &mut ws.grad_w_v,
223        &mut ws.grad_w_o,
224        &mut ws.grad_gate,
225        &mut ws.grad_up,
226        &mut ws.grad_down,
227        &mut ws.grad_input_norm,
228        &mut ws.grad_post_attn_norm,
229    ];
230    for buf in &mut all_bufs_mut {
231        let n = buf.len() as u32;
232        if n == 0 {
233            continue;
234        }
235        let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236    }
237}
238
239/// R-004: Compute gradient L2 norm without clipping (for observability only).
240///
241/// Uses GPU reduction (KAIZEN-054). Only ~9KB D2H per call.
242#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245    let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246    norm
247}
248
249/// ALB-072: Unscale all gradient buffers in the shared workspace by `inv_scale`.
250///
251/// In fp16 AMP, the fused cross-entropy kernel multiplies loss_scale into the
252/// gradient output. All subsequent backward gradients carry this scaling. The
253/// GPU block optimizer (AdamW) must receive unscaled gradients — otherwise the
254/// second moment `v` overflows f32, producing NaN in early layers.
255///
256/// This is the GPU-side equivalent of `GradScaler::unscale_and_check()` used
257/// for CPU embedding gradients.
258#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261    if (inv_scale - 1.0).abs() < 1e-7 {
262        return;
263    }
264
265    let n_wq = ws.grad_w_q.len() as u32;
266    let n_wk = ws.grad_w_k.len() as u32;
267    let n_wv = ws.grad_w_v.len() as u32;
268    let n_wo = ws.grad_w_o.len() as u32;
269    let n_gate = ws.grad_gate.len() as u32;
270    let n_up = ws.grad_up.len() as u32;
271    let n_down = ws.grad_down.len() as u32;
272    let n_inorm = ws.grad_input_norm.len() as u32;
273    let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275    let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276    let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277    let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278    let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279    let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280    let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281    let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286/// GPU-resident training state for pretraining.
287///
288/// # Contract (C-GPUTRAIN-001)
289///
290/// - `layer_inputs.len() == num_layers`
291/// - All buffers preallocated at init; zero GPU allocations during training
292/// - `step` increments monotonically
293#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295    /// Saved layer inputs for backward [num_layers][seq_len * hidden_size]
296    layer_inputs: Vec<GpuBuffer<f32>>,
297    /// Which layer inputs were saved during forward (activation checkpointing).
298    /// When checkpointing is enabled, only checkpoint boundary layers are saved.
299    /// Non-saved layers are recomputed from the nearest checkpoint before backward.
300    saved_layer_mask: Vec<bool>,
301    /// Temporary buffer for activation recomputation [seq_len * hidden_size].
302    /// Used as the initial input when recomputing from a checkpoint boundary.
303    /// Only allocated when activation checkpointing is enabled.
304    recompute_buf: Option<GpuBuffer<f32>>,
305    /// Final RMSNorm weight on GPU [hidden_size]
306    final_norm_weight: GpuBuffer<f32>,
307    /// Final block output (pre-norm) for RMSNorm backward [seq_len * hidden_size]
308    blocks_output: GpuBuffer<f32>,
309    /// Alternating gradient buffer A [seq_len * hidden_size]
310    grad_buf_a: GpuBuffer<f32>,
311    /// Alternating gradient buffer B [seq_len * hidden_size]
312    grad_buf_b: GpuBuffer<f32>,
313    /// Gradient for final norm weight [hidden_size]
314    grad_final_norm_weight: GpuBuffer<f32>,
315    /// RMSNorm output buffer (reused each step) [seq_len * hidden_size]
316    norm_output: GpuBuffer<f32>,
317    /// Logits buffer (reused each step) [seq_len * vocab_size]
318    logits_buf: GpuBuffer<f32>,
319    /// LM head gradient buffer [seq_len * hidden_size] (grad w.r.t. normed hidden)
320    lm_head_grad_hidden: GpuBuffer<f32>,
321    /// Per-block optimizer states
322    optimizer_states: Vec<GpuBlockOptimizerState>,
323    /// Optimizer step counter
324    step: u32,
325}
326
327/// GPU-resident transformer trainer for pretraining.
328///
329/// Uses `CudaTransformerBlock` forward/backward/optimizer_step on GPU,
330/// keeping only embedding lookup and cross-entropy loss on CPU.
331///
332/// # Contract (C-GPUTRAIN-002)
333///
334/// - Exactly 3 PCIe transfers per training step
335/// - Graceful fallback to CPU `TransformerTrainer` on any CUDA failure
336/// - Weight sync via `sync_weights_to_cpu()` before save
337#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339    /// CPU model (for embedding, saving, fallback)
340    model: Transformer,
341    /// CUDA device context
342    cuda_trainer: CudaTrainer,
343    /// GPU-resident transformer blocks (fp32 or NF4 via CudaBlock enum)
344    cuda_blocks: Vec<CudaBlock>,
345    /// Shared gradient workspace (one set, reused across layers; fp32 path only)
346    cuda_grad_workspace: CudaGradWorkspace,
347    /// ENT-263: Shared scratch for NF4 blocks (C-SCRATCH-001). None when fp32.
348    nf4_shared_scratch: Option<CudaBlockScratch>,
349    /// ENT-263: Shared LoRA gradient workspace for NF4 backward. None when fp32.
350    nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351    /// ENT-263: Per-block LoRA optimizer states for NF4. None when fp32.
352    nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353    /// GPU training state (layer inputs, grad bufs, optimizer states)
354    gpu_training: GpuPretrainState,
355    /// LM head weight on GPU [vocab_size * hidden_size]
356    lm_head_weight_gpu: GpuBuffer<f32>,
357    /// LM head weight gradient on GPU [vocab_size * hidden_size]
358    lm_head_grad_gpu: GpuBuffer<f32>,
359    /// LM head AdamW first moment [vocab_size * hidden_size]
360    lm_head_m: GpuBuffer<f32>,
361    /// LM head AdamW second moment [vocab_size * hidden_size]
362    lm_head_v: GpuBuffer<f32>,
363    /// Final norm weight AdamW first moment [hidden_size]
364    final_norm_m: GpuBuffer<f32>,
365    /// Final norm weight AdamW second moment [hidden_size]
366    final_norm_v: GpuBuffer<f32>,
367    /// CPU optimizer for embedding weights only
368    embed_optimizer: AdamW,
369    /// Training configuration
370    config: TransformerTrainConfig,
371    /// Metrics tracker
372    pub metrics: MetricsTracker,
373    /// Current optimizer step
374    step: usize,
375    /// Accumulated loss (for gradient accumulation)
376    accumulated_loss: f32,
377    /// Accumulated batch count
378    accumulated_batches: usize,
379    /// R-004: Last observed LM head gradient L2 norm (proxy for global grad norm)
380    last_grad_norm: f32,
381    /// R-040: Last observed embedding activation gradient L2 norm
382    last_embed_grad_norm: f32,
383    /// R-038: Per-block gradient accumulation for true multi-step gradient accumulation.
384    /// Only allocated when accumulation_steps > 1. CPU-side buffers (~335 MB for 350M).
385    grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386    /// ALB-091: GPU-resident gradient accumulation (replaces CPU accum when available).
387    /// Eliminates 24 × ga stream.synchronize() + D2H transfers per optimizer step.
388    gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389    /// R-002: Gradient scaler for mixed-precision training.
390    /// For BF16: no-op (scale=1.0, dynamic=false).
391    /// For FP16: dynamic loss scaling to prevent gradient underflow.
392    grad_scaler: GradScaler,
393    /// KAIZEN-047: Per-step wall-clock profiler.
394    /// Reports timing breakdown for each training phase.
395    profiler: StepProfiler,
396    /// KAIZEN-053: Pre-allocated forward scratch buffers [max_seq_len * hidden_size].
397    /// Reused every step — eliminates 2 × cuMemAlloc/Free per training step.
398    fwd_scratch_a: GpuBuffer<f32>,
399    fwd_scratch_b: GpuBuffer<f32>,
400    /// KAIZEN-056: Pre-allocated CPU staging buffer for H2D hidden state upload.
401    /// Eliminates vec![0.0; max_seq_len * hidden_size] allocation per step.
402    h2d_staging: Vec<f32>,
403    /// KAIZEN-059: Pre-allocated CPU staging buffer for D2H gradient downloads
404    /// during gradient accumulation. Sized to max(h*intermediate, vocab*h).
405    /// Eliminates ~15GB of per-step heap churn (36 × vec![0.0; h*i] + vec![0.0; vocab*h]
406    /// per micro-batch × accumulation_steps).
407    d2h_staging: Vec<f32>,
408    /// ALB-078: Pre-allocated state for fused gradient clipping pipeline.
409    /// Eliminates 24 stream.synchronize() calls per step.
410    fused_clip: Option<FusedClipState>,
411    /// Pre-allocated host zero buffer for zeroing final norm grad [hidden_size].
412    /// BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
413    /// so the buffer must be zeroed before each rms_norm_backward call.
414    final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419    /// Create a new GPU-resident trainer.
420    ///
421    /// # Errors
422    ///
423    /// Returns `Err` if CUDA initialization, kernel pre-warming, or block upload fails.
424    /// Caller should fall back to CPU `TransformerTrainer` on error.
425    pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426        let model = Transformer::new(&config.model_config);
427        Self::with_model(model, config)
428    }
429
430    /// ALB-089: Load SafeTensors checkpoint for GPU inference (forward-only).
431    ///
432    /// Creates a `CudaTransformerTrainer` in inference mode. The optimizer
433    /// state is allocated (wasteful but simple), but `forward_logits()` only
434    /// uses the forward path. Call `forward_logits(&tokens)` to generate.
435    ///
436    /// # Arguments
437    /// * `checkpoint_dir` - Directory containing model.safetensors + config.json
438    /// * `model_config` - Transformer architecture config
439    ///
440    /// # Errors
441    ///
442    /// Returns `Err` if SafeTensors loading or CUDA initialization fails.
443    pub fn for_inference(
444        checkpoint_dir: impl AsRef<std::path::Path>,
445        model_config: crate::transformer::TransformerConfig,
446    ) -> crate::Result<Self> {
447        let dir = checkpoint_dir.as_ref();
448
449        // ALB-089: Try APR format first (our native checkpoint format), then SafeTensors
450        let model = if let Some((Some(m), _step)) =
451            crate::config::try_load_apr_for_inference(dir, &model_config)
452        {
453            m
454        } else {
455            Transformer::from_safetensors(dir, &model_config)?
456        };
457
458        let mut config = TransformerTrainConfig::new(model_config);
459        config.max_seq_len = config.model_config.max_position_embeddings;
460        Self::with_model(model, config)
461    }
462
463    /// Create a GPU-resident trainer from an existing model.
464    ///
465    /// # Errors
466    ///
467    /// Returns `Err` if CUDA initialization fails.
468    pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469        if !cuda_training_available() {
470            return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471        }
472
473        let mc = &config.model_config;
474        let max_seq_len = config.max_seq_len;
475        let hidden_size = mc.hidden_size;
476        let vocab_size = mc.vocab_size;
477        let num_layers = mc.num_hidden_layers;
478
479        // Step 1: Create CUDA trainer (initializes kernel caches)
480        let cuda_trainer = CudaTrainer::new().map_err(|e| {
481            crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482        })?;
483
484        println!(
485            "  GPU: {} ({:.1} GB)",
486            cuda_trainer.device_name(),
487            cuda_trainer.total_memory() as f64 / 1e9
488        );
489
490        let ctx = cuda_trainer.context().clone();
491        let stream = cuda_trainer.stream();
492
493        // Step 2: Pre-warm forward kernels (C-PREWARM-001)
494        // Must happen before block upload — JIT compilation needs free VRAM
495        pre_warm_forward_kernels(
496            hidden_size,
497            mc.intermediate_size,
498            mc.num_attention_heads,
499            mc.num_kv_heads,
500            mc.head_dim(),
501            max_seq_len,
502        )
503        .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505        // Step 2a: Pre-warm backward kernels (trueno#200)
506        // MUST happen before any GPU work — Blackwell's cuModuleLoadData fails
507        // with ILLEGAL_ADDRESS when called during active GPU computation.
508        {
509            use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510            let head_dim = mc.head_dim();
511            pre_warm_lora_backward_kernels(
512                hidden_size,
513                mc.num_attention_heads * head_dim,
514                mc.num_kv_heads * head_dim,
515                max_seq_len,
516                config.lora_rank.unwrap_or(0),
517                mc.intermediate_size,
518                mc.num_attention_heads,
519                config.quantize_nf4 && config.is_lora(),
520            )
521            .map_err(|e| {
522                crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523            })?;
524            eprintln!("  ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525        }
526
527        // Step 2b: Bind cuBLAS handles to training stream (ALB-075)
528        // Must happen after kernel cache init, before any GEMM calls.
529        if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530            println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531        }
532        if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533            println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534        }
535
536        // Step 3: Upload transformer blocks to GPU
537        let use_nf4 = config.quantize_nf4 && config.is_lora();
538        let cuda_blocks = Self::upload_blocks(
539            &model,
540            mc,
541            &config,
542            &ctx,
543            use_nf4,
544            num_layers,
545            hidden_size,
546            max_seq_len,
547        )?;
548
549        // Step 4: Allocate shared gradient workspace
550        let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551            crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552        })?;
553
554        // Step 5: Allocate GPU training state
555        let buf_size = max_seq_len * hidden_size;
556        let logits_size = max_seq_len * vocab_size;
557
558        // Activation checkpointing: determine which layers save their inputs.
559        // Checkpoint boundary layers (every segment_size layers) are always saved.
560        // Non-boundary layers are recomputed from the nearest checkpoint during backward.
561        let checkpointing = config.checkpoint_config.enabled;
562        let segment_size = if checkpointing {
563            let ns = config.checkpoint_config.num_segments.max(1);
564            num_layers.div_ceil(ns)
565        } else {
566            1 // Every layer is a checkpoint (no recomputation)
567        };
568        let saved_layer_mask: Vec<bool> =
569            (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571        let mut layer_inputs = Vec::with_capacity(num_layers);
572        for _ in 0..num_layers {
573            layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574                crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575            })?);
576        }
577
578        // Allocate recompute buffer if checkpointing is enabled
579        let recompute_buf = if checkpointing {
580            Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581                crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582            })?)
583        } else {
584            None
585        };
586
587        if checkpointing {
588            let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589            println!(
590                "  ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591                config.checkpoint_config.num_segments, saved_count, num_layers
592            );
593        }
594
595        // Upload final RMSNorm weight
596        let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597        let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598            crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599        })?;
600
601        let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602            crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603        })?;
604        let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605            crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606        })?;
607        let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608            crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609        })?;
610        let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611            crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612        })?;
613        let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614            crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615        })?;
616        let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617            crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618        })?;
619        let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621        })?;
622
623        // Initialize per-block optimizer states (fp32 path only; NF4 uses LoRA states)
624        let mut optimizer_states = Vec::new();
625        if !use_nf4 {
626            optimizer_states.reserve(num_layers);
627            for (i, block) in cuda_blocks.iter().enumerate() {
628                optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629                    crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630                })?);
631            }
632        }
633
634        let gpu_training = GpuPretrainState {
635            layer_inputs,
636            saved_layer_mask,
637            recompute_buf,
638            final_norm_weight,
639            blocks_output,
640            grad_buf_a,
641            grad_buf_b,
642            grad_final_norm_weight,
643            norm_output,
644            logits_buf,
645            lm_head_grad_hidden,
646            optimizer_states,
647            step: 0,
648        };
649
650        // Step 6: Upload LM head weight to GPU
651        // Use tied weights (embed_tokens.weight) or separate lm_head
652        let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653        let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654        let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655            crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656        })?;
657        let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659        })?;
660        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
661        // zero memory (cuMemAlloc returns uninitialized VRAM).
662        let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663            .map_err(|e| {
664                crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665            })?;
666        let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667            .map_err(|e| {
668                crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669            })?;
670
671        // Final norm optimizer states
672        let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673            crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674        })?;
675        let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676            crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677        })?;
678
679        // KAIZEN-053: Pre-allocate forward scratch buffers (reused every step)
680        let buf_size = max_seq_len * hidden_size;
681        let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682            crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683        })?;
684        let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685            crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686        })?;
687
688        // Sync to ensure all uploads completed
689        stream
690            .synchronize()
691            .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693        println!(
694            "  ✓ GPU training state allocated (LM head: {:.1} MB)",
695            (vocab_size * hidden_size * 4) as f64 / 1e6
696        );
697
698        // ENT-263: Allocate NF4 infrastructure (shared scratch, LoRA grad workspace, optimizer states)
699        let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700            let lora_rank = config.lora_rank.unwrap_or(16);
701
702            // C-SCRATCH-001: Shared scratch for NF4 blocks (reused across all layers)
703            let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704                crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705            })?;
706
707            // LoRA gradient workspace (shared, reused per-block like CudaGradWorkspace)
708            let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709                crate::error::Error::ConfigError(format!(
710                    "NF4 LoRA grad workspace alloc failed: {e:?}"
711                ))
712            })?;
713
714            // Per-block LoRA optimizer states
715            let mut lora_opt_states = Vec::with_capacity(num_layers);
716            for (i, block) in cuda_blocks.iter().enumerate() {
717                lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718                    crate::error::Error::ConfigError(format!(
719                        "Block {i} LoRA opt state failed: {e:?}"
720                    ))
721                })?);
722            }
723
724            println!(
725                "  ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726            );
727            (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728        } else {
729            (None, None, None)
730        };
731
732        // KAIZEN-050: loss_fn removed — cross-entropy computed by fused GPU kernel
733        // C-EMBED-GRAD-001: CPU optimizer must match YAML hyperparams (not defaults)
734        let embed_optimizer =
735            AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737        // R-038: Allocate per-block gradient accumulation buffers (CPU-side)
738        // when accumulation_steps > 1 for true gradient accumulation.
739        let grad_accum = if config.accumulation_steps > 1 {
740            let kv_hidden = mc.num_kv_heads * mc.head_dim();
741            let block_sizes =
742                super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743                    hidden_size,
744                    kv_hidden,
745                    mc.intermediate_size,
746                );
747            let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748                num_layers,
749                block_sizes,
750                vocab_size,
751                hidden_size,
752            );
753            println!(
754                "  ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755                config.accumulation_steps,
756                (accum
757                    .block_grads
758                    .iter()
759                    .map(super::grad_accumulator::BlockGradientSet::total_elements)
760                    .sum::<usize>()
761                    + accum.lm_head_grad.len()
762                    + accum.final_norm_grad.len()
763                    + accum.embedding_grad.len()) as f64
764                    * 4.0
765                    / 1e6,
766            );
767            Some(accum)
768        } else {
769            None
770        };
771
772        // ALB-091: GPU-resident gradient accumulation (eliminates D2H bottleneck).
773        // Falls back to CPU accum if GPU allocation fails.
774        let gpu_grad_accum = if config.accumulation_steps > 1 {
775            match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776                Ok(accum) => {
777                    println!("  ✓ GPU gradient accumulation enabled (ALB-091)");
778                    Some(accum)
779                }
780                Err(e) => {
781                    eprintln!(
782                        "  [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783                    );
784                    None
785                }
786            }
787        } else {
788            None
789        };
790
791        // KAIZEN-059: Pre-allocate D2H staging buffer for gradient accumulation
792        // downloads. Only needed when GPU accum is unavailable (CPU fallback path).
793        let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794            let ws_max = hidden_size * mc.intermediate_size;
795            let lm_max = vocab_size * hidden_size;
796            vec![0.0f32; ws_max.max(lm_max)]
797        } else {
798            Vec::new()
799        };
800
801        // ALB-078: Pre-allocate fused gradient clipping state.
802        // Eliminates 24 stream syncs per step by keeping norm+clip on GPU.
803        let kv_hidden = mc.num_kv_heads * mc.head_dim();
804        let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806        // R-002: Initialize gradient scaler from precision config
807        let grad_scaler = GradScaler::from_config(&config.precision_config);
808        if config.precision_config.is_mixed() {
809            println!(
810                "  ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811                config.precision_config.compute_precision,
812                grad_scaler.scale(),
813                grad_scaler.is_dynamic(),
814            );
815        }
816
817        Ok(Self {
818            model,
819            cuda_trainer,
820            cuda_blocks,
821            cuda_grad_workspace,
822            nf4_shared_scratch,
823            nf4_lora_grad_workspace,
824            nf4_lora_optimizer_states,
825            gpu_training,
826            lm_head_weight_gpu,
827            lm_head_grad_gpu,
828            lm_head_m,
829            lm_head_v,
830            final_norm_m,
831            final_norm_v,
832            embed_optimizer,
833            // KAIZEN-047: Read profile_interval before moving config into struct.
834            profiler: if config.profile_interval > 0 {
835                StepProfiler::new(true, config.profile_interval)
836            } else {
837                StepProfiler::disabled()
838            },
839            config,
840            metrics: MetricsTracker::new(),
841            step: 0,
842            accumulated_loss: 0.0,
843            accumulated_batches: 0,
844            last_grad_norm: 0.0,
845            last_embed_grad_norm: 0.0,
846            grad_accum,
847            gpu_grad_accum,
848            grad_scaler,
849            fwd_scratch_a,
850            fwd_scratch_b,
851            h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852            d2h_staging,
853            fused_clip,
854            final_norm_zero_buf: vec![0.0f32; hidden_size],
855        })
856    }
857
858    /// Upload transformer blocks to GPU (NF4 or fp32 path).
859    #[allow(clippy::too_many_arguments)]
860    fn upload_blocks(
861        model: &Transformer,
862        mc: &crate::transformer::TransformerConfig,
863        config: &TransformerTrainConfig,
864        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865        use_nf4: bool,
866        num_layers: usize,
867        hidden_size: usize,
868        max_seq_len: usize,
869    ) -> crate::Result<Vec<CudaBlock>> {
870        let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872        if use_nf4 {
873            let lora_rank = config.lora_rank.unwrap_or(16);
874            let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875            let lora_scale = lora_alpha / lora_rank as f32;
876            let head_dim = mc.head_dim();
877            let q_dim = mc.num_attention_heads * head_dim;
878            let kv_hidden = mc.num_kv_heads * head_dim;
879
880            for (i, layer) in model.layers.iter().enumerate() {
881                let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882                    .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883                    .collect();
884                let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885                let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886                    .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887                    .collect();
888                let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890                let q_norm_data = layer
891                    .self_attn
892                    .q_norm
893                    .as_ref()
894                    .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895                let k_norm_data = layer
896                    .self_attn
897                    .k_norm
898                    .as_ref()
899                    .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901                let block = crate::transformer::CudaNf4TransformerBlock::new(
902                    mc,
903                    i,
904                    ctx.clone(),
905                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
906                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
913                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
914                    max_seq_len,
915                    Some((&lora_a_q, &lora_b_q)),
916                    Some((&lora_a_v, &lora_b_v)),
917                    lora_scale,
918                    lora_rank,
919                    q_norm_data.as_deref(),
920                    k_norm_data.as_deref(),
921                )
922                .map_err(|e| {
923                    crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924                })?;
925                cuda_blocks.push(CudaBlock::Nf4(block));
926            }
927            println!("  ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928        } else {
929            for (i, layer) in model.layers.iter().enumerate() {
930                // FALSIFY-CUDA-FORWARD-PARITY-002 thread Q/K/V biases
931                // through to the CudaTransformerBlock when present
932                // (Qwen2 family use_bias=true). Pre-fix these were
933                // silently dropped → val_loss > ln(vocab) on Qwen.
934                let b_q = layer
935                    .self_attn
936                    .b_q
937                    .as_ref()
938                    .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939                let b_k = layer
940                    .self_attn
941                    .b_k
942                    .as_ref()
943                    .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944                let b_v = layer
945                    .self_attn
946                    .b_v
947                    .as_ref()
948                    .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949                let block = CudaTransformerBlock::new(
950                    mc,
951                    i,
952                    ctx.clone(),
953                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
954                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
961                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
962                    max_seq_len,
963                    b_q.as_deref(),
964                    b_k.as_deref(),
965                    b_v.as_deref(),
966                )
967                .map_err(|e| {
968                    crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969                })?;
970                cuda_blocks.push(CudaBlock::Fp32(block));
971            }
972            println!("  ✓ {num_layers} transformer blocks uploaded to GPU");
973        }
974
975        Ok(cuda_blocks)
976    }
977
978    /// ALB-078: Initialize fused gradient clipping state (extracted for complexity).
979    fn init_fused_clip(
980        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981        config: &TransformerTrainConfig,
982        hidden_size: usize,
983        kv_hidden: usize,
984        mc: &crate::transformer::TransformerConfig,
985    ) -> Option<FusedClipState> {
986        config.base.max_grad_norm?;
987        let grad_sizes: [u32; 9] = [
988            (hidden_size * hidden_size) as u32,
989            (hidden_size * kv_hidden) as u32,
990            (hidden_size * kv_hidden) as u32,
991            (hidden_size * hidden_size) as u32,
992            (hidden_size * mc.intermediate_size) as u32,
993            (hidden_size * mc.intermediate_size) as u32,
994            (mc.intermediate_size * hidden_size) as u32,
995            hidden_size as u32,
996            hidden_size as u32,
997        ];
998        match FusedClipState::new(ctx, &grad_sizes) {
999            Ok(state) => {
1000                println!(
1001                    "  ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002                    state.total_partials,
1003                    f64::from(state.total_partials) * 4.0 / 1024.0,
1004                );
1005                Some(state)
1006            }
1007            Err(e) => {
1008                println!("  ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009                None
1010            }
1011        }
1012    }
1013
1014    /// Run one forward+backward step for a single sequence.
1015    ///
1016    /// # Contract (C-GPUSTEP-001)
1017    ///
1018    /// - Precondition: `input_ids.len() == target_ids.len() <= max_seq_len`
1019    /// - Postcondition: If `accumulate_only` is false, all GPU weights updated.
1020    ///   If true, gradients accumulated into CPU buffers (no weight updates).
1021    /// - Transfer count: 1 PCIe H2D + ~1KB control (KAIZEN-050, + 24×9 D2H if accumulating)
1022    fn train_step_single(
1023        &mut self,
1024        input_ids: &[u32],
1025        target_ids: &[u32],
1026        accumulate_only: bool,
1027    ) -> Option<f32> {
1028        self.profiler.begin_step();
1029        let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030        self.profiler.finish_step();
1031        result
1032    }
1033
1034    /// Inner training step — separated so profiler always records the step.
1035    fn train_step_inner(
1036        &mut self,
1037        input_ids: &[u32],
1038        target_ids: &[u32],
1039        accumulate_only: bool,
1040    ) -> Option<f32> {
1041        let hidden_size = self.config.model_config.hidden_size;
1042        let vocab_size = self.config.model_config.vocab_size;
1043
1044        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
1045        let max_sl = self.config.max_seq_len;
1046        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048        let seq_len = input_ids.len();
1049
1050        // Steps 1-6: GPU forward pass — logits stay GPU-resident (KAIZEN-050)
1051        // (sub-phases embed, h2d, forward, norm_lm instrumented inside gpu_forward)
1052        if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053            eprintln!(
1054                "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055                 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056            );
1057            return None;
1058        }
1059
1060        // Step 7: Fused GPU cross-entropy loss + softmax backward (KAIZEN-050)
1061        // Eliminates: logits D2H (77.8MB) + CPU softmax (40ms) + grad H2D (77.8MB)
1062        self.profiler.begin(StepProfiler::LOSS);
1063        let stream = self.cuda_trainer.stream();
1064
1065        // Compute combined scale: (1/seq_len) * (1/accum_steps)
1066        //
1067        // ALB-072: Do NOT multiply by grad_scaler.scale() here. All backward
1068        // computation uses f32 GpuBuffers — there is no fp16 gradient underflow
1069        // risk. The 65536x loss scaling caused gradient overflow in early layers
1070        // (blocks 0-1 went NaN). The GradScaler remains active for the CPU
1071        // embedding path (unscale_and_check in optimizer_step) as a safety check,
1072        // but it operates with scale=1.0 effective for GPU gradients.
1073        let mut loss_scale = 1.0 / seq_len as f32;
1074        if self.config.accumulation_steps > 1 {
1075            loss_scale /= self.config.accumulation_steps as f32;
1076        }
1077
1078        // KAIZEN-052: In-place — gradient written directly to logits_buf.
1079        let loss_val = fused_cross_entropy_cuda(
1080            &mut self.gpu_training.logits_buf,
1081            target_ids,
1082            seq_len as u32,
1083            vocab_size as u32,
1084            loss_scale,
1085            stream,
1086        )
1087        .ok()?;
1088
1089        // NaN guard (replaces logits NaN check — NaN logits → NaN loss via kernel)
1090        if !loss_val.is_finite() {
1091            return None;
1092        }
1093        self.profiler.end(StepProfiler::LOSS);
1094
1095        // Steps 8-11: GPU backward pass (with or without optimizer)
1096        // (sub-phases lm_bwd, norm_bwd, blk_bwd instrumented inside gpu_backward)
1097        // KAIZEN-050: grad_logits on GPU. KAIZEN-052: grad lives in logits_buf (in-place).
1098        //
1099        // ENT-263 fix: Capture loss regardless of backward success. The NF4 backward
1100        // path may fail (e.g., gemm_nf4_backward_a stub) but the loss was already
1101        // computed by fused_cross_entropy_cuda. Dropping the loss silently causes
1102        // loss=0.0 reporting despite valid forward passes.
1103        if let Some(grad_output_is_a) =
1104            self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105        {
1106            // Step 12: Embedding backward (CPU scatter-add always accumulates)
1107            self.profiler.begin(StepProfiler::EMBED_BWD);
1108            self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110            self.profiler.end(StepProfiler::EMBED_BWD);
1111        }
1112
1113        Some(loss_val)
1114    }
1115
1116    /// GPU forward pass: embed → blocks → norm → LM head.
1117    ///
1118    /// Logits stay GPU-resident in `self.gpu_training.logits_buf` (KAIZEN-050).
1119    /// Transfers: 1 H2D (hidden states). No D2H — logits consumed by fused kernel.
1120    #[allow(unsafe_code)]
1121    fn gpu_forward(
1122        &mut self,
1123        input_ids: &[u32],
1124        seq_len: usize,
1125        hidden_size: usize,
1126        vocab_size: usize,
1127    ) -> Option<()> {
1128        contract_pre_gpu_forward!();
1129        let stream = self.cuda_trainer.stream();
1130
1131        // Embedding lookup (CPU)
1132        self.profiler.begin(StepProfiler::EMBED);
1133        let hidden = self.model.embed_tokens.forward(input_ids);
1134        let hidden_slice = hidden.data().as_slice()?;
1135        self.profiler.end(StepProfiler::EMBED);
1136
1137        // Upload hidden states to GPU (Transfer 1: H2D)
1138        // Pad to max_seq_len so D2D copies to pre-allocated layer_inputs match.
1139        // KAIZEN-053: Reuse pre-allocated scratch buffers instead of cuMemAlloc per step.
1140        // KAIZEN-056: Reuse pre-allocated h2d_staging instead of alloc per step.
1141        self.profiler.begin(StepProfiler::H2D);
1142        self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143        self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144        if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145            eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146            return None;
1147        }
1148        self.profiler.end(StepProfiler::H2D);
1149
1150        // Forward through CUDA blocks using pre-allocated ping-pong buffers.
1151        // KAIZEN-053: fwd_scratch_a/b are top-level fields (not in gpu_training)
1152        // so borrowing them doesn't conflict with gpu_training.layer_inputs.
1153        self.profiler.begin(StepProfiler::FORWARD);
1154        let mut input_is_a = true; // Track which scratch buffer is "input"
1155        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156            // Use raw pointers for the ping-pong to avoid borrow conflicts
1157            // with self.gpu_training.layer_inputs
1158            let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159                if input_is_a {
1160                    (
1161                        std::ptr::from_ref(&self.fwd_scratch_a),
1162                        std::ptr::from_mut(&mut self.fwd_scratch_b),
1163                    )
1164                } else {
1165                    (
1166                        std::ptr::from_ref(&self.fwd_scratch_b),
1167                        std::ptr::from_mut(&mut self.fwd_scratch_a),
1168                    )
1169                };
1170            if self.gpu_training.saved_layer_mask[i] {
1171                // SAFETY: Both buffers are valid GPU allocations with matching max_seq_len size.
1172                // Copy completes before block.forward() reads from input (same stream ordering).
1173                unsafe {
1174                    self.gpu_training.layer_inputs[i]
1175                        .copy_from_buffer_async(&*input_ptr, stream)
1176                        .ok()?;
1177                }
1178            }
1179            // SAFETY: input_ptr and output_ptr point to disjoint fwd_scratch_{a,b}.
1180            // ENT-263: Pass shared scratch for NF4 blocks (C-SCRATCH-001).
1181            self.profiler.begin_layer();
1182            unsafe {
1183                block
1184                    .forward(
1185                        &*input_ptr,
1186                        &mut *output_ptr,
1187                        seq_len,
1188                        stream,
1189                        self.nf4_shared_scratch.as_mut(),
1190                    )
1191                    .ok()?;
1192            }
1193            self.profiler.end_layer_fwd(i);
1194            input_is_a = !input_is_a;
1195        }
1196        self.profiler.end(StepProfiler::FORWARD);
1197
1198        // After the loop, input_is_a tells us which buffer has the final output
1199        let final_output: &GpuBuffer<f32> =
1200            if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202        // Save blocks output for final norm backward
1203        // SAFETY: Disjoint GPU buffers with matching max_seq_len sizes.
1204        self.profiler.begin(StepProfiler::NORM_LM);
1205        unsafe {
1206            self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207        }
1208
1209        // Final RMSNorm forward (GPU)
1210        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: thread `config.rms_norm_eps`
1211        // through so Qwen2 (1e-6) gets the right epsilon. Pre-fix the
1212        // legacy `rms_norm_forward` hardcoded 1e-5 (Llama default).
1213        rms_norm_forward_with_eps(
1214            final_output,
1215            &self.gpu_training.final_norm_weight,
1216            &mut self.gpu_training.norm_output,
1217            seq_len as u32,
1218            hidden_size as u32,
1219            self.config.model_config.rms_norm_eps,
1220            stream,
1221        )
1222        .ok()?;
1223
1224        // LM head GEMM forward (GPU)
1225        // gemm_forward treats flat (V,H) memory as (H,V) row-major, which
1226        // implicitly transposes — matching the CPU matmul's tied-weight behavior.
1227        gemm_forward(
1228            &self.gpu_training.norm_output,
1229            &self.lm_head_weight_gpu,
1230            &mut self.gpu_training.logits_buf,
1231            seq_len as u32,
1232            hidden_size as u32,
1233            vocab_size as u32,
1234            stream,
1235        )
1236        .ok()?;
1237
1238        // KAIZEN-050: Logits stay GPU-resident — no D2H transfer.
1239        // Fused cross-entropy kernel reads logits_buf directly on GPU.
1240        self.profiler.end(StepProfiler::NORM_LM);
1241
1242        Some(())
1243    }
1244
1245    /// ALB-089: Forward-only pass that returns last-position logits on CPU.
1246    ///
1247    /// Runs the same GPU forward as training but downloads only the last
1248    /// SPEC-DISTILL-001 Phase 2d (PMAT-697): forward + caller-supplied
1249    /// logit-gradient backward + optimizer step.
1250    ///
1251    /// Unlike `forward_backward_batch` (which computes the gradient from
1252    /// CE loss internally), this method takes a precomputed last-position
1253    /// logit gradient — useful for knowledge distillation where the
1254    /// gradient is computed externally as the KD logit gradient
1255    /// `α·(softmax(s) - one_hot(label)) + (1-α)·T·(softmax(s/T) - softmax(t/T))`
1256    /// (per `aprender-train-distill::kd_step::kd_logit_gradient`).
1257    ///
1258    /// Flow:
1259    /// 1. `gpu_forward(input_ids)` — produces last-position logits in
1260    ///    `gpu_training.logits_buf`.
1261    /// 2. Upload `logit_gradient` into the last-position slice of
1262    ///    `logits_buf`, OVERWRITING what gpu_forward produced (matching
1263    ///    the in-place gradient convention `fused_cross_entropy_cuda`
1264    ///    uses for the CE path).
1265    /// 3. `gpu_backward` — back-props from the uploaded gradient through
1266    ///    the transformer stack, accumulating weight gradients.
1267    /// 4. `embed_backward` — embedding-table scatter-add.
1268    ///
1269    /// **Limitations** (Phase 2d):
1270    /// - The gradient applies to the LAST POSITION only (this is the KD
1271    ///   training objective for next-token-prediction). Sequence-wise KD
1272    ///   (every position) is a Phase 2e enhancement.
1273    /// - Returns `Some(())` on success, `None` on CUDA failure. Loss is
1274    ///   not computed (caller computes from kd_loss separately).
1275    ///
1276    /// # Errors
1277    ///
1278    /// Returns `None` if `gpu_forward`, the gradient upload, or
1279    /// `gpu_backward` fails. The CUDA stream may be in a poisoned state
1280    /// after such a failure; subsequent training steps should be
1281    /// considered unreliable.
1282    pub fn forward_backward_with_grad(
1283        &mut self,
1284        input_ids: &[u32],
1285        logit_gradient: &[f32],
1286    ) -> Option<()> {
1287        let seq_len = input_ids.len();
1288        let hidden_size = self.config.model_config.hidden_size;
1289        let vocab_size = self.config.model_config.vocab_size;
1290
1291        if seq_len == 0 || seq_len > self.config.max_seq_len {
1292            return None;
1293        }
1294        if logit_gradient.len() != vocab_size {
1295            eprintln!(
1296                "[forward_backward_with_grad] gradient len {} != vocab_size {}",
1297                logit_gradient.len(),
1298                vocab_size
1299            );
1300            return None;
1301        }
1302
1303        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1304
1305        // Upload the KD gradient into the last-position slice of logits_buf,
1306        // replacing whatever gpu_forward wrote there. This matches the
1307        // KAIZEN-052 in-place gradient convention that gpu_backward expects.
1308        let offset = (seq_len - 1) * vocab_size;
1309        self.gpu_training.logits_buf.copy_from_host_at(logit_gradient, offset).ok()?;
1310        let stream = self.cuda_trainer.stream();
1311        stream.synchronize().ok()?;
1312
1313        // Back-prop from the uploaded gradient through the transformer.
1314        // accumulate_only=false → run the optimizer step at the end.
1315        let grad_output_is_a = self.gpu_backward(seq_len, hidden_size, vocab_size, false)?;
1316        // Embedding backward (CPU scatter-add). Pre-condition: grad_output_is_a
1317        // is the buffer-flip flag from gpu_backward (per existing
1318        // `train_step_inner` pattern at line ~1108).
1319        self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1320
1321        Some(())
1322    }
1323
1324    /// position's logits (vocab_size floats) for token sampling. No backward
1325    /// pass, no loss computation.
1326    ///
1327    /// # Contract (C-CUDA-INF-001)
1328    ///
1329    /// - Same forward path as `gpu_forward()` — identical logits
1330    /// - Only downloads `logits[seq_len-1, :]` (128 KB for 32K vocab)
1331    /// - stream.synchronize() before D2H (C-STREAMSYNC-001)
1332    pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1333        let seq_len = input_ids.len();
1334        let hidden_size = self.config.model_config.hidden_size;
1335        let vocab_size = self.config.model_config.vocab_size;
1336
1337        if seq_len == 0 || seq_len > self.config.max_seq_len {
1338            return None;
1339        }
1340
1341        // Reuse gpu_forward for the actual computation
1342        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1343
1344        // C-STREAMSYNC-001: synchronize before D2H
1345        let stream = self.cuda_trainer.stream();
1346        stream.synchronize().ok()?;
1347
1348        // Download last position logits only: logits_buf[seq_len-1, :]
1349        let offset = (seq_len - 1) * vocab_size;
1350        let mut logits = vec![0.0f32; vocab_size];
1351        self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1352
1353        Some(logits)
1354    }
1355
1356    /// GPU backward pass with interleaved per-block optimizer step.
1357    ///
1358    /// Each block's backward writes weight gradients to the shared `CudaGradWorkspace`.
1359    /// Recompute layer inputs for a segment during backward (activation checkpointing).
1360    ///
1361    /// When checkpointing is enabled, non-checkpoint layers don't save their inputs
1362    /// during forward. Before their backward pass, we recompute from the nearest
1363    /// checkpoint by re-running forward through intermediate blocks.
1364    ///
1365    /// This recomputes the entire segment [checkpoint..=target_layer], storing
1366    /// intermediate layer_inputs so subsequent layers in the same segment don't
1367    /// need redundant recomputation.
1368    ///
1369    /// # Contract (R-021)
1370    ///
1371    /// After this call, `layer_inputs[i]` is valid for all i in [checkpoint..=target_layer].
1372    #[allow(unsafe_code)]
1373    fn recompute_segment(
1374        gpu_training: &mut GpuPretrainState,
1375        cuda_blocks: &mut [CudaBlock],
1376        nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1377        target_layer: usize,
1378        seq_len: usize,
1379        stream: &CudaStream,
1380    ) -> Option<()> {
1381        // Find nearest saved checkpoint at or before target
1382        let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1383
1384        if seg_start == target_layer {
1385            return Some(()); // Already saved
1386        }
1387
1388        // Copy checkpoint input to recompute_buf as starting point.
1389        // SAFETY: recompute_buf and layer_inputs are disjoint allocations.
1390        let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1391        unsafe {
1392            recompute_buf
1393                .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1394                .ok()?;
1395        }
1396
1397        // Forward through blocks [seg_start..target_layer], saving intermediate inputs.
1398        // For block i, input → block i → output becomes input for block i+1.
1399        // We save output (= input to block i+1) in layer_inputs[i+1].
1400        //
1401        // Buffer pattern:
1402        //   i == seg_start: input = recompute_buf, output = layer_inputs[seg_start+1]
1403        //   i > seg_start:  input = layer_inputs[i], output = layer_inputs[i+1]
1404        //
1405        // SAFETY: split_at_mut ensures non-overlapping borrows of layer_inputs.
1406        // recompute_buf is separate from layer_inputs.
1407        for i in seg_start..target_layer {
1408            if i == seg_start {
1409                // Input is in recompute_buf, output goes to layer_inputs[i+1]
1410                let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1411                let li = &mut gpu_training.layer_inputs;
1412                unsafe {
1413                    cuda_blocks[i]
1414                        .forward(
1415                            &*recompute_ptr,
1416                            &mut li[i + 1],
1417                            seq_len,
1418                            stream,
1419                            nf4_shared_scratch.as_mut(),
1420                        )
1421                        .ok()?;
1422                }
1423            } else {
1424                // Input = layer_inputs[i], output = layer_inputs[i+1]
1425                let li = &mut gpu_training.layer_inputs;
1426                let (left, right) = li.split_at_mut(i + 1);
1427                cuda_blocks[i]
1428                    .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1429                    .ok()?;
1430            }
1431        }
1432
1433        Some(())
1434    }
1435
1436    /// Since `gemm_backward_b` overwrites (not accumulates), we must run each block's
1437    /// optimizer step immediately after its backward, before the next block overwrites
1438    /// the workspace. This also enables per-block gradient clipping.
1439    ///
1440    /// When `accumulate_only` is true (R-038 gradient accumulation), the per-block
1441    /// optimizer steps are skipped and workspace gradients are downloaded to CPU-side
1442    /// `PerBlockGradientAccumulator` instead. LM head and final norm gradients are
1443    /// also downloaded and accumulated. The optimizer step is deferred until
1444    /// `gpu_optimizer_from_accum()` is called.
1445    ///
1446    /// Returns `grad_output_is_a` flag for embedding backward.
1447    /// Transfer: 0 H2D (KAIZEN-050/052: grad in logits_buf) + 24×9 D2H if accumulating.
1448    #[allow(unsafe_code)]
1449    fn gpu_backward(
1450        &mut self,
1451        seq_len: usize,
1452        hidden_size: usize,
1453        vocab_size: usize,
1454        accumulate_only: bool,
1455    ) -> Option<bool> {
1456        let stream = self.cuda_trainer.stream();
1457        let max_grad_norm = self.config.base.max_grad_norm;
1458        let lr = self.current_lr();
1459        // ALB-072: No inv_scale needed — loss_scale no longer includes grad_scaler.
1460        let beta1 = self.config.beta1;
1461        let beta2 = self.config.beta2;
1462        let weight_decay = self.config.weight_decay;
1463
1464        // KAIZEN-050: grad_logits GPU-resident. KAIZEN-052: grad lives in logits_buf (in-place).
1465        // No separate grad buffer. No GRAD_H2D transfer.
1466
1467        // LM head GEMM backward
1468        self.profiler.begin(StepProfiler::LM_BWD);
1469        gemm_backward_a(
1470            &self.gpu_training.logits_buf,
1471            &self.lm_head_weight_gpu,
1472            &mut self.gpu_training.lm_head_grad_hidden,
1473            seq_len as u32,
1474            hidden_size as u32,
1475            vocab_size as u32,
1476            stream,
1477        )
1478        .ok()?;
1479
1480        gemm_backward_b(
1481            &self.gpu_training.norm_output,
1482            &self.gpu_training.logits_buf,
1483            &mut self.lm_head_grad_gpu,
1484            seq_len as u32,
1485            hidden_size as u32,
1486            vocab_size as u32,
1487            stream,
1488        )
1489        .ok()?;
1490
1491        // Clip LM head weight gradient
1492        // KAIZEN-049: GPU norm reduction.
1493        // KAIZEN-051: No explicit sync needed — same stream ordering.
1494        // ALB-071: Always compute LM head grad norm for observability (R-004).
1495        // C-CLIP-001: squared_sum_cuda returns ||g||². Take sqrt for L2 norm (entrenar#311).
1496        let lm_sq_norm =
1497            squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1498                .unwrap_or(0.0);
1499        let lm_norm = lm_sq_norm.sqrt(); // L2 norm, NOT squared
1500        self.last_grad_norm = lm_norm; // R-004: capture for observability
1501                                       // C-BACKPARITY-001: LM head gradient norm tracing (pre-clip).
1502        if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1503            eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1504            // Also trace the grad_hidden flowing to blocks
1505            let gh_sq = squared_sum_cuda(
1506                &self.gpu_training.lm_head_grad_hidden,
1507                self.gpu_training.lm_head_grad_hidden.len() as u32,
1508                stream,
1509            )
1510            .unwrap_or(0.0);
1511            eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1512        }
1513        if let Some(max_norm) = max_grad_norm {
1514            let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1515            let n = self.lm_head_grad_gpu.len() as u32;
1516            let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1517        }
1518        self.profiler.end(StepProfiler::LM_BWD);
1519
1520        // Final RMSNorm backward
1521        self.profiler.begin(StepProfiler::NORM_BWD);
1522        // Zero grad_final_norm_weight before backward — kernel accumulates via atomicAdd
1523        self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1524        rms_norm_backward(
1525            &self.gpu_training.blocks_output,
1526            &self.gpu_training.final_norm_weight,
1527            &self.gpu_training.lm_head_grad_hidden,
1528            &mut self.gpu_training.grad_buf_a,
1529            &mut self.gpu_training.grad_final_norm_weight,
1530            seq_len as u32,
1531            hidden_size as u32,
1532            1e-5_f32,
1533            stream,
1534        )
1535        .ok()?;
1536
1537        // Clip final norm weight gradient
1538        // KAIZEN-051: No explicit sync needed — same stream ordering as LM head clip.
1539        if let Some(max_norm) = max_grad_norm {
1540            let (scale, _) = Self::compute_clip_scale_with_norm(
1541                &self.gpu_training.grad_final_norm_weight,
1542                max_norm,
1543                stream,
1544            );
1545            let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1546            let _ =
1547                gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1548        }
1549        self.profiler.end(StepProfiler::NORM_BWD);
1550
1551        // R-038: Either accumulate non-block grads or run non-block optimizer.
1552        if accumulate_only {
1553            // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1554            if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1555                let _ = gpu_accum.accumulate_nonblock(
1556                    &self.lm_head_grad_gpu,
1557                    &self.gpu_training.grad_final_norm_weight,
1558                    stream,
1559                );
1560            } else {
1561                stream.synchronize().ok()?;
1562                Self::download_nonblock_grads_to_accum(
1563                    &self.lm_head_grad_gpu,
1564                    &self.gpu_training.grad_final_norm_weight,
1565                    &mut self.grad_accum,
1566                    &mut self.d2h_staging,
1567                )?;
1568            }
1569        } else {
1570            Self::run_nonblock_optimizer_step(
1571                &mut self.gpu_training,
1572                Some(&mut self.lm_head_weight_gpu),
1573                &self.lm_head_grad_gpu,
1574                &mut self.lm_head_m,
1575                &mut self.lm_head_v,
1576                &mut self.final_norm_m,
1577                &mut self.final_norm_v,
1578                lr,
1579                beta1,
1580                beta2,
1581                weight_decay,
1582                stream,
1583            );
1584        }
1585
1586        // Backward through blocks in reverse, with interleaved clip + optimizer.
1587        // Each block's backward writes weight gradients to shared CudaGradWorkspace.
1588        //
1589        // SAFETY: grad_buf_a and grad_buf_b are disjoint fields. Raw pointers
1590        // allow alternating read/write without violating aliasing rules.
1591        self.profiler.begin(StepProfiler::BLK_BWD);
1592        let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1593        let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1594        let mut grad_output_is_a = true;
1595        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1596
1597        for layer_idx in (0..self.cuda_blocks.len()).rev() {
1598            // Activation checkpointing: if this layer's input wasn't saved during
1599            // forward, recompute the segment from the nearest checkpoint.
1600            if !self.gpu_training.saved_layer_mask[layer_idx] {
1601                Self::recompute_segment(
1602                    &mut self.gpu_training,
1603                    &mut self.cuda_blocks,
1604                    &mut self.nf4_shared_scratch,
1605                    layer_idx,
1606                    seq_len,
1607                    stream,
1608                )?;
1609            }
1610
1611            let (grad_output, grad_input) = unsafe {
1612                if grad_output_is_a {
1613                    (&*grad_a_ptr, &mut *grad_b_ptr)
1614                } else {
1615                    (&*grad_b_ptr, &mut *grad_a_ptr)
1616                }
1617            };
1618
1619            self.profiler.begin_layer();
1620            if use_nf4 {
1621                // ENT-263: NF4 backward — LoRA gradient computation
1622                // Uses backward_nf4() which computes gradients for LoRA weights and norms only.
1623                // output_scratch reuses grad_buf_a/b as temporary storage for recomputed forward.
1624                let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1625                    grad_b_ptr // grad_input is in b, use as output_scratch too (will be overwritten)
1626                } else {
1627                    grad_a_ptr
1628                };
1629                // We need a separate output_scratch. Reuse blocks_output as scratch since
1630                // it was already consumed for norm backward above.
1631                match self.cuda_blocks[layer_idx].backward_nf4(
1632                    &self.gpu_training.layer_inputs[layer_idx],
1633                    grad_output,
1634                    grad_input,
1635                    &mut self.gpu_training.blocks_output, // reuse as output_scratch
1636                    seq_len,
1637                    stream,
1638                    self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1639                    self.nf4_lora_grad_workspace
1640                        .as_mut()
1641                        .expect("NF4 requires LoRA grad workspace"),
1642                ) {
1643                    Ok(()) => {}
1644                    Err(e) => {
1645                        eprintln!(
1646                            "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1647                            layer_idx, e, seq_len, self.config.model_config.hidden_size
1648                        );
1649                        return None;
1650                    }
1651                }
1652
1653                // ENT-265: Clip LoRA gradients before optimizer step.
1654                // Without this, NF4 LoRA grads are unbounded — causes weight
1655                // divergence and embedding grad explosion (Run 7c: 26M at step 225).
1656                if let Some(max_norm) = max_grad_norm {
1657                    self.nf4_lora_grad_workspace
1658                        .as_mut()
1659                        .expect("NF4 requires LoRA grad ws")
1660                        .clip_gradients(max_norm, stream);
1661                }
1662
1663                // NF4 LoRA optimizer step — always runs, even during accumulation.
1664                //
1665                // BUG FIX (entrenar#264): Previously gated by `if !accumulate_only`.
1666                // Design: NF4 LoRA has ~6M params, so we scale lr by 1/accum_steps
1667                // for micro-batches instead of accumulating gradients.
1668                {
1669                    let step = self.gpu_training.step;
1670                    let effective_lr = if accumulate_only {
1671                        lr / self.config.accumulation_steps as f32
1672                    } else {
1673                        lr
1674                    };
1675                    if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1676                        let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1677                            &mut opt_states[layer_idx],
1678                            step,
1679                            effective_lr,
1680                            beta1,
1681                            beta2,
1682                            1e-8,
1683                            weight_decay,
1684                            stream,
1685                            self.nf4_lora_grad_workspace
1686                                .as_ref()
1687                                .expect("NF4 requires LoRA grad ws"),
1688                        );
1689                    }
1690                }
1691            } else {
1692                // Standard fp32 backward path
1693                self.cuda_blocks[layer_idx]
1694                    .backward(
1695                        &self.gpu_training.layer_inputs[layer_idx],
1696                        grad_output,
1697                        grad_input,
1698                        seq_len,
1699                        stream,
1700                        &mut self.cuda_grad_workspace,
1701                    )
1702                    .ok()?;
1703
1704                // C-CLIP-001 / entrenar#312: DISABLED per-block gradient clipping.
1705                // Per-block clipping distorts gradient flow across layers.
1706
1707                // C-BACKPARITY-001: Per-block gradient norm tracing for parity testing.
1708                // Only runs when ENTRENAR_TRACE_GRADIENTS=1 — zero overhead in production.
1709                if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1710                    let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1711                        &self.cuda_grad_workspace,
1712                        f32::MAX,
1713                        stream,
1714                    );
1715                    // Also trace the activation gradient (flows between blocks)
1716                    let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1717                        .unwrap_or(0.0);
1718                    let act_gnorm = act_sq.sqrt();
1719                    eprintln!(
1720                        "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1721                    );
1722                }
1723
1724                // R-038: Either accumulate workspace grads or run optimizer per-block.
1725                if accumulate_only {
1726                    // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1727                    if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1728                        let _ = gpu_accum.accumulate_block(
1729                            &self.cuda_grad_workspace,
1730                            layer_idx,
1731                            stream,
1732                        );
1733                    } else {
1734                        // CPU fallback: SYNC + D2H (ALB-065 / Rule 6).
1735                        stream.synchronize().ok()?;
1736                        if let Some(accum) = &mut self.grad_accum {
1737                            Self::download_workspace_to_accum(
1738                                &self.cuda_grad_workspace,
1739                                accum,
1740                                layer_idx,
1741                                &mut self.d2h_staging,
1742                            )?;
1743                        }
1744                    }
1745                } else {
1746                    // Per-block optimizer step: consume workspace gradients before next block overwrites
1747                    let step = self.gpu_training.step;
1748                    let _ = self.cuda_blocks[layer_idx].optimizer_step(
1749                        &mut self.gpu_training.optimizer_states[layer_idx],
1750                        step,
1751                        lr,
1752                        beta1,
1753                        beta2,
1754                        1e-8,
1755                        weight_decay,
1756                        stream,
1757                        &self.cuda_grad_workspace,
1758                    );
1759                }
1760            }
1761
1762            self.profiler.end_layer_bwd(layer_idx);
1763            grad_output_is_a = !grad_output_is_a;
1764        }
1765
1766        stream.synchronize().ok()?;
1767        self.profiler.end(StepProfiler::BLK_BWD);
1768
1769        Some(grad_output_is_a)
1770    }
1771
1772    /// R-038: Download non-block (LM head + final norm) gradients to CPU accumulator.
1773    /// Static method to avoid borrow conflicts.
1774    // KAIZEN-044: Pre-allocate single buffer for LM head + norm D2H downloads.
1775    // lm_head_grad is vocab×hidden (389M elements = 1.5 GB for Qwen3-4B).
1776    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1777    fn download_nonblock_grads_to_accum(
1778        lm_head_grad: &GpuBuffer<f32>,
1779        final_norm_grad: &GpuBuffer<f32>,
1780        grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1781        host: &mut [f32],
1782    ) -> Option<()> {
1783        let accum = grad_accum.as_mut()?;
1784
1785        let lm_slice = &mut host[..lm_head_grad.len()];
1786        lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1787        for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1788            *d += s;
1789        }
1790
1791        let norm_slice = &mut host[..final_norm_grad.len()];
1792        final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1793        for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1794            *d += s;
1795        }
1796        Some(())
1797    }
1798
1799    /// Run LM head + final norm optimizer step (non-accumulating path).
1800    /// Static method to avoid borrow conflicts with `stream`.
1801    #[allow(clippy::too_many_arguments)]
1802    fn run_nonblock_optimizer_step(
1803        gpu_training: &mut GpuPretrainState,
1804        lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1805        lm_head_grad_gpu: &GpuBuffer<f32>,
1806        lm_head_m: &mut GpuBuffer<f32>,
1807        lm_head_v: &mut GpuBuffer<f32>,
1808        final_norm_m: &mut GpuBuffer<f32>,
1809        final_norm_v: &mut GpuBuffer<f32>,
1810        lr: f32,
1811        beta1: f32,
1812        beta2: f32,
1813        weight_decay: f32,
1814        stream: &CudaStream,
1815    ) {
1816        gpu_training.step += 1;
1817        let step = gpu_training.step;
1818
1819        if let Some(lm_head_weight) = lm_head_weight_gpu {
1820            let n_lm = lm_head_weight.len() as u32;
1821            let _ = adamw_step_cuda(
1822                lm_head_weight,
1823                lm_head_grad_gpu,
1824                lm_head_m,
1825                lm_head_v,
1826                lr,
1827                beta1,
1828                beta2,
1829                1e-8,
1830                weight_decay,
1831                step,
1832                n_lm,
1833                stream,
1834            );
1835        }
1836
1837        let n_norm = gpu_training.final_norm_weight.len() as u32;
1838        let _ = adamw_step_cuda(
1839            &mut gpu_training.final_norm_weight,
1840            &gpu_training.grad_final_norm_weight,
1841            final_norm_m,
1842            final_norm_v,
1843            lr,
1844            beta1,
1845            beta2,
1846            1e-8,
1847            weight_decay,
1848            step,
1849            n_norm,
1850            stream,
1851        );
1852    }
1853
1854    /// R-038: Download shared CudaGradWorkspace to CPU per-block accumulation buffers.
1855    ///
1856    /// Static method to avoid borrow conflicts with `stream` (same pattern as
1857    /// `recompute_segment`). Must be called after stream.synchronize() (ALB-065 / Rule 6).
1858    // KAIZEN-044: Pre-allocate a single host buffer for all D2H downloads
1859    // in download_workspace_to_accum. Was allocating vec![0.0f32; len] × 9 buffers.
1860    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1861    fn download_workspace_to_accum(
1862        ws: &CudaGradWorkspace,
1863        accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1864        layer_idx: usize,
1865        host: &mut [f32],
1866    ) -> Option<()> {
1867        let bg = &mut accum.block_grads[layer_idx];
1868
1869        use super::grad_accumulator::component;
1870        let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1871            (&ws.grad_w_q, component::W_Q),
1872            (&ws.grad_w_k, component::W_K),
1873            (&ws.grad_w_v, component::W_V),
1874            (&ws.grad_w_o, component::W_O),
1875            (&ws.grad_gate, component::GATE),
1876            (&ws.grad_up, component::UP),
1877            (&ws.grad_down, component::DOWN),
1878            (&ws.grad_input_norm, component::INPUT_NORM),
1879            (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1880        ];
1881
1882        for (gpu_buf, comp_idx) in &bufs_and_components {
1883            let slice = &mut host[..gpu_buf.len()];
1884            gpu_buf.copy_to_host_at(slice, 0).ok()?;
1885            for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1886                *d += s;
1887            }
1888        }
1889        Some(())
1890    }
1891
1892    /// R-038: Upload averaged CPU accumulation buffers to GPU workspace and run
1893    /// optimizer step for all blocks + LM head + final norm.
1894    ///
1895    /// Called once after `accumulation_steps` micro-batches have been accumulated.
1896    /// ALB-091: Run optimizer step from GPU-resident accumulated gradients.
1897    /// D2D copy accum → workspace, then run per-block optimizer. Zero accum after.
1898    fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1899        let stream = self.cuda_trainer.stream();
1900        let lr = self.current_lr();
1901        let beta1 = self.config.beta1;
1902        let beta2 = self.config.beta2;
1903        let weight_decay = self.config.weight_decay;
1904
1905        // Sync once to ensure all accumulation kernels complete
1906        stream.synchronize().ok()?;
1907
1908        self.gpu_training.step += 1;
1909        let step = self.gpu_training.step;
1910
1911        // Upload GPU accum → workspace (D2D) and run optimizer per block
1912        let gpu_accum = self.gpu_grad_accum.as_ref()?;
1913        for layer_idx in 0..self.cuda_blocks.len() {
1914            gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1915
1916            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1917                &mut self.gpu_training.optimizer_states[layer_idx],
1918                step,
1919                lr,
1920                beta1,
1921                beta2,
1922                1e-8,
1923                weight_decay,
1924                stream,
1925                &self.cuda_grad_workspace,
1926            );
1927        }
1928
1929        // LM head: D2D copy accum → grad buffer, then optimizer step
1930        gpu_accum
1931            .upload_nonblock(
1932                &mut self.lm_head_grad_gpu,
1933                &mut self.gpu_training.grad_final_norm_weight,
1934            )
1935            .ok()?;
1936
1937        let n_lm = self.lm_head_weight_gpu.len() as u32;
1938        let _ = adamw_step_cuda(
1939            &mut self.lm_head_weight_gpu,
1940            &self.lm_head_grad_gpu,
1941            &mut self.lm_head_m,
1942            &mut self.lm_head_v,
1943            lr,
1944            beta1,
1945            beta2,
1946            1e-8,
1947            weight_decay,
1948            step,
1949            n_lm,
1950            stream,
1951        );
1952
1953        // Final norm optimizer step
1954        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1955        let _ = adamw_step_cuda(
1956            &mut self.gpu_training.final_norm_weight,
1957            &self.gpu_training.grad_final_norm_weight,
1958            &mut self.final_norm_m,
1959            &mut self.final_norm_v,
1960            lr,
1961            beta1,
1962            beta2,
1963            1e-8,
1964            weight_decay,
1965            step,
1966            n_norm,
1967            stream,
1968        );
1969
1970        stream.synchronize().ok()?;
1971
1972        // Zero accum for next window
1973        if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1974            let _ = gpu_accum.zero_all();
1975        }
1976
1977        Some(())
1978    }
1979
1980    #[allow(unsafe_code)]
1981    fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1982        let stream = self.cuda_trainer.stream();
1983        let lr = self.current_lr();
1984        let beta1 = self.config.beta1;
1985        let beta2 = self.config.beta2;
1986        let weight_decay = self.config.weight_decay;
1987
1988        // Average accumulated gradients
1989        let accum = self.grad_accum.as_mut()?;
1990        accum.average();
1991
1992        // Jidoka: check for NaN/Inf before applying
1993        if accum.has_non_finite() {
1994            println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1995            accum.zero_all();
1996            return Some(());
1997        }
1998
1999        self.gpu_training.step += 1;
2000        let step = self.gpu_training.step;
2001
2002        // Upload accumulated gradients and run optimizer for each block
2003        use super::grad_accumulator::component;
2004        for layer_idx in 0..self.cuda_blocks.len() {
2005            let bg = &accum.block_grads[layer_idx];
2006
2007            // Upload accumulated gradients to shared workspace
2008            // SAFETY: async host-to-device copies within the training stream; host buffers
2009            // (bg.components) are stable for the duration of the stream operations.
2010            unsafe {
2011                self.cuda_grad_workspace
2012                    .grad_w_q
2013                    .copy_from_host_async(&bg.components[component::W_Q], stream)
2014                    .ok()?;
2015                self.cuda_grad_workspace
2016                    .grad_w_k
2017                    .copy_from_host_async(&bg.components[component::W_K], stream)
2018                    .ok()?;
2019                self.cuda_grad_workspace
2020                    .grad_w_v
2021                    .copy_from_host_async(&bg.components[component::W_V], stream)
2022                    .ok()?;
2023                self.cuda_grad_workspace
2024                    .grad_w_o
2025                    .copy_from_host_async(&bg.components[component::W_O], stream)
2026                    .ok()?;
2027                self.cuda_grad_workspace
2028                    .grad_gate
2029                    .copy_from_host_async(&bg.components[component::GATE], stream)
2030                    .ok()?;
2031                self.cuda_grad_workspace
2032                    .grad_up
2033                    .copy_from_host_async(&bg.components[component::UP], stream)
2034                    .ok()?;
2035                self.cuda_grad_workspace
2036                    .grad_down
2037                    .copy_from_host_async(&bg.components[component::DOWN], stream)
2038                    .ok()?;
2039                self.cuda_grad_workspace
2040                    .grad_input_norm
2041                    .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
2042                    .ok()?;
2043                self.cuda_grad_workspace
2044                    .grad_post_attn_norm
2045                    .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
2046                    .ok()?;
2047            }
2048
2049            // Run optimizer step with uploaded averaged gradients
2050            let _ = self.cuda_blocks[layer_idx].optimizer_step(
2051                &mut self.gpu_training.optimizer_states[layer_idx],
2052                step,
2053                lr,
2054                beta1,
2055                beta2,
2056                1e-8,
2057                weight_decay,
2058                stream,
2059                &self.cuda_grad_workspace,
2060            );
2061        }
2062
2063        // Upload accumulated LM head gradients and run AdamW step
2064        // entrenar#314: Skip GPU LM head optimizer for tied weights.
2065        // SAFETY: async host-to-device copy; host buffer (accum.lm_head_grad) is stable.
2066        unsafe {
2067            self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
2068        }
2069        let n_lm = self.lm_head_weight_gpu.len() as u32;
2070        let _ = adamw_step_cuda(
2071            &mut self.lm_head_weight_gpu,
2072            &self.lm_head_grad_gpu,
2073            &mut self.lm_head_m,
2074            &mut self.lm_head_v,
2075            lr,
2076            beta1,
2077            beta2,
2078            1e-8,
2079            weight_decay,
2080            step,
2081            n_lm,
2082            stream,
2083        );
2084
2085        // Upload accumulated final norm gradients and run AdamW step
2086        // SAFETY: async host-to-device copy; host buffer (accum.final_norm_grad) is stable.
2087        unsafe {
2088            self.gpu_training
2089                .grad_final_norm_weight
2090                .copy_from_host_async(&accum.final_norm_grad, stream)
2091                .ok()?;
2092        }
2093        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2094        let _ = adamw_step_cuda(
2095            &mut self.gpu_training.final_norm_weight,
2096            &self.gpu_training.grad_final_norm_weight,
2097            &mut self.final_norm_m,
2098            &mut self.final_norm_v,
2099            lr,
2100            beta1,
2101            beta2,
2102            1e-8,
2103            weight_decay,
2104            step,
2105            n_norm,
2106            stream,
2107        );
2108
2109        stream.synchronize().ok()?;
2110
2111        // Zero accum for next window
2112        accum.zero_all();
2113        Some(())
2114    }
2115
2116    /// Compute gradient L2 norm via GPU reduction kernel (KAIZEN-049).
2117    ///
2118    /// Runs `SquaredSumKernel` on GPU, downloads only `num_blocks` partial sums (~1KB)
2119    /// instead of the full buffer (128MB for lm_head). Falls back to CPU download on error.
2120    ///
2121    /// # Contract (C-CLIPNORM-GPU-001)
2122    ///
2123    /// - **Precondition**: `buf.len() > 0`, stream is synchronized with prior kernel
2124    /// - **Postcondition**: `grad_norm ≈ sqrt(sum(buf[i]^2))`, `scale = min(1, max_norm/norm)`
2125    /// - **Transfer**: ~1KB D2H (num_blocks × 4B) vs n×4B (128MB for 32M elements)
2126    ///
2127    /// R-004: Returns `(clip_scale, grad_norm)` for observability.
2128    fn compute_clip_scale_with_norm(
2129        buf: &GpuBuffer<f32>,
2130        max_norm: f32,
2131        stream: &CudaStream,
2132    ) -> (f32, f32) {
2133        let n = buf.len() as u32;
2134        // Try GPU reduction first — ~1KB D2H instead of n×4 bytes
2135        let grad_norm = match squared_sum_cuda(buf, n, stream) {
2136            Ok(norm) => norm,
2137            Err(_) => {
2138                // Fallback: full D2H (original path)
2139                let mut host = vec![0.0f32; buf.len()];
2140                if buf.copy_to_host_at(&mut host, 0).is_err() {
2141                    return (1.0, 0.0);
2142                }
2143                let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2144                sq_sum.sqrt() as f32
2145            }
2146        };
2147        let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2148        (scale, grad_norm)
2149    }
2150
2151    /// Download embedding gradient from GPU, clip, and scatter-add into CPU weight.
2152    ///
2153    /// # Contract (C-EMBED-GRAD-001)
2154    ///
2155    /// The activation gradient from block[0]'s backward is unclipped (per-block clipping
2156    /// only applies to weight gradients in the shared workspace). For deep networks with
2157    /// random init, this gradient can overflow f32, producing NaN in the CPU AdamW.
2158    /// We clip the activation gradient to max_grad_norm before scatter-adding.
2159    #[allow(unsafe_code)]
2160    fn embed_backward(
2161        &mut self,
2162        input_ids: &[u32],
2163        _seq_len: usize,
2164        hidden_size: usize,
2165        vocab_size: usize,
2166        grad_output_is_a: bool,
2167    ) -> Option<()> {
2168        // The final backward output is in whichever buffer was last written
2169        let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2170        let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2171        let embed_grad_buf = unsafe {
2172            if grad_output_is_a {
2173                &*grad_a_ptr
2174            } else {
2175                &*grad_b_ptr
2176            }
2177        };
2178        let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2179
2180        // C-EMBED-GRAD-001: ALWAYS clip activation gradient before scatter-add.
2181        // Without this, 24-layer random-init backward amplifies gradients to ~1e35,
2182        // which overflows the CPU AdamW's second moment buffer.
2183        //
2184        // ALB-071: Decoupled from general grad_clip config. Embed activation gradient
2185        // clipping is a SAFETY constraint (prevents NaN), not a training hyperparameter.
2186        // Uses dedicated max_embed_grad_norm (default 1.0) independent of weight grad_clip.
2187        let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2188        {
2189            let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2190            let grad_norm = sq_sum.sqrt() as f32;
2191            self.last_embed_grad_norm = grad_norm; // R-040: per-parameter-group tracking
2192            if grad_norm > embed_clip_norm {
2193                let scale = embed_clip_norm / grad_norm;
2194                for g in &mut embed_grad_data {
2195                    *g *= scale;
2196                }
2197            }
2198        }
2199
2200        // KAIZEN-048: In-place scatter-add via grad_cell().borrow_mut().
2201        // Before: 3 × 128MB clones per step (grad() deep-copies Array1).
2202        // After: zero clones — mutate existing gradient buffer directly.
2203        let embed_weight = &mut self.model.embed_tokens.weight;
2204        let grad_cell = embed_weight.grad_cell();
2205        let mut grad_ref = grad_cell.borrow_mut();
2206        if grad_ref.is_none() {
2207            *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2208        }
2209        if let Some(grad) = grad_ref.as_mut() {
2210            for (pos, &token_id) in input_ids.iter().enumerate() {
2211                let tid = token_id as usize;
2212                if tid < vocab_size {
2213                    let src = pos * hidden_size;
2214                    let dst = tid * hidden_size;
2215                    for h in 0..hidden_size {
2216                        grad[dst + h] += embed_grad_data[src + h];
2217                    }
2218                }
2219            }
2220        }
2221        Some(())
2222    }
2223
2224    /// Apply optimizer step to CPU embedding and update metrics.
2225    ///
2226    /// GPU block optimizer steps now run interleaved with backward in `gpu_backward()`.
2227    /// LM head and final norm optimizer steps also run in `gpu_backward()`.
2228    /// This method handles only CPU embedding and bookkeeping.
2229    fn optimizer_step(&mut self) {
2230        // ALB-072: Gradients are no longer scaled by grad_scaler (loss_scale excludes
2231        // grad_scaler.scale()). All backward computation uses f32 — no fp16 underflow
2232        // risk. Skip unscaling; just update scaler as successful.
2233        self.grad_scaler.update(true);
2234
2235        // ALB-079: Sync CPU embedding optimizer lr with cosine schedule
2236        self.embed_optimizer.set_lr(self.current_lr());
2237        // CPU optimizer step for embedding weight
2238        let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2239        self.embed_optimizer.step_refs(&mut embed_params);
2240
2241        self.step += 1;
2242        self.metrics.losses.push(self.accumulated_loss);
2243        self.metrics.increment_step();
2244
2245        self.accumulated_loss = 0.0;
2246        self.accumulated_batches = 0;
2247    }
2248
2249    /// Process a batch (forward + backward + optimizer step with accumulation).
2250    ///
2251    /// R-038: When `accumulation_steps > 1`, runs forward+backward without optimizer
2252    /// for each micro-batch, downloading per-block weight gradients to CPU-side
2253    /// `PerBlockGradientAccumulator`. After `accumulation_steps` batches, averages
2254    /// the accumulated gradients, uploads them to GPU, and runs a single optimizer step.
2255    ///
2256    /// When `accumulation_steps == 1` (default), runs forward+backward+optimizer
2257    /// immediately per sequence (original behavior).
2258    ///
2259    /// Returns average loss for the batch.
2260    pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2261        if batch.batch_size == 0 {
2262            return 0.0;
2263        }
2264
2265        let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2266
2267        if self.accumulated_batches == 0 {
2268            // Zero embedding gradients at start of accumulation window
2269            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2270        }
2271
2272        let mut total_loss = 0.0;
2273        let mut valid_count = 0;
2274
2275        for i in 0..batch.batch_size {
2276            let Some(input_ids) = batch.get_input(i) else {
2277                continue;
2278            };
2279            let Some(target_ids) = batch.get_target(i) else {
2280                continue;
2281            };
2282
2283            // R-038: When accumulating, run backward without optimizer (accumulate_only=true).
2284            // Gradients are downloaded to CPU per-block accum buffers. Embedding grads are
2285            // scatter-added normally (they're already on CPU).
2286            if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2287                total_loss += loss;
2288                valid_count += 1;
2289                if accumulating {
2290                    if let Some(accum) = &mut self.gpu_grad_accum {
2291                        accum.accumulated_count += 1;
2292                    } else if let Some(accum) = &mut self.grad_accum {
2293                        accum.accumulated_count += 1;
2294                    }
2295                }
2296            }
2297        }
2298
2299        let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2300
2301        // Debug: help diagnose loss=0.0 when gradients are non-zero
2302        if avg_loss == 0.0 && valid_count > 0 {
2303            eprintln!(
2304                "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2305                valid_count, total_loss, batch.batch_size
2306            );
2307        }
2308
2309        self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2310        self.accumulated_batches += 1;
2311
2312        if self.accumulated_batches >= self.config.accumulation_steps {
2313            if accumulating {
2314                // ALB-091: Prefer GPU-resident accum path (zero D2H), fall back to CPU.
2315                if self.gpu_grad_accum.is_some() {
2316                    self.gpu_optimizer_from_gpu_accum();
2317                } else {
2318                    self.gpu_optimizer_from_accum();
2319                }
2320            }
2321            self.optimizer_step();
2322        }
2323
2324        avg_loss
2325    }
2326
2327    /// R-005: Evaluate a batch without backward pass or weight updates.
2328    /// Returns average cross-entropy loss, or 0.0 if no valid items.
2329    /// KAIZEN-050: Uses fused GPU cross-entropy (no logits D2H).
2330    pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2331        let hidden_size = self.config.model_config.hidden_size;
2332        let vocab_size = self.config.model_config.vocab_size;
2333        let max_sl = self.config.max_seq_len;
2334        let mut total_loss = 0.0;
2335        let mut valid_count = 0;
2336        for i in 0..batch.batch_size {
2337            if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2338            {
2339                total_loss += loss;
2340                valid_count += 1;
2341            }
2342        }
2343        if valid_count > 0 {
2344            total_loss / valid_count as f32
2345        } else {
2346            0.0
2347        }
2348    }
2349
2350    /// Evaluate a single sequence from a batch. Returns None if invalid.
2351    fn eval_single_sequence(
2352        &mut self,
2353        batch: &LMBatch,
2354        i: usize,
2355        max_sl: usize,
2356        hidden_size: usize,
2357        vocab_size: usize,
2358    ) -> Option<f32> {
2359        let input_ids = batch.get_input(i)?;
2360        let target_ids = batch.get_target(i)?;
2361        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
2362        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2363        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2364        let seq_len = input_ids.len();
2365        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2366        let stream = self.cuda_trainer.stream();
2367        let scale = 1.0 / seq_len as f32;
2368        let loss = fused_cross_entropy_cuda(
2369            &mut self.gpu_training.logits_buf,
2370            target_ids,
2371            seq_len as u32,
2372            vocab_size as u32,
2373            scale,
2374            stream,
2375        )
2376        .ok()?;
2377        if loss.is_finite() {
2378            Some(loss)
2379        } else {
2380            None
2381        }
2382    }
2383
2384    /// Train for one epoch over batches.
2385    pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2386        self.train_epoch_with_callback(batches, |_, _, _| {})
2387    }
2388
2389    /// Train for one epoch with a per-step callback.
2390    ///
2391    /// Stops early if `max_steps` is set and reached.
2392    pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2393    where
2394        F: FnMut(usize, f32, &Self),
2395    {
2396        if batches.is_empty() {
2397            return 0.0;
2398        }
2399
2400        let mut total_loss = 0.0;
2401        let mut batches_processed = 0;
2402
2403        for (i, batch) in batches.iter().enumerate() {
2404            if let Some(max) = self.config.max_steps {
2405                if self.step >= max {
2406                    break;
2407                }
2408            }
2409
2410            let batch_loss = self.train_batch(batch);
2411            total_loss += batch_loss;
2412            batches_processed += 1;
2413            on_batch(i, batch_loss, self);
2414        }
2415
2416        // KAIZEN-047: Print profiler summary at end of epoch
2417        if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2418            self.profiler.print_report();
2419        }
2420
2421        total_loss / batches_processed.max(1) as f32
2422    }
2423
2424    // --- DDP (data-parallel) support methods ---
2425
2426    /// Ensure the per-block gradient accumulator exists.
2427    ///
2428    /// For DDP, we always need accumulation buffers (even with accumulation_steps=1)
2429    /// because gradients must be downloaded to CPU for AllReduce before optimizer step.
2430    pub(crate) fn ensure_grad_accum(&mut self) {
2431        if self.grad_accum.is_some() {
2432            return;
2433        }
2434        let mc = &self.config.model_config;
2435        let hidden_size = mc.hidden_size;
2436        let kv_hidden = mc.num_kv_heads * mc.head_dim();
2437        let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2438            hidden_size,
2439            kv_hidden,
2440            mc.intermediate_size,
2441        );
2442        self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2443            self.cuda_blocks.len(),
2444            block_sizes,
2445            mc.vocab_size,
2446            hidden_size,
2447        ));
2448    }
2449
2450    /// Forward + backward for one batch, always accumulating (no optimizer step).
2451    ///
2452    /// Used by `DistributedCudaTrainer` to compute local gradients before AllReduce.
2453    /// Returns average loss for the batch.
2454    pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2455        if batch.batch_size == 0 {
2456            return 0.0;
2457        }
2458
2459        if self.accumulated_batches == 0 {
2460            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2461        }
2462
2463        let mut total_loss = 0.0;
2464        let mut valid_count = 0;
2465
2466        for i in 0..batch.batch_size {
2467            let Some(input_ids) = batch.get_input(i) else { continue };
2468            let Some(target_ids) = batch.get_target(i) else { continue };
2469
2470            // Always accumulate_only=true: gradients go to CPU accum buffers
2471            if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2472                total_loss += loss;
2473                valid_count += 1;
2474                if let Some(accum) = &mut self.grad_accum {
2475                    accum.accumulated_count += 1;
2476                }
2477            }
2478        }
2479
2480        if valid_count > 0 {
2481            total_loss / valid_count as f32
2482        } else {
2483            0.0
2484        }
2485    }
2486
2487    /// Apply DDP-averaged gradients: upload to GPU and run optimizer step.
2488    ///
2489    /// Called after AllReduce has written averaged gradients into the grad_accum.
2490    /// Runs gpu_optimizer_from_accum() for blocks + LM head + final norm,
2491    /// then optimizer_step() for embedding.
2492    pub(crate) fn apply_ddp_gradients(&mut self) {
2493        self.accumulated_loss = 0.0;
2494        self.accumulated_batches = 0;
2495        self.gpu_optimizer_from_accum();
2496        self.optimizer_step();
2497    }
2498
2499    /// Get a reference to the gradient accumulator (for DDP AllReduce).
2500    pub(crate) fn grad_accum_ref(
2501        &self,
2502    ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2503        self.grad_accum.as_ref()
2504    }
2505
2506    /// Get a mutable reference to the gradient accumulator (for DDP AllReduce).
2507    pub(crate) fn grad_accum_mut(
2508        &mut self,
2509    ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2510        self.grad_accum.as_mut()
2511    }
2512
2513    /// Get the training config.
2514    pub(crate) fn config(&self) -> &TransformerTrainConfig {
2515        &self.config
2516    }
2517
2518    /// Get CPU embedding gradient as flat Vec for AllReduce.
2519    pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2520        self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2521    }
2522
2523    /// Set CPU embedding gradient from AllReduced flat Vec.
2524    pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2525        self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2526    }
2527
2528    /// Returns true if max_steps has been reached.
2529    pub fn reached_max_steps(&self) -> bool {
2530        self.config.max_steps.is_some_and(|max| self.step >= max)
2531    }
2532
2533    /// Get current step count.
2534    pub fn step(&self) -> usize {
2535        self.step
2536    }
2537
2538    /// Set initial step for resume from checkpoint.
2539    ///
2540    /// Updates both the outer step counter (LR schedule, logging) and the
2541    /// GPU-side AdamW step counter (bias correction). Must be called before
2542    /// any `train_batch()` calls.
2543    pub fn set_initial_step(&mut self, step: usize) {
2544        self.step = step;
2545        self.gpu_training.step = step as u32;
2546    }
2547
2548    /// Set max_steps for cosine LR scheduler (ENT-275).
2549    ///
2550    /// Called by `train_loop_cuda` when `max_steps` is not explicitly set in
2551    /// the YAML config — auto-computes `epochs × batches_per_epoch` so cosine
2552    /// decay activates instead of falling back to constant lr.
2553    pub fn set_max_steps(&mut self, max_steps: usize) {
2554        self.config.max_steps = Some(max_steps);
2555    }
2556
2557    /// Get current learning rate (warmup + cosine decay).
2558    ///
2559    /// ALB-079: Phase 1 = linear warmup (0 → lr_max), Phase 2 = cosine decay
2560    /// (lr_max → 0) over remaining steps. Requires `max_steps` for decay;
2561    /// without it, falls back to constant lr after warmup.
2562    pub fn current_lr(&self) -> f32 {
2563        let base_lr = self.config.lr;
2564        if self.step < self.config.warmup_steps {
2565            // Phase 1: Linear warmup
2566            base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2567        } else if let Some(max_steps) = self.config.max_steps {
2568            // Phase 2: Cosine decay from lr_max to 0
2569            let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2570            if decay_steps == 0 {
2571                return base_lr;
2572            }
2573            let decay_step = self.step - self.config.warmup_steps;
2574            let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2575            0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2576        } else {
2577            // No max_steps: constant lr (legacy behavior)
2578            base_lr
2579        }
2580    }
2581
2582    /// KAIZEN-047: Enable step profiling with a report every `interval` steps.
2583    ///
2584    /// When enabled, prints a table of wall-clock timings per training phase
2585    /// every `interval` training steps. Use interval=0 for manual-only reporting.
2586    ///
2587    /// # Contract (C-STEPPROF-001)
2588    ///
2589    /// - No additional GPU synchronization points (relies on existing syncs)
2590    /// - Overhead: ~11 `Instant::now()` calls per step (~1µs total on Linux)
2591    /// - Timings include async dispatch overhead (not pure kernel time)
2592    pub fn enable_profiler(&mut self, interval: usize) {
2593        self.profiler = StepProfiler::new(true, interval);
2594    }
2595
2596    /// Print the profiler report (if profiling is enabled).
2597    pub fn print_profiler_report(&self) {
2598        self.profiler.print_report();
2599    }
2600
2601    /// R-004: Get last observed gradient L2 norm (LM head proxy).
2602    pub fn last_grad_norm(&self) -> f32 {
2603        self.last_grad_norm
2604    }
2605
2606    /// R-040: Get per-parameter-group gradient norms.
2607    /// Returns (lm_head_grad_norm, embed_grad_norm).
2608    pub fn param_grad_norms(&self) -> (f32, f32) {
2609        (self.last_grad_norm, self.last_embed_grad_norm)
2610    }
2611
2612    /// R-012: Get total trainable parameter count for MFU calculation.
2613    pub fn num_params(&self) -> usize {
2614        self.model.parameters().iter().map(|t| t.len()).sum()
2615    }
2616
2617    /// R-013: Query GPU memory usage (used_mb, total_mb).
2618    pub fn gpu_memory_mb(&self) -> (u64, u64) {
2619        match self.cuda_trainer.context().memory_info() {
2620            Ok((free, total)) => {
2621                let total_mb = (total / (1024 * 1024)) as u64;
2622                let used_mb = ((total - free) / (1024 * 1024)) as u64;
2623                (used_mb, total_mb)
2624            }
2625            Err(_) => (0, 0),
2626        }
2627    }
2628
2629    /// Sync all GPU weights back to CPU model.
2630    ///
2631    /// # Contract (C-SYNCWT-001)
2632    ///
2633    /// Must be called before save or any CPU model access after training.
2634    pub fn sync_weights_to_cpu(&mut self) {
2635        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2636
2637        if use_nf4 {
2638            // ENT-263: NF4 blocks are frozen — base weights don't change.
2639            // Only download LoRA adapter weights for checkpoint saving.
2640            // The base model on CPU stays as-is (original pretrained weights).
2641            // LoRA weights are saved separately (adapter_config.json + adapter.safetensors).
2642            // For now, skip per-layer sync — base weights are unchanged.
2643        } else {
2644            for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2645                if let Ok(weights) = block.download_weights() {
2646                    let layer = &mut self.model.layers[layer_idx];
2647
2648                    layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2649                    layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2650                    layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2651                    layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2652
2653                    layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2654                    layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2655                    layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2656
2657                    layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2658                    layer.post_attn_norm.weight =
2659                        Tensor::from_vec(weights.post_attn_norm_weight, false);
2660                }
2661            }
2662        }
2663
2664        // Sync final norm weight
2665        if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2666            self.model.norm.weight = Tensor::from_vec(norm_data, false);
2667        }
2668
2669        // Sync LM head weight
2670        // ALB-097: ALWAYS save GPU-trained LM head, even for tied-weight models.
2671        // During GPU training, lm_head diverges from embed_tokens because they have
2672        // separate optimizers (GPU AdamW vs CPU AdamW). If we skip the sync for tied
2673        // weights, the checkpoint loses 500+ steps of GPU LM head training → random-init
2674        // loss on resume (Five Whys root cause of ALB-097).
2675        if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2676            self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2677        }
2678    }
2679
2680    /// Get reference to model (syncs weights first).
2681    pub fn model(&self) -> &Transformer {
2682        &self.model
2683    }
2684
2685    /// Get mutable reference to model.
2686    pub fn model_mut(&mut self) -> &mut Transformer {
2687        &mut self.model
2688    }
2689
2690    /// Check if using mixed precision.
2691    pub fn is_mixed_precision(&self) -> bool {
2692        self.config.precision_config.is_mixed()
2693    }
2694
2695    /// Get the gradient scaler (R-002: loss scaling for mixed precision).
2696    pub fn grad_scaler(&self) -> &GradScaler {
2697        &self.grad_scaler
2698    }
2699
2700    /// Check if using gradient checkpointing.
2701    pub fn is_checkpointing(&self) -> bool {
2702        self.config.checkpoint_config.enabled
2703    }
2704
2705    /// Save model weights (syncs GPU→CPU first).
2706    pub fn save(
2707        &mut self,
2708        path: impl AsRef<std::path::Path>,
2709        name: &str,
2710        architecture: &str,
2711    ) -> crate::Result<()> {
2712        self.sync_weights_to_cpu();
2713
2714        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2715        let params: Vec<(String, Tensor)> = self
2716            .model
2717            .named_parameters()
2718            .into_iter()
2719            .map(|(name, tensor)| (name, tensor.clone()))
2720            .collect();
2721
2722        let metadata = ModelMetadata::new(name, architecture);
2723        let model = Model::new(metadata, params);
2724        let config = SaveConfig::new(ModelFormat::SafeTensors);
2725
2726        save_model(&model, path, &config)
2727    }
2728
2729    /// R-011: Prepare checkpoint data for async save.
2730    /// Syncs GPU weights to CPU and snapshots tensor data as Send-able Vec<f32>.
2731    /// Returns a closure that writes the checkpoint file from another thread.
2732    pub fn prepare_async_save(
2733        &mut self,
2734        name: &str,
2735        architecture: &str,
2736    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2737        self.sync_weights_to_cpu();
2738
2739        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2740        let param_data: Vec<(String, Vec<f32>)> = self
2741            .model
2742            .named_parameters()
2743            .into_iter()
2744            .map(|(n, t)| (n, t.data().to_vec()))
2745            .collect();
2746
2747        let name = name.to_string();
2748        let architecture = architecture.to_string();
2749
2750        Box::new(move |path: &std::path::Path| {
2751            let params: Vec<(String, Tensor)> =
2752                param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2753            let metadata = ModelMetadata::new(&name, &architecture);
2754            let model = Model::new(metadata, params);
2755            let config = SaveConfig::new(ModelFormat::SafeTensors);
2756            save_model(&model, path, &config)
2757        })
2758    }
2759
2760    /// ALB-096: Save model weights as APR checkpoint (syncs GPU→CPU first).
2761    ///
2762    /// Single atomic file containing all model weights. Use `save_apr_checkpoint()`
2763    /// to include optimizer state and training metadata in the same file.
2764    pub fn save_apr(
2765        &mut self,
2766        path: impl AsRef<std::path::Path>,
2767        name: &str,
2768        architecture: &str,
2769    ) -> crate::Result<()> {
2770        self.save_apr_with_tokenizer(path, name, architecture, None)
2771    }
2772
2773    /// SPEC-SHIP-TWO-001 §81 P0-D + P0-E: save APR checkpoint with arch
2774    /// metadata keys AND optionally embed the source tokenizer.json.
2775    ///
2776    /// When `tokenizer_dir` is `Some`, reads `<dir>/tokenizer.json` and
2777    /// embeds the vocabulary + merges + BOS/EOS IDs as well-known
2778    /// metadata keys. This makes the resulting .apr file standalone for
2779    /// `apr qa`, `apr run`, etc. — no `--tokenizer` flag required at
2780    /// downstream tool dispatch.
2781    pub fn save_apr_with_tokenizer(
2782        &mut self,
2783        path: impl AsRef<std::path::Path>,
2784        name: &str,
2785        architecture: &str,
2786        tokenizer_dir: Option<&std::path::Path>,
2787    ) -> crate::Result<()> {
2788        self.sync_weights_to_cpu();
2789
2790        let params: Vec<(String, Tensor)> = self
2791            .model
2792            .named_parameters()
2793            .into_iter()
2794            .map(|(name, tensor)| (name, tensor.clone()))
2795            .collect();
2796
2797        // SPEC-SHIP-TWO-001 §81 P0-E: write individual arch metadata keys
2798        // so downstream tools (apr qa C-03, apr bench, realizar) can read them
2799        // via AprV2Metadata's typed fields. The legacy save_model() path only
2800        // carries `name + architecture + format + version` which fails C-03.
2801        use crate::io::save::infer_all_tensor_shapes;
2802        use aprender::serialization::apr::AprWriter;
2803        use serde_json::Value as Jv;
2804
2805        let mc = &self.config.model_config;
2806        let mut writer = AprWriter::new();
2807
2808        // Identity / version metadata (preserves save_model behavior)
2809        writer.set_metadata("model_name", Jv::String(name.to_string()));
2810        writer.set_metadata("architecture", Jv::String(architecture.to_string()));
2811        writer.set_metadata("version", Jv::String("0.1.0".into()));
2812        writer.set_metadata("format", Jv::String("entrenar-checkpoint".into()));
2813
2814        // Arch dim keys (well-known to AprWriter::build_v2_metadata,
2815        // map to AprV2Metadata typed fields).
2816        writer.set_metadata(
2817            "hidden_size",
2818            Jv::Number(serde_json::Number::from(mc.hidden_size as u64)),
2819        );
2820        writer.set_metadata(
2821            "num_hidden_layers",
2822            Jv::Number(serde_json::Number::from(mc.num_hidden_layers as u64)),
2823        );
2824        writer.set_metadata(
2825            "num_attention_heads",
2826            Jv::Number(serde_json::Number::from(mc.num_attention_heads as u64)),
2827        );
2828        writer.set_metadata(
2829            "num_kv_heads",
2830            Jv::Number(serde_json::Number::from(mc.num_kv_heads as u64)),
2831        );
2832        writer.set_metadata(
2833            "intermediate_size",
2834            Jv::Number(serde_json::Number::from(mc.intermediate_size as u64)),
2835        );
2836        writer
2837            .set_metadata("vocab_size", Jv::Number(serde_json::Number::from(mc.vocab_size as u64)));
2838        writer.set_metadata(
2839            "max_position_embeddings",
2840            Jv::Number(serde_json::Number::from(mc.max_position_embeddings as u64)),
2841        );
2842        if let Some(rope) = serde_json::Number::from_f64(mc.rope_theta as f64) {
2843            writer.set_metadata("rope_theta", Jv::Number(rope));
2844        }
2845        if let Some(eps) = serde_json::Number::from_f64(mc.rms_norm_eps as f64) {
2846            writer.set_metadata("rms_norm_eps", Jv::Number(eps));
2847        }
2848
2849        // SPEC-SHIP-TWO-001 §81 P0-D: embed tokenizer.json from
2850        // `tokenizer_dir/tokenizer.json` so `apr qa` (which requires
2851        // an embedded tokenizer) accepts the resulting .apr file.
2852        // ALB-130 style: parse vocab + merges + special token IDs and
2853        // set as well-known metadata keys.
2854        if let Some(dir) = tokenizer_dir {
2855            let tok_path = dir.join("tokenizer.json");
2856            if let Ok(json_bytes) = std::fs::read(&tok_path) {
2857                if let Ok(tok) = serde_json::from_slice::<Jv>(&json_bytes) {
2858                    if let Some(model) = tok.get("model") {
2859                        if let Some(vocab_obj) = model.get("vocab").and_then(|v| v.as_object()) {
2860                            let mut vocab_pairs: Vec<(String, u64)> = vocab_obj
2861                                .iter()
2862                                .filter_map(|(k, v)| Some((k.clone(), v.as_u64()?)))
2863                                .collect();
2864                            vocab_pairs.sort_by_key(|(_, id)| *id);
2865                            let vocab: Vec<Jv> =
2866                                vocab_pairs.into_iter().map(|(k, _)| Jv::String(k)).collect();
2867                            writer.set_metadata("tokenizer.vocabulary", Jv::Array(vocab));
2868                        }
2869                        if let Some(merges_arr) = model.get("merges").and_then(|m| m.as_array()) {
2870                            let merges: Vec<Jv> = merges_arr
2871                                .iter()
2872                                .filter_map(|v| v.as_str().map(|s| Jv::String(s.to_string())))
2873                                .collect();
2874                            writer.set_metadata("tokenizer.merges", Jv::Array(merges));
2875                        }
2876                    }
2877                    // BOS / EOS from added_tokens (HF format).
2878                    if let Some(added) = tok.get("added_tokens").and_then(|a| a.as_array()) {
2879                        for entry in added {
2880                            let content =
2881                                entry.get("content").and_then(|c| c.as_str()).unwrap_or("");
2882                            let id = entry.get("id").and_then(|i| i.as_u64());
2883                            if let Some(id) = id {
2884                                match content {
2885                                    "<s>" | "<|im_start|>" | "<|begin_of_text|>" => {
2886                                        writer.set_metadata(
2887                                            "tokenizer.bos_token_id",
2888                                            Jv::Number(serde_json::Number::from(id)),
2889                                        );
2890                                    }
2891                                    "</s>" | "<|im_end|>" | "<|end_of_text|>" | "<|endoftext|>" => {
2892                                        writer.set_metadata(
2893                                            "tokenizer.eos_token_id",
2894                                            Jv::Number(serde_json::Number::from(id)),
2895                                        );
2896                                    }
2897                                    _ => {}
2898                                }
2899                            }
2900                        }
2901                    }
2902                }
2903            }
2904        }
2905
2906        // Tensors — reuse io::save's shape inference for 2D weight handling.
2907        let shapes = infer_all_tensor_shapes(&params);
2908        for (tname, tensor) in &params {
2909            let data = tensor.data();
2910            let slice = data.as_slice().expect("tensor data must be contiguous");
2911            let shape = shapes.get(tname).cloned().unwrap_or_else(|| vec![tensor.len()]);
2912            writer.add_tensor_f32(tname, shape, slice);
2913        }
2914
2915        writer
2916            .write(path)
2917            .map_err(|e| crate::error::Error::Serialization(format!("APR write failed: {e}")))
2918    }
2919
2920    /// ALB-096: Prepare APR checkpoint data for async save.
2921    ///
2922    /// Syncs GPU weights to CPU and snapshots tensor data + optimizer state as
2923    /// Send-able `Vec<f32>`. Returns a closure that writes a single atomic APR
2924    /// file from another thread. Includes model weights + CPU embedding optimizer
2925    /// state + training metadata — all in one file.
2926    fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2927        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2928        if use_nf4 {
2929            let frozen_suffixes = [
2930                "q_proj.weight",
2931                "k_proj.weight",
2932                "v_proj.weight",
2933                "o_proj.weight",
2934                "gate_proj.weight",
2935                "up_proj.weight",
2936                "down_proj.weight",
2937            ];
2938            self.model
2939                .named_parameters()
2940                .into_iter()
2941                .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2942                .map(|(n, t)| (n, t.data().to_vec()))
2943                .collect()
2944        } else {
2945            self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2946        }
2947    }
2948
2949    fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2950        if self.config.quantize_nf4 && self.config.is_lora() {
2951            self.cuda_blocks
2952                .iter()
2953                .enumerate()
2954                .filter_map(|(i, block)| {
2955                    block
2956                        .download_lora_weights()
2957                        .ok()
2958                        .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2959                })
2960                .collect()
2961        } else {
2962            Vec::new()
2963        }
2964    }
2965
2966    pub fn prepare_async_apr_save(
2967        &mut self,
2968        name: &str,
2969        architecture: &str,
2970        step: usize,
2971        loss: f64,
2972        lr: f64,
2973    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2974        self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2975    }
2976
2977    /// ALB-130: Prepare APR checkpoint with embedded tokenizer for inference.
2978    ///
2979    /// Training checkpoints must be self-contained for eval (`apr eval --task humaneval`).
2980    /// Without embedded tokenizer, inference falls back to structural validation (fake 100%).
2981    /// The tokenizer path comes from `spec.data.tokenizer` in the training YAML.
2982    pub fn prepare_async_apr_save_with_tokenizer(
2983        &mut self,
2984        name: &str,
2985        architecture: &str,
2986        step: usize,
2987        loss: f64,
2988        lr: f64,
2989        tokenizer_path: Option<&std::path::Path>,
2990    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2991        self.sync_weights_to_cpu();
2992
2993        let param_data = self.snapshot_param_data();
2994        let lora_data = self.snapshot_lora_data();
2995
2996        // Snapshot CPU embedding optimizer state
2997        let embed_m: Vec<Vec<f32>> = self
2998            .embed_optimizer
2999            .first_moments()
3000            .iter()
3001            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3002            .collect();
3003        let embed_v: Vec<Vec<f32>> = self
3004            .embed_optimizer
3005            .second_moments()
3006            .iter()
3007            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3008            .collect();
3009        let embed_step = self.embed_optimizer.step_count();
3010
3011        // ALB-118: Download GPU block optimizer states (m/v moments) for checkpointing.
3012        // Without this, resume re-initializes all 24 blocks' AdamW state to zero,
3013        // causing loss spikes and convergence failure (v10/v11/v12 post-mortems).
3014        // Transfer cost: ~2.3 GB D2H, <6ms on PCIe4/5.
3015        let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
3016            .gpu_training
3017            .optimizer_states
3018            .iter()
3019            .map(|state| state.download_to_host().unwrap_or_default())
3020            .collect();
3021
3022        // ALB-118: Download LM head and final norm optimizer states
3023        let lm_head_m_host = {
3024            let mut buf = vec![0.0f32; self.lm_head_m.len()];
3025            let _ = self.lm_head_m.copy_to_host(&mut buf);
3026            buf
3027        };
3028        let lm_head_v_host = {
3029            let mut buf = vec![0.0f32; self.lm_head_v.len()];
3030            let _ = self.lm_head_v.copy_to_host(&mut buf);
3031            buf
3032        };
3033        let final_norm_m_host = {
3034            let mut buf = vec![0.0f32; self.final_norm_m.len()];
3035            let _ = self.final_norm_m.copy_to_host(&mut buf);
3036            buf
3037        };
3038        let final_norm_v_host = {
3039            let mut buf = vec![0.0f32; self.final_norm_v.len()];
3040            let _ = self.final_norm_v.copy_to_host(&mut buf);
3041            buf
3042        };
3043
3044        let name = name.to_string();
3045        let architecture = architecture.to_string();
3046        let model_config_json = serde_json::to_string(&self.config.model_config).ok();
3047        let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
3048
3049        // SPEC-SHIP-TWO-001 §81 P0-E: extract individual arch metadata keys
3050        // so downstream tools (apr qa, apr bench, apr export) can read them
3051        // via AprV2Metadata's typed fields. The `model_config` JSON blob is
3052        // unrecognized by AprWriter::build_v2_metadata and goes into the
3053        // `custom` map — which `realizar::gguf::config::from_apr` does NOT
3054        // read (it requires `apr.metadata.hidden_size` etc. to be Some).
3055        let arch_hidden_size = self.config.model_config.hidden_size;
3056        let arch_num_layers = self.config.model_config.num_hidden_layers;
3057        let arch_num_heads = self.config.model_config.num_attention_heads;
3058        let arch_num_kv_heads = self.config.model_config.num_kv_heads;
3059        let arch_intermediate_size = self.config.model_config.intermediate_size;
3060        let arch_vocab_size = self.config.model_config.vocab_size;
3061        let arch_max_position_embeddings = self.config.model_config.max_position_embeddings;
3062        let arch_rope_theta = self.config.model_config.rope_theta;
3063        let arch_rms_norm_eps = self.config.model_config.rms_norm_eps;
3064
3065        // ALB-130: Pre-read tokenizer.json for embedding in checkpoint.
3066        // Parse HuggingFace tokenizer format → extract vocab + merges + special token IDs.
3067        let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
3068            tokenizer_path.and_then(|p| {
3069                let json_bytes = std::fs::read(p).ok()?;
3070                let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
3071                let model = tok.get("model")?;
3072                let vocab_obj = model.get("vocab")?.as_object()?;
3073                // Build sorted-by-id vocab list
3074                let mut vocab_pairs: Vec<(String, u64)> =
3075                    vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
3076                vocab_pairs.sort_by_key(|(_, id)| *id);
3077                let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
3078                // Merges as "token1 token2" strings
3079                let merges: Vec<String> = model
3080                    .get("merges")?
3081                    .as_array()?
3082                    .iter()
3083                    .filter_map(|v| v.as_str().map(String::from))
3084                    .collect();
3085                // Special tokens: BOS=<s>=1, EOS=</s>=2 (from added_tokens)
3086                let added = tok.get("added_tokens").and_then(|a| a.as_array());
3087                let bos_id = added.and_then(|arr| {
3088                    arr.iter()
3089                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
3090                        .and_then(|t| t.get("id")?.as_u64())
3091                });
3092                let eos_id = added.and_then(|arr| {
3093                    arr.iter()
3094                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
3095                        .and_then(|t| t.get("id")?.as_u64())
3096                });
3097                if vocab.is_empty() {
3098                    return None;
3099                }
3100                println!(
3101                    "  [ALB-130] Embedding tokenizer: {} vocab, {} merges",
3102                    vocab.len(),
3103                    merges.len()
3104                );
3105                Some((vocab, merges, bos_id, eos_id))
3106            });
3107
3108        Box::new(move |path: &std::path::Path| {
3109            use aprender::serialization::apr::AprWriter;
3110            use serde_json::Value as Jv;
3111
3112            let mut writer = AprWriter::new();
3113
3114            // Metadata
3115            writer.set_metadata("model_name", Jv::String(name));
3116            writer.set_metadata("architecture", Jv::String(architecture));
3117            writer.set_metadata(
3118                "format",
3119                Jv::String(if is_delta_checkpoint {
3120                    "entrenar-delta-checkpoint".into()
3121                } else {
3122                    "entrenar-checkpoint".into()
3123                }),
3124            );
3125            writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
3126            writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
3127            writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
3128            writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
3129            if let Some(cfg) = model_config_json {
3130                writer.set_metadata("model_config", Jv::String(cfg));
3131            }
3132
3133            // SPEC-SHIP-TWO-001 §81 P0-E: write individual arch metadata keys
3134            // so realizar's `from_apr` (C-03 gate) accepts the checkpoint.
3135            // `serde_json::Number::from(u as u64)` converts usize losslessly.
3136            writer.set_metadata(
3137                "hidden_size",
3138                Jv::Number(serde_json::Number::from(arch_hidden_size as u64)),
3139            );
3140            writer.set_metadata(
3141                "num_hidden_layers",
3142                Jv::Number(serde_json::Number::from(arch_num_layers as u64)),
3143            );
3144            writer.set_metadata(
3145                "num_attention_heads",
3146                Jv::Number(serde_json::Number::from(arch_num_heads as u64)),
3147            );
3148            writer.set_metadata(
3149                "num_kv_heads",
3150                Jv::Number(serde_json::Number::from(arch_num_kv_heads as u64)),
3151            );
3152            writer.set_metadata(
3153                "intermediate_size",
3154                Jv::Number(serde_json::Number::from(arch_intermediate_size as u64)),
3155            );
3156            writer.set_metadata(
3157                "vocab_size",
3158                Jv::Number(serde_json::Number::from(arch_vocab_size as u64)),
3159            );
3160            writer.set_metadata(
3161                "max_position_embeddings",
3162                Jv::Number(serde_json::Number::from(arch_max_position_embeddings as u64)),
3163            );
3164            if let Some(rope) = serde_json::Number::from_f64(arch_rope_theta as f64) {
3165                writer.set_metadata("rope_theta", Jv::Number(rope));
3166            }
3167            if let Some(eps) = serde_json::Number::from_f64(arch_rms_norm_eps as f64) {
3168                writer.set_metadata("rms_norm_eps", Jv::Number(eps));
3169            }
3170
3171            // ALB-130: Embed tokenizer vocab + merges for standalone inference
3172            if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
3173                writer.set_metadata(
3174                    "tokenizer.vocabulary",
3175                    Jv::Array(vocab.into_iter().map(Jv::String).collect()),
3176                );
3177                writer.set_metadata(
3178                    "tokenizer.merges",
3179                    Jv::Array(merges.into_iter().map(Jv::String).collect()),
3180                );
3181                if let Some(bos) = bos_id {
3182                    writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
3183                }
3184                if let Some(eos) = eos_id {
3185                    writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
3186                }
3187            }
3188
3189            // Find hidden_size from norm weights for shape inference
3190            let hidden_size = param_data
3191                .iter()
3192                .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
3193                .map_or(0, |(_, d)| d.len());
3194
3195            // Model weight tensors
3196            for (tensor_name, data) in &param_data {
3197                let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
3198                writer.add_tensor_f32(tensor_name.clone(), shape, data);
3199            }
3200
3201            // Optimizer state tensors
3202            for (i, m_data) in embed_m.iter().enumerate() {
3203                let len = m_data.len();
3204                writer.add_tensor_f32(
3205                    format!("__training__.embed_optimizer.m.{i}"),
3206                    vec![len],
3207                    m_data,
3208                );
3209            }
3210            for (i, v_data) in embed_v.iter().enumerate() {
3211                let len = v_data.len();
3212                writer.add_tensor_f32(
3213                    format!("__training__.embed_optimizer.v.{i}"),
3214                    vec![len],
3215                    v_data,
3216                );
3217            }
3218
3219            // ALB-118: Save GPU block optimizer states (m/v moments for all 24 blocks)
3220            for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
3221                for (suffix, data) in buffers {
3222                    let len = data.len();
3223                    writer.add_tensor_f32(
3224                        format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
3225                        vec![len],
3226                        data,
3227                    );
3228                }
3229            }
3230
3231            // ALB-118: Save LM head and final norm optimizer states
3232            if !lm_head_m_host.is_empty() {
3233                let len = lm_head_m_host.len();
3234                writer.add_tensor_f32(
3235                    "__training__.lm_head_optimizer.m".to_string(),
3236                    vec![len],
3237                    &lm_head_m_host,
3238                );
3239                let len = lm_head_v_host.len();
3240                writer.add_tensor_f32(
3241                    "__training__.lm_head_optimizer.v".to_string(),
3242                    vec![len],
3243                    &lm_head_v_host,
3244                );
3245            }
3246            if !final_norm_m_host.is_empty() {
3247                let len = final_norm_m_host.len();
3248                writer.add_tensor_f32(
3249                    "__training__.final_norm_optimizer.m".to_string(),
3250                    vec![len],
3251                    &final_norm_m_host,
3252                );
3253                let len = final_norm_v_host.len();
3254                writer.add_tensor_f32(
3255                    "__training__.final_norm_optimizer.v".to_string(),
3256                    vec![len],
3257                    &final_norm_v_host,
3258                );
3259            }
3260
3261            // ENT-276: Save LoRA adapter weights (QLoRA checkpoint resume)
3262            for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
3263                if !a_q.is_empty() {
3264                    writer.add_tensor_f32(
3265                        format!("lora.{layer_idx}.q_proj.lora_a"),
3266                        vec![a_q.len()],
3267                        a_q,
3268                    );
3269                    writer.add_tensor_f32(
3270                        format!("lora.{layer_idx}.q_proj.lora_b"),
3271                        vec![b_q.len()],
3272                        b_q,
3273                    );
3274                }
3275                if !a_v.is_empty() {
3276                    writer.add_tensor_f32(
3277                        format!("lora.{layer_idx}.v_proj.lora_a"),
3278                        vec![a_v.len()],
3279                        a_v,
3280                    );
3281                    writer.add_tensor_f32(
3282                        format!("lora.{layer_idx}.v_proj.lora_b"),
3283                        vec![b_v.len()],
3284                        b_v,
3285                    );
3286                }
3287            }
3288
3289            // Write APR checkpoint to file
3290            writer
3291                .write(path)
3292                .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3293
3294            Ok(())
3295        })
3296    }
3297
3298    /// GPU device name.
3299    pub fn gpu_name(&self) -> String {
3300        self.cuda_trainer.device_name()
3301    }
3302
3303    /// ENT-269: Save LoRA adapter weights as PEFT-compatible files.
3304    ///
3305    /// Downloads LoRA A/B matrices from GPU, un-scales B (divide by lora_scale),
3306    /// transposes to PEFT convention (A=[rank, d_in], B=[d_out, rank]),
3307    /// and writes `adapter_model.safetensors` + `adapter_config.json`.
3308    ///
3309    /// # Contract: C-QLORA-SAVE-001
3310    ///
3311    /// NF4 QLoRA training MUST produce `adapter_model.safetensors` in output_dir.
3312    pub fn save_cuda_lora_adapter(
3313        &self,
3314        output_dir: &std::path::Path,
3315        base_model_name: Option<&str>,
3316    ) -> crate::Result<()> {
3317        if !self.config.quantize_nf4 || !self.config.is_lora() {
3318            return Ok(()); // Not a QLoRA run, nothing to save
3319        }
3320
3321        let lora_rank = self.config.lora_rank.unwrap_or(16);
3322        let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3323        let lora_scale = lora_alpha / lora_rank as f32;
3324        let hidden_size = self.config.model_config.hidden_size;
3325        let head_dim = self.config.model_config.head_dim();
3326        let q_dim = self.config.model_config.num_attention_heads * head_dim;
3327        let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3328
3329        let lora_config =
3330            crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3331
3332        let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3333
3334        for (i, block) in self.cuda_blocks.iter().enumerate() {
3335            let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3336                Ok(weights) => weights,
3337                Err(_) => continue, // Skip non-NF4 blocks
3338            };
3339
3340            if a_q.is_empty() && a_v.is_empty() {
3341                continue;
3342            }
3343
3344            // Q projection LoRA
3345            if !a_q.is_empty() {
3346                // GPU stores A_q as [hidden, rank] row-major, PEFT expects [rank, hidden]
3347                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3348                for r in 0..hidden_size {
3349                    for c in 0..lora_rank {
3350                        a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3351                    }
3352                }
3353
3354                // GPU stores B_q as [rank, q_dim] pre-scaled by lora_scale
3355                // PEFT expects [q_dim, rank] un-scaled
3356                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3357                let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3358                for r in 0..lora_rank {
3359                    for c in 0..q_dim {
3360                        b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3361                    }
3362                }
3363
3364                let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3365                let mut layer = crate::lora::LoRALayer::new(
3366                    base_weight,
3367                    q_dim,
3368                    hidden_size,
3369                    lora_rank,
3370                    lora_alpha,
3371                );
3372                // Overwrite the A and B data with trained weights
3373                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3374                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3375
3376                adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3377            }
3378
3379            // V projection LoRA
3380            if !a_v.is_empty() {
3381                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3382                for r in 0..hidden_size {
3383                    for c in 0..lora_rank {
3384                        a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3385                    }
3386                }
3387
3388                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3389                let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3390                for r in 0..lora_rank {
3391                    for c in 0..kv_hidden {
3392                        b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3393                    }
3394                }
3395
3396                let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3397                let mut layer = crate::lora::LoRALayer::new(
3398                    base_weight,
3399                    kv_hidden,
3400                    hidden_size,
3401                    lora_rank,
3402                    lora_alpha,
3403                );
3404                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3405                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3406
3407                adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3408            }
3409        }
3410
3411        if adapters.is_empty() {
3412            println!("  [WARN] No LoRA adapters found to save");
3413            return Ok(());
3414        }
3415
3416        let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3417            adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3418
3419        std::fs::create_dir_all(output_dir).ok();
3420        crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3421            .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3422
3423        let adapter_path = output_dir.join("adapter_model.safetensors");
3424        let size_mb =
3425            std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3426        println!(
3427            "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3428            adapters.len(),
3429            size_mb,
3430            output_dir.display()
3431        );
3432
3433        Ok(())
3434    }
3435
3436    /// R-001: Save CPU embedding optimizer state (m/v buffers + step counter).
3437    ///
3438    /// Writes `optimizer_state.json` to the given directory. GPU block optimizer
3439    /// states remain on-device (D2H for 20 buffers × N blocks is deferred).
3440    pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3441        let path = dir.join("optimizer_state.json");
3442        let m_data: Vec<Option<Vec<f32>>> = self
3443            .embed_optimizer
3444            .first_moments()
3445            .iter()
3446            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3447            .collect();
3448        let v_data: Vec<Option<Vec<f32>>> = self
3449            .embed_optimizer
3450            .second_moments()
3451            .iter()
3452            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3453            .collect();
3454        let state = serde_json::json!({
3455            "type": "adamw_cpu_embed",
3456            "step": self.embed_optimizer.step_count(),
3457            "m": m_data,
3458            "v": v_data,
3459        });
3460        let json_str = serde_json::to_string(&state).map_err(|e| {
3461            crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3462        })?;
3463        std::fs::write(&path, json_str)
3464            .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3465        Ok(())
3466    }
3467
3468    /// ENT-276: Restore LoRA adapter weights from APR checkpoint.
3469    ///
3470    /// Reads `lora.{layer}.{q,v}_proj.lora_{a,b}` tensors from the APR file
3471    /// and uploads them to the NF4 CUDA blocks, replacing the fresh random init.
3472    /// Returns (layers_restored, layers_total).
3473    pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3474        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3475            Ok(r) => r,
3476            Err(_) => return (0, self.cuda_blocks.len()),
3477        };
3478
3479        let mut restored = 0usize;
3480        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3481            let a_q =
3482                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3483            let b_q =
3484                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3485            let a_v =
3486                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3487            let b_v =
3488                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3489
3490            if a_q.is_empty() {
3491                continue; // No LoRA data for this layer in checkpoint
3492            }
3493
3494            if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3495                eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3496                continue;
3497            }
3498            restored += 1;
3499        }
3500
3501        (restored, self.cuda_blocks.len())
3502    }
3503
3504    /// ALB-096: Load CPU embedding optimizer state from APR checkpoint.
3505    ///
3506    /// Reads `__training__.embed_optimizer.{m,v}.*` tensors from the APR file.
3507    /// Returns true if state was loaded.
3508    pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3509        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3510            Ok(r) => r,
3511            Err(_) => return false,
3512        };
3513
3514        // Restore step count from metadata
3515        if let Some(step_val) = reader.get_metadata("optimizer_step") {
3516            if let Some(step_str) = step_val.as_str() {
3517                if let Ok(step) = step_str.parse::<u64>() {
3518                    self.embed_optimizer.set_step_count(step);
3519                }
3520            }
3521        }
3522
3523        // Restore first moments (m)
3524        for i in 0..128 {
3525            let name = format!("__training__.embed_optimizer.m.{i}");
3526            match reader.read_tensor_f32(&name) {
3527                Ok(data) if !data.is_empty() => {
3528                    self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3529                }
3530                _ => break,
3531            }
3532        }
3533
3534        // Restore second moments (v)
3535        for i in 0..128 {
3536            let name = format!("__training__.embed_optimizer.v.{i}");
3537            match reader.read_tensor_f32(&name) {
3538                Ok(data) if !data.is_empty() => {
3539                    self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3540                }
3541                _ => break,
3542            }
3543        }
3544
3545        // ALB-118: Restore GPU block optimizer states (m/v moments for all blocks)
3546        let suffixes = [
3547            "m.w_q",
3548            "v.w_q",
3549            "m.w_k",
3550            "v.w_k",
3551            "m.w_v",
3552            "v.w_v",
3553            "m.w_o",
3554            "v.w_o",
3555            "m.w_gate",
3556            "v.w_gate",
3557            "m.w_up",
3558            "v.w_up",
3559            "m.w_down",
3560            "v.w_down",
3561            "m.input_norm",
3562            "v.input_norm",
3563            "m.post_attn_norm",
3564            "v.post_attn_norm",
3565        ];
3566        let mut blocks_restored = 0usize;
3567        for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3568            let mut data = std::collections::HashMap::new();
3569            for suffix in &suffixes {
3570                let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3571                if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3572                    if !tensor_data.is_empty() {
3573                        data.insert(suffix.to_string(), tensor_data);
3574                    }
3575                }
3576            }
3577            if !data.is_empty() {
3578                let _ = state.restore_from_host(&data);
3579                blocks_restored += 1;
3580            }
3581        }
3582
3583        // ALB-118: Restore LM head optimizer state
3584        if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3585            if m_data.len() == self.lm_head_m.len() {
3586                let _ = self.lm_head_m.copy_from_host(&m_data);
3587            }
3588        }
3589        if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3590            if v_data.len() == self.lm_head_v.len() {
3591                let _ = self.lm_head_v.copy_from_host(&v_data);
3592            }
3593        }
3594
3595        // ALB-118: Restore final norm optimizer state
3596        if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3597            if m_data.len() == self.final_norm_m.len() {
3598                let _ = self.final_norm_m.copy_from_host(&m_data);
3599            }
3600        }
3601        if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3602            if v_data.len() == self.final_norm_v.len() {
3603                let _ = self.final_norm_v.copy_from_host(&v_data);
3604            }
3605        }
3606
3607        // ALB-132: Report restore results — don't silently swallow failures
3608        if blocks_restored > 0 {
3609            println!(
3610                "  ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3611                self.gpu_training.optimizer_states.len()
3612            );
3613        } else if !self.gpu_training.optimizer_states.is_empty() {
3614            println!(
3615                "  [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3616                self.gpu_training.optimizer_states.len()
3617            );
3618        }
3619
3620        true
3621    }
3622
3623    /// R-001: Load CPU embedding optimizer state from `optimizer_state.json`.
3624    ///
3625    /// Returns true if state was loaded, false if file doesn't exist.
3626    pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3627        let path = dir.join("optimizer_state.json");
3628        let data = match std::fs::read_to_string(&path) {
3629            Ok(d) => d,
3630            Err(_) => return false,
3631        };
3632        let state: serde_json::Value = match serde_json::from_str(&data) {
3633            Ok(v) => v,
3634            Err(_) => return false,
3635        };
3636        if let Some(step) = state["step"].as_u64() {
3637            self.embed_optimizer.set_step_count(step);
3638        }
3639        restore_moment_buffers(&state["m"], |idx, arr| {
3640            self.embed_optimizer.set_first_moment(idx, arr);
3641        });
3642        restore_moment_buffers(&state["v"], |idx, arr| {
3643            self.embed_optimizer.set_second_moment(idx, arr);
3644        });
3645        true
3646    }
3647}
3648
3649/// ALB-096: Infer 2D tensor shape from name and element count.
3650///
3651/// Same logic as `infer_all_tensor_shapes` in `io/save.rs` but for a single tensor.
3652#[cfg(feature = "cuda")]
3653fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3654    if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3655        vec![numel]
3656    } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3657        let other_dim = numel / hidden_size;
3658        if name.ends_with("down_proj.weight") {
3659            vec![hidden_size, other_dim]
3660        } else {
3661            vec![other_dim, hidden_size]
3662        }
3663    } else {
3664        vec![numel]
3665    }
3666}
3667
3668/// Parse a JSON array of moment buffers and apply each via callback.
3669#[cfg(feature = "cuda")]
3670fn restore_moment_buffers(
3671    json_arr: &serde_json::Value,
3672    mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3673) {
3674    let Some(arr) = json_arr.as_array() else { return };
3675    for (idx, val) in arr.iter().enumerate() {
3676        let Some(inner) = val.as_array() else { continue };
3677        let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3678        if !floats.is_empty() {
3679            set_fn(idx, ndarray::Array1::from_vec(floats));
3680        }
3681    }
3682}
3683
3684// ── Non-CUDA stub ──
3685
3686#[cfg(not(feature = "cuda"))]
3687pub struct CudaTransformerTrainer;
3688
3689#[cfg(not(feature = "cuda"))]
3690impl CudaTransformerTrainer {
3691    pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3692        Err(crate::error::Error::ConfigError(
3693            "CUDA not available (compiled without cuda feature)".into(),
3694        ))
3695    }
3696
3697    pub fn with_model(
3698        _model: crate::transformer::Transformer,
3699        _config: super::config::TransformerTrainConfig,
3700    ) -> crate::Result<Self> {
3701        Err(crate::error::Error::ConfigError(
3702            "CUDA not available (compiled without cuda feature)".into(),
3703        ))
3704    }
3705
3706    pub fn gpu_name(&self) -> String {
3707        unreachable!("CudaTransformerTrainer stub should never be instantiated")
3708    }
3709}
3710
3711#[cfg(test)]
3712mod tests {
3713    #[test]
3714    #[cfg(not(feature = "cuda"))]
3715    fn test_cuda_trainer_stub_returns_error() {
3716        use super::super::config::TransformerTrainConfig;
3717        use crate::transformer::TransformerConfig;
3718
3719        let mc = TransformerConfig::tiny();
3720        let config = TransformerTrainConfig::new(mc);
3721        let result = super::CudaTransformerTrainer::new(config);
3722        assert!(result.is_err());
3723    }
3724}