Skip to main content

entrenar/train/transformer_trainer/
cuda_trainer.rs

1//! GPU-resident transformer trainer (ALB-040)
2//!
3//! Wires the existing `CudaTransformerBlock` forward/backward/optimizer_step
4//! into the pretraining path. Follows the proven `classify_pipeline.rs` pattern.
5//!
6//! # Architecture
7//!
8//! ```text
9//! CudaTransformerTrainer
10//! ├── model: Transformer                 (CPU — embed + save)
11//! ├── cuda_trainer: CudaTrainer          (GPU device context)
12//! ├── cuda_blocks: Vec<CudaBlock>            (fp32 or NF4)
13//! ├── cuda_grad_workspace: CudaGradWorkspace
14//! ├── gpu_training: GpuPretrainState     (layer_inputs, grad bufs, opt states)
15//! ├── lm_head_weight_gpu: GpuBuffer      (V × H on GPU)
16//! ├── lm_head_grad_gpu: GpuBuffer        (V × H gradient scratch)
17//! ├── lm_head_m/v: GpuBuffer             (AdamW moment states)
18//! └── config: TransformerTrainConfig
19//! ```
20//!
21//! # Transfer budget (C-GPUTRAIN-002, updated KAIZEN-050/052)
22//!
23//! 1 PCIe transfer per training step (+ tiny control transfers):
24//! 1. H2D: hidden states after embedding (seq×H×4 bytes)
25//! 2. H2D: target_ids for fused cross-entropy (seq×4 bytes — ~512B)
26//! 3. D2H: loss_partials from fused cross-entropy (seq×4 bytes — ~512B)
27//!
28//! Eliminated by KAIZEN-050:
29//! - D2H logits (was seq×V×4 = 77.8MB for Qwen3-4B)
30//! - H2D grad_logits (was seq×V×4 = 77.8MB)
31//!
32//! Eliminated by KAIZEN-052:
33//! - grad_gpu buffer allocation (was seq×V×4 = 77.8MB per step)
34
35#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42    gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46    adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47    gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48    squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64    CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65    GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75/// Compute gradient L2 norm of the shared workspace via GPU reduction (KAIZEN-054).
76///
77/// Uses `squared_sum_cuda` per buffer (~1KB D2H each) instead of downloading entire
78/// gradient buffers to CPU (was 58 MB+ per block, disabled in ALB-067).
79///
80/// Free function to avoid borrow conflicts with `&mut self`.
81#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83    ws: &CudaGradWorkspace,
84    max_norm: f32,
85    stream: &CudaStream,
86) -> (f32, f32) {
87    use crate::autograd::cuda_optim::PendingSquaredSum;
88
89    let all_bufs: [&GpuBuffer<f32>; 9] = [
90        &ws.grad_w_q,
91        &ws.grad_w_k,
92        &ws.grad_w_v,
93        &ws.grad_w_o,
94        &ws.grad_gate,
95        &ws.grad_up,
96        &ws.grad_down,
97        &ws.grad_input_norm,
98        &ws.grad_post_attn_norm,
99    ];
100
101    // KAIZEN-055: Launch all 9 squared_sum kernels back-to-back without syncing.
102    // Single sync after all launches — reduces 9 pipeline flushes to 1 per block.
103    let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104    for buf in &all_bufs {
105        let n = buf.len() as u32;
106        if n == 0 {
107            continue;
108        }
109        if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110            pending.push(p);
111        }
112    }
113
114    // Single sync point for all 9 kernel launches.
115    if stream.synchronize().is_err() {
116        return (1.0, 0.0);
117    }
118
119    // Collect results: download partial sums (~1KB each) and combine.
120    // C-CLIP-001: squared_sum_collect returns sum(x²) = ||g||².
121    // Accumulate directly — do NOT re-square (entrenar#311 fix).
122    let mut total_sq = 0.0f64;
123    for p in &pending {
124        if let Ok(sq_norm) = squared_sum_collect(p) {
125            total_sq += f64::from(sq_norm); // sq_norm is already ||g||²
126        }
127    }
128
129    let grad_norm = total_sq.sqrt() as f32; // L2 norm = sqrt(sum of squared norms)
130    let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131    (scale, grad_norm)
132}
133
134/// Clip all gradient buffers in the shared workspace using GPU-computed L2 norm (KAIZEN-054).
135///
136/// R-004: Returns pre-clip gradient L2 norm for observability logging.
137#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139    let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140    if (scale - 1.0).abs() < 1e-7 {
141        return grad_norm;
142    }
143
144    let n_wq = ws.grad_w_q.len() as u32;
145    let n_wk = ws.grad_w_k.len() as u32;
146    let n_wv = ws.grad_w_v.len() as u32;
147    let n_wo = ws.grad_w_o.len() as u32;
148    let n_gate = ws.grad_gate.len() as u32;
149    let n_up = ws.grad_up.len() as u32;
150    let n_down = ws.grad_down.len() as u32;
151    let n_inorm = ws.grad_input_norm.len() as u32;
152    let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154    let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155    let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156    let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157    let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158    let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159    let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160    let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163    grad_norm
164}
165
166/// ALB-078: Fused gradient clipping — entire pipeline stays on GPU.
167///
168/// Replaces `clip_workspace_gradients` by eliminating the stream.synchronize()
169/// and D2H partial-sum download. All computation happens on GPU:
170///
171/// 1. 9× SquaredSumKernel → write partials to pre-allocated contiguous buffer
172/// 2. 1× ClipScaleReduceKernel → reduce partials, compute scale on GPU
173/// 3. 9× GradientClipGpuScaleKernel → read scale from GPU, apply to gradients
174///
175/// Zero sync points, zero D2H transfers per block.
176#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178    ws: &mut CudaGradWorkspace,
179    max_norm: f32,
180    state: &FusedClipState,
181    stream: &CudaStream,
182) {
183    let all_bufs: [&GpuBuffer<f32>; 9] = [
184        &ws.grad_w_q,
185        &ws.grad_w_k,
186        &ws.grad_w_v,
187        &ws.grad_w_o,
188        &ws.grad_gate,
189        &ws.grad_up,
190        &ws.grad_down,
191        &ws.grad_input_norm,
192        &ws.grad_post_attn_norm,
193    ];
194
195    // Phase 1: Launch 9 squared_sum kernels into contiguous partials buffer.
196    // Each writes to state.partials_buf at its pre-computed offset.
197    for (i, buf) in all_bufs.iter().enumerate() {
198        let n = buf.len() as u32;
199        if n == 0 {
200            continue;
201        }
202        let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203        let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204    }
205
206    // Phase 2: Reduce all partials and compute clip_scale on GPU.
207    // Stream ordering guarantees all squared_sum kernels complete before this runs.
208    let _ = clip_scale_reduce_cuda(
209        &state.partials_buf,
210        state.total_partials,
211        max_norm,
212        &state.scale_buf,
213        stream,
214    );
215
216    // Phase 3: Apply clip scale to all 9 gradient buffers.
217    // Scale is read from GPU memory — no D2H needed.
218    let scale_ptr = state.scale_buf.as_ptr(); // output[0] = clip_scale
219    let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220        &mut ws.grad_w_q,
221        &mut ws.grad_w_k,
222        &mut ws.grad_w_v,
223        &mut ws.grad_w_o,
224        &mut ws.grad_gate,
225        &mut ws.grad_up,
226        &mut ws.grad_down,
227        &mut ws.grad_input_norm,
228        &mut ws.grad_post_attn_norm,
229    ];
230    for buf in &mut all_bufs_mut {
231        let n = buf.len() as u32;
232        if n == 0 {
233            continue;
234        }
235        let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236    }
237}
238
239/// R-004: Compute gradient L2 norm without clipping (for observability only).
240///
241/// Uses GPU reduction (KAIZEN-054). Only ~9KB D2H per call.
242#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245    let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246    norm
247}
248
249/// ALB-072: Unscale all gradient buffers in the shared workspace by `inv_scale`.
250///
251/// In fp16 AMP, the fused cross-entropy kernel multiplies loss_scale into the
252/// gradient output. All subsequent backward gradients carry this scaling. The
253/// GPU block optimizer (AdamW) must receive unscaled gradients — otherwise the
254/// second moment `v` overflows f32, producing NaN in early layers.
255///
256/// This is the GPU-side equivalent of `GradScaler::unscale_and_check()` used
257/// for CPU embedding gradients.
258#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261    if (inv_scale - 1.0).abs() < 1e-7 {
262        return;
263    }
264
265    let n_wq = ws.grad_w_q.len() as u32;
266    let n_wk = ws.grad_w_k.len() as u32;
267    let n_wv = ws.grad_w_v.len() as u32;
268    let n_wo = ws.grad_w_o.len() as u32;
269    let n_gate = ws.grad_gate.len() as u32;
270    let n_up = ws.grad_up.len() as u32;
271    let n_down = ws.grad_down.len() as u32;
272    let n_inorm = ws.grad_input_norm.len() as u32;
273    let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275    let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276    let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277    let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278    let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279    let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280    let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281    let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286/// GPU-resident training state for pretraining.
287///
288/// # Contract (C-GPUTRAIN-001)
289///
290/// - `layer_inputs.len() == num_layers`
291/// - All buffers preallocated at init; zero GPU allocations during training
292/// - `step` increments monotonically
293#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295    /// Saved layer inputs for backward [num_layers][seq_len * hidden_size]
296    layer_inputs: Vec<GpuBuffer<f32>>,
297    /// Which layer inputs were saved during forward (activation checkpointing).
298    /// When checkpointing is enabled, only checkpoint boundary layers are saved.
299    /// Non-saved layers are recomputed from the nearest checkpoint before backward.
300    saved_layer_mask: Vec<bool>,
301    /// Temporary buffer for activation recomputation [seq_len * hidden_size].
302    /// Used as the initial input when recomputing from a checkpoint boundary.
303    /// Only allocated when activation checkpointing is enabled.
304    recompute_buf: Option<GpuBuffer<f32>>,
305    /// Final RMSNorm weight on GPU [hidden_size]
306    final_norm_weight: GpuBuffer<f32>,
307    /// Final block output (pre-norm) for RMSNorm backward [seq_len * hidden_size]
308    blocks_output: GpuBuffer<f32>,
309    /// Alternating gradient buffer A [seq_len * hidden_size]
310    grad_buf_a: GpuBuffer<f32>,
311    /// Alternating gradient buffer B [seq_len * hidden_size]
312    grad_buf_b: GpuBuffer<f32>,
313    /// Gradient for final norm weight [hidden_size]
314    grad_final_norm_weight: GpuBuffer<f32>,
315    /// RMSNorm output buffer (reused each step) [seq_len * hidden_size]
316    norm_output: GpuBuffer<f32>,
317    /// Logits buffer (reused each step) [seq_len * vocab_size]
318    logits_buf: GpuBuffer<f32>,
319    /// LM head gradient buffer [seq_len * hidden_size] (grad w.r.t. normed hidden)
320    lm_head_grad_hidden: GpuBuffer<f32>,
321    /// Per-block optimizer states
322    optimizer_states: Vec<GpuBlockOptimizerState>,
323    /// Optimizer step counter
324    step: u32,
325}
326
327/// GPU-resident transformer trainer for pretraining.
328///
329/// Uses `CudaTransformerBlock` forward/backward/optimizer_step on GPU,
330/// keeping only embedding lookup and cross-entropy loss on CPU.
331///
332/// # Contract (C-GPUTRAIN-002)
333///
334/// - Exactly 3 PCIe transfers per training step
335/// - Graceful fallback to CPU `TransformerTrainer` on any CUDA failure
336/// - Weight sync via `sync_weights_to_cpu()` before save
337#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339    /// CPU model (for embedding, saving, fallback)
340    model: Transformer,
341    /// CUDA device context
342    cuda_trainer: CudaTrainer,
343    /// GPU-resident transformer blocks (fp32 or NF4 via CudaBlock enum)
344    cuda_blocks: Vec<CudaBlock>,
345    /// Shared gradient workspace (one set, reused across layers; fp32 path only)
346    cuda_grad_workspace: CudaGradWorkspace,
347    /// ENT-263: Shared scratch for NF4 blocks (C-SCRATCH-001). None when fp32.
348    nf4_shared_scratch: Option<CudaBlockScratch>,
349    /// ENT-263: Shared LoRA gradient workspace for NF4 backward. None when fp32.
350    nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351    /// ENT-263: Per-block LoRA optimizer states for NF4. None when fp32.
352    nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353    /// GPU training state (layer inputs, grad bufs, optimizer states)
354    gpu_training: GpuPretrainState,
355    /// LM head weight on GPU [vocab_size * hidden_size]
356    lm_head_weight_gpu: GpuBuffer<f32>,
357    /// LM head weight gradient on GPU [vocab_size * hidden_size]
358    lm_head_grad_gpu: GpuBuffer<f32>,
359    /// LM head AdamW first moment [vocab_size * hidden_size]
360    lm_head_m: GpuBuffer<f32>,
361    /// LM head AdamW second moment [vocab_size * hidden_size]
362    lm_head_v: GpuBuffer<f32>,
363    /// Final norm weight AdamW first moment [hidden_size]
364    final_norm_m: GpuBuffer<f32>,
365    /// Final norm weight AdamW second moment [hidden_size]
366    final_norm_v: GpuBuffer<f32>,
367    /// CPU optimizer for embedding weights only
368    embed_optimizer: AdamW,
369    /// Training configuration
370    config: TransformerTrainConfig,
371    /// Metrics tracker
372    pub metrics: MetricsTracker,
373    /// Current optimizer step
374    step: usize,
375    /// Accumulated loss (for gradient accumulation)
376    accumulated_loss: f32,
377    /// Accumulated batch count
378    accumulated_batches: usize,
379    /// R-004: Last observed LM head gradient L2 norm (proxy for global grad norm)
380    last_grad_norm: f32,
381    /// R-040: Last observed embedding activation gradient L2 norm
382    last_embed_grad_norm: f32,
383    /// R-038: Per-block gradient accumulation for true multi-step gradient accumulation.
384    /// Only allocated when accumulation_steps > 1. CPU-side buffers (~335 MB for 350M).
385    grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386    /// ALB-091: GPU-resident gradient accumulation (replaces CPU accum when available).
387    /// Eliminates 24 × ga stream.synchronize() + D2H transfers per optimizer step.
388    gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389    /// R-002: Gradient scaler for mixed-precision training.
390    /// For BF16: no-op (scale=1.0, dynamic=false).
391    /// For FP16: dynamic loss scaling to prevent gradient underflow.
392    grad_scaler: GradScaler,
393    /// KAIZEN-047: Per-step wall-clock profiler.
394    /// Reports timing breakdown for each training phase.
395    profiler: StepProfiler,
396    /// KAIZEN-053: Pre-allocated forward scratch buffers [max_seq_len * hidden_size].
397    /// Reused every step — eliminates 2 × cuMemAlloc/Free per training step.
398    fwd_scratch_a: GpuBuffer<f32>,
399    fwd_scratch_b: GpuBuffer<f32>,
400    /// KAIZEN-056: Pre-allocated CPU staging buffer for H2D hidden state upload.
401    /// Eliminates vec![0.0; max_seq_len * hidden_size] allocation per step.
402    h2d_staging: Vec<f32>,
403    /// KAIZEN-059: Pre-allocated CPU staging buffer for D2H gradient downloads
404    /// during gradient accumulation. Sized to max(h*intermediate, vocab*h).
405    /// Eliminates ~15GB of per-step heap churn (36 × vec![0.0; h*i] + vec![0.0; vocab*h]
406    /// per micro-batch × accumulation_steps).
407    d2h_staging: Vec<f32>,
408    /// ALB-078: Pre-allocated state for fused gradient clipping pipeline.
409    /// Eliminates 24 stream.synchronize() calls per step.
410    fused_clip: Option<FusedClipState>,
411    /// Pre-allocated host zero buffer for zeroing final norm grad [hidden_size].
412    /// BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
413    /// so the buffer must be zeroed before each rms_norm_backward call.
414    final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419    /// Create a new GPU-resident trainer.
420    ///
421    /// # Errors
422    ///
423    /// Returns `Err` if CUDA initialization, kernel pre-warming, or block upload fails.
424    /// Caller should fall back to CPU `TransformerTrainer` on error.
425    pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426        let model = Transformer::new(&config.model_config);
427        Self::with_model(model, config)
428    }
429
430    /// ALB-089: Load SafeTensors checkpoint for GPU inference (forward-only).
431    ///
432    /// Creates a `CudaTransformerTrainer` in inference mode. The optimizer
433    /// state is allocated (wasteful but simple), but `forward_logits()` only
434    /// uses the forward path. Call `forward_logits(&tokens)` to generate.
435    ///
436    /// # Arguments
437    /// * `checkpoint_dir` - Directory containing model.safetensors + config.json
438    /// * `model_config` - Transformer architecture config
439    ///
440    /// # Errors
441    ///
442    /// Returns `Err` if SafeTensors loading or CUDA initialization fails.
443    pub fn for_inference(
444        checkpoint_dir: impl AsRef<std::path::Path>,
445        model_config: crate::transformer::TransformerConfig,
446    ) -> crate::Result<Self> {
447        let dir = checkpoint_dir.as_ref();
448
449        // ALB-089: Try APR format first (our native checkpoint format), then SafeTensors
450        let model = if let Some((Some(m), _step)) =
451            crate::config::try_load_apr_for_inference(dir, &model_config)
452        {
453            m
454        } else {
455            Transformer::from_safetensors(dir, &model_config)?
456        };
457
458        let mut config = TransformerTrainConfig::new(model_config);
459        config.max_seq_len = config.model_config.max_position_embeddings;
460        Self::with_model(model, config)
461    }
462
463    /// Create a GPU-resident trainer from an existing model.
464    ///
465    /// # Errors
466    ///
467    /// Returns `Err` if CUDA initialization fails.
468    pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469        if !cuda_training_available() {
470            return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471        }
472
473        let mc = &config.model_config;
474        let max_seq_len = config.max_seq_len;
475        let hidden_size = mc.hidden_size;
476        let vocab_size = mc.vocab_size;
477        let num_layers = mc.num_hidden_layers;
478
479        // Step 1: Create CUDA trainer (initializes kernel caches)
480        let cuda_trainer = CudaTrainer::new().map_err(|e| {
481            crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482        })?;
483
484        println!(
485            "  GPU: {} ({:.1} GB)",
486            cuda_trainer.device_name(),
487            cuda_trainer.total_memory() as f64 / 1e9
488        );
489
490        let ctx = cuda_trainer.context().clone();
491        let stream = cuda_trainer.stream();
492
493        // Step 2: Pre-warm forward kernels (C-PREWARM-001)
494        // Must happen before block upload — JIT compilation needs free VRAM
495        pre_warm_forward_kernels(
496            hidden_size,
497            mc.intermediate_size,
498            mc.num_attention_heads,
499            mc.num_kv_heads,
500            mc.head_dim(),
501            max_seq_len,
502        )
503        .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505        // Step 2a: Pre-warm backward kernels (trueno#200)
506        // MUST happen before any GPU work — Blackwell's cuModuleLoadData fails
507        // with ILLEGAL_ADDRESS when called during active GPU computation.
508        {
509            use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510            let head_dim = mc.head_dim();
511            pre_warm_lora_backward_kernels(
512                hidden_size,
513                mc.num_attention_heads * head_dim,
514                mc.num_kv_heads * head_dim,
515                max_seq_len,
516                config.lora_rank.unwrap_or(0),
517                mc.intermediate_size,
518                mc.num_attention_heads,
519                config.quantize_nf4 && config.is_lora(),
520            )
521            .map_err(|e| {
522                crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523            })?;
524            eprintln!("  ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525        }
526
527        // Step 2b: Bind cuBLAS handles to training stream (ALB-075)
528        // Must happen after kernel cache init, before any GEMM calls.
529        if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530            println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531        }
532        if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533            println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534        }
535
536        // Step 3: Upload transformer blocks to GPU
537        let use_nf4 = config.quantize_nf4 && config.is_lora();
538        let cuda_blocks = Self::upload_blocks(
539            &model,
540            mc,
541            &config,
542            &ctx,
543            use_nf4,
544            num_layers,
545            hidden_size,
546            max_seq_len,
547        )?;
548
549        // Step 4: Allocate shared gradient workspace
550        let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551            crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552        })?;
553
554        // Step 5: Allocate GPU training state
555        let buf_size = max_seq_len * hidden_size;
556        let logits_size = max_seq_len * vocab_size;
557
558        // Activation checkpointing: determine which layers save their inputs.
559        // Checkpoint boundary layers (every segment_size layers) are always saved.
560        // Non-boundary layers are recomputed from the nearest checkpoint during backward.
561        let checkpointing = config.checkpoint_config.enabled;
562        let segment_size = if checkpointing {
563            let ns = config.checkpoint_config.num_segments.max(1);
564            num_layers.div_ceil(ns)
565        } else {
566            1 // Every layer is a checkpoint (no recomputation)
567        };
568        let saved_layer_mask: Vec<bool> =
569            (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571        let mut layer_inputs = Vec::with_capacity(num_layers);
572        for _ in 0..num_layers {
573            layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574                crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575            })?);
576        }
577
578        // Allocate recompute buffer if checkpointing is enabled
579        let recompute_buf = if checkpointing {
580            Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581                crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582            })?)
583        } else {
584            None
585        };
586
587        if checkpointing {
588            let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589            println!(
590                "  ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591                config.checkpoint_config.num_segments, saved_count, num_layers
592            );
593        }
594
595        // Upload final RMSNorm weight
596        let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597        let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598            crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599        })?;
600
601        let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602            crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603        })?;
604        let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605            crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606        })?;
607        let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608            crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609        })?;
610        let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611            crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612        })?;
613        let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614            crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615        })?;
616        let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617            crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618        })?;
619        let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621        })?;
622
623        // Initialize per-block optimizer states (fp32 path only; NF4 uses LoRA states)
624        let mut optimizer_states = Vec::new();
625        if !use_nf4 {
626            optimizer_states.reserve(num_layers);
627            for (i, block) in cuda_blocks.iter().enumerate() {
628                optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629                    crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630                })?);
631            }
632        }
633
634        let gpu_training = GpuPretrainState {
635            layer_inputs,
636            saved_layer_mask,
637            recompute_buf,
638            final_norm_weight,
639            blocks_output,
640            grad_buf_a,
641            grad_buf_b,
642            grad_final_norm_weight,
643            norm_output,
644            logits_buf,
645            lm_head_grad_hidden,
646            optimizer_states,
647            step: 0,
648        };
649
650        // Step 6: Upload LM head weight to GPU
651        // Use tied weights (embed_tokens.weight) or separate lm_head
652        let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653        let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654        let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655            crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656        })?;
657        let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659        })?;
660        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
661        // zero memory (cuMemAlloc returns uninitialized VRAM).
662        let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663            .map_err(|e| {
664                crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665            })?;
666        let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667            .map_err(|e| {
668                crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669            })?;
670
671        // Final norm optimizer states
672        let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673            crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674        })?;
675        let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676            crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677        })?;
678
679        // KAIZEN-053: Pre-allocate forward scratch buffers (reused every step)
680        let buf_size = max_seq_len * hidden_size;
681        let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682            crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683        })?;
684        let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685            crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686        })?;
687
688        // Sync to ensure all uploads completed
689        stream
690            .synchronize()
691            .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693        println!(
694            "  ✓ GPU training state allocated (LM head: {:.1} MB)",
695            (vocab_size * hidden_size * 4) as f64 / 1e6
696        );
697
698        // ENT-263: Allocate NF4 infrastructure (shared scratch, LoRA grad workspace, optimizer states)
699        let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700            let lora_rank = config.lora_rank.unwrap_or(16);
701
702            // C-SCRATCH-001: Shared scratch for NF4 blocks (reused across all layers)
703            let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704                crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705            })?;
706
707            // LoRA gradient workspace (shared, reused per-block like CudaGradWorkspace)
708            let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709                crate::error::Error::ConfigError(format!(
710                    "NF4 LoRA grad workspace alloc failed: {e:?}"
711                ))
712            })?;
713
714            // Per-block LoRA optimizer states
715            let mut lora_opt_states = Vec::with_capacity(num_layers);
716            for (i, block) in cuda_blocks.iter().enumerate() {
717                lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718                    crate::error::Error::ConfigError(format!(
719                        "Block {i} LoRA opt state failed: {e:?}"
720                    ))
721                })?);
722            }
723
724            println!(
725                "  ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726            );
727            (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728        } else {
729            (None, None, None)
730        };
731
732        // KAIZEN-050: loss_fn removed — cross-entropy computed by fused GPU kernel
733        // C-EMBED-GRAD-001: CPU optimizer must match YAML hyperparams (not defaults)
734        let embed_optimizer =
735            AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737        // R-038: Allocate per-block gradient accumulation buffers (CPU-side)
738        // when accumulation_steps > 1 for true gradient accumulation.
739        let grad_accum = if config.accumulation_steps > 1 {
740            let kv_hidden = mc.num_kv_heads * mc.head_dim();
741            let block_sizes =
742                super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743                    hidden_size,
744                    kv_hidden,
745                    mc.intermediate_size,
746                );
747            let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748                num_layers,
749                block_sizes,
750                vocab_size,
751                hidden_size,
752            );
753            println!(
754                "  ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755                config.accumulation_steps,
756                (accum
757                    .block_grads
758                    .iter()
759                    .map(super::grad_accumulator::BlockGradientSet::total_elements)
760                    .sum::<usize>()
761                    + accum.lm_head_grad.len()
762                    + accum.final_norm_grad.len()
763                    + accum.embedding_grad.len()) as f64
764                    * 4.0
765                    / 1e6,
766            );
767            Some(accum)
768        } else {
769            None
770        };
771
772        // ALB-091: GPU-resident gradient accumulation (eliminates D2H bottleneck).
773        // Falls back to CPU accum if GPU allocation fails.
774        let gpu_grad_accum = if config.accumulation_steps > 1 {
775            match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776                Ok(accum) => {
777                    println!("  ✓ GPU gradient accumulation enabled (ALB-091)");
778                    Some(accum)
779                }
780                Err(e) => {
781                    eprintln!(
782                        "  [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783                    );
784                    None
785                }
786            }
787        } else {
788            None
789        };
790
791        // KAIZEN-059: Pre-allocate D2H staging buffer for gradient accumulation
792        // downloads. Only needed when GPU accum is unavailable (CPU fallback path).
793        let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794            let ws_max = hidden_size * mc.intermediate_size;
795            let lm_max = vocab_size * hidden_size;
796            vec![0.0f32; ws_max.max(lm_max)]
797        } else {
798            Vec::new()
799        };
800
801        // ALB-078: Pre-allocate fused gradient clipping state.
802        // Eliminates 24 stream syncs per step by keeping norm+clip on GPU.
803        let kv_hidden = mc.num_kv_heads * mc.head_dim();
804        let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806        // R-002: Initialize gradient scaler from precision config
807        let grad_scaler = GradScaler::from_config(&config.precision_config);
808        if config.precision_config.is_mixed() {
809            println!(
810                "  ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811                config.precision_config.compute_precision,
812                grad_scaler.scale(),
813                grad_scaler.is_dynamic(),
814            );
815        }
816
817        Ok(Self {
818            model,
819            cuda_trainer,
820            cuda_blocks,
821            cuda_grad_workspace,
822            nf4_shared_scratch,
823            nf4_lora_grad_workspace,
824            nf4_lora_optimizer_states,
825            gpu_training,
826            lm_head_weight_gpu,
827            lm_head_grad_gpu,
828            lm_head_m,
829            lm_head_v,
830            final_norm_m,
831            final_norm_v,
832            embed_optimizer,
833            // KAIZEN-047: Read profile_interval before moving config into struct.
834            profiler: if config.profile_interval > 0 {
835                StepProfiler::new(true, config.profile_interval)
836            } else {
837                StepProfiler::disabled()
838            },
839            config,
840            metrics: MetricsTracker::new(),
841            step: 0,
842            accumulated_loss: 0.0,
843            accumulated_batches: 0,
844            last_grad_norm: 0.0,
845            last_embed_grad_norm: 0.0,
846            grad_accum,
847            gpu_grad_accum,
848            grad_scaler,
849            fwd_scratch_a,
850            fwd_scratch_b,
851            h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852            d2h_staging,
853            fused_clip,
854            final_norm_zero_buf: vec![0.0f32; hidden_size],
855        })
856    }
857
858    /// Upload transformer blocks to GPU (NF4 or fp32 path).
859    #[allow(clippy::too_many_arguments)]
860    fn upload_blocks(
861        model: &Transformer,
862        mc: &crate::transformer::TransformerConfig,
863        config: &TransformerTrainConfig,
864        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865        use_nf4: bool,
866        num_layers: usize,
867        hidden_size: usize,
868        max_seq_len: usize,
869    ) -> crate::Result<Vec<CudaBlock>> {
870        let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872        if use_nf4 {
873            let lora_rank = config.lora_rank.unwrap_or(16);
874            let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875            let lora_scale = lora_alpha / lora_rank as f32;
876            let head_dim = mc.head_dim();
877            let q_dim = mc.num_attention_heads * head_dim;
878            let kv_hidden = mc.num_kv_heads * head_dim;
879
880            for (i, layer) in model.layers.iter().enumerate() {
881                let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882                    .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883                    .collect();
884                let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885                let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886                    .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887                    .collect();
888                let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890                let q_norm_data = layer
891                    .self_attn
892                    .q_norm
893                    .as_ref()
894                    .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895                let k_norm_data = layer
896                    .self_attn
897                    .k_norm
898                    .as_ref()
899                    .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901                let block = crate::transformer::CudaNf4TransformerBlock::new(
902                    mc,
903                    i,
904                    ctx.clone(),
905                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
906                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
913                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
914                    max_seq_len,
915                    Some((&lora_a_q, &lora_b_q)),
916                    Some((&lora_a_v, &lora_b_v)),
917                    lora_scale,
918                    lora_rank,
919                    q_norm_data.as_deref(),
920                    k_norm_data.as_deref(),
921                )
922                .map_err(|e| {
923                    crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924                })?;
925                cuda_blocks.push(CudaBlock::Nf4(block));
926            }
927            println!("  ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928        } else {
929            for (i, layer) in model.layers.iter().enumerate() {
930                // FALSIFY-CUDA-FORWARD-PARITY-002 thread Q/K/V biases
931                // through to the CudaTransformerBlock when present
932                // (Qwen2 family use_bias=true). Pre-fix these were
933                // silently dropped → val_loss > ln(vocab) on Qwen.
934                let b_q = layer
935                    .self_attn
936                    .b_q
937                    .as_ref()
938                    .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939                let b_k = layer
940                    .self_attn
941                    .b_k
942                    .as_ref()
943                    .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944                let b_v = layer
945                    .self_attn
946                    .b_v
947                    .as_ref()
948                    .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949                let block = CudaTransformerBlock::new(
950                    mc,
951                    i,
952                    ctx.clone(),
953                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
954                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
961                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
962                    max_seq_len,
963                    b_q.as_deref(),
964                    b_k.as_deref(),
965                    b_v.as_deref(),
966                )
967                .map_err(|e| {
968                    crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969                })?;
970                cuda_blocks.push(CudaBlock::Fp32(block));
971            }
972            println!("  ✓ {num_layers} transformer blocks uploaded to GPU");
973        }
974
975        Ok(cuda_blocks)
976    }
977
978    /// ALB-078: Initialize fused gradient clipping state (extracted for complexity).
979    fn init_fused_clip(
980        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981        config: &TransformerTrainConfig,
982        hidden_size: usize,
983        kv_hidden: usize,
984        mc: &crate::transformer::TransformerConfig,
985    ) -> Option<FusedClipState> {
986        config.base.max_grad_norm?;
987        let grad_sizes: [u32; 9] = [
988            (hidden_size * hidden_size) as u32,
989            (hidden_size * kv_hidden) as u32,
990            (hidden_size * kv_hidden) as u32,
991            (hidden_size * hidden_size) as u32,
992            (hidden_size * mc.intermediate_size) as u32,
993            (hidden_size * mc.intermediate_size) as u32,
994            (mc.intermediate_size * hidden_size) as u32,
995            hidden_size as u32,
996            hidden_size as u32,
997        ];
998        match FusedClipState::new(ctx, &grad_sizes) {
999            Ok(state) => {
1000                println!(
1001                    "  ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002                    state.total_partials,
1003                    f64::from(state.total_partials) * 4.0 / 1024.0,
1004                );
1005                Some(state)
1006            }
1007            Err(e) => {
1008                println!("  ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009                None
1010            }
1011        }
1012    }
1013
1014    /// Run one forward+backward step for a single sequence.
1015    ///
1016    /// # Contract (C-GPUSTEP-001)
1017    ///
1018    /// - Precondition: `input_ids.len() == target_ids.len() <= max_seq_len`
1019    /// - Postcondition: If `accumulate_only` is false, all GPU weights updated.
1020    ///   If true, gradients accumulated into CPU buffers (no weight updates).
1021    /// - Transfer count: 1 PCIe H2D + ~1KB control (KAIZEN-050, + 24×9 D2H if accumulating)
1022    fn train_step_single(
1023        &mut self,
1024        input_ids: &[u32],
1025        target_ids: &[u32],
1026        accumulate_only: bool,
1027    ) -> Option<f32> {
1028        self.profiler.begin_step();
1029        let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030        self.profiler.finish_step();
1031        result
1032    }
1033
1034    /// Inner training step — separated so profiler always records the step.
1035    fn train_step_inner(
1036        &mut self,
1037        input_ids: &[u32],
1038        target_ids: &[u32],
1039        accumulate_only: bool,
1040    ) -> Option<f32> {
1041        let hidden_size = self.config.model_config.hidden_size;
1042        let vocab_size = self.config.model_config.vocab_size;
1043
1044        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
1045        let max_sl = self.config.max_seq_len;
1046        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048        let seq_len = input_ids.len();
1049
1050        // Steps 1-6: GPU forward pass — logits stay GPU-resident (KAIZEN-050)
1051        // (sub-phases embed, h2d, forward, norm_lm instrumented inside gpu_forward)
1052        if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053            eprintln!(
1054                "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055                 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056            );
1057            return None;
1058        }
1059
1060        // Step 7: Fused GPU cross-entropy loss + softmax backward (KAIZEN-050)
1061        // Eliminates: logits D2H (77.8MB) + CPU softmax (40ms) + grad H2D (77.8MB)
1062        self.profiler.begin(StepProfiler::LOSS);
1063        let stream = self.cuda_trainer.stream();
1064
1065        // Compute combined scale: (1/seq_len) * (1/accum_steps)
1066        //
1067        // ALB-072: Do NOT multiply by grad_scaler.scale() here. All backward
1068        // computation uses f32 GpuBuffers — there is no fp16 gradient underflow
1069        // risk. The 65536x loss scaling caused gradient overflow in early layers
1070        // (blocks 0-1 went NaN). The GradScaler remains active for the CPU
1071        // embedding path (unscale_and_check in optimizer_step) as a safety check,
1072        // but it operates with scale=1.0 effective for GPU gradients.
1073        let mut loss_scale = 1.0 / seq_len as f32;
1074        if self.config.accumulation_steps > 1 {
1075            loss_scale /= self.config.accumulation_steps as f32;
1076        }
1077
1078        // KAIZEN-052: In-place — gradient written directly to logits_buf.
1079        let loss_val = fused_cross_entropy_cuda(
1080            &mut self.gpu_training.logits_buf,
1081            target_ids,
1082            seq_len as u32,
1083            vocab_size as u32,
1084            loss_scale,
1085            stream,
1086        )
1087        .ok()?;
1088
1089        // NaN guard (replaces logits NaN check — NaN logits → NaN loss via kernel)
1090        if !loss_val.is_finite() {
1091            return None;
1092        }
1093        self.profiler.end(StepProfiler::LOSS);
1094
1095        // Steps 8-11: GPU backward pass (with or without optimizer)
1096        // (sub-phases lm_bwd, norm_bwd, blk_bwd instrumented inside gpu_backward)
1097        // KAIZEN-050: grad_logits on GPU. KAIZEN-052: grad lives in logits_buf (in-place).
1098        //
1099        // ENT-263 fix: Capture loss regardless of backward success. The NF4 backward
1100        // path may fail (e.g., gemm_nf4_backward_a stub) but the loss was already
1101        // computed by fused_cross_entropy_cuda. Dropping the loss silently causes
1102        // loss=0.0 reporting despite valid forward passes.
1103        if let Some(grad_output_is_a) =
1104            self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105        {
1106            // Step 12: Embedding backward (CPU scatter-add always accumulates)
1107            self.profiler.begin(StepProfiler::EMBED_BWD);
1108            self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110            self.profiler.end(StepProfiler::EMBED_BWD);
1111        }
1112
1113        Some(loss_val)
1114    }
1115
1116    /// GPU forward pass: embed → blocks → norm → LM head.
1117    ///
1118    /// Logits stay GPU-resident in `self.gpu_training.logits_buf` (KAIZEN-050).
1119    /// Transfers: 1 H2D (hidden states). No D2H — logits consumed by fused kernel.
1120    #[allow(unsafe_code)]
1121    fn gpu_forward(
1122        &mut self,
1123        input_ids: &[u32],
1124        seq_len: usize,
1125        hidden_size: usize,
1126        vocab_size: usize,
1127    ) -> Option<()> {
1128        contract_pre_gpu_forward!();
1129        let stream = self.cuda_trainer.stream();
1130
1131        // Embedding lookup (CPU)
1132        self.profiler.begin(StepProfiler::EMBED);
1133        let hidden = self.model.embed_tokens.forward(input_ids);
1134        let hidden_slice = hidden.data().as_slice()?;
1135        self.profiler.end(StepProfiler::EMBED);
1136
1137        // Upload hidden states to GPU (Transfer 1: H2D)
1138        // Pad to max_seq_len so D2D copies to pre-allocated layer_inputs match.
1139        // KAIZEN-053: Reuse pre-allocated scratch buffers instead of cuMemAlloc per step.
1140        // KAIZEN-056: Reuse pre-allocated h2d_staging instead of alloc per step.
1141        self.profiler.begin(StepProfiler::H2D);
1142        self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143        self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144        if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145            eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146            return None;
1147        }
1148        self.profiler.end(StepProfiler::H2D);
1149
1150        // Forward through CUDA blocks using pre-allocated ping-pong buffers.
1151        // KAIZEN-053: fwd_scratch_a/b are top-level fields (not in gpu_training)
1152        // so borrowing them doesn't conflict with gpu_training.layer_inputs.
1153        self.profiler.begin(StepProfiler::FORWARD);
1154        let mut input_is_a = true; // Track which scratch buffer is "input"
1155        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156            // Use raw pointers for the ping-pong to avoid borrow conflicts
1157            // with self.gpu_training.layer_inputs
1158            let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159                if input_is_a {
1160                    (
1161                        std::ptr::from_ref(&self.fwd_scratch_a),
1162                        std::ptr::from_mut(&mut self.fwd_scratch_b),
1163                    )
1164                } else {
1165                    (
1166                        std::ptr::from_ref(&self.fwd_scratch_b),
1167                        std::ptr::from_mut(&mut self.fwd_scratch_a),
1168                    )
1169                };
1170            if self.gpu_training.saved_layer_mask[i] {
1171                // SAFETY: Both buffers are valid GPU allocations with matching max_seq_len size.
1172                // Copy completes before block.forward() reads from input (same stream ordering).
1173                unsafe {
1174                    self.gpu_training.layer_inputs[i]
1175                        .copy_from_buffer_async(&*input_ptr, stream)
1176                        .ok()?;
1177                }
1178            }
1179            // SAFETY: input_ptr and output_ptr point to disjoint fwd_scratch_{a,b}.
1180            // ENT-263: Pass shared scratch for NF4 blocks (C-SCRATCH-001).
1181            self.profiler.begin_layer();
1182            unsafe {
1183                block
1184                    .forward(
1185                        &*input_ptr,
1186                        &mut *output_ptr,
1187                        seq_len,
1188                        stream,
1189                        self.nf4_shared_scratch.as_mut(),
1190                    )
1191                    .ok()?;
1192            }
1193            self.profiler.end_layer_fwd(i);
1194            input_is_a = !input_is_a;
1195        }
1196        self.profiler.end(StepProfiler::FORWARD);
1197
1198        // After the loop, input_is_a tells us which buffer has the final output
1199        let final_output: &GpuBuffer<f32> =
1200            if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202        // Save blocks output for final norm backward
1203        // SAFETY: Disjoint GPU buffers with matching max_seq_len sizes.
1204        self.profiler.begin(StepProfiler::NORM_LM);
1205        unsafe {
1206            self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207        }
1208
1209        // Final RMSNorm forward (GPU)
1210        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: thread `config.rms_norm_eps`
1211        // through so Qwen2 (1e-6) gets the right epsilon. Pre-fix the
1212        // legacy `rms_norm_forward` hardcoded 1e-5 (Llama default).
1213        rms_norm_forward_with_eps(
1214            final_output,
1215            &self.gpu_training.final_norm_weight,
1216            &mut self.gpu_training.norm_output,
1217            seq_len as u32,
1218            hidden_size as u32,
1219            self.config.model_config.rms_norm_eps,
1220            stream,
1221        )
1222        .ok()?;
1223
1224        // LM head GEMM forward (GPU)
1225        // gemm_forward treats flat (V,H) memory as (H,V) row-major, which
1226        // implicitly transposes — matching the CPU matmul's tied-weight behavior.
1227        gemm_forward(
1228            &self.gpu_training.norm_output,
1229            &self.lm_head_weight_gpu,
1230            &mut self.gpu_training.logits_buf,
1231            seq_len as u32,
1232            hidden_size as u32,
1233            vocab_size as u32,
1234            stream,
1235        )
1236        .ok()?;
1237
1238        // KAIZEN-050: Logits stay GPU-resident — no D2H transfer.
1239        // Fused cross-entropy kernel reads logits_buf directly on GPU.
1240        self.profiler.end(StepProfiler::NORM_LM);
1241
1242        Some(())
1243    }
1244
1245    /// ALB-089: Forward-only pass that returns last-position logits on CPU.
1246    ///
1247    /// Runs the same GPU forward as training but downloads only the last
1248    /// position's logits (vocab_size floats) for token sampling. No backward
1249    /// pass, no loss computation.
1250    ///
1251    /// # Contract (C-CUDA-INF-001)
1252    ///
1253    /// - Same forward path as `gpu_forward()` — identical logits
1254    /// - Only downloads `logits[seq_len-1, :]` (128 KB for 32K vocab)
1255    /// - stream.synchronize() before D2H (C-STREAMSYNC-001)
1256    pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1257        let seq_len = input_ids.len();
1258        let hidden_size = self.config.model_config.hidden_size;
1259        let vocab_size = self.config.model_config.vocab_size;
1260
1261        if seq_len == 0 || seq_len > self.config.max_seq_len {
1262            return None;
1263        }
1264
1265        // Reuse gpu_forward for the actual computation
1266        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1267
1268        // C-STREAMSYNC-001: synchronize before D2H
1269        let stream = self.cuda_trainer.stream();
1270        stream.synchronize().ok()?;
1271
1272        // Download last position logits only: logits_buf[seq_len-1, :]
1273        let offset = (seq_len - 1) * vocab_size;
1274        let mut logits = vec![0.0f32; vocab_size];
1275        self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1276
1277        Some(logits)
1278    }
1279
1280    /// GPU backward pass with interleaved per-block optimizer step.
1281    ///
1282    /// Each block's backward writes weight gradients to the shared `CudaGradWorkspace`.
1283    /// Recompute layer inputs for a segment during backward (activation checkpointing).
1284    ///
1285    /// When checkpointing is enabled, non-checkpoint layers don't save their inputs
1286    /// during forward. Before their backward pass, we recompute from the nearest
1287    /// checkpoint by re-running forward through intermediate blocks.
1288    ///
1289    /// This recomputes the entire segment [checkpoint..=target_layer], storing
1290    /// intermediate layer_inputs so subsequent layers in the same segment don't
1291    /// need redundant recomputation.
1292    ///
1293    /// # Contract (R-021)
1294    ///
1295    /// After this call, `layer_inputs[i]` is valid for all i in [checkpoint..=target_layer].
1296    #[allow(unsafe_code)]
1297    fn recompute_segment(
1298        gpu_training: &mut GpuPretrainState,
1299        cuda_blocks: &mut [CudaBlock],
1300        nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1301        target_layer: usize,
1302        seq_len: usize,
1303        stream: &CudaStream,
1304    ) -> Option<()> {
1305        // Find nearest saved checkpoint at or before target
1306        let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1307
1308        if seg_start == target_layer {
1309            return Some(()); // Already saved
1310        }
1311
1312        // Copy checkpoint input to recompute_buf as starting point.
1313        // SAFETY: recompute_buf and layer_inputs are disjoint allocations.
1314        let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1315        unsafe {
1316            recompute_buf
1317                .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1318                .ok()?;
1319        }
1320
1321        // Forward through blocks [seg_start..target_layer], saving intermediate inputs.
1322        // For block i, input → block i → output becomes input for block i+1.
1323        // We save output (= input to block i+1) in layer_inputs[i+1].
1324        //
1325        // Buffer pattern:
1326        //   i == seg_start: input = recompute_buf, output = layer_inputs[seg_start+1]
1327        //   i > seg_start:  input = layer_inputs[i], output = layer_inputs[i+1]
1328        //
1329        // SAFETY: split_at_mut ensures non-overlapping borrows of layer_inputs.
1330        // recompute_buf is separate from layer_inputs.
1331        for i in seg_start..target_layer {
1332            if i == seg_start {
1333                // Input is in recompute_buf, output goes to layer_inputs[i+1]
1334                let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1335                let li = &mut gpu_training.layer_inputs;
1336                unsafe {
1337                    cuda_blocks[i]
1338                        .forward(
1339                            &*recompute_ptr,
1340                            &mut li[i + 1],
1341                            seq_len,
1342                            stream,
1343                            nf4_shared_scratch.as_mut(),
1344                        )
1345                        .ok()?;
1346                }
1347            } else {
1348                // Input = layer_inputs[i], output = layer_inputs[i+1]
1349                let li = &mut gpu_training.layer_inputs;
1350                let (left, right) = li.split_at_mut(i + 1);
1351                cuda_blocks[i]
1352                    .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1353                    .ok()?;
1354            }
1355        }
1356
1357        Some(())
1358    }
1359
1360    /// Since `gemm_backward_b` overwrites (not accumulates), we must run each block's
1361    /// optimizer step immediately after its backward, before the next block overwrites
1362    /// the workspace. This also enables per-block gradient clipping.
1363    ///
1364    /// When `accumulate_only` is true (R-038 gradient accumulation), the per-block
1365    /// optimizer steps are skipped and workspace gradients are downloaded to CPU-side
1366    /// `PerBlockGradientAccumulator` instead. LM head and final norm gradients are
1367    /// also downloaded and accumulated. The optimizer step is deferred until
1368    /// `gpu_optimizer_from_accum()` is called.
1369    ///
1370    /// Returns `grad_output_is_a` flag for embedding backward.
1371    /// Transfer: 0 H2D (KAIZEN-050/052: grad in logits_buf) + 24×9 D2H if accumulating.
1372    #[allow(unsafe_code)]
1373    fn gpu_backward(
1374        &mut self,
1375        seq_len: usize,
1376        hidden_size: usize,
1377        vocab_size: usize,
1378        accumulate_only: bool,
1379    ) -> Option<bool> {
1380        let stream = self.cuda_trainer.stream();
1381        let max_grad_norm = self.config.base.max_grad_norm;
1382        let lr = self.current_lr();
1383        // ALB-072: No inv_scale needed — loss_scale no longer includes grad_scaler.
1384        let beta1 = self.config.beta1;
1385        let beta2 = self.config.beta2;
1386        let weight_decay = self.config.weight_decay;
1387
1388        // KAIZEN-050: grad_logits GPU-resident. KAIZEN-052: grad lives in logits_buf (in-place).
1389        // No separate grad buffer. No GRAD_H2D transfer.
1390
1391        // LM head GEMM backward
1392        self.profiler.begin(StepProfiler::LM_BWD);
1393        gemm_backward_a(
1394            &self.gpu_training.logits_buf,
1395            &self.lm_head_weight_gpu,
1396            &mut self.gpu_training.lm_head_grad_hidden,
1397            seq_len as u32,
1398            hidden_size as u32,
1399            vocab_size as u32,
1400            stream,
1401        )
1402        .ok()?;
1403
1404        gemm_backward_b(
1405            &self.gpu_training.norm_output,
1406            &self.gpu_training.logits_buf,
1407            &mut self.lm_head_grad_gpu,
1408            seq_len as u32,
1409            hidden_size as u32,
1410            vocab_size as u32,
1411            stream,
1412        )
1413        .ok()?;
1414
1415        // Clip LM head weight gradient
1416        // KAIZEN-049: GPU norm reduction.
1417        // KAIZEN-051: No explicit sync needed — same stream ordering.
1418        // ALB-071: Always compute LM head grad norm for observability (R-004).
1419        // C-CLIP-001: squared_sum_cuda returns ||g||². Take sqrt for L2 norm (entrenar#311).
1420        let lm_sq_norm =
1421            squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1422                .unwrap_or(0.0);
1423        let lm_norm = lm_sq_norm.sqrt(); // L2 norm, NOT squared
1424        self.last_grad_norm = lm_norm; // R-004: capture for observability
1425                                       // C-BACKPARITY-001: LM head gradient norm tracing (pre-clip).
1426        if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1427            eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1428            // Also trace the grad_hidden flowing to blocks
1429            let gh_sq = squared_sum_cuda(
1430                &self.gpu_training.lm_head_grad_hidden,
1431                self.gpu_training.lm_head_grad_hidden.len() as u32,
1432                stream,
1433            )
1434            .unwrap_or(0.0);
1435            eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1436        }
1437        if let Some(max_norm) = max_grad_norm {
1438            let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1439            let n = self.lm_head_grad_gpu.len() as u32;
1440            let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1441        }
1442        self.profiler.end(StepProfiler::LM_BWD);
1443
1444        // Final RMSNorm backward
1445        self.profiler.begin(StepProfiler::NORM_BWD);
1446        // Zero grad_final_norm_weight before backward — kernel accumulates via atomicAdd
1447        self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1448        rms_norm_backward(
1449            &self.gpu_training.blocks_output,
1450            &self.gpu_training.final_norm_weight,
1451            &self.gpu_training.lm_head_grad_hidden,
1452            &mut self.gpu_training.grad_buf_a,
1453            &mut self.gpu_training.grad_final_norm_weight,
1454            seq_len as u32,
1455            hidden_size as u32,
1456            1e-5_f32,
1457            stream,
1458        )
1459        .ok()?;
1460
1461        // Clip final norm weight gradient
1462        // KAIZEN-051: No explicit sync needed — same stream ordering as LM head clip.
1463        if let Some(max_norm) = max_grad_norm {
1464            let (scale, _) = Self::compute_clip_scale_with_norm(
1465                &self.gpu_training.grad_final_norm_weight,
1466                max_norm,
1467                stream,
1468            );
1469            let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1470            let _ =
1471                gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1472        }
1473        self.profiler.end(StepProfiler::NORM_BWD);
1474
1475        // R-038: Either accumulate non-block grads or run non-block optimizer.
1476        if accumulate_only {
1477            // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1478            if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1479                let _ = gpu_accum.accumulate_nonblock(
1480                    &self.lm_head_grad_gpu,
1481                    &self.gpu_training.grad_final_norm_weight,
1482                    stream,
1483                );
1484            } else {
1485                stream.synchronize().ok()?;
1486                Self::download_nonblock_grads_to_accum(
1487                    &self.lm_head_grad_gpu,
1488                    &self.gpu_training.grad_final_norm_weight,
1489                    &mut self.grad_accum,
1490                    &mut self.d2h_staging,
1491                )?;
1492            }
1493        } else {
1494            Self::run_nonblock_optimizer_step(
1495                &mut self.gpu_training,
1496                Some(&mut self.lm_head_weight_gpu),
1497                &self.lm_head_grad_gpu,
1498                &mut self.lm_head_m,
1499                &mut self.lm_head_v,
1500                &mut self.final_norm_m,
1501                &mut self.final_norm_v,
1502                lr,
1503                beta1,
1504                beta2,
1505                weight_decay,
1506                stream,
1507            );
1508        }
1509
1510        // Backward through blocks in reverse, with interleaved clip + optimizer.
1511        // Each block's backward writes weight gradients to shared CudaGradWorkspace.
1512        //
1513        // SAFETY: grad_buf_a and grad_buf_b are disjoint fields. Raw pointers
1514        // allow alternating read/write without violating aliasing rules.
1515        self.profiler.begin(StepProfiler::BLK_BWD);
1516        let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1517        let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1518        let mut grad_output_is_a = true;
1519        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1520
1521        for layer_idx in (0..self.cuda_blocks.len()).rev() {
1522            // Activation checkpointing: if this layer's input wasn't saved during
1523            // forward, recompute the segment from the nearest checkpoint.
1524            if !self.gpu_training.saved_layer_mask[layer_idx] {
1525                Self::recompute_segment(
1526                    &mut self.gpu_training,
1527                    &mut self.cuda_blocks,
1528                    &mut self.nf4_shared_scratch,
1529                    layer_idx,
1530                    seq_len,
1531                    stream,
1532                )?;
1533            }
1534
1535            let (grad_output, grad_input) = unsafe {
1536                if grad_output_is_a {
1537                    (&*grad_a_ptr, &mut *grad_b_ptr)
1538                } else {
1539                    (&*grad_b_ptr, &mut *grad_a_ptr)
1540                }
1541            };
1542
1543            self.profiler.begin_layer();
1544            if use_nf4 {
1545                // ENT-263: NF4 backward — LoRA gradient computation
1546                // Uses backward_nf4() which computes gradients for LoRA weights and norms only.
1547                // output_scratch reuses grad_buf_a/b as temporary storage for recomputed forward.
1548                let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1549                    grad_b_ptr // grad_input is in b, use as output_scratch too (will be overwritten)
1550                } else {
1551                    grad_a_ptr
1552                };
1553                // We need a separate output_scratch. Reuse blocks_output as scratch since
1554                // it was already consumed for norm backward above.
1555                match self.cuda_blocks[layer_idx].backward_nf4(
1556                    &self.gpu_training.layer_inputs[layer_idx],
1557                    grad_output,
1558                    grad_input,
1559                    &mut self.gpu_training.blocks_output, // reuse as output_scratch
1560                    seq_len,
1561                    stream,
1562                    self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1563                    self.nf4_lora_grad_workspace
1564                        .as_mut()
1565                        .expect("NF4 requires LoRA grad workspace"),
1566                ) {
1567                    Ok(()) => {}
1568                    Err(e) => {
1569                        eprintln!(
1570                            "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1571                            layer_idx, e, seq_len, self.config.model_config.hidden_size
1572                        );
1573                        return None;
1574                    }
1575                }
1576
1577                // ENT-265: Clip LoRA gradients before optimizer step.
1578                // Without this, NF4 LoRA grads are unbounded — causes weight
1579                // divergence and embedding grad explosion (Run 7c: 26M at step 225).
1580                if let Some(max_norm) = max_grad_norm {
1581                    self.nf4_lora_grad_workspace
1582                        .as_mut()
1583                        .expect("NF4 requires LoRA grad ws")
1584                        .clip_gradients(max_norm, stream);
1585                }
1586
1587                // NF4 LoRA optimizer step — always runs, even during accumulation.
1588                //
1589                // BUG FIX (entrenar#264): Previously gated by `if !accumulate_only`.
1590                // Design: NF4 LoRA has ~6M params, so we scale lr by 1/accum_steps
1591                // for micro-batches instead of accumulating gradients.
1592                {
1593                    let step = self.gpu_training.step;
1594                    let effective_lr = if accumulate_only {
1595                        lr / self.config.accumulation_steps as f32
1596                    } else {
1597                        lr
1598                    };
1599                    if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1600                        let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1601                            &mut opt_states[layer_idx],
1602                            step,
1603                            effective_lr,
1604                            beta1,
1605                            beta2,
1606                            1e-8,
1607                            weight_decay,
1608                            stream,
1609                            self.nf4_lora_grad_workspace
1610                                .as_ref()
1611                                .expect("NF4 requires LoRA grad ws"),
1612                        );
1613                    }
1614                }
1615            } else {
1616                // Standard fp32 backward path
1617                self.cuda_blocks[layer_idx]
1618                    .backward(
1619                        &self.gpu_training.layer_inputs[layer_idx],
1620                        grad_output,
1621                        grad_input,
1622                        seq_len,
1623                        stream,
1624                        &mut self.cuda_grad_workspace,
1625                    )
1626                    .ok()?;
1627
1628                // C-CLIP-001 / entrenar#312: DISABLED per-block gradient clipping.
1629                // Per-block clipping distorts gradient flow across layers.
1630
1631                // C-BACKPARITY-001: Per-block gradient norm tracing for parity testing.
1632                // Only runs when ENTRENAR_TRACE_GRADIENTS=1 — zero overhead in production.
1633                if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1634                    let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1635                        &self.cuda_grad_workspace,
1636                        f32::MAX,
1637                        stream,
1638                    );
1639                    // Also trace the activation gradient (flows between blocks)
1640                    let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1641                        .unwrap_or(0.0);
1642                    let act_gnorm = act_sq.sqrt();
1643                    eprintln!(
1644                        "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1645                    );
1646                }
1647
1648                // R-038: Either accumulate workspace grads or run optimizer per-block.
1649                if accumulate_only {
1650                    // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1651                    if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1652                        let _ = gpu_accum.accumulate_block(
1653                            &self.cuda_grad_workspace,
1654                            layer_idx,
1655                            stream,
1656                        );
1657                    } else {
1658                        // CPU fallback: SYNC + D2H (ALB-065 / Rule 6).
1659                        stream.synchronize().ok()?;
1660                        if let Some(accum) = &mut self.grad_accum {
1661                            Self::download_workspace_to_accum(
1662                                &self.cuda_grad_workspace,
1663                                accum,
1664                                layer_idx,
1665                                &mut self.d2h_staging,
1666                            )?;
1667                        }
1668                    }
1669                } else {
1670                    // Per-block optimizer step: consume workspace gradients before next block overwrites
1671                    let step = self.gpu_training.step;
1672                    let _ = self.cuda_blocks[layer_idx].optimizer_step(
1673                        &mut self.gpu_training.optimizer_states[layer_idx],
1674                        step,
1675                        lr,
1676                        beta1,
1677                        beta2,
1678                        1e-8,
1679                        weight_decay,
1680                        stream,
1681                        &self.cuda_grad_workspace,
1682                    );
1683                }
1684            }
1685
1686            self.profiler.end_layer_bwd(layer_idx);
1687            grad_output_is_a = !grad_output_is_a;
1688        }
1689
1690        stream.synchronize().ok()?;
1691        self.profiler.end(StepProfiler::BLK_BWD);
1692
1693        Some(grad_output_is_a)
1694    }
1695
1696    /// R-038: Download non-block (LM head + final norm) gradients to CPU accumulator.
1697    /// Static method to avoid borrow conflicts.
1698    // KAIZEN-044: Pre-allocate single buffer for LM head + norm D2H downloads.
1699    // lm_head_grad is vocab×hidden (389M elements = 1.5 GB for Qwen3-4B).
1700    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1701    fn download_nonblock_grads_to_accum(
1702        lm_head_grad: &GpuBuffer<f32>,
1703        final_norm_grad: &GpuBuffer<f32>,
1704        grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1705        host: &mut [f32],
1706    ) -> Option<()> {
1707        let accum = grad_accum.as_mut()?;
1708
1709        let lm_slice = &mut host[..lm_head_grad.len()];
1710        lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1711        for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1712            *d += s;
1713        }
1714
1715        let norm_slice = &mut host[..final_norm_grad.len()];
1716        final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1717        for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1718            *d += s;
1719        }
1720        Some(())
1721    }
1722
1723    /// Run LM head + final norm optimizer step (non-accumulating path).
1724    /// Static method to avoid borrow conflicts with `stream`.
1725    #[allow(clippy::too_many_arguments)]
1726    fn run_nonblock_optimizer_step(
1727        gpu_training: &mut GpuPretrainState,
1728        lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1729        lm_head_grad_gpu: &GpuBuffer<f32>,
1730        lm_head_m: &mut GpuBuffer<f32>,
1731        lm_head_v: &mut GpuBuffer<f32>,
1732        final_norm_m: &mut GpuBuffer<f32>,
1733        final_norm_v: &mut GpuBuffer<f32>,
1734        lr: f32,
1735        beta1: f32,
1736        beta2: f32,
1737        weight_decay: f32,
1738        stream: &CudaStream,
1739    ) {
1740        gpu_training.step += 1;
1741        let step = gpu_training.step;
1742
1743        if let Some(lm_head_weight) = lm_head_weight_gpu {
1744            let n_lm = lm_head_weight.len() as u32;
1745            let _ = adamw_step_cuda(
1746                lm_head_weight,
1747                lm_head_grad_gpu,
1748                lm_head_m,
1749                lm_head_v,
1750                lr,
1751                beta1,
1752                beta2,
1753                1e-8,
1754                weight_decay,
1755                step,
1756                n_lm,
1757                stream,
1758            );
1759        }
1760
1761        let n_norm = gpu_training.final_norm_weight.len() as u32;
1762        let _ = adamw_step_cuda(
1763            &mut gpu_training.final_norm_weight,
1764            &gpu_training.grad_final_norm_weight,
1765            final_norm_m,
1766            final_norm_v,
1767            lr,
1768            beta1,
1769            beta2,
1770            1e-8,
1771            weight_decay,
1772            step,
1773            n_norm,
1774            stream,
1775        );
1776    }
1777
1778    /// R-038: Download shared CudaGradWorkspace to CPU per-block accumulation buffers.
1779    ///
1780    /// Static method to avoid borrow conflicts with `stream` (same pattern as
1781    /// `recompute_segment`). Must be called after stream.synchronize() (ALB-065 / Rule 6).
1782    // KAIZEN-044: Pre-allocate a single host buffer for all D2H downloads
1783    // in download_workspace_to_accum. Was allocating vec![0.0f32; len] × 9 buffers.
1784    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1785    fn download_workspace_to_accum(
1786        ws: &CudaGradWorkspace,
1787        accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1788        layer_idx: usize,
1789        host: &mut [f32],
1790    ) -> Option<()> {
1791        let bg = &mut accum.block_grads[layer_idx];
1792
1793        use super::grad_accumulator::component;
1794        let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1795            (&ws.grad_w_q, component::W_Q),
1796            (&ws.grad_w_k, component::W_K),
1797            (&ws.grad_w_v, component::W_V),
1798            (&ws.grad_w_o, component::W_O),
1799            (&ws.grad_gate, component::GATE),
1800            (&ws.grad_up, component::UP),
1801            (&ws.grad_down, component::DOWN),
1802            (&ws.grad_input_norm, component::INPUT_NORM),
1803            (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1804        ];
1805
1806        for (gpu_buf, comp_idx) in &bufs_and_components {
1807            let slice = &mut host[..gpu_buf.len()];
1808            gpu_buf.copy_to_host_at(slice, 0).ok()?;
1809            for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1810                *d += s;
1811            }
1812        }
1813        Some(())
1814    }
1815
1816    /// R-038: Upload averaged CPU accumulation buffers to GPU workspace and run
1817    /// optimizer step for all blocks + LM head + final norm.
1818    ///
1819    /// Called once after `accumulation_steps` micro-batches have been accumulated.
1820    /// ALB-091: Run optimizer step from GPU-resident accumulated gradients.
1821    /// D2D copy accum → workspace, then run per-block optimizer. Zero accum after.
1822    fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1823        let stream = self.cuda_trainer.stream();
1824        let lr = self.current_lr();
1825        let beta1 = self.config.beta1;
1826        let beta2 = self.config.beta2;
1827        let weight_decay = self.config.weight_decay;
1828
1829        // Sync once to ensure all accumulation kernels complete
1830        stream.synchronize().ok()?;
1831
1832        self.gpu_training.step += 1;
1833        let step = self.gpu_training.step;
1834
1835        // Upload GPU accum → workspace (D2D) and run optimizer per block
1836        let gpu_accum = self.gpu_grad_accum.as_ref()?;
1837        for layer_idx in 0..self.cuda_blocks.len() {
1838            gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1839
1840            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1841                &mut self.gpu_training.optimizer_states[layer_idx],
1842                step,
1843                lr,
1844                beta1,
1845                beta2,
1846                1e-8,
1847                weight_decay,
1848                stream,
1849                &self.cuda_grad_workspace,
1850            );
1851        }
1852
1853        // LM head: D2D copy accum → grad buffer, then optimizer step
1854        gpu_accum
1855            .upload_nonblock(
1856                &mut self.lm_head_grad_gpu,
1857                &mut self.gpu_training.grad_final_norm_weight,
1858            )
1859            .ok()?;
1860
1861        let n_lm = self.lm_head_weight_gpu.len() as u32;
1862        let _ = adamw_step_cuda(
1863            &mut self.lm_head_weight_gpu,
1864            &self.lm_head_grad_gpu,
1865            &mut self.lm_head_m,
1866            &mut self.lm_head_v,
1867            lr,
1868            beta1,
1869            beta2,
1870            1e-8,
1871            weight_decay,
1872            step,
1873            n_lm,
1874            stream,
1875        );
1876
1877        // Final norm optimizer step
1878        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1879        let _ = adamw_step_cuda(
1880            &mut self.gpu_training.final_norm_weight,
1881            &self.gpu_training.grad_final_norm_weight,
1882            &mut self.final_norm_m,
1883            &mut self.final_norm_v,
1884            lr,
1885            beta1,
1886            beta2,
1887            1e-8,
1888            weight_decay,
1889            step,
1890            n_norm,
1891            stream,
1892        );
1893
1894        stream.synchronize().ok()?;
1895
1896        // Zero accum for next window
1897        if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1898            let _ = gpu_accum.zero_all();
1899        }
1900
1901        Some(())
1902    }
1903
1904    #[allow(unsafe_code)]
1905    fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1906        let stream = self.cuda_trainer.stream();
1907        let lr = self.current_lr();
1908        let beta1 = self.config.beta1;
1909        let beta2 = self.config.beta2;
1910        let weight_decay = self.config.weight_decay;
1911
1912        // Average accumulated gradients
1913        let accum = self.grad_accum.as_mut()?;
1914        accum.average();
1915
1916        // Jidoka: check for NaN/Inf before applying
1917        if accum.has_non_finite() {
1918            println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1919            accum.zero_all();
1920            return Some(());
1921        }
1922
1923        self.gpu_training.step += 1;
1924        let step = self.gpu_training.step;
1925
1926        // Upload accumulated gradients and run optimizer for each block
1927        use super::grad_accumulator::component;
1928        for layer_idx in 0..self.cuda_blocks.len() {
1929            let bg = &accum.block_grads[layer_idx];
1930
1931            // Upload accumulated gradients to shared workspace
1932            // SAFETY: async host-to-device copies within the training stream; host buffers
1933            // (bg.components) are stable for the duration of the stream operations.
1934            unsafe {
1935                self.cuda_grad_workspace
1936                    .grad_w_q
1937                    .copy_from_host_async(&bg.components[component::W_Q], stream)
1938                    .ok()?;
1939                self.cuda_grad_workspace
1940                    .grad_w_k
1941                    .copy_from_host_async(&bg.components[component::W_K], stream)
1942                    .ok()?;
1943                self.cuda_grad_workspace
1944                    .grad_w_v
1945                    .copy_from_host_async(&bg.components[component::W_V], stream)
1946                    .ok()?;
1947                self.cuda_grad_workspace
1948                    .grad_w_o
1949                    .copy_from_host_async(&bg.components[component::W_O], stream)
1950                    .ok()?;
1951                self.cuda_grad_workspace
1952                    .grad_gate
1953                    .copy_from_host_async(&bg.components[component::GATE], stream)
1954                    .ok()?;
1955                self.cuda_grad_workspace
1956                    .grad_up
1957                    .copy_from_host_async(&bg.components[component::UP], stream)
1958                    .ok()?;
1959                self.cuda_grad_workspace
1960                    .grad_down
1961                    .copy_from_host_async(&bg.components[component::DOWN], stream)
1962                    .ok()?;
1963                self.cuda_grad_workspace
1964                    .grad_input_norm
1965                    .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1966                    .ok()?;
1967                self.cuda_grad_workspace
1968                    .grad_post_attn_norm
1969                    .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1970                    .ok()?;
1971            }
1972
1973            // Run optimizer step with uploaded averaged gradients
1974            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1975                &mut self.gpu_training.optimizer_states[layer_idx],
1976                step,
1977                lr,
1978                beta1,
1979                beta2,
1980                1e-8,
1981                weight_decay,
1982                stream,
1983                &self.cuda_grad_workspace,
1984            );
1985        }
1986
1987        // Upload accumulated LM head gradients and run AdamW step
1988        // entrenar#314: Skip GPU LM head optimizer for tied weights.
1989        // SAFETY: async host-to-device copy; host buffer (accum.lm_head_grad) is stable.
1990        unsafe {
1991            self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1992        }
1993        let n_lm = self.lm_head_weight_gpu.len() as u32;
1994        let _ = adamw_step_cuda(
1995            &mut self.lm_head_weight_gpu,
1996            &self.lm_head_grad_gpu,
1997            &mut self.lm_head_m,
1998            &mut self.lm_head_v,
1999            lr,
2000            beta1,
2001            beta2,
2002            1e-8,
2003            weight_decay,
2004            step,
2005            n_lm,
2006            stream,
2007        );
2008
2009        // Upload accumulated final norm gradients and run AdamW step
2010        // SAFETY: async host-to-device copy; host buffer (accum.final_norm_grad) is stable.
2011        unsafe {
2012            self.gpu_training
2013                .grad_final_norm_weight
2014                .copy_from_host_async(&accum.final_norm_grad, stream)
2015                .ok()?;
2016        }
2017        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2018        let _ = adamw_step_cuda(
2019            &mut self.gpu_training.final_norm_weight,
2020            &self.gpu_training.grad_final_norm_weight,
2021            &mut self.final_norm_m,
2022            &mut self.final_norm_v,
2023            lr,
2024            beta1,
2025            beta2,
2026            1e-8,
2027            weight_decay,
2028            step,
2029            n_norm,
2030            stream,
2031        );
2032
2033        stream.synchronize().ok()?;
2034
2035        // Zero accum for next window
2036        accum.zero_all();
2037        Some(())
2038    }
2039
2040    /// Compute gradient L2 norm via GPU reduction kernel (KAIZEN-049).
2041    ///
2042    /// Runs `SquaredSumKernel` on GPU, downloads only `num_blocks` partial sums (~1KB)
2043    /// instead of the full buffer (128MB for lm_head). Falls back to CPU download on error.
2044    ///
2045    /// # Contract (C-CLIPNORM-GPU-001)
2046    ///
2047    /// - **Precondition**: `buf.len() > 0`, stream is synchronized with prior kernel
2048    /// - **Postcondition**: `grad_norm ≈ sqrt(sum(buf[i]^2))`, `scale = min(1, max_norm/norm)`
2049    /// - **Transfer**: ~1KB D2H (num_blocks × 4B) vs n×4B (128MB for 32M elements)
2050    ///
2051    /// R-004: Returns `(clip_scale, grad_norm)` for observability.
2052    fn compute_clip_scale_with_norm(
2053        buf: &GpuBuffer<f32>,
2054        max_norm: f32,
2055        stream: &CudaStream,
2056    ) -> (f32, f32) {
2057        let n = buf.len() as u32;
2058        // Try GPU reduction first — ~1KB D2H instead of n×4 bytes
2059        let grad_norm = match squared_sum_cuda(buf, n, stream) {
2060            Ok(norm) => norm,
2061            Err(_) => {
2062                // Fallback: full D2H (original path)
2063                let mut host = vec![0.0f32; buf.len()];
2064                if buf.copy_to_host_at(&mut host, 0).is_err() {
2065                    return (1.0, 0.0);
2066                }
2067                let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2068                sq_sum.sqrt() as f32
2069            }
2070        };
2071        let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2072        (scale, grad_norm)
2073    }
2074
2075    /// Download embedding gradient from GPU, clip, and scatter-add into CPU weight.
2076    ///
2077    /// # Contract (C-EMBED-GRAD-001)
2078    ///
2079    /// The activation gradient from block[0]'s backward is unclipped (per-block clipping
2080    /// only applies to weight gradients in the shared workspace). For deep networks with
2081    /// random init, this gradient can overflow f32, producing NaN in the CPU AdamW.
2082    /// We clip the activation gradient to max_grad_norm before scatter-adding.
2083    #[allow(unsafe_code)]
2084    fn embed_backward(
2085        &mut self,
2086        input_ids: &[u32],
2087        _seq_len: usize,
2088        hidden_size: usize,
2089        vocab_size: usize,
2090        grad_output_is_a: bool,
2091    ) -> Option<()> {
2092        // The final backward output is in whichever buffer was last written
2093        let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2094        let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2095        let embed_grad_buf = unsafe {
2096            if grad_output_is_a {
2097                &*grad_a_ptr
2098            } else {
2099                &*grad_b_ptr
2100            }
2101        };
2102        let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2103
2104        // C-EMBED-GRAD-001: ALWAYS clip activation gradient before scatter-add.
2105        // Without this, 24-layer random-init backward amplifies gradients to ~1e35,
2106        // which overflows the CPU AdamW's second moment buffer.
2107        //
2108        // ALB-071: Decoupled from general grad_clip config. Embed activation gradient
2109        // clipping is a SAFETY constraint (prevents NaN), not a training hyperparameter.
2110        // Uses dedicated max_embed_grad_norm (default 1.0) independent of weight grad_clip.
2111        let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2112        {
2113            let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2114            let grad_norm = sq_sum.sqrt() as f32;
2115            self.last_embed_grad_norm = grad_norm; // R-040: per-parameter-group tracking
2116            if grad_norm > embed_clip_norm {
2117                let scale = embed_clip_norm / grad_norm;
2118                for g in &mut embed_grad_data {
2119                    *g *= scale;
2120                }
2121            }
2122        }
2123
2124        // KAIZEN-048: In-place scatter-add via grad_cell().borrow_mut().
2125        // Before: 3 × 128MB clones per step (grad() deep-copies Array1).
2126        // After: zero clones — mutate existing gradient buffer directly.
2127        let embed_weight = &mut self.model.embed_tokens.weight;
2128        let grad_cell = embed_weight.grad_cell();
2129        let mut grad_ref = grad_cell.borrow_mut();
2130        if grad_ref.is_none() {
2131            *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2132        }
2133        if let Some(grad) = grad_ref.as_mut() {
2134            for (pos, &token_id) in input_ids.iter().enumerate() {
2135                let tid = token_id as usize;
2136                if tid < vocab_size {
2137                    let src = pos * hidden_size;
2138                    let dst = tid * hidden_size;
2139                    for h in 0..hidden_size {
2140                        grad[dst + h] += embed_grad_data[src + h];
2141                    }
2142                }
2143            }
2144        }
2145        Some(())
2146    }
2147
2148    /// Apply optimizer step to CPU embedding and update metrics.
2149    ///
2150    /// GPU block optimizer steps now run interleaved with backward in `gpu_backward()`.
2151    /// LM head and final norm optimizer steps also run in `gpu_backward()`.
2152    /// This method handles only CPU embedding and bookkeeping.
2153    fn optimizer_step(&mut self) {
2154        // ALB-072: Gradients are no longer scaled by grad_scaler (loss_scale excludes
2155        // grad_scaler.scale()). All backward computation uses f32 — no fp16 underflow
2156        // risk. Skip unscaling; just update scaler as successful.
2157        self.grad_scaler.update(true);
2158
2159        // ALB-079: Sync CPU embedding optimizer lr with cosine schedule
2160        self.embed_optimizer.set_lr(self.current_lr());
2161        // CPU optimizer step for embedding weight
2162        let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2163        self.embed_optimizer.step_refs(&mut embed_params);
2164
2165        self.step += 1;
2166        self.metrics.losses.push(self.accumulated_loss);
2167        self.metrics.increment_step();
2168
2169        self.accumulated_loss = 0.0;
2170        self.accumulated_batches = 0;
2171    }
2172
2173    /// Process a batch (forward + backward + optimizer step with accumulation).
2174    ///
2175    /// R-038: When `accumulation_steps > 1`, runs forward+backward without optimizer
2176    /// for each micro-batch, downloading per-block weight gradients to CPU-side
2177    /// `PerBlockGradientAccumulator`. After `accumulation_steps` batches, averages
2178    /// the accumulated gradients, uploads them to GPU, and runs a single optimizer step.
2179    ///
2180    /// When `accumulation_steps == 1` (default), runs forward+backward+optimizer
2181    /// immediately per sequence (original behavior).
2182    ///
2183    /// Returns average loss for the batch.
2184    pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2185        if batch.batch_size == 0 {
2186            return 0.0;
2187        }
2188
2189        let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2190
2191        if self.accumulated_batches == 0 {
2192            // Zero embedding gradients at start of accumulation window
2193            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2194        }
2195
2196        let mut total_loss = 0.0;
2197        let mut valid_count = 0;
2198
2199        for i in 0..batch.batch_size {
2200            let Some(input_ids) = batch.get_input(i) else {
2201                continue;
2202            };
2203            let Some(target_ids) = batch.get_target(i) else {
2204                continue;
2205            };
2206
2207            // R-038: When accumulating, run backward without optimizer (accumulate_only=true).
2208            // Gradients are downloaded to CPU per-block accum buffers. Embedding grads are
2209            // scatter-added normally (they're already on CPU).
2210            if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2211                total_loss += loss;
2212                valid_count += 1;
2213                if accumulating {
2214                    if let Some(accum) = &mut self.gpu_grad_accum {
2215                        accum.accumulated_count += 1;
2216                    } else if let Some(accum) = &mut self.grad_accum {
2217                        accum.accumulated_count += 1;
2218                    }
2219                }
2220            }
2221        }
2222
2223        let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2224
2225        // Debug: help diagnose loss=0.0 when gradients are non-zero
2226        if avg_loss == 0.0 && valid_count > 0 {
2227            eprintln!(
2228                "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2229                valid_count, total_loss, batch.batch_size
2230            );
2231        }
2232
2233        self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2234        self.accumulated_batches += 1;
2235
2236        if self.accumulated_batches >= self.config.accumulation_steps {
2237            if accumulating {
2238                // ALB-091: Prefer GPU-resident accum path (zero D2H), fall back to CPU.
2239                if self.gpu_grad_accum.is_some() {
2240                    self.gpu_optimizer_from_gpu_accum();
2241                } else {
2242                    self.gpu_optimizer_from_accum();
2243                }
2244            }
2245            self.optimizer_step();
2246        }
2247
2248        avg_loss
2249    }
2250
2251    /// R-005: Evaluate a batch without backward pass or weight updates.
2252    /// Returns average cross-entropy loss, or 0.0 if no valid items.
2253    /// KAIZEN-050: Uses fused GPU cross-entropy (no logits D2H).
2254    pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2255        let hidden_size = self.config.model_config.hidden_size;
2256        let vocab_size = self.config.model_config.vocab_size;
2257        let max_sl = self.config.max_seq_len;
2258        let mut total_loss = 0.0;
2259        let mut valid_count = 0;
2260        for i in 0..batch.batch_size {
2261            if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2262            {
2263                total_loss += loss;
2264                valid_count += 1;
2265            }
2266        }
2267        if valid_count > 0 {
2268            total_loss / valid_count as f32
2269        } else {
2270            0.0
2271        }
2272    }
2273
2274    /// Evaluate a single sequence from a batch. Returns None if invalid.
2275    fn eval_single_sequence(
2276        &mut self,
2277        batch: &LMBatch,
2278        i: usize,
2279        max_sl: usize,
2280        hidden_size: usize,
2281        vocab_size: usize,
2282    ) -> Option<f32> {
2283        let input_ids = batch.get_input(i)?;
2284        let target_ids = batch.get_target(i)?;
2285        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
2286        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2287        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2288        let seq_len = input_ids.len();
2289        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2290        let stream = self.cuda_trainer.stream();
2291        let scale = 1.0 / seq_len as f32;
2292        let loss = fused_cross_entropy_cuda(
2293            &mut self.gpu_training.logits_buf,
2294            target_ids,
2295            seq_len as u32,
2296            vocab_size as u32,
2297            scale,
2298            stream,
2299        )
2300        .ok()?;
2301        if loss.is_finite() {
2302            Some(loss)
2303        } else {
2304            None
2305        }
2306    }
2307
2308    /// Train for one epoch over batches.
2309    pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2310        self.train_epoch_with_callback(batches, |_, _, _| {})
2311    }
2312
2313    /// Train for one epoch with a per-step callback.
2314    ///
2315    /// Stops early if `max_steps` is set and reached.
2316    pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2317    where
2318        F: FnMut(usize, f32, &Self),
2319    {
2320        if batches.is_empty() {
2321            return 0.0;
2322        }
2323
2324        let mut total_loss = 0.0;
2325        let mut batches_processed = 0;
2326
2327        for (i, batch) in batches.iter().enumerate() {
2328            if let Some(max) = self.config.max_steps {
2329                if self.step >= max {
2330                    break;
2331                }
2332            }
2333
2334            let batch_loss = self.train_batch(batch);
2335            total_loss += batch_loss;
2336            batches_processed += 1;
2337            on_batch(i, batch_loss, self);
2338        }
2339
2340        // KAIZEN-047: Print profiler summary at end of epoch
2341        if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2342            self.profiler.print_report();
2343        }
2344
2345        total_loss / batches_processed.max(1) as f32
2346    }
2347
2348    // --- DDP (data-parallel) support methods ---
2349
2350    /// Ensure the per-block gradient accumulator exists.
2351    ///
2352    /// For DDP, we always need accumulation buffers (even with accumulation_steps=1)
2353    /// because gradients must be downloaded to CPU for AllReduce before optimizer step.
2354    pub(crate) fn ensure_grad_accum(&mut self) {
2355        if self.grad_accum.is_some() {
2356            return;
2357        }
2358        let mc = &self.config.model_config;
2359        let hidden_size = mc.hidden_size;
2360        let kv_hidden = mc.num_kv_heads * mc.head_dim();
2361        let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2362            hidden_size,
2363            kv_hidden,
2364            mc.intermediate_size,
2365        );
2366        self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2367            self.cuda_blocks.len(),
2368            block_sizes,
2369            mc.vocab_size,
2370            hidden_size,
2371        ));
2372    }
2373
2374    /// Forward + backward for one batch, always accumulating (no optimizer step).
2375    ///
2376    /// Used by `DistributedCudaTrainer` to compute local gradients before AllReduce.
2377    /// Returns average loss for the batch.
2378    pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2379        if batch.batch_size == 0 {
2380            return 0.0;
2381        }
2382
2383        if self.accumulated_batches == 0 {
2384            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2385        }
2386
2387        let mut total_loss = 0.0;
2388        let mut valid_count = 0;
2389
2390        for i in 0..batch.batch_size {
2391            let Some(input_ids) = batch.get_input(i) else { continue };
2392            let Some(target_ids) = batch.get_target(i) else { continue };
2393
2394            // Always accumulate_only=true: gradients go to CPU accum buffers
2395            if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2396                total_loss += loss;
2397                valid_count += 1;
2398                if let Some(accum) = &mut self.grad_accum {
2399                    accum.accumulated_count += 1;
2400                }
2401            }
2402        }
2403
2404        if valid_count > 0 {
2405            total_loss / valid_count as f32
2406        } else {
2407            0.0
2408        }
2409    }
2410
2411    /// Apply DDP-averaged gradients: upload to GPU and run optimizer step.
2412    ///
2413    /// Called after AllReduce has written averaged gradients into the grad_accum.
2414    /// Runs gpu_optimizer_from_accum() for blocks + LM head + final norm,
2415    /// then optimizer_step() for embedding.
2416    pub(crate) fn apply_ddp_gradients(&mut self) {
2417        self.accumulated_loss = 0.0;
2418        self.accumulated_batches = 0;
2419        self.gpu_optimizer_from_accum();
2420        self.optimizer_step();
2421    }
2422
2423    /// Get a reference to the gradient accumulator (for DDP AllReduce).
2424    pub(crate) fn grad_accum_ref(
2425        &self,
2426    ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2427        self.grad_accum.as_ref()
2428    }
2429
2430    /// Get a mutable reference to the gradient accumulator (for DDP AllReduce).
2431    pub(crate) fn grad_accum_mut(
2432        &mut self,
2433    ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2434        self.grad_accum.as_mut()
2435    }
2436
2437    /// Get the training config.
2438    pub(crate) fn config(&self) -> &TransformerTrainConfig {
2439        &self.config
2440    }
2441
2442    /// Get CPU embedding gradient as flat Vec for AllReduce.
2443    pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2444        self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2445    }
2446
2447    /// Set CPU embedding gradient from AllReduced flat Vec.
2448    pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2449        self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2450    }
2451
2452    /// Returns true if max_steps has been reached.
2453    pub fn reached_max_steps(&self) -> bool {
2454        self.config.max_steps.is_some_and(|max| self.step >= max)
2455    }
2456
2457    /// Get current step count.
2458    pub fn step(&self) -> usize {
2459        self.step
2460    }
2461
2462    /// Set initial step for resume from checkpoint.
2463    ///
2464    /// Updates both the outer step counter (LR schedule, logging) and the
2465    /// GPU-side AdamW step counter (bias correction). Must be called before
2466    /// any `train_batch()` calls.
2467    pub fn set_initial_step(&mut self, step: usize) {
2468        self.step = step;
2469        self.gpu_training.step = step as u32;
2470    }
2471
2472    /// Set max_steps for cosine LR scheduler (ENT-275).
2473    ///
2474    /// Called by `train_loop_cuda` when `max_steps` is not explicitly set in
2475    /// the YAML config — auto-computes `epochs × batches_per_epoch` so cosine
2476    /// decay activates instead of falling back to constant lr.
2477    pub fn set_max_steps(&mut self, max_steps: usize) {
2478        self.config.max_steps = Some(max_steps);
2479    }
2480
2481    /// Get current learning rate (warmup + cosine decay).
2482    ///
2483    /// ALB-079: Phase 1 = linear warmup (0 → lr_max), Phase 2 = cosine decay
2484    /// (lr_max → 0) over remaining steps. Requires `max_steps` for decay;
2485    /// without it, falls back to constant lr after warmup.
2486    pub fn current_lr(&self) -> f32 {
2487        let base_lr = self.config.lr;
2488        if self.step < self.config.warmup_steps {
2489            // Phase 1: Linear warmup
2490            base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2491        } else if let Some(max_steps) = self.config.max_steps {
2492            // Phase 2: Cosine decay from lr_max to 0
2493            let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2494            if decay_steps == 0 {
2495                return base_lr;
2496            }
2497            let decay_step = self.step - self.config.warmup_steps;
2498            let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2499            0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2500        } else {
2501            // No max_steps: constant lr (legacy behavior)
2502            base_lr
2503        }
2504    }
2505
2506    /// KAIZEN-047: Enable step profiling with a report every `interval` steps.
2507    ///
2508    /// When enabled, prints a table of wall-clock timings per training phase
2509    /// every `interval` training steps. Use interval=0 for manual-only reporting.
2510    ///
2511    /// # Contract (C-STEPPROF-001)
2512    ///
2513    /// - No additional GPU synchronization points (relies on existing syncs)
2514    /// - Overhead: ~11 `Instant::now()` calls per step (~1µs total on Linux)
2515    /// - Timings include async dispatch overhead (not pure kernel time)
2516    pub fn enable_profiler(&mut self, interval: usize) {
2517        self.profiler = StepProfiler::new(true, interval);
2518    }
2519
2520    /// Print the profiler report (if profiling is enabled).
2521    pub fn print_profiler_report(&self) {
2522        self.profiler.print_report();
2523    }
2524
2525    /// R-004: Get last observed gradient L2 norm (LM head proxy).
2526    pub fn last_grad_norm(&self) -> f32 {
2527        self.last_grad_norm
2528    }
2529
2530    /// R-040: Get per-parameter-group gradient norms.
2531    /// Returns (lm_head_grad_norm, embed_grad_norm).
2532    pub fn param_grad_norms(&self) -> (f32, f32) {
2533        (self.last_grad_norm, self.last_embed_grad_norm)
2534    }
2535
2536    /// R-012: Get total trainable parameter count for MFU calculation.
2537    pub fn num_params(&self) -> usize {
2538        self.model.parameters().iter().map(|t| t.len()).sum()
2539    }
2540
2541    /// R-013: Query GPU memory usage (used_mb, total_mb).
2542    pub fn gpu_memory_mb(&self) -> (u64, u64) {
2543        match self.cuda_trainer.context().memory_info() {
2544            Ok((free, total)) => {
2545                let total_mb = (total / (1024 * 1024)) as u64;
2546                let used_mb = ((total - free) / (1024 * 1024)) as u64;
2547                (used_mb, total_mb)
2548            }
2549            Err(_) => (0, 0),
2550        }
2551    }
2552
2553    /// Sync all GPU weights back to CPU model.
2554    ///
2555    /// # Contract (C-SYNCWT-001)
2556    ///
2557    /// Must be called before save or any CPU model access after training.
2558    pub fn sync_weights_to_cpu(&mut self) {
2559        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2560
2561        if use_nf4 {
2562            // ENT-263: NF4 blocks are frozen — base weights don't change.
2563            // Only download LoRA adapter weights for checkpoint saving.
2564            // The base model on CPU stays as-is (original pretrained weights).
2565            // LoRA weights are saved separately (adapter_config.json + adapter.safetensors).
2566            // For now, skip per-layer sync — base weights are unchanged.
2567        } else {
2568            for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2569                if let Ok(weights) = block.download_weights() {
2570                    let layer = &mut self.model.layers[layer_idx];
2571
2572                    layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2573                    layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2574                    layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2575                    layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2576
2577                    layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2578                    layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2579                    layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2580
2581                    layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2582                    layer.post_attn_norm.weight =
2583                        Tensor::from_vec(weights.post_attn_norm_weight, false);
2584                }
2585            }
2586        }
2587
2588        // Sync final norm weight
2589        if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2590            self.model.norm.weight = Tensor::from_vec(norm_data, false);
2591        }
2592
2593        // Sync LM head weight
2594        // ALB-097: ALWAYS save GPU-trained LM head, even for tied-weight models.
2595        // During GPU training, lm_head diverges from embed_tokens because they have
2596        // separate optimizers (GPU AdamW vs CPU AdamW). If we skip the sync for tied
2597        // weights, the checkpoint loses 500+ steps of GPU LM head training → random-init
2598        // loss on resume (Five Whys root cause of ALB-097).
2599        if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2600            self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2601        }
2602    }
2603
2604    /// Get reference to model (syncs weights first).
2605    pub fn model(&self) -> &Transformer {
2606        &self.model
2607    }
2608
2609    /// Get mutable reference to model.
2610    pub fn model_mut(&mut self) -> &mut Transformer {
2611        &mut self.model
2612    }
2613
2614    /// Check if using mixed precision.
2615    pub fn is_mixed_precision(&self) -> bool {
2616        self.config.precision_config.is_mixed()
2617    }
2618
2619    /// Get the gradient scaler (R-002: loss scaling for mixed precision).
2620    pub fn grad_scaler(&self) -> &GradScaler {
2621        &self.grad_scaler
2622    }
2623
2624    /// Check if using gradient checkpointing.
2625    pub fn is_checkpointing(&self) -> bool {
2626        self.config.checkpoint_config.enabled
2627    }
2628
2629    /// Save model weights (syncs GPU→CPU first).
2630    pub fn save(
2631        &mut self,
2632        path: impl AsRef<std::path::Path>,
2633        name: &str,
2634        architecture: &str,
2635    ) -> crate::Result<()> {
2636        self.sync_weights_to_cpu();
2637
2638        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2639        let params: Vec<(String, Tensor)> = self
2640            .model
2641            .named_parameters()
2642            .into_iter()
2643            .map(|(name, tensor)| (name, tensor.clone()))
2644            .collect();
2645
2646        let metadata = ModelMetadata::new(name, architecture);
2647        let model = Model::new(metadata, params);
2648        let config = SaveConfig::new(ModelFormat::SafeTensors);
2649
2650        save_model(&model, path, &config)
2651    }
2652
2653    /// R-011: Prepare checkpoint data for async save.
2654    /// Syncs GPU weights to CPU and snapshots tensor data as Send-able Vec<f32>.
2655    /// Returns a closure that writes the checkpoint file from another thread.
2656    pub fn prepare_async_save(
2657        &mut self,
2658        name: &str,
2659        architecture: &str,
2660    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2661        self.sync_weights_to_cpu();
2662
2663        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2664        let param_data: Vec<(String, Vec<f32>)> = self
2665            .model
2666            .named_parameters()
2667            .into_iter()
2668            .map(|(n, t)| (n, t.data().to_vec()))
2669            .collect();
2670
2671        let name = name.to_string();
2672        let architecture = architecture.to_string();
2673
2674        Box::new(move |path: &std::path::Path| {
2675            let params: Vec<(String, Tensor)> =
2676                param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2677            let metadata = ModelMetadata::new(&name, &architecture);
2678            let model = Model::new(metadata, params);
2679            let config = SaveConfig::new(ModelFormat::SafeTensors);
2680            save_model(&model, path, &config)
2681        })
2682    }
2683
2684    /// ALB-096: Save model weights as APR checkpoint (syncs GPU→CPU first).
2685    ///
2686    /// Single atomic file containing all model weights. Use `save_apr_checkpoint()`
2687    /// to include optimizer state and training metadata in the same file.
2688    pub fn save_apr(
2689        &mut self,
2690        path: impl AsRef<std::path::Path>,
2691        name: &str,
2692        architecture: &str,
2693    ) -> crate::Result<()> {
2694        self.sync_weights_to_cpu();
2695
2696        let params: Vec<(String, Tensor)> = self
2697            .model
2698            .named_parameters()
2699            .into_iter()
2700            .map(|(name, tensor)| (name, tensor.clone()))
2701            .collect();
2702
2703        let metadata = ModelMetadata::new(name, architecture);
2704        let model = Model::new(metadata, params);
2705        let config = SaveConfig::new(ModelFormat::Apr);
2706
2707        save_model(&model, path, &config)
2708    }
2709
2710    /// ALB-096: Prepare APR checkpoint data for async save.
2711    ///
2712    /// Syncs GPU weights to CPU and snapshots tensor data + optimizer state as
2713    /// Send-able `Vec<f32>`. Returns a closure that writes a single atomic APR
2714    /// file from another thread. Includes model weights + CPU embedding optimizer
2715    /// state + training metadata — all in one file.
2716    fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2717        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2718        if use_nf4 {
2719            let frozen_suffixes = [
2720                "q_proj.weight",
2721                "k_proj.weight",
2722                "v_proj.weight",
2723                "o_proj.weight",
2724                "gate_proj.weight",
2725                "up_proj.weight",
2726                "down_proj.weight",
2727            ];
2728            self.model
2729                .named_parameters()
2730                .into_iter()
2731                .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2732                .map(|(n, t)| (n, t.data().to_vec()))
2733                .collect()
2734        } else {
2735            self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2736        }
2737    }
2738
2739    fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2740        if self.config.quantize_nf4 && self.config.is_lora() {
2741            self.cuda_blocks
2742                .iter()
2743                .enumerate()
2744                .filter_map(|(i, block)| {
2745                    block
2746                        .download_lora_weights()
2747                        .ok()
2748                        .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2749                })
2750                .collect()
2751        } else {
2752            Vec::new()
2753        }
2754    }
2755
2756    pub fn prepare_async_apr_save(
2757        &mut self,
2758        name: &str,
2759        architecture: &str,
2760        step: usize,
2761        loss: f64,
2762        lr: f64,
2763    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2764        self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2765    }
2766
2767    /// ALB-130: Prepare APR checkpoint with embedded tokenizer for inference.
2768    ///
2769    /// Training checkpoints must be self-contained for eval (`apr eval --task humaneval`).
2770    /// Without embedded tokenizer, inference falls back to structural validation (fake 100%).
2771    /// The tokenizer path comes from `spec.data.tokenizer` in the training YAML.
2772    pub fn prepare_async_apr_save_with_tokenizer(
2773        &mut self,
2774        name: &str,
2775        architecture: &str,
2776        step: usize,
2777        loss: f64,
2778        lr: f64,
2779        tokenizer_path: Option<&std::path::Path>,
2780    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2781        self.sync_weights_to_cpu();
2782
2783        let param_data = self.snapshot_param_data();
2784        let lora_data = self.snapshot_lora_data();
2785
2786        // Snapshot CPU embedding optimizer state
2787        let embed_m: Vec<Vec<f32>> = self
2788            .embed_optimizer
2789            .first_moments()
2790            .iter()
2791            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2792            .collect();
2793        let embed_v: Vec<Vec<f32>> = self
2794            .embed_optimizer
2795            .second_moments()
2796            .iter()
2797            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2798            .collect();
2799        let embed_step = self.embed_optimizer.step_count();
2800
2801        // ALB-118: Download GPU block optimizer states (m/v moments) for checkpointing.
2802        // Without this, resume re-initializes all 24 blocks' AdamW state to zero,
2803        // causing loss spikes and convergence failure (v10/v11/v12 post-mortems).
2804        // Transfer cost: ~2.3 GB D2H, <6ms on PCIe4/5.
2805        let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2806            .gpu_training
2807            .optimizer_states
2808            .iter()
2809            .map(|state| state.download_to_host().unwrap_or_default())
2810            .collect();
2811
2812        // ALB-118: Download LM head and final norm optimizer states
2813        let lm_head_m_host = {
2814            let mut buf = vec![0.0f32; self.lm_head_m.len()];
2815            let _ = self.lm_head_m.copy_to_host(&mut buf);
2816            buf
2817        };
2818        let lm_head_v_host = {
2819            let mut buf = vec![0.0f32; self.lm_head_v.len()];
2820            let _ = self.lm_head_v.copy_to_host(&mut buf);
2821            buf
2822        };
2823        let final_norm_m_host = {
2824            let mut buf = vec![0.0f32; self.final_norm_m.len()];
2825            let _ = self.final_norm_m.copy_to_host(&mut buf);
2826            buf
2827        };
2828        let final_norm_v_host = {
2829            let mut buf = vec![0.0f32; self.final_norm_v.len()];
2830            let _ = self.final_norm_v.copy_to_host(&mut buf);
2831            buf
2832        };
2833
2834        let name = name.to_string();
2835        let architecture = architecture.to_string();
2836        let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2837        let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2838
2839        // ALB-130: Pre-read tokenizer.json for embedding in checkpoint.
2840        // Parse HuggingFace tokenizer format → extract vocab + merges + special token IDs.
2841        let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2842            tokenizer_path.and_then(|p| {
2843                let json_bytes = std::fs::read(p).ok()?;
2844                let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2845                let model = tok.get("model")?;
2846                let vocab_obj = model.get("vocab")?.as_object()?;
2847                // Build sorted-by-id vocab list
2848                let mut vocab_pairs: Vec<(String, u64)> =
2849                    vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
2850                vocab_pairs.sort_by_key(|(_, id)| *id);
2851                let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
2852                // Merges as "token1 token2" strings
2853                let merges: Vec<String> = model
2854                    .get("merges")?
2855                    .as_array()?
2856                    .iter()
2857                    .filter_map(|v| v.as_str().map(String::from))
2858                    .collect();
2859                // Special tokens: BOS=<s>=1, EOS=</s>=2 (from added_tokens)
2860                let added = tok.get("added_tokens").and_then(|a| a.as_array());
2861                let bos_id = added.and_then(|arr| {
2862                    arr.iter()
2863                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
2864                        .and_then(|t| t.get("id")?.as_u64())
2865                });
2866                let eos_id = added.and_then(|arr| {
2867                    arr.iter()
2868                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
2869                        .and_then(|t| t.get("id")?.as_u64())
2870                });
2871                if vocab.is_empty() {
2872                    return None;
2873                }
2874                println!(
2875                    "  [ALB-130] Embedding tokenizer: {} vocab, {} merges",
2876                    vocab.len(),
2877                    merges.len()
2878                );
2879                Some((vocab, merges, bos_id, eos_id))
2880            });
2881
2882        Box::new(move |path: &std::path::Path| {
2883            use aprender::serialization::apr::AprWriter;
2884            use serde_json::Value as Jv;
2885
2886            let mut writer = AprWriter::new();
2887
2888            // Metadata
2889            writer.set_metadata("model_name", Jv::String(name));
2890            writer.set_metadata("architecture", Jv::String(architecture));
2891            writer.set_metadata(
2892                "format",
2893                Jv::String(if is_delta_checkpoint {
2894                    "entrenar-delta-checkpoint".into()
2895                } else {
2896                    "entrenar-checkpoint".into()
2897                }),
2898            );
2899            writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
2900            writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
2901            writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
2902            writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
2903            if let Some(cfg) = model_config_json {
2904                writer.set_metadata("model_config", Jv::String(cfg));
2905            }
2906
2907            // ALB-130: Embed tokenizer vocab + merges for standalone inference
2908            if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
2909                writer.set_metadata(
2910                    "tokenizer.vocabulary",
2911                    Jv::Array(vocab.into_iter().map(Jv::String).collect()),
2912                );
2913                writer.set_metadata(
2914                    "tokenizer.merges",
2915                    Jv::Array(merges.into_iter().map(Jv::String).collect()),
2916                );
2917                if let Some(bos) = bos_id {
2918                    writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
2919                }
2920                if let Some(eos) = eos_id {
2921                    writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
2922                }
2923            }
2924
2925            // Find hidden_size from norm weights for shape inference
2926            let hidden_size = param_data
2927                .iter()
2928                .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
2929                .map_or(0, |(_, d)| d.len());
2930
2931            // Model weight tensors
2932            for (tensor_name, data) in &param_data {
2933                let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
2934                writer.add_tensor_f32(tensor_name.clone(), shape, data);
2935            }
2936
2937            // Optimizer state tensors
2938            for (i, m_data) in embed_m.iter().enumerate() {
2939                let len = m_data.len();
2940                writer.add_tensor_f32(
2941                    format!("__training__.embed_optimizer.m.{i}"),
2942                    vec![len],
2943                    m_data,
2944                );
2945            }
2946            for (i, v_data) in embed_v.iter().enumerate() {
2947                let len = v_data.len();
2948                writer.add_tensor_f32(
2949                    format!("__training__.embed_optimizer.v.{i}"),
2950                    vec![len],
2951                    v_data,
2952                );
2953            }
2954
2955            // ALB-118: Save GPU block optimizer states (m/v moments for all 24 blocks)
2956            for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
2957                for (suffix, data) in buffers {
2958                    let len = data.len();
2959                    writer.add_tensor_f32(
2960                        format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
2961                        vec![len],
2962                        data,
2963                    );
2964                }
2965            }
2966
2967            // ALB-118: Save LM head and final norm optimizer states
2968            if !lm_head_m_host.is_empty() {
2969                let len = lm_head_m_host.len();
2970                writer.add_tensor_f32(
2971                    "__training__.lm_head_optimizer.m".to_string(),
2972                    vec![len],
2973                    &lm_head_m_host,
2974                );
2975                let len = lm_head_v_host.len();
2976                writer.add_tensor_f32(
2977                    "__training__.lm_head_optimizer.v".to_string(),
2978                    vec![len],
2979                    &lm_head_v_host,
2980                );
2981            }
2982            if !final_norm_m_host.is_empty() {
2983                let len = final_norm_m_host.len();
2984                writer.add_tensor_f32(
2985                    "__training__.final_norm_optimizer.m".to_string(),
2986                    vec![len],
2987                    &final_norm_m_host,
2988                );
2989                let len = final_norm_v_host.len();
2990                writer.add_tensor_f32(
2991                    "__training__.final_norm_optimizer.v".to_string(),
2992                    vec![len],
2993                    &final_norm_v_host,
2994                );
2995            }
2996
2997            // ENT-276: Save LoRA adapter weights (QLoRA checkpoint resume)
2998            for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
2999                if !a_q.is_empty() {
3000                    writer.add_tensor_f32(
3001                        format!("lora.{layer_idx}.q_proj.lora_a"),
3002                        vec![a_q.len()],
3003                        a_q,
3004                    );
3005                    writer.add_tensor_f32(
3006                        format!("lora.{layer_idx}.q_proj.lora_b"),
3007                        vec![b_q.len()],
3008                        b_q,
3009                    );
3010                }
3011                if !a_v.is_empty() {
3012                    writer.add_tensor_f32(
3013                        format!("lora.{layer_idx}.v_proj.lora_a"),
3014                        vec![a_v.len()],
3015                        a_v,
3016                    );
3017                    writer.add_tensor_f32(
3018                        format!("lora.{layer_idx}.v_proj.lora_b"),
3019                        vec![b_v.len()],
3020                        b_v,
3021                    );
3022                }
3023            }
3024
3025            // Write APR checkpoint to file
3026            writer
3027                .write(path)
3028                .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3029
3030            Ok(())
3031        })
3032    }
3033
3034    /// GPU device name.
3035    pub fn gpu_name(&self) -> String {
3036        self.cuda_trainer.device_name()
3037    }
3038
3039    /// ENT-269: Save LoRA adapter weights as PEFT-compatible files.
3040    ///
3041    /// Downloads LoRA A/B matrices from GPU, un-scales B (divide by lora_scale),
3042    /// transposes to PEFT convention (A=[rank, d_in], B=[d_out, rank]),
3043    /// and writes `adapter_model.safetensors` + `adapter_config.json`.
3044    ///
3045    /// # Contract: C-QLORA-SAVE-001
3046    ///
3047    /// NF4 QLoRA training MUST produce `adapter_model.safetensors` in output_dir.
3048    pub fn save_cuda_lora_adapter(
3049        &self,
3050        output_dir: &std::path::Path,
3051        base_model_name: Option<&str>,
3052    ) -> crate::Result<()> {
3053        if !self.config.quantize_nf4 || !self.config.is_lora() {
3054            return Ok(()); // Not a QLoRA run, nothing to save
3055        }
3056
3057        let lora_rank = self.config.lora_rank.unwrap_or(16);
3058        let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3059        let lora_scale = lora_alpha / lora_rank as f32;
3060        let hidden_size = self.config.model_config.hidden_size;
3061        let head_dim = self.config.model_config.head_dim();
3062        let q_dim = self.config.model_config.num_attention_heads * head_dim;
3063        let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3064
3065        let lora_config =
3066            crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3067
3068        let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3069
3070        for (i, block) in self.cuda_blocks.iter().enumerate() {
3071            let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3072                Ok(weights) => weights,
3073                Err(_) => continue, // Skip non-NF4 blocks
3074            };
3075
3076            if a_q.is_empty() && a_v.is_empty() {
3077                continue;
3078            }
3079
3080            // Q projection LoRA
3081            if !a_q.is_empty() {
3082                // GPU stores A_q as [hidden, rank] row-major, PEFT expects [rank, hidden]
3083                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3084                for r in 0..hidden_size {
3085                    for c in 0..lora_rank {
3086                        a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3087                    }
3088                }
3089
3090                // GPU stores B_q as [rank, q_dim] pre-scaled by lora_scale
3091                // PEFT expects [q_dim, rank] un-scaled
3092                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3093                let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3094                for r in 0..lora_rank {
3095                    for c in 0..q_dim {
3096                        b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3097                    }
3098                }
3099
3100                let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3101                let mut layer = crate::lora::LoRALayer::new(
3102                    base_weight,
3103                    q_dim,
3104                    hidden_size,
3105                    lora_rank,
3106                    lora_alpha,
3107                );
3108                // Overwrite the A and B data with trained weights
3109                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3110                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3111
3112                adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3113            }
3114
3115            // V projection LoRA
3116            if !a_v.is_empty() {
3117                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3118                for r in 0..hidden_size {
3119                    for c in 0..lora_rank {
3120                        a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3121                    }
3122                }
3123
3124                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3125                let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3126                for r in 0..lora_rank {
3127                    for c in 0..kv_hidden {
3128                        b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3129                    }
3130                }
3131
3132                let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3133                let mut layer = crate::lora::LoRALayer::new(
3134                    base_weight,
3135                    kv_hidden,
3136                    hidden_size,
3137                    lora_rank,
3138                    lora_alpha,
3139                );
3140                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3141                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3142
3143                adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3144            }
3145        }
3146
3147        if adapters.is_empty() {
3148            println!("  [WARN] No LoRA adapters found to save");
3149            return Ok(());
3150        }
3151
3152        let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3153            adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3154
3155        std::fs::create_dir_all(output_dir).ok();
3156        crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3157            .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3158
3159        let adapter_path = output_dir.join("adapter_model.safetensors");
3160        let size_mb =
3161            std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3162        println!(
3163            "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3164            adapters.len(),
3165            size_mb,
3166            output_dir.display()
3167        );
3168
3169        Ok(())
3170    }
3171
3172    /// R-001: Save CPU embedding optimizer state (m/v buffers + step counter).
3173    ///
3174    /// Writes `optimizer_state.json` to the given directory. GPU block optimizer
3175    /// states remain on-device (D2H for 20 buffers × N blocks is deferred).
3176    pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3177        let path = dir.join("optimizer_state.json");
3178        let m_data: Vec<Option<Vec<f32>>> = self
3179            .embed_optimizer
3180            .first_moments()
3181            .iter()
3182            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3183            .collect();
3184        let v_data: Vec<Option<Vec<f32>>> = self
3185            .embed_optimizer
3186            .second_moments()
3187            .iter()
3188            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3189            .collect();
3190        let state = serde_json::json!({
3191            "type": "adamw_cpu_embed",
3192            "step": self.embed_optimizer.step_count(),
3193            "m": m_data,
3194            "v": v_data,
3195        });
3196        let json_str = serde_json::to_string(&state).map_err(|e| {
3197            crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3198        })?;
3199        std::fs::write(&path, json_str)
3200            .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3201        Ok(())
3202    }
3203
3204    /// ENT-276: Restore LoRA adapter weights from APR checkpoint.
3205    ///
3206    /// Reads `lora.{layer}.{q,v}_proj.lora_{a,b}` tensors from the APR file
3207    /// and uploads them to the NF4 CUDA blocks, replacing the fresh random init.
3208    /// Returns (layers_restored, layers_total).
3209    pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3210        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3211            Ok(r) => r,
3212            Err(_) => return (0, self.cuda_blocks.len()),
3213        };
3214
3215        let mut restored = 0usize;
3216        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3217            let a_q =
3218                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3219            let b_q =
3220                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3221            let a_v =
3222                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3223            let b_v =
3224                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3225
3226            if a_q.is_empty() {
3227                continue; // No LoRA data for this layer in checkpoint
3228            }
3229
3230            if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3231                eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3232                continue;
3233            }
3234            restored += 1;
3235        }
3236
3237        (restored, self.cuda_blocks.len())
3238    }
3239
3240    /// ALB-096: Load CPU embedding optimizer state from APR checkpoint.
3241    ///
3242    /// Reads `__training__.embed_optimizer.{m,v}.*` tensors from the APR file.
3243    /// Returns true if state was loaded.
3244    pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3245        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3246            Ok(r) => r,
3247            Err(_) => return false,
3248        };
3249
3250        // Restore step count from metadata
3251        if let Some(step_val) = reader.get_metadata("optimizer_step") {
3252            if let Some(step_str) = step_val.as_str() {
3253                if let Ok(step) = step_str.parse::<u64>() {
3254                    self.embed_optimizer.set_step_count(step);
3255                }
3256            }
3257        }
3258
3259        // Restore first moments (m)
3260        for i in 0..128 {
3261            let name = format!("__training__.embed_optimizer.m.{i}");
3262            match reader.read_tensor_f32(&name) {
3263                Ok(data) if !data.is_empty() => {
3264                    self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3265                }
3266                _ => break,
3267            }
3268        }
3269
3270        // Restore second moments (v)
3271        for i in 0..128 {
3272            let name = format!("__training__.embed_optimizer.v.{i}");
3273            match reader.read_tensor_f32(&name) {
3274                Ok(data) if !data.is_empty() => {
3275                    self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3276                }
3277                _ => break,
3278            }
3279        }
3280
3281        // ALB-118: Restore GPU block optimizer states (m/v moments for all blocks)
3282        let suffixes = [
3283            "m.w_q",
3284            "v.w_q",
3285            "m.w_k",
3286            "v.w_k",
3287            "m.w_v",
3288            "v.w_v",
3289            "m.w_o",
3290            "v.w_o",
3291            "m.w_gate",
3292            "v.w_gate",
3293            "m.w_up",
3294            "v.w_up",
3295            "m.w_down",
3296            "v.w_down",
3297            "m.input_norm",
3298            "v.input_norm",
3299            "m.post_attn_norm",
3300            "v.post_attn_norm",
3301        ];
3302        let mut blocks_restored = 0usize;
3303        for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3304            let mut data = std::collections::HashMap::new();
3305            for suffix in &suffixes {
3306                let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3307                if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3308                    if !tensor_data.is_empty() {
3309                        data.insert(suffix.to_string(), tensor_data);
3310                    }
3311                }
3312            }
3313            if !data.is_empty() {
3314                let _ = state.restore_from_host(&data);
3315                blocks_restored += 1;
3316            }
3317        }
3318
3319        // ALB-118: Restore LM head optimizer state
3320        if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3321            if m_data.len() == self.lm_head_m.len() {
3322                let _ = self.lm_head_m.copy_from_host(&m_data);
3323            }
3324        }
3325        if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3326            if v_data.len() == self.lm_head_v.len() {
3327                let _ = self.lm_head_v.copy_from_host(&v_data);
3328            }
3329        }
3330
3331        // ALB-118: Restore final norm optimizer state
3332        if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3333            if m_data.len() == self.final_norm_m.len() {
3334                let _ = self.final_norm_m.copy_from_host(&m_data);
3335            }
3336        }
3337        if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3338            if v_data.len() == self.final_norm_v.len() {
3339                let _ = self.final_norm_v.copy_from_host(&v_data);
3340            }
3341        }
3342
3343        // ALB-132: Report restore results — don't silently swallow failures
3344        if blocks_restored > 0 {
3345            println!(
3346                "  ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3347                self.gpu_training.optimizer_states.len()
3348            );
3349        } else if !self.gpu_training.optimizer_states.is_empty() {
3350            println!(
3351                "  [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3352                self.gpu_training.optimizer_states.len()
3353            );
3354        }
3355
3356        true
3357    }
3358
3359    /// R-001: Load CPU embedding optimizer state from `optimizer_state.json`.
3360    ///
3361    /// Returns true if state was loaded, false if file doesn't exist.
3362    pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3363        let path = dir.join("optimizer_state.json");
3364        let data = match std::fs::read_to_string(&path) {
3365            Ok(d) => d,
3366            Err(_) => return false,
3367        };
3368        let state: serde_json::Value = match serde_json::from_str(&data) {
3369            Ok(v) => v,
3370            Err(_) => return false,
3371        };
3372        if let Some(step) = state["step"].as_u64() {
3373            self.embed_optimizer.set_step_count(step);
3374        }
3375        restore_moment_buffers(&state["m"], |idx, arr| {
3376            self.embed_optimizer.set_first_moment(idx, arr);
3377        });
3378        restore_moment_buffers(&state["v"], |idx, arr| {
3379            self.embed_optimizer.set_second_moment(idx, arr);
3380        });
3381        true
3382    }
3383}
3384
3385/// ALB-096: Infer 2D tensor shape from name and element count.
3386///
3387/// Same logic as `infer_all_tensor_shapes` in `io/save.rs` but for a single tensor.
3388#[cfg(feature = "cuda")]
3389fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3390    if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3391        vec![numel]
3392    } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3393        let other_dim = numel / hidden_size;
3394        if name.ends_with("down_proj.weight") {
3395            vec![hidden_size, other_dim]
3396        } else {
3397            vec![other_dim, hidden_size]
3398        }
3399    } else {
3400        vec![numel]
3401    }
3402}
3403
3404/// Parse a JSON array of moment buffers and apply each via callback.
3405#[cfg(feature = "cuda")]
3406fn restore_moment_buffers(
3407    json_arr: &serde_json::Value,
3408    mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3409) {
3410    let Some(arr) = json_arr.as_array() else { return };
3411    for (idx, val) in arr.iter().enumerate() {
3412        let Some(inner) = val.as_array() else { continue };
3413        let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3414        if !floats.is_empty() {
3415            set_fn(idx, ndarray::Array1::from_vec(floats));
3416        }
3417    }
3418}
3419
3420// ── Non-CUDA stub ──
3421
3422#[cfg(not(feature = "cuda"))]
3423pub struct CudaTransformerTrainer;
3424
3425#[cfg(not(feature = "cuda"))]
3426impl CudaTransformerTrainer {
3427    pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3428        Err(crate::error::Error::ConfigError(
3429            "CUDA not available (compiled without cuda feature)".into(),
3430        ))
3431    }
3432
3433    pub fn with_model(
3434        _model: crate::transformer::Transformer,
3435        _config: super::config::TransformerTrainConfig,
3436    ) -> crate::Result<Self> {
3437        Err(crate::error::Error::ConfigError(
3438            "CUDA not available (compiled without cuda feature)".into(),
3439        ))
3440    }
3441
3442    pub fn gpu_name(&self) -> String {
3443        unreachable!("CudaTransformerTrainer stub should never be instantiated")
3444    }
3445}
3446
3447#[cfg(test)]
3448mod tests {
3449    #[test]
3450    #[cfg(not(feature = "cuda"))]
3451    fn test_cuda_trainer_stub_returns_error() {
3452        use super::super::config::TransformerTrainConfig;
3453        use crate::transformer::TransformerConfig;
3454
3455        let mc = TransformerConfig::tiny();
3456        let config = TransformerTrainConfig::new(mc);
3457        let result = super::CudaTransformerTrainer::new(config);
3458        assert!(result.is_err());
3459    }
3460}