1#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42 gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46 adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47 gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48 squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64 CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65 GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83 ws: &CudaGradWorkspace,
84 max_norm: f32,
85 stream: &CudaStream,
86) -> (f32, f32) {
87 use crate::autograd::cuda_optim::PendingSquaredSum;
88
89 let all_bufs: [&GpuBuffer<f32>; 9] = [
90 &ws.grad_w_q,
91 &ws.grad_w_k,
92 &ws.grad_w_v,
93 &ws.grad_w_o,
94 &ws.grad_gate,
95 &ws.grad_up,
96 &ws.grad_down,
97 &ws.grad_input_norm,
98 &ws.grad_post_attn_norm,
99 ];
100
101 let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104 for buf in &all_bufs {
105 let n = buf.len() as u32;
106 if n == 0 {
107 continue;
108 }
109 if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110 pending.push(p);
111 }
112 }
113
114 if stream.synchronize().is_err() {
116 return (1.0, 0.0);
117 }
118
119 let mut total_sq = 0.0f64;
123 for p in &pending {
124 if let Ok(sq_norm) = squared_sum_collect(p) {
125 total_sq += f64::from(sq_norm); }
127 }
128
129 let grad_norm = total_sq.sqrt() as f32; let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131 (scale, grad_norm)
132}
133
134#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139 let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140 if (scale - 1.0).abs() < 1e-7 {
141 return grad_norm;
142 }
143
144 let n_wq = ws.grad_w_q.len() as u32;
145 let n_wk = ws.grad_w_k.len() as u32;
146 let n_wv = ws.grad_w_v.len() as u32;
147 let n_wo = ws.grad_w_o.len() as u32;
148 let n_gate = ws.grad_gate.len() as u32;
149 let n_up = ws.grad_up.len() as u32;
150 let n_down = ws.grad_down.len() as u32;
151 let n_inorm = ws.grad_input_norm.len() as u32;
152 let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154 let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155 let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156 let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157 let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158 let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159 let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160 let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163 grad_norm
164}
165
166#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178 ws: &mut CudaGradWorkspace,
179 max_norm: f32,
180 state: &FusedClipState,
181 stream: &CudaStream,
182) {
183 let all_bufs: [&GpuBuffer<f32>; 9] = [
184 &ws.grad_w_q,
185 &ws.grad_w_k,
186 &ws.grad_w_v,
187 &ws.grad_w_o,
188 &ws.grad_gate,
189 &ws.grad_up,
190 &ws.grad_down,
191 &ws.grad_input_norm,
192 &ws.grad_post_attn_norm,
193 ];
194
195 for (i, buf) in all_bufs.iter().enumerate() {
198 let n = buf.len() as u32;
199 if n == 0 {
200 continue;
201 }
202 let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203 let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204 }
205
206 let _ = clip_scale_reduce_cuda(
209 &state.partials_buf,
210 state.total_partials,
211 max_norm,
212 &state.scale_buf,
213 stream,
214 );
215
216 let scale_ptr = state.scale_buf.as_ptr(); let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220 &mut ws.grad_w_q,
221 &mut ws.grad_w_k,
222 &mut ws.grad_w_v,
223 &mut ws.grad_w_o,
224 &mut ws.grad_gate,
225 &mut ws.grad_up,
226 &mut ws.grad_down,
227 &mut ws.grad_input_norm,
228 &mut ws.grad_post_attn_norm,
229 ];
230 for buf in &mut all_bufs_mut {
231 let n = buf.len() as u32;
232 if n == 0 {
233 continue;
234 }
235 let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236 }
237}
238
239#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245 let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246 norm
247}
248
249#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261 if (inv_scale - 1.0).abs() < 1e-7 {
262 return;
263 }
264
265 let n_wq = ws.grad_w_q.len() as u32;
266 let n_wk = ws.grad_w_k.len() as u32;
267 let n_wv = ws.grad_w_v.len() as u32;
268 let n_wo = ws.grad_w_o.len() as u32;
269 let n_gate = ws.grad_gate.len() as u32;
270 let n_up = ws.grad_up.len() as u32;
271 let n_down = ws.grad_down.len() as u32;
272 let n_inorm = ws.grad_input_norm.len() as u32;
273 let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275 let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276 let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277 let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278 let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279 let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280 let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281 let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295 layer_inputs: Vec<GpuBuffer<f32>>,
297 saved_layer_mask: Vec<bool>,
301 recompute_buf: Option<GpuBuffer<f32>>,
305 final_norm_weight: GpuBuffer<f32>,
307 blocks_output: GpuBuffer<f32>,
309 grad_buf_a: GpuBuffer<f32>,
311 grad_buf_b: GpuBuffer<f32>,
313 grad_final_norm_weight: GpuBuffer<f32>,
315 norm_output: GpuBuffer<f32>,
317 logits_buf: GpuBuffer<f32>,
319 lm_head_grad_hidden: GpuBuffer<f32>,
321 optimizer_states: Vec<GpuBlockOptimizerState>,
323 step: u32,
325}
326
327#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339 model: Transformer,
341 cuda_trainer: CudaTrainer,
343 cuda_blocks: Vec<CudaBlock>,
345 cuda_grad_workspace: CudaGradWorkspace,
347 nf4_shared_scratch: Option<CudaBlockScratch>,
349 nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351 nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353 gpu_training: GpuPretrainState,
355 lm_head_weight_gpu: GpuBuffer<f32>,
357 lm_head_grad_gpu: GpuBuffer<f32>,
359 lm_head_m: GpuBuffer<f32>,
361 lm_head_v: GpuBuffer<f32>,
363 final_norm_m: GpuBuffer<f32>,
365 final_norm_v: GpuBuffer<f32>,
367 embed_optimizer: AdamW,
369 config: TransformerTrainConfig,
371 pub metrics: MetricsTracker,
373 step: usize,
375 accumulated_loss: f32,
377 accumulated_batches: usize,
379 last_grad_norm: f32,
381 last_embed_grad_norm: f32,
383 grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386 gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389 grad_scaler: GradScaler,
393 profiler: StepProfiler,
396 fwd_scratch_a: GpuBuffer<f32>,
399 fwd_scratch_b: GpuBuffer<f32>,
400 h2d_staging: Vec<f32>,
403 d2h_staging: Vec<f32>,
408 fused_clip: Option<FusedClipState>,
411 final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419 pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426 let model = Transformer::new(&config.model_config);
427 Self::with_model(model, config)
428 }
429
430 pub fn for_inference(
444 checkpoint_dir: impl AsRef<std::path::Path>,
445 model_config: crate::transformer::TransformerConfig,
446 ) -> crate::Result<Self> {
447 let dir = checkpoint_dir.as_ref();
448
449 let model = if let Some((Some(m), _step)) =
451 crate::config::try_load_apr_for_inference(dir, &model_config)
452 {
453 m
454 } else {
455 Transformer::from_safetensors(dir, &model_config)?
456 };
457
458 let mut config = TransformerTrainConfig::new(model_config);
459 config.max_seq_len = config.model_config.max_position_embeddings;
460 Self::with_model(model, config)
461 }
462
463 pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469 if !cuda_training_available() {
470 return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471 }
472
473 let mc = &config.model_config;
474 let max_seq_len = config.max_seq_len;
475 let hidden_size = mc.hidden_size;
476 let vocab_size = mc.vocab_size;
477 let num_layers = mc.num_hidden_layers;
478
479 let cuda_trainer = CudaTrainer::new().map_err(|e| {
481 crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482 })?;
483
484 println!(
485 " GPU: {} ({:.1} GB)",
486 cuda_trainer.device_name(),
487 cuda_trainer.total_memory() as f64 / 1e9
488 );
489
490 let ctx = cuda_trainer.context().clone();
491 let stream = cuda_trainer.stream();
492
493 pre_warm_forward_kernels(
496 hidden_size,
497 mc.intermediate_size,
498 mc.num_attention_heads,
499 mc.num_kv_heads,
500 mc.head_dim(),
501 max_seq_len,
502 )
503 .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505 {
509 use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510 let head_dim = mc.head_dim();
511 pre_warm_lora_backward_kernels(
512 hidden_size,
513 mc.num_attention_heads * head_dim,
514 mc.num_kv_heads * head_dim,
515 max_seq_len,
516 config.lora_rank.unwrap_or(0),
517 mc.intermediate_size,
518 mc.num_attention_heads,
519 config.quantize_nf4 && config.is_lora(),
520 )
521 .map_err(|e| {
522 crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523 })?;
524 eprintln!(" ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525 }
526
527 if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530 println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531 }
532 if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533 println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534 }
535
536 let use_nf4 = config.quantize_nf4 && config.is_lora();
538 let cuda_blocks = Self::upload_blocks(
539 &model,
540 mc,
541 &config,
542 &ctx,
543 use_nf4,
544 num_layers,
545 hidden_size,
546 max_seq_len,
547 )?;
548
549 let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551 crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552 })?;
553
554 let buf_size = max_seq_len * hidden_size;
556 let logits_size = max_seq_len * vocab_size;
557
558 let checkpointing = config.checkpoint_config.enabled;
562 let segment_size = if checkpointing {
563 let ns = config.checkpoint_config.num_segments.max(1);
564 num_layers.div_ceil(ns)
565 } else {
566 1 };
568 let saved_layer_mask: Vec<bool> =
569 (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571 let mut layer_inputs = Vec::with_capacity(num_layers);
572 for _ in 0..num_layers {
573 layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574 crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575 })?);
576 }
577
578 let recompute_buf = if checkpointing {
580 Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581 crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582 })?)
583 } else {
584 None
585 };
586
587 if checkpointing {
588 let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589 println!(
590 " ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591 config.checkpoint_config.num_segments, saved_count, num_layers
592 );
593 }
594
595 let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597 let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598 crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599 })?;
600
601 let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602 crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603 })?;
604 let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605 crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606 })?;
607 let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608 crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609 })?;
610 let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611 crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612 })?;
613 let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614 crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615 })?;
616 let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617 crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618 })?;
619 let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621 })?;
622
623 let mut optimizer_states = Vec::new();
625 if !use_nf4 {
626 optimizer_states.reserve(num_layers);
627 for (i, block) in cuda_blocks.iter().enumerate() {
628 optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629 crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630 })?);
631 }
632 }
633
634 let gpu_training = GpuPretrainState {
635 layer_inputs,
636 saved_layer_mask,
637 recompute_buf,
638 final_norm_weight,
639 blocks_output,
640 grad_buf_a,
641 grad_buf_b,
642 grad_final_norm_weight,
643 norm_output,
644 logits_buf,
645 lm_head_grad_hidden,
646 optimizer_states,
647 step: 0,
648 };
649
650 let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653 let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654 let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655 crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656 })?;
657 let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659 })?;
660 let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663 .map_err(|e| {
664 crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665 })?;
666 let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667 .map_err(|e| {
668 crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669 })?;
670
671 let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673 crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674 })?;
675 let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676 crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677 })?;
678
679 let buf_size = max_seq_len * hidden_size;
681 let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682 crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683 })?;
684 let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685 crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686 })?;
687
688 stream
690 .synchronize()
691 .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693 println!(
694 " ✓ GPU training state allocated (LM head: {:.1} MB)",
695 (vocab_size * hidden_size * 4) as f64 / 1e6
696 );
697
698 let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700 let lora_rank = config.lora_rank.unwrap_or(16);
701
702 let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704 crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705 })?;
706
707 let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709 crate::error::Error::ConfigError(format!(
710 "NF4 LoRA grad workspace alloc failed: {e:?}"
711 ))
712 })?;
713
714 let mut lora_opt_states = Vec::with_capacity(num_layers);
716 for (i, block) in cuda_blocks.iter().enumerate() {
717 lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718 crate::error::Error::ConfigError(format!(
719 "Block {i} LoRA opt state failed: {e:?}"
720 ))
721 })?);
722 }
723
724 println!(
725 " ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726 );
727 (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728 } else {
729 (None, None, None)
730 };
731
732 let embed_optimizer =
735 AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737 let grad_accum = if config.accumulation_steps > 1 {
740 let kv_hidden = mc.num_kv_heads * mc.head_dim();
741 let block_sizes =
742 super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743 hidden_size,
744 kv_hidden,
745 mc.intermediate_size,
746 );
747 let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748 num_layers,
749 block_sizes,
750 vocab_size,
751 hidden_size,
752 );
753 println!(
754 " ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755 config.accumulation_steps,
756 (accum
757 .block_grads
758 .iter()
759 .map(super::grad_accumulator::BlockGradientSet::total_elements)
760 .sum::<usize>()
761 + accum.lm_head_grad.len()
762 + accum.final_norm_grad.len()
763 + accum.embedding_grad.len()) as f64
764 * 4.0
765 / 1e6,
766 );
767 Some(accum)
768 } else {
769 None
770 };
771
772 let gpu_grad_accum = if config.accumulation_steps > 1 {
775 match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776 Ok(accum) => {
777 println!(" ✓ GPU gradient accumulation enabled (ALB-091)");
778 Some(accum)
779 }
780 Err(e) => {
781 eprintln!(
782 " [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783 );
784 None
785 }
786 }
787 } else {
788 None
789 };
790
791 let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794 let ws_max = hidden_size * mc.intermediate_size;
795 let lm_max = vocab_size * hidden_size;
796 vec![0.0f32; ws_max.max(lm_max)]
797 } else {
798 Vec::new()
799 };
800
801 let kv_hidden = mc.num_kv_heads * mc.head_dim();
804 let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806 let grad_scaler = GradScaler::from_config(&config.precision_config);
808 if config.precision_config.is_mixed() {
809 println!(
810 " ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811 config.precision_config.compute_precision,
812 grad_scaler.scale(),
813 grad_scaler.is_dynamic(),
814 );
815 }
816
817 Ok(Self {
818 model,
819 cuda_trainer,
820 cuda_blocks,
821 cuda_grad_workspace,
822 nf4_shared_scratch,
823 nf4_lora_grad_workspace,
824 nf4_lora_optimizer_states,
825 gpu_training,
826 lm_head_weight_gpu,
827 lm_head_grad_gpu,
828 lm_head_m,
829 lm_head_v,
830 final_norm_m,
831 final_norm_v,
832 embed_optimizer,
833 profiler: if config.profile_interval > 0 {
835 StepProfiler::new(true, config.profile_interval)
836 } else {
837 StepProfiler::disabled()
838 },
839 config,
840 metrics: MetricsTracker::new(),
841 step: 0,
842 accumulated_loss: 0.0,
843 accumulated_batches: 0,
844 last_grad_norm: 0.0,
845 last_embed_grad_norm: 0.0,
846 grad_accum,
847 gpu_grad_accum,
848 grad_scaler,
849 fwd_scratch_a,
850 fwd_scratch_b,
851 h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852 d2h_staging,
853 fused_clip,
854 final_norm_zero_buf: vec![0.0f32; hidden_size],
855 })
856 }
857
858 #[allow(clippy::too_many_arguments)]
860 fn upload_blocks(
861 model: &Transformer,
862 mc: &crate::transformer::TransformerConfig,
863 config: &TransformerTrainConfig,
864 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865 use_nf4: bool,
866 num_layers: usize,
867 hidden_size: usize,
868 max_seq_len: usize,
869 ) -> crate::Result<Vec<CudaBlock>> {
870 let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872 if use_nf4 {
873 let lora_rank = config.lora_rank.unwrap_or(16);
874 let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875 let lora_scale = lora_alpha / lora_rank as f32;
876 let head_dim = mc.head_dim();
877 let q_dim = mc.num_attention_heads * head_dim;
878 let kv_hidden = mc.num_kv_heads * head_dim;
879
880 for (i, layer) in model.layers.iter().enumerate() {
881 let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882 .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883 .collect();
884 let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885 let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886 .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887 .collect();
888 let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890 let q_norm_data = layer
891 .self_attn
892 .q_norm
893 .as_ref()
894 .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895 let k_norm_data = layer
896 .self_attn
897 .k_norm
898 .as_ref()
899 .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901 let block = crate::transformer::CudaNf4TransformerBlock::new(
902 mc,
903 i,
904 ctx.clone(),
905 layer.input_norm.weight.data().as_slice().expect("contiguous"),
906 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912 layer.ffn.w_up.data().as_slice().expect("contiguous"),
913 layer.ffn.w_down.data().as_slice().expect("contiguous"),
914 max_seq_len,
915 Some((&lora_a_q, &lora_b_q)),
916 Some((&lora_a_v, &lora_b_v)),
917 lora_scale,
918 lora_rank,
919 q_norm_data.as_deref(),
920 k_norm_data.as_deref(),
921 )
922 .map_err(|e| {
923 crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924 })?;
925 cuda_blocks.push(CudaBlock::Nf4(block));
926 }
927 println!(" ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928 } else {
929 for (i, layer) in model.layers.iter().enumerate() {
930 let b_q = layer
935 .self_attn
936 .b_q
937 .as_ref()
938 .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939 let b_k = layer
940 .self_attn
941 .b_k
942 .as_ref()
943 .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944 let b_v = layer
945 .self_attn
946 .b_v
947 .as_ref()
948 .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949 let block = CudaTransformerBlock::new(
950 mc,
951 i,
952 ctx.clone(),
953 layer.input_norm.weight.data().as_slice().expect("contiguous"),
954 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960 layer.ffn.w_up.data().as_slice().expect("contiguous"),
961 layer.ffn.w_down.data().as_slice().expect("contiguous"),
962 max_seq_len,
963 b_q.as_deref(),
964 b_k.as_deref(),
965 b_v.as_deref(),
966 )
967 .map_err(|e| {
968 crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969 })?;
970 cuda_blocks.push(CudaBlock::Fp32(block));
971 }
972 println!(" ✓ {num_layers} transformer blocks uploaded to GPU");
973 }
974
975 Ok(cuda_blocks)
976 }
977
978 fn init_fused_clip(
980 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981 config: &TransformerTrainConfig,
982 hidden_size: usize,
983 kv_hidden: usize,
984 mc: &crate::transformer::TransformerConfig,
985 ) -> Option<FusedClipState> {
986 config.base.max_grad_norm?;
987 let grad_sizes: [u32; 9] = [
988 (hidden_size * hidden_size) as u32,
989 (hidden_size * kv_hidden) as u32,
990 (hidden_size * kv_hidden) as u32,
991 (hidden_size * hidden_size) as u32,
992 (hidden_size * mc.intermediate_size) as u32,
993 (hidden_size * mc.intermediate_size) as u32,
994 (mc.intermediate_size * hidden_size) as u32,
995 hidden_size as u32,
996 hidden_size as u32,
997 ];
998 match FusedClipState::new(ctx, &grad_sizes) {
999 Ok(state) => {
1000 println!(
1001 " ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002 state.total_partials,
1003 f64::from(state.total_partials) * 4.0 / 1024.0,
1004 );
1005 Some(state)
1006 }
1007 Err(e) => {
1008 println!(" ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009 None
1010 }
1011 }
1012 }
1013
1014 fn train_step_single(
1023 &mut self,
1024 input_ids: &[u32],
1025 target_ids: &[u32],
1026 accumulate_only: bool,
1027 ) -> Option<f32> {
1028 self.profiler.begin_step();
1029 let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030 self.profiler.finish_step();
1031 result
1032 }
1033
1034 fn train_step_inner(
1036 &mut self,
1037 input_ids: &[u32],
1038 target_ids: &[u32],
1039 accumulate_only: bool,
1040 ) -> Option<f32> {
1041 let hidden_size = self.config.model_config.hidden_size;
1042 let vocab_size = self.config.model_config.vocab_size;
1043
1044 let max_sl = self.config.max_seq_len;
1046 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048 let seq_len = input_ids.len();
1049
1050 if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053 eprintln!(
1054 "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056 );
1057 return None;
1058 }
1059
1060 self.profiler.begin(StepProfiler::LOSS);
1063 let stream = self.cuda_trainer.stream();
1064
1065 let mut loss_scale = 1.0 / seq_len as f32;
1074 if self.config.accumulation_steps > 1 {
1075 loss_scale /= self.config.accumulation_steps as f32;
1076 }
1077
1078 let loss_val = fused_cross_entropy_cuda(
1080 &mut self.gpu_training.logits_buf,
1081 target_ids,
1082 seq_len as u32,
1083 vocab_size as u32,
1084 loss_scale,
1085 stream,
1086 )
1087 .ok()?;
1088
1089 if !loss_val.is_finite() {
1091 return None;
1092 }
1093 self.profiler.end(StepProfiler::LOSS);
1094
1095 if let Some(grad_output_is_a) =
1104 self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105 {
1106 self.profiler.begin(StepProfiler::EMBED_BWD);
1108 self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110 self.profiler.end(StepProfiler::EMBED_BWD);
1111 }
1112
1113 Some(loss_val)
1114 }
1115
1116 #[allow(unsafe_code)]
1121 fn gpu_forward(
1122 &mut self,
1123 input_ids: &[u32],
1124 seq_len: usize,
1125 hidden_size: usize,
1126 vocab_size: usize,
1127 ) -> Option<()> {
1128 contract_pre_gpu_forward!();
1129 let stream = self.cuda_trainer.stream();
1130
1131 self.profiler.begin(StepProfiler::EMBED);
1133 let hidden = self.model.embed_tokens.forward(input_ids);
1134 let hidden_slice = hidden.data().as_slice()?;
1135 self.profiler.end(StepProfiler::EMBED);
1136
1137 self.profiler.begin(StepProfiler::H2D);
1142 self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143 self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144 if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145 eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146 return None;
1147 }
1148 self.profiler.end(StepProfiler::H2D);
1149
1150 self.profiler.begin(StepProfiler::FORWARD);
1154 let mut input_is_a = true; for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156 let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159 if input_is_a {
1160 (
1161 std::ptr::from_ref(&self.fwd_scratch_a),
1162 std::ptr::from_mut(&mut self.fwd_scratch_b),
1163 )
1164 } else {
1165 (
1166 std::ptr::from_ref(&self.fwd_scratch_b),
1167 std::ptr::from_mut(&mut self.fwd_scratch_a),
1168 )
1169 };
1170 if self.gpu_training.saved_layer_mask[i] {
1171 unsafe {
1174 self.gpu_training.layer_inputs[i]
1175 .copy_from_buffer_async(&*input_ptr, stream)
1176 .ok()?;
1177 }
1178 }
1179 self.profiler.begin_layer();
1182 unsafe {
1183 block
1184 .forward(
1185 &*input_ptr,
1186 &mut *output_ptr,
1187 seq_len,
1188 stream,
1189 self.nf4_shared_scratch.as_mut(),
1190 )
1191 .ok()?;
1192 }
1193 self.profiler.end_layer_fwd(i);
1194 input_is_a = !input_is_a;
1195 }
1196 self.profiler.end(StepProfiler::FORWARD);
1197
1198 let final_output: &GpuBuffer<f32> =
1200 if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202 self.profiler.begin(StepProfiler::NORM_LM);
1205 unsafe {
1206 self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207 }
1208
1209 rms_norm_forward_with_eps(
1214 final_output,
1215 &self.gpu_training.final_norm_weight,
1216 &mut self.gpu_training.norm_output,
1217 seq_len as u32,
1218 hidden_size as u32,
1219 self.config.model_config.rms_norm_eps,
1220 stream,
1221 )
1222 .ok()?;
1223
1224 gemm_forward(
1228 &self.gpu_training.norm_output,
1229 &self.lm_head_weight_gpu,
1230 &mut self.gpu_training.logits_buf,
1231 seq_len as u32,
1232 hidden_size as u32,
1233 vocab_size as u32,
1234 stream,
1235 )
1236 .ok()?;
1237
1238 self.profiler.end(StepProfiler::NORM_LM);
1241
1242 Some(())
1243 }
1244
1245 pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1257 let seq_len = input_ids.len();
1258 let hidden_size = self.config.model_config.hidden_size;
1259 let vocab_size = self.config.model_config.vocab_size;
1260
1261 if seq_len == 0 || seq_len > self.config.max_seq_len {
1262 return None;
1263 }
1264
1265 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1267
1268 let stream = self.cuda_trainer.stream();
1270 stream.synchronize().ok()?;
1271
1272 let offset = (seq_len - 1) * vocab_size;
1274 let mut logits = vec![0.0f32; vocab_size];
1275 self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1276
1277 Some(logits)
1278 }
1279
1280 #[allow(unsafe_code)]
1297 fn recompute_segment(
1298 gpu_training: &mut GpuPretrainState,
1299 cuda_blocks: &mut [CudaBlock],
1300 nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1301 target_layer: usize,
1302 seq_len: usize,
1303 stream: &CudaStream,
1304 ) -> Option<()> {
1305 let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1307
1308 if seg_start == target_layer {
1309 return Some(()); }
1311
1312 let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1315 unsafe {
1316 recompute_buf
1317 .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1318 .ok()?;
1319 }
1320
1321 for i in seg_start..target_layer {
1332 if i == seg_start {
1333 let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1335 let li = &mut gpu_training.layer_inputs;
1336 unsafe {
1337 cuda_blocks[i]
1338 .forward(
1339 &*recompute_ptr,
1340 &mut li[i + 1],
1341 seq_len,
1342 stream,
1343 nf4_shared_scratch.as_mut(),
1344 )
1345 .ok()?;
1346 }
1347 } else {
1348 let li = &mut gpu_training.layer_inputs;
1350 let (left, right) = li.split_at_mut(i + 1);
1351 cuda_blocks[i]
1352 .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1353 .ok()?;
1354 }
1355 }
1356
1357 Some(())
1358 }
1359
1360 #[allow(unsafe_code)]
1373 fn gpu_backward(
1374 &mut self,
1375 seq_len: usize,
1376 hidden_size: usize,
1377 vocab_size: usize,
1378 accumulate_only: bool,
1379 ) -> Option<bool> {
1380 let stream = self.cuda_trainer.stream();
1381 let max_grad_norm = self.config.base.max_grad_norm;
1382 let lr = self.current_lr();
1383 let beta1 = self.config.beta1;
1385 let beta2 = self.config.beta2;
1386 let weight_decay = self.config.weight_decay;
1387
1388 self.profiler.begin(StepProfiler::LM_BWD);
1393 gemm_backward_a(
1394 &self.gpu_training.logits_buf,
1395 &self.lm_head_weight_gpu,
1396 &mut self.gpu_training.lm_head_grad_hidden,
1397 seq_len as u32,
1398 hidden_size as u32,
1399 vocab_size as u32,
1400 stream,
1401 )
1402 .ok()?;
1403
1404 gemm_backward_b(
1405 &self.gpu_training.norm_output,
1406 &self.gpu_training.logits_buf,
1407 &mut self.lm_head_grad_gpu,
1408 seq_len as u32,
1409 hidden_size as u32,
1410 vocab_size as u32,
1411 stream,
1412 )
1413 .ok()?;
1414
1415 let lm_sq_norm =
1421 squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1422 .unwrap_or(0.0);
1423 let lm_norm = lm_sq_norm.sqrt(); self.last_grad_norm = lm_norm; if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1427 eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1428 let gh_sq = squared_sum_cuda(
1430 &self.gpu_training.lm_head_grad_hidden,
1431 self.gpu_training.lm_head_grad_hidden.len() as u32,
1432 stream,
1433 )
1434 .unwrap_or(0.0);
1435 eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1436 }
1437 if let Some(max_norm) = max_grad_norm {
1438 let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1439 let n = self.lm_head_grad_gpu.len() as u32;
1440 let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1441 }
1442 self.profiler.end(StepProfiler::LM_BWD);
1443
1444 self.profiler.begin(StepProfiler::NORM_BWD);
1446 self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1448 rms_norm_backward(
1449 &self.gpu_training.blocks_output,
1450 &self.gpu_training.final_norm_weight,
1451 &self.gpu_training.lm_head_grad_hidden,
1452 &mut self.gpu_training.grad_buf_a,
1453 &mut self.gpu_training.grad_final_norm_weight,
1454 seq_len as u32,
1455 hidden_size as u32,
1456 1e-5_f32,
1457 stream,
1458 )
1459 .ok()?;
1460
1461 if let Some(max_norm) = max_grad_norm {
1464 let (scale, _) = Self::compute_clip_scale_with_norm(
1465 &self.gpu_training.grad_final_norm_weight,
1466 max_norm,
1467 stream,
1468 );
1469 let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1470 let _ =
1471 gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1472 }
1473 self.profiler.end(StepProfiler::NORM_BWD);
1474
1475 if accumulate_only {
1477 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1479 let _ = gpu_accum.accumulate_nonblock(
1480 &self.lm_head_grad_gpu,
1481 &self.gpu_training.grad_final_norm_weight,
1482 stream,
1483 );
1484 } else {
1485 stream.synchronize().ok()?;
1486 Self::download_nonblock_grads_to_accum(
1487 &self.lm_head_grad_gpu,
1488 &self.gpu_training.grad_final_norm_weight,
1489 &mut self.grad_accum,
1490 &mut self.d2h_staging,
1491 )?;
1492 }
1493 } else {
1494 Self::run_nonblock_optimizer_step(
1495 &mut self.gpu_training,
1496 Some(&mut self.lm_head_weight_gpu),
1497 &self.lm_head_grad_gpu,
1498 &mut self.lm_head_m,
1499 &mut self.lm_head_v,
1500 &mut self.final_norm_m,
1501 &mut self.final_norm_v,
1502 lr,
1503 beta1,
1504 beta2,
1505 weight_decay,
1506 stream,
1507 );
1508 }
1509
1510 self.profiler.begin(StepProfiler::BLK_BWD);
1516 let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1517 let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1518 let mut grad_output_is_a = true;
1519 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1520
1521 for layer_idx in (0..self.cuda_blocks.len()).rev() {
1522 if !self.gpu_training.saved_layer_mask[layer_idx] {
1525 Self::recompute_segment(
1526 &mut self.gpu_training,
1527 &mut self.cuda_blocks,
1528 &mut self.nf4_shared_scratch,
1529 layer_idx,
1530 seq_len,
1531 stream,
1532 )?;
1533 }
1534
1535 let (grad_output, grad_input) = unsafe {
1536 if grad_output_is_a {
1537 (&*grad_a_ptr, &mut *grad_b_ptr)
1538 } else {
1539 (&*grad_b_ptr, &mut *grad_a_ptr)
1540 }
1541 };
1542
1543 self.profiler.begin_layer();
1544 if use_nf4 {
1545 let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1549 grad_b_ptr } else {
1551 grad_a_ptr
1552 };
1553 match self.cuda_blocks[layer_idx].backward_nf4(
1556 &self.gpu_training.layer_inputs[layer_idx],
1557 grad_output,
1558 grad_input,
1559 &mut self.gpu_training.blocks_output, seq_len,
1561 stream,
1562 self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1563 self.nf4_lora_grad_workspace
1564 .as_mut()
1565 .expect("NF4 requires LoRA grad workspace"),
1566 ) {
1567 Ok(()) => {}
1568 Err(e) => {
1569 eprintln!(
1570 "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1571 layer_idx, e, seq_len, self.config.model_config.hidden_size
1572 );
1573 return None;
1574 }
1575 }
1576
1577 if let Some(max_norm) = max_grad_norm {
1581 self.nf4_lora_grad_workspace
1582 .as_mut()
1583 .expect("NF4 requires LoRA grad ws")
1584 .clip_gradients(max_norm, stream);
1585 }
1586
1587 {
1593 let step = self.gpu_training.step;
1594 let effective_lr = if accumulate_only {
1595 lr / self.config.accumulation_steps as f32
1596 } else {
1597 lr
1598 };
1599 if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1600 let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1601 &mut opt_states[layer_idx],
1602 step,
1603 effective_lr,
1604 beta1,
1605 beta2,
1606 1e-8,
1607 weight_decay,
1608 stream,
1609 self.nf4_lora_grad_workspace
1610 .as_ref()
1611 .expect("NF4 requires LoRA grad ws"),
1612 );
1613 }
1614 }
1615 } else {
1616 self.cuda_blocks[layer_idx]
1618 .backward(
1619 &self.gpu_training.layer_inputs[layer_idx],
1620 grad_output,
1621 grad_input,
1622 seq_len,
1623 stream,
1624 &mut self.cuda_grad_workspace,
1625 )
1626 .ok()?;
1627
1628 if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1634 let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1635 &self.cuda_grad_workspace,
1636 f32::MAX,
1637 stream,
1638 );
1639 let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1641 .unwrap_or(0.0);
1642 let act_gnorm = act_sq.sqrt();
1643 eprintln!(
1644 "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1645 );
1646 }
1647
1648 if accumulate_only {
1650 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1652 let _ = gpu_accum.accumulate_block(
1653 &self.cuda_grad_workspace,
1654 layer_idx,
1655 stream,
1656 );
1657 } else {
1658 stream.synchronize().ok()?;
1660 if let Some(accum) = &mut self.grad_accum {
1661 Self::download_workspace_to_accum(
1662 &self.cuda_grad_workspace,
1663 accum,
1664 layer_idx,
1665 &mut self.d2h_staging,
1666 )?;
1667 }
1668 }
1669 } else {
1670 let step = self.gpu_training.step;
1672 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1673 &mut self.gpu_training.optimizer_states[layer_idx],
1674 step,
1675 lr,
1676 beta1,
1677 beta2,
1678 1e-8,
1679 weight_decay,
1680 stream,
1681 &self.cuda_grad_workspace,
1682 );
1683 }
1684 }
1685
1686 self.profiler.end_layer_bwd(layer_idx);
1687 grad_output_is_a = !grad_output_is_a;
1688 }
1689
1690 stream.synchronize().ok()?;
1691 self.profiler.end(StepProfiler::BLK_BWD);
1692
1693 Some(grad_output_is_a)
1694 }
1695
1696 fn download_nonblock_grads_to_accum(
1702 lm_head_grad: &GpuBuffer<f32>,
1703 final_norm_grad: &GpuBuffer<f32>,
1704 grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1705 host: &mut [f32],
1706 ) -> Option<()> {
1707 let accum = grad_accum.as_mut()?;
1708
1709 let lm_slice = &mut host[..lm_head_grad.len()];
1710 lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1711 for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1712 *d += s;
1713 }
1714
1715 let norm_slice = &mut host[..final_norm_grad.len()];
1716 final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1717 for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1718 *d += s;
1719 }
1720 Some(())
1721 }
1722
1723 #[allow(clippy::too_many_arguments)]
1726 fn run_nonblock_optimizer_step(
1727 gpu_training: &mut GpuPretrainState,
1728 lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1729 lm_head_grad_gpu: &GpuBuffer<f32>,
1730 lm_head_m: &mut GpuBuffer<f32>,
1731 lm_head_v: &mut GpuBuffer<f32>,
1732 final_norm_m: &mut GpuBuffer<f32>,
1733 final_norm_v: &mut GpuBuffer<f32>,
1734 lr: f32,
1735 beta1: f32,
1736 beta2: f32,
1737 weight_decay: f32,
1738 stream: &CudaStream,
1739 ) {
1740 gpu_training.step += 1;
1741 let step = gpu_training.step;
1742
1743 if let Some(lm_head_weight) = lm_head_weight_gpu {
1744 let n_lm = lm_head_weight.len() as u32;
1745 let _ = adamw_step_cuda(
1746 lm_head_weight,
1747 lm_head_grad_gpu,
1748 lm_head_m,
1749 lm_head_v,
1750 lr,
1751 beta1,
1752 beta2,
1753 1e-8,
1754 weight_decay,
1755 step,
1756 n_lm,
1757 stream,
1758 );
1759 }
1760
1761 let n_norm = gpu_training.final_norm_weight.len() as u32;
1762 let _ = adamw_step_cuda(
1763 &mut gpu_training.final_norm_weight,
1764 &gpu_training.grad_final_norm_weight,
1765 final_norm_m,
1766 final_norm_v,
1767 lr,
1768 beta1,
1769 beta2,
1770 1e-8,
1771 weight_decay,
1772 step,
1773 n_norm,
1774 stream,
1775 );
1776 }
1777
1778 fn download_workspace_to_accum(
1786 ws: &CudaGradWorkspace,
1787 accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1788 layer_idx: usize,
1789 host: &mut [f32],
1790 ) -> Option<()> {
1791 let bg = &mut accum.block_grads[layer_idx];
1792
1793 use super::grad_accumulator::component;
1794 let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1795 (&ws.grad_w_q, component::W_Q),
1796 (&ws.grad_w_k, component::W_K),
1797 (&ws.grad_w_v, component::W_V),
1798 (&ws.grad_w_o, component::W_O),
1799 (&ws.grad_gate, component::GATE),
1800 (&ws.grad_up, component::UP),
1801 (&ws.grad_down, component::DOWN),
1802 (&ws.grad_input_norm, component::INPUT_NORM),
1803 (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1804 ];
1805
1806 for (gpu_buf, comp_idx) in &bufs_and_components {
1807 let slice = &mut host[..gpu_buf.len()];
1808 gpu_buf.copy_to_host_at(slice, 0).ok()?;
1809 for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1810 *d += s;
1811 }
1812 }
1813 Some(())
1814 }
1815
1816 fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1823 let stream = self.cuda_trainer.stream();
1824 let lr = self.current_lr();
1825 let beta1 = self.config.beta1;
1826 let beta2 = self.config.beta2;
1827 let weight_decay = self.config.weight_decay;
1828
1829 stream.synchronize().ok()?;
1831
1832 self.gpu_training.step += 1;
1833 let step = self.gpu_training.step;
1834
1835 let gpu_accum = self.gpu_grad_accum.as_ref()?;
1837 for layer_idx in 0..self.cuda_blocks.len() {
1838 gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1839
1840 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1841 &mut self.gpu_training.optimizer_states[layer_idx],
1842 step,
1843 lr,
1844 beta1,
1845 beta2,
1846 1e-8,
1847 weight_decay,
1848 stream,
1849 &self.cuda_grad_workspace,
1850 );
1851 }
1852
1853 gpu_accum
1855 .upload_nonblock(
1856 &mut self.lm_head_grad_gpu,
1857 &mut self.gpu_training.grad_final_norm_weight,
1858 )
1859 .ok()?;
1860
1861 let n_lm = self.lm_head_weight_gpu.len() as u32;
1862 let _ = adamw_step_cuda(
1863 &mut self.lm_head_weight_gpu,
1864 &self.lm_head_grad_gpu,
1865 &mut self.lm_head_m,
1866 &mut self.lm_head_v,
1867 lr,
1868 beta1,
1869 beta2,
1870 1e-8,
1871 weight_decay,
1872 step,
1873 n_lm,
1874 stream,
1875 );
1876
1877 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1879 let _ = adamw_step_cuda(
1880 &mut self.gpu_training.final_norm_weight,
1881 &self.gpu_training.grad_final_norm_weight,
1882 &mut self.final_norm_m,
1883 &mut self.final_norm_v,
1884 lr,
1885 beta1,
1886 beta2,
1887 1e-8,
1888 weight_decay,
1889 step,
1890 n_norm,
1891 stream,
1892 );
1893
1894 stream.synchronize().ok()?;
1895
1896 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1898 let _ = gpu_accum.zero_all();
1899 }
1900
1901 Some(())
1902 }
1903
1904 #[allow(unsafe_code)]
1905 fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1906 let stream = self.cuda_trainer.stream();
1907 let lr = self.current_lr();
1908 let beta1 = self.config.beta1;
1909 let beta2 = self.config.beta2;
1910 let weight_decay = self.config.weight_decay;
1911
1912 let accum = self.grad_accum.as_mut()?;
1914 accum.average();
1915
1916 if accum.has_non_finite() {
1918 println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1919 accum.zero_all();
1920 return Some(());
1921 }
1922
1923 self.gpu_training.step += 1;
1924 let step = self.gpu_training.step;
1925
1926 use super::grad_accumulator::component;
1928 for layer_idx in 0..self.cuda_blocks.len() {
1929 let bg = &accum.block_grads[layer_idx];
1930
1931 unsafe {
1935 self.cuda_grad_workspace
1936 .grad_w_q
1937 .copy_from_host_async(&bg.components[component::W_Q], stream)
1938 .ok()?;
1939 self.cuda_grad_workspace
1940 .grad_w_k
1941 .copy_from_host_async(&bg.components[component::W_K], stream)
1942 .ok()?;
1943 self.cuda_grad_workspace
1944 .grad_w_v
1945 .copy_from_host_async(&bg.components[component::W_V], stream)
1946 .ok()?;
1947 self.cuda_grad_workspace
1948 .grad_w_o
1949 .copy_from_host_async(&bg.components[component::W_O], stream)
1950 .ok()?;
1951 self.cuda_grad_workspace
1952 .grad_gate
1953 .copy_from_host_async(&bg.components[component::GATE], stream)
1954 .ok()?;
1955 self.cuda_grad_workspace
1956 .grad_up
1957 .copy_from_host_async(&bg.components[component::UP], stream)
1958 .ok()?;
1959 self.cuda_grad_workspace
1960 .grad_down
1961 .copy_from_host_async(&bg.components[component::DOWN], stream)
1962 .ok()?;
1963 self.cuda_grad_workspace
1964 .grad_input_norm
1965 .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1966 .ok()?;
1967 self.cuda_grad_workspace
1968 .grad_post_attn_norm
1969 .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1970 .ok()?;
1971 }
1972
1973 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1975 &mut self.gpu_training.optimizer_states[layer_idx],
1976 step,
1977 lr,
1978 beta1,
1979 beta2,
1980 1e-8,
1981 weight_decay,
1982 stream,
1983 &self.cuda_grad_workspace,
1984 );
1985 }
1986
1987 unsafe {
1991 self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1992 }
1993 let n_lm = self.lm_head_weight_gpu.len() as u32;
1994 let _ = adamw_step_cuda(
1995 &mut self.lm_head_weight_gpu,
1996 &self.lm_head_grad_gpu,
1997 &mut self.lm_head_m,
1998 &mut self.lm_head_v,
1999 lr,
2000 beta1,
2001 beta2,
2002 1e-8,
2003 weight_decay,
2004 step,
2005 n_lm,
2006 stream,
2007 );
2008
2009 unsafe {
2012 self.gpu_training
2013 .grad_final_norm_weight
2014 .copy_from_host_async(&accum.final_norm_grad, stream)
2015 .ok()?;
2016 }
2017 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2018 let _ = adamw_step_cuda(
2019 &mut self.gpu_training.final_norm_weight,
2020 &self.gpu_training.grad_final_norm_weight,
2021 &mut self.final_norm_m,
2022 &mut self.final_norm_v,
2023 lr,
2024 beta1,
2025 beta2,
2026 1e-8,
2027 weight_decay,
2028 step,
2029 n_norm,
2030 stream,
2031 );
2032
2033 stream.synchronize().ok()?;
2034
2035 accum.zero_all();
2037 Some(())
2038 }
2039
2040 fn compute_clip_scale_with_norm(
2053 buf: &GpuBuffer<f32>,
2054 max_norm: f32,
2055 stream: &CudaStream,
2056 ) -> (f32, f32) {
2057 let n = buf.len() as u32;
2058 let grad_norm = match squared_sum_cuda(buf, n, stream) {
2060 Ok(norm) => norm,
2061 Err(_) => {
2062 let mut host = vec![0.0f32; buf.len()];
2064 if buf.copy_to_host_at(&mut host, 0).is_err() {
2065 return (1.0, 0.0);
2066 }
2067 let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2068 sq_sum.sqrt() as f32
2069 }
2070 };
2071 let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2072 (scale, grad_norm)
2073 }
2074
2075 #[allow(unsafe_code)]
2084 fn embed_backward(
2085 &mut self,
2086 input_ids: &[u32],
2087 _seq_len: usize,
2088 hidden_size: usize,
2089 vocab_size: usize,
2090 grad_output_is_a: bool,
2091 ) -> Option<()> {
2092 let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2094 let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2095 let embed_grad_buf = unsafe {
2096 if grad_output_is_a {
2097 &*grad_a_ptr
2098 } else {
2099 &*grad_b_ptr
2100 }
2101 };
2102 let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2103
2104 let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2112 {
2113 let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2114 let grad_norm = sq_sum.sqrt() as f32;
2115 self.last_embed_grad_norm = grad_norm; if grad_norm > embed_clip_norm {
2117 let scale = embed_clip_norm / grad_norm;
2118 for g in &mut embed_grad_data {
2119 *g *= scale;
2120 }
2121 }
2122 }
2123
2124 let embed_weight = &mut self.model.embed_tokens.weight;
2128 let grad_cell = embed_weight.grad_cell();
2129 let mut grad_ref = grad_cell.borrow_mut();
2130 if grad_ref.is_none() {
2131 *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2132 }
2133 if let Some(grad) = grad_ref.as_mut() {
2134 for (pos, &token_id) in input_ids.iter().enumerate() {
2135 let tid = token_id as usize;
2136 if tid < vocab_size {
2137 let src = pos * hidden_size;
2138 let dst = tid * hidden_size;
2139 for h in 0..hidden_size {
2140 grad[dst + h] += embed_grad_data[src + h];
2141 }
2142 }
2143 }
2144 }
2145 Some(())
2146 }
2147
2148 fn optimizer_step(&mut self) {
2154 self.grad_scaler.update(true);
2158
2159 self.embed_optimizer.set_lr(self.current_lr());
2161 let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2163 self.embed_optimizer.step_refs(&mut embed_params);
2164
2165 self.step += 1;
2166 self.metrics.losses.push(self.accumulated_loss);
2167 self.metrics.increment_step();
2168
2169 self.accumulated_loss = 0.0;
2170 self.accumulated_batches = 0;
2171 }
2172
2173 pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2185 if batch.batch_size == 0 {
2186 return 0.0;
2187 }
2188
2189 let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2190
2191 if self.accumulated_batches == 0 {
2192 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2194 }
2195
2196 let mut total_loss = 0.0;
2197 let mut valid_count = 0;
2198
2199 for i in 0..batch.batch_size {
2200 let Some(input_ids) = batch.get_input(i) else {
2201 continue;
2202 };
2203 let Some(target_ids) = batch.get_target(i) else {
2204 continue;
2205 };
2206
2207 if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2211 total_loss += loss;
2212 valid_count += 1;
2213 if accumulating {
2214 if let Some(accum) = &mut self.gpu_grad_accum {
2215 accum.accumulated_count += 1;
2216 } else if let Some(accum) = &mut self.grad_accum {
2217 accum.accumulated_count += 1;
2218 }
2219 }
2220 }
2221 }
2222
2223 let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2224
2225 if avg_loss == 0.0 && valid_count > 0 {
2227 eprintln!(
2228 "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2229 valid_count, total_loss, batch.batch_size
2230 );
2231 }
2232
2233 self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2234 self.accumulated_batches += 1;
2235
2236 if self.accumulated_batches >= self.config.accumulation_steps {
2237 if accumulating {
2238 if self.gpu_grad_accum.is_some() {
2240 self.gpu_optimizer_from_gpu_accum();
2241 } else {
2242 self.gpu_optimizer_from_accum();
2243 }
2244 }
2245 self.optimizer_step();
2246 }
2247
2248 avg_loss
2249 }
2250
2251 pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2255 let hidden_size = self.config.model_config.hidden_size;
2256 let vocab_size = self.config.model_config.vocab_size;
2257 let max_sl = self.config.max_seq_len;
2258 let mut total_loss = 0.0;
2259 let mut valid_count = 0;
2260 for i in 0..batch.batch_size {
2261 if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2262 {
2263 total_loss += loss;
2264 valid_count += 1;
2265 }
2266 }
2267 if valid_count > 0 {
2268 total_loss / valid_count as f32
2269 } else {
2270 0.0
2271 }
2272 }
2273
2274 fn eval_single_sequence(
2276 &mut self,
2277 batch: &LMBatch,
2278 i: usize,
2279 max_sl: usize,
2280 hidden_size: usize,
2281 vocab_size: usize,
2282 ) -> Option<f32> {
2283 let input_ids = batch.get_input(i)?;
2284 let target_ids = batch.get_target(i)?;
2285 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2287 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2288 let seq_len = input_ids.len();
2289 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2290 let stream = self.cuda_trainer.stream();
2291 let scale = 1.0 / seq_len as f32;
2292 let loss = fused_cross_entropy_cuda(
2293 &mut self.gpu_training.logits_buf,
2294 target_ids,
2295 seq_len as u32,
2296 vocab_size as u32,
2297 scale,
2298 stream,
2299 )
2300 .ok()?;
2301 if loss.is_finite() {
2302 Some(loss)
2303 } else {
2304 None
2305 }
2306 }
2307
2308 pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2310 self.train_epoch_with_callback(batches, |_, _, _| {})
2311 }
2312
2313 pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2317 where
2318 F: FnMut(usize, f32, &Self),
2319 {
2320 if batches.is_empty() {
2321 return 0.0;
2322 }
2323
2324 let mut total_loss = 0.0;
2325 let mut batches_processed = 0;
2326
2327 for (i, batch) in batches.iter().enumerate() {
2328 if let Some(max) = self.config.max_steps {
2329 if self.step >= max {
2330 break;
2331 }
2332 }
2333
2334 let batch_loss = self.train_batch(batch);
2335 total_loss += batch_loss;
2336 batches_processed += 1;
2337 on_batch(i, batch_loss, self);
2338 }
2339
2340 if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2342 self.profiler.print_report();
2343 }
2344
2345 total_loss / batches_processed.max(1) as f32
2346 }
2347
2348 pub(crate) fn ensure_grad_accum(&mut self) {
2355 if self.grad_accum.is_some() {
2356 return;
2357 }
2358 let mc = &self.config.model_config;
2359 let hidden_size = mc.hidden_size;
2360 let kv_hidden = mc.num_kv_heads * mc.head_dim();
2361 let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2362 hidden_size,
2363 kv_hidden,
2364 mc.intermediate_size,
2365 );
2366 self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2367 self.cuda_blocks.len(),
2368 block_sizes,
2369 mc.vocab_size,
2370 hidden_size,
2371 ));
2372 }
2373
2374 pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2379 if batch.batch_size == 0 {
2380 return 0.0;
2381 }
2382
2383 if self.accumulated_batches == 0 {
2384 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2385 }
2386
2387 let mut total_loss = 0.0;
2388 let mut valid_count = 0;
2389
2390 for i in 0..batch.batch_size {
2391 let Some(input_ids) = batch.get_input(i) else { continue };
2392 let Some(target_ids) = batch.get_target(i) else { continue };
2393
2394 if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2396 total_loss += loss;
2397 valid_count += 1;
2398 if let Some(accum) = &mut self.grad_accum {
2399 accum.accumulated_count += 1;
2400 }
2401 }
2402 }
2403
2404 if valid_count > 0 {
2405 total_loss / valid_count as f32
2406 } else {
2407 0.0
2408 }
2409 }
2410
2411 pub(crate) fn apply_ddp_gradients(&mut self) {
2417 self.accumulated_loss = 0.0;
2418 self.accumulated_batches = 0;
2419 self.gpu_optimizer_from_accum();
2420 self.optimizer_step();
2421 }
2422
2423 pub(crate) fn grad_accum_ref(
2425 &self,
2426 ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2427 self.grad_accum.as_ref()
2428 }
2429
2430 pub(crate) fn grad_accum_mut(
2432 &mut self,
2433 ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2434 self.grad_accum.as_mut()
2435 }
2436
2437 pub(crate) fn config(&self) -> &TransformerTrainConfig {
2439 &self.config
2440 }
2441
2442 pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2444 self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2445 }
2446
2447 pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2449 self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2450 }
2451
2452 pub fn reached_max_steps(&self) -> bool {
2454 self.config.max_steps.is_some_and(|max| self.step >= max)
2455 }
2456
2457 pub fn step(&self) -> usize {
2459 self.step
2460 }
2461
2462 pub fn set_initial_step(&mut self, step: usize) {
2468 self.step = step;
2469 self.gpu_training.step = step as u32;
2470 }
2471
2472 pub fn set_max_steps(&mut self, max_steps: usize) {
2478 self.config.max_steps = Some(max_steps);
2479 }
2480
2481 pub fn current_lr(&self) -> f32 {
2487 let base_lr = self.config.lr;
2488 if self.step < self.config.warmup_steps {
2489 base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2491 } else if let Some(max_steps) = self.config.max_steps {
2492 let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2494 if decay_steps == 0 {
2495 return base_lr;
2496 }
2497 let decay_step = self.step - self.config.warmup_steps;
2498 let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2499 0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2500 } else {
2501 base_lr
2503 }
2504 }
2505
2506 pub fn enable_profiler(&mut self, interval: usize) {
2517 self.profiler = StepProfiler::new(true, interval);
2518 }
2519
2520 pub fn print_profiler_report(&self) {
2522 self.profiler.print_report();
2523 }
2524
2525 pub fn last_grad_norm(&self) -> f32 {
2527 self.last_grad_norm
2528 }
2529
2530 pub fn param_grad_norms(&self) -> (f32, f32) {
2533 (self.last_grad_norm, self.last_embed_grad_norm)
2534 }
2535
2536 pub fn num_params(&self) -> usize {
2538 self.model.parameters().iter().map(|t| t.len()).sum()
2539 }
2540
2541 pub fn gpu_memory_mb(&self) -> (u64, u64) {
2543 match self.cuda_trainer.context().memory_info() {
2544 Ok((free, total)) => {
2545 let total_mb = (total / (1024 * 1024)) as u64;
2546 let used_mb = ((total - free) / (1024 * 1024)) as u64;
2547 (used_mb, total_mb)
2548 }
2549 Err(_) => (0, 0),
2550 }
2551 }
2552
2553 pub fn sync_weights_to_cpu(&mut self) {
2559 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2560
2561 if use_nf4 {
2562 } else {
2568 for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2569 if let Ok(weights) = block.download_weights() {
2570 let layer = &mut self.model.layers[layer_idx];
2571
2572 layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2573 layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2574 layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2575 layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2576
2577 layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2578 layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2579 layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2580
2581 layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2582 layer.post_attn_norm.weight =
2583 Tensor::from_vec(weights.post_attn_norm_weight, false);
2584 }
2585 }
2586 }
2587
2588 if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2590 self.model.norm.weight = Tensor::from_vec(norm_data, false);
2591 }
2592
2593 if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2600 self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2601 }
2602 }
2603
2604 pub fn model(&self) -> &Transformer {
2606 &self.model
2607 }
2608
2609 pub fn model_mut(&mut self) -> &mut Transformer {
2611 &mut self.model
2612 }
2613
2614 pub fn is_mixed_precision(&self) -> bool {
2616 self.config.precision_config.is_mixed()
2617 }
2618
2619 pub fn grad_scaler(&self) -> &GradScaler {
2621 &self.grad_scaler
2622 }
2623
2624 pub fn is_checkpointing(&self) -> bool {
2626 self.config.checkpoint_config.enabled
2627 }
2628
2629 pub fn save(
2631 &mut self,
2632 path: impl AsRef<std::path::Path>,
2633 name: &str,
2634 architecture: &str,
2635 ) -> crate::Result<()> {
2636 self.sync_weights_to_cpu();
2637
2638 let params: Vec<(String, Tensor)> = self
2640 .model
2641 .named_parameters()
2642 .into_iter()
2643 .map(|(name, tensor)| (name, tensor.clone()))
2644 .collect();
2645
2646 let metadata = ModelMetadata::new(name, architecture);
2647 let model = Model::new(metadata, params);
2648 let config = SaveConfig::new(ModelFormat::SafeTensors);
2649
2650 save_model(&model, path, &config)
2651 }
2652
2653 pub fn prepare_async_save(
2657 &mut self,
2658 name: &str,
2659 architecture: &str,
2660 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2661 self.sync_weights_to_cpu();
2662
2663 let param_data: Vec<(String, Vec<f32>)> = self
2665 .model
2666 .named_parameters()
2667 .into_iter()
2668 .map(|(n, t)| (n, t.data().to_vec()))
2669 .collect();
2670
2671 let name = name.to_string();
2672 let architecture = architecture.to_string();
2673
2674 Box::new(move |path: &std::path::Path| {
2675 let params: Vec<(String, Tensor)> =
2676 param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2677 let metadata = ModelMetadata::new(&name, &architecture);
2678 let model = Model::new(metadata, params);
2679 let config = SaveConfig::new(ModelFormat::SafeTensors);
2680 save_model(&model, path, &config)
2681 })
2682 }
2683
2684 pub fn save_apr(
2689 &mut self,
2690 path: impl AsRef<std::path::Path>,
2691 name: &str,
2692 architecture: &str,
2693 ) -> crate::Result<()> {
2694 self.sync_weights_to_cpu();
2695
2696 let params: Vec<(String, Tensor)> = self
2697 .model
2698 .named_parameters()
2699 .into_iter()
2700 .map(|(name, tensor)| (name, tensor.clone()))
2701 .collect();
2702
2703 let metadata = ModelMetadata::new(name, architecture);
2704 let model = Model::new(metadata, params);
2705 let config = SaveConfig::new(ModelFormat::Apr);
2706
2707 save_model(&model, path, &config)
2708 }
2709
2710 fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2717 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2718 if use_nf4 {
2719 let frozen_suffixes = [
2720 "q_proj.weight",
2721 "k_proj.weight",
2722 "v_proj.weight",
2723 "o_proj.weight",
2724 "gate_proj.weight",
2725 "up_proj.weight",
2726 "down_proj.weight",
2727 ];
2728 self.model
2729 .named_parameters()
2730 .into_iter()
2731 .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2732 .map(|(n, t)| (n, t.data().to_vec()))
2733 .collect()
2734 } else {
2735 self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2736 }
2737 }
2738
2739 fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2740 if self.config.quantize_nf4 && self.config.is_lora() {
2741 self.cuda_blocks
2742 .iter()
2743 .enumerate()
2744 .filter_map(|(i, block)| {
2745 block
2746 .download_lora_weights()
2747 .ok()
2748 .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2749 })
2750 .collect()
2751 } else {
2752 Vec::new()
2753 }
2754 }
2755
2756 pub fn prepare_async_apr_save(
2757 &mut self,
2758 name: &str,
2759 architecture: &str,
2760 step: usize,
2761 loss: f64,
2762 lr: f64,
2763 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2764 self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2765 }
2766
2767 pub fn prepare_async_apr_save_with_tokenizer(
2773 &mut self,
2774 name: &str,
2775 architecture: &str,
2776 step: usize,
2777 loss: f64,
2778 lr: f64,
2779 tokenizer_path: Option<&std::path::Path>,
2780 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2781 self.sync_weights_to_cpu();
2782
2783 let param_data = self.snapshot_param_data();
2784 let lora_data = self.snapshot_lora_data();
2785
2786 let embed_m: Vec<Vec<f32>> = self
2788 .embed_optimizer
2789 .first_moments()
2790 .iter()
2791 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2792 .collect();
2793 let embed_v: Vec<Vec<f32>> = self
2794 .embed_optimizer
2795 .second_moments()
2796 .iter()
2797 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2798 .collect();
2799 let embed_step = self.embed_optimizer.step_count();
2800
2801 let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2806 .gpu_training
2807 .optimizer_states
2808 .iter()
2809 .map(|state| state.download_to_host().unwrap_or_default())
2810 .collect();
2811
2812 let lm_head_m_host = {
2814 let mut buf = vec![0.0f32; self.lm_head_m.len()];
2815 let _ = self.lm_head_m.copy_to_host(&mut buf);
2816 buf
2817 };
2818 let lm_head_v_host = {
2819 let mut buf = vec![0.0f32; self.lm_head_v.len()];
2820 let _ = self.lm_head_v.copy_to_host(&mut buf);
2821 buf
2822 };
2823 let final_norm_m_host = {
2824 let mut buf = vec![0.0f32; self.final_norm_m.len()];
2825 let _ = self.final_norm_m.copy_to_host(&mut buf);
2826 buf
2827 };
2828 let final_norm_v_host = {
2829 let mut buf = vec![0.0f32; self.final_norm_v.len()];
2830 let _ = self.final_norm_v.copy_to_host(&mut buf);
2831 buf
2832 };
2833
2834 let name = name.to_string();
2835 let architecture = architecture.to_string();
2836 let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2837 let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2838
2839 let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2842 tokenizer_path.and_then(|p| {
2843 let json_bytes = std::fs::read(p).ok()?;
2844 let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2845 let model = tok.get("model")?;
2846 let vocab_obj = model.get("vocab")?.as_object()?;
2847 let mut vocab_pairs: Vec<(String, u64)> =
2849 vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
2850 vocab_pairs.sort_by_key(|(_, id)| *id);
2851 let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
2852 let merges: Vec<String> = model
2854 .get("merges")?
2855 .as_array()?
2856 .iter()
2857 .filter_map(|v| v.as_str().map(String::from))
2858 .collect();
2859 let added = tok.get("added_tokens").and_then(|a| a.as_array());
2861 let bos_id = added.and_then(|arr| {
2862 arr.iter()
2863 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
2864 .and_then(|t| t.get("id")?.as_u64())
2865 });
2866 let eos_id = added.and_then(|arr| {
2867 arr.iter()
2868 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
2869 .and_then(|t| t.get("id")?.as_u64())
2870 });
2871 if vocab.is_empty() {
2872 return None;
2873 }
2874 println!(
2875 " [ALB-130] Embedding tokenizer: {} vocab, {} merges",
2876 vocab.len(),
2877 merges.len()
2878 );
2879 Some((vocab, merges, bos_id, eos_id))
2880 });
2881
2882 Box::new(move |path: &std::path::Path| {
2883 use aprender::serialization::apr::AprWriter;
2884 use serde_json::Value as Jv;
2885
2886 let mut writer = AprWriter::new();
2887
2888 writer.set_metadata("model_name", Jv::String(name));
2890 writer.set_metadata("architecture", Jv::String(architecture));
2891 writer.set_metadata(
2892 "format",
2893 Jv::String(if is_delta_checkpoint {
2894 "entrenar-delta-checkpoint".into()
2895 } else {
2896 "entrenar-checkpoint".into()
2897 }),
2898 );
2899 writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
2900 writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
2901 writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
2902 writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
2903 if let Some(cfg) = model_config_json {
2904 writer.set_metadata("model_config", Jv::String(cfg));
2905 }
2906
2907 if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
2909 writer.set_metadata(
2910 "tokenizer.vocabulary",
2911 Jv::Array(vocab.into_iter().map(Jv::String).collect()),
2912 );
2913 writer.set_metadata(
2914 "tokenizer.merges",
2915 Jv::Array(merges.into_iter().map(Jv::String).collect()),
2916 );
2917 if let Some(bos) = bos_id {
2918 writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
2919 }
2920 if let Some(eos) = eos_id {
2921 writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
2922 }
2923 }
2924
2925 let hidden_size = param_data
2927 .iter()
2928 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
2929 .map_or(0, |(_, d)| d.len());
2930
2931 for (tensor_name, data) in ¶m_data {
2933 let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
2934 writer.add_tensor_f32(tensor_name.clone(), shape, data);
2935 }
2936
2937 for (i, m_data) in embed_m.iter().enumerate() {
2939 let len = m_data.len();
2940 writer.add_tensor_f32(
2941 format!("__training__.embed_optimizer.m.{i}"),
2942 vec![len],
2943 m_data,
2944 );
2945 }
2946 for (i, v_data) in embed_v.iter().enumerate() {
2947 let len = v_data.len();
2948 writer.add_tensor_f32(
2949 format!("__training__.embed_optimizer.v.{i}"),
2950 vec![len],
2951 v_data,
2952 );
2953 }
2954
2955 for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
2957 for (suffix, data) in buffers {
2958 let len = data.len();
2959 writer.add_tensor_f32(
2960 format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
2961 vec![len],
2962 data,
2963 );
2964 }
2965 }
2966
2967 if !lm_head_m_host.is_empty() {
2969 let len = lm_head_m_host.len();
2970 writer.add_tensor_f32(
2971 "__training__.lm_head_optimizer.m".to_string(),
2972 vec![len],
2973 &lm_head_m_host,
2974 );
2975 let len = lm_head_v_host.len();
2976 writer.add_tensor_f32(
2977 "__training__.lm_head_optimizer.v".to_string(),
2978 vec![len],
2979 &lm_head_v_host,
2980 );
2981 }
2982 if !final_norm_m_host.is_empty() {
2983 let len = final_norm_m_host.len();
2984 writer.add_tensor_f32(
2985 "__training__.final_norm_optimizer.m".to_string(),
2986 vec![len],
2987 &final_norm_m_host,
2988 );
2989 let len = final_norm_v_host.len();
2990 writer.add_tensor_f32(
2991 "__training__.final_norm_optimizer.v".to_string(),
2992 vec![len],
2993 &final_norm_v_host,
2994 );
2995 }
2996
2997 for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
2999 if !a_q.is_empty() {
3000 writer.add_tensor_f32(
3001 format!("lora.{layer_idx}.q_proj.lora_a"),
3002 vec![a_q.len()],
3003 a_q,
3004 );
3005 writer.add_tensor_f32(
3006 format!("lora.{layer_idx}.q_proj.lora_b"),
3007 vec![b_q.len()],
3008 b_q,
3009 );
3010 }
3011 if !a_v.is_empty() {
3012 writer.add_tensor_f32(
3013 format!("lora.{layer_idx}.v_proj.lora_a"),
3014 vec![a_v.len()],
3015 a_v,
3016 );
3017 writer.add_tensor_f32(
3018 format!("lora.{layer_idx}.v_proj.lora_b"),
3019 vec![b_v.len()],
3020 b_v,
3021 );
3022 }
3023 }
3024
3025 writer
3027 .write(path)
3028 .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3029
3030 Ok(())
3031 })
3032 }
3033
3034 pub fn gpu_name(&self) -> String {
3036 self.cuda_trainer.device_name()
3037 }
3038
3039 pub fn save_cuda_lora_adapter(
3049 &self,
3050 output_dir: &std::path::Path,
3051 base_model_name: Option<&str>,
3052 ) -> crate::Result<()> {
3053 if !self.config.quantize_nf4 || !self.config.is_lora() {
3054 return Ok(()); }
3056
3057 let lora_rank = self.config.lora_rank.unwrap_or(16);
3058 let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3059 let lora_scale = lora_alpha / lora_rank as f32;
3060 let hidden_size = self.config.model_config.hidden_size;
3061 let head_dim = self.config.model_config.head_dim();
3062 let q_dim = self.config.model_config.num_attention_heads * head_dim;
3063 let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3064
3065 let lora_config =
3066 crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3067
3068 let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3069
3070 for (i, block) in self.cuda_blocks.iter().enumerate() {
3071 let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3072 Ok(weights) => weights,
3073 Err(_) => continue, };
3075
3076 if a_q.is_empty() && a_v.is_empty() {
3077 continue;
3078 }
3079
3080 if !a_q.is_empty() {
3082 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3084 for r in 0..hidden_size {
3085 for c in 0..lora_rank {
3086 a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3087 }
3088 }
3089
3090 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3093 let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3094 for r in 0..lora_rank {
3095 for c in 0..q_dim {
3096 b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3097 }
3098 }
3099
3100 let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3101 let mut layer = crate::lora::LoRALayer::new(
3102 base_weight,
3103 q_dim,
3104 hidden_size,
3105 lora_rank,
3106 lora_alpha,
3107 );
3108 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3110 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3111
3112 adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3113 }
3114
3115 if !a_v.is_empty() {
3117 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3118 for r in 0..hidden_size {
3119 for c in 0..lora_rank {
3120 a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3121 }
3122 }
3123
3124 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3125 let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3126 for r in 0..lora_rank {
3127 for c in 0..kv_hidden {
3128 b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3129 }
3130 }
3131
3132 let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3133 let mut layer = crate::lora::LoRALayer::new(
3134 base_weight,
3135 kv_hidden,
3136 hidden_size,
3137 lora_rank,
3138 lora_alpha,
3139 );
3140 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3141 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3142
3143 adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3144 }
3145 }
3146
3147 if adapters.is_empty() {
3148 println!(" [WARN] No LoRA adapters found to save");
3149 return Ok(());
3150 }
3151
3152 let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3153 adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3154
3155 std::fs::create_dir_all(output_dir).ok();
3156 crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3157 .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3158
3159 let adapter_path = output_dir.join("adapter_model.safetensors");
3160 let size_mb =
3161 std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3162 println!(
3163 "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3164 adapters.len(),
3165 size_mb,
3166 output_dir.display()
3167 );
3168
3169 Ok(())
3170 }
3171
3172 pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3177 let path = dir.join("optimizer_state.json");
3178 let m_data: Vec<Option<Vec<f32>>> = self
3179 .embed_optimizer
3180 .first_moments()
3181 .iter()
3182 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3183 .collect();
3184 let v_data: Vec<Option<Vec<f32>>> = self
3185 .embed_optimizer
3186 .second_moments()
3187 .iter()
3188 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3189 .collect();
3190 let state = serde_json::json!({
3191 "type": "adamw_cpu_embed",
3192 "step": self.embed_optimizer.step_count(),
3193 "m": m_data,
3194 "v": v_data,
3195 });
3196 let json_str = serde_json::to_string(&state).map_err(|e| {
3197 crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3198 })?;
3199 std::fs::write(&path, json_str)
3200 .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3201 Ok(())
3202 }
3203
3204 pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3210 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3211 Ok(r) => r,
3212 Err(_) => return (0, self.cuda_blocks.len()),
3213 };
3214
3215 let mut restored = 0usize;
3216 for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3217 let a_q =
3218 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3219 let b_q =
3220 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3221 let a_v =
3222 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3223 let b_v =
3224 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3225
3226 if a_q.is_empty() {
3227 continue; }
3229
3230 if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3231 eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3232 continue;
3233 }
3234 restored += 1;
3235 }
3236
3237 (restored, self.cuda_blocks.len())
3238 }
3239
3240 pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3245 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3246 Ok(r) => r,
3247 Err(_) => return false,
3248 };
3249
3250 if let Some(step_val) = reader.get_metadata("optimizer_step") {
3252 if let Some(step_str) = step_val.as_str() {
3253 if let Ok(step) = step_str.parse::<u64>() {
3254 self.embed_optimizer.set_step_count(step);
3255 }
3256 }
3257 }
3258
3259 for i in 0..128 {
3261 let name = format!("__training__.embed_optimizer.m.{i}");
3262 match reader.read_tensor_f32(&name) {
3263 Ok(data) if !data.is_empty() => {
3264 self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3265 }
3266 _ => break,
3267 }
3268 }
3269
3270 for i in 0..128 {
3272 let name = format!("__training__.embed_optimizer.v.{i}");
3273 match reader.read_tensor_f32(&name) {
3274 Ok(data) if !data.is_empty() => {
3275 self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3276 }
3277 _ => break,
3278 }
3279 }
3280
3281 let suffixes = [
3283 "m.w_q",
3284 "v.w_q",
3285 "m.w_k",
3286 "v.w_k",
3287 "m.w_v",
3288 "v.w_v",
3289 "m.w_o",
3290 "v.w_o",
3291 "m.w_gate",
3292 "v.w_gate",
3293 "m.w_up",
3294 "v.w_up",
3295 "m.w_down",
3296 "v.w_down",
3297 "m.input_norm",
3298 "v.input_norm",
3299 "m.post_attn_norm",
3300 "v.post_attn_norm",
3301 ];
3302 let mut blocks_restored = 0usize;
3303 for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3304 let mut data = std::collections::HashMap::new();
3305 for suffix in &suffixes {
3306 let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3307 if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3308 if !tensor_data.is_empty() {
3309 data.insert(suffix.to_string(), tensor_data);
3310 }
3311 }
3312 }
3313 if !data.is_empty() {
3314 let _ = state.restore_from_host(&data);
3315 blocks_restored += 1;
3316 }
3317 }
3318
3319 if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3321 if m_data.len() == self.lm_head_m.len() {
3322 let _ = self.lm_head_m.copy_from_host(&m_data);
3323 }
3324 }
3325 if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3326 if v_data.len() == self.lm_head_v.len() {
3327 let _ = self.lm_head_v.copy_from_host(&v_data);
3328 }
3329 }
3330
3331 if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3333 if m_data.len() == self.final_norm_m.len() {
3334 let _ = self.final_norm_m.copy_from_host(&m_data);
3335 }
3336 }
3337 if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3338 if v_data.len() == self.final_norm_v.len() {
3339 let _ = self.final_norm_v.copy_from_host(&v_data);
3340 }
3341 }
3342
3343 if blocks_restored > 0 {
3345 println!(
3346 " ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3347 self.gpu_training.optimizer_states.len()
3348 );
3349 } else if !self.gpu_training.optimizer_states.is_empty() {
3350 println!(
3351 " [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3352 self.gpu_training.optimizer_states.len()
3353 );
3354 }
3355
3356 true
3357 }
3358
3359 pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3363 let path = dir.join("optimizer_state.json");
3364 let data = match std::fs::read_to_string(&path) {
3365 Ok(d) => d,
3366 Err(_) => return false,
3367 };
3368 let state: serde_json::Value = match serde_json::from_str(&data) {
3369 Ok(v) => v,
3370 Err(_) => return false,
3371 };
3372 if let Some(step) = state["step"].as_u64() {
3373 self.embed_optimizer.set_step_count(step);
3374 }
3375 restore_moment_buffers(&state["m"], |idx, arr| {
3376 self.embed_optimizer.set_first_moment(idx, arr);
3377 });
3378 restore_moment_buffers(&state["v"], |idx, arr| {
3379 self.embed_optimizer.set_second_moment(idx, arr);
3380 });
3381 true
3382 }
3383}
3384
3385#[cfg(feature = "cuda")]
3389fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3390 if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3391 vec![numel]
3392 } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3393 let other_dim = numel / hidden_size;
3394 if name.ends_with("down_proj.weight") {
3395 vec![hidden_size, other_dim]
3396 } else {
3397 vec![other_dim, hidden_size]
3398 }
3399 } else {
3400 vec![numel]
3401 }
3402}
3403
3404#[cfg(feature = "cuda")]
3406fn restore_moment_buffers(
3407 json_arr: &serde_json::Value,
3408 mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3409) {
3410 let Some(arr) = json_arr.as_array() else { return };
3411 for (idx, val) in arr.iter().enumerate() {
3412 let Some(inner) = val.as_array() else { continue };
3413 let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3414 if !floats.is_empty() {
3415 set_fn(idx, ndarray::Array1::from_vec(floats));
3416 }
3417 }
3418}
3419
3420#[cfg(not(feature = "cuda"))]
3423pub struct CudaTransformerTrainer;
3424
3425#[cfg(not(feature = "cuda"))]
3426impl CudaTransformerTrainer {
3427 pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3428 Err(crate::error::Error::ConfigError(
3429 "CUDA not available (compiled without cuda feature)".into(),
3430 ))
3431 }
3432
3433 pub fn with_model(
3434 _model: crate::transformer::Transformer,
3435 _config: super::config::TransformerTrainConfig,
3436 ) -> crate::Result<Self> {
3437 Err(crate::error::Error::ConfigError(
3438 "CUDA not available (compiled without cuda feature)".into(),
3439 ))
3440 }
3441
3442 pub fn gpu_name(&self) -> String {
3443 unreachable!("CudaTransformerTrainer stub should never be instantiated")
3444 }
3445}
3446
3447#[cfg(test)]
3448mod tests {
3449 #[test]
3450 #[cfg(not(feature = "cuda"))]
3451 fn test_cuda_trainer_stub_returns_error() {
3452 use super::super::config::TransformerTrainConfig;
3453 use crate::transformer::TransformerConfig;
3454
3455 let mc = TransformerConfig::tiny();
3456 let config = TransformerTrainConfig::new(mc);
3457 let result = super::CudaTransformerTrainer::new(config);
3458 assert!(result.is_err());
3459 }
3460}