1#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42 gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46 adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47 gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48 squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64 CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65 GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83 ws: &CudaGradWorkspace,
84 max_norm: f32,
85 stream: &CudaStream,
86) -> (f32, f32) {
87 use crate::autograd::cuda_optim::PendingSquaredSum;
88
89 let all_bufs: [&GpuBuffer<f32>; 9] = [
90 &ws.grad_w_q,
91 &ws.grad_w_k,
92 &ws.grad_w_v,
93 &ws.grad_w_o,
94 &ws.grad_gate,
95 &ws.grad_up,
96 &ws.grad_down,
97 &ws.grad_input_norm,
98 &ws.grad_post_attn_norm,
99 ];
100
101 let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104 for buf in &all_bufs {
105 let n = buf.len() as u32;
106 if n == 0 {
107 continue;
108 }
109 if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110 pending.push(p);
111 }
112 }
113
114 if stream.synchronize().is_err() {
116 return (1.0, 0.0);
117 }
118
119 let mut total_sq = 0.0f64;
123 for p in &pending {
124 if let Ok(sq_norm) = squared_sum_collect(p) {
125 total_sq += f64::from(sq_norm); }
127 }
128
129 let grad_norm = total_sq.sqrt() as f32; let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131 (scale, grad_norm)
132}
133
134#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139 let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140 if (scale - 1.0).abs() < 1e-7 {
141 return grad_norm;
142 }
143
144 let n_wq = ws.grad_w_q.len() as u32;
145 let n_wk = ws.grad_w_k.len() as u32;
146 let n_wv = ws.grad_w_v.len() as u32;
147 let n_wo = ws.grad_w_o.len() as u32;
148 let n_gate = ws.grad_gate.len() as u32;
149 let n_up = ws.grad_up.len() as u32;
150 let n_down = ws.grad_down.len() as u32;
151 let n_inorm = ws.grad_input_norm.len() as u32;
152 let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154 let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155 let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156 let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157 let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158 let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159 let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160 let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163 grad_norm
164}
165
166#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178 ws: &mut CudaGradWorkspace,
179 max_norm: f32,
180 state: &FusedClipState,
181 stream: &CudaStream,
182) {
183 let all_bufs: [&GpuBuffer<f32>; 9] = [
184 &ws.grad_w_q,
185 &ws.grad_w_k,
186 &ws.grad_w_v,
187 &ws.grad_w_o,
188 &ws.grad_gate,
189 &ws.grad_up,
190 &ws.grad_down,
191 &ws.grad_input_norm,
192 &ws.grad_post_attn_norm,
193 ];
194
195 for (i, buf) in all_bufs.iter().enumerate() {
198 let n = buf.len() as u32;
199 if n == 0 {
200 continue;
201 }
202 let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203 let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204 }
205
206 let _ = clip_scale_reduce_cuda(
209 &state.partials_buf,
210 state.total_partials,
211 max_norm,
212 &state.scale_buf,
213 stream,
214 );
215
216 let scale_ptr = state.scale_buf.as_ptr(); let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220 &mut ws.grad_w_q,
221 &mut ws.grad_w_k,
222 &mut ws.grad_w_v,
223 &mut ws.grad_w_o,
224 &mut ws.grad_gate,
225 &mut ws.grad_up,
226 &mut ws.grad_down,
227 &mut ws.grad_input_norm,
228 &mut ws.grad_post_attn_norm,
229 ];
230 for buf in &mut all_bufs_mut {
231 let n = buf.len() as u32;
232 if n == 0 {
233 continue;
234 }
235 let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236 }
237}
238
239#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245 let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246 norm
247}
248
249#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261 if (inv_scale - 1.0).abs() < 1e-7 {
262 return;
263 }
264
265 let n_wq = ws.grad_w_q.len() as u32;
266 let n_wk = ws.grad_w_k.len() as u32;
267 let n_wv = ws.grad_w_v.len() as u32;
268 let n_wo = ws.grad_w_o.len() as u32;
269 let n_gate = ws.grad_gate.len() as u32;
270 let n_up = ws.grad_up.len() as u32;
271 let n_down = ws.grad_down.len() as u32;
272 let n_inorm = ws.grad_input_norm.len() as u32;
273 let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275 let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276 let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277 let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278 let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279 let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280 let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281 let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295 layer_inputs: Vec<GpuBuffer<f32>>,
297 saved_layer_mask: Vec<bool>,
301 recompute_buf: Option<GpuBuffer<f32>>,
305 final_norm_weight: GpuBuffer<f32>,
307 blocks_output: GpuBuffer<f32>,
309 grad_buf_a: GpuBuffer<f32>,
311 grad_buf_b: GpuBuffer<f32>,
313 grad_final_norm_weight: GpuBuffer<f32>,
315 norm_output: GpuBuffer<f32>,
317 logits_buf: GpuBuffer<f32>,
319 lm_head_grad_hidden: GpuBuffer<f32>,
321 optimizer_states: Vec<GpuBlockOptimizerState>,
323 step: u32,
325}
326
327#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339 model: Transformer,
341 cuda_trainer: CudaTrainer,
343 cuda_blocks: Vec<CudaBlock>,
345 cuda_grad_workspace: CudaGradWorkspace,
347 nf4_shared_scratch: Option<CudaBlockScratch>,
349 nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351 nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353 gpu_training: GpuPretrainState,
355 lm_head_weight_gpu: GpuBuffer<f32>,
357 lm_head_grad_gpu: GpuBuffer<f32>,
359 lm_head_m: GpuBuffer<f32>,
361 lm_head_v: GpuBuffer<f32>,
363 final_norm_m: GpuBuffer<f32>,
365 final_norm_v: GpuBuffer<f32>,
367 embed_optimizer: AdamW,
369 config: TransformerTrainConfig,
371 pub metrics: MetricsTracker,
373 step: usize,
375 accumulated_loss: f32,
377 accumulated_batches: usize,
379 last_grad_norm: f32,
381 last_embed_grad_norm: f32,
383 grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386 gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389 grad_scaler: GradScaler,
393 profiler: StepProfiler,
396 fwd_scratch_a: GpuBuffer<f32>,
399 fwd_scratch_b: GpuBuffer<f32>,
400 h2d_staging: Vec<f32>,
403 d2h_staging: Vec<f32>,
408 fused_clip: Option<FusedClipState>,
411 final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419 pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426 let model = Transformer::new(&config.model_config);
427 Self::with_model(model, config)
428 }
429
430 pub fn for_inference(
444 checkpoint_dir: impl AsRef<std::path::Path>,
445 model_config: crate::transformer::TransformerConfig,
446 ) -> crate::Result<Self> {
447 let dir = checkpoint_dir.as_ref();
448
449 let model = if let Some((Some(m), _step)) =
451 crate::config::try_load_apr_for_inference(dir, &model_config)
452 {
453 m
454 } else {
455 Transformer::from_safetensors(dir, &model_config)?
456 };
457
458 let mut config = TransformerTrainConfig::new(model_config);
459 config.max_seq_len = config.model_config.max_position_embeddings;
460 Self::with_model(model, config)
461 }
462
463 pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469 if !cuda_training_available() {
470 return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471 }
472
473 let mc = &config.model_config;
474 let max_seq_len = config.max_seq_len;
475 let hidden_size = mc.hidden_size;
476 let vocab_size = mc.vocab_size;
477 let num_layers = mc.num_hidden_layers;
478
479 let cuda_trainer = CudaTrainer::new().map_err(|e| {
481 crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482 })?;
483
484 println!(
485 " GPU: {} ({:.1} GB)",
486 cuda_trainer.device_name(),
487 cuda_trainer.total_memory() as f64 / 1e9
488 );
489
490 let ctx = cuda_trainer.context().clone();
491 let stream = cuda_trainer.stream();
492
493 pre_warm_forward_kernels(
496 hidden_size,
497 mc.intermediate_size,
498 mc.num_attention_heads,
499 mc.num_kv_heads,
500 mc.head_dim(),
501 max_seq_len,
502 )
503 .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505 {
509 use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510 let head_dim = mc.head_dim();
511 pre_warm_lora_backward_kernels(
512 hidden_size,
513 mc.num_attention_heads * head_dim,
514 mc.num_kv_heads * head_dim,
515 max_seq_len,
516 config.lora_rank.unwrap_or(0),
517 mc.intermediate_size,
518 mc.num_attention_heads,
519 config.quantize_nf4 && config.is_lora(),
520 )
521 .map_err(|e| {
522 crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523 })?;
524 eprintln!(" ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525 }
526
527 if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530 println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531 }
532 if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533 println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534 }
535
536 let use_nf4 = config.quantize_nf4 && config.is_lora();
538 let cuda_blocks = Self::upload_blocks(
539 &model,
540 mc,
541 &config,
542 &ctx,
543 use_nf4,
544 num_layers,
545 hidden_size,
546 max_seq_len,
547 )?;
548
549 let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551 crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552 })?;
553
554 let buf_size = max_seq_len * hidden_size;
556 let logits_size = max_seq_len * vocab_size;
557
558 let checkpointing = config.checkpoint_config.enabled;
562 let segment_size = if checkpointing {
563 let ns = config.checkpoint_config.num_segments.max(1);
564 num_layers.div_ceil(ns)
565 } else {
566 1 };
568 let saved_layer_mask: Vec<bool> =
569 (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571 let mut layer_inputs = Vec::with_capacity(num_layers);
572 for _ in 0..num_layers {
573 layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574 crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575 })?);
576 }
577
578 let recompute_buf = if checkpointing {
580 Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581 crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582 })?)
583 } else {
584 None
585 };
586
587 if checkpointing {
588 let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589 println!(
590 " ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591 config.checkpoint_config.num_segments, saved_count, num_layers
592 );
593 }
594
595 let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597 let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598 crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599 })?;
600
601 let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602 crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603 })?;
604 let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605 crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606 })?;
607 let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608 crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609 })?;
610 let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611 crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612 })?;
613 let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614 crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615 })?;
616 let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617 crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618 })?;
619 let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621 })?;
622
623 let mut optimizer_states = Vec::new();
625 if !use_nf4 {
626 optimizer_states.reserve(num_layers);
627 for (i, block) in cuda_blocks.iter().enumerate() {
628 optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629 crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630 })?);
631 }
632 }
633
634 let gpu_training = GpuPretrainState {
635 layer_inputs,
636 saved_layer_mask,
637 recompute_buf,
638 final_norm_weight,
639 blocks_output,
640 grad_buf_a,
641 grad_buf_b,
642 grad_final_norm_weight,
643 norm_output,
644 logits_buf,
645 lm_head_grad_hidden,
646 optimizer_states,
647 step: 0,
648 };
649
650 let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653 let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654 let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655 crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656 })?;
657 let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659 })?;
660 let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663 .map_err(|e| {
664 crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665 })?;
666 let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667 .map_err(|e| {
668 crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669 })?;
670
671 let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673 crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674 })?;
675 let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676 crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677 })?;
678
679 let buf_size = max_seq_len * hidden_size;
681 let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682 crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683 })?;
684 let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685 crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686 })?;
687
688 stream
690 .synchronize()
691 .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693 println!(
694 " ✓ GPU training state allocated (LM head: {:.1} MB)",
695 (vocab_size * hidden_size * 4) as f64 / 1e6
696 );
697
698 let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700 let lora_rank = config.lora_rank.unwrap_or(16);
701
702 let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704 crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705 })?;
706
707 let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709 crate::error::Error::ConfigError(format!(
710 "NF4 LoRA grad workspace alloc failed: {e:?}"
711 ))
712 })?;
713
714 let mut lora_opt_states = Vec::with_capacity(num_layers);
716 for (i, block) in cuda_blocks.iter().enumerate() {
717 lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718 crate::error::Error::ConfigError(format!(
719 "Block {i} LoRA opt state failed: {e:?}"
720 ))
721 })?);
722 }
723
724 println!(
725 " ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726 );
727 (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728 } else {
729 (None, None, None)
730 };
731
732 let embed_optimizer =
735 AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737 let grad_accum = if config.accumulation_steps > 1 {
740 let kv_hidden = mc.num_kv_heads * mc.head_dim();
741 let block_sizes =
742 super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743 hidden_size,
744 kv_hidden,
745 mc.intermediate_size,
746 );
747 let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748 num_layers,
749 block_sizes,
750 vocab_size,
751 hidden_size,
752 );
753 println!(
754 " ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755 config.accumulation_steps,
756 (accum
757 .block_grads
758 .iter()
759 .map(super::grad_accumulator::BlockGradientSet::total_elements)
760 .sum::<usize>()
761 + accum.lm_head_grad.len()
762 + accum.final_norm_grad.len()
763 + accum.embedding_grad.len()) as f64
764 * 4.0
765 / 1e6,
766 );
767 Some(accum)
768 } else {
769 None
770 };
771
772 let gpu_grad_accum = if config.accumulation_steps > 1 {
775 match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776 Ok(accum) => {
777 println!(" ✓ GPU gradient accumulation enabled (ALB-091)");
778 Some(accum)
779 }
780 Err(e) => {
781 eprintln!(
782 " [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783 );
784 None
785 }
786 }
787 } else {
788 None
789 };
790
791 let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794 let ws_max = hidden_size * mc.intermediate_size;
795 let lm_max = vocab_size * hidden_size;
796 vec![0.0f32; ws_max.max(lm_max)]
797 } else {
798 Vec::new()
799 };
800
801 let kv_hidden = mc.num_kv_heads * mc.head_dim();
804 let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806 let grad_scaler = GradScaler::from_config(&config.precision_config);
808 if config.precision_config.is_mixed() {
809 println!(
810 " ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811 config.precision_config.compute_precision,
812 grad_scaler.scale(),
813 grad_scaler.is_dynamic(),
814 );
815 }
816
817 Ok(Self {
818 model,
819 cuda_trainer,
820 cuda_blocks,
821 cuda_grad_workspace,
822 nf4_shared_scratch,
823 nf4_lora_grad_workspace,
824 nf4_lora_optimizer_states,
825 gpu_training,
826 lm_head_weight_gpu,
827 lm_head_grad_gpu,
828 lm_head_m,
829 lm_head_v,
830 final_norm_m,
831 final_norm_v,
832 embed_optimizer,
833 profiler: if config.profile_interval > 0 {
835 StepProfiler::new(true, config.profile_interval)
836 } else {
837 StepProfiler::disabled()
838 },
839 config,
840 metrics: MetricsTracker::new(),
841 step: 0,
842 accumulated_loss: 0.0,
843 accumulated_batches: 0,
844 last_grad_norm: 0.0,
845 last_embed_grad_norm: 0.0,
846 grad_accum,
847 gpu_grad_accum,
848 grad_scaler,
849 fwd_scratch_a,
850 fwd_scratch_b,
851 h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852 d2h_staging,
853 fused_clip,
854 final_norm_zero_buf: vec![0.0f32; hidden_size],
855 })
856 }
857
858 #[allow(clippy::too_many_arguments)]
860 fn upload_blocks(
861 model: &Transformer,
862 mc: &crate::transformer::TransformerConfig,
863 config: &TransformerTrainConfig,
864 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865 use_nf4: bool,
866 num_layers: usize,
867 hidden_size: usize,
868 max_seq_len: usize,
869 ) -> crate::Result<Vec<CudaBlock>> {
870 let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872 if use_nf4 {
873 let lora_rank = config.lora_rank.unwrap_or(16);
874 let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875 let lora_scale = lora_alpha / lora_rank as f32;
876 let head_dim = mc.head_dim();
877 let q_dim = mc.num_attention_heads * head_dim;
878 let kv_hidden = mc.num_kv_heads * head_dim;
879
880 for (i, layer) in model.layers.iter().enumerate() {
881 let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882 .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883 .collect();
884 let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885 let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886 .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887 .collect();
888 let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890 let q_norm_data = layer
891 .self_attn
892 .q_norm
893 .as_ref()
894 .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895 let k_norm_data = layer
896 .self_attn
897 .k_norm
898 .as_ref()
899 .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901 let block = crate::transformer::CudaNf4TransformerBlock::new(
902 mc,
903 i,
904 ctx.clone(),
905 layer.input_norm.weight.data().as_slice().expect("contiguous"),
906 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912 layer.ffn.w_up.data().as_slice().expect("contiguous"),
913 layer.ffn.w_down.data().as_slice().expect("contiguous"),
914 max_seq_len,
915 Some((&lora_a_q, &lora_b_q)),
916 Some((&lora_a_v, &lora_b_v)),
917 lora_scale,
918 lora_rank,
919 q_norm_data.as_deref(),
920 k_norm_data.as_deref(),
921 )
922 .map_err(|e| {
923 crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924 })?;
925 cuda_blocks.push(CudaBlock::Nf4(block));
926 }
927 println!(" ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928 } else {
929 for (i, layer) in model.layers.iter().enumerate() {
930 let b_q = layer
935 .self_attn
936 .b_q
937 .as_ref()
938 .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939 let b_k = layer
940 .self_attn
941 .b_k
942 .as_ref()
943 .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944 let b_v = layer
945 .self_attn
946 .b_v
947 .as_ref()
948 .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949 let block = CudaTransformerBlock::new(
950 mc,
951 i,
952 ctx.clone(),
953 layer.input_norm.weight.data().as_slice().expect("contiguous"),
954 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960 layer.ffn.w_up.data().as_slice().expect("contiguous"),
961 layer.ffn.w_down.data().as_slice().expect("contiguous"),
962 max_seq_len,
963 b_q.as_deref(),
964 b_k.as_deref(),
965 b_v.as_deref(),
966 )
967 .map_err(|e| {
968 crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969 })?;
970 cuda_blocks.push(CudaBlock::Fp32(block));
971 }
972 println!(" ✓ {num_layers} transformer blocks uploaded to GPU");
973 }
974
975 Ok(cuda_blocks)
976 }
977
978 fn init_fused_clip(
980 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981 config: &TransformerTrainConfig,
982 hidden_size: usize,
983 kv_hidden: usize,
984 mc: &crate::transformer::TransformerConfig,
985 ) -> Option<FusedClipState> {
986 config.base.max_grad_norm?;
987 let grad_sizes: [u32; 9] = [
988 (hidden_size * hidden_size) as u32,
989 (hidden_size * kv_hidden) as u32,
990 (hidden_size * kv_hidden) as u32,
991 (hidden_size * hidden_size) as u32,
992 (hidden_size * mc.intermediate_size) as u32,
993 (hidden_size * mc.intermediate_size) as u32,
994 (mc.intermediate_size * hidden_size) as u32,
995 hidden_size as u32,
996 hidden_size as u32,
997 ];
998 match FusedClipState::new(ctx, &grad_sizes) {
999 Ok(state) => {
1000 println!(
1001 " ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002 state.total_partials,
1003 f64::from(state.total_partials) * 4.0 / 1024.0,
1004 );
1005 Some(state)
1006 }
1007 Err(e) => {
1008 println!(" ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009 None
1010 }
1011 }
1012 }
1013
1014 fn train_step_single(
1023 &mut self,
1024 input_ids: &[u32],
1025 target_ids: &[u32],
1026 accumulate_only: bool,
1027 ) -> Option<f32> {
1028 self.profiler.begin_step();
1029 let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030 self.profiler.finish_step();
1031 result
1032 }
1033
1034 fn train_step_inner(
1036 &mut self,
1037 input_ids: &[u32],
1038 target_ids: &[u32],
1039 accumulate_only: bool,
1040 ) -> Option<f32> {
1041 let hidden_size = self.config.model_config.hidden_size;
1042 let vocab_size = self.config.model_config.vocab_size;
1043
1044 let max_sl = self.config.max_seq_len;
1046 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048 let seq_len = input_ids.len();
1049
1050 if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053 eprintln!(
1054 "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056 );
1057 return None;
1058 }
1059
1060 self.profiler.begin(StepProfiler::LOSS);
1063 let stream = self.cuda_trainer.stream();
1064
1065 let mut loss_scale = 1.0 / seq_len as f32;
1074 if self.config.accumulation_steps > 1 {
1075 loss_scale /= self.config.accumulation_steps as f32;
1076 }
1077
1078 let loss_val = fused_cross_entropy_cuda(
1080 &mut self.gpu_training.logits_buf,
1081 target_ids,
1082 seq_len as u32,
1083 vocab_size as u32,
1084 loss_scale,
1085 stream,
1086 )
1087 .ok()?;
1088
1089 if !loss_val.is_finite() {
1091 return None;
1092 }
1093 self.profiler.end(StepProfiler::LOSS);
1094
1095 if let Some(grad_output_is_a) =
1104 self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105 {
1106 self.profiler.begin(StepProfiler::EMBED_BWD);
1108 self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110 self.profiler.end(StepProfiler::EMBED_BWD);
1111 }
1112
1113 Some(loss_val)
1114 }
1115
1116 #[allow(unsafe_code)]
1121 fn gpu_forward(
1122 &mut self,
1123 input_ids: &[u32],
1124 seq_len: usize,
1125 hidden_size: usize,
1126 vocab_size: usize,
1127 ) -> Option<()> {
1128 contract_pre_gpu_forward!();
1129 let stream = self.cuda_trainer.stream();
1130
1131 self.profiler.begin(StepProfiler::EMBED);
1133 let hidden = self.model.embed_tokens.forward(input_ids);
1134 let hidden_slice = hidden.data().as_slice()?;
1135 self.profiler.end(StepProfiler::EMBED);
1136
1137 self.profiler.begin(StepProfiler::H2D);
1142 self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143 self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144 if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145 eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146 return None;
1147 }
1148 self.profiler.end(StepProfiler::H2D);
1149
1150 self.profiler.begin(StepProfiler::FORWARD);
1154 let mut input_is_a = true; for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156 let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159 if input_is_a {
1160 (
1161 std::ptr::from_ref(&self.fwd_scratch_a),
1162 std::ptr::from_mut(&mut self.fwd_scratch_b),
1163 )
1164 } else {
1165 (
1166 std::ptr::from_ref(&self.fwd_scratch_b),
1167 std::ptr::from_mut(&mut self.fwd_scratch_a),
1168 )
1169 };
1170 if self.gpu_training.saved_layer_mask[i] {
1171 unsafe {
1174 self.gpu_training.layer_inputs[i]
1175 .copy_from_buffer_async(&*input_ptr, stream)
1176 .ok()?;
1177 }
1178 }
1179 self.profiler.begin_layer();
1182 unsafe {
1183 block
1184 .forward(
1185 &*input_ptr,
1186 &mut *output_ptr,
1187 seq_len,
1188 stream,
1189 self.nf4_shared_scratch.as_mut(),
1190 )
1191 .ok()?;
1192 }
1193 self.profiler.end_layer_fwd(i);
1194 input_is_a = !input_is_a;
1195 }
1196 self.profiler.end(StepProfiler::FORWARD);
1197
1198 let final_output: &GpuBuffer<f32> =
1200 if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202 self.profiler.begin(StepProfiler::NORM_LM);
1205 unsafe {
1206 self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207 }
1208
1209 rms_norm_forward_with_eps(
1214 final_output,
1215 &self.gpu_training.final_norm_weight,
1216 &mut self.gpu_training.norm_output,
1217 seq_len as u32,
1218 hidden_size as u32,
1219 self.config.model_config.rms_norm_eps,
1220 stream,
1221 )
1222 .ok()?;
1223
1224 gemm_forward(
1228 &self.gpu_training.norm_output,
1229 &self.lm_head_weight_gpu,
1230 &mut self.gpu_training.logits_buf,
1231 seq_len as u32,
1232 hidden_size as u32,
1233 vocab_size as u32,
1234 stream,
1235 )
1236 .ok()?;
1237
1238 self.profiler.end(StepProfiler::NORM_LM);
1241
1242 Some(())
1243 }
1244
1245 pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1257 let seq_len = input_ids.len();
1258 let hidden_size = self.config.model_config.hidden_size;
1259 let vocab_size = self.config.model_config.vocab_size;
1260
1261 if seq_len == 0 || seq_len > self.config.max_seq_len {
1262 return None;
1263 }
1264
1265 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1267
1268 let stream = self.cuda_trainer.stream();
1270 stream.synchronize().ok()?;
1271
1272 let offset = (seq_len - 1) * vocab_size;
1274 let mut logits = vec![0.0f32; vocab_size];
1275 self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1276
1277 Some(logits)
1278 }
1279
1280 #[allow(unsafe_code)]
1297 fn recompute_segment(
1298 gpu_training: &mut GpuPretrainState,
1299 cuda_blocks: &mut [CudaBlock],
1300 nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1301 target_layer: usize,
1302 seq_len: usize,
1303 stream: &CudaStream,
1304 ) -> Option<()> {
1305 let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1307
1308 if seg_start == target_layer {
1309 return Some(()); }
1311
1312 let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1315 unsafe {
1316 recompute_buf
1317 .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1318 .ok()?;
1319 }
1320
1321 for i in seg_start..target_layer {
1332 if i == seg_start {
1333 let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1335 let li = &mut gpu_training.layer_inputs;
1336 unsafe {
1337 cuda_blocks[i]
1338 .forward(
1339 &*recompute_ptr,
1340 &mut li[i + 1],
1341 seq_len,
1342 stream,
1343 nf4_shared_scratch.as_mut(),
1344 )
1345 .ok()?;
1346 }
1347 } else {
1348 let li = &mut gpu_training.layer_inputs;
1350 let (left, right) = li.split_at_mut(i + 1);
1351 cuda_blocks[i]
1352 .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1353 .ok()?;
1354 }
1355 }
1356
1357 Some(())
1358 }
1359
1360 #[allow(unsafe_code)]
1373 fn gpu_backward(
1374 &mut self,
1375 seq_len: usize,
1376 hidden_size: usize,
1377 vocab_size: usize,
1378 accumulate_only: bool,
1379 ) -> Option<bool> {
1380 let stream = self.cuda_trainer.stream();
1381 let max_grad_norm = self.config.base.max_grad_norm;
1382 let lr = self.current_lr();
1383 let beta1 = self.config.beta1;
1385 let beta2 = self.config.beta2;
1386 let weight_decay = self.config.weight_decay;
1387
1388 self.profiler.begin(StepProfiler::LM_BWD);
1393 gemm_backward_a(
1394 &self.gpu_training.logits_buf,
1395 &self.lm_head_weight_gpu,
1396 &mut self.gpu_training.lm_head_grad_hidden,
1397 seq_len as u32,
1398 hidden_size as u32,
1399 vocab_size as u32,
1400 stream,
1401 )
1402 .ok()?;
1403
1404 gemm_backward_b(
1405 &self.gpu_training.norm_output,
1406 &self.gpu_training.logits_buf,
1407 &mut self.lm_head_grad_gpu,
1408 seq_len as u32,
1409 hidden_size as u32,
1410 vocab_size as u32,
1411 stream,
1412 )
1413 .ok()?;
1414
1415 let lm_sq_norm =
1421 squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1422 .unwrap_or(0.0);
1423 let lm_norm = lm_sq_norm.sqrt(); self.last_grad_norm = lm_norm; if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1427 eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1428 let gh_sq = squared_sum_cuda(
1430 &self.gpu_training.lm_head_grad_hidden,
1431 self.gpu_training.lm_head_grad_hidden.len() as u32,
1432 stream,
1433 )
1434 .unwrap_or(0.0);
1435 eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1436 }
1437 if let Some(max_norm) = max_grad_norm {
1438 let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1439 let n = self.lm_head_grad_gpu.len() as u32;
1440 let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1441 }
1442 self.profiler.end(StepProfiler::LM_BWD);
1443
1444 self.profiler.begin(StepProfiler::NORM_BWD);
1446 self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1448 rms_norm_backward(
1449 &self.gpu_training.blocks_output,
1450 &self.gpu_training.final_norm_weight,
1451 &self.gpu_training.lm_head_grad_hidden,
1452 &mut self.gpu_training.grad_buf_a,
1453 &mut self.gpu_training.grad_final_norm_weight,
1454 seq_len as u32,
1455 hidden_size as u32,
1456 1e-5_f32,
1457 stream,
1458 )
1459 .ok()?;
1460
1461 if let Some(max_norm) = max_grad_norm {
1464 let (scale, _) = Self::compute_clip_scale_with_norm(
1465 &self.gpu_training.grad_final_norm_weight,
1466 max_norm,
1467 stream,
1468 );
1469 let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1470 let _ =
1471 gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1472 }
1473 self.profiler.end(StepProfiler::NORM_BWD);
1474
1475 if accumulate_only {
1477 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1479 let _ = gpu_accum.accumulate_nonblock(
1480 &self.lm_head_grad_gpu,
1481 &self.gpu_training.grad_final_norm_weight,
1482 stream,
1483 );
1484 } else {
1485 stream.synchronize().ok()?;
1486 Self::download_nonblock_grads_to_accum(
1487 &self.lm_head_grad_gpu,
1488 &self.gpu_training.grad_final_norm_weight,
1489 &mut self.grad_accum,
1490 &mut self.d2h_staging,
1491 )?;
1492 }
1493 } else {
1494 Self::run_nonblock_optimizer_step(
1495 &mut self.gpu_training,
1496 Some(&mut self.lm_head_weight_gpu),
1497 &self.lm_head_grad_gpu,
1498 &mut self.lm_head_m,
1499 &mut self.lm_head_v,
1500 &mut self.final_norm_m,
1501 &mut self.final_norm_v,
1502 lr,
1503 beta1,
1504 beta2,
1505 weight_decay,
1506 stream,
1507 );
1508 }
1509
1510 self.profiler.begin(StepProfiler::BLK_BWD);
1516 let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1517 let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1518 let mut grad_output_is_a = true;
1519 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1520
1521 for layer_idx in (0..self.cuda_blocks.len()).rev() {
1522 if !self.gpu_training.saved_layer_mask[layer_idx] {
1525 Self::recompute_segment(
1526 &mut self.gpu_training,
1527 &mut self.cuda_blocks,
1528 &mut self.nf4_shared_scratch,
1529 layer_idx,
1530 seq_len,
1531 stream,
1532 )?;
1533 }
1534
1535 let (grad_output, grad_input) = unsafe {
1536 if grad_output_is_a {
1537 (&*grad_a_ptr, &mut *grad_b_ptr)
1538 } else {
1539 (&*grad_b_ptr, &mut *grad_a_ptr)
1540 }
1541 };
1542
1543 self.profiler.begin_layer();
1544 if use_nf4 {
1545 let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1549 grad_b_ptr } else {
1551 grad_a_ptr
1552 };
1553 match self.cuda_blocks[layer_idx].backward_nf4(
1556 &self.gpu_training.layer_inputs[layer_idx],
1557 grad_output,
1558 grad_input,
1559 &mut self.gpu_training.blocks_output, seq_len,
1561 stream,
1562 self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1563 self.nf4_lora_grad_workspace
1564 .as_mut()
1565 .expect("NF4 requires LoRA grad workspace"),
1566 ) {
1567 Ok(()) => {}
1568 Err(e) => {
1569 eprintln!(
1570 "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1571 layer_idx, e, seq_len, self.config.model_config.hidden_size
1572 );
1573 return None;
1574 }
1575 }
1576
1577 if let Some(max_norm) = max_grad_norm {
1581 self.nf4_lora_grad_workspace
1582 .as_mut()
1583 .expect("NF4 requires LoRA grad ws")
1584 .clip_gradients(max_norm, stream);
1585 }
1586
1587 {
1593 let step = self.gpu_training.step;
1594 let effective_lr = if accumulate_only {
1595 lr / self.config.accumulation_steps as f32
1596 } else {
1597 lr
1598 };
1599 if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1600 let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1601 &mut opt_states[layer_idx],
1602 step,
1603 effective_lr,
1604 beta1,
1605 beta2,
1606 1e-8,
1607 weight_decay,
1608 stream,
1609 self.nf4_lora_grad_workspace
1610 .as_ref()
1611 .expect("NF4 requires LoRA grad ws"),
1612 );
1613 }
1614 }
1615 } else {
1616 self.cuda_blocks[layer_idx]
1618 .backward(
1619 &self.gpu_training.layer_inputs[layer_idx],
1620 grad_output,
1621 grad_input,
1622 seq_len,
1623 stream,
1624 &mut self.cuda_grad_workspace,
1625 )
1626 .ok()?;
1627
1628 if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1634 let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1635 &self.cuda_grad_workspace,
1636 f32::MAX,
1637 stream,
1638 );
1639 let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1641 .unwrap_or(0.0);
1642 let act_gnorm = act_sq.sqrt();
1643 eprintln!(
1644 "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1645 );
1646 }
1647
1648 if accumulate_only {
1650 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1652 let _ = gpu_accum.accumulate_block(
1653 &self.cuda_grad_workspace,
1654 layer_idx,
1655 stream,
1656 );
1657 } else {
1658 stream.synchronize().ok()?;
1660 if let Some(accum) = &mut self.grad_accum {
1661 Self::download_workspace_to_accum(
1662 &self.cuda_grad_workspace,
1663 accum,
1664 layer_idx,
1665 &mut self.d2h_staging,
1666 )?;
1667 }
1668 }
1669 } else {
1670 let step = self.gpu_training.step;
1672 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1673 &mut self.gpu_training.optimizer_states[layer_idx],
1674 step,
1675 lr,
1676 beta1,
1677 beta2,
1678 1e-8,
1679 weight_decay,
1680 stream,
1681 &self.cuda_grad_workspace,
1682 );
1683 }
1684 }
1685
1686 self.profiler.end_layer_bwd(layer_idx);
1687 grad_output_is_a = !grad_output_is_a;
1688 }
1689
1690 stream.synchronize().ok()?;
1691 self.profiler.end(StepProfiler::BLK_BWD);
1692
1693 Some(grad_output_is_a)
1694 }
1695
1696 fn download_nonblock_grads_to_accum(
1702 lm_head_grad: &GpuBuffer<f32>,
1703 final_norm_grad: &GpuBuffer<f32>,
1704 grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1705 host: &mut [f32],
1706 ) -> Option<()> {
1707 let accum = grad_accum.as_mut()?;
1708
1709 let lm_slice = &mut host[..lm_head_grad.len()];
1710 lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1711 for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1712 *d += s;
1713 }
1714
1715 let norm_slice = &mut host[..final_norm_grad.len()];
1716 final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1717 for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1718 *d += s;
1719 }
1720 Some(())
1721 }
1722
1723 #[allow(clippy::too_many_arguments)]
1726 fn run_nonblock_optimizer_step(
1727 gpu_training: &mut GpuPretrainState,
1728 lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1729 lm_head_grad_gpu: &GpuBuffer<f32>,
1730 lm_head_m: &mut GpuBuffer<f32>,
1731 lm_head_v: &mut GpuBuffer<f32>,
1732 final_norm_m: &mut GpuBuffer<f32>,
1733 final_norm_v: &mut GpuBuffer<f32>,
1734 lr: f32,
1735 beta1: f32,
1736 beta2: f32,
1737 weight_decay: f32,
1738 stream: &CudaStream,
1739 ) {
1740 gpu_training.step += 1;
1741 let step = gpu_training.step;
1742
1743 if let Some(lm_head_weight) = lm_head_weight_gpu {
1744 let n_lm = lm_head_weight.len() as u32;
1745 let _ = adamw_step_cuda(
1746 lm_head_weight,
1747 lm_head_grad_gpu,
1748 lm_head_m,
1749 lm_head_v,
1750 lr,
1751 beta1,
1752 beta2,
1753 1e-8,
1754 weight_decay,
1755 step,
1756 n_lm,
1757 stream,
1758 );
1759 }
1760
1761 let n_norm = gpu_training.final_norm_weight.len() as u32;
1762 let _ = adamw_step_cuda(
1763 &mut gpu_training.final_norm_weight,
1764 &gpu_training.grad_final_norm_weight,
1765 final_norm_m,
1766 final_norm_v,
1767 lr,
1768 beta1,
1769 beta2,
1770 1e-8,
1771 weight_decay,
1772 step,
1773 n_norm,
1774 stream,
1775 );
1776 }
1777
1778 fn download_workspace_to_accum(
1786 ws: &CudaGradWorkspace,
1787 accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1788 layer_idx: usize,
1789 host: &mut [f32],
1790 ) -> Option<()> {
1791 let bg = &mut accum.block_grads[layer_idx];
1792
1793 use super::grad_accumulator::component;
1794 let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1795 (&ws.grad_w_q, component::W_Q),
1796 (&ws.grad_w_k, component::W_K),
1797 (&ws.grad_w_v, component::W_V),
1798 (&ws.grad_w_o, component::W_O),
1799 (&ws.grad_gate, component::GATE),
1800 (&ws.grad_up, component::UP),
1801 (&ws.grad_down, component::DOWN),
1802 (&ws.grad_input_norm, component::INPUT_NORM),
1803 (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1804 ];
1805
1806 for (gpu_buf, comp_idx) in &bufs_and_components {
1807 let slice = &mut host[..gpu_buf.len()];
1808 gpu_buf.copy_to_host_at(slice, 0).ok()?;
1809 for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1810 *d += s;
1811 }
1812 }
1813 Some(())
1814 }
1815
1816 fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1823 let stream = self.cuda_trainer.stream();
1824 let lr = self.current_lr();
1825 let beta1 = self.config.beta1;
1826 let beta2 = self.config.beta2;
1827 let weight_decay = self.config.weight_decay;
1828
1829 stream.synchronize().ok()?;
1831
1832 self.gpu_training.step += 1;
1833 let step = self.gpu_training.step;
1834
1835 let gpu_accum = self.gpu_grad_accum.as_ref()?;
1837 for layer_idx in 0..self.cuda_blocks.len() {
1838 gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1839
1840 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1841 &mut self.gpu_training.optimizer_states[layer_idx],
1842 step,
1843 lr,
1844 beta1,
1845 beta2,
1846 1e-8,
1847 weight_decay,
1848 stream,
1849 &self.cuda_grad_workspace,
1850 );
1851 }
1852
1853 gpu_accum
1855 .upload_nonblock(
1856 &mut self.lm_head_grad_gpu,
1857 &mut self.gpu_training.grad_final_norm_weight,
1858 )
1859 .ok()?;
1860
1861 let n_lm = self.lm_head_weight_gpu.len() as u32;
1862 let _ = adamw_step_cuda(
1863 &mut self.lm_head_weight_gpu,
1864 &self.lm_head_grad_gpu,
1865 &mut self.lm_head_m,
1866 &mut self.lm_head_v,
1867 lr,
1868 beta1,
1869 beta2,
1870 1e-8,
1871 weight_decay,
1872 step,
1873 n_lm,
1874 stream,
1875 );
1876
1877 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1879 let _ = adamw_step_cuda(
1880 &mut self.gpu_training.final_norm_weight,
1881 &self.gpu_training.grad_final_norm_weight,
1882 &mut self.final_norm_m,
1883 &mut self.final_norm_v,
1884 lr,
1885 beta1,
1886 beta2,
1887 1e-8,
1888 weight_decay,
1889 step,
1890 n_norm,
1891 stream,
1892 );
1893
1894 stream.synchronize().ok()?;
1895
1896 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1898 let _ = gpu_accum.zero_all();
1899 }
1900
1901 Some(())
1902 }
1903
1904 #[allow(unsafe_code)]
1905 fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1906 let stream = self.cuda_trainer.stream();
1907 let lr = self.current_lr();
1908 let beta1 = self.config.beta1;
1909 let beta2 = self.config.beta2;
1910 let weight_decay = self.config.weight_decay;
1911
1912 let accum = self.grad_accum.as_mut()?;
1914 accum.average();
1915
1916 if accum.has_non_finite() {
1918 println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1919 accum.zero_all();
1920 return Some(());
1921 }
1922
1923 self.gpu_training.step += 1;
1924 let step = self.gpu_training.step;
1925
1926 use super::grad_accumulator::component;
1928 for layer_idx in 0..self.cuda_blocks.len() {
1929 let bg = &accum.block_grads[layer_idx];
1930
1931 unsafe {
1935 self.cuda_grad_workspace
1936 .grad_w_q
1937 .copy_from_host_async(&bg.components[component::W_Q], stream)
1938 .ok()?;
1939 self.cuda_grad_workspace
1940 .grad_w_k
1941 .copy_from_host_async(&bg.components[component::W_K], stream)
1942 .ok()?;
1943 self.cuda_grad_workspace
1944 .grad_w_v
1945 .copy_from_host_async(&bg.components[component::W_V], stream)
1946 .ok()?;
1947 self.cuda_grad_workspace
1948 .grad_w_o
1949 .copy_from_host_async(&bg.components[component::W_O], stream)
1950 .ok()?;
1951 self.cuda_grad_workspace
1952 .grad_gate
1953 .copy_from_host_async(&bg.components[component::GATE], stream)
1954 .ok()?;
1955 self.cuda_grad_workspace
1956 .grad_up
1957 .copy_from_host_async(&bg.components[component::UP], stream)
1958 .ok()?;
1959 self.cuda_grad_workspace
1960 .grad_down
1961 .copy_from_host_async(&bg.components[component::DOWN], stream)
1962 .ok()?;
1963 self.cuda_grad_workspace
1964 .grad_input_norm
1965 .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1966 .ok()?;
1967 self.cuda_grad_workspace
1968 .grad_post_attn_norm
1969 .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1970 .ok()?;
1971 }
1972
1973 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1975 &mut self.gpu_training.optimizer_states[layer_idx],
1976 step,
1977 lr,
1978 beta1,
1979 beta2,
1980 1e-8,
1981 weight_decay,
1982 stream,
1983 &self.cuda_grad_workspace,
1984 );
1985 }
1986
1987 unsafe {
1991 self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1992 }
1993 let n_lm = self.lm_head_weight_gpu.len() as u32;
1994 let _ = adamw_step_cuda(
1995 &mut self.lm_head_weight_gpu,
1996 &self.lm_head_grad_gpu,
1997 &mut self.lm_head_m,
1998 &mut self.lm_head_v,
1999 lr,
2000 beta1,
2001 beta2,
2002 1e-8,
2003 weight_decay,
2004 step,
2005 n_lm,
2006 stream,
2007 );
2008
2009 unsafe {
2012 self.gpu_training
2013 .grad_final_norm_weight
2014 .copy_from_host_async(&accum.final_norm_grad, stream)
2015 .ok()?;
2016 }
2017 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2018 let _ = adamw_step_cuda(
2019 &mut self.gpu_training.final_norm_weight,
2020 &self.gpu_training.grad_final_norm_weight,
2021 &mut self.final_norm_m,
2022 &mut self.final_norm_v,
2023 lr,
2024 beta1,
2025 beta2,
2026 1e-8,
2027 weight_decay,
2028 step,
2029 n_norm,
2030 stream,
2031 );
2032
2033 stream.synchronize().ok()?;
2034
2035 accum.zero_all();
2037 Some(())
2038 }
2039
2040 fn compute_clip_scale_with_norm(
2053 buf: &GpuBuffer<f32>,
2054 max_norm: f32,
2055 stream: &CudaStream,
2056 ) -> (f32, f32) {
2057 let n = buf.len() as u32;
2058 let grad_norm = match squared_sum_cuda(buf, n, stream) {
2060 Ok(norm) => norm,
2061 Err(_) => {
2062 let mut host = vec![0.0f32; buf.len()];
2064 if buf.copy_to_host_at(&mut host, 0).is_err() {
2065 return (1.0, 0.0);
2066 }
2067 let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2068 sq_sum.sqrt() as f32
2069 }
2070 };
2071 let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2072 (scale, grad_norm)
2073 }
2074
2075 #[allow(unsafe_code)]
2084 fn embed_backward(
2085 &mut self,
2086 input_ids: &[u32],
2087 _seq_len: usize,
2088 hidden_size: usize,
2089 vocab_size: usize,
2090 grad_output_is_a: bool,
2091 ) -> Option<()> {
2092 let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2094 let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2095 let embed_grad_buf = unsafe {
2096 if grad_output_is_a {
2097 &*grad_a_ptr
2098 } else {
2099 &*grad_b_ptr
2100 }
2101 };
2102 let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2103
2104 let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2112 {
2113 let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2114 let grad_norm = sq_sum.sqrt() as f32;
2115 self.last_embed_grad_norm = grad_norm; if grad_norm > embed_clip_norm {
2117 let scale = embed_clip_norm / grad_norm;
2118 for g in &mut embed_grad_data {
2119 *g *= scale;
2120 }
2121 }
2122 }
2123
2124 let embed_weight = &mut self.model.embed_tokens.weight;
2128 let grad_cell = embed_weight.grad_cell();
2129 let mut grad_ref = grad_cell.borrow_mut();
2130 if grad_ref.is_none() {
2131 *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2132 }
2133 if let Some(grad) = grad_ref.as_mut() {
2134 for (pos, &token_id) in input_ids.iter().enumerate() {
2135 let tid = token_id as usize;
2136 if tid < vocab_size {
2137 let src = pos * hidden_size;
2138 let dst = tid * hidden_size;
2139 for h in 0..hidden_size {
2140 grad[dst + h] += embed_grad_data[src + h];
2141 }
2142 }
2143 }
2144 }
2145 Some(())
2146 }
2147
2148 fn optimizer_step(&mut self) {
2154 self.grad_scaler.update(true);
2158
2159 self.embed_optimizer.set_lr(self.current_lr());
2161 let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2163 self.embed_optimizer.step_refs(&mut embed_params);
2164
2165 self.step += 1;
2166 self.metrics.losses.push(self.accumulated_loss);
2167 self.metrics.increment_step();
2168
2169 self.accumulated_loss = 0.0;
2170 self.accumulated_batches = 0;
2171 }
2172
2173 pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2185 if batch.batch_size == 0 {
2186 return 0.0;
2187 }
2188
2189 let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2190
2191 if self.accumulated_batches == 0 {
2192 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2194 }
2195
2196 let mut total_loss = 0.0;
2197 let mut valid_count = 0;
2198
2199 for i in 0..batch.batch_size {
2200 let Some(input_ids) = batch.get_input(i) else {
2201 continue;
2202 };
2203 let Some(target_ids) = batch.get_target(i) else {
2204 continue;
2205 };
2206
2207 if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2211 total_loss += loss;
2212 valid_count += 1;
2213 if accumulating {
2214 if let Some(accum) = &mut self.gpu_grad_accum {
2215 accum.accumulated_count += 1;
2216 } else if let Some(accum) = &mut self.grad_accum {
2217 accum.accumulated_count += 1;
2218 }
2219 }
2220 }
2221 }
2222
2223 let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2224
2225 if avg_loss == 0.0 && valid_count > 0 {
2227 eprintln!(
2228 "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2229 valid_count, total_loss, batch.batch_size
2230 );
2231 }
2232
2233 self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2234 self.accumulated_batches += 1;
2235
2236 if self.accumulated_batches >= self.config.accumulation_steps {
2237 if accumulating {
2238 if self.gpu_grad_accum.is_some() {
2240 self.gpu_optimizer_from_gpu_accum();
2241 } else {
2242 self.gpu_optimizer_from_accum();
2243 }
2244 }
2245 self.optimizer_step();
2246 }
2247
2248 avg_loss
2249 }
2250
2251 pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2255 let hidden_size = self.config.model_config.hidden_size;
2256 let vocab_size = self.config.model_config.vocab_size;
2257 let max_sl = self.config.max_seq_len;
2258 let mut total_loss = 0.0;
2259 let mut valid_count = 0;
2260 for i in 0..batch.batch_size {
2261 if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2262 {
2263 total_loss += loss;
2264 valid_count += 1;
2265 }
2266 }
2267 if valid_count > 0 {
2268 total_loss / valid_count as f32
2269 } else {
2270 0.0
2271 }
2272 }
2273
2274 fn eval_single_sequence(
2276 &mut self,
2277 batch: &LMBatch,
2278 i: usize,
2279 max_sl: usize,
2280 hidden_size: usize,
2281 vocab_size: usize,
2282 ) -> Option<f32> {
2283 let input_ids = batch.get_input(i)?;
2284 let target_ids = batch.get_target(i)?;
2285 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2287 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2288 let seq_len = input_ids.len();
2289 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2290 let stream = self.cuda_trainer.stream();
2291 let scale = 1.0 / seq_len as f32;
2292 let loss = fused_cross_entropy_cuda(
2293 &mut self.gpu_training.logits_buf,
2294 target_ids,
2295 seq_len as u32,
2296 vocab_size as u32,
2297 scale,
2298 stream,
2299 )
2300 .ok()?;
2301 if loss.is_finite() {
2302 Some(loss)
2303 } else {
2304 None
2305 }
2306 }
2307
2308 pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2310 self.train_epoch_with_callback(batches, |_, _, _| {})
2311 }
2312
2313 pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2317 where
2318 F: FnMut(usize, f32, &Self),
2319 {
2320 if batches.is_empty() {
2321 return 0.0;
2322 }
2323
2324 let mut total_loss = 0.0;
2325 let mut batches_processed = 0;
2326
2327 for (i, batch) in batches.iter().enumerate() {
2328 if let Some(max) = self.config.max_steps {
2329 if self.step >= max {
2330 break;
2331 }
2332 }
2333
2334 let batch_loss = self.train_batch(batch);
2335 total_loss += batch_loss;
2336 batches_processed += 1;
2337 on_batch(i, batch_loss, self);
2338 }
2339
2340 if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2342 self.profiler.print_report();
2343 }
2344
2345 total_loss / batches_processed.max(1) as f32
2346 }
2347
2348 pub(crate) fn ensure_grad_accum(&mut self) {
2355 if self.grad_accum.is_some() {
2356 return;
2357 }
2358 let mc = &self.config.model_config;
2359 let hidden_size = mc.hidden_size;
2360 let kv_hidden = mc.num_kv_heads * mc.head_dim();
2361 let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2362 hidden_size,
2363 kv_hidden,
2364 mc.intermediate_size,
2365 );
2366 self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2367 self.cuda_blocks.len(),
2368 block_sizes,
2369 mc.vocab_size,
2370 hidden_size,
2371 ));
2372 }
2373
2374 pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2379 if batch.batch_size == 0 {
2380 return 0.0;
2381 }
2382
2383 if self.accumulated_batches == 0 {
2384 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2385 }
2386
2387 let mut total_loss = 0.0;
2388 let mut valid_count = 0;
2389
2390 for i in 0..batch.batch_size {
2391 let Some(input_ids) = batch.get_input(i) else { continue };
2392 let Some(target_ids) = batch.get_target(i) else { continue };
2393
2394 if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2396 total_loss += loss;
2397 valid_count += 1;
2398 if let Some(accum) = &mut self.grad_accum {
2399 accum.accumulated_count += 1;
2400 }
2401 }
2402 }
2403
2404 if valid_count > 0 {
2405 total_loss / valid_count as f32
2406 } else {
2407 0.0
2408 }
2409 }
2410
2411 pub(crate) fn apply_ddp_gradients(&mut self) {
2417 self.accumulated_loss = 0.0;
2418 self.accumulated_batches = 0;
2419 self.gpu_optimizer_from_accum();
2420 self.optimizer_step();
2421 }
2422
2423 pub(crate) fn grad_accum_ref(
2425 &self,
2426 ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2427 self.grad_accum.as_ref()
2428 }
2429
2430 pub(crate) fn grad_accum_mut(
2432 &mut self,
2433 ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2434 self.grad_accum.as_mut()
2435 }
2436
2437 pub(crate) fn config(&self) -> &TransformerTrainConfig {
2439 &self.config
2440 }
2441
2442 pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2444 self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2445 }
2446
2447 pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2449 self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2450 }
2451
2452 pub fn reached_max_steps(&self) -> bool {
2454 self.config.max_steps.is_some_and(|max| self.step >= max)
2455 }
2456
2457 pub fn step(&self) -> usize {
2459 self.step
2460 }
2461
2462 pub fn set_initial_step(&mut self, step: usize) {
2468 self.step = step;
2469 self.gpu_training.step = step as u32;
2470 }
2471
2472 pub fn set_max_steps(&mut self, max_steps: usize) {
2478 self.config.max_steps = Some(max_steps);
2479 }
2480
2481 pub fn current_lr(&self) -> f32 {
2487 let base_lr = self.config.lr;
2488 if self.step < self.config.warmup_steps {
2489 base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2491 } else if let Some(max_steps) = self.config.max_steps {
2492 let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2494 if decay_steps == 0 {
2495 return base_lr;
2496 }
2497 let decay_step = self.step - self.config.warmup_steps;
2498 let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2499 0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2500 } else {
2501 base_lr
2503 }
2504 }
2505
2506 pub fn enable_profiler(&mut self, interval: usize) {
2517 self.profiler = StepProfiler::new(true, interval);
2518 }
2519
2520 pub fn print_profiler_report(&self) {
2522 self.profiler.print_report();
2523 }
2524
2525 pub fn last_grad_norm(&self) -> f32 {
2527 self.last_grad_norm
2528 }
2529
2530 pub fn param_grad_norms(&self) -> (f32, f32) {
2533 (self.last_grad_norm, self.last_embed_grad_norm)
2534 }
2535
2536 pub fn num_params(&self) -> usize {
2538 self.model.parameters().iter().map(|t| t.len()).sum()
2539 }
2540
2541 pub fn gpu_memory_mb(&self) -> (u64, u64) {
2543 match self.cuda_trainer.context().memory_info() {
2544 Ok((free, total)) => {
2545 let total_mb = (total / (1024 * 1024)) as u64;
2546 let used_mb = ((total - free) / (1024 * 1024)) as u64;
2547 (used_mb, total_mb)
2548 }
2549 Err(_) => (0, 0),
2550 }
2551 }
2552
2553 pub fn sync_weights_to_cpu(&mut self) {
2559 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2560
2561 if use_nf4 {
2562 } else {
2568 for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2569 if let Ok(weights) = block.download_weights() {
2570 let layer = &mut self.model.layers[layer_idx];
2571
2572 layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2573 layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2574 layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2575 layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2576
2577 layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2578 layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2579 layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2580
2581 layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2582 layer.post_attn_norm.weight =
2583 Tensor::from_vec(weights.post_attn_norm_weight, false);
2584 }
2585 }
2586 }
2587
2588 if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2590 self.model.norm.weight = Tensor::from_vec(norm_data, false);
2591 }
2592
2593 if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2600 self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2601 }
2602 }
2603
2604 pub fn model(&self) -> &Transformer {
2606 &self.model
2607 }
2608
2609 pub fn model_mut(&mut self) -> &mut Transformer {
2611 &mut self.model
2612 }
2613
2614 pub fn is_mixed_precision(&self) -> bool {
2616 self.config.precision_config.is_mixed()
2617 }
2618
2619 pub fn grad_scaler(&self) -> &GradScaler {
2621 &self.grad_scaler
2622 }
2623
2624 pub fn is_checkpointing(&self) -> bool {
2626 self.config.checkpoint_config.enabled
2627 }
2628
2629 pub fn save(
2631 &mut self,
2632 path: impl AsRef<std::path::Path>,
2633 name: &str,
2634 architecture: &str,
2635 ) -> crate::Result<()> {
2636 self.sync_weights_to_cpu();
2637
2638 let params: Vec<(String, Tensor)> = self
2640 .model
2641 .named_parameters()
2642 .into_iter()
2643 .map(|(name, tensor)| (name, tensor.clone()))
2644 .collect();
2645
2646 let metadata = ModelMetadata::new(name, architecture);
2647 let model = Model::new(metadata, params);
2648 let config = SaveConfig::new(ModelFormat::SafeTensors);
2649
2650 save_model(&model, path, &config)
2651 }
2652
2653 pub fn prepare_async_save(
2657 &mut self,
2658 name: &str,
2659 architecture: &str,
2660 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2661 self.sync_weights_to_cpu();
2662
2663 let param_data: Vec<(String, Vec<f32>)> = self
2665 .model
2666 .named_parameters()
2667 .into_iter()
2668 .map(|(n, t)| (n, t.data().to_vec()))
2669 .collect();
2670
2671 let name = name.to_string();
2672 let architecture = architecture.to_string();
2673
2674 Box::new(move |path: &std::path::Path| {
2675 let params: Vec<(String, Tensor)> =
2676 param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2677 let metadata = ModelMetadata::new(&name, &architecture);
2678 let model = Model::new(metadata, params);
2679 let config = SaveConfig::new(ModelFormat::SafeTensors);
2680 save_model(&model, path, &config)
2681 })
2682 }
2683
2684 pub fn save_apr(
2689 &mut self,
2690 path: impl AsRef<std::path::Path>,
2691 name: &str,
2692 architecture: &str,
2693 ) -> crate::Result<()> {
2694 self.save_apr_with_tokenizer(path, name, architecture, None)
2695 }
2696
2697 pub fn save_apr_with_tokenizer(
2706 &mut self,
2707 path: impl AsRef<std::path::Path>,
2708 name: &str,
2709 architecture: &str,
2710 tokenizer_dir: Option<&std::path::Path>,
2711 ) -> crate::Result<()> {
2712 self.sync_weights_to_cpu();
2713
2714 let params: Vec<(String, Tensor)> = self
2715 .model
2716 .named_parameters()
2717 .into_iter()
2718 .map(|(name, tensor)| (name, tensor.clone()))
2719 .collect();
2720
2721 use crate::io::save::infer_all_tensor_shapes;
2726 use aprender::serialization::apr::AprWriter;
2727 use serde_json::Value as Jv;
2728
2729 let mc = &self.config.model_config;
2730 let mut writer = AprWriter::new();
2731
2732 writer.set_metadata("model_name", Jv::String(name.to_string()));
2734 writer.set_metadata("architecture", Jv::String(architecture.to_string()));
2735 writer.set_metadata("version", Jv::String("0.1.0".into()));
2736 writer.set_metadata("format", Jv::String("entrenar-checkpoint".into()));
2737
2738 writer.set_metadata(
2741 "hidden_size",
2742 Jv::Number(serde_json::Number::from(mc.hidden_size as u64)),
2743 );
2744 writer.set_metadata(
2745 "num_hidden_layers",
2746 Jv::Number(serde_json::Number::from(mc.num_hidden_layers as u64)),
2747 );
2748 writer.set_metadata(
2749 "num_attention_heads",
2750 Jv::Number(serde_json::Number::from(mc.num_attention_heads as u64)),
2751 );
2752 writer.set_metadata(
2753 "num_kv_heads",
2754 Jv::Number(serde_json::Number::from(mc.num_kv_heads as u64)),
2755 );
2756 writer.set_metadata(
2757 "intermediate_size",
2758 Jv::Number(serde_json::Number::from(mc.intermediate_size as u64)),
2759 );
2760 writer
2761 .set_metadata("vocab_size", Jv::Number(serde_json::Number::from(mc.vocab_size as u64)));
2762 writer.set_metadata(
2763 "max_position_embeddings",
2764 Jv::Number(serde_json::Number::from(mc.max_position_embeddings as u64)),
2765 );
2766 if let Some(rope) = serde_json::Number::from_f64(mc.rope_theta as f64) {
2767 writer.set_metadata("rope_theta", Jv::Number(rope));
2768 }
2769 if let Some(eps) = serde_json::Number::from_f64(mc.rms_norm_eps as f64) {
2770 writer.set_metadata("rms_norm_eps", Jv::Number(eps));
2771 }
2772
2773 if let Some(dir) = tokenizer_dir {
2779 let tok_path = dir.join("tokenizer.json");
2780 if let Ok(json_bytes) = std::fs::read(&tok_path) {
2781 if let Ok(tok) = serde_json::from_slice::<Jv>(&json_bytes) {
2782 if let Some(model) = tok.get("model") {
2783 if let Some(vocab_obj) = model.get("vocab").and_then(|v| v.as_object()) {
2784 let mut vocab_pairs: Vec<(String, u64)> = vocab_obj
2785 .iter()
2786 .filter_map(|(k, v)| Some((k.clone(), v.as_u64()?)))
2787 .collect();
2788 vocab_pairs.sort_by_key(|(_, id)| *id);
2789 let vocab: Vec<Jv> =
2790 vocab_pairs.into_iter().map(|(k, _)| Jv::String(k)).collect();
2791 writer.set_metadata("tokenizer.vocabulary", Jv::Array(vocab));
2792 }
2793 if let Some(merges_arr) = model.get("merges").and_then(|m| m.as_array()) {
2794 let merges: Vec<Jv> = merges_arr
2795 .iter()
2796 .filter_map(|v| v.as_str().map(|s| Jv::String(s.to_string())))
2797 .collect();
2798 writer.set_metadata("tokenizer.merges", Jv::Array(merges));
2799 }
2800 }
2801 if let Some(added) = tok.get("added_tokens").and_then(|a| a.as_array()) {
2803 for entry in added {
2804 let content =
2805 entry.get("content").and_then(|c| c.as_str()).unwrap_or("");
2806 let id = entry.get("id").and_then(|i| i.as_u64());
2807 if let Some(id) = id {
2808 match content {
2809 "<s>" | "<|im_start|>" | "<|begin_of_text|>" => {
2810 writer.set_metadata(
2811 "tokenizer.bos_token_id",
2812 Jv::Number(serde_json::Number::from(id)),
2813 );
2814 }
2815 "</s>" | "<|im_end|>" | "<|end_of_text|>" | "<|endoftext|>" => {
2816 writer.set_metadata(
2817 "tokenizer.eos_token_id",
2818 Jv::Number(serde_json::Number::from(id)),
2819 );
2820 }
2821 _ => {}
2822 }
2823 }
2824 }
2825 }
2826 }
2827 }
2828 }
2829
2830 let shapes = infer_all_tensor_shapes(¶ms);
2832 for (tname, tensor) in ¶ms {
2833 let data = tensor.data();
2834 let slice = data.as_slice().expect("tensor data must be contiguous");
2835 let shape = shapes.get(tname).cloned().unwrap_or_else(|| vec![tensor.len()]);
2836 writer.add_tensor_f32(tname, shape, slice);
2837 }
2838
2839 writer
2840 .write(path)
2841 .map_err(|e| crate::error::Error::Serialization(format!("APR write failed: {e}")))
2842 }
2843
2844 fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2851 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2852 if use_nf4 {
2853 let frozen_suffixes = [
2854 "q_proj.weight",
2855 "k_proj.weight",
2856 "v_proj.weight",
2857 "o_proj.weight",
2858 "gate_proj.weight",
2859 "up_proj.weight",
2860 "down_proj.weight",
2861 ];
2862 self.model
2863 .named_parameters()
2864 .into_iter()
2865 .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2866 .map(|(n, t)| (n, t.data().to_vec()))
2867 .collect()
2868 } else {
2869 self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2870 }
2871 }
2872
2873 fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2874 if self.config.quantize_nf4 && self.config.is_lora() {
2875 self.cuda_blocks
2876 .iter()
2877 .enumerate()
2878 .filter_map(|(i, block)| {
2879 block
2880 .download_lora_weights()
2881 .ok()
2882 .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2883 })
2884 .collect()
2885 } else {
2886 Vec::new()
2887 }
2888 }
2889
2890 pub fn prepare_async_apr_save(
2891 &mut self,
2892 name: &str,
2893 architecture: &str,
2894 step: usize,
2895 loss: f64,
2896 lr: f64,
2897 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2898 self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2899 }
2900
2901 pub fn prepare_async_apr_save_with_tokenizer(
2907 &mut self,
2908 name: &str,
2909 architecture: &str,
2910 step: usize,
2911 loss: f64,
2912 lr: f64,
2913 tokenizer_path: Option<&std::path::Path>,
2914 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2915 self.sync_weights_to_cpu();
2916
2917 let param_data = self.snapshot_param_data();
2918 let lora_data = self.snapshot_lora_data();
2919
2920 let embed_m: Vec<Vec<f32>> = self
2922 .embed_optimizer
2923 .first_moments()
2924 .iter()
2925 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2926 .collect();
2927 let embed_v: Vec<Vec<f32>> = self
2928 .embed_optimizer
2929 .second_moments()
2930 .iter()
2931 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2932 .collect();
2933 let embed_step = self.embed_optimizer.step_count();
2934
2935 let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2940 .gpu_training
2941 .optimizer_states
2942 .iter()
2943 .map(|state| state.download_to_host().unwrap_or_default())
2944 .collect();
2945
2946 let lm_head_m_host = {
2948 let mut buf = vec![0.0f32; self.lm_head_m.len()];
2949 let _ = self.lm_head_m.copy_to_host(&mut buf);
2950 buf
2951 };
2952 let lm_head_v_host = {
2953 let mut buf = vec![0.0f32; self.lm_head_v.len()];
2954 let _ = self.lm_head_v.copy_to_host(&mut buf);
2955 buf
2956 };
2957 let final_norm_m_host = {
2958 let mut buf = vec![0.0f32; self.final_norm_m.len()];
2959 let _ = self.final_norm_m.copy_to_host(&mut buf);
2960 buf
2961 };
2962 let final_norm_v_host = {
2963 let mut buf = vec![0.0f32; self.final_norm_v.len()];
2964 let _ = self.final_norm_v.copy_to_host(&mut buf);
2965 buf
2966 };
2967
2968 let name = name.to_string();
2969 let architecture = architecture.to_string();
2970 let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2971 let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2972
2973 let arch_hidden_size = self.config.model_config.hidden_size;
2980 let arch_num_layers = self.config.model_config.num_hidden_layers;
2981 let arch_num_heads = self.config.model_config.num_attention_heads;
2982 let arch_num_kv_heads = self.config.model_config.num_kv_heads;
2983 let arch_intermediate_size = self.config.model_config.intermediate_size;
2984 let arch_vocab_size = self.config.model_config.vocab_size;
2985 let arch_max_position_embeddings = self.config.model_config.max_position_embeddings;
2986 let arch_rope_theta = self.config.model_config.rope_theta;
2987 let arch_rms_norm_eps = self.config.model_config.rms_norm_eps;
2988
2989 let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2992 tokenizer_path.and_then(|p| {
2993 let json_bytes = std::fs::read(p).ok()?;
2994 let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2995 let model = tok.get("model")?;
2996 let vocab_obj = model.get("vocab")?.as_object()?;
2997 let mut vocab_pairs: Vec<(String, u64)> =
2999 vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
3000 vocab_pairs.sort_by_key(|(_, id)| *id);
3001 let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
3002 let merges: Vec<String> = model
3004 .get("merges")?
3005 .as_array()?
3006 .iter()
3007 .filter_map(|v| v.as_str().map(String::from))
3008 .collect();
3009 let added = tok.get("added_tokens").and_then(|a| a.as_array());
3011 let bos_id = added.and_then(|arr| {
3012 arr.iter()
3013 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
3014 .and_then(|t| t.get("id")?.as_u64())
3015 });
3016 let eos_id = added.and_then(|arr| {
3017 arr.iter()
3018 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
3019 .and_then(|t| t.get("id")?.as_u64())
3020 });
3021 if vocab.is_empty() {
3022 return None;
3023 }
3024 println!(
3025 " [ALB-130] Embedding tokenizer: {} vocab, {} merges",
3026 vocab.len(),
3027 merges.len()
3028 );
3029 Some((vocab, merges, bos_id, eos_id))
3030 });
3031
3032 Box::new(move |path: &std::path::Path| {
3033 use aprender::serialization::apr::AprWriter;
3034 use serde_json::Value as Jv;
3035
3036 let mut writer = AprWriter::new();
3037
3038 writer.set_metadata("model_name", Jv::String(name));
3040 writer.set_metadata("architecture", Jv::String(architecture));
3041 writer.set_metadata(
3042 "format",
3043 Jv::String(if is_delta_checkpoint {
3044 "entrenar-delta-checkpoint".into()
3045 } else {
3046 "entrenar-checkpoint".into()
3047 }),
3048 );
3049 writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
3050 writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
3051 writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
3052 writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
3053 if let Some(cfg) = model_config_json {
3054 writer.set_metadata("model_config", Jv::String(cfg));
3055 }
3056
3057 writer.set_metadata(
3061 "hidden_size",
3062 Jv::Number(serde_json::Number::from(arch_hidden_size as u64)),
3063 );
3064 writer.set_metadata(
3065 "num_hidden_layers",
3066 Jv::Number(serde_json::Number::from(arch_num_layers as u64)),
3067 );
3068 writer.set_metadata(
3069 "num_attention_heads",
3070 Jv::Number(serde_json::Number::from(arch_num_heads as u64)),
3071 );
3072 writer.set_metadata(
3073 "num_kv_heads",
3074 Jv::Number(serde_json::Number::from(arch_num_kv_heads as u64)),
3075 );
3076 writer.set_metadata(
3077 "intermediate_size",
3078 Jv::Number(serde_json::Number::from(arch_intermediate_size as u64)),
3079 );
3080 writer.set_metadata(
3081 "vocab_size",
3082 Jv::Number(serde_json::Number::from(arch_vocab_size as u64)),
3083 );
3084 writer.set_metadata(
3085 "max_position_embeddings",
3086 Jv::Number(serde_json::Number::from(arch_max_position_embeddings as u64)),
3087 );
3088 if let Some(rope) = serde_json::Number::from_f64(arch_rope_theta as f64) {
3089 writer.set_metadata("rope_theta", Jv::Number(rope));
3090 }
3091 if let Some(eps) = serde_json::Number::from_f64(arch_rms_norm_eps as f64) {
3092 writer.set_metadata("rms_norm_eps", Jv::Number(eps));
3093 }
3094
3095 if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
3097 writer.set_metadata(
3098 "tokenizer.vocabulary",
3099 Jv::Array(vocab.into_iter().map(Jv::String).collect()),
3100 );
3101 writer.set_metadata(
3102 "tokenizer.merges",
3103 Jv::Array(merges.into_iter().map(Jv::String).collect()),
3104 );
3105 if let Some(bos) = bos_id {
3106 writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
3107 }
3108 if let Some(eos) = eos_id {
3109 writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
3110 }
3111 }
3112
3113 let hidden_size = param_data
3115 .iter()
3116 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
3117 .map_or(0, |(_, d)| d.len());
3118
3119 for (tensor_name, data) in ¶m_data {
3121 let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
3122 writer.add_tensor_f32(tensor_name.clone(), shape, data);
3123 }
3124
3125 for (i, m_data) in embed_m.iter().enumerate() {
3127 let len = m_data.len();
3128 writer.add_tensor_f32(
3129 format!("__training__.embed_optimizer.m.{i}"),
3130 vec![len],
3131 m_data,
3132 );
3133 }
3134 for (i, v_data) in embed_v.iter().enumerate() {
3135 let len = v_data.len();
3136 writer.add_tensor_f32(
3137 format!("__training__.embed_optimizer.v.{i}"),
3138 vec![len],
3139 v_data,
3140 );
3141 }
3142
3143 for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
3145 for (suffix, data) in buffers {
3146 let len = data.len();
3147 writer.add_tensor_f32(
3148 format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
3149 vec![len],
3150 data,
3151 );
3152 }
3153 }
3154
3155 if !lm_head_m_host.is_empty() {
3157 let len = lm_head_m_host.len();
3158 writer.add_tensor_f32(
3159 "__training__.lm_head_optimizer.m".to_string(),
3160 vec![len],
3161 &lm_head_m_host,
3162 );
3163 let len = lm_head_v_host.len();
3164 writer.add_tensor_f32(
3165 "__training__.lm_head_optimizer.v".to_string(),
3166 vec![len],
3167 &lm_head_v_host,
3168 );
3169 }
3170 if !final_norm_m_host.is_empty() {
3171 let len = final_norm_m_host.len();
3172 writer.add_tensor_f32(
3173 "__training__.final_norm_optimizer.m".to_string(),
3174 vec![len],
3175 &final_norm_m_host,
3176 );
3177 let len = final_norm_v_host.len();
3178 writer.add_tensor_f32(
3179 "__training__.final_norm_optimizer.v".to_string(),
3180 vec![len],
3181 &final_norm_v_host,
3182 );
3183 }
3184
3185 for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
3187 if !a_q.is_empty() {
3188 writer.add_tensor_f32(
3189 format!("lora.{layer_idx}.q_proj.lora_a"),
3190 vec![a_q.len()],
3191 a_q,
3192 );
3193 writer.add_tensor_f32(
3194 format!("lora.{layer_idx}.q_proj.lora_b"),
3195 vec![b_q.len()],
3196 b_q,
3197 );
3198 }
3199 if !a_v.is_empty() {
3200 writer.add_tensor_f32(
3201 format!("lora.{layer_idx}.v_proj.lora_a"),
3202 vec![a_v.len()],
3203 a_v,
3204 );
3205 writer.add_tensor_f32(
3206 format!("lora.{layer_idx}.v_proj.lora_b"),
3207 vec![b_v.len()],
3208 b_v,
3209 );
3210 }
3211 }
3212
3213 writer
3215 .write(path)
3216 .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3217
3218 Ok(())
3219 })
3220 }
3221
3222 pub fn gpu_name(&self) -> String {
3224 self.cuda_trainer.device_name()
3225 }
3226
3227 pub fn save_cuda_lora_adapter(
3237 &self,
3238 output_dir: &std::path::Path,
3239 base_model_name: Option<&str>,
3240 ) -> crate::Result<()> {
3241 if !self.config.quantize_nf4 || !self.config.is_lora() {
3242 return Ok(()); }
3244
3245 let lora_rank = self.config.lora_rank.unwrap_or(16);
3246 let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3247 let lora_scale = lora_alpha / lora_rank as f32;
3248 let hidden_size = self.config.model_config.hidden_size;
3249 let head_dim = self.config.model_config.head_dim();
3250 let q_dim = self.config.model_config.num_attention_heads * head_dim;
3251 let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3252
3253 let lora_config =
3254 crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3255
3256 let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3257
3258 for (i, block) in self.cuda_blocks.iter().enumerate() {
3259 let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3260 Ok(weights) => weights,
3261 Err(_) => continue, };
3263
3264 if a_q.is_empty() && a_v.is_empty() {
3265 continue;
3266 }
3267
3268 if !a_q.is_empty() {
3270 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3272 for r in 0..hidden_size {
3273 for c in 0..lora_rank {
3274 a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3275 }
3276 }
3277
3278 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3281 let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3282 for r in 0..lora_rank {
3283 for c in 0..q_dim {
3284 b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3285 }
3286 }
3287
3288 let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3289 let mut layer = crate::lora::LoRALayer::new(
3290 base_weight,
3291 q_dim,
3292 hidden_size,
3293 lora_rank,
3294 lora_alpha,
3295 );
3296 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3298 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3299
3300 adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3301 }
3302
3303 if !a_v.is_empty() {
3305 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3306 for r in 0..hidden_size {
3307 for c in 0..lora_rank {
3308 a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3309 }
3310 }
3311
3312 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3313 let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3314 for r in 0..lora_rank {
3315 for c in 0..kv_hidden {
3316 b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3317 }
3318 }
3319
3320 let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3321 let mut layer = crate::lora::LoRALayer::new(
3322 base_weight,
3323 kv_hidden,
3324 hidden_size,
3325 lora_rank,
3326 lora_alpha,
3327 );
3328 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3329 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3330
3331 adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3332 }
3333 }
3334
3335 if adapters.is_empty() {
3336 println!(" [WARN] No LoRA adapters found to save");
3337 return Ok(());
3338 }
3339
3340 let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3341 adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3342
3343 std::fs::create_dir_all(output_dir).ok();
3344 crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3345 .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3346
3347 let adapter_path = output_dir.join("adapter_model.safetensors");
3348 let size_mb =
3349 std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3350 println!(
3351 "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3352 adapters.len(),
3353 size_mb,
3354 output_dir.display()
3355 );
3356
3357 Ok(())
3358 }
3359
3360 pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3365 let path = dir.join("optimizer_state.json");
3366 let m_data: Vec<Option<Vec<f32>>> = self
3367 .embed_optimizer
3368 .first_moments()
3369 .iter()
3370 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3371 .collect();
3372 let v_data: Vec<Option<Vec<f32>>> = self
3373 .embed_optimizer
3374 .second_moments()
3375 .iter()
3376 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3377 .collect();
3378 let state = serde_json::json!({
3379 "type": "adamw_cpu_embed",
3380 "step": self.embed_optimizer.step_count(),
3381 "m": m_data,
3382 "v": v_data,
3383 });
3384 let json_str = serde_json::to_string(&state).map_err(|e| {
3385 crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3386 })?;
3387 std::fs::write(&path, json_str)
3388 .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3389 Ok(())
3390 }
3391
3392 pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3398 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3399 Ok(r) => r,
3400 Err(_) => return (0, self.cuda_blocks.len()),
3401 };
3402
3403 let mut restored = 0usize;
3404 for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3405 let a_q =
3406 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3407 let b_q =
3408 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3409 let a_v =
3410 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3411 let b_v =
3412 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3413
3414 if a_q.is_empty() {
3415 continue; }
3417
3418 if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3419 eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3420 continue;
3421 }
3422 restored += 1;
3423 }
3424
3425 (restored, self.cuda_blocks.len())
3426 }
3427
3428 pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3433 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3434 Ok(r) => r,
3435 Err(_) => return false,
3436 };
3437
3438 if let Some(step_val) = reader.get_metadata("optimizer_step") {
3440 if let Some(step_str) = step_val.as_str() {
3441 if let Ok(step) = step_str.parse::<u64>() {
3442 self.embed_optimizer.set_step_count(step);
3443 }
3444 }
3445 }
3446
3447 for i in 0..128 {
3449 let name = format!("__training__.embed_optimizer.m.{i}");
3450 match reader.read_tensor_f32(&name) {
3451 Ok(data) if !data.is_empty() => {
3452 self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3453 }
3454 _ => break,
3455 }
3456 }
3457
3458 for i in 0..128 {
3460 let name = format!("__training__.embed_optimizer.v.{i}");
3461 match reader.read_tensor_f32(&name) {
3462 Ok(data) if !data.is_empty() => {
3463 self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3464 }
3465 _ => break,
3466 }
3467 }
3468
3469 let suffixes = [
3471 "m.w_q",
3472 "v.w_q",
3473 "m.w_k",
3474 "v.w_k",
3475 "m.w_v",
3476 "v.w_v",
3477 "m.w_o",
3478 "v.w_o",
3479 "m.w_gate",
3480 "v.w_gate",
3481 "m.w_up",
3482 "v.w_up",
3483 "m.w_down",
3484 "v.w_down",
3485 "m.input_norm",
3486 "v.input_norm",
3487 "m.post_attn_norm",
3488 "v.post_attn_norm",
3489 ];
3490 let mut blocks_restored = 0usize;
3491 for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3492 let mut data = std::collections::HashMap::new();
3493 for suffix in &suffixes {
3494 let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3495 if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3496 if !tensor_data.is_empty() {
3497 data.insert(suffix.to_string(), tensor_data);
3498 }
3499 }
3500 }
3501 if !data.is_empty() {
3502 let _ = state.restore_from_host(&data);
3503 blocks_restored += 1;
3504 }
3505 }
3506
3507 if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3509 if m_data.len() == self.lm_head_m.len() {
3510 let _ = self.lm_head_m.copy_from_host(&m_data);
3511 }
3512 }
3513 if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3514 if v_data.len() == self.lm_head_v.len() {
3515 let _ = self.lm_head_v.copy_from_host(&v_data);
3516 }
3517 }
3518
3519 if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3521 if m_data.len() == self.final_norm_m.len() {
3522 let _ = self.final_norm_m.copy_from_host(&m_data);
3523 }
3524 }
3525 if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3526 if v_data.len() == self.final_norm_v.len() {
3527 let _ = self.final_norm_v.copy_from_host(&v_data);
3528 }
3529 }
3530
3531 if blocks_restored > 0 {
3533 println!(
3534 " ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3535 self.gpu_training.optimizer_states.len()
3536 );
3537 } else if !self.gpu_training.optimizer_states.is_empty() {
3538 println!(
3539 " [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3540 self.gpu_training.optimizer_states.len()
3541 );
3542 }
3543
3544 true
3545 }
3546
3547 pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3551 let path = dir.join("optimizer_state.json");
3552 let data = match std::fs::read_to_string(&path) {
3553 Ok(d) => d,
3554 Err(_) => return false,
3555 };
3556 let state: serde_json::Value = match serde_json::from_str(&data) {
3557 Ok(v) => v,
3558 Err(_) => return false,
3559 };
3560 if let Some(step) = state["step"].as_u64() {
3561 self.embed_optimizer.set_step_count(step);
3562 }
3563 restore_moment_buffers(&state["m"], |idx, arr| {
3564 self.embed_optimizer.set_first_moment(idx, arr);
3565 });
3566 restore_moment_buffers(&state["v"], |idx, arr| {
3567 self.embed_optimizer.set_second_moment(idx, arr);
3568 });
3569 true
3570 }
3571}
3572
3573#[cfg(feature = "cuda")]
3577fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3578 if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3579 vec![numel]
3580 } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3581 let other_dim = numel / hidden_size;
3582 if name.ends_with("down_proj.weight") {
3583 vec![hidden_size, other_dim]
3584 } else {
3585 vec![other_dim, hidden_size]
3586 }
3587 } else {
3588 vec![numel]
3589 }
3590}
3591
3592#[cfg(feature = "cuda")]
3594fn restore_moment_buffers(
3595 json_arr: &serde_json::Value,
3596 mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3597) {
3598 let Some(arr) = json_arr.as_array() else { return };
3599 for (idx, val) in arr.iter().enumerate() {
3600 let Some(inner) = val.as_array() else { continue };
3601 let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3602 if !floats.is_empty() {
3603 set_fn(idx, ndarray::Array1::from_vec(floats));
3604 }
3605 }
3606}
3607
3608#[cfg(not(feature = "cuda"))]
3611pub struct CudaTransformerTrainer;
3612
3613#[cfg(not(feature = "cuda"))]
3614impl CudaTransformerTrainer {
3615 pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3616 Err(crate::error::Error::ConfigError(
3617 "CUDA not available (compiled without cuda feature)".into(),
3618 ))
3619 }
3620
3621 pub fn with_model(
3622 _model: crate::transformer::Transformer,
3623 _config: super::config::TransformerTrainConfig,
3624 ) -> crate::Result<Self> {
3625 Err(crate::error::Error::ConfigError(
3626 "CUDA not available (compiled without cuda feature)".into(),
3627 ))
3628 }
3629
3630 pub fn gpu_name(&self) -> String {
3631 unreachable!("CudaTransformerTrainer stub should never be instantiated")
3632 }
3633}
3634
3635#[cfg(test)]
3636mod tests {
3637 #[test]
3638 #[cfg(not(feature = "cuda"))]
3639 fn test_cuda_trainer_stub_returns_error() {
3640 use super::super::config::TransformerTrainConfig;
3641 use crate::transformer::TransformerConfig;
3642
3643 let mc = TransformerConfig::tiny();
3644 let config = TransformerTrainConfig::new(mc);
3645 let result = super::CudaTransformerTrainer::new(config);
3646 assert!(result.is_err());
3647 }
3648}