Skip to main content

entrenar/train/transformer_trainer/
cuda_trainer.rs

1//! GPU-resident transformer trainer (ALB-040)
2//!
3//! Wires the existing `CudaTransformerBlock` forward/backward/optimizer_step
4//! into the pretraining path. Follows the proven `classify_pipeline.rs` pattern.
5//!
6//! # Architecture
7//!
8//! ```text
9//! CudaTransformerTrainer
10//! ├── model: Transformer                 (CPU — embed + save)
11//! ├── cuda_trainer: CudaTrainer          (GPU device context)
12//! ├── cuda_blocks: Vec<CudaBlock>            (fp32 or NF4)
13//! ├── cuda_grad_workspace: CudaGradWorkspace
14//! ├── gpu_training: GpuPretrainState     (layer_inputs, grad bufs, opt states)
15//! ├── lm_head_weight_gpu: GpuBuffer      (V × H on GPU)
16//! ├── lm_head_grad_gpu: GpuBuffer        (V × H gradient scratch)
17//! ├── lm_head_m/v: GpuBuffer             (AdamW moment states)
18//! └── config: TransformerTrainConfig
19//! ```
20//!
21//! # Transfer budget (C-GPUTRAIN-002, updated KAIZEN-050/052)
22//!
23//! 1 PCIe transfer per training step (+ tiny control transfers):
24//! 1. H2D: hidden states after embedding (seq×H×4 bytes)
25//! 2. H2D: target_ids for fused cross-entropy (seq×4 bytes — ~512B)
26//! 3. D2H: loss_partials from fused cross-entropy (seq×4 bytes — ~512B)
27//!
28//! Eliminated by KAIZEN-050:
29//! - D2H logits (was seq×V×4 = 77.8MB for Qwen3-4B)
30//! - H2D grad_logits (was seq×V×4 = 77.8MB)
31//!
32//! Eliminated by KAIZEN-052:
33//! - grad_gpu buffer allocation (was seq×V×4 = 77.8MB per step)
34
35#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{gemm_forward, pre_warm_forward_kernels, rms_norm_forward};
42#[cfg(feature = "cuda")]
43use crate::autograd::cuda_optim::{
44    adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
45    gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
46    squared_sum_launch_into, FusedClipState,
47};
48#[cfg(feature = "cuda")]
49use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
50#[cfg(feature = "cuda")]
51use crate::autograd::precision::GradScaler;
52#[cfg(feature = "cuda")]
53use crate::autograd::Tensor;
54#[cfg(feature = "cuda")]
55use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
56#[cfg(feature = "cuda")]
57use crate::optim::{AdamW, Optimizer};
58#[cfg(feature = "cuda")]
59use crate::train::MetricsTracker;
60#[cfg(feature = "cuda")]
61use crate::transformer::{
62    CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
63    GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
64};
65
66#[cfg(feature = "cuda")]
67use super::batch::LMBatch;
68#[cfg(feature = "cuda")]
69use super::config::TransformerTrainConfig;
70#[cfg(feature = "cuda")]
71use super::step_profiler::StepProfiler;
72
73/// Compute gradient L2 norm of the shared workspace via GPU reduction (KAIZEN-054).
74///
75/// Uses `squared_sum_cuda` per buffer (~1KB D2H each) instead of downloading entire
76/// gradient buffers to CPU (was 58 MB+ per block, disabled in ALB-067).
77///
78/// Free function to avoid borrow conflicts with `&mut self`.
79#[cfg(feature = "cuda")]
80fn compute_workspace_clip_scale_gpu(
81    ws: &CudaGradWorkspace,
82    max_norm: f32,
83    stream: &CudaStream,
84) -> (f32, f32) {
85    use crate::autograd::cuda_optim::PendingSquaredSum;
86
87    let all_bufs: [&GpuBuffer<f32>; 9] = [
88        &ws.grad_w_q,
89        &ws.grad_w_k,
90        &ws.grad_w_v,
91        &ws.grad_w_o,
92        &ws.grad_gate,
93        &ws.grad_up,
94        &ws.grad_down,
95        &ws.grad_input_norm,
96        &ws.grad_post_attn_norm,
97    ];
98
99    // KAIZEN-055: Launch all 9 squared_sum kernels back-to-back without syncing.
100    // Single sync after all launches — reduces 9 pipeline flushes to 1 per block.
101    let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
102    for buf in &all_bufs {
103        let n = buf.len() as u32;
104        if n == 0 {
105            continue;
106        }
107        if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
108            pending.push(p);
109        }
110    }
111
112    // Single sync point for all 9 kernel launches.
113    if stream.synchronize().is_err() {
114        return (1.0, 0.0);
115    }
116
117    // Collect results: download partial sums (~1KB each) and combine.
118    // C-CLIP-001: squared_sum_collect returns sum(x²) = ||g||².
119    // Accumulate directly — do NOT re-square (entrenar#311 fix).
120    let mut total_sq = 0.0f64;
121    for p in &pending {
122        if let Ok(sq_norm) = squared_sum_collect(p) {
123            total_sq += f64::from(sq_norm); // sq_norm is already ||g||²
124        }
125    }
126
127    let grad_norm = total_sq.sqrt() as f32; // L2 norm = sqrt(sum of squared norms)
128    let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
129    (scale, grad_norm)
130}
131
132/// Clip all gradient buffers in the shared workspace using GPU-computed L2 norm (KAIZEN-054).
133///
134/// R-004: Returns pre-clip gradient L2 norm for observability logging.
135#[cfg(feature = "cuda")]
136fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
137    let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
138    if (scale - 1.0).abs() < 1e-7 {
139        return grad_norm;
140    }
141
142    let n_wq = ws.grad_w_q.len() as u32;
143    let n_wk = ws.grad_w_k.len() as u32;
144    let n_wv = ws.grad_w_v.len() as u32;
145    let n_wo = ws.grad_w_o.len() as u32;
146    let n_gate = ws.grad_gate.len() as u32;
147    let n_up = ws.grad_up.len() as u32;
148    let n_down = ws.grad_down.len() as u32;
149    let n_inorm = ws.grad_input_norm.len() as u32;
150    let n_panorm = ws.grad_post_attn_norm.len() as u32;
151
152    let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
153    let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
154    let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
155    let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
156    let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
157    let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
158    let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
159    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
160    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
161    grad_norm
162}
163
164/// ALB-078: Fused gradient clipping — entire pipeline stays on GPU.
165///
166/// Replaces `clip_workspace_gradients` by eliminating the stream.synchronize()
167/// and D2H partial-sum download. All computation happens on GPU:
168///
169/// 1. 9× SquaredSumKernel → write partials to pre-allocated contiguous buffer
170/// 2. 1× ClipScaleReduceKernel → reduce partials, compute scale on GPU
171/// 3. 9× GradientClipGpuScaleKernel → read scale from GPU, apply to gradients
172///
173/// Zero sync points, zero D2H transfers per block.
174#[cfg(feature = "cuda")]
175fn fused_clip_workspace_gradients(
176    ws: &mut CudaGradWorkspace,
177    max_norm: f32,
178    state: &FusedClipState,
179    stream: &CudaStream,
180) {
181    let all_bufs: [&GpuBuffer<f32>; 9] = [
182        &ws.grad_w_q,
183        &ws.grad_w_k,
184        &ws.grad_w_v,
185        &ws.grad_w_o,
186        &ws.grad_gate,
187        &ws.grad_up,
188        &ws.grad_down,
189        &ws.grad_input_norm,
190        &ws.grad_post_attn_norm,
191    ];
192
193    // Phase 1: Launch 9 squared_sum kernels into contiguous partials buffer.
194    // Each writes to state.partials_buf at its pre-computed offset.
195    for (i, buf) in all_bufs.iter().enumerate() {
196        let n = buf.len() as u32;
197        if n == 0 {
198            continue;
199        }
200        let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
201        let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
202    }
203
204    // Phase 2: Reduce all partials and compute clip_scale on GPU.
205    // Stream ordering guarantees all squared_sum kernels complete before this runs.
206    let _ = clip_scale_reduce_cuda(
207        &state.partials_buf,
208        state.total_partials,
209        max_norm,
210        &state.scale_buf,
211        stream,
212    );
213
214    // Phase 3: Apply clip scale to all 9 gradient buffers.
215    // Scale is read from GPU memory — no D2H needed.
216    let scale_ptr = state.scale_buf.as_ptr(); // output[0] = clip_scale
217    let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
218        &mut ws.grad_w_q,
219        &mut ws.grad_w_k,
220        &mut ws.grad_w_v,
221        &mut ws.grad_w_o,
222        &mut ws.grad_gate,
223        &mut ws.grad_up,
224        &mut ws.grad_down,
225        &mut ws.grad_input_norm,
226        &mut ws.grad_post_attn_norm,
227    ];
228    for buf in &mut all_bufs_mut {
229        let n = buf.len() as u32;
230        if n == 0 {
231            continue;
232        }
233        let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
234    }
235}
236
237/// R-004: Compute gradient L2 norm without clipping (for observability only).
238///
239/// Uses GPU reduction (KAIZEN-054). Only ~9KB D2H per call.
240#[cfg(feature = "cuda")]
241#[allow(dead_code)]
242fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
243    let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
244    norm
245}
246
247/// ALB-072: Unscale all gradient buffers in the shared workspace by `inv_scale`.
248///
249/// In fp16 AMP, the fused cross-entropy kernel multiplies loss_scale into the
250/// gradient output. All subsequent backward gradients carry this scaling. The
251/// GPU block optimizer (AdamW) must receive unscaled gradients — otherwise the
252/// second moment `v` overflows f32, producing NaN in early layers.
253///
254/// This is the GPU-side equivalent of `GradScaler::unscale_and_check()` used
255/// for CPU embedding gradients.
256#[cfg(feature = "cuda")]
257#[allow(dead_code)]
258fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
259    if (inv_scale - 1.0).abs() < 1e-7 {
260        return;
261    }
262
263    let n_wq = ws.grad_w_q.len() as u32;
264    let n_wk = ws.grad_w_k.len() as u32;
265    let n_wv = ws.grad_w_v.len() as u32;
266    let n_wo = ws.grad_w_o.len() as u32;
267    let n_gate = ws.grad_gate.len() as u32;
268    let n_up = ws.grad_up.len() as u32;
269    let n_down = ws.grad_down.len() as u32;
270    let n_inorm = ws.grad_input_norm.len() as u32;
271    let n_panorm = ws.grad_post_attn_norm.len() as u32;
272
273    let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
274    let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
275    let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
276    let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
277    let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
278    let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
279    let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
280    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
281    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
282}
283
284/// GPU-resident training state for pretraining.
285///
286/// # Contract (C-GPUTRAIN-001)
287///
288/// - `layer_inputs.len() == num_layers`
289/// - All buffers preallocated at init; zero GPU allocations during training
290/// - `step` increments monotonically
291#[cfg(feature = "cuda")]
292struct GpuPretrainState {
293    /// Saved layer inputs for backward [num_layers][seq_len * hidden_size]
294    layer_inputs: Vec<GpuBuffer<f32>>,
295    /// Which layer inputs were saved during forward (activation checkpointing).
296    /// When checkpointing is enabled, only checkpoint boundary layers are saved.
297    /// Non-saved layers are recomputed from the nearest checkpoint before backward.
298    saved_layer_mask: Vec<bool>,
299    /// Temporary buffer for activation recomputation [seq_len * hidden_size].
300    /// Used as the initial input when recomputing from a checkpoint boundary.
301    /// Only allocated when activation checkpointing is enabled.
302    recompute_buf: Option<GpuBuffer<f32>>,
303    /// Final RMSNorm weight on GPU [hidden_size]
304    final_norm_weight: GpuBuffer<f32>,
305    /// Final block output (pre-norm) for RMSNorm backward [seq_len * hidden_size]
306    blocks_output: GpuBuffer<f32>,
307    /// Alternating gradient buffer A [seq_len * hidden_size]
308    grad_buf_a: GpuBuffer<f32>,
309    /// Alternating gradient buffer B [seq_len * hidden_size]
310    grad_buf_b: GpuBuffer<f32>,
311    /// Gradient for final norm weight [hidden_size]
312    grad_final_norm_weight: GpuBuffer<f32>,
313    /// RMSNorm output buffer (reused each step) [seq_len * hidden_size]
314    norm_output: GpuBuffer<f32>,
315    /// Logits buffer (reused each step) [seq_len * vocab_size]
316    logits_buf: GpuBuffer<f32>,
317    /// LM head gradient buffer [seq_len * hidden_size] (grad w.r.t. normed hidden)
318    lm_head_grad_hidden: GpuBuffer<f32>,
319    /// Per-block optimizer states
320    optimizer_states: Vec<GpuBlockOptimizerState>,
321    /// Optimizer step counter
322    step: u32,
323}
324
325/// GPU-resident transformer trainer for pretraining.
326///
327/// Uses `CudaTransformerBlock` forward/backward/optimizer_step on GPU,
328/// keeping only embedding lookup and cross-entropy loss on CPU.
329///
330/// # Contract (C-GPUTRAIN-002)
331///
332/// - Exactly 3 PCIe transfers per training step
333/// - Graceful fallback to CPU `TransformerTrainer` on any CUDA failure
334/// - Weight sync via `sync_weights_to_cpu()` before save
335#[cfg(feature = "cuda")]
336pub struct CudaTransformerTrainer {
337    /// CPU model (for embedding, saving, fallback)
338    model: Transformer,
339    /// CUDA device context
340    cuda_trainer: CudaTrainer,
341    /// GPU-resident transformer blocks (fp32 or NF4 via CudaBlock enum)
342    cuda_blocks: Vec<CudaBlock>,
343    /// Shared gradient workspace (one set, reused across layers; fp32 path only)
344    cuda_grad_workspace: CudaGradWorkspace,
345    /// ENT-263: Shared scratch for NF4 blocks (C-SCRATCH-001). None when fp32.
346    nf4_shared_scratch: Option<CudaBlockScratch>,
347    /// ENT-263: Shared LoRA gradient workspace for NF4 backward. None when fp32.
348    nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
349    /// ENT-263: Per-block LoRA optimizer states for NF4. None when fp32.
350    nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
351    /// GPU training state (layer inputs, grad bufs, optimizer states)
352    gpu_training: GpuPretrainState,
353    /// LM head weight on GPU [vocab_size * hidden_size]
354    lm_head_weight_gpu: GpuBuffer<f32>,
355    /// LM head weight gradient on GPU [vocab_size * hidden_size]
356    lm_head_grad_gpu: GpuBuffer<f32>,
357    /// LM head AdamW first moment [vocab_size * hidden_size]
358    lm_head_m: GpuBuffer<f32>,
359    /// LM head AdamW second moment [vocab_size * hidden_size]
360    lm_head_v: GpuBuffer<f32>,
361    /// Final norm weight AdamW first moment [hidden_size]
362    final_norm_m: GpuBuffer<f32>,
363    /// Final norm weight AdamW second moment [hidden_size]
364    final_norm_v: GpuBuffer<f32>,
365    /// CPU optimizer for embedding weights only
366    embed_optimizer: AdamW,
367    /// Training configuration
368    config: TransformerTrainConfig,
369    /// Metrics tracker
370    pub metrics: MetricsTracker,
371    /// Current optimizer step
372    step: usize,
373    /// Accumulated loss (for gradient accumulation)
374    accumulated_loss: f32,
375    /// Accumulated batch count
376    accumulated_batches: usize,
377    /// R-004: Last observed LM head gradient L2 norm (proxy for global grad norm)
378    last_grad_norm: f32,
379    /// R-040: Last observed embedding activation gradient L2 norm
380    last_embed_grad_norm: f32,
381    /// R-038: Per-block gradient accumulation for true multi-step gradient accumulation.
382    /// Only allocated when accumulation_steps > 1. CPU-side buffers (~335 MB for 350M).
383    grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
384    /// ALB-091: GPU-resident gradient accumulation (replaces CPU accum when available).
385    /// Eliminates 24 × ga stream.synchronize() + D2H transfers per optimizer step.
386    gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
387    /// R-002: Gradient scaler for mixed-precision training.
388    /// For BF16: no-op (scale=1.0, dynamic=false).
389    /// For FP16: dynamic loss scaling to prevent gradient underflow.
390    grad_scaler: GradScaler,
391    /// KAIZEN-047: Per-step wall-clock profiler.
392    /// Reports timing breakdown for each training phase.
393    profiler: StepProfiler,
394    /// KAIZEN-053: Pre-allocated forward scratch buffers [max_seq_len * hidden_size].
395    /// Reused every step — eliminates 2 × cuMemAlloc/Free per training step.
396    fwd_scratch_a: GpuBuffer<f32>,
397    fwd_scratch_b: GpuBuffer<f32>,
398    /// KAIZEN-056: Pre-allocated CPU staging buffer for H2D hidden state upload.
399    /// Eliminates vec![0.0; max_seq_len * hidden_size] allocation per step.
400    h2d_staging: Vec<f32>,
401    /// KAIZEN-059: Pre-allocated CPU staging buffer for D2H gradient downloads
402    /// during gradient accumulation. Sized to max(h*intermediate, vocab*h).
403    /// Eliminates ~15GB of per-step heap churn (36 × vec![0.0; h*i] + vec![0.0; vocab*h]
404    /// per micro-batch × accumulation_steps).
405    d2h_staging: Vec<f32>,
406    /// ALB-078: Pre-allocated state for fused gradient clipping pipeline.
407    /// Eliminates 24 stream.synchronize() calls per step.
408    fused_clip: Option<FusedClipState>,
409    /// Pre-allocated host zero buffer for zeroing final norm grad [hidden_size].
410    /// BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
411    /// so the buffer must be zeroed before each rms_norm_backward call.
412    final_norm_zero_buf: Vec<f32>,
413}
414
415#[cfg(feature = "cuda")]
416impl CudaTransformerTrainer {
417    /// Create a new GPU-resident trainer.
418    ///
419    /// # Errors
420    ///
421    /// Returns `Err` if CUDA initialization, kernel pre-warming, or block upload fails.
422    /// Caller should fall back to CPU `TransformerTrainer` on error.
423    pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
424        let model = Transformer::new(&config.model_config);
425        Self::with_model(model, config)
426    }
427
428    /// ALB-089: Load SafeTensors checkpoint for GPU inference (forward-only).
429    ///
430    /// Creates a `CudaTransformerTrainer` in inference mode. The optimizer
431    /// state is allocated (wasteful but simple), but `forward_logits()` only
432    /// uses the forward path. Call `forward_logits(&tokens)` to generate.
433    ///
434    /// # Arguments
435    /// * `checkpoint_dir` - Directory containing model.safetensors + config.json
436    /// * `model_config` - Transformer architecture config
437    ///
438    /// # Errors
439    ///
440    /// Returns `Err` if SafeTensors loading or CUDA initialization fails.
441    pub fn for_inference(
442        checkpoint_dir: impl AsRef<std::path::Path>,
443        model_config: crate::transformer::TransformerConfig,
444    ) -> crate::Result<Self> {
445        let dir = checkpoint_dir.as_ref();
446
447        // ALB-089: Try APR format first (our native checkpoint format), then SafeTensors
448        let model = if let Some((Some(m), _step)) =
449            crate::config::try_load_apr_for_inference(dir, &model_config)
450        {
451            m
452        } else {
453            Transformer::from_safetensors(dir, &model_config)?
454        };
455
456        let mut config = TransformerTrainConfig::new(model_config);
457        config.max_seq_len = config.model_config.max_position_embeddings;
458        Self::with_model(model, config)
459    }
460
461    /// Create a GPU-resident trainer from an existing model.
462    ///
463    /// # Errors
464    ///
465    /// Returns `Err` if CUDA initialization fails.
466    pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
467        if !cuda_training_available() {
468            return Err(crate::error::Error::ConfigError("CUDA not available".into()));
469        }
470
471        let mc = &config.model_config;
472        let max_seq_len = config.max_seq_len;
473        let hidden_size = mc.hidden_size;
474        let vocab_size = mc.vocab_size;
475        let num_layers = mc.num_hidden_layers;
476
477        // Step 1: Create CUDA trainer (initializes kernel caches)
478        let cuda_trainer = CudaTrainer::new().map_err(|e| {
479            crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
480        })?;
481
482        println!(
483            "  GPU: {} ({:.1} GB)",
484            cuda_trainer.device_name(),
485            cuda_trainer.total_memory() as f64 / 1e9
486        );
487
488        let ctx = cuda_trainer.context().clone();
489        let stream = cuda_trainer.stream();
490
491        // Step 2: Pre-warm forward kernels (C-PREWARM-001)
492        // Must happen before block upload — JIT compilation needs free VRAM
493        pre_warm_forward_kernels(
494            hidden_size,
495            mc.intermediate_size,
496            mc.num_attention_heads,
497            mc.num_kv_heads,
498            mc.head_dim(),
499            max_seq_len,
500        )
501        .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
502
503        // Step 2a: Pre-warm backward kernels (trueno#200)
504        // MUST happen before any GPU work — Blackwell's cuModuleLoadData fails
505        // with ILLEGAL_ADDRESS when called during active GPU computation.
506        {
507            use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
508            let head_dim = mc.head_dim();
509            pre_warm_lora_backward_kernels(
510                hidden_size,
511                mc.num_attention_heads * head_dim,
512                mc.num_kv_heads * head_dim,
513                max_seq_len,
514                config.lora_rank.unwrap_or(0),
515                mc.intermediate_size,
516                mc.num_attention_heads,
517                config.quantize_nf4 && config.is_lora(),
518            )
519            .map_err(|e| {
520                crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
521            })?;
522            eprintln!("  ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
523        }
524
525        // Step 2b: Bind cuBLAS handles to training stream (ALB-075)
526        // Must happen after kernel cache init, before any GEMM calls.
527        if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
528            println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
529        }
530        if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
531            println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
532        }
533
534        // Step 3: Upload transformer blocks to GPU
535        let use_nf4 = config.quantize_nf4 && config.is_lora();
536        let cuda_blocks = Self::upload_blocks(
537            &model,
538            mc,
539            &config,
540            &ctx,
541            use_nf4,
542            num_layers,
543            hidden_size,
544            max_seq_len,
545        )?;
546
547        // Step 4: Allocate shared gradient workspace
548        let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
549            crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
550        })?;
551
552        // Step 5: Allocate GPU training state
553        let buf_size = max_seq_len * hidden_size;
554        let logits_size = max_seq_len * vocab_size;
555
556        // Activation checkpointing: determine which layers save their inputs.
557        // Checkpoint boundary layers (every segment_size layers) are always saved.
558        // Non-boundary layers are recomputed from the nearest checkpoint during backward.
559        let checkpointing = config.checkpoint_config.enabled;
560        let segment_size = if checkpointing {
561            let ns = config.checkpoint_config.num_segments.max(1);
562            num_layers.div_ceil(ns)
563        } else {
564            1 // Every layer is a checkpoint (no recomputation)
565        };
566        let saved_layer_mask: Vec<bool> =
567            (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
568
569        let mut layer_inputs = Vec::with_capacity(num_layers);
570        for _ in 0..num_layers {
571            layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
572                crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
573            })?);
574        }
575
576        // Allocate recompute buffer if checkpointing is enabled
577        let recompute_buf = if checkpointing {
578            Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
579                crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
580            })?)
581        } else {
582            None
583        };
584
585        if checkpointing {
586            let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
587            println!(
588                "  ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
589                config.checkpoint_config.num_segments, saved_count, num_layers
590            );
591        }
592
593        // Upload final RMSNorm weight
594        let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
595        let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
596            crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
597        })?;
598
599        let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
600            crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
601        })?;
602        let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
603            crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
604        })?;
605        let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
606            crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
607        })?;
608        let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
609            crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
610        })?;
611        let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
612            crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
613        })?;
614        let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
615            crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
616        })?;
617        let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
618            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
619        })?;
620
621        // Initialize per-block optimizer states (fp32 path only; NF4 uses LoRA states)
622        let mut optimizer_states = Vec::new();
623        if !use_nf4 {
624            optimizer_states.reserve(num_layers);
625            for (i, block) in cuda_blocks.iter().enumerate() {
626                optimizer_states.push(block.init_optimizer_state().map_err(|e| {
627                    crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
628                })?);
629            }
630        }
631
632        let gpu_training = GpuPretrainState {
633            layer_inputs,
634            saved_layer_mask,
635            recompute_buf,
636            final_norm_weight,
637            blocks_output,
638            grad_buf_a,
639            grad_buf_b,
640            grad_final_norm_weight,
641            norm_output,
642            logits_buf,
643            lm_head_grad_hidden,
644            optimizer_states,
645            step: 0,
646        };
647
648        // Step 6: Upload LM head weight to GPU
649        // Use tied weights (embed_tokens.weight) or separate lm_head
650        let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
651        let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
652        let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
653            crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
654        })?;
655        let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
656            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
657        })?;
658        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
659        // zero memory (cuMemAlloc returns uninitialized VRAM).
660        let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
661            .map_err(|e| {
662                crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
663            })?;
664        let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
665            .map_err(|e| {
666                crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
667            })?;
668
669        // Final norm optimizer states
670        let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
671            crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
672        })?;
673        let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
674            crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
675        })?;
676
677        // KAIZEN-053: Pre-allocate forward scratch buffers (reused every step)
678        let buf_size = max_seq_len * hidden_size;
679        let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
680            crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
681        })?;
682        let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
683            crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
684        })?;
685
686        // Sync to ensure all uploads completed
687        stream
688            .synchronize()
689            .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
690
691        println!(
692            "  ✓ GPU training state allocated (LM head: {:.1} MB)",
693            (vocab_size * hidden_size * 4) as f64 / 1e6
694        );
695
696        // ENT-263: Allocate NF4 infrastructure (shared scratch, LoRA grad workspace, optimizer states)
697        let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
698            let lora_rank = config.lora_rank.unwrap_or(16);
699
700            // C-SCRATCH-001: Shared scratch for NF4 blocks (reused across all layers)
701            let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
702                crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
703            })?;
704
705            // LoRA gradient workspace (shared, reused per-block like CudaGradWorkspace)
706            let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
707                crate::error::Error::ConfigError(format!(
708                    "NF4 LoRA grad workspace alloc failed: {e:?}"
709                ))
710            })?;
711
712            // Per-block LoRA optimizer states
713            let mut lora_opt_states = Vec::with_capacity(num_layers);
714            for (i, block) in cuda_blocks.iter().enumerate() {
715                lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
716                    crate::error::Error::ConfigError(format!(
717                        "Block {i} LoRA opt state failed: {e:?}"
718                    ))
719                })?);
720            }
721
722            println!(
723                "  ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
724            );
725            (Some(scratch), Some(grad_ws), Some(lora_opt_states))
726        } else {
727            (None, None, None)
728        };
729
730        // KAIZEN-050: loss_fn removed — cross-entropy computed by fused GPU kernel
731        // C-EMBED-GRAD-001: CPU optimizer must match YAML hyperparams (not defaults)
732        let embed_optimizer =
733            AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
734
735        // R-038: Allocate per-block gradient accumulation buffers (CPU-side)
736        // when accumulation_steps > 1 for true gradient accumulation.
737        let grad_accum = if config.accumulation_steps > 1 {
738            let kv_hidden = mc.num_kv_heads * mc.head_dim();
739            let block_sizes =
740                super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
741                    hidden_size,
742                    kv_hidden,
743                    mc.intermediate_size,
744                );
745            let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
746                num_layers,
747                block_sizes,
748                vocab_size,
749                hidden_size,
750            );
751            println!(
752                "  ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
753                config.accumulation_steps,
754                (accum
755                    .block_grads
756                    .iter()
757                    .map(super::grad_accumulator::BlockGradientSet::total_elements)
758                    .sum::<usize>()
759                    + accum.lm_head_grad.len()
760                    + accum.final_norm_grad.len()
761                    + accum.embedding_grad.len()) as f64
762                    * 4.0
763                    / 1e6,
764            );
765            Some(accum)
766        } else {
767            None
768        };
769
770        // ALB-091: GPU-resident gradient accumulation (eliminates D2H bottleneck).
771        // Falls back to CPU accum if GPU allocation fails.
772        let gpu_grad_accum = if config.accumulation_steps > 1 {
773            match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
774                Ok(accum) => {
775                    println!("  ✓ GPU gradient accumulation enabled (ALB-091)");
776                    Some(accum)
777                }
778                Err(e) => {
779                    eprintln!(
780                        "  [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
781                    );
782                    None
783                }
784            }
785        } else {
786            None
787        };
788
789        // KAIZEN-059: Pre-allocate D2H staging buffer for gradient accumulation
790        // downloads. Only needed when GPU accum is unavailable (CPU fallback path).
791        let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
792            let ws_max = hidden_size * mc.intermediate_size;
793            let lm_max = vocab_size * hidden_size;
794            vec![0.0f32; ws_max.max(lm_max)]
795        } else {
796            Vec::new()
797        };
798
799        // ALB-078: Pre-allocate fused gradient clipping state.
800        // Eliminates 24 stream syncs per step by keeping norm+clip on GPU.
801        let kv_hidden = mc.num_kv_heads * mc.head_dim();
802        let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
803
804        // R-002: Initialize gradient scaler from precision config
805        let grad_scaler = GradScaler::from_config(&config.precision_config);
806        if config.precision_config.is_mixed() {
807            println!(
808                "  ✓ Mixed precision: {} (loss scale={}, dynamic={})",
809                config.precision_config.compute_precision,
810                grad_scaler.scale(),
811                grad_scaler.is_dynamic(),
812            );
813        }
814
815        Ok(Self {
816            model,
817            cuda_trainer,
818            cuda_blocks,
819            cuda_grad_workspace,
820            nf4_shared_scratch,
821            nf4_lora_grad_workspace,
822            nf4_lora_optimizer_states,
823            gpu_training,
824            lm_head_weight_gpu,
825            lm_head_grad_gpu,
826            lm_head_m,
827            lm_head_v,
828            final_norm_m,
829            final_norm_v,
830            embed_optimizer,
831            // KAIZEN-047: Read profile_interval before moving config into struct.
832            profiler: if config.profile_interval > 0 {
833                StepProfiler::new(true, config.profile_interval)
834            } else {
835                StepProfiler::disabled()
836            },
837            config,
838            metrics: MetricsTracker::new(),
839            step: 0,
840            accumulated_loss: 0.0,
841            accumulated_batches: 0,
842            last_grad_norm: 0.0,
843            last_embed_grad_norm: 0.0,
844            grad_accum,
845            gpu_grad_accum,
846            grad_scaler,
847            fwd_scratch_a,
848            fwd_scratch_b,
849            h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
850            d2h_staging,
851            fused_clip,
852            final_norm_zero_buf: vec![0.0f32; hidden_size],
853        })
854    }
855
856    /// Upload transformer blocks to GPU (NF4 or fp32 path).
857    #[allow(clippy::too_many_arguments)]
858    fn upload_blocks(
859        model: &Transformer,
860        mc: &crate::transformer::TransformerConfig,
861        config: &TransformerTrainConfig,
862        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
863        use_nf4: bool,
864        num_layers: usize,
865        hidden_size: usize,
866        max_seq_len: usize,
867    ) -> crate::Result<Vec<CudaBlock>> {
868        let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
869
870        if use_nf4 {
871            let lora_rank = config.lora_rank.unwrap_or(16);
872            let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
873            let lora_scale = lora_alpha / lora_rank as f32;
874            let head_dim = mc.head_dim();
875            let q_dim = mc.num_attention_heads * head_dim;
876            let kv_hidden = mc.num_kv_heads * head_dim;
877
878            for (i, layer) in model.layers.iter().enumerate() {
879                let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
880                    .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
881                    .collect();
882                let lora_b_q = vec![0.0f32; lora_rank * q_dim];
883                let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
884                    .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
885                    .collect();
886                let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
887
888                let q_norm_data = layer
889                    .self_attn
890                    .q_norm
891                    .as_ref()
892                    .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
893                let k_norm_data = layer
894                    .self_attn
895                    .k_norm
896                    .as_ref()
897                    .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
898
899                let block = crate::transformer::CudaNf4TransformerBlock::new(
900                    mc,
901                    i,
902                    ctx.clone(),
903                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
904                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
905                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
906                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
907                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
908                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
909                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
910                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
911                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
912                    max_seq_len,
913                    Some((&lora_a_q, &lora_b_q)),
914                    Some((&lora_a_v, &lora_b_v)),
915                    lora_scale,
916                    lora_rank,
917                    q_norm_data.as_deref(),
918                    k_norm_data.as_deref(),
919                )
920                .map_err(|e| {
921                    crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
922                })?;
923                cuda_blocks.push(CudaBlock::Nf4(block));
924            }
925            println!("  ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
926        } else {
927            for (i, layer) in model.layers.iter().enumerate() {
928                let block = CudaTransformerBlock::new(
929                    mc,
930                    i,
931                    ctx.clone(),
932                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
933                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
934                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
935                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
936                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
937                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
938                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
939                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
940                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
941                    max_seq_len,
942                )
943                .map_err(|e| {
944                    crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
945                })?;
946                cuda_blocks.push(CudaBlock::Fp32(block));
947            }
948            println!("  ✓ {num_layers} transformer blocks uploaded to GPU");
949        }
950
951        Ok(cuda_blocks)
952    }
953
954    /// ALB-078: Initialize fused gradient clipping state (extracted for complexity).
955    fn init_fused_clip(
956        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
957        config: &TransformerTrainConfig,
958        hidden_size: usize,
959        kv_hidden: usize,
960        mc: &crate::transformer::TransformerConfig,
961    ) -> Option<FusedClipState> {
962        config.base.max_grad_norm?;
963        let grad_sizes: [u32; 9] = [
964            (hidden_size * hidden_size) as u32,
965            (hidden_size * kv_hidden) as u32,
966            (hidden_size * kv_hidden) as u32,
967            (hidden_size * hidden_size) as u32,
968            (hidden_size * mc.intermediate_size) as u32,
969            (hidden_size * mc.intermediate_size) as u32,
970            (mc.intermediate_size * hidden_size) as u32,
971            hidden_size as u32,
972            hidden_size as u32,
973        ];
974        match FusedClipState::new(ctx, &grad_sizes) {
975            Ok(state) => {
976                println!(
977                    "  ✓ Fused gradient clipping: {} partials ({:.1} KB)",
978                    state.total_partials,
979                    f64::from(state.total_partials) * 4.0 / 1024.0,
980                );
981                Some(state)
982            }
983            Err(e) => {
984                println!("  ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
985                None
986            }
987        }
988    }
989
990    /// Run one forward+backward step for a single sequence.
991    ///
992    /// # Contract (C-GPUSTEP-001)
993    ///
994    /// - Precondition: `input_ids.len() == target_ids.len() <= max_seq_len`
995    /// - Postcondition: If `accumulate_only` is false, all GPU weights updated.
996    ///   If true, gradients accumulated into CPU buffers (no weight updates).
997    /// - Transfer count: 1 PCIe H2D + ~1KB control (KAIZEN-050, + 24×9 D2H if accumulating)
998    fn train_step_single(
999        &mut self,
1000        input_ids: &[u32],
1001        target_ids: &[u32],
1002        accumulate_only: bool,
1003    ) -> Option<f32> {
1004        self.profiler.begin_step();
1005        let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1006        self.profiler.finish_step();
1007        result
1008    }
1009
1010    /// Inner training step — separated so profiler always records the step.
1011    fn train_step_inner(
1012        &mut self,
1013        input_ids: &[u32],
1014        target_ids: &[u32],
1015        accumulate_only: bool,
1016    ) -> Option<f32> {
1017        let hidden_size = self.config.model_config.hidden_size;
1018        let vocab_size = self.config.model_config.vocab_size;
1019
1020        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
1021        let max_sl = self.config.max_seq_len;
1022        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1023        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1024        let seq_len = input_ids.len();
1025
1026        // Steps 1-6: GPU forward pass — logits stay GPU-resident (KAIZEN-050)
1027        // (sub-phases embed, h2d, forward, norm_lm instrumented inside gpu_forward)
1028        if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1029            eprintln!(
1030                "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1031                 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1032            );
1033            return None;
1034        }
1035
1036        // Step 7: Fused GPU cross-entropy loss + softmax backward (KAIZEN-050)
1037        // Eliminates: logits D2H (77.8MB) + CPU softmax (40ms) + grad H2D (77.8MB)
1038        self.profiler.begin(StepProfiler::LOSS);
1039        let stream = self.cuda_trainer.stream();
1040
1041        // Compute combined scale: (1/seq_len) * (1/accum_steps)
1042        //
1043        // ALB-072: Do NOT multiply by grad_scaler.scale() here. All backward
1044        // computation uses f32 GpuBuffers — there is no fp16 gradient underflow
1045        // risk. The 65536x loss scaling caused gradient overflow in early layers
1046        // (blocks 0-1 went NaN). The GradScaler remains active for the CPU
1047        // embedding path (unscale_and_check in optimizer_step) as a safety check,
1048        // but it operates with scale=1.0 effective for GPU gradients.
1049        let mut loss_scale = 1.0 / seq_len as f32;
1050        if self.config.accumulation_steps > 1 {
1051            loss_scale /= self.config.accumulation_steps as f32;
1052        }
1053
1054        // KAIZEN-052: In-place — gradient written directly to logits_buf.
1055        let loss_val = fused_cross_entropy_cuda(
1056            &mut self.gpu_training.logits_buf,
1057            target_ids,
1058            seq_len as u32,
1059            vocab_size as u32,
1060            loss_scale,
1061            stream,
1062        )
1063        .ok()?;
1064
1065        // NaN guard (replaces logits NaN check — NaN logits → NaN loss via kernel)
1066        if !loss_val.is_finite() {
1067            return None;
1068        }
1069        self.profiler.end(StepProfiler::LOSS);
1070
1071        // Steps 8-11: GPU backward pass (with or without optimizer)
1072        // (sub-phases lm_bwd, norm_bwd, blk_bwd instrumented inside gpu_backward)
1073        // KAIZEN-050: grad_logits on GPU. KAIZEN-052: grad lives in logits_buf (in-place).
1074        //
1075        // ENT-263 fix: Capture loss regardless of backward success. The NF4 backward
1076        // path may fail (e.g., gemm_nf4_backward_a stub) but the loss was already
1077        // computed by fused_cross_entropy_cuda. Dropping the loss silently causes
1078        // loss=0.0 reporting despite valid forward passes.
1079        if let Some(grad_output_is_a) =
1080            self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1081        {
1082            // Step 12: Embedding backward (CPU scatter-add always accumulates)
1083            self.profiler.begin(StepProfiler::EMBED_BWD);
1084            self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1085
1086            self.profiler.end(StepProfiler::EMBED_BWD);
1087        }
1088
1089        Some(loss_val)
1090    }
1091
1092    /// GPU forward pass: embed → blocks → norm → LM head.
1093    ///
1094    /// Logits stay GPU-resident in `self.gpu_training.logits_buf` (KAIZEN-050).
1095    /// Transfers: 1 H2D (hidden states). No D2H — logits consumed by fused kernel.
1096    #[allow(unsafe_code)]
1097    fn gpu_forward(
1098        &mut self,
1099        input_ids: &[u32],
1100        seq_len: usize,
1101        hidden_size: usize,
1102        vocab_size: usize,
1103    ) -> Option<()> {
1104        contract_pre_gpu_forward!();
1105        let stream = self.cuda_trainer.stream();
1106
1107        // Embedding lookup (CPU)
1108        self.profiler.begin(StepProfiler::EMBED);
1109        let hidden = self.model.embed_tokens.forward(input_ids);
1110        let hidden_slice = hidden.data().as_slice()?;
1111        self.profiler.end(StepProfiler::EMBED);
1112
1113        // Upload hidden states to GPU (Transfer 1: H2D)
1114        // Pad to max_seq_len so D2D copies to pre-allocated layer_inputs match.
1115        // KAIZEN-053: Reuse pre-allocated scratch buffers instead of cuMemAlloc per step.
1116        // KAIZEN-056: Reuse pre-allocated h2d_staging instead of alloc per step.
1117        self.profiler.begin(StepProfiler::H2D);
1118        self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1119        self.h2d_staging[hidden_slice.len()..].fill(0.0);
1120        if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1121            eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1122            return None;
1123        }
1124        self.profiler.end(StepProfiler::H2D);
1125
1126        // Forward through CUDA blocks using pre-allocated ping-pong buffers.
1127        // KAIZEN-053: fwd_scratch_a/b are top-level fields (not in gpu_training)
1128        // so borrowing them doesn't conflict with gpu_training.layer_inputs.
1129        self.profiler.begin(StepProfiler::FORWARD);
1130        let mut input_is_a = true; // Track which scratch buffer is "input"
1131        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1132            // Use raw pointers for the ping-pong to avoid borrow conflicts
1133            // with self.gpu_training.layer_inputs
1134            let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1135                if input_is_a {
1136                    (
1137                        std::ptr::from_ref(&self.fwd_scratch_a),
1138                        std::ptr::from_mut(&mut self.fwd_scratch_b),
1139                    )
1140                } else {
1141                    (
1142                        std::ptr::from_ref(&self.fwd_scratch_b),
1143                        std::ptr::from_mut(&mut self.fwd_scratch_a),
1144                    )
1145                };
1146            if self.gpu_training.saved_layer_mask[i] {
1147                // SAFETY: Both buffers are valid GPU allocations with matching max_seq_len size.
1148                // Copy completes before block.forward() reads from input (same stream ordering).
1149                unsafe {
1150                    self.gpu_training.layer_inputs[i]
1151                        .copy_from_buffer_async(&*input_ptr, stream)
1152                        .ok()?;
1153                }
1154            }
1155            // SAFETY: input_ptr and output_ptr point to disjoint fwd_scratch_{a,b}.
1156            // ENT-263: Pass shared scratch for NF4 blocks (C-SCRATCH-001).
1157            self.profiler.begin_layer();
1158            unsafe {
1159                block
1160                    .forward(
1161                        &*input_ptr,
1162                        &mut *output_ptr,
1163                        seq_len,
1164                        stream,
1165                        self.nf4_shared_scratch.as_mut(),
1166                    )
1167                    .ok()?;
1168            }
1169            self.profiler.end_layer_fwd(i);
1170            input_is_a = !input_is_a;
1171        }
1172        self.profiler.end(StepProfiler::FORWARD);
1173
1174        // After the loop, input_is_a tells us which buffer has the final output
1175        let final_output: &GpuBuffer<f32> =
1176            if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1177
1178        // Save blocks output for final norm backward
1179        // SAFETY: Disjoint GPU buffers with matching max_seq_len sizes.
1180        self.profiler.begin(StepProfiler::NORM_LM);
1181        unsafe {
1182            self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1183        }
1184
1185        // Final RMSNorm forward (GPU)
1186        rms_norm_forward(
1187            final_output,
1188            &self.gpu_training.final_norm_weight,
1189            &mut self.gpu_training.norm_output,
1190            seq_len as u32,
1191            hidden_size as u32,
1192            stream,
1193        )
1194        .ok()?;
1195
1196        // LM head GEMM forward (GPU)
1197        // gemm_forward treats flat (V,H) memory as (H,V) row-major, which
1198        // implicitly transposes — matching the CPU matmul's tied-weight behavior.
1199        gemm_forward(
1200            &self.gpu_training.norm_output,
1201            &self.lm_head_weight_gpu,
1202            &mut self.gpu_training.logits_buf,
1203            seq_len as u32,
1204            hidden_size as u32,
1205            vocab_size as u32,
1206            stream,
1207        )
1208        .ok()?;
1209
1210        // KAIZEN-050: Logits stay GPU-resident — no D2H transfer.
1211        // Fused cross-entropy kernel reads logits_buf directly on GPU.
1212        self.profiler.end(StepProfiler::NORM_LM);
1213
1214        Some(())
1215    }
1216
1217    /// ALB-089: Forward-only pass that returns last-position logits on CPU.
1218    ///
1219    /// Runs the same GPU forward as training but downloads only the last
1220    /// position's logits (vocab_size floats) for token sampling. No backward
1221    /// pass, no loss computation.
1222    ///
1223    /// # Contract (C-CUDA-INF-001)
1224    ///
1225    /// - Same forward path as `gpu_forward()` — identical logits
1226    /// - Only downloads `logits[seq_len-1, :]` (128 KB for 32K vocab)
1227    /// - stream.synchronize() before D2H (C-STREAMSYNC-001)
1228    pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1229        let seq_len = input_ids.len();
1230        let hidden_size = self.config.model_config.hidden_size;
1231        let vocab_size = self.config.model_config.vocab_size;
1232
1233        if seq_len == 0 || seq_len > self.config.max_seq_len {
1234            return None;
1235        }
1236
1237        // Reuse gpu_forward for the actual computation
1238        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1239
1240        // C-STREAMSYNC-001: synchronize before D2H
1241        let stream = self.cuda_trainer.stream();
1242        stream.synchronize().ok()?;
1243
1244        // Download last position logits only: logits_buf[seq_len-1, :]
1245        let offset = (seq_len - 1) * vocab_size;
1246        let mut logits = vec![0.0f32; vocab_size];
1247        self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1248
1249        Some(logits)
1250    }
1251
1252    /// GPU backward pass with interleaved per-block optimizer step.
1253    ///
1254    /// Each block's backward writes weight gradients to the shared `CudaGradWorkspace`.
1255    /// Recompute layer inputs for a segment during backward (activation checkpointing).
1256    ///
1257    /// When checkpointing is enabled, non-checkpoint layers don't save their inputs
1258    /// during forward. Before their backward pass, we recompute from the nearest
1259    /// checkpoint by re-running forward through intermediate blocks.
1260    ///
1261    /// This recomputes the entire segment [checkpoint..=target_layer], storing
1262    /// intermediate layer_inputs so subsequent layers in the same segment don't
1263    /// need redundant recomputation.
1264    ///
1265    /// # Contract (R-021)
1266    ///
1267    /// After this call, `layer_inputs[i]` is valid for all i in [checkpoint..=target_layer].
1268    #[allow(unsafe_code)]
1269    fn recompute_segment(
1270        gpu_training: &mut GpuPretrainState,
1271        cuda_blocks: &mut [CudaBlock],
1272        nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1273        target_layer: usize,
1274        seq_len: usize,
1275        stream: &CudaStream,
1276    ) -> Option<()> {
1277        // Find nearest saved checkpoint at or before target
1278        let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1279
1280        if seg_start == target_layer {
1281            return Some(()); // Already saved
1282        }
1283
1284        // Copy checkpoint input to recompute_buf as starting point.
1285        // SAFETY: recompute_buf and layer_inputs are disjoint allocations.
1286        let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1287        unsafe {
1288            recompute_buf
1289                .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1290                .ok()?;
1291        }
1292
1293        // Forward through blocks [seg_start..target_layer], saving intermediate inputs.
1294        // For block i, input → block i → output becomes input for block i+1.
1295        // We save output (= input to block i+1) in layer_inputs[i+1].
1296        //
1297        // Buffer pattern:
1298        //   i == seg_start: input = recompute_buf, output = layer_inputs[seg_start+1]
1299        //   i > seg_start:  input = layer_inputs[i], output = layer_inputs[i+1]
1300        //
1301        // SAFETY: split_at_mut ensures non-overlapping borrows of layer_inputs.
1302        // recompute_buf is separate from layer_inputs.
1303        for i in seg_start..target_layer {
1304            if i == seg_start {
1305                // Input is in recompute_buf, output goes to layer_inputs[i+1]
1306                let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1307                let li = &mut gpu_training.layer_inputs;
1308                unsafe {
1309                    cuda_blocks[i]
1310                        .forward(
1311                            &*recompute_ptr,
1312                            &mut li[i + 1],
1313                            seq_len,
1314                            stream,
1315                            nf4_shared_scratch.as_mut(),
1316                        )
1317                        .ok()?;
1318                }
1319            } else {
1320                // Input = layer_inputs[i], output = layer_inputs[i+1]
1321                let li = &mut gpu_training.layer_inputs;
1322                let (left, right) = li.split_at_mut(i + 1);
1323                cuda_blocks[i]
1324                    .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1325                    .ok()?;
1326            }
1327        }
1328
1329        Some(())
1330    }
1331
1332    /// Since `gemm_backward_b` overwrites (not accumulates), we must run each block's
1333    /// optimizer step immediately after its backward, before the next block overwrites
1334    /// the workspace. This also enables per-block gradient clipping.
1335    ///
1336    /// When `accumulate_only` is true (R-038 gradient accumulation), the per-block
1337    /// optimizer steps are skipped and workspace gradients are downloaded to CPU-side
1338    /// `PerBlockGradientAccumulator` instead. LM head and final norm gradients are
1339    /// also downloaded and accumulated. The optimizer step is deferred until
1340    /// `gpu_optimizer_from_accum()` is called.
1341    ///
1342    /// Returns `grad_output_is_a` flag for embedding backward.
1343    /// Transfer: 0 H2D (KAIZEN-050/052: grad in logits_buf) + 24×9 D2H if accumulating.
1344    #[allow(unsafe_code)]
1345    fn gpu_backward(
1346        &mut self,
1347        seq_len: usize,
1348        hidden_size: usize,
1349        vocab_size: usize,
1350        accumulate_only: bool,
1351    ) -> Option<bool> {
1352        let stream = self.cuda_trainer.stream();
1353        let max_grad_norm = self.config.base.max_grad_norm;
1354        let lr = self.current_lr();
1355        // ALB-072: No inv_scale needed — loss_scale no longer includes grad_scaler.
1356        let beta1 = self.config.beta1;
1357        let beta2 = self.config.beta2;
1358        let weight_decay = self.config.weight_decay;
1359
1360        // KAIZEN-050: grad_logits GPU-resident. KAIZEN-052: grad lives in logits_buf (in-place).
1361        // No separate grad buffer. No GRAD_H2D transfer.
1362
1363        // LM head GEMM backward
1364        self.profiler.begin(StepProfiler::LM_BWD);
1365        gemm_backward_a(
1366            &self.gpu_training.logits_buf,
1367            &self.lm_head_weight_gpu,
1368            &mut self.gpu_training.lm_head_grad_hidden,
1369            seq_len as u32,
1370            hidden_size as u32,
1371            vocab_size as u32,
1372            stream,
1373        )
1374        .ok()?;
1375
1376        gemm_backward_b(
1377            &self.gpu_training.norm_output,
1378            &self.gpu_training.logits_buf,
1379            &mut self.lm_head_grad_gpu,
1380            seq_len as u32,
1381            hidden_size as u32,
1382            vocab_size as u32,
1383            stream,
1384        )
1385        .ok()?;
1386
1387        // Clip LM head weight gradient
1388        // KAIZEN-049: GPU norm reduction.
1389        // KAIZEN-051: No explicit sync needed — same stream ordering.
1390        // ALB-071: Always compute LM head grad norm for observability (R-004).
1391        // C-CLIP-001: squared_sum_cuda returns ||g||². Take sqrt for L2 norm (entrenar#311).
1392        let lm_sq_norm =
1393            squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1394                .unwrap_or(0.0);
1395        let lm_norm = lm_sq_norm.sqrt(); // L2 norm, NOT squared
1396        self.last_grad_norm = lm_norm; // R-004: capture for observability
1397                                       // C-BACKPARITY-001: LM head gradient norm tracing (pre-clip).
1398        if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1399            eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1400            // Also trace the grad_hidden flowing to blocks
1401            let gh_sq = squared_sum_cuda(
1402                &self.gpu_training.lm_head_grad_hidden,
1403                self.gpu_training.lm_head_grad_hidden.len() as u32,
1404                stream,
1405            )
1406            .unwrap_or(0.0);
1407            eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1408        }
1409        if let Some(max_norm) = max_grad_norm {
1410            let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1411            let n = self.lm_head_grad_gpu.len() as u32;
1412            let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1413        }
1414        self.profiler.end(StepProfiler::LM_BWD);
1415
1416        // Final RMSNorm backward
1417        self.profiler.begin(StepProfiler::NORM_BWD);
1418        // Zero grad_final_norm_weight before backward — kernel accumulates via atomicAdd
1419        self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1420        rms_norm_backward(
1421            &self.gpu_training.blocks_output,
1422            &self.gpu_training.final_norm_weight,
1423            &self.gpu_training.lm_head_grad_hidden,
1424            &mut self.gpu_training.grad_buf_a,
1425            &mut self.gpu_training.grad_final_norm_weight,
1426            seq_len as u32,
1427            hidden_size as u32,
1428            1e-5_f32,
1429            stream,
1430        )
1431        .ok()?;
1432
1433        // Clip final norm weight gradient
1434        // KAIZEN-051: No explicit sync needed — same stream ordering as LM head clip.
1435        if let Some(max_norm) = max_grad_norm {
1436            let (scale, _) = Self::compute_clip_scale_with_norm(
1437                &self.gpu_training.grad_final_norm_weight,
1438                max_norm,
1439                stream,
1440            );
1441            let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1442            let _ =
1443                gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1444        }
1445        self.profiler.end(StepProfiler::NORM_BWD);
1446
1447        // R-038: Either accumulate non-block grads or run non-block optimizer.
1448        if accumulate_only {
1449            // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1450            if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1451                let _ = gpu_accum.accumulate_nonblock(
1452                    &self.lm_head_grad_gpu,
1453                    &self.gpu_training.grad_final_norm_weight,
1454                    stream,
1455                );
1456            } else {
1457                stream.synchronize().ok()?;
1458                Self::download_nonblock_grads_to_accum(
1459                    &self.lm_head_grad_gpu,
1460                    &self.gpu_training.grad_final_norm_weight,
1461                    &mut self.grad_accum,
1462                    &mut self.d2h_staging,
1463                )?;
1464            }
1465        } else {
1466            Self::run_nonblock_optimizer_step(
1467                &mut self.gpu_training,
1468                Some(&mut self.lm_head_weight_gpu),
1469                &self.lm_head_grad_gpu,
1470                &mut self.lm_head_m,
1471                &mut self.lm_head_v,
1472                &mut self.final_norm_m,
1473                &mut self.final_norm_v,
1474                lr,
1475                beta1,
1476                beta2,
1477                weight_decay,
1478                stream,
1479            );
1480        }
1481
1482        // Backward through blocks in reverse, with interleaved clip + optimizer.
1483        // Each block's backward writes weight gradients to shared CudaGradWorkspace.
1484        //
1485        // SAFETY: grad_buf_a and grad_buf_b are disjoint fields. Raw pointers
1486        // allow alternating read/write without violating aliasing rules.
1487        self.profiler.begin(StepProfiler::BLK_BWD);
1488        let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1489        let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1490        let mut grad_output_is_a = true;
1491        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1492
1493        for layer_idx in (0..self.cuda_blocks.len()).rev() {
1494            // Activation checkpointing: if this layer's input wasn't saved during
1495            // forward, recompute the segment from the nearest checkpoint.
1496            if !self.gpu_training.saved_layer_mask[layer_idx] {
1497                Self::recompute_segment(
1498                    &mut self.gpu_training,
1499                    &mut self.cuda_blocks,
1500                    &mut self.nf4_shared_scratch,
1501                    layer_idx,
1502                    seq_len,
1503                    stream,
1504                )?;
1505            }
1506
1507            let (grad_output, grad_input) = unsafe {
1508                if grad_output_is_a {
1509                    (&*grad_a_ptr, &mut *grad_b_ptr)
1510                } else {
1511                    (&*grad_b_ptr, &mut *grad_a_ptr)
1512                }
1513            };
1514
1515            self.profiler.begin_layer();
1516            if use_nf4 {
1517                // ENT-263: NF4 backward — LoRA gradient computation
1518                // Uses backward_nf4() which computes gradients for LoRA weights and norms only.
1519                // output_scratch reuses grad_buf_a/b as temporary storage for recomputed forward.
1520                let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1521                    grad_b_ptr // grad_input is in b, use as output_scratch too (will be overwritten)
1522                } else {
1523                    grad_a_ptr
1524                };
1525                // We need a separate output_scratch. Reuse blocks_output as scratch since
1526                // it was already consumed for norm backward above.
1527                match self.cuda_blocks[layer_idx].backward_nf4(
1528                    &self.gpu_training.layer_inputs[layer_idx],
1529                    grad_output,
1530                    grad_input,
1531                    &mut self.gpu_training.blocks_output, // reuse as output_scratch
1532                    seq_len,
1533                    stream,
1534                    self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1535                    self.nf4_lora_grad_workspace
1536                        .as_mut()
1537                        .expect("NF4 requires LoRA grad workspace"),
1538                ) {
1539                    Ok(()) => {}
1540                    Err(e) => {
1541                        eprintln!(
1542                            "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1543                            layer_idx, e, seq_len, self.config.model_config.hidden_size
1544                        );
1545                        return None;
1546                    }
1547                }
1548
1549                // ENT-265: Clip LoRA gradients before optimizer step.
1550                // Without this, NF4 LoRA grads are unbounded — causes weight
1551                // divergence and embedding grad explosion (Run 7c: 26M at step 225).
1552                if let Some(max_norm) = max_grad_norm {
1553                    self.nf4_lora_grad_workspace
1554                        .as_mut()
1555                        .expect("NF4 requires LoRA grad ws")
1556                        .clip_gradients(max_norm, stream);
1557                }
1558
1559                // NF4 LoRA optimizer step — always runs, even during accumulation.
1560                //
1561                // BUG FIX (entrenar#264): Previously gated by `if !accumulate_only`.
1562                // Design: NF4 LoRA has ~6M params, so we scale lr by 1/accum_steps
1563                // for micro-batches instead of accumulating gradients.
1564                {
1565                    let step = self.gpu_training.step;
1566                    let effective_lr = if accumulate_only {
1567                        lr / self.config.accumulation_steps as f32
1568                    } else {
1569                        lr
1570                    };
1571                    if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1572                        let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1573                            &mut opt_states[layer_idx],
1574                            step,
1575                            effective_lr,
1576                            beta1,
1577                            beta2,
1578                            1e-8,
1579                            weight_decay,
1580                            stream,
1581                            self.nf4_lora_grad_workspace
1582                                .as_ref()
1583                                .expect("NF4 requires LoRA grad ws"),
1584                        );
1585                    }
1586                }
1587            } else {
1588                // Standard fp32 backward path
1589                self.cuda_blocks[layer_idx]
1590                    .backward(
1591                        &self.gpu_training.layer_inputs[layer_idx],
1592                        grad_output,
1593                        grad_input,
1594                        seq_len,
1595                        stream,
1596                        &mut self.cuda_grad_workspace,
1597                    )
1598                    .ok()?;
1599
1600                // C-CLIP-001 / entrenar#312: DISABLED per-block gradient clipping.
1601                // Per-block clipping distorts gradient flow across layers.
1602
1603                // C-BACKPARITY-001: Per-block gradient norm tracing for parity testing.
1604                // Only runs when ENTRENAR_TRACE_GRADIENTS=1 — zero overhead in production.
1605                if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1606                    let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1607                        &self.cuda_grad_workspace,
1608                        f32::MAX,
1609                        stream,
1610                    );
1611                    // Also trace the activation gradient (flows between blocks)
1612                    let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1613                        .unwrap_or(0.0);
1614                    let act_gnorm = act_sq.sqrt();
1615                    eprintln!(
1616                        "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1617                    );
1618                }
1619
1620                // R-038: Either accumulate workspace grads or run optimizer per-block.
1621                if accumulate_only {
1622                    // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1623                    if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1624                        let _ = gpu_accum.accumulate_block(
1625                            &self.cuda_grad_workspace,
1626                            layer_idx,
1627                            stream,
1628                        );
1629                    } else {
1630                        // CPU fallback: SYNC + D2H (ALB-065 / Rule 6).
1631                        stream.synchronize().ok()?;
1632                        if let Some(accum) = &mut self.grad_accum {
1633                            Self::download_workspace_to_accum(
1634                                &self.cuda_grad_workspace,
1635                                accum,
1636                                layer_idx,
1637                                &mut self.d2h_staging,
1638                            )?;
1639                        }
1640                    }
1641                } else {
1642                    // Per-block optimizer step: consume workspace gradients before next block overwrites
1643                    let step = self.gpu_training.step;
1644                    let _ = self.cuda_blocks[layer_idx].optimizer_step(
1645                        &mut self.gpu_training.optimizer_states[layer_idx],
1646                        step,
1647                        lr,
1648                        beta1,
1649                        beta2,
1650                        1e-8,
1651                        weight_decay,
1652                        stream,
1653                        &self.cuda_grad_workspace,
1654                    );
1655                }
1656            }
1657
1658            self.profiler.end_layer_bwd(layer_idx);
1659            grad_output_is_a = !grad_output_is_a;
1660        }
1661
1662        stream.synchronize().ok()?;
1663        self.profiler.end(StepProfiler::BLK_BWD);
1664
1665        Some(grad_output_is_a)
1666    }
1667
1668    /// R-038: Download non-block (LM head + final norm) gradients to CPU accumulator.
1669    /// Static method to avoid borrow conflicts.
1670    // KAIZEN-044: Pre-allocate single buffer for LM head + norm D2H downloads.
1671    // lm_head_grad is vocab×hidden (389M elements = 1.5 GB for Qwen3-4B).
1672    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1673    fn download_nonblock_grads_to_accum(
1674        lm_head_grad: &GpuBuffer<f32>,
1675        final_norm_grad: &GpuBuffer<f32>,
1676        grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1677        host: &mut [f32],
1678    ) -> Option<()> {
1679        let accum = grad_accum.as_mut()?;
1680
1681        let lm_slice = &mut host[..lm_head_grad.len()];
1682        lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1683        for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1684            *d += s;
1685        }
1686
1687        let norm_slice = &mut host[..final_norm_grad.len()];
1688        final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1689        for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1690            *d += s;
1691        }
1692        Some(())
1693    }
1694
1695    /// Run LM head + final norm optimizer step (non-accumulating path).
1696    /// Static method to avoid borrow conflicts with `stream`.
1697    #[allow(clippy::too_many_arguments)]
1698    fn run_nonblock_optimizer_step(
1699        gpu_training: &mut GpuPretrainState,
1700        lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1701        lm_head_grad_gpu: &GpuBuffer<f32>,
1702        lm_head_m: &mut GpuBuffer<f32>,
1703        lm_head_v: &mut GpuBuffer<f32>,
1704        final_norm_m: &mut GpuBuffer<f32>,
1705        final_norm_v: &mut GpuBuffer<f32>,
1706        lr: f32,
1707        beta1: f32,
1708        beta2: f32,
1709        weight_decay: f32,
1710        stream: &CudaStream,
1711    ) {
1712        gpu_training.step += 1;
1713        let step = gpu_training.step;
1714
1715        if let Some(lm_head_weight) = lm_head_weight_gpu {
1716            let n_lm = lm_head_weight.len() as u32;
1717            let _ = adamw_step_cuda(
1718                lm_head_weight,
1719                lm_head_grad_gpu,
1720                lm_head_m,
1721                lm_head_v,
1722                lr,
1723                beta1,
1724                beta2,
1725                1e-8,
1726                weight_decay,
1727                step,
1728                n_lm,
1729                stream,
1730            );
1731        }
1732
1733        let n_norm = gpu_training.final_norm_weight.len() as u32;
1734        let _ = adamw_step_cuda(
1735            &mut gpu_training.final_norm_weight,
1736            &gpu_training.grad_final_norm_weight,
1737            final_norm_m,
1738            final_norm_v,
1739            lr,
1740            beta1,
1741            beta2,
1742            1e-8,
1743            weight_decay,
1744            step,
1745            n_norm,
1746            stream,
1747        );
1748    }
1749
1750    /// R-038: Download shared CudaGradWorkspace to CPU per-block accumulation buffers.
1751    ///
1752    /// Static method to avoid borrow conflicts with `stream` (same pattern as
1753    /// `recompute_segment`). Must be called after stream.synchronize() (ALB-065 / Rule 6).
1754    // KAIZEN-044: Pre-allocate a single host buffer for all D2H downloads
1755    // in download_workspace_to_accum. Was allocating vec![0.0f32; len] × 9 buffers.
1756    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1757    fn download_workspace_to_accum(
1758        ws: &CudaGradWorkspace,
1759        accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1760        layer_idx: usize,
1761        host: &mut [f32],
1762    ) -> Option<()> {
1763        let bg = &mut accum.block_grads[layer_idx];
1764
1765        use super::grad_accumulator::component;
1766        let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1767            (&ws.grad_w_q, component::W_Q),
1768            (&ws.grad_w_k, component::W_K),
1769            (&ws.grad_w_v, component::W_V),
1770            (&ws.grad_w_o, component::W_O),
1771            (&ws.grad_gate, component::GATE),
1772            (&ws.grad_up, component::UP),
1773            (&ws.grad_down, component::DOWN),
1774            (&ws.grad_input_norm, component::INPUT_NORM),
1775            (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1776        ];
1777
1778        for (gpu_buf, comp_idx) in &bufs_and_components {
1779            let slice = &mut host[..gpu_buf.len()];
1780            gpu_buf.copy_to_host_at(slice, 0).ok()?;
1781            for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1782                *d += s;
1783            }
1784        }
1785        Some(())
1786    }
1787
1788    /// R-038: Upload averaged CPU accumulation buffers to GPU workspace and run
1789    /// optimizer step for all blocks + LM head + final norm.
1790    ///
1791    /// Called once after `accumulation_steps` micro-batches have been accumulated.
1792    /// ALB-091: Run optimizer step from GPU-resident accumulated gradients.
1793    /// D2D copy accum → workspace, then run per-block optimizer. Zero accum after.
1794    fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1795        let stream = self.cuda_trainer.stream();
1796        let lr = self.current_lr();
1797        let beta1 = self.config.beta1;
1798        let beta2 = self.config.beta2;
1799        let weight_decay = self.config.weight_decay;
1800
1801        // Sync once to ensure all accumulation kernels complete
1802        stream.synchronize().ok()?;
1803
1804        self.gpu_training.step += 1;
1805        let step = self.gpu_training.step;
1806
1807        // Upload GPU accum → workspace (D2D) and run optimizer per block
1808        let gpu_accum = self.gpu_grad_accum.as_ref()?;
1809        for layer_idx in 0..self.cuda_blocks.len() {
1810            gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1811
1812            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1813                &mut self.gpu_training.optimizer_states[layer_idx],
1814                step,
1815                lr,
1816                beta1,
1817                beta2,
1818                1e-8,
1819                weight_decay,
1820                stream,
1821                &self.cuda_grad_workspace,
1822            );
1823        }
1824
1825        // LM head: D2D copy accum → grad buffer, then optimizer step
1826        gpu_accum
1827            .upload_nonblock(
1828                &mut self.lm_head_grad_gpu,
1829                &mut self.gpu_training.grad_final_norm_weight,
1830            )
1831            .ok()?;
1832
1833        let n_lm = self.lm_head_weight_gpu.len() as u32;
1834        let _ = adamw_step_cuda(
1835            &mut self.lm_head_weight_gpu,
1836            &self.lm_head_grad_gpu,
1837            &mut self.lm_head_m,
1838            &mut self.lm_head_v,
1839            lr,
1840            beta1,
1841            beta2,
1842            1e-8,
1843            weight_decay,
1844            step,
1845            n_lm,
1846            stream,
1847        );
1848
1849        // Final norm optimizer step
1850        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1851        let _ = adamw_step_cuda(
1852            &mut self.gpu_training.final_norm_weight,
1853            &self.gpu_training.grad_final_norm_weight,
1854            &mut self.final_norm_m,
1855            &mut self.final_norm_v,
1856            lr,
1857            beta1,
1858            beta2,
1859            1e-8,
1860            weight_decay,
1861            step,
1862            n_norm,
1863            stream,
1864        );
1865
1866        stream.synchronize().ok()?;
1867
1868        // Zero accum for next window
1869        if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1870            let _ = gpu_accum.zero_all();
1871        }
1872
1873        Some(())
1874    }
1875
1876    #[allow(unsafe_code)]
1877    fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1878        let stream = self.cuda_trainer.stream();
1879        let lr = self.current_lr();
1880        let beta1 = self.config.beta1;
1881        let beta2 = self.config.beta2;
1882        let weight_decay = self.config.weight_decay;
1883
1884        // Average accumulated gradients
1885        let accum = self.grad_accum.as_mut()?;
1886        accum.average();
1887
1888        // Jidoka: check for NaN/Inf before applying
1889        if accum.has_non_finite() {
1890            println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1891            accum.zero_all();
1892            return Some(());
1893        }
1894
1895        self.gpu_training.step += 1;
1896        let step = self.gpu_training.step;
1897
1898        // Upload accumulated gradients and run optimizer for each block
1899        use super::grad_accumulator::component;
1900        for layer_idx in 0..self.cuda_blocks.len() {
1901            let bg = &accum.block_grads[layer_idx];
1902
1903            // Upload accumulated gradients to shared workspace
1904            // SAFETY: async host-to-device copies within the training stream; host buffers
1905            // (bg.components) are stable for the duration of the stream operations.
1906            unsafe {
1907                self.cuda_grad_workspace
1908                    .grad_w_q
1909                    .copy_from_host_async(&bg.components[component::W_Q], stream)
1910                    .ok()?;
1911                self.cuda_grad_workspace
1912                    .grad_w_k
1913                    .copy_from_host_async(&bg.components[component::W_K], stream)
1914                    .ok()?;
1915                self.cuda_grad_workspace
1916                    .grad_w_v
1917                    .copy_from_host_async(&bg.components[component::W_V], stream)
1918                    .ok()?;
1919                self.cuda_grad_workspace
1920                    .grad_w_o
1921                    .copy_from_host_async(&bg.components[component::W_O], stream)
1922                    .ok()?;
1923                self.cuda_grad_workspace
1924                    .grad_gate
1925                    .copy_from_host_async(&bg.components[component::GATE], stream)
1926                    .ok()?;
1927                self.cuda_grad_workspace
1928                    .grad_up
1929                    .copy_from_host_async(&bg.components[component::UP], stream)
1930                    .ok()?;
1931                self.cuda_grad_workspace
1932                    .grad_down
1933                    .copy_from_host_async(&bg.components[component::DOWN], stream)
1934                    .ok()?;
1935                self.cuda_grad_workspace
1936                    .grad_input_norm
1937                    .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1938                    .ok()?;
1939                self.cuda_grad_workspace
1940                    .grad_post_attn_norm
1941                    .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1942                    .ok()?;
1943            }
1944
1945            // Run optimizer step with uploaded averaged gradients
1946            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1947                &mut self.gpu_training.optimizer_states[layer_idx],
1948                step,
1949                lr,
1950                beta1,
1951                beta2,
1952                1e-8,
1953                weight_decay,
1954                stream,
1955                &self.cuda_grad_workspace,
1956            );
1957        }
1958
1959        // Upload accumulated LM head gradients and run AdamW step
1960        // entrenar#314: Skip GPU LM head optimizer for tied weights.
1961        // SAFETY: async host-to-device copy; host buffer (accum.lm_head_grad) is stable.
1962        unsafe {
1963            self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1964        }
1965        let n_lm = self.lm_head_weight_gpu.len() as u32;
1966        let _ = adamw_step_cuda(
1967            &mut self.lm_head_weight_gpu,
1968            &self.lm_head_grad_gpu,
1969            &mut self.lm_head_m,
1970            &mut self.lm_head_v,
1971            lr,
1972            beta1,
1973            beta2,
1974            1e-8,
1975            weight_decay,
1976            step,
1977            n_lm,
1978            stream,
1979        );
1980
1981        // Upload accumulated final norm gradients and run AdamW step
1982        // SAFETY: async host-to-device copy; host buffer (accum.final_norm_grad) is stable.
1983        unsafe {
1984            self.gpu_training
1985                .grad_final_norm_weight
1986                .copy_from_host_async(&accum.final_norm_grad, stream)
1987                .ok()?;
1988        }
1989        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1990        let _ = adamw_step_cuda(
1991            &mut self.gpu_training.final_norm_weight,
1992            &self.gpu_training.grad_final_norm_weight,
1993            &mut self.final_norm_m,
1994            &mut self.final_norm_v,
1995            lr,
1996            beta1,
1997            beta2,
1998            1e-8,
1999            weight_decay,
2000            step,
2001            n_norm,
2002            stream,
2003        );
2004
2005        stream.synchronize().ok()?;
2006
2007        // Zero accum for next window
2008        accum.zero_all();
2009        Some(())
2010    }
2011
2012    /// Compute gradient L2 norm via GPU reduction kernel (KAIZEN-049).
2013    ///
2014    /// Runs `SquaredSumKernel` on GPU, downloads only `num_blocks` partial sums (~1KB)
2015    /// instead of the full buffer (128MB for lm_head). Falls back to CPU download on error.
2016    ///
2017    /// # Contract (C-CLIPNORM-GPU-001)
2018    ///
2019    /// - **Precondition**: `buf.len() > 0`, stream is synchronized with prior kernel
2020    /// - **Postcondition**: `grad_norm ≈ sqrt(sum(buf[i]^2))`, `scale = min(1, max_norm/norm)`
2021    /// - **Transfer**: ~1KB D2H (num_blocks × 4B) vs n×4B (128MB for 32M elements)
2022    ///
2023    /// R-004: Returns `(clip_scale, grad_norm)` for observability.
2024    fn compute_clip_scale_with_norm(
2025        buf: &GpuBuffer<f32>,
2026        max_norm: f32,
2027        stream: &CudaStream,
2028    ) -> (f32, f32) {
2029        let n = buf.len() as u32;
2030        // Try GPU reduction first — ~1KB D2H instead of n×4 bytes
2031        let grad_norm = match squared_sum_cuda(buf, n, stream) {
2032            Ok(norm) => norm,
2033            Err(_) => {
2034                // Fallback: full D2H (original path)
2035                let mut host = vec![0.0f32; buf.len()];
2036                if buf.copy_to_host_at(&mut host, 0).is_err() {
2037                    return (1.0, 0.0);
2038                }
2039                let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2040                sq_sum.sqrt() as f32
2041            }
2042        };
2043        let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2044        (scale, grad_norm)
2045    }
2046
2047    /// Download embedding gradient from GPU, clip, and scatter-add into CPU weight.
2048    ///
2049    /// # Contract (C-EMBED-GRAD-001)
2050    ///
2051    /// The activation gradient from block[0]'s backward is unclipped (per-block clipping
2052    /// only applies to weight gradients in the shared workspace). For deep networks with
2053    /// random init, this gradient can overflow f32, producing NaN in the CPU AdamW.
2054    /// We clip the activation gradient to max_grad_norm before scatter-adding.
2055    #[allow(unsafe_code)]
2056    fn embed_backward(
2057        &mut self,
2058        input_ids: &[u32],
2059        _seq_len: usize,
2060        hidden_size: usize,
2061        vocab_size: usize,
2062        grad_output_is_a: bool,
2063    ) -> Option<()> {
2064        // The final backward output is in whichever buffer was last written
2065        let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2066        let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2067        let embed_grad_buf = unsafe {
2068            if grad_output_is_a {
2069                &*grad_a_ptr
2070            } else {
2071                &*grad_b_ptr
2072            }
2073        };
2074        let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2075
2076        // C-EMBED-GRAD-001: ALWAYS clip activation gradient before scatter-add.
2077        // Without this, 24-layer random-init backward amplifies gradients to ~1e35,
2078        // which overflows the CPU AdamW's second moment buffer.
2079        //
2080        // ALB-071: Decoupled from general grad_clip config. Embed activation gradient
2081        // clipping is a SAFETY constraint (prevents NaN), not a training hyperparameter.
2082        // Uses dedicated max_embed_grad_norm (default 1.0) independent of weight grad_clip.
2083        let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2084        {
2085            let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2086            let grad_norm = sq_sum.sqrt() as f32;
2087            self.last_embed_grad_norm = grad_norm; // R-040: per-parameter-group tracking
2088            if grad_norm > embed_clip_norm {
2089                let scale = embed_clip_norm / grad_norm;
2090                for g in &mut embed_grad_data {
2091                    *g *= scale;
2092                }
2093            }
2094        }
2095
2096        // KAIZEN-048: In-place scatter-add via grad_cell().borrow_mut().
2097        // Before: 3 × 128MB clones per step (grad() deep-copies Array1).
2098        // After: zero clones — mutate existing gradient buffer directly.
2099        let embed_weight = &mut self.model.embed_tokens.weight;
2100        let grad_cell = embed_weight.grad_cell();
2101        let mut grad_ref = grad_cell.borrow_mut();
2102        if grad_ref.is_none() {
2103            *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2104        }
2105        if let Some(grad) = grad_ref.as_mut() {
2106            for (pos, &token_id) in input_ids.iter().enumerate() {
2107                let tid = token_id as usize;
2108                if tid < vocab_size {
2109                    let src = pos * hidden_size;
2110                    let dst = tid * hidden_size;
2111                    for h in 0..hidden_size {
2112                        grad[dst + h] += embed_grad_data[src + h];
2113                    }
2114                }
2115            }
2116        }
2117        Some(())
2118    }
2119
2120    /// Apply optimizer step to CPU embedding and update metrics.
2121    ///
2122    /// GPU block optimizer steps now run interleaved with backward in `gpu_backward()`.
2123    /// LM head and final norm optimizer steps also run in `gpu_backward()`.
2124    /// This method handles only CPU embedding and bookkeeping.
2125    fn optimizer_step(&mut self) {
2126        // ALB-072: Gradients are no longer scaled by grad_scaler (loss_scale excludes
2127        // grad_scaler.scale()). All backward computation uses f32 — no fp16 underflow
2128        // risk. Skip unscaling; just update scaler as successful.
2129        self.grad_scaler.update(true);
2130
2131        // ALB-079: Sync CPU embedding optimizer lr with cosine schedule
2132        self.embed_optimizer.set_lr(self.current_lr());
2133        // CPU optimizer step for embedding weight
2134        let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2135        self.embed_optimizer.step_refs(&mut embed_params);
2136
2137        self.step += 1;
2138        self.metrics.losses.push(self.accumulated_loss);
2139        self.metrics.increment_step();
2140
2141        self.accumulated_loss = 0.0;
2142        self.accumulated_batches = 0;
2143    }
2144
2145    /// Process a batch (forward + backward + optimizer step with accumulation).
2146    ///
2147    /// R-038: When `accumulation_steps > 1`, runs forward+backward without optimizer
2148    /// for each micro-batch, downloading per-block weight gradients to CPU-side
2149    /// `PerBlockGradientAccumulator`. After `accumulation_steps` batches, averages
2150    /// the accumulated gradients, uploads them to GPU, and runs a single optimizer step.
2151    ///
2152    /// When `accumulation_steps == 1` (default), runs forward+backward+optimizer
2153    /// immediately per sequence (original behavior).
2154    ///
2155    /// Returns average loss for the batch.
2156    pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2157        if batch.batch_size == 0 {
2158            return 0.0;
2159        }
2160
2161        let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2162
2163        if self.accumulated_batches == 0 {
2164            // Zero embedding gradients at start of accumulation window
2165            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2166        }
2167
2168        let mut total_loss = 0.0;
2169        let mut valid_count = 0;
2170
2171        for i in 0..batch.batch_size {
2172            let Some(input_ids) = batch.get_input(i) else {
2173                continue;
2174            };
2175            let Some(target_ids) = batch.get_target(i) else {
2176                continue;
2177            };
2178
2179            // R-038: When accumulating, run backward without optimizer (accumulate_only=true).
2180            // Gradients are downloaded to CPU per-block accum buffers. Embedding grads are
2181            // scatter-added normally (they're already on CPU).
2182            if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2183                total_loss += loss;
2184                valid_count += 1;
2185                if accumulating {
2186                    if let Some(accum) = &mut self.gpu_grad_accum {
2187                        accum.accumulated_count += 1;
2188                    } else if let Some(accum) = &mut self.grad_accum {
2189                        accum.accumulated_count += 1;
2190                    }
2191                }
2192            }
2193        }
2194
2195        let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2196
2197        // Debug: help diagnose loss=0.0 when gradients are non-zero
2198        if avg_loss == 0.0 && valid_count > 0 {
2199            eprintln!(
2200                "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2201                valid_count, total_loss, batch.batch_size
2202            );
2203        }
2204
2205        self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2206        self.accumulated_batches += 1;
2207
2208        if self.accumulated_batches >= self.config.accumulation_steps {
2209            if accumulating {
2210                // ALB-091: Prefer GPU-resident accum path (zero D2H), fall back to CPU.
2211                if self.gpu_grad_accum.is_some() {
2212                    self.gpu_optimizer_from_gpu_accum();
2213                } else {
2214                    self.gpu_optimizer_from_accum();
2215                }
2216            }
2217            self.optimizer_step();
2218        }
2219
2220        avg_loss
2221    }
2222
2223    /// R-005: Evaluate a batch without backward pass or weight updates.
2224    /// Returns average cross-entropy loss, or 0.0 if no valid items.
2225    /// KAIZEN-050: Uses fused GPU cross-entropy (no logits D2H).
2226    pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2227        let hidden_size = self.config.model_config.hidden_size;
2228        let vocab_size = self.config.model_config.vocab_size;
2229        let max_sl = self.config.max_seq_len;
2230        let mut total_loss = 0.0;
2231        let mut valid_count = 0;
2232        for i in 0..batch.batch_size {
2233            if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2234            {
2235                total_loss += loss;
2236                valid_count += 1;
2237            }
2238        }
2239        if valid_count > 0 {
2240            total_loss / valid_count as f32
2241        } else {
2242            0.0
2243        }
2244    }
2245
2246    /// Evaluate a single sequence from a batch. Returns None if invalid.
2247    fn eval_single_sequence(
2248        &mut self,
2249        batch: &LMBatch,
2250        i: usize,
2251        max_sl: usize,
2252        hidden_size: usize,
2253        vocab_size: usize,
2254    ) -> Option<f32> {
2255        let input_ids = batch.get_input(i)?;
2256        let target_ids = batch.get_target(i)?;
2257        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
2258        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2259        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2260        let seq_len = input_ids.len();
2261        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2262        let stream = self.cuda_trainer.stream();
2263        let scale = 1.0 / seq_len as f32;
2264        let loss = fused_cross_entropy_cuda(
2265            &mut self.gpu_training.logits_buf,
2266            target_ids,
2267            seq_len as u32,
2268            vocab_size as u32,
2269            scale,
2270            stream,
2271        )
2272        .ok()?;
2273        if loss.is_finite() {
2274            Some(loss)
2275        } else {
2276            None
2277        }
2278    }
2279
2280    /// Train for one epoch over batches.
2281    pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2282        self.train_epoch_with_callback(batches, |_, _, _| {})
2283    }
2284
2285    /// Train for one epoch with a per-step callback.
2286    ///
2287    /// Stops early if `max_steps` is set and reached.
2288    pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2289    where
2290        F: FnMut(usize, f32, &Self),
2291    {
2292        if batches.is_empty() {
2293            return 0.0;
2294        }
2295
2296        let mut total_loss = 0.0;
2297        let mut batches_processed = 0;
2298
2299        for (i, batch) in batches.iter().enumerate() {
2300            if let Some(max) = self.config.max_steps {
2301                if self.step >= max {
2302                    break;
2303                }
2304            }
2305
2306            let batch_loss = self.train_batch(batch);
2307            total_loss += batch_loss;
2308            batches_processed += 1;
2309            on_batch(i, batch_loss, self);
2310        }
2311
2312        // KAIZEN-047: Print profiler summary at end of epoch
2313        if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2314            self.profiler.print_report();
2315        }
2316
2317        total_loss / batches_processed.max(1) as f32
2318    }
2319
2320    // --- DDP (data-parallel) support methods ---
2321
2322    /// Ensure the per-block gradient accumulator exists.
2323    ///
2324    /// For DDP, we always need accumulation buffers (even with accumulation_steps=1)
2325    /// because gradients must be downloaded to CPU for AllReduce before optimizer step.
2326    pub(crate) fn ensure_grad_accum(&mut self) {
2327        if self.grad_accum.is_some() {
2328            return;
2329        }
2330        let mc = &self.config.model_config;
2331        let hidden_size = mc.hidden_size;
2332        let kv_hidden = mc.num_kv_heads * mc.head_dim();
2333        let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2334            hidden_size,
2335            kv_hidden,
2336            mc.intermediate_size,
2337        );
2338        self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2339            self.cuda_blocks.len(),
2340            block_sizes,
2341            mc.vocab_size,
2342            hidden_size,
2343        ));
2344    }
2345
2346    /// Forward + backward for one batch, always accumulating (no optimizer step).
2347    ///
2348    /// Used by `DistributedCudaTrainer` to compute local gradients before AllReduce.
2349    /// Returns average loss for the batch.
2350    pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2351        if batch.batch_size == 0 {
2352            return 0.0;
2353        }
2354
2355        if self.accumulated_batches == 0 {
2356            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2357        }
2358
2359        let mut total_loss = 0.0;
2360        let mut valid_count = 0;
2361
2362        for i in 0..batch.batch_size {
2363            let Some(input_ids) = batch.get_input(i) else { continue };
2364            let Some(target_ids) = batch.get_target(i) else { continue };
2365
2366            // Always accumulate_only=true: gradients go to CPU accum buffers
2367            if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2368                total_loss += loss;
2369                valid_count += 1;
2370                if let Some(accum) = &mut self.grad_accum {
2371                    accum.accumulated_count += 1;
2372                }
2373            }
2374        }
2375
2376        if valid_count > 0 {
2377            total_loss / valid_count as f32
2378        } else {
2379            0.0
2380        }
2381    }
2382
2383    /// Apply DDP-averaged gradients: upload to GPU and run optimizer step.
2384    ///
2385    /// Called after AllReduce has written averaged gradients into the grad_accum.
2386    /// Runs gpu_optimizer_from_accum() for blocks + LM head + final norm,
2387    /// then optimizer_step() for embedding.
2388    pub(crate) fn apply_ddp_gradients(&mut self) {
2389        self.accumulated_loss = 0.0;
2390        self.accumulated_batches = 0;
2391        self.gpu_optimizer_from_accum();
2392        self.optimizer_step();
2393    }
2394
2395    /// Get a reference to the gradient accumulator (for DDP AllReduce).
2396    pub(crate) fn grad_accum_ref(
2397        &self,
2398    ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2399        self.grad_accum.as_ref()
2400    }
2401
2402    /// Get a mutable reference to the gradient accumulator (for DDP AllReduce).
2403    pub(crate) fn grad_accum_mut(
2404        &mut self,
2405    ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2406        self.grad_accum.as_mut()
2407    }
2408
2409    /// Get the training config.
2410    pub(crate) fn config(&self) -> &TransformerTrainConfig {
2411        &self.config
2412    }
2413
2414    /// Get CPU embedding gradient as flat Vec for AllReduce.
2415    pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2416        self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2417    }
2418
2419    /// Set CPU embedding gradient from AllReduced flat Vec.
2420    pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2421        self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2422    }
2423
2424    /// Returns true if max_steps has been reached.
2425    pub fn reached_max_steps(&self) -> bool {
2426        self.config.max_steps.is_some_and(|max| self.step >= max)
2427    }
2428
2429    /// Get current step count.
2430    pub fn step(&self) -> usize {
2431        self.step
2432    }
2433
2434    /// Set initial step for resume from checkpoint.
2435    ///
2436    /// Updates both the outer step counter (LR schedule, logging) and the
2437    /// GPU-side AdamW step counter (bias correction). Must be called before
2438    /// any `train_batch()` calls.
2439    pub fn set_initial_step(&mut self, step: usize) {
2440        self.step = step;
2441        self.gpu_training.step = step as u32;
2442    }
2443
2444    /// Set max_steps for cosine LR scheduler (ENT-275).
2445    ///
2446    /// Called by `train_loop_cuda` when `max_steps` is not explicitly set in
2447    /// the YAML config — auto-computes `epochs × batches_per_epoch` so cosine
2448    /// decay activates instead of falling back to constant lr.
2449    pub fn set_max_steps(&mut self, max_steps: usize) {
2450        self.config.max_steps = Some(max_steps);
2451    }
2452
2453    /// Get current learning rate (warmup + cosine decay).
2454    ///
2455    /// ALB-079: Phase 1 = linear warmup (0 → lr_max), Phase 2 = cosine decay
2456    /// (lr_max → 0) over remaining steps. Requires `max_steps` for decay;
2457    /// without it, falls back to constant lr after warmup.
2458    pub fn current_lr(&self) -> f32 {
2459        let base_lr = self.config.lr;
2460        if self.step < self.config.warmup_steps {
2461            // Phase 1: Linear warmup
2462            base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2463        } else if let Some(max_steps) = self.config.max_steps {
2464            // Phase 2: Cosine decay from lr_max to 0
2465            let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2466            if decay_steps == 0 {
2467                return base_lr;
2468            }
2469            let decay_step = self.step - self.config.warmup_steps;
2470            let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2471            0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2472        } else {
2473            // No max_steps: constant lr (legacy behavior)
2474            base_lr
2475        }
2476    }
2477
2478    /// KAIZEN-047: Enable step profiling with a report every `interval` steps.
2479    ///
2480    /// When enabled, prints a table of wall-clock timings per training phase
2481    /// every `interval` training steps. Use interval=0 for manual-only reporting.
2482    ///
2483    /// # Contract (C-STEPPROF-001)
2484    ///
2485    /// - No additional GPU synchronization points (relies on existing syncs)
2486    /// - Overhead: ~11 `Instant::now()` calls per step (~1µs total on Linux)
2487    /// - Timings include async dispatch overhead (not pure kernel time)
2488    pub fn enable_profiler(&mut self, interval: usize) {
2489        self.profiler = StepProfiler::new(true, interval);
2490    }
2491
2492    /// Print the profiler report (if profiling is enabled).
2493    pub fn print_profiler_report(&self) {
2494        self.profiler.print_report();
2495    }
2496
2497    /// R-004: Get last observed gradient L2 norm (LM head proxy).
2498    pub fn last_grad_norm(&self) -> f32 {
2499        self.last_grad_norm
2500    }
2501
2502    /// R-040: Get per-parameter-group gradient norms.
2503    /// Returns (lm_head_grad_norm, embed_grad_norm).
2504    pub fn param_grad_norms(&self) -> (f32, f32) {
2505        (self.last_grad_norm, self.last_embed_grad_norm)
2506    }
2507
2508    /// R-012: Get total trainable parameter count for MFU calculation.
2509    pub fn num_params(&self) -> usize {
2510        self.model.parameters().iter().map(|t| t.len()).sum()
2511    }
2512
2513    /// R-013: Query GPU memory usage (used_mb, total_mb).
2514    pub fn gpu_memory_mb(&self) -> (u64, u64) {
2515        match self.cuda_trainer.context().memory_info() {
2516            Ok((free, total)) => {
2517                let total_mb = (total / (1024 * 1024)) as u64;
2518                let used_mb = ((total - free) / (1024 * 1024)) as u64;
2519                (used_mb, total_mb)
2520            }
2521            Err(_) => (0, 0),
2522        }
2523    }
2524
2525    /// Sync all GPU weights back to CPU model.
2526    ///
2527    /// # Contract (C-SYNCWT-001)
2528    ///
2529    /// Must be called before save or any CPU model access after training.
2530    pub fn sync_weights_to_cpu(&mut self) {
2531        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2532
2533        if use_nf4 {
2534            // ENT-263: NF4 blocks are frozen — base weights don't change.
2535            // Only download LoRA adapter weights for checkpoint saving.
2536            // The base model on CPU stays as-is (original pretrained weights).
2537            // LoRA weights are saved separately (adapter_config.json + adapter.safetensors).
2538            // For now, skip per-layer sync — base weights are unchanged.
2539        } else {
2540            for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2541                if let Ok(weights) = block.download_weights() {
2542                    let layer = &mut self.model.layers[layer_idx];
2543
2544                    layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2545                    layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2546                    layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2547                    layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2548
2549                    layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2550                    layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2551                    layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2552
2553                    layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2554                    layer.post_attn_norm.weight =
2555                        Tensor::from_vec(weights.post_attn_norm_weight, false);
2556                }
2557            }
2558        }
2559
2560        // Sync final norm weight
2561        if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2562            self.model.norm.weight = Tensor::from_vec(norm_data, false);
2563        }
2564
2565        // Sync LM head weight
2566        // ALB-097: ALWAYS save GPU-trained LM head, even for tied-weight models.
2567        // During GPU training, lm_head diverges from embed_tokens because they have
2568        // separate optimizers (GPU AdamW vs CPU AdamW). If we skip the sync for tied
2569        // weights, the checkpoint loses 500+ steps of GPU LM head training → random-init
2570        // loss on resume (Five Whys root cause of ALB-097).
2571        if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2572            self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2573        }
2574    }
2575
2576    /// Get reference to model (syncs weights first).
2577    pub fn model(&self) -> &Transformer {
2578        &self.model
2579    }
2580
2581    /// Get mutable reference to model.
2582    pub fn model_mut(&mut self) -> &mut Transformer {
2583        &mut self.model
2584    }
2585
2586    /// Check if using mixed precision.
2587    pub fn is_mixed_precision(&self) -> bool {
2588        self.config.precision_config.is_mixed()
2589    }
2590
2591    /// Get the gradient scaler (R-002: loss scaling for mixed precision).
2592    pub fn grad_scaler(&self) -> &GradScaler {
2593        &self.grad_scaler
2594    }
2595
2596    /// Check if using gradient checkpointing.
2597    pub fn is_checkpointing(&self) -> bool {
2598        self.config.checkpoint_config.enabled
2599    }
2600
2601    /// Save model weights (syncs GPU→CPU first).
2602    pub fn save(
2603        &mut self,
2604        path: impl AsRef<std::path::Path>,
2605        name: &str,
2606        architecture: &str,
2607    ) -> crate::Result<()> {
2608        self.sync_weights_to_cpu();
2609
2610        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2611        let params: Vec<(String, Tensor)> = self
2612            .model
2613            .named_parameters()
2614            .into_iter()
2615            .map(|(name, tensor)| (name, tensor.clone()))
2616            .collect();
2617
2618        let metadata = ModelMetadata::new(name, architecture);
2619        let model = Model::new(metadata, params);
2620        let config = SaveConfig::new(ModelFormat::SafeTensors);
2621
2622        save_model(&model, path, &config)
2623    }
2624
2625    /// R-011: Prepare checkpoint data for async save.
2626    /// Syncs GPU weights to CPU and snapshots tensor data as Send-able Vec<f32>.
2627    /// Returns a closure that writes the checkpoint file from another thread.
2628    pub fn prepare_async_save(
2629        &mut self,
2630        name: &str,
2631        architecture: &str,
2632    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2633        self.sync_weights_to_cpu();
2634
2635        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2636        let param_data: Vec<(String, Vec<f32>)> = self
2637            .model
2638            .named_parameters()
2639            .into_iter()
2640            .map(|(n, t)| (n, t.data().to_vec()))
2641            .collect();
2642
2643        let name = name.to_string();
2644        let architecture = architecture.to_string();
2645
2646        Box::new(move |path: &std::path::Path| {
2647            let params: Vec<(String, Tensor)> =
2648                param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2649            let metadata = ModelMetadata::new(&name, &architecture);
2650            let model = Model::new(metadata, params);
2651            let config = SaveConfig::new(ModelFormat::SafeTensors);
2652            save_model(&model, path, &config)
2653        })
2654    }
2655
2656    /// ALB-096: Save model weights as APR checkpoint (syncs GPU→CPU first).
2657    ///
2658    /// Single atomic file containing all model weights. Use `save_apr_checkpoint()`
2659    /// to include optimizer state and training metadata in the same file.
2660    pub fn save_apr(
2661        &mut self,
2662        path: impl AsRef<std::path::Path>,
2663        name: &str,
2664        architecture: &str,
2665    ) -> crate::Result<()> {
2666        self.sync_weights_to_cpu();
2667
2668        let params: Vec<(String, Tensor)> = self
2669            .model
2670            .named_parameters()
2671            .into_iter()
2672            .map(|(name, tensor)| (name, tensor.clone()))
2673            .collect();
2674
2675        let metadata = ModelMetadata::new(name, architecture);
2676        let model = Model::new(metadata, params);
2677        let config = SaveConfig::new(ModelFormat::Apr);
2678
2679        save_model(&model, path, &config)
2680    }
2681
2682    /// ALB-096: Prepare APR checkpoint data for async save.
2683    ///
2684    /// Syncs GPU weights to CPU and snapshots tensor data + optimizer state as
2685    /// Send-able `Vec<f32>`. Returns a closure that writes a single atomic APR
2686    /// file from another thread. Includes model weights + CPU embedding optimizer
2687    /// state + training metadata — all in one file.
2688    fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2689        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2690        if use_nf4 {
2691            let frozen_suffixes = [
2692                "q_proj.weight",
2693                "k_proj.weight",
2694                "v_proj.weight",
2695                "o_proj.weight",
2696                "gate_proj.weight",
2697                "up_proj.weight",
2698                "down_proj.weight",
2699            ];
2700            self.model
2701                .named_parameters()
2702                .into_iter()
2703                .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2704                .map(|(n, t)| (n, t.data().to_vec()))
2705                .collect()
2706        } else {
2707            self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2708        }
2709    }
2710
2711    fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2712        if self.config.quantize_nf4 && self.config.is_lora() {
2713            self.cuda_blocks
2714                .iter()
2715                .enumerate()
2716                .filter_map(|(i, block)| {
2717                    block
2718                        .download_lora_weights()
2719                        .ok()
2720                        .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2721                })
2722                .collect()
2723        } else {
2724            Vec::new()
2725        }
2726    }
2727
2728    pub fn prepare_async_apr_save(
2729        &mut self,
2730        name: &str,
2731        architecture: &str,
2732        step: usize,
2733        loss: f64,
2734        lr: f64,
2735    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2736        self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2737    }
2738
2739    /// ALB-130: Prepare APR checkpoint with embedded tokenizer for inference.
2740    ///
2741    /// Training checkpoints must be self-contained for eval (`apr eval --task humaneval`).
2742    /// Without embedded tokenizer, inference falls back to structural validation (fake 100%).
2743    /// The tokenizer path comes from `spec.data.tokenizer` in the training YAML.
2744    pub fn prepare_async_apr_save_with_tokenizer(
2745        &mut self,
2746        name: &str,
2747        architecture: &str,
2748        step: usize,
2749        loss: f64,
2750        lr: f64,
2751        tokenizer_path: Option<&std::path::Path>,
2752    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2753        self.sync_weights_to_cpu();
2754
2755        let param_data = self.snapshot_param_data();
2756        let lora_data = self.snapshot_lora_data();
2757
2758        // Snapshot CPU embedding optimizer state
2759        let embed_m: Vec<Vec<f32>> = self
2760            .embed_optimizer
2761            .first_moments()
2762            .iter()
2763            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2764            .collect();
2765        let embed_v: Vec<Vec<f32>> = self
2766            .embed_optimizer
2767            .second_moments()
2768            .iter()
2769            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2770            .collect();
2771        let embed_step = self.embed_optimizer.step_count();
2772
2773        // ALB-118: Download GPU block optimizer states (m/v moments) for checkpointing.
2774        // Without this, resume re-initializes all 24 blocks' AdamW state to zero,
2775        // causing loss spikes and convergence failure (v10/v11/v12 post-mortems).
2776        // Transfer cost: ~2.3 GB D2H, <6ms on PCIe4/5.
2777        let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2778            .gpu_training
2779            .optimizer_states
2780            .iter()
2781            .map(|state| state.download_to_host().unwrap_or_default())
2782            .collect();
2783
2784        // ALB-118: Download LM head and final norm optimizer states
2785        let lm_head_m_host = {
2786            let mut buf = vec![0.0f32; self.lm_head_m.len()];
2787            let _ = self.lm_head_m.copy_to_host(&mut buf);
2788            buf
2789        };
2790        let lm_head_v_host = {
2791            let mut buf = vec![0.0f32; self.lm_head_v.len()];
2792            let _ = self.lm_head_v.copy_to_host(&mut buf);
2793            buf
2794        };
2795        let final_norm_m_host = {
2796            let mut buf = vec![0.0f32; self.final_norm_m.len()];
2797            let _ = self.final_norm_m.copy_to_host(&mut buf);
2798            buf
2799        };
2800        let final_norm_v_host = {
2801            let mut buf = vec![0.0f32; self.final_norm_v.len()];
2802            let _ = self.final_norm_v.copy_to_host(&mut buf);
2803            buf
2804        };
2805
2806        let name = name.to_string();
2807        let architecture = architecture.to_string();
2808        let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2809        let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2810
2811        // ALB-130: Pre-read tokenizer.json for embedding in checkpoint.
2812        // Parse HuggingFace tokenizer format → extract vocab + merges + special token IDs.
2813        let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2814            tokenizer_path.and_then(|p| {
2815                let json_bytes = std::fs::read(p).ok()?;
2816                let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2817                let model = tok.get("model")?;
2818                let vocab_obj = model.get("vocab")?.as_object()?;
2819                // Build sorted-by-id vocab list
2820                let mut vocab_pairs: Vec<(String, u64)> =
2821                    vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
2822                vocab_pairs.sort_by_key(|(_, id)| *id);
2823                let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
2824                // Merges as "token1 token2" strings
2825                let merges: Vec<String> = model
2826                    .get("merges")?
2827                    .as_array()?
2828                    .iter()
2829                    .filter_map(|v| v.as_str().map(String::from))
2830                    .collect();
2831                // Special tokens: BOS=<s>=1, EOS=</s>=2 (from added_tokens)
2832                let added = tok.get("added_tokens").and_then(|a| a.as_array());
2833                let bos_id = added.and_then(|arr| {
2834                    arr.iter()
2835                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
2836                        .and_then(|t| t.get("id")?.as_u64())
2837                });
2838                let eos_id = added.and_then(|arr| {
2839                    arr.iter()
2840                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
2841                        .and_then(|t| t.get("id")?.as_u64())
2842                });
2843                if vocab.is_empty() {
2844                    return None;
2845                }
2846                println!(
2847                    "  [ALB-130] Embedding tokenizer: {} vocab, {} merges",
2848                    vocab.len(),
2849                    merges.len()
2850                );
2851                Some((vocab, merges, bos_id, eos_id))
2852            });
2853
2854        Box::new(move |path: &std::path::Path| {
2855            use aprender::serialization::apr::AprWriter;
2856            use serde_json::Value as Jv;
2857
2858            let mut writer = AprWriter::new();
2859
2860            // Metadata
2861            writer.set_metadata("model_name", Jv::String(name));
2862            writer.set_metadata("architecture", Jv::String(architecture));
2863            writer.set_metadata(
2864                "format",
2865                Jv::String(if is_delta_checkpoint {
2866                    "entrenar-delta-checkpoint".into()
2867                } else {
2868                    "entrenar-checkpoint".into()
2869                }),
2870            );
2871            writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
2872            writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
2873            writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
2874            writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
2875            if let Some(cfg) = model_config_json {
2876                writer.set_metadata("model_config", Jv::String(cfg));
2877            }
2878
2879            // ALB-130: Embed tokenizer vocab + merges for standalone inference
2880            if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
2881                writer.set_metadata(
2882                    "tokenizer.vocabulary",
2883                    Jv::Array(vocab.into_iter().map(Jv::String).collect()),
2884                );
2885                writer.set_metadata(
2886                    "tokenizer.merges",
2887                    Jv::Array(merges.into_iter().map(Jv::String).collect()),
2888                );
2889                if let Some(bos) = bos_id {
2890                    writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
2891                }
2892                if let Some(eos) = eos_id {
2893                    writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
2894                }
2895            }
2896
2897            // Find hidden_size from norm weights for shape inference
2898            let hidden_size = param_data
2899                .iter()
2900                .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
2901                .map_or(0, |(_, d)| d.len());
2902
2903            // Model weight tensors
2904            for (tensor_name, data) in &param_data {
2905                let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
2906                writer.add_tensor_f32(tensor_name.clone(), shape, data);
2907            }
2908
2909            // Optimizer state tensors
2910            for (i, m_data) in embed_m.iter().enumerate() {
2911                let len = m_data.len();
2912                writer.add_tensor_f32(
2913                    format!("__training__.embed_optimizer.m.{i}"),
2914                    vec![len],
2915                    m_data,
2916                );
2917            }
2918            for (i, v_data) in embed_v.iter().enumerate() {
2919                let len = v_data.len();
2920                writer.add_tensor_f32(
2921                    format!("__training__.embed_optimizer.v.{i}"),
2922                    vec![len],
2923                    v_data,
2924                );
2925            }
2926
2927            // ALB-118: Save GPU block optimizer states (m/v moments for all 24 blocks)
2928            for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
2929                for (suffix, data) in buffers {
2930                    let len = data.len();
2931                    writer.add_tensor_f32(
2932                        format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
2933                        vec![len],
2934                        data,
2935                    );
2936                }
2937            }
2938
2939            // ALB-118: Save LM head and final norm optimizer states
2940            if !lm_head_m_host.is_empty() {
2941                let len = lm_head_m_host.len();
2942                writer.add_tensor_f32(
2943                    "__training__.lm_head_optimizer.m".to_string(),
2944                    vec![len],
2945                    &lm_head_m_host,
2946                );
2947                let len = lm_head_v_host.len();
2948                writer.add_tensor_f32(
2949                    "__training__.lm_head_optimizer.v".to_string(),
2950                    vec![len],
2951                    &lm_head_v_host,
2952                );
2953            }
2954            if !final_norm_m_host.is_empty() {
2955                let len = final_norm_m_host.len();
2956                writer.add_tensor_f32(
2957                    "__training__.final_norm_optimizer.m".to_string(),
2958                    vec![len],
2959                    &final_norm_m_host,
2960                );
2961                let len = final_norm_v_host.len();
2962                writer.add_tensor_f32(
2963                    "__training__.final_norm_optimizer.v".to_string(),
2964                    vec![len],
2965                    &final_norm_v_host,
2966                );
2967            }
2968
2969            // ENT-276: Save LoRA adapter weights (QLoRA checkpoint resume)
2970            for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
2971                if !a_q.is_empty() {
2972                    writer.add_tensor_f32(
2973                        format!("lora.{layer_idx}.q_proj.lora_a"),
2974                        vec![a_q.len()],
2975                        a_q,
2976                    );
2977                    writer.add_tensor_f32(
2978                        format!("lora.{layer_idx}.q_proj.lora_b"),
2979                        vec![b_q.len()],
2980                        b_q,
2981                    );
2982                }
2983                if !a_v.is_empty() {
2984                    writer.add_tensor_f32(
2985                        format!("lora.{layer_idx}.v_proj.lora_a"),
2986                        vec![a_v.len()],
2987                        a_v,
2988                    );
2989                    writer.add_tensor_f32(
2990                        format!("lora.{layer_idx}.v_proj.lora_b"),
2991                        vec![b_v.len()],
2992                        b_v,
2993                    );
2994                }
2995            }
2996
2997            // Write APR checkpoint to file
2998            writer
2999                .write(path)
3000                .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3001
3002            Ok(())
3003        })
3004    }
3005
3006    /// GPU device name.
3007    pub fn gpu_name(&self) -> String {
3008        self.cuda_trainer.device_name()
3009    }
3010
3011    /// ENT-269: Save LoRA adapter weights as PEFT-compatible files.
3012    ///
3013    /// Downloads LoRA A/B matrices from GPU, un-scales B (divide by lora_scale),
3014    /// transposes to PEFT convention (A=[rank, d_in], B=[d_out, rank]),
3015    /// and writes `adapter_model.safetensors` + `adapter_config.json`.
3016    ///
3017    /// # Contract: C-QLORA-SAVE-001
3018    ///
3019    /// NF4 QLoRA training MUST produce `adapter_model.safetensors` in output_dir.
3020    pub fn save_cuda_lora_adapter(
3021        &self,
3022        output_dir: &std::path::Path,
3023        base_model_name: Option<&str>,
3024    ) -> crate::Result<()> {
3025        if !self.config.quantize_nf4 || !self.config.is_lora() {
3026            return Ok(()); // Not a QLoRA run, nothing to save
3027        }
3028
3029        let lora_rank = self.config.lora_rank.unwrap_or(16);
3030        let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3031        let lora_scale = lora_alpha / lora_rank as f32;
3032        let hidden_size = self.config.model_config.hidden_size;
3033        let head_dim = self.config.model_config.head_dim();
3034        let q_dim = self.config.model_config.num_attention_heads * head_dim;
3035        let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3036
3037        let lora_config =
3038            crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3039
3040        let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3041
3042        for (i, block) in self.cuda_blocks.iter().enumerate() {
3043            let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3044                Ok(weights) => weights,
3045                Err(_) => continue, // Skip non-NF4 blocks
3046            };
3047
3048            if a_q.is_empty() && a_v.is_empty() {
3049                continue;
3050            }
3051
3052            // Q projection LoRA
3053            if !a_q.is_empty() {
3054                // GPU stores A_q as [hidden, rank] row-major, PEFT expects [rank, hidden]
3055                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3056                for r in 0..hidden_size {
3057                    for c in 0..lora_rank {
3058                        a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3059                    }
3060                }
3061
3062                // GPU stores B_q as [rank, q_dim] pre-scaled by lora_scale
3063                // PEFT expects [q_dim, rank] un-scaled
3064                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3065                let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3066                for r in 0..lora_rank {
3067                    for c in 0..q_dim {
3068                        b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3069                    }
3070                }
3071
3072                let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3073                let mut layer = crate::lora::LoRALayer::new(
3074                    base_weight,
3075                    q_dim,
3076                    hidden_size,
3077                    lora_rank,
3078                    lora_alpha,
3079                );
3080                // Overwrite the A and B data with trained weights
3081                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3082                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3083
3084                adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3085            }
3086
3087            // V projection LoRA
3088            if !a_v.is_empty() {
3089                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3090                for r in 0..hidden_size {
3091                    for c in 0..lora_rank {
3092                        a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3093                    }
3094                }
3095
3096                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3097                let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3098                for r in 0..lora_rank {
3099                    for c in 0..kv_hidden {
3100                        b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3101                    }
3102                }
3103
3104                let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3105                let mut layer = crate::lora::LoRALayer::new(
3106                    base_weight,
3107                    kv_hidden,
3108                    hidden_size,
3109                    lora_rank,
3110                    lora_alpha,
3111                );
3112                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3113                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3114
3115                adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3116            }
3117        }
3118
3119        if adapters.is_empty() {
3120            println!("  [WARN] No LoRA adapters found to save");
3121            return Ok(());
3122        }
3123
3124        let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3125            adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3126
3127        std::fs::create_dir_all(output_dir).ok();
3128        crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3129            .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3130
3131        let adapter_path = output_dir.join("adapter_model.safetensors");
3132        let size_mb =
3133            std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3134        println!(
3135            "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3136            adapters.len(),
3137            size_mb,
3138            output_dir.display()
3139        );
3140
3141        Ok(())
3142    }
3143
3144    /// R-001: Save CPU embedding optimizer state (m/v buffers + step counter).
3145    ///
3146    /// Writes `optimizer_state.json` to the given directory. GPU block optimizer
3147    /// states remain on-device (D2H for 20 buffers × N blocks is deferred).
3148    pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3149        let path = dir.join("optimizer_state.json");
3150        let m_data: Vec<Option<Vec<f32>>> = self
3151            .embed_optimizer
3152            .first_moments()
3153            .iter()
3154            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3155            .collect();
3156        let v_data: Vec<Option<Vec<f32>>> = self
3157            .embed_optimizer
3158            .second_moments()
3159            .iter()
3160            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3161            .collect();
3162        let state = serde_json::json!({
3163            "type": "adamw_cpu_embed",
3164            "step": self.embed_optimizer.step_count(),
3165            "m": m_data,
3166            "v": v_data,
3167        });
3168        let json_str = serde_json::to_string(&state).map_err(|e| {
3169            crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3170        })?;
3171        std::fs::write(&path, json_str)
3172            .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3173        Ok(())
3174    }
3175
3176    /// ENT-276: Restore LoRA adapter weights from APR checkpoint.
3177    ///
3178    /// Reads `lora.{layer}.{q,v}_proj.lora_{a,b}` tensors from the APR file
3179    /// and uploads them to the NF4 CUDA blocks, replacing the fresh random init.
3180    /// Returns (layers_restored, layers_total).
3181    pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3182        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3183            Ok(r) => r,
3184            Err(_) => return (0, self.cuda_blocks.len()),
3185        };
3186
3187        let mut restored = 0usize;
3188        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3189            let a_q =
3190                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3191            let b_q =
3192                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3193            let a_v =
3194                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3195            let b_v =
3196                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3197
3198            if a_q.is_empty() {
3199                continue; // No LoRA data for this layer in checkpoint
3200            }
3201
3202            if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3203                eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3204                continue;
3205            }
3206            restored += 1;
3207        }
3208
3209        (restored, self.cuda_blocks.len())
3210    }
3211
3212    /// ALB-096: Load CPU embedding optimizer state from APR checkpoint.
3213    ///
3214    /// Reads `__training__.embed_optimizer.{m,v}.*` tensors from the APR file.
3215    /// Returns true if state was loaded.
3216    pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3217        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3218            Ok(r) => r,
3219            Err(_) => return false,
3220        };
3221
3222        // Restore step count from metadata
3223        if let Some(step_val) = reader.get_metadata("optimizer_step") {
3224            if let Some(step_str) = step_val.as_str() {
3225                if let Ok(step) = step_str.parse::<u64>() {
3226                    self.embed_optimizer.set_step_count(step);
3227                }
3228            }
3229        }
3230
3231        // Restore first moments (m)
3232        for i in 0..128 {
3233            let name = format!("__training__.embed_optimizer.m.{i}");
3234            match reader.read_tensor_f32(&name) {
3235                Ok(data) if !data.is_empty() => {
3236                    self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3237                }
3238                _ => break,
3239            }
3240        }
3241
3242        // Restore second moments (v)
3243        for i in 0..128 {
3244            let name = format!("__training__.embed_optimizer.v.{i}");
3245            match reader.read_tensor_f32(&name) {
3246                Ok(data) if !data.is_empty() => {
3247                    self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3248                }
3249                _ => break,
3250            }
3251        }
3252
3253        // ALB-118: Restore GPU block optimizer states (m/v moments for all blocks)
3254        let suffixes = [
3255            "m.w_q",
3256            "v.w_q",
3257            "m.w_k",
3258            "v.w_k",
3259            "m.w_v",
3260            "v.w_v",
3261            "m.w_o",
3262            "v.w_o",
3263            "m.w_gate",
3264            "v.w_gate",
3265            "m.w_up",
3266            "v.w_up",
3267            "m.w_down",
3268            "v.w_down",
3269            "m.input_norm",
3270            "v.input_norm",
3271            "m.post_attn_norm",
3272            "v.post_attn_norm",
3273        ];
3274        let mut blocks_restored = 0usize;
3275        for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3276            let mut data = std::collections::HashMap::new();
3277            for suffix in &suffixes {
3278                let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3279                if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3280                    if !tensor_data.is_empty() {
3281                        data.insert(suffix.to_string(), tensor_data);
3282                    }
3283                }
3284            }
3285            if !data.is_empty() {
3286                let _ = state.restore_from_host(&data);
3287                blocks_restored += 1;
3288            }
3289        }
3290
3291        // ALB-118: Restore LM head optimizer state
3292        if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3293            if m_data.len() == self.lm_head_m.len() {
3294                let _ = self.lm_head_m.copy_from_host(&m_data);
3295            }
3296        }
3297        if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3298            if v_data.len() == self.lm_head_v.len() {
3299                let _ = self.lm_head_v.copy_from_host(&v_data);
3300            }
3301        }
3302
3303        // ALB-118: Restore final norm optimizer state
3304        if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3305            if m_data.len() == self.final_norm_m.len() {
3306                let _ = self.final_norm_m.copy_from_host(&m_data);
3307            }
3308        }
3309        if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3310            if v_data.len() == self.final_norm_v.len() {
3311                let _ = self.final_norm_v.copy_from_host(&v_data);
3312            }
3313        }
3314
3315        // ALB-132: Report restore results — don't silently swallow failures
3316        if blocks_restored > 0 {
3317            println!(
3318                "  ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3319                self.gpu_training.optimizer_states.len()
3320            );
3321        } else if !self.gpu_training.optimizer_states.is_empty() {
3322            println!(
3323                "  [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3324                self.gpu_training.optimizer_states.len()
3325            );
3326        }
3327
3328        true
3329    }
3330
3331    /// R-001: Load CPU embedding optimizer state from `optimizer_state.json`.
3332    ///
3333    /// Returns true if state was loaded, false if file doesn't exist.
3334    pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3335        let path = dir.join("optimizer_state.json");
3336        let data = match std::fs::read_to_string(&path) {
3337            Ok(d) => d,
3338            Err(_) => return false,
3339        };
3340        let state: serde_json::Value = match serde_json::from_str(&data) {
3341            Ok(v) => v,
3342            Err(_) => return false,
3343        };
3344        if let Some(step) = state["step"].as_u64() {
3345            self.embed_optimizer.set_step_count(step);
3346        }
3347        restore_moment_buffers(&state["m"], |idx, arr| {
3348            self.embed_optimizer.set_first_moment(idx, arr);
3349        });
3350        restore_moment_buffers(&state["v"], |idx, arr| {
3351            self.embed_optimizer.set_second_moment(idx, arr);
3352        });
3353        true
3354    }
3355}
3356
3357/// ALB-096: Infer 2D tensor shape from name and element count.
3358///
3359/// Same logic as `infer_all_tensor_shapes` in `io/save.rs` but for a single tensor.
3360#[cfg(feature = "cuda")]
3361fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3362    if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3363        vec![numel]
3364    } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3365        let other_dim = numel / hidden_size;
3366        if name.ends_with("down_proj.weight") {
3367            vec![hidden_size, other_dim]
3368        } else {
3369            vec![other_dim, hidden_size]
3370        }
3371    } else {
3372        vec![numel]
3373    }
3374}
3375
3376/// Parse a JSON array of moment buffers and apply each via callback.
3377#[cfg(feature = "cuda")]
3378fn restore_moment_buffers(
3379    json_arr: &serde_json::Value,
3380    mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3381) {
3382    let Some(arr) = json_arr.as_array() else { return };
3383    for (idx, val) in arr.iter().enumerate() {
3384        let Some(inner) = val.as_array() else { continue };
3385        let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3386        if !floats.is_empty() {
3387            set_fn(idx, ndarray::Array1::from_vec(floats));
3388        }
3389    }
3390}
3391
3392// ── Non-CUDA stub ──
3393
3394#[cfg(not(feature = "cuda"))]
3395pub struct CudaTransformerTrainer;
3396
3397#[cfg(not(feature = "cuda"))]
3398impl CudaTransformerTrainer {
3399    pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3400        Err(crate::error::Error::ConfigError(
3401            "CUDA not available (compiled without cuda feature)".into(),
3402        ))
3403    }
3404
3405    pub fn with_model(
3406        _model: crate::transformer::Transformer,
3407        _config: super::config::TransformerTrainConfig,
3408    ) -> crate::Result<Self> {
3409        Err(crate::error::Error::ConfigError(
3410            "CUDA not available (compiled without cuda feature)".into(),
3411        ))
3412    }
3413
3414    pub fn gpu_name(&self) -> String {
3415        unreachable!("CudaTransformerTrainer stub should never be instantiated")
3416    }
3417}
3418
3419#[cfg(test)]
3420mod tests {
3421    #[test]
3422    #[cfg(not(feature = "cuda"))]
3423    fn test_cuda_trainer_stub_returns_error() {
3424        use super::super::config::TransformerTrainConfig;
3425        use crate::transformer::TransformerConfig;
3426
3427        let mc = TransformerConfig::tiny();
3428        let config = TransformerTrainConfig::new(mc);
3429        let result = super::CudaTransformerTrainer::new(config);
3430        assert!(result.is_err());
3431    }
3432}