Skip to main content

entrenar/train/transformer_trainer/
cuda_trainer.rs

1//! GPU-resident transformer trainer (ALB-040)
2//!
3//! Wires the existing `CudaTransformerBlock` forward/backward/optimizer_step
4//! into the pretraining path. Follows the proven `classify_pipeline.rs` pattern.
5//!
6//! # Architecture
7//!
8//! ```text
9//! CudaTransformerTrainer
10//! ├── model: Transformer                 (CPU — embed + save)
11//! ├── cuda_trainer: CudaTrainer          (GPU device context)
12//! ├── cuda_blocks: Vec<CudaBlock>            (fp32 or NF4)
13//! ├── cuda_grad_workspace: CudaGradWorkspace
14//! ├── gpu_training: GpuPretrainState     (layer_inputs, grad bufs, opt states)
15//! ├── lm_head_weight_gpu: GpuBuffer      (V × H on GPU)
16//! ├── lm_head_grad_gpu: GpuBuffer        (V × H gradient scratch)
17//! ├── lm_head_m/v: GpuBuffer             (AdamW moment states)
18//! └── config: TransformerTrainConfig
19//! ```
20//!
21//! # Transfer budget (C-GPUTRAIN-002, updated KAIZEN-050/052)
22//!
23//! 1 PCIe transfer per training step (+ tiny control transfers):
24//! 1. H2D: hidden states after embedding (seq×H×4 bytes)
25//! 2. H2D: target_ids for fused cross-entropy (seq×4 bytes — ~512B)
26//! 3. D2H: loss_partials from fused cross-entropy (seq×4 bytes — ~512B)
27//!
28//! Eliminated by KAIZEN-050:
29//! - D2H logits (was seq×V×4 = 77.8MB for Qwen3-4B)
30//! - H2D grad_logits (was seq×V×4 = 77.8MB)
31//!
32//! Eliminated by KAIZEN-052:
33//! - grad_gpu buffer allocation (was seq×V×4 = 77.8MB per step)
34
35#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42    gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46    adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47    gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48    squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64    CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65    GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75/// Compute gradient L2 norm of the shared workspace via GPU reduction (KAIZEN-054).
76///
77/// Uses `squared_sum_cuda` per buffer (~1KB D2H each) instead of downloading entire
78/// gradient buffers to CPU (was 58 MB+ per block, disabled in ALB-067).
79///
80/// Free function to avoid borrow conflicts with `&mut self`.
81#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83    ws: &CudaGradWorkspace,
84    max_norm: f32,
85    stream: &CudaStream,
86) -> (f32, f32) {
87    use crate::autograd::cuda_optim::PendingSquaredSum;
88
89    let all_bufs: [&GpuBuffer<f32>; 9] = [
90        &ws.grad_w_q,
91        &ws.grad_w_k,
92        &ws.grad_w_v,
93        &ws.grad_w_o,
94        &ws.grad_gate,
95        &ws.grad_up,
96        &ws.grad_down,
97        &ws.grad_input_norm,
98        &ws.grad_post_attn_norm,
99    ];
100
101    // KAIZEN-055: Launch all 9 squared_sum kernels back-to-back without syncing.
102    // Single sync after all launches — reduces 9 pipeline flushes to 1 per block.
103    let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104    for buf in &all_bufs {
105        let n = buf.len() as u32;
106        if n == 0 {
107            continue;
108        }
109        if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110            pending.push(p);
111        }
112    }
113
114    // Single sync point for all 9 kernel launches.
115    if stream.synchronize().is_err() {
116        return (1.0, 0.0);
117    }
118
119    // Collect results: download partial sums (~1KB each) and combine.
120    // C-CLIP-001: squared_sum_collect returns sum(x²) = ||g||².
121    // Accumulate directly — do NOT re-square (entrenar#311 fix).
122    let mut total_sq = 0.0f64;
123    for p in &pending {
124        if let Ok(sq_norm) = squared_sum_collect(p) {
125            total_sq += f64::from(sq_norm); // sq_norm is already ||g||²
126        }
127    }
128
129    let grad_norm = total_sq.sqrt() as f32; // L2 norm = sqrt(sum of squared norms)
130    let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131    (scale, grad_norm)
132}
133
134/// Clip all gradient buffers in the shared workspace using GPU-computed L2 norm (KAIZEN-054).
135///
136/// R-004: Returns pre-clip gradient L2 norm for observability logging.
137#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139    let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140    if (scale - 1.0).abs() < 1e-7 {
141        return grad_norm;
142    }
143
144    let n_wq = ws.grad_w_q.len() as u32;
145    let n_wk = ws.grad_w_k.len() as u32;
146    let n_wv = ws.grad_w_v.len() as u32;
147    let n_wo = ws.grad_w_o.len() as u32;
148    let n_gate = ws.grad_gate.len() as u32;
149    let n_up = ws.grad_up.len() as u32;
150    let n_down = ws.grad_down.len() as u32;
151    let n_inorm = ws.grad_input_norm.len() as u32;
152    let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154    let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155    let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156    let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157    let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158    let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159    let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160    let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163    grad_norm
164}
165
166/// ALB-078: Fused gradient clipping — entire pipeline stays on GPU.
167///
168/// Replaces `clip_workspace_gradients` by eliminating the stream.synchronize()
169/// and D2H partial-sum download. All computation happens on GPU:
170///
171/// 1. 9× SquaredSumKernel → write partials to pre-allocated contiguous buffer
172/// 2. 1× ClipScaleReduceKernel → reduce partials, compute scale on GPU
173/// 3. 9× GradientClipGpuScaleKernel → read scale from GPU, apply to gradients
174///
175/// Zero sync points, zero D2H transfers per block.
176#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178    ws: &mut CudaGradWorkspace,
179    max_norm: f32,
180    state: &FusedClipState,
181    stream: &CudaStream,
182) {
183    let all_bufs: [&GpuBuffer<f32>; 9] = [
184        &ws.grad_w_q,
185        &ws.grad_w_k,
186        &ws.grad_w_v,
187        &ws.grad_w_o,
188        &ws.grad_gate,
189        &ws.grad_up,
190        &ws.grad_down,
191        &ws.grad_input_norm,
192        &ws.grad_post_attn_norm,
193    ];
194
195    // Phase 1: Launch 9 squared_sum kernels into contiguous partials buffer.
196    // Each writes to state.partials_buf at its pre-computed offset.
197    for (i, buf) in all_bufs.iter().enumerate() {
198        let n = buf.len() as u32;
199        if n == 0 {
200            continue;
201        }
202        let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203        let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204    }
205
206    // Phase 2: Reduce all partials and compute clip_scale on GPU.
207    // Stream ordering guarantees all squared_sum kernels complete before this runs.
208    let _ = clip_scale_reduce_cuda(
209        &state.partials_buf,
210        state.total_partials,
211        max_norm,
212        &state.scale_buf,
213        stream,
214    );
215
216    // Phase 3: Apply clip scale to all 9 gradient buffers.
217    // Scale is read from GPU memory — no D2H needed.
218    let scale_ptr = state.scale_buf.as_ptr(); // output[0] = clip_scale
219    let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220        &mut ws.grad_w_q,
221        &mut ws.grad_w_k,
222        &mut ws.grad_w_v,
223        &mut ws.grad_w_o,
224        &mut ws.grad_gate,
225        &mut ws.grad_up,
226        &mut ws.grad_down,
227        &mut ws.grad_input_norm,
228        &mut ws.grad_post_attn_norm,
229    ];
230    for buf in &mut all_bufs_mut {
231        let n = buf.len() as u32;
232        if n == 0 {
233            continue;
234        }
235        let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236    }
237}
238
239/// R-004: Compute gradient L2 norm without clipping (for observability only).
240///
241/// Uses GPU reduction (KAIZEN-054). Only ~9KB D2H per call.
242#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245    let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246    norm
247}
248
249/// ALB-072: Unscale all gradient buffers in the shared workspace by `inv_scale`.
250///
251/// In fp16 AMP, the fused cross-entropy kernel multiplies loss_scale into the
252/// gradient output. All subsequent backward gradients carry this scaling. The
253/// GPU block optimizer (AdamW) must receive unscaled gradients — otherwise the
254/// second moment `v` overflows f32, producing NaN in early layers.
255///
256/// This is the GPU-side equivalent of `GradScaler::unscale_and_check()` used
257/// for CPU embedding gradients.
258#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261    if (inv_scale - 1.0).abs() < 1e-7 {
262        return;
263    }
264
265    let n_wq = ws.grad_w_q.len() as u32;
266    let n_wk = ws.grad_w_k.len() as u32;
267    let n_wv = ws.grad_w_v.len() as u32;
268    let n_wo = ws.grad_w_o.len() as u32;
269    let n_gate = ws.grad_gate.len() as u32;
270    let n_up = ws.grad_up.len() as u32;
271    let n_down = ws.grad_down.len() as u32;
272    let n_inorm = ws.grad_input_norm.len() as u32;
273    let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275    let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276    let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277    let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278    let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279    let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280    let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281    let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282    let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283    let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286/// GPU-resident training state for pretraining.
287///
288/// # Contract (C-GPUTRAIN-001)
289///
290/// - `layer_inputs.len() == num_layers`
291/// - All buffers preallocated at init; zero GPU allocations during training
292/// - `step` increments monotonically
293#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295    /// Saved layer inputs for backward [num_layers][seq_len * hidden_size]
296    layer_inputs: Vec<GpuBuffer<f32>>,
297    /// Which layer inputs were saved during forward (activation checkpointing).
298    /// When checkpointing is enabled, only checkpoint boundary layers are saved.
299    /// Non-saved layers are recomputed from the nearest checkpoint before backward.
300    saved_layer_mask: Vec<bool>,
301    /// Temporary buffer for activation recomputation [seq_len * hidden_size].
302    /// Used as the initial input when recomputing from a checkpoint boundary.
303    /// Only allocated when activation checkpointing is enabled.
304    recompute_buf: Option<GpuBuffer<f32>>,
305    /// Final RMSNorm weight on GPU [hidden_size]
306    final_norm_weight: GpuBuffer<f32>,
307    /// Final block output (pre-norm) for RMSNorm backward [seq_len * hidden_size]
308    blocks_output: GpuBuffer<f32>,
309    /// Alternating gradient buffer A [seq_len * hidden_size]
310    grad_buf_a: GpuBuffer<f32>,
311    /// Alternating gradient buffer B [seq_len * hidden_size]
312    grad_buf_b: GpuBuffer<f32>,
313    /// Gradient for final norm weight [hidden_size]
314    grad_final_norm_weight: GpuBuffer<f32>,
315    /// RMSNorm output buffer (reused each step) [seq_len * hidden_size]
316    norm_output: GpuBuffer<f32>,
317    /// Logits buffer (reused each step) [seq_len * vocab_size]
318    logits_buf: GpuBuffer<f32>,
319    /// LM head gradient buffer [seq_len * hidden_size] (grad w.r.t. normed hidden)
320    lm_head_grad_hidden: GpuBuffer<f32>,
321    /// Per-block optimizer states
322    optimizer_states: Vec<GpuBlockOptimizerState>,
323    /// Optimizer step counter
324    step: u32,
325}
326
327/// GPU-resident transformer trainer for pretraining.
328///
329/// Uses `CudaTransformerBlock` forward/backward/optimizer_step on GPU,
330/// keeping only embedding lookup and cross-entropy loss on CPU.
331///
332/// # Contract (C-GPUTRAIN-002)
333///
334/// - Exactly 3 PCIe transfers per training step
335/// - Graceful fallback to CPU `TransformerTrainer` on any CUDA failure
336/// - Weight sync via `sync_weights_to_cpu()` before save
337#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339    /// CPU model (for embedding, saving, fallback)
340    model: Transformer,
341    /// CUDA device context
342    cuda_trainer: CudaTrainer,
343    /// GPU-resident transformer blocks (fp32 or NF4 via CudaBlock enum)
344    cuda_blocks: Vec<CudaBlock>,
345    /// Shared gradient workspace (one set, reused across layers; fp32 path only)
346    cuda_grad_workspace: CudaGradWorkspace,
347    /// ENT-263: Shared scratch for NF4 blocks (C-SCRATCH-001). None when fp32.
348    nf4_shared_scratch: Option<CudaBlockScratch>,
349    /// ENT-263: Shared LoRA gradient workspace for NF4 backward. None when fp32.
350    nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351    /// ENT-263: Per-block LoRA optimizer states for NF4. None when fp32.
352    nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353    /// GPU training state (layer inputs, grad bufs, optimizer states)
354    gpu_training: GpuPretrainState,
355    /// LM head weight on GPU [vocab_size * hidden_size]
356    lm_head_weight_gpu: GpuBuffer<f32>,
357    /// LM head weight gradient on GPU [vocab_size * hidden_size]
358    lm_head_grad_gpu: GpuBuffer<f32>,
359    /// LM head AdamW first moment [vocab_size * hidden_size]
360    lm_head_m: GpuBuffer<f32>,
361    /// LM head AdamW second moment [vocab_size * hidden_size]
362    lm_head_v: GpuBuffer<f32>,
363    /// Final norm weight AdamW first moment [hidden_size]
364    final_norm_m: GpuBuffer<f32>,
365    /// Final norm weight AdamW second moment [hidden_size]
366    final_norm_v: GpuBuffer<f32>,
367    /// CPU optimizer for embedding weights only
368    embed_optimizer: AdamW,
369    /// Training configuration
370    config: TransformerTrainConfig,
371    /// Metrics tracker
372    pub metrics: MetricsTracker,
373    /// Current optimizer step
374    step: usize,
375    /// Accumulated loss (for gradient accumulation)
376    accumulated_loss: f32,
377    /// Accumulated batch count
378    accumulated_batches: usize,
379    /// R-004: Last observed LM head gradient L2 norm (proxy for global grad norm)
380    last_grad_norm: f32,
381    /// R-040: Last observed embedding activation gradient L2 norm
382    last_embed_grad_norm: f32,
383    /// R-038: Per-block gradient accumulation for true multi-step gradient accumulation.
384    /// Only allocated when accumulation_steps > 1. CPU-side buffers (~335 MB for 350M).
385    grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386    /// ALB-091: GPU-resident gradient accumulation (replaces CPU accum when available).
387    /// Eliminates 24 × ga stream.synchronize() + D2H transfers per optimizer step.
388    gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389    /// R-002: Gradient scaler for mixed-precision training.
390    /// For BF16: no-op (scale=1.0, dynamic=false).
391    /// For FP16: dynamic loss scaling to prevent gradient underflow.
392    grad_scaler: GradScaler,
393    /// KAIZEN-047: Per-step wall-clock profiler.
394    /// Reports timing breakdown for each training phase.
395    profiler: StepProfiler,
396    /// KAIZEN-053: Pre-allocated forward scratch buffers [max_seq_len * hidden_size].
397    /// Reused every step — eliminates 2 × cuMemAlloc/Free per training step.
398    fwd_scratch_a: GpuBuffer<f32>,
399    fwd_scratch_b: GpuBuffer<f32>,
400    /// KAIZEN-056: Pre-allocated CPU staging buffer for H2D hidden state upload.
401    /// Eliminates vec![0.0; max_seq_len * hidden_size] allocation per step.
402    h2d_staging: Vec<f32>,
403    /// KAIZEN-059: Pre-allocated CPU staging buffer for D2H gradient downloads
404    /// during gradient accumulation. Sized to max(h*intermediate, vocab*h).
405    /// Eliminates ~15GB of per-step heap churn (36 × vec![0.0; h*i] + vec![0.0; vocab*h]
406    /// per micro-batch × accumulation_steps).
407    d2h_staging: Vec<f32>,
408    /// ALB-078: Pre-allocated state for fused gradient clipping pipeline.
409    /// Eliminates 24 stream.synchronize() calls per step.
410    fused_clip: Option<FusedClipState>,
411    /// Pre-allocated host zero buffer for zeroing final norm grad [hidden_size].
412    /// BatchedRmsNormBackwardKernel accumulates grad_gamma via atomicAdd,
413    /// so the buffer must be zeroed before each rms_norm_backward call.
414    final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419    /// Create a new GPU-resident trainer.
420    ///
421    /// # Errors
422    ///
423    /// Returns `Err` if CUDA initialization, kernel pre-warming, or block upload fails.
424    /// Caller should fall back to CPU `TransformerTrainer` on error.
425    pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426        let model = Transformer::new(&config.model_config);
427        Self::with_model(model, config)
428    }
429
430    /// ALB-089: Load SafeTensors checkpoint for GPU inference (forward-only).
431    ///
432    /// Creates a `CudaTransformerTrainer` in inference mode. The optimizer
433    /// state is allocated (wasteful but simple), but `forward_logits()` only
434    /// uses the forward path. Call `forward_logits(&tokens)` to generate.
435    ///
436    /// # Arguments
437    /// * `checkpoint_dir` - Directory containing model.safetensors + config.json
438    /// * `model_config` - Transformer architecture config
439    ///
440    /// # Errors
441    ///
442    /// Returns `Err` if SafeTensors loading or CUDA initialization fails.
443    pub fn for_inference(
444        checkpoint_dir: impl AsRef<std::path::Path>,
445        model_config: crate::transformer::TransformerConfig,
446    ) -> crate::Result<Self> {
447        let dir = checkpoint_dir.as_ref();
448
449        // ALB-089: Try APR format first (our native checkpoint format), then SafeTensors
450        let model = if let Some((Some(m), _step)) =
451            crate::config::try_load_apr_for_inference(dir, &model_config)
452        {
453            m
454        } else {
455            Transformer::from_safetensors(dir, &model_config)?
456        };
457
458        let mut config = TransformerTrainConfig::new(model_config);
459        config.max_seq_len = config.model_config.max_position_embeddings;
460        Self::with_model(model, config)
461    }
462
463    /// Create a GPU-resident trainer from an existing model.
464    ///
465    /// # Errors
466    ///
467    /// Returns `Err` if CUDA initialization fails.
468    pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469        if !cuda_training_available() {
470            return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471        }
472
473        let mc = &config.model_config;
474        let max_seq_len = config.max_seq_len;
475        let hidden_size = mc.hidden_size;
476        let vocab_size = mc.vocab_size;
477        let num_layers = mc.num_hidden_layers;
478
479        // Step 1: Create CUDA trainer (initializes kernel caches)
480        let cuda_trainer = CudaTrainer::new().map_err(|e| {
481            crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482        })?;
483
484        println!(
485            "  GPU: {} ({:.1} GB)",
486            cuda_trainer.device_name(),
487            cuda_trainer.total_memory() as f64 / 1e9
488        );
489
490        let ctx = cuda_trainer.context().clone();
491        let stream = cuda_trainer.stream();
492
493        // Step 2: Pre-warm forward kernels (C-PREWARM-001)
494        // Must happen before block upload — JIT compilation needs free VRAM
495        pre_warm_forward_kernels(
496            hidden_size,
497            mc.intermediate_size,
498            mc.num_attention_heads,
499            mc.num_kv_heads,
500            mc.head_dim(),
501            max_seq_len,
502        )
503        .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505        // Step 2a: Pre-warm backward kernels (trueno#200)
506        // MUST happen before any GPU work — Blackwell's cuModuleLoadData fails
507        // with ILLEGAL_ADDRESS when called during active GPU computation.
508        {
509            use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510            let head_dim = mc.head_dim();
511            pre_warm_lora_backward_kernels(
512                hidden_size,
513                mc.num_attention_heads * head_dim,
514                mc.num_kv_heads * head_dim,
515                max_seq_len,
516                config.lora_rank.unwrap_or(0),
517                mc.intermediate_size,
518                mc.num_attention_heads,
519                config.quantize_nf4 && config.is_lora(),
520            )
521            .map_err(|e| {
522                crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523            })?;
524            eprintln!("  ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525        }
526
527        // Step 2b: Bind cuBLAS handles to training stream (ALB-075)
528        // Must happen after kernel cache init, before any GEMM calls.
529        if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530            println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531        }
532        if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533            println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534        }
535
536        // Step 3: Upload transformer blocks to GPU
537        let use_nf4 = config.quantize_nf4 && config.is_lora();
538        let cuda_blocks = Self::upload_blocks(
539            &model,
540            mc,
541            &config,
542            &ctx,
543            use_nf4,
544            num_layers,
545            hidden_size,
546            max_seq_len,
547        )?;
548
549        // Step 4: Allocate shared gradient workspace
550        let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551            crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552        })?;
553
554        // Step 5: Allocate GPU training state
555        let buf_size = max_seq_len * hidden_size;
556        let logits_size = max_seq_len * vocab_size;
557
558        // Activation checkpointing: determine which layers save their inputs.
559        // Checkpoint boundary layers (every segment_size layers) are always saved.
560        // Non-boundary layers are recomputed from the nearest checkpoint during backward.
561        let checkpointing = config.checkpoint_config.enabled;
562        let segment_size = if checkpointing {
563            let ns = config.checkpoint_config.num_segments.max(1);
564            num_layers.div_ceil(ns)
565        } else {
566            1 // Every layer is a checkpoint (no recomputation)
567        };
568        let saved_layer_mask: Vec<bool> =
569            (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571        let mut layer_inputs = Vec::with_capacity(num_layers);
572        for _ in 0..num_layers {
573            layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574                crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575            })?);
576        }
577
578        // Allocate recompute buffer if checkpointing is enabled
579        let recompute_buf = if checkpointing {
580            Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581                crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582            })?)
583        } else {
584            None
585        };
586
587        if checkpointing {
588            let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589            println!(
590                "  ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591                config.checkpoint_config.num_segments, saved_count, num_layers
592            );
593        }
594
595        // Upload final RMSNorm weight
596        let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597        let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598            crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599        })?;
600
601        let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602            crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603        })?;
604        let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605            crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606        })?;
607        let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608            crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609        })?;
610        let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611            crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612        })?;
613        let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614            crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615        })?;
616        let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617            crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618        })?;
619        let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621        })?;
622
623        // Initialize per-block optimizer states (fp32 path only; NF4 uses LoRA states)
624        let mut optimizer_states = Vec::new();
625        if !use_nf4 {
626            optimizer_states.reserve(num_layers);
627            for (i, block) in cuda_blocks.iter().enumerate() {
628                optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629                    crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630                })?);
631            }
632        }
633
634        let gpu_training = GpuPretrainState {
635            layer_inputs,
636            saved_layer_mask,
637            recompute_buf,
638            final_norm_weight,
639            blocks_output,
640            grad_buf_a,
641            grad_buf_b,
642            grad_final_norm_weight,
643            norm_output,
644            logits_buf,
645            lm_head_grad_hidden,
646            optimizer_states,
647            step: 0,
648        };
649
650        // Step 6: Upload LM head weight to GPU
651        // Use tied weights (embed_tokens.weight) or separate lm_head
652        let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653        let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654        let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655            crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656        })?;
657        let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658            crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659        })?;
660        // CRITICAL: Must zero-initialize m/v buffers. GpuBuffer::new() does NOT
661        // zero memory (cuMemAlloc returns uninitialized VRAM).
662        let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663            .map_err(|e| {
664                crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665            })?;
666        let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667            .map_err(|e| {
668                crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669            })?;
670
671        // Final norm optimizer states
672        let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673            crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674        })?;
675        let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676            crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677        })?;
678
679        // KAIZEN-053: Pre-allocate forward scratch buffers (reused every step)
680        let buf_size = max_seq_len * hidden_size;
681        let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682            crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683        })?;
684        let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685            crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686        })?;
687
688        // Sync to ensure all uploads completed
689        stream
690            .synchronize()
691            .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693        println!(
694            "  ✓ GPU training state allocated (LM head: {:.1} MB)",
695            (vocab_size * hidden_size * 4) as f64 / 1e6
696        );
697
698        // ENT-263: Allocate NF4 infrastructure (shared scratch, LoRA grad workspace, optimizer states)
699        let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700            let lora_rank = config.lora_rank.unwrap_or(16);
701
702            // C-SCRATCH-001: Shared scratch for NF4 blocks (reused across all layers)
703            let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704                crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705            })?;
706
707            // LoRA gradient workspace (shared, reused per-block like CudaGradWorkspace)
708            let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709                crate::error::Error::ConfigError(format!(
710                    "NF4 LoRA grad workspace alloc failed: {e:?}"
711                ))
712            })?;
713
714            // Per-block LoRA optimizer states
715            let mut lora_opt_states = Vec::with_capacity(num_layers);
716            for (i, block) in cuda_blocks.iter().enumerate() {
717                lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718                    crate::error::Error::ConfigError(format!(
719                        "Block {i} LoRA opt state failed: {e:?}"
720                    ))
721                })?);
722            }
723
724            println!(
725                "  ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726            );
727            (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728        } else {
729            (None, None, None)
730        };
731
732        // KAIZEN-050: loss_fn removed — cross-entropy computed by fused GPU kernel
733        // C-EMBED-GRAD-001: CPU optimizer must match YAML hyperparams (not defaults)
734        let embed_optimizer =
735            AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737        // R-038: Allocate per-block gradient accumulation buffers (CPU-side)
738        // when accumulation_steps > 1 for true gradient accumulation.
739        let grad_accum = if config.accumulation_steps > 1 {
740            let kv_hidden = mc.num_kv_heads * mc.head_dim();
741            let block_sizes =
742                super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743                    hidden_size,
744                    kv_hidden,
745                    mc.intermediate_size,
746                );
747            let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748                num_layers,
749                block_sizes,
750                vocab_size,
751                hidden_size,
752            );
753            println!(
754                "  ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755                config.accumulation_steps,
756                (accum
757                    .block_grads
758                    .iter()
759                    .map(super::grad_accumulator::BlockGradientSet::total_elements)
760                    .sum::<usize>()
761                    + accum.lm_head_grad.len()
762                    + accum.final_norm_grad.len()
763                    + accum.embedding_grad.len()) as f64
764                    * 4.0
765                    / 1e6,
766            );
767            Some(accum)
768        } else {
769            None
770        };
771
772        // ALB-091: GPU-resident gradient accumulation (eliminates D2H bottleneck).
773        // Falls back to CPU accum if GPU allocation fails.
774        let gpu_grad_accum = if config.accumulation_steps > 1 {
775            match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776                Ok(accum) => {
777                    println!("  ✓ GPU gradient accumulation enabled (ALB-091)");
778                    Some(accum)
779                }
780                Err(e) => {
781                    eprintln!(
782                        "  [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783                    );
784                    None
785                }
786            }
787        } else {
788            None
789        };
790
791        // KAIZEN-059: Pre-allocate D2H staging buffer for gradient accumulation
792        // downloads. Only needed when GPU accum is unavailable (CPU fallback path).
793        let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794            let ws_max = hidden_size * mc.intermediate_size;
795            let lm_max = vocab_size * hidden_size;
796            vec![0.0f32; ws_max.max(lm_max)]
797        } else {
798            Vec::new()
799        };
800
801        // ALB-078: Pre-allocate fused gradient clipping state.
802        // Eliminates 24 stream syncs per step by keeping norm+clip on GPU.
803        let kv_hidden = mc.num_kv_heads * mc.head_dim();
804        let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806        // R-002: Initialize gradient scaler from precision config
807        let grad_scaler = GradScaler::from_config(&config.precision_config);
808        if config.precision_config.is_mixed() {
809            println!(
810                "  ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811                config.precision_config.compute_precision,
812                grad_scaler.scale(),
813                grad_scaler.is_dynamic(),
814            );
815        }
816
817        Ok(Self {
818            model,
819            cuda_trainer,
820            cuda_blocks,
821            cuda_grad_workspace,
822            nf4_shared_scratch,
823            nf4_lora_grad_workspace,
824            nf4_lora_optimizer_states,
825            gpu_training,
826            lm_head_weight_gpu,
827            lm_head_grad_gpu,
828            lm_head_m,
829            lm_head_v,
830            final_norm_m,
831            final_norm_v,
832            embed_optimizer,
833            // KAIZEN-047: Read profile_interval before moving config into struct.
834            profiler: if config.profile_interval > 0 {
835                StepProfiler::new(true, config.profile_interval)
836            } else {
837                StepProfiler::disabled()
838            },
839            config,
840            metrics: MetricsTracker::new(),
841            step: 0,
842            accumulated_loss: 0.0,
843            accumulated_batches: 0,
844            last_grad_norm: 0.0,
845            last_embed_grad_norm: 0.0,
846            grad_accum,
847            gpu_grad_accum,
848            grad_scaler,
849            fwd_scratch_a,
850            fwd_scratch_b,
851            h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852            d2h_staging,
853            fused_clip,
854            final_norm_zero_buf: vec![0.0f32; hidden_size],
855        })
856    }
857
858    /// Upload transformer blocks to GPU (NF4 or fp32 path).
859    #[allow(clippy::too_many_arguments)]
860    fn upload_blocks(
861        model: &Transformer,
862        mc: &crate::transformer::TransformerConfig,
863        config: &TransformerTrainConfig,
864        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865        use_nf4: bool,
866        num_layers: usize,
867        hidden_size: usize,
868        max_seq_len: usize,
869    ) -> crate::Result<Vec<CudaBlock>> {
870        let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872        if use_nf4 {
873            let lora_rank = config.lora_rank.unwrap_or(16);
874            let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875            let lora_scale = lora_alpha / lora_rank as f32;
876            let head_dim = mc.head_dim();
877            let q_dim = mc.num_attention_heads * head_dim;
878            let kv_hidden = mc.num_kv_heads * head_dim;
879
880            for (i, layer) in model.layers.iter().enumerate() {
881                let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882                    .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883                    .collect();
884                let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885                let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886                    .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887                    .collect();
888                let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890                let q_norm_data = layer
891                    .self_attn
892                    .q_norm
893                    .as_ref()
894                    .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895                let k_norm_data = layer
896                    .self_attn
897                    .k_norm
898                    .as_ref()
899                    .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901                let block = crate::transformer::CudaNf4TransformerBlock::new(
902                    mc,
903                    i,
904                    ctx.clone(),
905                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
906                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
913                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
914                    max_seq_len,
915                    Some((&lora_a_q, &lora_b_q)),
916                    Some((&lora_a_v, &lora_b_v)),
917                    lora_scale,
918                    lora_rank,
919                    q_norm_data.as_deref(),
920                    k_norm_data.as_deref(),
921                )
922                .map_err(|e| {
923                    crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924                })?;
925                cuda_blocks.push(CudaBlock::Nf4(block));
926            }
927            println!("  ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928        } else {
929            for (i, layer) in model.layers.iter().enumerate() {
930                // FALSIFY-CUDA-FORWARD-PARITY-002 thread Q/K/V biases
931                // through to the CudaTransformerBlock when present
932                // (Qwen2 family use_bias=true). Pre-fix these were
933                // silently dropped → val_loss > ln(vocab) on Qwen.
934                let b_q = layer
935                    .self_attn
936                    .b_q
937                    .as_ref()
938                    .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939                let b_k = layer
940                    .self_attn
941                    .b_k
942                    .as_ref()
943                    .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944                let b_v = layer
945                    .self_attn
946                    .b_v
947                    .as_ref()
948                    .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949                let block = CudaTransformerBlock::new(
950                    mc,
951                    i,
952                    ctx.clone(),
953                    layer.input_norm.weight.data().as_slice().expect("contiguous"),
954                    layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955                    layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956                    layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957                    layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958                    layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959                    layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960                    layer.ffn.w_up.data().as_slice().expect("contiguous"),
961                    layer.ffn.w_down.data().as_slice().expect("contiguous"),
962                    max_seq_len,
963                    b_q.as_deref(),
964                    b_k.as_deref(),
965                    b_v.as_deref(),
966                )
967                .map_err(|e| {
968                    crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969                })?;
970                cuda_blocks.push(CudaBlock::Fp32(block));
971            }
972            println!("  ✓ {num_layers} transformer blocks uploaded to GPU");
973        }
974
975        Ok(cuda_blocks)
976    }
977
978    /// ALB-078: Initialize fused gradient clipping state (extracted for complexity).
979    fn init_fused_clip(
980        ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981        config: &TransformerTrainConfig,
982        hidden_size: usize,
983        kv_hidden: usize,
984        mc: &crate::transformer::TransformerConfig,
985    ) -> Option<FusedClipState> {
986        config.base.max_grad_norm?;
987        let grad_sizes: [u32; 9] = [
988            (hidden_size * hidden_size) as u32,
989            (hidden_size * kv_hidden) as u32,
990            (hidden_size * kv_hidden) as u32,
991            (hidden_size * hidden_size) as u32,
992            (hidden_size * mc.intermediate_size) as u32,
993            (hidden_size * mc.intermediate_size) as u32,
994            (mc.intermediate_size * hidden_size) as u32,
995            hidden_size as u32,
996            hidden_size as u32,
997        ];
998        match FusedClipState::new(ctx, &grad_sizes) {
999            Ok(state) => {
1000                println!(
1001                    "  ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002                    state.total_partials,
1003                    f64::from(state.total_partials) * 4.0 / 1024.0,
1004                );
1005                Some(state)
1006            }
1007            Err(e) => {
1008                println!("  ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009                None
1010            }
1011        }
1012    }
1013
1014    /// Run one forward+backward step for a single sequence.
1015    ///
1016    /// # Contract (C-GPUSTEP-001)
1017    ///
1018    /// - Precondition: `input_ids.len() == target_ids.len() <= max_seq_len`
1019    /// - Postcondition: If `accumulate_only` is false, all GPU weights updated.
1020    ///   If true, gradients accumulated into CPU buffers (no weight updates).
1021    /// - Transfer count: 1 PCIe H2D + ~1KB control (KAIZEN-050, + 24×9 D2H if accumulating)
1022    fn train_step_single(
1023        &mut self,
1024        input_ids: &[u32],
1025        target_ids: &[u32],
1026        accumulate_only: bool,
1027    ) -> Option<f32> {
1028        self.profiler.begin_step();
1029        let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030        self.profiler.finish_step();
1031        result
1032    }
1033
1034    /// Inner training step — separated so profiler always records the step.
1035    fn train_step_inner(
1036        &mut self,
1037        input_ids: &[u32],
1038        target_ids: &[u32],
1039        accumulate_only: bool,
1040    ) -> Option<f32> {
1041        let hidden_size = self.config.model_config.hidden_size;
1042        let vocab_size = self.config.model_config.vocab_size;
1043
1044        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
1045        let max_sl = self.config.max_seq_len;
1046        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048        let seq_len = input_ids.len();
1049
1050        // Steps 1-6: GPU forward pass — logits stay GPU-resident (KAIZEN-050)
1051        // (sub-phases embed, h2d, forward, norm_lm instrumented inside gpu_forward)
1052        if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053            eprintln!(
1054                "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055                 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056            );
1057            return None;
1058        }
1059
1060        // Step 7: Fused GPU cross-entropy loss + softmax backward (KAIZEN-050)
1061        // Eliminates: logits D2H (77.8MB) + CPU softmax (40ms) + grad H2D (77.8MB)
1062        self.profiler.begin(StepProfiler::LOSS);
1063        let stream = self.cuda_trainer.stream();
1064
1065        // Compute combined scale: (1/seq_len) * (1/accum_steps)
1066        //
1067        // ALB-072: Do NOT multiply by grad_scaler.scale() here. All backward
1068        // computation uses f32 GpuBuffers — there is no fp16 gradient underflow
1069        // risk. The 65536x loss scaling caused gradient overflow in early layers
1070        // (blocks 0-1 went NaN). The GradScaler remains active for the CPU
1071        // embedding path (unscale_and_check in optimizer_step) as a safety check,
1072        // but it operates with scale=1.0 effective for GPU gradients.
1073        let mut loss_scale = 1.0 / seq_len as f32;
1074        if self.config.accumulation_steps > 1 {
1075            loss_scale /= self.config.accumulation_steps as f32;
1076        }
1077
1078        // KAIZEN-052: In-place — gradient written directly to logits_buf.
1079        let loss_val = fused_cross_entropy_cuda(
1080            &mut self.gpu_training.logits_buf,
1081            target_ids,
1082            seq_len as u32,
1083            vocab_size as u32,
1084            loss_scale,
1085            stream,
1086        )
1087        .ok()?;
1088
1089        // NaN guard (replaces logits NaN check — NaN logits → NaN loss via kernel)
1090        if !loss_val.is_finite() {
1091            return None;
1092        }
1093        self.profiler.end(StepProfiler::LOSS);
1094
1095        // Steps 8-11: GPU backward pass (with or without optimizer)
1096        // (sub-phases lm_bwd, norm_bwd, blk_bwd instrumented inside gpu_backward)
1097        // KAIZEN-050: grad_logits on GPU. KAIZEN-052: grad lives in logits_buf (in-place).
1098        //
1099        // ENT-263 fix: Capture loss regardless of backward success. The NF4 backward
1100        // path may fail (e.g., gemm_nf4_backward_a stub) but the loss was already
1101        // computed by fused_cross_entropy_cuda. Dropping the loss silently causes
1102        // loss=0.0 reporting despite valid forward passes.
1103        if let Some(grad_output_is_a) =
1104            self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105        {
1106            // Step 12: Embedding backward (CPU scatter-add always accumulates)
1107            self.profiler.begin(StepProfiler::EMBED_BWD);
1108            self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110            self.profiler.end(StepProfiler::EMBED_BWD);
1111        }
1112
1113        Some(loss_val)
1114    }
1115
1116    /// GPU forward pass: embed → blocks → norm → LM head.
1117    ///
1118    /// Logits stay GPU-resident in `self.gpu_training.logits_buf` (KAIZEN-050).
1119    /// Transfers: 1 H2D (hidden states). No D2H — logits consumed by fused kernel.
1120    #[allow(unsafe_code)]
1121    fn gpu_forward(
1122        &mut self,
1123        input_ids: &[u32],
1124        seq_len: usize,
1125        hidden_size: usize,
1126        vocab_size: usize,
1127    ) -> Option<()> {
1128        contract_pre_gpu_forward!();
1129        let stream = self.cuda_trainer.stream();
1130
1131        // Embedding lookup (CPU)
1132        self.profiler.begin(StepProfiler::EMBED);
1133        let hidden = self.model.embed_tokens.forward(input_ids);
1134        let hidden_slice = hidden.data().as_slice()?;
1135        self.profiler.end(StepProfiler::EMBED);
1136
1137        // Upload hidden states to GPU (Transfer 1: H2D)
1138        // Pad to max_seq_len so D2D copies to pre-allocated layer_inputs match.
1139        // KAIZEN-053: Reuse pre-allocated scratch buffers instead of cuMemAlloc per step.
1140        // KAIZEN-056: Reuse pre-allocated h2d_staging instead of alloc per step.
1141        self.profiler.begin(StepProfiler::H2D);
1142        self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143        self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144        if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145            eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146            return None;
1147        }
1148        self.profiler.end(StepProfiler::H2D);
1149
1150        // Forward through CUDA blocks using pre-allocated ping-pong buffers.
1151        // KAIZEN-053: fwd_scratch_a/b are top-level fields (not in gpu_training)
1152        // so borrowing them doesn't conflict with gpu_training.layer_inputs.
1153        self.profiler.begin(StepProfiler::FORWARD);
1154        let mut input_is_a = true; // Track which scratch buffer is "input"
1155        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156            // Use raw pointers for the ping-pong to avoid borrow conflicts
1157            // with self.gpu_training.layer_inputs
1158            let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159                if input_is_a {
1160                    (
1161                        std::ptr::from_ref(&self.fwd_scratch_a),
1162                        std::ptr::from_mut(&mut self.fwd_scratch_b),
1163                    )
1164                } else {
1165                    (
1166                        std::ptr::from_ref(&self.fwd_scratch_b),
1167                        std::ptr::from_mut(&mut self.fwd_scratch_a),
1168                    )
1169                };
1170            if self.gpu_training.saved_layer_mask[i] {
1171                // SAFETY: Both buffers are valid GPU allocations with matching max_seq_len size.
1172                // Copy completes before block.forward() reads from input (same stream ordering).
1173                unsafe {
1174                    self.gpu_training.layer_inputs[i]
1175                        .copy_from_buffer_async(&*input_ptr, stream)
1176                        .ok()?;
1177                }
1178            }
1179            // SAFETY: input_ptr and output_ptr point to disjoint fwd_scratch_{a,b}.
1180            // ENT-263: Pass shared scratch for NF4 blocks (C-SCRATCH-001).
1181            self.profiler.begin_layer();
1182            unsafe {
1183                block
1184                    .forward(
1185                        &*input_ptr,
1186                        &mut *output_ptr,
1187                        seq_len,
1188                        stream,
1189                        self.nf4_shared_scratch.as_mut(),
1190                    )
1191                    .ok()?;
1192            }
1193            self.profiler.end_layer_fwd(i);
1194            input_is_a = !input_is_a;
1195        }
1196        self.profiler.end(StepProfiler::FORWARD);
1197
1198        // After the loop, input_is_a tells us which buffer has the final output
1199        let final_output: &GpuBuffer<f32> =
1200            if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202        // Save blocks output for final norm backward
1203        // SAFETY: Disjoint GPU buffers with matching max_seq_len sizes.
1204        self.profiler.begin(StepProfiler::NORM_LM);
1205        unsafe {
1206            self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207        }
1208
1209        // Final RMSNorm forward (GPU)
1210        // FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: thread `config.rms_norm_eps`
1211        // through so Qwen2 (1e-6) gets the right epsilon. Pre-fix the
1212        // legacy `rms_norm_forward` hardcoded 1e-5 (Llama default).
1213        rms_norm_forward_with_eps(
1214            final_output,
1215            &self.gpu_training.final_norm_weight,
1216            &mut self.gpu_training.norm_output,
1217            seq_len as u32,
1218            hidden_size as u32,
1219            self.config.model_config.rms_norm_eps,
1220            stream,
1221        )
1222        .ok()?;
1223
1224        // LM head GEMM forward (GPU)
1225        // gemm_forward treats flat (V,H) memory as (H,V) row-major, which
1226        // implicitly transposes — matching the CPU matmul's tied-weight behavior.
1227        gemm_forward(
1228            &self.gpu_training.norm_output,
1229            &self.lm_head_weight_gpu,
1230            &mut self.gpu_training.logits_buf,
1231            seq_len as u32,
1232            hidden_size as u32,
1233            vocab_size as u32,
1234            stream,
1235        )
1236        .ok()?;
1237
1238        // KAIZEN-050: Logits stay GPU-resident — no D2H transfer.
1239        // Fused cross-entropy kernel reads logits_buf directly on GPU.
1240        self.profiler.end(StepProfiler::NORM_LM);
1241
1242        Some(())
1243    }
1244
1245    /// ALB-089: Forward-only pass that returns last-position logits on CPU.
1246    ///
1247    /// Runs the same GPU forward as training but downloads only the last
1248    /// position's logits (vocab_size floats) for token sampling. No backward
1249    /// pass, no loss computation.
1250    ///
1251    /// # Contract (C-CUDA-INF-001)
1252    ///
1253    /// - Same forward path as `gpu_forward()` — identical logits
1254    /// - Only downloads `logits[seq_len-1, :]` (128 KB for 32K vocab)
1255    /// - stream.synchronize() before D2H (C-STREAMSYNC-001)
1256    pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1257        let seq_len = input_ids.len();
1258        let hidden_size = self.config.model_config.hidden_size;
1259        let vocab_size = self.config.model_config.vocab_size;
1260
1261        if seq_len == 0 || seq_len > self.config.max_seq_len {
1262            return None;
1263        }
1264
1265        // Reuse gpu_forward for the actual computation
1266        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1267
1268        // C-STREAMSYNC-001: synchronize before D2H
1269        let stream = self.cuda_trainer.stream();
1270        stream.synchronize().ok()?;
1271
1272        // Download last position logits only: logits_buf[seq_len-1, :]
1273        let offset = (seq_len - 1) * vocab_size;
1274        let mut logits = vec![0.0f32; vocab_size];
1275        self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1276
1277        Some(logits)
1278    }
1279
1280    /// GPU backward pass with interleaved per-block optimizer step.
1281    ///
1282    /// Each block's backward writes weight gradients to the shared `CudaGradWorkspace`.
1283    /// Recompute layer inputs for a segment during backward (activation checkpointing).
1284    ///
1285    /// When checkpointing is enabled, non-checkpoint layers don't save their inputs
1286    /// during forward. Before their backward pass, we recompute from the nearest
1287    /// checkpoint by re-running forward through intermediate blocks.
1288    ///
1289    /// This recomputes the entire segment [checkpoint..=target_layer], storing
1290    /// intermediate layer_inputs so subsequent layers in the same segment don't
1291    /// need redundant recomputation.
1292    ///
1293    /// # Contract (R-021)
1294    ///
1295    /// After this call, `layer_inputs[i]` is valid for all i in [checkpoint..=target_layer].
1296    #[allow(unsafe_code)]
1297    fn recompute_segment(
1298        gpu_training: &mut GpuPretrainState,
1299        cuda_blocks: &mut [CudaBlock],
1300        nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1301        target_layer: usize,
1302        seq_len: usize,
1303        stream: &CudaStream,
1304    ) -> Option<()> {
1305        // Find nearest saved checkpoint at or before target
1306        let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1307
1308        if seg_start == target_layer {
1309            return Some(()); // Already saved
1310        }
1311
1312        // Copy checkpoint input to recompute_buf as starting point.
1313        // SAFETY: recompute_buf and layer_inputs are disjoint allocations.
1314        let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1315        unsafe {
1316            recompute_buf
1317                .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1318                .ok()?;
1319        }
1320
1321        // Forward through blocks [seg_start..target_layer], saving intermediate inputs.
1322        // For block i, input → block i → output becomes input for block i+1.
1323        // We save output (= input to block i+1) in layer_inputs[i+1].
1324        //
1325        // Buffer pattern:
1326        //   i == seg_start: input = recompute_buf, output = layer_inputs[seg_start+1]
1327        //   i > seg_start:  input = layer_inputs[i], output = layer_inputs[i+1]
1328        //
1329        // SAFETY: split_at_mut ensures non-overlapping borrows of layer_inputs.
1330        // recompute_buf is separate from layer_inputs.
1331        for i in seg_start..target_layer {
1332            if i == seg_start {
1333                // Input is in recompute_buf, output goes to layer_inputs[i+1]
1334                let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1335                let li = &mut gpu_training.layer_inputs;
1336                unsafe {
1337                    cuda_blocks[i]
1338                        .forward(
1339                            &*recompute_ptr,
1340                            &mut li[i + 1],
1341                            seq_len,
1342                            stream,
1343                            nf4_shared_scratch.as_mut(),
1344                        )
1345                        .ok()?;
1346                }
1347            } else {
1348                // Input = layer_inputs[i], output = layer_inputs[i+1]
1349                let li = &mut gpu_training.layer_inputs;
1350                let (left, right) = li.split_at_mut(i + 1);
1351                cuda_blocks[i]
1352                    .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1353                    .ok()?;
1354            }
1355        }
1356
1357        Some(())
1358    }
1359
1360    /// Since `gemm_backward_b` overwrites (not accumulates), we must run each block's
1361    /// optimizer step immediately after its backward, before the next block overwrites
1362    /// the workspace. This also enables per-block gradient clipping.
1363    ///
1364    /// When `accumulate_only` is true (R-038 gradient accumulation), the per-block
1365    /// optimizer steps are skipped and workspace gradients are downloaded to CPU-side
1366    /// `PerBlockGradientAccumulator` instead. LM head and final norm gradients are
1367    /// also downloaded and accumulated. The optimizer step is deferred until
1368    /// `gpu_optimizer_from_accum()` is called.
1369    ///
1370    /// Returns `grad_output_is_a` flag for embedding backward.
1371    /// Transfer: 0 H2D (KAIZEN-050/052: grad in logits_buf) + 24×9 D2H if accumulating.
1372    #[allow(unsafe_code)]
1373    fn gpu_backward(
1374        &mut self,
1375        seq_len: usize,
1376        hidden_size: usize,
1377        vocab_size: usize,
1378        accumulate_only: bool,
1379    ) -> Option<bool> {
1380        let stream = self.cuda_trainer.stream();
1381        let max_grad_norm = self.config.base.max_grad_norm;
1382        let lr = self.current_lr();
1383        // ALB-072: No inv_scale needed — loss_scale no longer includes grad_scaler.
1384        let beta1 = self.config.beta1;
1385        let beta2 = self.config.beta2;
1386        let weight_decay = self.config.weight_decay;
1387
1388        // KAIZEN-050: grad_logits GPU-resident. KAIZEN-052: grad lives in logits_buf (in-place).
1389        // No separate grad buffer. No GRAD_H2D transfer.
1390
1391        // LM head GEMM backward
1392        self.profiler.begin(StepProfiler::LM_BWD);
1393        gemm_backward_a(
1394            &self.gpu_training.logits_buf,
1395            &self.lm_head_weight_gpu,
1396            &mut self.gpu_training.lm_head_grad_hidden,
1397            seq_len as u32,
1398            hidden_size as u32,
1399            vocab_size as u32,
1400            stream,
1401        )
1402        .ok()?;
1403
1404        gemm_backward_b(
1405            &self.gpu_training.norm_output,
1406            &self.gpu_training.logits_buf,
1407            &mut self.lm_head_grad_gpu,
1408            seq_len as u32,
1409            hidden_size as u32,
1410            vocab_size as u32,
1411            stream,
1412        )
1413        .ok()?;
1414
1415        // Clip LM head weight gradient
1416        // KAIZEN-049: GPU norm reduction.
1417        // KAIZEN-051: No explicit sync needed — same stream ordering.
1418        // ALB-071: Always compute LM head grad norm for observability (R-004).
1419        // C-CLIP-001: squared_sum_cuda returns ||g||². Take sqrt for L2 norm (entrenar#311).
1420        let lm_sq_norm =
1421            squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1422                .unwrap_or(0.0);
1423        let lm_norm = lm_sq_norm.sqrt(); // L2 norm, NOT squared
1424        self.last_grad_norm = lm_norm; // R-004: capture for observability
1425                                       // C-BACKPARITY-001: LM head gradient norm tracing (pre-clip).
1426        if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1427            eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1428            // Also trace the grad_hidden flowing to blocks
1429            let gh_sq = squared_sum_cuda(
1430                &self.gpu_training.lm_head_grad_hidden,
1431                self.gpu_training.lm_head_grad_hidden.len() as u32,
1432                stream,
1433            )
1434            .unwrap_or(0.0);
1435            eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1436        }
1437        if let Some(max_norm) = max_grad_norm {
1438            let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1439            let n = self.lm_head_grad_gpu.len() as u32;
1440            let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1441        }
1442        self.profiler.end(StepProfiler::LM_BWD);
1443
1444        // Final RMSNorm backward
1445        self.profiler.begin(StepProfiler::NORM_BWD);
1446        // Zero grad_final_norm_weight before backward — kernel accumulates via atomicAdd
1447        self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1448        rms_norm_backward(
1449            &self.gpu_training.blocks_output,
1450            &self.gpu_training.final_norm_weight,
1451            &self.gpu_training.lm_head_grad_hidden,
1452            &mut self.gpu_training.grad_buf_a,
1453            &mut self.gpu_training.grad_final_norm_weight,
1454            seq_len as u32,
1455            hidden_size as u32,
1456            1e-5_f32,
1457            stream,
1458        )
1459        .ok()?;
1460
1461        // Clip final norm weight gradient
1462        // KAIZEN-051: No explicit sync needed — same stream ordering as LM head clip.
1463        if let Some(max_norm) = max_grad_norm {
1464            let (scale, _) = Self::compute_clip_scale_with_norm(
1465                &self.gpu_training.grad_final_norm_weight,
1466                max_norm,
1467                stream,
1468            );
1469            let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1470            let _ =
1471                gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1472        }
1473        self.profiler.end(StepProfiler::NORM_BWD);
1474
1475        // R-038: Either accumulate non-block grads or run non-block optimizer.
1476        if accumulate_only {
1477            // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1478            if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1479                let _ = gpu_accum.accumulate_nonblock(
1480                    &self.lm_head_grad_gpu,
1481                    &self.gpu_training.grad_final_norm_weight,
1482                    stream,
1483                );
1484            } else {
1485                stream.synchronize().ok()?;
1486                Self::download_nonblock_grads_to_accum(
1487                    &self.lm_head_grad_gpu,
1488                    &self.gpu_training.grad_final_norm_weight,
1489                    &mut self.grad_accum,
1490                    &mut self.d2h_staging,
1491                )?;
1492            }
1493        } else {
1494            Self::run_nonblock_optimizer_step(
1495                &mut self.gpu_training,
1496                Some(&mut self.lm_head_weight_gpu),
1497                &self.lm_head_grad_gpu,
1498                &mut self.lm_head_m,
1499                &mut self.lm_head_v,
1500                &mut self.final_norm_m,
1501                &mut self.final_norm_v,
1502                lr,
1503                beta1,
1504                beta2,
1505                weight_decay,
1506                stream,
1507            );
1508        }
1509
1510        // Backward through blocks in reverse, with interleaved clip + optimizer.
1511        // Each block's backward writes weight gradients to shared CudaGradWorkspace.
1512        //
1513        // SAFETY: grad_buf_a and grad_buf_b are disjoint fields. Raw pointers
1514        // allow alternating read/write without violating aliasing rules.
1515        self.profiler.begin(StepProfiler::BLK_BWD);
1516        let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1517        let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1518        let mut grad_output_is_a = true;
1519        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1520
1521        for layer_idx in (0..self.cuda_blocks.len()).rev() {
1522            // Activation checkpointing: if this layer's input wasn't saved during
1523            // forward, recompute the segment from the nearest checkpoint.
1524            if !self.gpu_training.saved_layer_mask[layer_idx] {
1525                Self::recompute_segment(
1526                    &mut self.gpu_training,
1527                    &mut self.cuda_blocks,
1528                    &mut self.nf4_shared_scratch,
1529                    layer_idx,
1530                    seq_len,
1531                    stream,
1532                )?;
1533            }
1534
1535            let (grad_output, grad_input) = unsafe {
1536                if grad_output_is_a {
1537                    (&*grad_a_ptr, &mut *grad_b_ptr)
1538                } else {
1539                    (&*grad_b_ptr, &mut *grad_a_ptr)
1540                }
1541            };
1542
1543            self.profiler.begin_layer();
1544            if use_nf4 {
1545                // ENT-263: NF4 backward — LoRA gradient computation
1546                // Uses backward_nf4() which computes gradients for LoRA weights and norms only.
1547                // output_scratch reuses grad_buf_a/b as temporary storage for recomputed forward.
1548                let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1549                    grad_b_ptr // grad_input is in b, use as output_scratch too (will be overwritten)
1550                } else {
1551                    grad_a_ptr
1552                };
1553                // We need a separate output_scratch. Reuse blocks_output as scratch since
1554                // it was already consumed for norm backward above.
1555                match self.cuda_blocks[layer_idx].backward_nf4(
1556                    &self.gpu_training.layer_inputs[layer_idx],
1557                    grad_output,
1558                    grad_input,
1559                    &mut self.gpu_training.blocks_output, // reuse as output_scratch
1560                    seq_len,
1561                    stream,
1562                    self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1563                    self.nf4_lora_grad_workspace
1564                        .as_mut()
1565                        .expect("NF4 requires LoRA grad workspace"),
1566                ) {
1567                    Ok(()) => {}
1568                    Err(e) => {
1569                        eprintln!(
1570                            "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1571                            layer_idx, e, seq_len, self.config.model_config.hidden_size
1572                        );
1573                        return None;
1574                    }
1575                }
1576
1577                // ENT-265: Clip LoRA gradients before optimizer step.
1578                // Without this, NF4 LoRA grads are unbounded — causes weight
1579                // divergence and embedding grad explosion (Run 7c: 26M at step 225).
1580                if let Some(max_norm) = max_grad_norm {
1581                    self.nf4_lora_grad_workspace
1582                        .as_mut()
1583                        .expect("NF4 requires LoRA grad ws")
1584                        .clip_gradients(max_norm, stream);
1585                }
1586
1587                // NF4 LoRA optimizer step — always runs, even during accumulation.
1588                //
1589                // BUG FIX (entrenar#264): Previously gated by `if !accumulate_only`.
1590                // Design: NF4 LoRA has ~6M params, so we scale lr by 1/accum_steps
1591                // for micro-batches instead of accumulating gradients.
1592                {
1593                    let step = self.gpu_training.step;
1594                    let effective_lr = if accumulate_only {
1595                        lr / self.config.accumulation_steps as f32
1596                    } else {
1597                        lr
1598                    };
1599                    if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1600                        let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1601                            &mut opt_states[layer_idx],
1602                            step,
1603                            effective_lr,
1604                            beta1,
1605                            beta2,
1606                            1e-8,
1607                            weight_decay,
1608                            stream,
1609                            self.nf4_lora_grad_workspace
1610                                .as_ref()
1611                                .expect("NF4 requires LoRA grad ws"),
1612                        );
1613                    }
1614                }
1615            } else {
1616                // Standard fp32 backward path
1617                self.cuda_blocks[layer_idx]
1618                    .backward(
1619                        &self.gpu_training.layer_inputs[layer_idx],
1620                        grad_output,
1621                        grad_input,
1622                        seq_len,
1623                        stream,
1624                        &mut self.cuda_grad_workspace,
1625                    )
1626                    .ok()?;
1627
1628                // C-CLIP-001 / entrenar#312: DISABLED per-block gradient clipping.
1629                // Per-block clipping distorts gradient flow across layers.
1630
1631                // C-BACKPARITY-001: Per-block gradient norm tracing for parity testing.
1632                // Only runs when ENTRENAR_TRACE_GRADIENTS=1 — zero overhead in production.
1633                if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1634                    let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1635                        &self.cuda_grad_workspace,
1636                        f32::MAX,
1637                        stream,
1638                    );
1639                    // Also trace the activation gradient (flows between blocks)
1640                    let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1641                        .unwrap_or(0.0);
1642                    let act_gnorm = act_sq.sqrt();
1643                    eprintln!(
1644                        "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1645                    );
1646                }
1647
1648                // R-038: Either accumulate workspace grads or run optimizer per-block.
1649                if accumulate_only {
1650                    // ALB-091: GPU-resident accumulation (no sync, no D2H) or CPU fallback.
1651                    if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1652                        let _ = gpu_accum.accumulate_block(
1653                            &self.cuda_grad_workspace,
1654                            layer_idx,
1655                            stream,
1656                        );
1657                    } else {
1658                        // CPU fallback: SYNC + D2H (ALB-065 / Rule 6).
1659                        stream.synchronize().ok()?;
1660                        if let Some(accum) = &mut self.grad_accum {
1661                            Self::download_workspace_to_accum(
1662                                &self.cuda_grad_workspace,
1663                                accum,
1664                                layer_idx,
1665                                &mut self.d2h_staging,
1666                            )?;
1667                        }
1668                    }
1669                } else {
1670                    // Per-block optimizer step: consume workspace gradients before next block overwrites
1671                    let step = self.gpu_training.step;
1672                    let _ = self.cuda_blocks[layer_idx].optimizer_step(
1673                        &mut self.gpu_training.optimizer_states[layer_idx],
1674                        step,
1675                        lr,
1676                        beta1,
1677                        beta2,
1678                        1e-8,
1679                        weight_decay,
1680                        stream,
1681                        &self.cuda_grad_workspace,
1682                    );
1683                }
1684            }
1685
1686            self.profiler.end_layer_bwd(layer_idx);
1687            grad_output_is_a = !grad_output_is_a;
1688        }
1689
1690        stream.synchronize().ok()?;
1691        self.profiler.end(StepProfiler::BLK_BWD);
1692
1693        Some(grad_output_is_a)
1694    }
1695
1696    /// R-038: Download non-block (LM head + final norm) gradients to CPU accumulator.
1697    /// Static method to avoid borrow conflicts.
1698    // KAIZEN-044: Pre-allocate single buffer for LM head + norm D2H downloads.
1699    // lm_head_grad is vocab×hidden (389M elements = 1.5 GB for Qwen3-4B).
1700    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1701    fn download_nonblock_grads_to_accum(
1702        lm_head_grad: &GpuBuffer<f32>,
1703        final_norm_grad: &GpuBuffer<f32>,
1704        grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1705        host: &mut [f32],
1706    ) -> Option<()> {
1707        let accum = grad_accum.as_mut()?;
1708
1709        let lm_slice = &mut host[..lm_head_grad.len()];
1710        lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1711        for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1712            *d += s;
1713        }
1714
1715        let norm_slice = &mut host[..final_norm_grad.len()];
1716        final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1717        for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1718            *d += s;
1719        }
1720        Some(())
1721    }
1722
1723    /// Run LM head + final norm optimizer step (non-accumulating path).
1724    /// Static method to avoid borrow conflicts with `stream`.
1725    #[allow(clippy::too_many_arguments)]
1726    fn run_nonblock_optimizer_step(
1727        gpu_training: &mut GpuPretrainState,
1728        lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1729        lm_head_grad_gpu: &GpuBuffer<f32>,
1730        lm_head_m: &mut GpuBuffer<f32>,
1731        lm_head_v: &mut GpuBuffer<f32>,
1732        final_norm_m: &mut GpuBuffer<f32>,
1733        final_norm_v: &mut GpuBuffer<f32>,
1734        lr: f32,
1735        beta1: f32,
1736        beta2: f32,
1737        weight_decay: f32,
1738        stream: &CudaStream,
1739    ) {
1740        gpu_training.step += 1;
1741        let step = gpu_training.step;
1742
1743        if let Some(lm_head_weight) = lm_head_weight_gpu {
1744            let n_lm = lm_head_weight.len() as u32;
1745            let _ = adamw_step_cuda(
1746                lm_head_weight,
1747                lm_head_grad_gpu,
1748                lm_head_m,
1749                lm_head_v,
1750                lr,
1751                beta1,
1752                beta2,
1753                1e-8,
1754                weight_decay,
1755                step,
1756                n_lm,
1757                stream,
1758            );
1759        }
1760
1761        let n_norm = gpu_training.final_norm_weight.len() as u32;
1762        let _ = adamw_step_cuda(
1763            &mut gpu_training.final_norm_weight,
1764            &gpu_training.grad_final_norm_weight,
1765            final_norm_m,
1766            final_norm_v,
1767            lr,
1768            beta1,
1769            beta2,
1770            1e-8,
1771            weight_decay,
1772            step,
1773            n_norm,
1774            stream,
1775        );
1776    }
1777
1778    /// R-038: Download shared CudaGradWorkspace to CPU per-block accumulation buffers.
1779    ///
1780    /// Static method to avoid borrow conflicts with `stream` (same pattern as
1781    /// `recompute_segment`). Must be called after stream.synchronize() (ALB-065 / Rule 6).
1782    // KAIZEN-044: Pre-allocate a single host buffer for all D2H downloads
1783    // in download_workspace_to_accum. Was allocating vec![0.0f32; len] × 9 buffers.
1784    // KAIZEN-059: Host buffer now passed in (d2h_staging) — zero per-call allocations.
1785    fn download_workspace_to_accum(
1786        ws: &CudaGradWorkspace,
1787        accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1788        layer_idx: usize,
1789        host: &mut [f32],
1790    ) -> Option<()> {
1791        let bg = &mut accum.block_grads[layer_idx];
1792
1793        use super::grad_accumulator::component;
1794        let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1795            (&ws.grad_w_q, component::W_Q),
1796            (&ws.grad_w_k, component::W_K),
1797            (&ws.grad_w_v, component::W_V),
1798            (&ws.grad_w_o, component::W_O),
1799            (&ws.grad_gate, component::GATE),
1800            (&ws.grad_up, component::UP),
1801            (&ws.grad_down, component::DOWN),
1802            (&ws.grad_input_norm, component::INPUT_NORM),
1803            (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1804        ];
1805
1806        for (gpu_buf, comp_idx) in &bufs_and_components {
1807            let slice = &mut host[..gpu_buf.len()];
1808            gpu_buf.copy_to_host_at(slice, 0).ok()?;
1809            for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1810                *d += s;
1811            }
1812        }
1813        Some(())
1814    }
1815
1816    /// R-038: Upload averaged CPU accumulation buffers to GPU workspace and run
1817    /// optimizer step for all blocks + LM head + final norm.
1818    ///
1819    /// Called once after `accumulation_steps` micro-batches have been accumulated.
1820    /// ALB-091: Run optimizer step from GPU-resident accumulated gradients.
1821    /// D2D copy accum → workspace, then run per-block optimizer. Zero accum after.
1822    fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1823        let stream = self.cuda_trainer.stream();
1824        let lr = self.current_lr();
1825        let beta1 = self.config.beta1;
1826        let beta2 = self.config.beta2;
1827        let weight_decay = self.config.weight_decay;
1828
1829        // Sync once to ensure all accumulation kernels complete
1830        stream.synchronize().ok()?;
1831
1832        self.gpu_training.step += 1;
1833        let step = self.gpu_training.step;
1834
1835        // Upload GPU accum → workspace (D2D) and run optimizer per block
1836        let gpu_accum = self.gpu_grad_accum.as_ref()?;
1837        for layer_idx in 0..self.cuda_blocks.len() {
1838            gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1839
1840            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1841                &mut self.gpu_training.optimizer_states[layer_idx],
1842                step,
1843                lr,
1844                beta1,
1845                beta2,
1846                1e-8,
1847                weight_decay,
1848                stream,
1849                &self.cuda_grad_workspace,
1850            );
1851        }
1852
1853        // LM head: D2D copy accum → grad buffer, then optimizer step
1854        gpu_accum
1855            .upload_nonblock(
1856                &mut self.lm_head_grad_gpu,
1857                &mut self.gpu_training.grad_final_norm_weight,
1858            )
1859            .ok()?;
1860
1861        let n_lm = self.lm_head_weight_gpu.len() as u32;
1862        let _ = adamw_step_cuda(
1863            &mut self.lm_head_weight_gpu,
1864            &self.lm_head_grad_gpu,
1865            &mut self.lm_head_m,
1866            &mut self.lm_head_v,
1867            lr,
1868            beta1,
1869            beta2,
1870            1e-8,
1871            weight_decay,
1872            step,
1873            n_lm,
1874            stream,
1875        );
1876
1877        // Final norm optimizer step
1878        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1879        let _ = adamw_step_cuda(
1880            &mut self.gpu_training.final_norm_weight,
1881            &self.gpu_training.grad_final_norm_weight,
1882            &mut self.final_norm_m,
1883            &mut self.final_norm_v,
1884            lr,
1885            beta1,
1886            beta2,
1887            1e-8,
1888            weight_decay,
1889            step,
1890            n_norm,
1891            stream,
1892        );
1893
1894        stream.synchronize().ok()?;
1895
1896        // Zero accum for next window
1897        if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1898            let _ = gpu_accum.zero_all();
1899        }
1900
1901        Some(())
1902    }
1903
1904    #[allow(unsafe_code)]
1905    fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1906        let stream = self.cuda_trainer.stream();
1907        let lr = self.current_lr();
1908        let beta1 = self.config.beta1;
1909        let beta2 = self.config.beta2;
1910        let weight_decay = self.config.weight_decay;
1911
1912        // Average accumulated gradients
1913        let accum = self.grad_accum.as_mut()?;
1914        accum.average();
1915
1916        // Jidoka: check for NaN/Inf before applying
1917        if accum.has_non_finite() {
1918            println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1919            accum.zero_all();
1920            return Some(());
1921        }
1922
1923        self.gpu_training.step += 1;
1924        let step = self.gpu_training.step;
1925
1926        // Upload accumulated gradients and run optimizer for each block
1927        use super::grad_accumulator::component;
1928        for layer_idx in 0..self.cuda_blocks.len() {
1929            let bg = &accum.block_grads[layer_idx];
1930
1931            // Upload accumulated gradients to shared workspace
1932            // SAFETY: async host-to-device copies within the training stream; host buffers
1933            // (bg.components) are stable for the duration of the stream operations.
1934            unsafe {
1935                self.cuda_grad_workspace
1936                    .grad_w_q
1937                    .copy_from_host_async(&bg.components[component::W_Q], stream)
1938                    .ok()?;
1939                self.cuda_grad_workspace
1940                    .grad_w_k
1941                    .copy_from_host_async(&bg.components[component::W_K], stream)
1942                    .ok()?;
1943                self.cuda_grad_workspace
1944                    .grad_w_v
1945                    .copy_from_host_async(&bg.components[component::W_V], stream)
1946                    .ok()?;
1947                self.cuda_grad_workspace
1948                    .grad_w_o
1949                    .copy_from_host_async(&bg.components[component::W_O], stream)
1950                    .ok()?;
1951                self.cuda_grad_workspace
1952                    .grad_gate
1953                    .copy_from_host_async(&bg.components[component::GATE], stream)
1954                    .ok()?;
1955                self.cuda_grad_workspace
1956                    .grad_up
1957                    .copy_from_host_async(&bg.components[component::UP], stream)
1958                    .ok()?;
1959                self.cuda_grad_workspace
1960                    .grad_down
1961                    .copy_from_host_async(&bg.components[component::DOWN], stream)
1962                    .ok()?;
1963                self.cuda_grad_workspace
1964                    .grad_input_norm
1965                    .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1966                    .ok()?;
1967                self.cuda_grad_workspace
1968                    .grad_post_attn_norm
1969                    .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1970                    .ok()?;
1971            }
1972
1973            // Run optimizer step with uploaded averaged gradients
1974            let _ = self.cuda_blocks[layer_idx].optimizer_step(
1975                &mut self.gpu_training.optimizer_states[layer_idx],
1976                step,
1977                lr,
1978                beta1,
1979                beta2,
1980                1e-8,
1981                weight_decay,
1982                stream,
1983                &self.cuda_grad_workspace,
1984            );
1985        }
1986
1987        // Upload accumulated LM head gradients and run AdamW step
1988        // entrenar#314: Skip GPU LM head optimizer for tied weights.
1989        // SAFETY: async host-to-device copy; host buffer (accum.lm_head_grad) is stable.
1990        unsafe {
1991            self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1992        }
1993        let n_lm = self.lm_head_weight_gpu.len() as u32;
1994        let _ = adamw_step_cuda(
1995            &mut self.lm_head_weight_gpu,
1996            &self.lm_head_grad_gpu,
1997            &mut self.lm_head_m,
1998            &mut self.lm_head_v,
1999            lr,
2000            beta1,
2001            beta2,
2002            1e-8,
2003            weight_decay,
2004            step,
2005            n_lm,
2006            stream,
2007        );
2008
2009        // Upload accumulated final norm gradients and run AdamW step
2010        // SAFETY: async host-to-device copy; host buffer (accum.final_norm_grad) is stable.
2011        unsafe {
2012            self.gpu_training
2013                .grad_final_norm_weight
2014                .copy_from_host_async(&accum.final_norm_grad, stream)
2015                .ok()?;
2016        }
2017        let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2018        let _ = adamw_step_cuda(
2019            &mut self.gpu_training.final_norm_weight,
2020            &self.gpu_training.grad_final_norm_weight,
2021            &mut self.final_norm_m,
2022            &mut self.final_norm_v,
2023            lr,
2024            beta1,
2025            beta2,
2026            1e-8,
2027            weight_decay,
2028            step,
2029            n_norm,
2030            stream,
2031        );
2032
2033        stream.synchronize().ok()?;
2034
2035        // Zero accum for next window
2036        accum.zero_all();
2037        Some(())
2038    }
2039
2040    /// Compute gradient L2 norm via GPU reduction kernel (KAIZEN-049).
2041    ///
2042    /// Runs `SquaredSumKernel` on GPU, downloads only `num_blocks` partial sums (~1KB)
2043    /// instead of the full buffer (128MB for lm_head). Falls back to CPU download on error.
2044    ///
2045    /// # Contract (C-CLIPNORM-GPU-001)
2046    ///
2047    /// - **Precondition**: `buf.len() > 0`, stream is synchronized with prior kernel
2048    /// - **Postcondition**: `grad_norm ≈ sqrt(sum(buf[i]^2))`, `scale = min(1, max_norm/norm)`
2049    /// - **Transfer**: ~1KB D2H (num_blocks × 4B) vs n×4B (128MB for 32M elements)
2050    ///
2051    /// R-004: Returns `(clip_scale, grad_norm)` for observability.
2052    fn compute_clip_scale_with_norm(
2053        buf: &GpuBuffer<f32>,
2054        max_norm: f32,
2055        stream: &CudaStream,
2056    ) -> (f32, f32) {
2057        let n = buf.len() as u32;
2058        // Try GPU reduction first — ~1KB D2H instead of n×4 bytes
2059        let grad_norm = match squared_sum_cuda(buf, n, stream) {
2060            Ok(norm) => norm,
2061            Err(_) => {
2062                // Fallback: full D2H (original path)
2063                let mut host = vec![0.0f32; buf.len()];
2064                if buf.copy_to_host_at(&mut host, 0).is_err() {
2065                    return (1.0, 0.0);
2066                }
2067                let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2068                sq_sum.sqrt() as f32
2069            }
2070        };
2071        let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2072        (scale, grad_norm)
2073    }
2074
2075    /// Download embedding gradient from GPU, clip, and scatter-add into CPU weight.
2076    ///
2077    /// # Contract (C-EMBED-GRAD-001)
2078    ///
2079    /// The activation gradient from block[0]'s backward is unclipped (per-block clipping
2080    /// only applies to weight gradients in the shared workspace). For deep networks with
2081    /// random init, this gradient can overflow f32, producing NaN in the CPU AdamW.
2082    /// We clip the activation gradient to max_grad_norm before scatter-adding.
2083    #[allow(unsafe_code)]
2084    fn embed_backward(
2085        &mut self,
2086        input_ids: &[u32],
2087        _seq_len: usize,
2088        hidden_size: usize,
2089        vocab_size: usize,
2090        grad_output_is_a: bool,
2091    ) -> Option<()> {
2092        // The final backward output is in whichever buffer was last written
2093        let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2094        let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2095        let embed_grad_buf = unsafe {
2096            if grad_output_is_a {
2097                &*grad_a_ptr
2098            } else {
2099                &*grad_b_ptr
2100            }
2101        };
2102        let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2103
2104        // C-EMBED-GRAD-001: ALWAYS clip activation gradient before scatter-add.
2105        // Without this, 24-layer random-init backward amplifies gradients to ~1e35,
2106        // which overflows the CPU AdamW's second moment buffer.
2107        //
2108        // ALB-071: Decoupled from general grad_clip config. Embed activation gradient
2109        // clipping is a SAFETY constraint (prevents NaN), not a training hyperparameter.
2110        // Uses dedicated max_embed_grad_norm (default 1.0) independent of weight grad_clip.
2111        let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2112        {
2113            let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2114            let grad_norm = sq_sum.sqrt() as f32;
2115            self.last_embed_grad_norm = grad_norm; // R-040: per-parameter-group tracking
2116            if grad_norm > embed_clip_norm {
2117                let scale = embed_clip_norm / grad_norm;
2118                for g in &mut embed_grad_data {
2119                    *g *= scale;
2120                }
2121            }
2122        }
2123
2124        // KAIZEN-048: In-place scatter-add via grad_cell().borrow_mut().
2125        // Before: 3 × 128MB clones per step (grad() deep-copies Array1).
2126        // After: zero clones — mutate existing gradient buffer directly.
2127        let embed_weight = &mut self.model.embed_tokens.weight;
2128        let grad_cell = embed_weight.grad_cell();
2129        let mut grad_ref = grad_cell.borrow_mut();
2130        if grad_ref.is_none() {
2131            *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2132        }
2133        if let Some(grad) = grad_ref.as_mut() {
2134            for (pos, &token_id) in input_ids.iter().enumerate() {
2135                let tid = token_id as usize;
2136                if tid < vocab_size {
2137                    let src = pos * hidden_size;
2138                    let dst = tid * hidden_size;
2139                    for h in 0..hidden_size {
2140                        grad[dst + h] += embed_grad_data[src + h];
2141                    }
2142                }
2143            }
2144        }
2145        Some(())
2146    }
2147
2148    /// Apply optimizer step to CPU embedding and update metrics.
2149    ///
2150    /// GPU block optimizer steps now run interleaved with backward in `gpu_backward()`.
2151    /// LM head and final norm optimizer steps also run in `gpu_backward()`.
2152    /// This method handles only CPU embedding and bookkeeping.
2153    fn optimizer_step(&mut self) {
2154        // ALB-072: Gradients are no longer scaled by grad_scaler (loss_scale excludes
2155        // grad_scaler.scale()). All backward computation uses f32 — no fp16 underflow
2156        // risk. Skip unscaling; just update scaler as successful.
2157        self.grad_scaler.update(true);
2158
2159        // ALB-079: Sync CPU embedding optimizer lr with cosine schedule
2160        self.embed_optimizer.set_lr(self.current_lr());
2161        // CPU optimizer step for embedding weight
2162        let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2163        self.embed_optimizer.step_refs(&mut embed_params);
2164
2165        self.step += 1;
2166        self.metrics.losses.push(self.accumulated_loss);
2167        self.metrics.increment_step();
2168
2169        self.accumulated_loss = 0.0;
2170        self.accumulated_batches = 0;
2171    }
2172
2173    /// Process a batch (forward + backward + optimizer step with accumulation).
2174    ///
2175    /// R-038: When `accumulation_steps > 1`, runs forward+backward without optimizer
2176    /// for each micro-batch, downloading per-block weight gradients to CPU-side
2177    /// `PerBlockGradientAccumulator`. After `accumulation_steps` batches, averages
2178    /// the accumulated gradients, uploads them to GPU, and runs a single optimizer step.
2179    ///
2180    /// When `accumulation_steps == 1` (default), runs forward+backward+optimizer
2181    /// immediately per sequence (original behavior).
2182    ///
2183    /// Returns average loss for the batch.
2184    pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2185        if batch.batch_size == 0 {
2186            return 0.0;
2187        }
2188
2189        let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2190
2191        if self.accumulated_batches == 0 {
2192            // Zero embedding gradients at start of accumulation window
2193            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2194        }
2195
2196        let mut total_loss = 0.0;
2197        let mut valid_count = 0;
2198
2199        for i in 0..batch.batch_size {
2200            let Some(input_ids) = batch.get_input(i) else {
2201                continue;
2202            };
2203            let Some(target_ids) = batch.get_target(i) else {
2204                continue;
2205            };
2206
2207            // R-038: When accumulating, run backward without optimizer (accumulate_only=true).
2208            // Gradients are downloaded to CPU per-block accum buffers. Embedding grads are
2209            // scatter-added normally (they're already on CPU).
2210            if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2211                total_loss += loss;
2212                valid_count += 1;
2213                if accumulating {
2214                    if let Some(accum) = &mut self.gpu_grad_accum {
2215                        accum.accumulated_count += 1;
2216                    } else if let Some(accum) = &mut self.grad_accum {
2217                        accum.accumulated_count += 1;
2218                    }
2219                }
2220            }
2221        }
2222
2223        let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2224
2225        // Debug: help diagnose loss=0.0 when gradients are non-zero
2226        if avg_loss == 0.0 && valid_count > 0 {
2227            eprintln!(
2228                "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2229                valid_count, total_loss, batch.batch_size
2230            );
2231        }
2232
2233        self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2234        self.accumulated_batches += 1;
2235
2236        if self.accumulated_batches >= self.config.accumulation_steps {
2237            if accumulating {
2238                // ALB-091: Prefer GPU-resident accum path (zero D2H), fall back to CPU.
2239                if self.gpu_grad_accum.is_some() {
2240                    self.gpu_optimizer_from_gpu_accum();
2241                } else {
2242                    self.gpu_optimizer_from_accum();
2243                }
2244            }
2245            self.optimizer_step();
2246        }
2247
2248        avg_loss
2249    }
2250
2251    /// R-005: Evaluate a batch without backward pass or weight updates.
2252    /// Returns average cross-entropy loss, or 0.0 if no valid items.
2253    /// KAIZEN-050: Uses fused GPU cross-entropy (no logits D2H).
2254    pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2255        let hidden_size = self.config.model_config.hidden_size;
2256        let vocab_size = self.config.model_config.vocab_size;
2257        let max_sl = self.config.max_seq_len;
2258        let mut total_loss = 0.0;
2259        let mut valid_count = 0;
2260        for i in 0..batch.batch_size {
2261            if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2262            {
2263                total_loss += loss;
2264                valid_count += 1;
2265            }
2266        }
2267        if valid_count > 0 {
2268            total_loss / valid_count as f32
2269        } else {
2270            0.0
2271        }
2272    }
2273
2274    /// Evaluate a single sequence from a batch. Returns None if invalid.
2275    fn eval_single_sequence(
2276        &mut self,
2277        batch: &LMBatch,
2278        i: usize,
2279        max_sl: usize,
2280        hidden_size: usize,
2281        vocab_size: usize,
2282    ) -> Option<f32> {
2283        let input_ids = batch.get_input(i)?;
2284        let target_ids = batch.get_target(i)?;
2285        // Truncate to max_seq_len — GPU buffers are pre-allocated for this size
2286        let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2287        let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2288        let seq_len = input_ids.len();
2289        self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2290        let stream = self.cuda_trainer.stream();
2291        let scale = 1.0 / seq_len as f32;
2292        let loss = fused_cross_entropy_cuda(
2293            &mut self.gpu_training.logits_buf,
2294            target_ids,
2295            seq_len as u32,
2296            vocab_size as u32,
2297            scale,
2298            stream,
2299        )
2300        .ok()?;
2301        if loss.is_finite() {
2302            Some(loss)
2303        } else {
2304            None
2305        }
2306    }
2307
2308    /// Train for one epoch over batches.
2309    pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2310        self.train_epoch_with_callback(batches, |_, _, _| {})
2311    }
2312
2313    /// Train for one epoch with a per-step callback.
2314    ///
2315    /// Stops early if `max_steps` is set and reached.
2316    pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2317    where
2318        F: FnMut(usize, f32, &Self),
2319    {
2320        if batches.is_empty() {
2321            return 0.0;
2322        }
2323
2324        let mut total_loss = 0.0;
2325        let mut batches_processed = 0;
2326
2327        for (i, batch) in batches.iter().enumerate() {
2328            if let Some(max) = self.config.max_steps {
2329                if self.step >= max {
2330                    break;
2331                }
2332            }
2333
2334            let batch_loss = self.train_batch(batch);
2335            total_loss += batch_loss;
2336            batches_processed += 1;
2337            on_batch(i, batch_loss, self);
2338        }
2339
2340        // KAIZEN-047: Print profiler summary at end of epoch
2341        if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2342            self.profiler.print_report();
2343        }
2344
2345        total_loss / batches_processed.max(1) as f32
2346    }
2347
2348    // --- DDP (data-parallel) support methods ---
2349
2350    /// Ensure the per-block gradient accumulator exists.
2351    ///
2352    /// For DDP, we always need accumulation buffers (even with accumulation_steps=1)
2353    /// because gradients must be downloaded to CPU for AllReduce before optimizer step.
2354    pub(crate) fn ensure_grad_accum(&mut self) {
2355        if self.grad_accum.is_some() {
2356            return;
2357        }
2358        let mc = &self.config.model_config;
2359        let hidden_size = mc.hidden_size;
2360        let kv_hidden = mc.num_kv_heads * mc.head_dim();
2361        let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2362            hidden_size,
2363            kv_hidden,
2364            mc.intermediate_size,
2365        );
2366        self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2367            self.cuda_blocks.len(),
2368            block_sizes,
2369            mc.vocab_size,
2370            hidden_size,
2371        ));
2372    }
2373
2374    /// Forward + backward for one batch, always accumulating (no optimizer step).
2375    ///
2376    /// Used by `DistributedCudaTrainer` to compute local gradients before AllReduce.
2377    /// Returns average loss for the batch.
2378    pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2379        if batch.batch_size == 0 {
2380            return 0.0;
2381        }
2382
2383        if self.accumulated_batches == 0 {
2384            self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2385        }
2386
2387        let mut total_loss = 0.0;
2388        let mut valid_count = 0;
2389
2390        for i in 0..batch.batch_size {
2391            let Some(input_ids) = batch.get_input(i) else { continue };
2392            let Some(target_ids) = batch.get_target(i) else { continue };
2393
2394            // Always accumulate_only=true: gradients go to CPU accum buffers
2395            if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2396                total_loss += loss;
2397                valid_count += 1;
2398                if let Some(accum) = &mut self.grad_accum {
2399                    accum.accumulated_count += 1;
2400                }
2401            }
2402        }
2403
2404        if valid_count > 0 {
2405            total_loss / valid_count as f32
2406        } else {
2407            0.0
2408        }
2409    }
2410
2411    /// Apply DDP-averaged gradients: upload to GPU and run optimizer step.
2412    ///
2413    /// Called after AllReduce has written averaged gradients into the grad_accum.
2414    /// Runs gpu_optimizer_from_accum() for blocks + LM head + final norm,
2415    /// then optimizer_step() for embedding.
2416    pub(crate) fn apply_ddp_gradients(&mut self) {
2417        self.accumulated_loss = 0.0;
2418        self.accumulated_batches = 0;
2419        self.gpu_optimizer_from_accum();
2420        self.optimizer_step();
2421    }
2422
2423    /// Get a reference to the gradient accumulator (for DDP AllReduce).
2424    pub(crate) fn grad_accum_ref(
2425        &self,
2426    ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2427        self.grad_accum.as_ref()
2428    }
2429
2430    /// Get a mutable reference to the gradient accumulator (for DDP AllReduce).
2431    pub(crate) fn grad_accum_mut(
2432        &mut self,
2433    ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2434        self.grad_accum.as_mut()
2435    }
2436
2437    /// Get the training config.
2438    pub(crate) fn config(&self) -> &TransformerTrainConfig {
2439        &self.config
2440    }
2441
2442    /// Get CPU embedding gradient as flat Vec for AllReduce.
2443    pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2444        self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2445    }
2446
2447    /// Set CPU embedding gradient from AllReduced flat Vec.
2448    pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2449        self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2450    }
2451
2452    /// Returns true if max_steps has been reached.
2453    pub fn reached_max_steps(&self) -> bool {
2454        self.config.max_steps.is_some_and(|max| self.step >= max)
2455    }
2456
2457    /// Get current step count.
2458    pub fn step(&self) -> usize {
2459        self.step
2460    }
2461
2462    /// Set initial step for resume from checkpoint.
2463    ///
2464    /// Updates both the outer step counter (LR schedule, logging) and the
2465    /// GPU-side AdamW step counter (bias correction). Must be called before
2466    /// any `train_batch()` calls.
2467    pub fn set_initial_step(&mut self, step: usize) {
2468        self.step = step;
2469        self.gpu_training.step = step as u32;
2470    }
2471
2472    /// Set max_steps for cosine LR scheduler (ENT-275).
2473    ///
2474    /// Called by `train_loop_cuda` when `max_steps` is not explicitly set in
2475    /// the YAML config — auto-computes `epochs × batches_per_epoch` so cosine
2476    /// decay activates instead of falling back to constant lr.
2477    pub fn set_max_steps(&mut self, max_steps: usize) {
2478        self.config.max_steps = Some(max_steps);
2479    }
2480
2481    /// Get current learning rate (warmup + cosine decay).
2482    ///
2483    /// ALB-079: Phase 1 = linear warmup (0 → lr_max), Phase 2 = cosine decay
2484    /// (lr_max → 0) over remaining steps. Requires `max_steps` for decay;
2485    /// without it, falls back to constant lr after warmup.
2486    pub fn current_lr(&self) -> f32 {
2487        let base_lr = self.config.lr;
2488        if self.step < self.config.warmup_steps {
2489            // Phase 1: Linear warmup
2490            base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2491        } else if let Some(max_steps) = self.config.max_steps {
2492            // Phase 2: Cosine decay from lr_max to 0
2493            let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2494            if decay_steps == 0 {
2495                return base_lr;
2496            }
2497            let decay_step = self.step - self.config.warmup_steps;
2498            let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2499            0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2500        } else {
2501            // No max_steps: constant lr (legacy behavior)
2502            base_lr
2503        }
2504    }
2505
2506    /// KAIZEN-047: Enable step profiling with a report every `interval` steps.
2507    ///
2508    /// When enabled, prints a table of wall-clock timings per training phase
2509    /// every `interval` training steps. Use interval=0 for manual-only reporting.
2510    ///
2511    /// # Contract (C-STEPPROF-001)
2512    ///
2513    /// - No additional GPU synchronization points (relies on existing syncs)
2514    /// - Overhead: ~11 `Instant::now()` calls per step (~1µs total on Linux)
2515    /// - Timings include async dispatch overhead (not pure kernel time)
2516    pub fn enable_profiler(&mut self, interval: usize) {
2517        self.profiler = StepProfiler::new(true, interval);
2518    }
2519
2520    /// Print the profiler report (if profiling is enabled).
2521    pub fn print_profiler_report(&self) {
2522        self.profiler.print_report();
2523    }
2524
2525    /// R-004: Get last observed gradient L2 norm (LM head proxy).
2526    pub fn last_grad_norm(&self) -> f32 {
2527        self.last_grad_norm
2528    }
2529
2530    /// R-040: Get per-parameter-group gradient norms.
2531    /// Returns (lm_head_grad_norm, embed_grad_norm).
2532    pub fn param_grad_norms(&self) -> (f32, f32) {
2533        (self.last_grad_norm, self.last_embed_grad_norm)
2534    }
2535
2536    /// R-012: Get total trainable parameter count for MFU calculation.
2537    pub fn num_params(&self) -> usize {
2538        self.model.parameters().iter().map(|t| t.len()).sum()
2539    }
2540
2541    /// R-013: Query GPU memory usage (used_mb, total_mb).
2542    pub fn gpu_memory_mb(&self) -> (u64, u64) {
2543        match self.cuda_trainer.context().memory_info() {
2544            Ok((free, total)) => {
2545                let total_mb = (total / (1024 * 1024)) as u64;
2546                let used_mb = ((total - free) / (1024 * 1024)) as u64;
2547                (used_mb, total_mb)
2548            }
2549            Err(_) => (0, 0),
2550        }
2551    }
2552
2553    /// Sync all GPU weights back to CPU model.
2554    ///
2555    /// # Contract (C-SYNCWT-001)
2556    ///
2557    /// Must be called before save or any CPU model access after training.
2558    pub fn sync_weights_to_cpu(&mut self) {
2559        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2560
2561        if use_nf4 {
2562            // ENT-263: NF4 blocks are frozen — base weights don't change.
2563            // Only download LoRA adapter weights for checkpoint saving.
2564            // The base model on CPU stays as-is (original pretrained weights).
2565            // LoRA weights are saved separately (adapter_config.json + adapter.safetensors).
2566            // For now, skip per-layer sync — base weights are unchanged.
2567        } else {
2568            for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2569                if let Ok(weights) = block.download_weights() {
2570                    let layer = &mut self.model.layers[layer_idx];
2571
2572                    layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2573                    layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2574                    layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2575                    layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2576
2577                    layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2578                    layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2579                    layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2580
2581                    layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2582                    layer.post_attn_norm.weight =
2583                        Tensor::from_vec(weights.post_attn_norm_weight, false);
2584                }
2585            }
2586        }
2587
2588        // Sync final norm weight
2589        if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2590            self.model.norm.weight = Tensor::from_vec(norm_data, false);
2591        }
2592
2593        // Sync LM head weight
2594        // ALB-097: ALWAYS save GPU-trained LM head, even for tied-weight models.
2595        // During GPU training, lm_head diverges from embed_tokens because they have
2596        // separate optimizers (GPU AdamW vs CPU AdamW). If we skip the sync for tied
2597        // weights, the checkpoint loses 500+ steps of GPU LM head training → random-init
2598        // loss on resume (Five Whys root cause of ALB-097).
2599        if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2600            self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2601        }
2602    }
2603
2604    /// Get reference to model (syncs weights first).
2605    pub fn model(&self) -> &Transformer {
2606        &self.model
2607    }
2608
2609    /// Get mutable reference to model.
2610    pub fn model_mut(&mut self) -> &mut Transformer {
2611        &mut self.model
2612    }
2613
2614    /// Check if using mixed precision.
2615    pub fn is_mixed_precision(&self) -> bool {
2616        self.config.precision_config.is_mixed()
2617    }
2618
2619    /// Get the gradient scaler (R-002: loss scaling for mixed precision).
2620    pub fn grad_scaler(&self) -> &GradScaler {
2621        &self.grad_scaler
2622    }
2623
2624    /// Check if using gradient checkpointing.
2625    pub fn is_checkpointing(&self) -> bool {
2626        self.config.checkpoint_config.enabled
2627    }
2628
2629    /// Save model weights (syncs GPU→CPU first).
2630    pub fn save(
2631        &mut self,
2632        path: impl AsRef<std::path::Path>,
2633        name: &str,
2634        architecture: &str,
2635    ) -> crate::Result<()> {
2636        self.sync_weights_to_cpu();
2637
2638        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2639        let params: Vec<(String, Tensor)> = self
2640            .model
2641            .named_parameters()
2642            .into_iter()
2643            .map(|(name, tensor)| (name, tensor.clone()))
2644            .collect();
2645
2646        let metadata = ModelMetadata::new(name, architecture);
2647        let model = Model::new(metadata, params);
2648        let config = SaveConfig::new(ModelFormat::SafeTensors);
2649
2650        save_model(&model, path, &config)
2651    }
2652
2653    /// R-011: Prepare checkpoint data for async save.
2654    /// Syncs GPU weights to CPU and snapshots tensor data as Send-able Vec<f32>.
2655    /// Returns a closure that writes the checkpoint file from another thread.
2656    pub fn prepare_async_save(
2657        &mut self,
2658        name: &str,
2659        architecture: &str,
2660    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2661        self.sync_weights_to_cpu();
2662
2663        // Use named_parameters() for correct name mapping (handles attention biases etc.)
2664        let param_data: Vec<(String, Vec<f32>)> = self
2665            .model
2666            .named_parameters()
2667            .into_iter()
2668            .map(|(n, t)| (n, t.data().to_vec()))
2669            .collect();
2670
2671        let name = name.to_string();
2672        let architecture = architecture.to_string();
2673
2674        Box::new(move |path: &std::path::Path| {
2675            let params: Vec<(String, Tensor)> =
2676                param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2677            let metadata = ModelMetadata::new(&name, &architecture);
2678            let model = Model::new(metadata, params);
2679            let config = SaveConfig::new(ModelFormat::SafeTensors);
2680            save_model(&model, path, &config)
2681        })
2682    }
2683
2684    /// ALB-096: Save model weights as APR checkpoint (syncs GPU→CPU first).
2685    ///
2686    /// Single atomic file containing all model weights. Use `save_apr_checkpoint()`
2687    /// to include optimizer state and training metadata in the same file.
2688    pub fn save_apr(
2689        &mut self,
2690        path: impl AsRef<std::path::Path>,
2691        name: &str,
2692        architecture: &str,
2693    ) -> crate::Result<()> {
2694        self.save_apr_with_tokenizer(path, name, architecture, None)
2695    }
2696
2697    /// SPEC-SHIP-TWO-001 §81 P0-D + P0-E: save APR checkpoint with arch
2698    /// metadata keys AND optionally embed the source tokenizer.json.
2699    ///
2700    /// When `tokenizer_dir` is `Some`, reads `<dir>/tokenizer.json` and
2701    /// embeds the vocabulary + merges + BOS/EOS IDs as well-known
2702    /// metadata keys. This makes the resulting .apr file standalone for
2703    /// `apr qa`, `apr run`, etc. — no `--tokenizer` flag required at
2704    /// downstream tool dispatch.
2705    pub fn save_apr_with_tokenizer(
2706        &mut self,
2707        path: impl AsRef<std::path::Path>,
2708        name: &str,
2709        architecture: &str,
2710        tokenizer_dir: Option<&std::path::Path>,
2711    ) -> crate::Result<()> {
2712        self.sync_weights_to_cpu();
2713
2714        let params: Vec<(String, Tensor)> = self
2715            .model
2716            .named_parameters()
2717            .into_iter()
2718            .map(|(name, tensor)| (name, tensor.clone()))
2719            .collect();
2720
2721        // SPEC-SHIP-TWO-001 §81 P0-E: write individual arch metadata keys
2722        // so downstream tools (apr qa C-03, apr bench, realizar) can read them
2723        // via AprV2Metadata's typed fields. The legacy save_model() path only
2724        // carries `name + architecture + format + version` which fails C-03.
2725        use crate::io::save::infer_all_tensor_shapes;
2726        use aprender::serialization::apr::AprWriter;
2727        use serde_json::Value as Jv;
2728
2729        let mc = &self.config.model_config;
2730        let mut writer = AprWriter::new();
2731
2732        // Identity / version metadata (preserves save_model behavior)
2733        writer.set_metadata("model_name", Jv::String(name.to_string()));
2734        writer.set_metadata("architecture", Jv::String(architecture.to_string()));
2735        writer.set_metadata("version", Jv::String("0.1.0".into()));
2736        writer.set_metadata("format", Jv::String("entrenar-checkpoint".into()));
2737
2738        // Arch dim keys (well-known to AprWriter::build_v2_metadata,
2739        // map to AprV2Metadata typed fields).
2740        writer.set_metadata(
2741            "hidden_size",
2742            Jv::Number(serde_json::Number::from(mc.hidden_size as u64)),
2743        );
2744        writer.set_metadata(
2745            "num_hidden_layers",
2746            Jv::Number(serde_json::Number::from(mc.num_hidden_layers as u64)),
2747        );
2748        writer.set_metadata(
2749            "num_attention_heads",
2750            Jv::Number(serde_json::Number::from(mc.num_attention_heads as u64)),
2751        );
2752        writer.set_metadata(
2753            "num_kv_heads",
2754            Jv::Number(serde_json::Number::from(mc.num_kv_heads as u64)),
2755        );
2756        writer.set_metadata(
2757            "intermediate_size",
2758            Jv::Number(serde_json::Number::from(mc.intermediate_size as u64)),
2759        );
2760        writer
2761            .set_metadata("vocab_size", Jv::Number(serde_json::Number::from(mc.vocab_size as u64)));
2762        writer.set_metadata(
2763            "max_position_embeddings",
2764            Jv::Number(serde_json::Number::from(mc.max_position_embeddings as u64)),
2765        );
2766        if let Some(rope) = serde_json::Number::from_f64(mc.rope_theta as f64) {
2767            writer.set_metadata("rope_theta", Jv::Number(rope));
2768        }
2769        if let Some(eps) = serde_json::Number::from_f64(mc.rms_norm_eps as f64) {
2770            writer.set_metadata("rms_norm_eps", Jv::Number(eps));
2771        }
2772
2773        // SPEC-SHIP-TWO-001 §81 P0-D: embed tokenizer.json from
2774        // `tokenizer_dir/tokenizer.json` so `apr qa` (which requires
2775        // an embedded tokenizer) accepts the resulting .apr file.
2776        // ALB-130 style: parse vocab + merges + special token IDs and
2777        // set as well-known metadata keys.
2778        if let Some(dir) = tokenizer_dir {
2779            let tok_path = dir.join("tokenizer.json");
2780            if let Ok(json_bytes) = std::fs::read(&tok_path) {
2781                if let Ok(tok) = serde_json::from_slice::<Jv>(&json_bytes) {
2782                    if let Some(model) = tok.get("model") {
2783                        if let Some(vocab_obj) = model.get("vocab").and_then(|v| v.as_object()) {
2784                            let mut vocab_pairs: Vec<(String, u64)> = vocab_obj
2785                                .iter()
2786                                .filter_map(|(k, v)| Some((k.clone(), v.as_u64()?)))
2787                                .collect();
2788                            vocab_pairs.sort_by_key(|(_, id)| *id);
2789                            let vocab: Vec<Jv> =
2790                                vocab_pairs.into_iter().map(|(k, _)| Jv::String(k)).collect();
2791                            writer.set_metadata("tokenizer.vocabulary", Jv::Array(vocab));
2792                        }
2793                        if let Some(merges_arr) = model.get("merges").and_then(|m| m.as_array()) {
2794                            let merges: Vec<Jv> = merges_arr
2795                                .iter()
2796                                .filter_map(|v| v.as_str().map(|s| Jv::String(s.to_string())))
2797                                .collect();
2798                            writer.set_metadata("tokenizer.merges", Jv::Array(merges));
2799                        }
2800                    }
2801                    // BOS / EOS from added_tokens (HF format).
2802                    if let Some(added) = tok.get("added_tokens").and_then(|a| a.as_array()) {
2803                        for entry in added {
2804                            let content =
2805                                entry.get("content").and_then(|c| c.as_str()).unwrap_or("");
2806                            let id = entry.get("id").and_then(|i| i.as_u64());
2807                            if let Some(id) = id {
2808                                match content {
2809                                    "<s>" | "<|im_start|>" | "<|begin_of_text|>" => {
2810                                        writer.set_metadata(
2811                                            "tokenizer.bos_token_id",
2812                                            Jv::Number(serde_json::Number::from(id)),
2813                                        );
2814                                    }
2815                                    "</s>" | "<|im_end|>" | "<|end_of_text|>" | "<|endoftext|>" => {
2816                                        writer.set_metadata(
2817                                            "tokenizer.eos_token_id",
2818                                            Jv::Number(serde_json::Number::from(id)),
2819                                        );
2820                                    }
2821                                    _ => {}
2822                                }
2823                            }
2824                        }
2825                    }
2826                }
2827            }
2828        }
2829
2830        // Tensors — reuse io::save's shape inference for 2D weight handling.
2831        let shapes = infer_all_tensor_shapes(&params);
2832        for (tname, tensor) in &params {
2833            let data = tensor.data();
2834            let slice = data.as_slice().expect("tensor data must be contiguous");
2835            let shape = shapes.get(tname).cloned().unwrap_or_else(|| vec![tensor.len()]);
2836            writer.add_tensor_f32(tname, shape, slice);
2837        }
2838
2839        writer
2840            .write(path)
2841            .map_err(|e| crate::error::Error::Serialization(format!("APR write failed: {e}")))
2842    }
2843
2844    /// ALB-096: Prepare APR checkpoint data for async save.
2845    ///
2846    /// Syncs GPU weights to CPU and snapshots tensor data + optimizer state as
2847    /// Send-able `Vec<f32>`. Returns a closure that writes a single atomic APR
2848    /// file from another thread. Includes model weights + CPU embedding optimizer
2849    /// state + training metadata — all in one file.
2850    fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2851        let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2852        if use_nf4 {
2853            let frozen_suffixes = [
2854                "q_proj.weight",
2855                "k_proj.weight",
2856                "v_proj.weight",
2857                "o_proj.weight",
2858                "gate_proj.weight",
2859                "up_proj.weight",
2860                "down_proj.weight",
2861            ];
2862            self.model
2863                .named_parameters()
2864                .into_iter()
2865                .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2866                .map(|(n, t)| (n, t.data().to_vec()))
2867                .collect()
2868        } else {
2869            self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2870        }
2871    }
2872
2873    fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2874        if self.config.quantize_nf4 && self.config.is_lora() {
2875            self.cuda_blocks
2876                .iter()
2877                .enumerate()
2878                .filter_map(|(i, block)| {
2879                    block
2880                        .download_lora_weights()
2881                        .ok()
2882                        .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2883                })
2884                .collect()
2885        } else {
2886            Vec::new()
2887        }
2888    }
2889
2890    pub fn prepare_async_apr_save(
2891        &mut self,
2892        name: &str,
2893        architecture: &str,
2894        step: usize,
2895        loss: f64,
2896        lr: f64,
2897    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2898        self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2899    }
2900
2901    /// ALB-130: Prepare APR checkpoint with embedded tokenizer for inference.
2902    ///
2903    /// Training checkpoints must be self-contained for eval (`apr eval --task humaneval`).
2904    /// Without embedded tokenizer, inference falls back to structural validation (fake 100%).
2905    /// The tokenizer path comes from `spec.data.tokenizer` in the training YAML.
2906    pub fn prepare_async_apr_save_with_tokenizer(
2907        &mut self,
2908        name: &str,
2909        architecture: &str,
2910        step: usize,
2911        loss: f64,
2912        lr: f64,
2913        tokenizer_path: Option<&std::path::Path>,
2914    ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2915        self.sync_weights_to_cpu();
2916
2917        let param_data = self.snapshot_param_data();
2918        let lora_data = self.snapshot_lora_data();
2919
2920        // Snapshot CPU embedding optimizer state
2921        let embed_m: Vec<Vec<f32>> = self
2922            .embed_optimizer
2923            .first_moments()
2924            .iter()
2925            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2926            .collect();
2927        let embed_v: Vec<Vec<f32>> = self
2928            .embed_optimizer
2929            .second_moments()
2930            .iter()
2931            .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2932            .collect();
2933        let embed_step = self.embed_optimizer.step_count();
2934
2935        // ALB-118: Download GPU block optimizer states (m/v moments) for checkpointing.
2936        // Without this, resume re-initializes all 24 blocks' AdamW state to zero,
2937        // causing loss spikes and convergence failure (v10/v11/v12 post-mortems).
2938        // Transfer cost: ~2.3 GB D2H, <6ms on PCIe4/5.
2939        let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2940            .gpu_training
2941            .optimizer_states
2942            .iter()
2943            .map(|state| state.download_to_host().unwrap_or_default())
2944            .collect();
2945
2946        // ALB-118: Download LM head and final norm optimizer states
2947        let lm_head_m_host = {
2948            let mut buf = vec![0.0f32; self.lm_head_m.len()];
2949            let _ = self.lm_head_m.copy_to_host(&mut buf);
2950            buf
2951        };
2952        let lm_head_v_host = {
2953            let mut buf = vec![0.0f32; self.lm_head_v.len()];
2954            let _ = self.lm_head_v.copy_to_host(&mut buf);
2955            buf
2956        };
2957        let final_norm_m_host = {
2958            let mut buf = vec![0.0f32; self.final_norm_m.len()];
2959            let _ = self.final_norm_m.copy_to_host(&mut buf);
2960            buf
2961        };
2962        let final_norm_v_host = {
2963            let mut buf = vec![0.0f32; self.final_norm_v.len()];
2964            let _ = self.final_norm_v.copy_to_host(&mut buf);
2965            buf
2966        };
2967
2968        let name = name.to_string();
2969        let architecture = architecture.to_string();
2970        let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2971        let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2972
2973        // SPEC-SHIP-TWO-001 §81 P0-E: extract individual arch metadata keys
2974        // so downstream tools (apr qa, apr bench, apr export) can read them
2975        // via AprV2Metadata's typed fields. The `model_config` JSON blob is
2976        // unrecognized by AprWriter::build_v2_metadata and goes into the
2977        // `custom` map — which `realizar::gguf::config::from_apr` does NOT
2978        // read (it requires `apr.metadata.hidden_size` etc. to be Some).
2979        let arch_hidden_size = self.config.model_config.hidden_size;
2980        let arch_num_layers = self.config.model_config.num_hidden_layers;
2981        let arch_num_heads = self.config.model_config.num_attention_heads;
2982        let arch_num_kv_heads = self.config.model_config.num_kv_heads;
2983        let arch_intermediate_size = self.config.model_config.intermediate_size;
2984        let arch_vocab_size = self.config.model_config.vocab_size;
2985        let arch_max_position_embeddings = self.config.model_config.max_position_embeddings;
2986        let arch_rope_theta = self.config.model_config.rope_theta;
2987        let arch_rms_norm_eps = self.config.model_config.rms_norm_eps;
2988
2989        // ALB-130: Pre-read tokenizer.json for embedding in checkpoint.
2990        // Parse HuggingFace tokenizer format → extract vocab + merges + special token IDs.
2991        let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2992            tokenizer_path.and_then(|p| {
2993                let json_bytes = std::fs::read(p).ok()?;
2994                let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2995                let model = tok.get("model")?;
2996                let vocab_obj = model.get("vocab")?.as_object()?;
2997                // Build sorted-by-id vocab list
2998                let mut vocab_pairs: Vec<(String, u64)> =
2999                    vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
3000                vocab_pairs.sort_by_key(|(_, id)| *id);
3001                let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
3002                // Merges as "token1 token2" strings
3003                let merges: Vec<String> = model
3004                    .get("merges")?
3005                    .as_array()?
3006                    .iter()
3007                    .filter_map(|v| v.as_str().map(String::from))
3008                    .collect();
3009                // Special tokens: BOS=<s>=1, EOS=</s>=2 (from added_tokens)
3010                let added = tok.get("added_tokens").and_then(|a| a.as_array());
3011                let bos_id = added.and_then(|arr| {
3012                    arr.iter()
3013                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
3014                        .and_then(|t| t.get("id")?.as_u64())
3015                });
3016                let eos_id = added.and_then(|arr| {
3017                    arr.iter()
3018                        .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
3019                        .and_then(|t| t.get("id")?.as_u64())
3020                });
3021                if vocab.is_empty() {
3022                    return None;
3023                }
3024                println!(
3025                    "  [ALB-130] Embedding tokenizer: {} vocab, {} merges",
3026                    vocab.len(),
3027                    merges.len()
3028                );
3029                Some((vocab, merges, bos_id, eos_id))
3030            });
3031
3032        Box::new(move |path: &std::path::Path| {
3033            use aprender::serialization::apr::AprWriter;
3034            use serde_json::Value as Jv;
3035
3036            let mut writer = AprWriter::new();
3037
3038            // Metadata
3039            writer.set_metadata("model_name", Jv::String(name));
3040            writer.set_metadata("architecture", Jv::String(architecture));
3041            writer.set_metadata(
3042                "format",
3043                Jv::String(if is_delta_checkpoint {
3044                    "entrenar-delta-checkpoint".into()
3045                } else {
3046                    "entrenar-checkpoint".into()
3047                }),
3048            );
3049            writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
3050            writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
3051            writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
3052            writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
3053            if let Some(cfg) = model_config_json {
3054                writer.set_metadata("model_config", Jv::String(cfg));
3055            }
3056
3057            // SPEC-SHIP-TWO-001 §81 P0-E: write individual arch metadata keys
3058            // so realizar's `from_apr` (C-03 gate) accepts the checkpoint.
3059            // `serde_json::Number::from(u as u64)` converts usize losslessly.
3060            writer.set_metadata(
3061                "hidden_size",
3062                Jv::Number(serde_json::Number::from(arch_hidden_size as u64)),
3063            );
3064            writer.set_metadata(
3065                "num_hidden_layers",
3066                Jv::Number(serde_json::Number::from(arch_num_layers as u64)),
3067            );
3068            writer.set_metadata(
3069                "num_attention_heads",
3070                Jv::Number(serde_json::Number::from(arch_num_heads as u64)),
3071            );
3072            writer.set_metadata(
3073                "num_kv_heads",
3074                Jv::Number(serde_json::Number::from(arch_num_kv_heads as u64)),
3075            );
3076            writer.set_metadata(
3077                "intermediate_size",
3078                Jv::Number(serde_json::Number::from(arch_intermediate_size as u64)),
3079            );
3080            writer.set_metadata(
3081                "vocab_size",
3082                Jv::Number(serde_json::Number::from(arch_vocab_size as u64)),
3083            );
3084            writer.set_metadata(
3085                "max_position_embeddings",
3086                Jv::Number(serde_json::Number::from(arch_max_position_embeddings as u64)),
3087            );
3088            if let Some(rope) = serde_json::Number::from_f64(arch_rope_theta as f64) {
3089                writer.set_metadata("rope_theta", Jv::Number(rope));
3090            }
3091            if let Some(eps) = serde_json::Number::from_f64(arch_rms_norm_eps as f64) {
3092                writer.set_metadata("rms_norm_eps", Jv::Number(eps));
3093            }
3094
3095            // ALB-130: Embed tokenizer vocab + merges for standalone inference
3096            if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
3097                writer.set_metadata(
3098                    "tokenizer.vocabulary",
3099                    Jv::Array(vocab.into_iter().map(Jv::String).collect()),
3100                );
3101                writer.set_metadata(
3102                    "tokenizer.merges",
3103                    Jv::Array(merges.into_iter().map(Jv::String).collect()),
3104                );
3105                if let Some(bos) = bos_id {
3106                    writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
3107                }
3108                if let Some(eos) = eos_id {
3109                    writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
3110                }
3111            }
3112
3113            // Find hidden_size from norm weights for shape inference
3114            let hidden_size = param_data
3115                .iter()
3116                .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
3117                .map_or(0, |(_, d)| d.len());
3118
3119            // Model weight tensors
3120            for (tensor_name, data) in &param_data {
3121                let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
3122                writer.add_tensor_f32(tensor_name.clone(), shape, data);
3123            }
3124
3125            // Optimizer state tensors
3126            for (i, m_data) in embed_m.iter().enumerate() {
3127                let len = m_data.len();
3128                writer.add_tensor_f32(
3129                    format!("__training__.embed_optimizer.m.{i}"),
3130                    vec![len],
3131                    m_data,
3132                );
3133            }
3134            for (i, v_data) in embed_v.iter().enumerate() {
3135                let len = v_data.len();
3136                writer.add_tensor_f32(
3137                    format!("__training__.embed_optimizer.v.{i}"),
3138                    vec![len],
3139                    v_data,
3140                );
3141            }
3142
3143            // ALB-118: Save GPU block optimizer states (m/v moments for all 24 blocks)
3144            for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
3145                for (suffix, data) in buffers {
3146                    let len = data.len();
3147                    writer.add_tensor_f32(
3148                        format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
3149                        vec![len],
3150                        data,
3151                    );
3152                }
3153            }
3154
3155            // ALB-118: Save LM head and final norm optimizer states
3156            if !lm_head_m_host.is_empty() {
3157                let len = lm_head_m_host.len();
3158                writer.add_tensor_f32(
3159                    "__training__.lm_head_optimizer.m".to_string(),
3160                    vec![len],
3161                    &lm_head_m_host,
3162                );
3163                let len = lm_head_v_host.len();
3164                writer.add_tensor_f32(
3165                    "__training__.lm_head_optimizer.v".to_string(),
3166                    vec![len],
3167                    &lm_head_v_host,
3168                );
3169            }
3170            if !final_norm_m_host.is_empty() {
3171                let len = final_norm_m_host.len();
3172                writer.add_tensor_f32(
3173                    "__training__.final_norm_optimizer.m".to_string(),
3174                    vec![len],
3175                    &final_norm_m_host,
3176                );
3177                let len = final_norm_v_host.len();
3178                writer.add_tensor_f32(
3179                    "__training__.final_norm_optimizer.v".to_string(),
3180                    vec![len],
3181                    &final_norm_v_host,
3182                );
3183            }
3184
3185            // ENT-276: Save LoRA adapter weights (QLoRA checkpoint resume)
3186            for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
3187                if !a_q.is_empty() {
3188                    writer.add_tensor_f32(
3189                        format!("lora.{layer_idx}.q_proj.lora_a"),
3190                        vec![a_q.len()],
3191                        a_q,
3192                    );
3193                    writer.add_tensor_f32(
3194                        format!("lora.{layer_idx}.q_proj.lora_b"),
3195                        vec![b_q.len()],
3196                        b_q,
3197                    );
3198                }
3199                if !a_v.is_empty() {
3200                    writer.add_tensor_f32(
3201                        format!("lora.{layer_idx}.v_proj.lora_a"),
3202                        vec![a_v.len()],
3203                        a_v,
3204                    );
3205                    writer.add_tensor_f32(
3206                        format!("lora.{layer_idx}.v_proj.lora_b"),
3207                        vec![b_v.len()],
3208                        b_v,
3209                    );
3210                }
3211            }
3212
3213            // Write APR checkpoint to file
3214            writer
3215                .write(path)
3216                .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3217
3218            Ok(())
3219        })
3220    }
3221
3222    /// GPU device name.
3223    pub fn gpu_name(&self) -> String {
3224        self.cuda_trainer.device_name()
3225    }
3226
3227    /// ENT-269: Save LoRA adapter weights as PEFT-compatible files.
3228    ///
3229    /// Downloads LoRA A/B matrices from GPU, un-scales B (divide by lora_scale),
3230    /// transposes to PEFT convention (A=[rank, d_in], B=[d_out, rank]),
3231    /// and writes `adapter_model.safetensors` + `adapter_config.json`.
3232    ///
3233    /// # Contract: C-QLORA-SAVE-001
3234    ///
3235    /// NF4 QLoRA training MUST produce `adapter_model.safetensors` in output_dir.
3236    pub fn save_cuda_lora_adapter(
3237        &self,
3238        output_dir: &std::path::Path,
3239        base_model_name: Option<&str>,
3240    ) -> crate::Result<()> {
3241        if !self.config.quantize_nf4 || !self.config.is_lora() {
3242            return Ok(()); // Not a QLoRA run, nothing to save
3243        }
3244
3245        let lora_rank = self.config.lora_rank.unwrap_or(16);
3246        let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3247        let lora_scale = lora_alpha / lora_rank as f32;
3248        let hidden_size = self.config.model_config.hidden_size;
3249        let head_dim = self.config.model_config.head_dim();
3250        let q_dim = self.config.model_config.num_attention_heads * head_dim;
3251        let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3252
3253        let lora_config =
3254            crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3255
3256        let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3257
3258        for (i, block) in self.cuda_blocks.iter().enumerate() {
3259            let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3260                Ok(weights) => weights,
3261                Err(_) => continue, // Skip non-NF4 blocks
3262            };
3263
3264            if a_q.is_empty() && a_v.is_empty() {
3265                continue;
3266            }
3267
3268            // Q projection LoRA
3269            if !a_q.is_empty() {
3270                // GPU stores A_q as [hidden, rank] row-major, PEFT expects [rank, hidden]
3271                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3272                for r in 0..hidden_size {
3273                    for c in 0..lora_rank {
3274                        a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3275                    }
3276                }
3277
3278                // GPU stores B_q as [rank, q_dim] pre-scaled by lora_scale
3279                // PEFT expects [q_dim, rank] un-scaled
3280                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3281                let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3282                for r in 0..lora_rank {
3283                    for c in 0..q_dim {
3284                        b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3285                    }
3286                }
3287
3288                let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3289                let mut layer = crate::lora::LoRALayer::new(
3290                    base_weight,
3291                    q_dim,
3292                    hidden_size,
3293                    lora_rank,
3294                    lora_alpha,
3295                );
3296                // Overwrite the A and B data with trained weights
3297                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3298                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3299
3300                adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3301            }
3302
3303            // V projection LoRA
3304            if !a_v.is_empty() {
3305                let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3306                for r in 0..hidden_size {
3307                    for c in 0..lora_rank {
3308                        a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3309                    }
3310                }
3311
3312                let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3313                let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3314                for r in 0..lora_rank {
3315                    for c in 0..kv_hidden {
3316                        b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3317                    }
3318                }
3319
3320                let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3321                let mut layer = crate::lora::LoRALayer::new(
3322                    base_weight,
3323                    kv_hidden,
3324                    hidden_size,
3325                    lora_rank,
3326                    lora_alpha,
3327                );
3328                layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3329                layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3330
3331                adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3332            }
3333        }
3334
3335        if adapters.is_empty() {
3336            println!("  [WARN] No LoRA adapters found to save");
3337            return Ok(());
3338        }
3339
3340        let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3341            adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3342
3343        std::fs::create_dir_all(output_dir).ok();
3344        crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3345            .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3346
3347        let adapter_path = output_dir.join("adapter_model.safetensors");
3348        let size_mb =
3349            std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3350        println!(
3351            "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3352            adapters.len(),
3353            size_mb,
3354            output_dir.display()
3355        );
3356
3357        Ok(())
3358    }
3359
3360    /// R-001: Save CPU embedding optimizer state (m/v buffers + step counter).
3361    ///
3362    /// Writes `optimizer_state.json` to the given directory. GPU block optimizer
3363    /// states remain on-device (D2H for 20 buffers × N blocks is deferred).
3364    pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3365        let path = dir.join("optimizer_state.json");
3366        let m_data: Vec<Option<Vec<f32>>> = self
3367            .embed_optimizer
3368            .first_moments()
3369            .iter()
3370            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3371            .collect();
3372        let v_data: Vec<Option<Vec<f32>>> = self
3373            .embed_optimizer
3374            .second_moments()
3375            .iter()
3376            .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3377            .collect();
3378        let state = serde_json::json!({
3379            "type": "adamw_cpu_embed",
3380            "step": self.embed_optimizer.step_count(),
3381            "m": m_data,
3382            "v": v_data,
3383        });
3384        let json_str = serde_json::to_string(&state).map_err(|e| {
3385            crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3386        })?;
3387        std::fs::write(&path, json_str)
3388            .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3389        Ok(())
3390    }
3391
3392    /// ENT-276: Restore LoRA adapter weights from APR checkpoint.
3393    ///
3394    /// Reads `lora.{layer}.{q,v}_proj.lora_{a,b}` tensors from the APR file
3395    /// and uploads them to the NF4 CUDA blocks, replacing the fresh random init.
3396    /// Returns (layers_restored, layers_total).
3397    pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3398        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3399            Ok(r) => r,
3400            Err(_) => return (0, self.cuda_blocks.len()),
3401        };
3402
3403        let mut restored = 0usize;
3404        for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3405            let a_q =
3406                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3407            let b_q =
3408                reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3409            let a_v =
3410                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3411            let b_v =
3412                reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3413
3414            if a_q.is_empty() {
3415                continue; // No LoRA data for this layer in checkpoint
3416            }
3417
3418            if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3419                eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3420                continue;
3421            }
3422            restored += 1;
3423        }
3424
3425        (restored, self.cuda_blocks.len())
3426    }
3427
3428    /// ALB-096: Load CPU embedding optimizer state from APR checkpoint.
3429    ///
3430    /// Reads `__training__.embed_optimizer.{m,v}.*` tensors from the APR file.
3431    /// Returns true if state was loaded.
3432    pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3433        let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3434            Ok(r) => r,
3435            Err(_) => return false,
3436        };
3437
3438        // Restore step count from metadata
3439        if let Some(step_val) = reader.get_metadata("optimizer_step") {
3440            if let Some(step_str) = step_val.as_str() {
3441                if let Ok(step) = step_str.parse::<u64>() {
3442                    self.embed_optimizer.set_step_count(step);
3443                }
3444            }
3445        }
3446
3447        // Restore first moments (m)
3448        for i in 0..128 {
3449            let name = format!("__training__.embed_optimizer.m.{i}");
3450            match reader.read_tensor_f32(&name) {
3451                Ok(data) if !data.is_empty() => {
3452                    self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3453                }
3454                _ => break,
3455            }
3456        }
3457
3458        // Restore second moments (v)
3459        for i in 0..128 {
3460            let name = format!("__training__.embed_optimizer.v.{i}");
3461            match reader.read_tensor_f32(&name) {
3462                Ok(data) if !data.is_empty() => {
3463                    self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3464                }
3465                _ => break,
3466            }
3467        }
3468
3469        // ALB-118: Restore GPU block optimizer states (m/v moments for all blocks)
3470        let suffixes = [
3471            "m.w_q",
3472            "v.w_q",
3473            "m.w_k",
3474            "v.w_k",
3475            "m.w_v",
3476            "v.w_v",
3477            "m.w_o",
3478            "v.w_o",
3479            "m.w_gate",
3480            "v.w_gate",
3481            "m.w_up",
3482            "v.w_up",
3483            "m.w_down",
3484            "v.w_down",
3485            "m.input_norm",
3486            "v.input_norm",
3487            "m.post_attn_norm",
3488            "v.post_attn_norm",
3489        ];
3490        let mut blocks_restored = 0usize;
3491        for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3492            let mut data = std::collections::HashMap::new();
3493            for suffix in &suffixes {
3494                let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3495                if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3496                    if !tensor_data.is_empty() {
3497                        data.insert(suffix.to_string(), tensor_data);
3498                    }
3499                }
3500            }
3501            if !data.is_empty() {
3502                let _ = state.restore_from_host(&data);
3503                blocks_restored += 1;
3504            }
3505        }
3506
3507        // ALB-118: Restore LM head optimizer state
3508        if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3509            if m_data.len() == self.lm_head_m.len() {
3510                let _ = self.lm_head_m.copy_from_host(&m_data);
3511            }
3512        }
3513        if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3514            if v_data.len() == self.lm_head_v.len() {
3515                let _ = self.lm_head_v.copy_from_host(&v_data);
3516            }
3517        }
3518
3519        // ALB-118: Restore final norm optimizer state
3520        if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3521            if m_data.len() == self.final_norm_m.len() {
3522                let _ = self.final_norm_m.copy_from_host(&m_data);
3523            }
3524        }
3525        if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3526            if v_data.len() == self.final_norm_v.len() {
3527                let _ = self.final_norm_v.copy_from_host(&v_data);
3528            }
3529        }
3530
3531        // ALB-132: Report restore results — don't silently swallow failures
3532        if blocks_restored > 0 {
3533            println!(
3534                "  ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3535                self.gpu_training.optimizer_states.len()
3536            );
3537        } else if !self.gpu_training.optimizer_states.is_empty() {
3538            println!(
3539                "  [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3540                self.gpu_training.optimizer_states.len()
3541            );
3542        }
3543
3544        true
3545    }
3546
3547    /// R-001: Load CPU embedding optimizer state from `optimizer_state.json`.
3548    ///
3549    /// Returns true if state was loaded, false if file doesn't exist.
3550    pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3551        let path = dir.join("optimizer_state.json");
3552        let data = match std::fs::read_to_string(&path) {
3553            Ok(d) => d,
3554            Err(_) => return false,
3555        };
3556        let state: serde_json::Value = match serde_json::from_str(&data) {
3557            Ok(v) => v,
3558            Err(_) => return false,
3559        };
3560        if let Some(step) = state["step"].as_u64() {
3561            self.embed_optimizer.set_step_count(step);
3562        }
3563        restore_moment_buffers(&state["m"], |idx, arr| {
3564            self.embed_optimizer.set_first_moment(idx, arr);
3565        });
3566        restore_moment_buffers(&state["v"], |idx, arr| {
3567            self.embed_optimizer.set_second_moment(idx, arr);
3568        });
3569        true
3570    }
3571}
3572
3573/// ALB-096: Infer 2D tensor shape from name and element count.
3574///
3575/// Same logic as `infer_all_tensor_shapes` in `io/save.rs` but for a single tensor.
3576#[cfg(feature = "cuda")]
3577fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3578    if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3579        vec![numel]
3580    } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3581        let other_dim = numel / hidden_size;
3582        if name.ends_with("down_proj.weight") {
3583            vec![hidden_size, other_dim]
3584        } else {
3585            vec![other_dim, hidden_size]
3586        }
3587    } else {
3588        vec![numel]
3589    }
3590}
3591
3592/// Parse a JSON array of moment buffers and apply each via callback.
3593#[cfg(feature = "cuda")]
3594fn restore_moment_buffers(
3595    json_arr: &serde_json::Value,
3596    mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3597) {
3598    let Some(arr) = json_arr.as_array() else { return };
3599    for (idx, val) in arr.iter().enumerate() {
3600        let Some(inner) = val.as_array() else { continue };
3601        let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3602        if !floats.is_empty() {
3603            set_fn(idx, ndarray::Array1::from_vec(floats));
3604        }
3605    }
3606}
3607
3608// ── Non-CUDA stub ──
3609
3610#[cfg(not(feature = "cuda"))]
3611pub struct CudaTransformerTrainer;
3612
3613#[cfg(not(feature = "cuda"))]
3614impl CudaTransformerTrainer {
3615    pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3616        Err(crate::error::Error::ConfigError(
3617            "CUDA not available (compiled without cuda feature)".into(),
3618        ))
3619    }
3620
3621    pub fn with_model(
3622        _model: crate::transformer::Transformer,
3623        _config: super::config::TransformerTrainConfig,
3624    ) -> crate::Result<Self> {
3625        Err(crate::error::Error::ConfigError(
3626            "CUDA not available (compiled without cuda feature)".into(),
3627        ))
3628    }
3629
3630    pub fn gpu_name(&self) -> String {
3631        unreachable!("CudaTransformerTrainer stub should never be instantiated")
3632    }
3633}
3634
3635#[cfg(test)]
3636mod tests {
3637    #[test]
3638    #[cfg(not(feature = "cuda"))]
3639    fn test_cuda_trainer_stub_returns_error() {
3640        use super::super::config::TransformerTrainConfig;
3641        use crate::transformer::TransformerConfig;
3642
3643        let mc = TransformerConfig::tiny();
3644        let config = TransformerTrainConfig::new(mc);
3645        let result = super::CudaTransformerTrainer::new(config);
3646        assert!(result.is_err());
3647    }
3648}