1#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{
42 gemm_forward, pre_warm_forward_kernels, rms_norm_forward, rms_norm_forward_with_eps,
43};
44#[cfg(feature = "cuda")]
45use crate::autograd::cuda_optim::{
46 adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
47 gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
48 squared_sum_launch_into, FusedClipState,
49};
50#[cfg(feature = "cuda")]
51use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
52#[cfg(feature = "cuda")]
53use crate::autograd::precision::GradScaler;
54#[cfg(feature = "cuda")]
55use crate::autograd::Tensor;
56#[cfg(feature = "cuda")]
57use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
58#[cfg(feature = "cuda")]
59use crate::optim::{AdamW, Optimizer};
60#[cfg(feature = "cuda")]
61use crate::train::MetricsTracker;
62#[cfg(feature = "cuda")]
63use crate::transformer::{
64 CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
65 GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
66};
67
68#[cfg(feature = "cuda")]
69use super::batch::LMBatch;
70#[cfg(feature = "cuda")]
71use super::config::TransformerTrainConfig;
72#[cfg(feature = "cuda")]
73use super::step_profiler::StepProfiler;
74
75#[cfg(feature = "cuda")]
82fn compute_workspace_clip_scale_gpu(
83 ws: &CudaGradWorkspace,
84 max_norm: f32,
85 stream: &CudaStream,
86) -> (f32, f32) {
87 use crate::autograd::cuda_optim::PendingSquaredSum;
88
89 let all_bufs: [&GpuBuffer<f32>; 9] = [
90 &ws.grad_w_q,
91 &ws.grad_w_k,
92 &ws.grad_w_v,
93 &ws.grad_w_o,
94 &ws.grad_gate,
95 &ws.grad_up,
96 &ws.grad_down,
97 &ws.grad_input_norm,
98 &ws.grad_post_attn_norm,
99 ];
100
101 let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
104 for buf in &all_bufs {
105 let n = buf.len() as u32;
106 if n == 0 {
107 continue;
108 }
109 if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
110 pending.push(p);
111 }
112 }
113
114 if stream.synchronize().is_err() {
116 return (1.0, 0.0);
117 }
118
119 let mut total_sq = 0.0f64;
123 for p in &pending {
124 if let Ok(sq_norm) = squared_sum_collect(p) {
125 total_sq += f64::from(sq_norm); }
127 }
128
129 let grad_norm = total_sq.sqrt() as f32; let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
131 (scale, grad_norm)
132}
133
134#[cfg(feature = "cuda")]
138fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
139 let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
140 if (scale - 1.0).abs() < 1e-7 {
141 return grad_norm;
142 }
143
144 let n_wq = ws.grad_w_q.len() as u32;
145 let n_wk = ws.grad_w_k.len() as u32;
146 let n_wv = ws.grad_w_v.len() as u32;
147 let n_wo = ws.grad_w_o.len() as u32;
148 let n_gate = ws.grad_gate.len() as u32;
149 let n_up = ws.grad_up.len() as u32;
150 let n_down = ws.grad_down.len() as u32;
151 let n_inorm = ws.grad_input_norm.len() as u32;
152 let n_panorm = ws.grad_post_attn_norm.len() as u32;
153
154 let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
155 let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
156 let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
157 let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
158 let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
159 let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
160 let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
161 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
162 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
163 grad_norm
164}
165
166#[cfg(feature = "cuda")]
177fn fused_clip_workspace_gradients(
178 ws: &mut CudaGradWorkspace,
179 max_norm: f32,
180 state: &FusedClipState,
181 stream: &CudaStream,
182) {
183 let all_bufs: [&GpuBuffer<f32>; 9] = [
184 &ws.grad_w_q,
185 &ws.grad_w_k,
186 &ws.grad_w_v,
187 &ws.grad_w_o,
188 &ws.grad_gate,
189 &ws.grad_up,
190 &ws.grad_down,
191 &ws.grad_input_norm,
192 &ws.grad_post_attn_norm,
193 ];
194
195 for (i, buf) in all_bufs.iter().enumerate() {
198 let n = buf.len() as u32;
199 if n == 0 {
200 continue;
201 }
202 let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
203 let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
204 }
205
206 let _ = clip_scale_reduce_cuda(
209 &state.partials_buf,
210 state.total_partials,
211 max_norm,
212 &state.scale_buf,
213 stream,
214 );
215
216 let scale_ptr = state.scale_buf.as_ptr(); let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
220 &mut ws.grad_w_q,
221 &mut ws.grad_w_k,
222 &mut ws.grad_w_v,
223 &mut ws.grad_w_o,
224 &mut ws.grad_gate,
225 &mut ws.grad_up,
226 &mut ws.grad_down,
227 &mut ws.grad_input_norm,
228 &mut ws.grad_post_attn_norm,
229 ];
230 for buf in &mut all_bufs_mut {
231 let n = buf.len() as u32;
232 if n == 0 {
233 continue;
234 }
235 let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
236 }
237}
238
239#[cfg(feature = "cuda")]
243#[allow(dead_code)]
244fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
245 let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
246 norm
247}
248
249#[cfg(feature = "cuda")]
259#[allow(dead_code)]
260fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
261 if (inv_scale - 1.0).abs() < 1e-7 {
262 return;
263 }
264
265 let n_wq = ws.grad_w_q.len() as u32;
266 let n_wk = ws.grad_w_k.len() as u32;
267 let n_wv = ws.grad_w_v.len() as u32;
268 let n_wo = ws.grad_w_o.len() as u32;
269 let n_gate = ws.grad_gate.len() as u32;
270 let n_up = ws.grad_up.len() as u32;
271 let n_down = ws.grad_down.len() as u32;
272 let n_inorm = ws.grad_input_norm.len() as u32;
273 let n_panorm = ws.grad_post_attn_norm.len() as u32;
274
275 let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
276 let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
277 let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
278 let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
279 let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
280 let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
281 let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
282 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
283 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
284}
285
286#[cfg(feature = "cuda")]
294struct GpuPretrainState {
295 layer_inputs: Vec<GpuBuffer<f32>>,
297 saved_layer_mask: Vec<bool>,
301 recompute_buf: Option<GpuBuffer<f32>>,
305 final_norm_weight: GpuBuffer<f32>,
307 blocks_output: GpuBuffer<f32>,
309 grad_buf_a: GpuBuffer<f32>,
311 grad_buf_b: GpuBuffer<f32>,
313 grad_final_norm_weight: GpuBuffer<f32>,
315 norm_output: GpuBuffer<f32>,
317 logits_buf: GpuBuffer<f32>,
319 lm_head_grad_hidden: GpuBuffer<f32>,
321 optimizer_states: Vec<GpuBlockOptimizerState>,
323 step: u32,
325}
326
327#[cfg(feature = "cuda")]
338pub struct CudaTransformerTrainer {
339 model: Transformer,
341 cuda_trainer: CudaTrainer,
343 cuda_blocks: Vec<CudaBlock>,
345 cuda_grad_workspace: CudaGradWorkspace,
347 nf4_shared_scratch: Option<CudaBlockScratch>,
349 nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
351 nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
353 gpu_training: GpuPretrainState,
355 lm_head_weight_gpu: GpuBuffer<f32>,
357 lm_head_grad_gpu: GpuBuffer<f32>,
359 lm_head_m: GpuBuffer<f32>,
361 lm_head_v: GpuBuffer<f32>,
363 final_norm_m: GpuBuffer<f32>,
365 final_norm_v: GpuBuffer<f32>,
367 embed_optimizer: AdamW,
369 config: TransformerTrainConfig,
371 pub metrics: MetricsTracker,
373 step: usize,
375 accumulated_loss: f32,
377 accumulated_batches: usize,
379 last_grad_norm: f32,
381 last_embed_grad_norm: f32,
383 grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
386 gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
389 grad_scaler: GradScaler,
393 profiler: StepProfiler,
396 fwd_scratch_a: GpuBuffer<f32>,
399 fwd_scratch_b: GpuBuffer<f32>,
400 h2d_staging: Vec<f32>,
403 d2h_staging: Vec<f32>,
408 fused_clip: Option<FusedClipState>,
411 final_norm_zero_buf: Vec<f32>,
415}
416
417#[cfg(feature = "cuda")]
418impl CudaTransformerTrainer {
419 pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
426 let model = Transformer::new(&config.model_config);
427 Self::with_model(model, config)
428 }
429
430 pub fn for_inference(
444 checkpoint_dir: impl AsRef<std::path::Path>,
445 model_config: crate::transformer::TransformerConfig,
446 ) -> crate::Result<Self> {
447 let dir = checkpoint_dir.as_ref();
448
449 let model = if let Some((Some(m), _step)) =
451 crate::config::try_load_apr_for_inference(dir, &model_config)
452 {
453 m
454 } else {
455 Transformer::from_safetensors(dir, &model_config)?
456 };
457
458 let mut config = TransformerTrainConfig::new(model_config);
459 config.max_seq_len = config.model_config.max_position_embeddings;
460 Self::with_model(model, config)
461 }
462
463 pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
469 if !cuda_training_available() {
470 return Err(crate::error::Error::ConfigError("CUDA not available".into()));
471 }
472
473 let mc = &config.model_config;
474 let max_seq_len = config.max_seq_len;
475 let hidden_size = mc.hidden_size;
476 let vocab_size = mc.vocab_size;
477 let num_layers = mc.num_hidden_layers;
478
479 let cuda_trainer = CudaTrainer::new().map_err(|e| {
481 crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
482 })?;
483
484 println!(
485 " GPU: {} ({:.1} GB)",
486 cuda_trainer.device_name(),
487 cuda_trainer.total_memory() as f64 / 1e9
488 );
489
490 let ctx = cuda_trainer.context().clone();
491 let stream = cuda_trainer.stream();
492
493 pre_warm_forward_kernels(
496 hidden_size,
497 mc.intermediate_size,
498 mc.num_attention_heads,
499 mc.num_kv_heads,
500 mc.head_dim(),
501 max_seq_len,
502 )
503 .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
504
505 {
509 use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
510 let head_dim = mc.head_dim();
511 pre_warm_lora_backward_kernels(
512 hidden_size,
513 mc.num_attention_heads * head_dim,
514 mc.num_kv_heads * head_dim,
515 max_seq_len,
516 config.lora_rank.unwrap_or(0),
517 mc.intermediate_size,
518 mc.num_attention_heads,
519 config.quantize_nf4 && config.is_lora(),
520 )
521 .map_err(|e| {
522 crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
523 })?;
524 eprintln!(" ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
525 }
526
527 if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
530 println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
531 }
532 if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
533 println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
534 }
535
536 let use_nf4 = config.quantize_nf4 && config.is_lora();
538 let cuda_blocks = Self::upload_blocks(
539 &model,
540 mc,
541 &config,
542 &ctx,
543 use_nf4,
544 num_layers,
545 hidden_size,
546 max_seq_len,
547 )?;
548
549 let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
551 crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
552 })?;
553
554 let buf_size = max_seq_len * hidden_size;
556 let logits_size = max_seq_len * vocab_size;
557
558 let checkpointing = config.checkpoint_config.enabled;
562 let segment_size = if checkpointing {
563 let ns = config.checkpoint_config.num_segments.max(1);
564 num_layers.div_ceil(ns)
565 } else {
566 1 };
568 let saved_layer_mask: Vec<bool> =
569 (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
570
571 let mut layer_inputs = Vec::with_capacity(num_layers);
572 for _ in 0..num_layers {
573 layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
574 crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
575 })?);
576 }
577
578 let recompute_buf = if checkpointing {
580 Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
581 crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
582 })?)
583 } else {
584 None
585 };
586
587 if checkpointing {
588 let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
589 println!(
590 " ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
591 config.checkpoint_config.num_segments, saved_count, num_layers
592 );
593 }
594
595 let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
597 let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
598 crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
599 })?;
600
601 let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
602 crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
603 })?;
604 let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
605 crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
606 })?;
607 let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
608 crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
609 })?;
610 let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
611 crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
612 })?;
613 let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
614 crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
615 })?;
616 let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
617 crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
618 })?;
619 let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
620 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
621 })?;
622
623 let mut optimizer_states = Vec::new();
625 if !use_nf4 {
626 optimizer_states.reserve(num_layers);
627 for (i, block) in cuda_blocks.iter().enumerate() {
628 optimizer_states.push(block.init_optimizer_state().map_err(|e| {
629 crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
630 })?);
631 }
632 }
633
634 let gpu_training = GpuPretrainState {
635 layer_inputs,
636 saved_layer_mask,
637 recompute_buf,
638 final_norm_weight,
639 blocks_output,
640 grad_buf_a,
641 grad_buf_b,
642 grad_final_norm_weight,
643 norm_output,
644 logits_buf,
645 lm_head_grad_hidden,
646 optimizer_states,
647 step: 0,
648 };
649
650 let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
653 let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
654 let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
655 crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
656 })?;
657 let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
658 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
659 })?;
660 let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
663 .map_err(|e| {
664 crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
665 })?;
666 let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
667 .map_err(|e| {
668 crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
669 })?;
670
671 let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
673 crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
674 })?;
675 let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
676 crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
677 })?;
678
679 let buf_size = max_seq_len * hidden_size;
681 let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
682 crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
683 })?;
684 let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
685 crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
686 })?;
687
688 stream
690 .synchronize()
691 .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
692
693 println!(
694 " ✓ GPU training state allocated (LM head: {:.1} MB)",
695 (vocab_size * hidden_size * 4) as f64 / 1e6
696 );
697
698 let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
700 let lora_rank = config.lora_rank.unwrap_or(16);
701
702 let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
704 crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
705 })?;
706
707 let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
709 crate::error::Error::ConfigError(format!(
710 "NF4 LoRA grad workspace alloc failed: {e:?}"
711 ))
712 })?;
713
714 let mut lora_opt_states = Vec::with_capacity(num_layers);
716 for (i, block) in cuda_blocks.iter().enumerate() {
717 lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
718 crate::error::Error::ConfigError(format!(
719 "Block {i} LoRA opt state failed: {e:?}"
720 ))
721 })?);
722 }
723
724 println!(
725 " ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
726 );
727 (Some(scratch), Some(grad_ws), Some(lora_opt_states))
728 } else {
729 (None, None, None)
730 };
731
732 let embed_optimizer =
735 AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
736
737 let grad_accum = if config.accumulation_steps > 1 {
740 let kv_hidden = mc.num_kv_heads * mc.head_dim();
741 let block_sizes =
742 super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
743 hidden_size,
744 kv_hidden,
745 mc.intermediate_size,
746 );
747 let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
748 num_layers,
749 block_sizes,
750 vocab_size,
751 hidden_size,
752 );
753 println!(
754 " ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
755 config.accumulation_steps,
756 (accum
757 .block_grads
758 .iter()
759 .map(super::grad_accumulator::BlockGradientSet::total_elements)
760 .sum::<usize>()
761 + accum.lm_head_grad.len()
762 + accum.final_norm_grad.len()
763 + accum.embedding_grad.len()) as f64
764 * 4.0
765 / 1e6,
766 );
767 Some(accum)
768 } else {
769 None
770 };
771
772 let gpu_grad_accum = if config.accumulation_steps > 1 {
775 match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
776 Ok(accum) => {
777 println!(" ✓ GPU gradient accumulation enabled (ALB-091)");
778 Some(accum)
779 }
780 Err(e) => {
781 eprintln!(
782 " [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
783 );
784 None
785 }
786 }
787 } else {
788 None
789 };
790
791 let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
794 let ws_max = hidden_size * mc.intermediate_size;
795 let lm_max = vocab_size * hidden_size;
796 vec![0.0f32; ws_max.max(lm_max)]
797 } else {
798 Vec::new()
799 };
800
801 let kv_hidden = mc.num_kv_heads * mc.head_dim();
804 let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
805
806 let grad_scaler = GradScaler::from_config(&config.precision_config);
808 if config.precision_config.is_mixed() {
809 println!(
810 " ✓ Mixed precision: {} (loss scale={}, dynamic={})",
811 config.precision_config.compute_precision,
812 grad_scaler.scale(),
813 grad_scaler.is_dynamic(),
814 );
815 }
816
817 Ok(Self {
818 model,
819 cuda_trainer,
820 cuda_blocks,
821 cuda_grad_workspace,
822 nf4_shared_scratch,
823 nf4_lora_grad_workspace,
824 nf4_lora_optimizer_states,
825 gpu_training,
826 lm_head_weight_gpu,
827 lm_head_grad_gpu,
828 lm_head_m,
829 lm_head_v,
830 final_norm_m,
831 final_norm_v,
832 embed_optimizer,
833 profiler: if config.profile_interval > 0 {
835 StepProfiler::new(true, config.profile_interval)
836 } else {
837 StepProfiler::disabled()
838 },
839 config,
840 metrics: MetricsTracker::new(),
841 step: 0,
842 accumulated_loss: 0.0,
843 accumulated_batches: 0,
844 last_grad_norm: 0.0,
845 last_embed_grad_norm: 0.0,
846 grad_accum,
847 gpu_grad_accum,
848 grad_scaler,
849 fwd_scratch_a,
850 fwd_scratch_b,
851 h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
852 d2h_staging,
853 fused_clip,
854 final_norm_zero_buf: vec![0.0f32; hidden_size],
855 })
856 }
857
858 #[allow(clippy::too_many_arguments)]
860 fn upload_blocks(
861 model: &Transformer,
862 mc: &crate::transformer::TransformerConfig,
863 config: &TransformerTrainConfig,
864 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
865 use_nf4: bool,
866 num_layers: usize,
867 hidden_size: usize,
868 max_seq_len: usize,
869 ) -> crate::Result<Vec<CudaBlock>> {
870 let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
871
872 if use_nf4 {
873 let lora_rank = config.lora_rank.unwrap_or(16);
874 let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
875 let lora_scale = lora_alpha / lora_rank as f32;
876 let head_dim = mc.head_dim();
877 let q_dim = mc.num_attention_heads * head_dim;
878 let kv_hidden = mc.num_kv_heads * head_dim;
879
880 for (i, layer) in model.layers.iter().enumerate() {
881 let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
882 .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
883 .collect();
884 let lora_b_q = vec![0.0f32; lora_rank * q_dim];
885 let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
886 .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
887 .collect();
888 let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
889
890 let q_norm_data = layer
891 .self_attn
892 .q_norm
893 .as_ref()
894 .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
895 let k_norm_data = layer
896 .self_attn
897 .k_norm
898 .as_ref()
899 .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
900
901 let block = crate::transformer::CudaNf4TransformerBlock::new(
902 mc,
903 i,
904 ctx.clone(),
905 layer.input_norm.weight.data().as_slice().expect("contiguous"),
906 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
907 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
908 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
909 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
910 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
911 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
912 layer.ffn.w_up.data().as_slice().expect("contiguous"),
913 layer.ffn.w_down.data().as_slice().expect("contiguous"),
914 max_seq_len,
915 Some((&lora_a_q, &lora_b_q)),
916 Some((&lora_a_v, &lora_b_v)),
917 lora_scale,
918 lora_rank,
919 q_norm_data.as_deref(),
920 k_norm_data.as_deref(),
921 )
922 .map_err(|e| {
923 crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
924 })?;
925 cuda_blocks.push(CudaBlock::Nf4(block));
926 }
927 println!(" ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
928 } else {
929 for (i, layer) in model.layers.iter().enumerate() {
930 let b_q = layer
935 .self_attn
936 .b_q
937 .as_ref()
938 .map(|t| t.data().as_slice().expect("contiguous b_q").to_vec());
939 let b_k = layer
940 .self_attn
941 .b_k
942 .as_ref()
943 .map(|t| t.data().as_slice().expect("contiguous b_k").to_vec());
944 let b_v = layer
945 .self_attn
946 .b_v
947 .as_ref()
948 .map(|t| t.data().as_slice().expect("contiguous b_v").to_vec());
949 let block = CudaTransformerBlock::new(
950 mc,
951 i,
952 ctx.clone(),
953 layer.input_norm.weight.data().as_slice().expect("contiguous"),
954 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
955 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
956 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
957 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
958 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
959 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
960 layer.ffn.w_up.data().as_slice().expect("contiguous"),
961 layer.ffn.w_down.data().as_slice().expect("contiguous"),
962 max_seq_len,
963 b_q.as_deref(),
964 b_k.as_deref(),
965 b_v.as_deref(),
966 )
967 .map_err(|e| {
968 crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
969 })?;
970 cuda_blocks.push(CudaBlock::Fp32(block));
971 }
972 println!(" ✓ {num_layers} transformer blocks uploaded to GPU");
973 }
974
975 Ok(cuda_blocks)
976 }
977
978 fn init_fused_clip(
980 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
981 config: &TransformerTrainConfig,
982 hidden_size: usize,
983 kv_hidden: usize,
984 mc: &crate::transformer::TransformerConfig,
985 ) -> Option<FusedClipState> {
986 config.base.max_grad_norm?;
987 let grad_sizes: [u32; 9] = [
988 (hidden_size * hidden_size) as u32,
989 (hidden_size * kv_hidden) as u32,
990 (hidden_size * kv_hidden) as u32,
991 (hidden_size * hidden_size) as u32,
992 (hidden_size * mc.intermediate_size) as u32,
993 (hidden_size * mc.intermediate_size) as u32,
994 (mc.intermediate_size * hidden_size) as u32,
995 hidden_size as u32,
996 hidden_size as u32,
997 ];
998 match FusedClipState::new(ctx, &grad_sizes) {
999 Ok(state) => {
1000 println!(
1001 " ✓ Fused gradient clipping: {} partials ({:.1} KB)",
1002 state.total_partials,
1003 f64::from(state.total_partials) * 4.0 / 1024.0,
1004 );
1005 Some(state)
1006 }
1007 Err(e) => {
1008 println!(" ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
1009 None
1010 }
1011 }
1012 }
1013
1014 fn train_step_single(
1023 &mut self,
1024 input_ids: &[u32],
1025 target_ids: &[u32],
1026 accumulate_only: bool,
1027 ) -> Option<f32> {
1028 self.profiler.begin_step();
1029 let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1030 self.profiler.finish_step();
1031 result
1032 }
1033
1034 fn train_step_inner(
1036 &mut self,
1037 input_ids: &[u32],
1038 target_ids: &[u32],
1039 accumulate_only: bool,
1040 ) -> Option<f32> {
1041 let hidden_size = self.config.model_config.hidden_size;
1042 let vocab_size = self.config.model_config.vocab_size;
1043
1044 let max_sl = self.config.max_seq_len;
1046 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1047 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1048 let seq_len = input_ids.len();
1049
1050 if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1053 eprintln!(
1054 "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1055 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1056 );
1057 return None;
1058 }
1059
1060 self.profiler.begin(StepProfiler::LOSS);
1063 let stream = self.cuda_trainer.stream();
1064
1065 let mut loss_scale = 1.0 / seq_len as f32;
1074 if self.config.accumulation_steps > 1 {
1075 loss_scale /= self.config.accumulation_steps as f32;
1076 }
1077
1078 let loss_val = fused_cross_entropy_cuda(
1080 &mut self.gpu_training.logits_buf,
1081 target_ids,
1082 seq_len as u32,
1083 vocab_size as u32,
1084 loss_scale,
1085 stream,
1086 )
1087 .ok()?;
1088
1089 if !loss_val.is_finite() {
1091 return None;
1092 }
1093 self.profiler.end(StepProfiler::LOSS);
1094
1095 if let Some(grad_output_is_a) =
1104 self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1105 {
1106 self.profiler.begin(StepProfiler::EMBED_BWD);
1108 self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1109
1110 self.profiler.end(StepProfiler::EMBED_BWD);
1111 }
1112
1113 Some(loss_val)
1114 }
1115
1116 #[allow(unsafe_code)]
1121 fn gpu_forward(
1122 &mut self,
1123 input_ids: &[u32],
1124 seq_len: usize,
1125 hidden_size: usize,
1126 vocab_size: usize,
1127 ) -> Option<()> {
1128 contract_pre_gpu_forward!();
1129 let stream = self.cuda_trainer.stream();
1130
1131 self.profiler.begin(StepProfiler::EMBED);
1133 let hidden = self.model.embed_tokens.forward(input_ids);
1134 let hidden_slice = hidden.data().as_slice()?;
1135 self.profiler.end(StepProfiler::EMBED);
1136
1137 self.profiler.begin(StepProfiler::H2D);
1142 self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1143 self.h2d_staging[hidden_slice.len()..].fill(0.0);
1144 if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1145 eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1146 return None;
1147 }
1148 self.profiler.end(StepProfiler::H2D);
1149
1150 self.profiler.begin(StepProfiler::FORWARD);
1154 let mut input_is_a = true; for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1156 let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1159 if input_is_a {
1160 (
1161 std::ptr::from_ref(&self.fwd_scratch_a),
1162 std::ptr::from_mut(&mut self.fwd_scratch_b),
1163 )
1164 } else {
1165 (
1166 std::ptr::from_ref(&self.fwd_scratch_b),
1167 std::ptr::from_mut(&mut self.fwd_scratch_a),
1168 )
1169 };
1170 if self.gpu_training.saved_layer_mask[i] {
1171 unsafe {
1174 self.gpu_training.layer_inputs[i]
1175 .copy_from_buffer_async(&*input_ptr, stream)
1176 .ok()?;
1177 }
1178 }
1179 self.profiler.begin_layer();
1182 unsafe {
1183 block
1184 .forward(
1185 &*input_ptr,
1186 &mut *output_ptr,
1187 seq_len,
1188 stream,
1189 self.nf4_shared_scratch.as_mut(),
1190 )
1191 .ok()?;
1192 }
1193 self.profiler.end_layer_fwd(i);
1194 input_is_a = !input_is_a;
1195 }
1196 self.profiler.end(StepProfiler::FORWARD);
1197
1198 let final_output: &GpuBuffer<f32> =
1200 if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1201
1202 self.profiler.begin(StepProfiler::NORM_LM);
1205 unsafe {
1206 self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1207 }
1208
1209 rms_norm_forward_with_eps(
1214 final_output,
1215 &self.gpu_training.final_norm_weight,
1216 &mut self.gpu_training.norm_output,
1217 seq_len as u32,
1218 hidden_size as u32,
1219 self.config.model_config.rms_norm_eps,
1220 stream,
1221 )
1222 .ok()?;
1223
1224 gemm_forward(
1228 &self.gpu_training.norm_output,
1229 &self.lm_head_weight_gpu,
1230 &mut self.gpu_training.logits_buf,
1231 seq_len as u32,
1232 hidden_size as u32,
1233 vocab_size as u32,
1234 stream,
1235 )
1236 .ok()?;
1237
1238 self.profiler.end(StepProfiler::NORM_LM);
1241
1242 Some(())
1243 }
1244
1245 pub fn forward_backward_with_grad(
1283 &mut self,
1284 input_ids: &[u32],
1285 logit_gradient: &[f32],
1286 ) -> Option<()> {
1287 let seq_len = input_ids.len();
1288 let hidden_size = self.config.model_config.hidden_size;
1289 let vocab_size = self.config.model_config.vocab_size;
1290
1291 if seq_len == 0 || seq_len > self.config.max_seq_len {
1292 return None;
1293 }
1294 if logit_gradient.len() != vocab_size {
1295 eprintln!(
1296 "[forward_backward_with_grad] gradient len {} != vocab_size {}",
1297 logit_gradient.len(),
1298 vocab_size
1299 );
1300 return None;
1301 }
1302
1303 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1304
1305 let offset = (seq_len - 1) * vocab_size;
1309 self.gpu_training.logits_buf.copy_from_host_at(logit_gradient, offset).ok()?;
1310 let stream = self.cuda_trainer.stream();
1311 stream.synchronize().ok()?;
1312
1313 let grad_output_is_a = self.gpu_backward(seq_len, hidden_size, vocab_size, false)?;
1316 self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1320
1321 Some(())
1322 }
1323
1324 pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1333 let seq_len = input_ids.len();
1334 let hidden_size = self.config.model_config.hidden_size;
1335 let vocab_size = self.config.model_config.vocab_size;
1336
1337 if seq_len == 0 || seq_len > self.config.max_seq_len {
1338 return None;
1339 }
1340
1341 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1343
1344 let stream = self.cuda_trainer.stream();
1346 stream.synchronize().ok()?;
1347
1348 let offset = (seq_len - 1) * vocab_size;
1350 let mut logits = vec![0.0f32; vocab_size];
1351 self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1352
1353 Some(logits)
1354 }
1355
1356 #[allow(unsafe_code)]
1373 fn recompute_segment(
1374 gpu_training: &mut GpuPretrainState,
1375 cuda_blocks: &mut [CudaBlock],
1376 nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1377 target_layer: usize,
1378 seq_len: usize,
1379 stream: &CudaStream,
1380 ) -> Option<()> {
1381 let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1383
1384 if seg_start == target_layer {
1385 return Some(()); }
1387
1388 let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1391 unsafe {
1392 recompute_buf
1393 .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1394 .ok()?;
1395 }
1396
1397 for i in seg_start..target_layer {
1408 if i == seg_start {
1409 let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1411 let li = &mut gpu_training.layer_inputs;
1412 unsafe {
1413 cuda_blocks[i]
1414 .forward(
1415 &*recompute_ptr,
1416 &mut li[i + 1],
1417 seq_len,
1418 stream,
1419 nf4_shared_scratch.as_mut(),
1420 )
1421 .ok()?;
1422 }
1423 } else {
1424 let li = &mut gpu_training.layer_inputs;
1426 let (left, right) = li.split_at_mut(i + 1);
1427 cuda_blocks[i]
1428 .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1429 .ok()?;
1430 }
1431 }
1432
1433 Some(())
1434 }
1435
1436 #[allow(unsafe_code)]
1449 fn gpu_backward(
1450 &mut self,
1451 seq_len: usize,
1452 hidden_size: usize,
1453 vocab_size: usize,
1454 accumulate_only: bool,
1455 ) -> Option<bool> {
1456 let stream = self.cuda_trainer.stream();
1457 let max_grad_norm = self.config.base.max_grad_norm;
1458 let lr = self.current_lr();
1459 let beta1 = self.config.beta1;
1461 let beta2 = self.config.beta2;
1462 let weight_decay = self.config.weight_decay;
1463
1464 self.profiler.begin(StepProfiler::LM_BWD);
1469 gemm_backward_a(
1470 &self.gpu_training.logits_buf,
1471 &self.lm_head_weight_gpu,
1472 &mut self.gpu_training.lm_head_grad_hidden,
1473 seq_len as u32,
1474 hidden_size as u32,
1475 vocab_size as u32,
1476 stream,
1477 )
1478 .ok()?;
1479
1480 gemm_backward_b(
1481 &self.gpu_training.norm_output,
1482 &self.gpu_training.logits_buf,
1483 &mut self.lm_head_grad_gpu,
1484 seq_len as u32,
1485 hidden_size as u32,
1486 vocab_size as u32,
1487 stream,
1488 )
1489 .ok()?;
1490
1491 let lm_sq_norm =
1497 squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1498 .unwrap_or(0.0);
1499 let lm_norm = lm_sq_norm.sqrt(); self.last_grad_norm = lm_norm; if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1503 eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1504 let gh_sq = squared_sum_cuda(
1506 &self.gpu_training.lm_head_grad_hidden,
1507 self.gpu_training.lm_head_grad_hidden.len() as u32,
1508 stream,
1509 )
1510 .unwrap_or(0.0);
1511 eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1512 }
1513 if let Some(max_norm) = max_grad_norm {
1514 let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1515 let n = self.lm_head_grad_gpu.len() as u32;
1516 let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1517 }
1518 self.profiler.end(StepProfiler::LM_BWD);
1519
1520 self.profiler.begin(StepProfiler::NORM_BWD);
1522 self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1524 rms_norm_backward(
1525 &self.gpu_training.blocks_output,
1526 &self.gpu_training.final_norm_weight,
1527 &self.gpu_training.lm_head_grad_hidden,
1528 &mut self.gpu_training.grad_buf_a,
1529 &mut self.gpu_training.grad_final_norm_weight,
1530 seq_len as u32,
1531 hidden_size as u32,
1532 1e-5_f32,
1533 stream,
1534 )
1535 .ok()?;
1536
1537 if let Some(max_norm) = max_grad_norm {
1540 let (scale, _) = Self::compute_clip_scale_with_norm(
1541 &self.gpu_training.grad_final_norm_weight,
1542 max_norm,
1543 stream,
1544 );
1545 let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1546 let _ =
1547 gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1548 }
1549 self.profiler.end(StepProfiler::NORM_BWD);
1550
1551 if accumulate_only {
1553 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1555 let _ = gpu_accum.accumulate_nonblock(
1556 &self.lm_head_grad_gpu,
1557 &self.gpu_training.grad_final_norm_weight,
1558 stream,
1559 );
1560 } else {
1561 stream.synchronize().ok()?;
1562 Self::download_nonblock_grads_to_accum(
1563 &self.lm_head_grad_gpu,
1564 &self.gpu_training.grad_final_norm_weight,
1565 &mut self.grad_accum,
1566 &mut self.d2h_staging,
1567 )?;
1568 }
1569 } else {
1570 Self::run_nonblock_optimizer_step(
1571 &mut self.gpu_training,
1572 Some(&mut self.lm_head_weight_gpu),
1573 &self.lm_head_grad_gpu,
1574 &mut self.lm_head_m,
1575 &mut self.lm_head_v,
1576 &mut self.final_norm_m,
1577 &mut self.final_norm_v,
1578 lr,
1579 beta1,
1580 beta2,
1581 weight_decay,
1582 stream,
1583 );
1584 }
1585
1586 self.profiler.begin(StepProfiler::BLK_BWD);
1592 let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1593 let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1594 let mut grad_output_is_a = true;
1595 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1596
1597 for layer_idx in (0..self.cuda_blocks.len()).rev() {
1598 if !self.gpu_training.saved_layer_mask[layer_idx] {
1601 Self::recompute_segment(
1602 &mut self.gpu_training,
1603 &mut self.cuda_blocks,
1604 &mut self.nf4_shared_scratch,
1605 layer_idx,
1606 seq_len,
1607 stream,
1608 )?;
1609 }
1610
1611 let (grad_output, grad_input) = unsafe {
1612 if grad_output_is_a {
1613 (&*grad_a_ptr, &mut *grad_b_ptr)
1614 } else {
1615 (&*grad_b_ptr, &mut *grad_a_ptr)
1616 }
1617 };
1618
1619 self.profiler.begin_layer();
1620 if use_nf4 {
1621 let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1625 grad_b_ptr } else {
1627 grad_a_ptr
1628 };
1629 match self.cuda_blocks[layer_idx].backward_nf4(
1632 &self.gpu_training.layer_inputs[layer_idx],
1633 grad_output,
1634 grad_input,
1635 &mut self.gpu_training.blocks_output, seq_len,
1637 stream,
1638 self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1639 self.nf4_lora_grad_workspace
1640 .as_mut()
1641 .expect("NF4 requires LoRA grad workspace"),
1642 ) {
1643 Ok(()) => {}
1644 Err(e) => {
1645 eprintln!(
1646 "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1647 layer_idx, e, seq_len, self.config.model_config.hidden_size
1648 );
1649 return None;
1650 }
1651 }
1652
1653 if let Some(max_norm) = max_grad_norm {
1657 self.nf4_lora_grad_workspace
1658 .as_mut()
1659 .expect("NF4 requires LoRA grad ws")
1660 .clip_gradients(max_norm, stream);
1661 }
1662
1663 {
1669 let step = self.gpu_training.step;
1670 let effective_lr = if accumulate_only {
1671 lr / self.config.accumulation_steps as f32
1672 } else {
1673 lr
1674 };
1675 if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1676 let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1677 &mut opt_states[layer_idx],
1678 step,
1679 effective_lr,
1680 beta1,
1681 beta2,
1682 1e-8,
1683 weight_decay,
1684 stream,
1685 self.nf4_lora_grad_workspace
1686 .as_ref()
1687 .expect("NF4 requires LoRA grad ws"),
1688 );
1689 }
1690 }
1691 } else {
1692 self.cuda_blocks[layer_idx]
1694 .backward(
1695 &self.gpu_training.layer_inputs[layer_idx],
1696 grad_output,
1697 grad_input,
1698 seq_len,
1699 stream,
1700 &mut self.cuda_grad_workspace,
1701 )
1702 .ok()?;
1703
1704 if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1710 let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1711 &self.cuda_grad_workspace,
1712 f32::MAX,
1713 stream,
1714 );
1715 let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1717 .unwrap_or(0.0);
1718 let act_gnorm = act_sq.sqrt();
1719 eprintln!(
1720 "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1721 );
1722 }
1723
1724 if accumulate_only {
1726 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1728 let _ = gpu_accum.accumulate_block(
1729 &self.cuda_grad_workspace,
1730 layer_idx,
1731 stream,
1732 );
1733 } else {
1734 stream.synchronize().ok()?;
1736 if let Some(accum) = &mut self.grad_accum {
1737 Self::download_workspace_to_accum(
1738 &self.cuda_grad_workspace,
1739 accum,
1740 layer_idx,
1741 &mut self.d2h_staging,
1742 )?;
1743 }
1744 }
1745 } else {
1746 let step = self.gpu_training.step;
1748 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1749 &mut self.gpu_training.optimizer_states[layer_idx],
1750 step,
1751 lr,
1752 beta1,
1753 beta2,
1754 1e-8,
1755 weight_decay,
1756 stream,
1757 &self.cuda_grad_workspace,
1758 );
1759 }
1760 }
1761
1762 self.profiler.end_layer_bwd(layer_idx);
1763 grad_output_is_a = !grad_output_is_a;
1764 }
1765
1766 stream.synchronize().ok()?;
1767 self.profiler.end(StepProfiler::BLK_BWD);
1768
1769 Some(grad_output_is_a)
1770 }
1771
1772 fn download_nonblock_grads_to_accum(
1778 lm_head_grad: &GpuBuffer<f32>,
1779 final_norm_grad: &GpuBuffer<f32>,
1780 grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1781 host: &mut [f32],
1782 ) -> Option<()> {
1783 let accum = grad_accum.as_mut()?;
1784
1785 let lm_slice = &mut host[..lm_head_grad.len()];
1786 lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1787 for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1788 *d += s;
1789 }
1790
1791 let norm_slice = &mut host[..final_norm_grad.len()];
1792 final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1793 for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1794 *d += s;
1795 }
1796 Some(())
1797 }
1798
1799 #[allow(clippy::too_many_arguments)]
1802 fn run_nonblock_optimizer_step(
1803 gpu_training: &mut GpuPretrainState,
1804 lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1805 lm_head_grad_gpu: &GpuBuffer<f32>,
1806 lm_head_m: &mut GpuBuffer<f32>,
1807 lm_head_v: &mut GpuBuffer<f32>,
1808 final_norm_m: &mut GpuBuffer<f32>,
1809 final_norm_v: &mut GpuBuffer<f32>,
1810 lr: f32,
1811 beta1: f32,
1812 beta2: f32,
1813 weight_decay: f32,
1814 stream: &CudaStream,
1815 ) {
1816 gpu_training.step += 1;
1817 let step = gpu_training.step;
1818
1819 if let Some(lm_head_weight) = lm_head_weight_gpu {
1820 let n_lm = lm_head_weight.len() as u32;
1821 let _ = adamw_step_cuda(
1822 lm_head_weight,
1823 lm_head_grad_gpu,
1824 lm_head_m,
1825 lm_head_v,
1826 lr,
1827 beta1,
1828 beta2,
1829 1e-8,
1830 weight_decay,
1831 step,
1832 n_lm,
1833 stream,
1834 );
1835 }
1836
1837 let n_norm = gpu_training.final_norm_weight.len() as u32;
1838 let _ = adamw_step_cuda(
1839 &mut gpu_training.final_norm_weight,
1840 &gpu_training.grad_final_norm_weight,
1841 final_norm_m,
1842 final_norm_v,
1843 lr,
1844 beta1,
1845 beta2,
1846 1e-8,
1847 weight_decay,
1848 step,
1849 n_norm,
1850 stream,
1851 );
1852 }
1853
1854 fn download_workspace_to_accum(
1862 ws: &CudaGradWorkspace,
1863 accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1864 layer_idx: usize,
1865 host: &mut [f32],
1866 ) -> Option<()> {
1867 let bg = &mut accum.block_grads[layer_idx];
1868
1869 use super::grad_accumulator::component;
1870 let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1871 (&ws.grad_w_q, component::W_Q),
1872 (&ws.grad_w_k, component::W_K),
1873 (&ws.grad_w_v, component::W_V),
1874 (&ws.grad_w_o, component::W_O),
1875 (&ws.grad_gate, component::GATE),
1876 (&ws.grad_up, component::UP),
1877 (&ws.grad_down, component::DOWN),
1878 (&ws.grad_input_norm, component::INPUT_NORM),
1879 (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1880 ];
1881
1882 for (gpu_buf, comp_idx) in &bufs_and_components {
1883 let slice = &mut host[..gpu_buf.len()];
1884 gpu_buf.copy_to_host_at(slice, 0).ok()?;
1885 for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1886 *d += s;
1887 }
1888 }
1889 Some(())
1890 }
1891
1892 fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1899 let stream = self.cuda_trainer.stream();
1900 let lr = self.current_lr();
1901 let beta1 = self.config.beta1;
1902 let beta2 = self.config.beta2;
1903 let weight_decay = self.config.weight_decay;
1904
1905 stream.synchronize().ok()?;
1907
1908 self.gpu_training.step += 1;
1909 let step = self.gpu_training.step;
1910
1911 let gpu_accum = self.gpu_grad_accum.as_ref()?;
1913 for layer_idx in 0..self.cuda_blocks.len() {
1914 gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1915
1916 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1917 &mut self.gpu_training.optimizer_states[layer_idx],
1918 step,
1919 lr,
1920 beta1,
1921 beta2,
1922 1e-8,
1923 weight_decay,
1924 stream,
1925 &self.cuda_grad_workspace,
1926 );
1927 }
1928
1929 gpu_accum
1931 .upload_nonblock(
1932 &mut self.lm_head_grad_gpu,
1933 &mut self.gpu_training.grad_final_norm_weight,
1934 )
1935 .ok()?;
1936
1937 let n_lm = self.lm_head_weight_gpu.len() as u32;
1938 let _ = adamw_step_cuda(
1939 &mut self.lm_head_weight_gpu,
1940 &self.lm_head_grad_gpu,
1941 &mut self.lm_head_m,
1942 &mut self.lm_head_v,
1943 lr,
1944 beta1,
1945 beta2,
1946 1e-8,
1947 weight_decay,
1948 step,
1949 n_lm,
1950 stream,
1951 );
1952
1953 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1955 let _ = adamw_step_cuda(
1956 &mut self.gpu_training.final_norm_weight,
1957 &self.gpu_training.grad_final_norm_weight,
1958 &mut self.final_norm_m,
1959 &mut self.final_norm_v,
1960 lr,
1961 beta1,
1962 beta2,
1963 1e-8,
1964 weight_decay,
1965 step,
1966 n_norm,
1967 stream,
1968 );
1969
1970 stream.synchronize().ok()?;
1971
1972 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1974 let _ = gpu_accum.zero_all();
1975 }
1976
1977 Some(())
1978 }
1979
1980 #[allow(unsafe_code)]
1981 fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1982 let stream = self.cuda_trainer.stream();
1983 let lr = self.current_lr();
1984 let beta1 = self.config.beta1;
1985 let beta2 = self.config.beta2;
1986 let weight_decay = self.config.weight_decay;
1987
1988 let accum = self.grad_accum.as_mut()?;
1990 accum.average();
1991
1992 if accum.has_non_finite() {
1994 println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1995 accum.zero_all();
1996 return Some(());
1997 }
1998
1999 self.gpu_training.step += 1;
2000 let step = self.gpu_training.step;
2001
2002 use super::grad_accumulator::component;
2004 for layer_idx in 0..self.cuda_blocks.len() {
2005 let bg = &accum.block_grads[layer_idx];
2006
2007 unsafe {
2011 self.cuda_grad_workspace
2012 .grad_w_q
2013 .copy_from_host_async(&bg.components[component::W_Q], stream)
2014 .ok()?;
2015 self.cuda_grad_workspace
2016 .grad_w_k
2017 .copy_from_host_async(&bg.components[component::W_K], stream)
2018 .ok()?;
2019 self.cuda_grad_workspace
2020 .grad_w_v
2021 .copy_from_host_async(&bg.components[component::W_V], stream)
2022 .ok()?;
2023 self.cuda_grad_workspace
2024 .grad_w_o
2025 .copy_from_host_async(&bg.components[component::W_O], stream)
2026 .ok()?;
2027 self.cuda_grad_workspace
2028 .grad_gate
2029 .copy_from_host_async(&bg.components[component::GATE], stream)
2030 .ok()?;
2031 self.cuda_grad_workspace
2032 .grad_up
2033 .copy_from_host_async(&bg.components[component::UP], stream)
2034 .ok()?;
2035 self.cuda_grad_workspace
2036 .grad_down
2037 .copy_from_host_async(&bg.components[component::DOWN], stream)
2038 .ok()?;
2039 self.cuda_grad_workspace
2040 .grad_input_norm
2041 .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
2042 .ok()?;
2043 self.cuda_grad_workspace
2044 .grad_post_attn_norm
2045 .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
2046 .ok()?;
2047 }
2048
2049 let _ = self.cuda_blocks[layer_idx].optimizer_step(
2051 &mut self.gpu_training.optimizer_states[layer_idx],
2052 step,
2053 lr,
2054 beta1,
2055 beta2,
2056 1e-8,
2057 weight_decay,
2058 stream,
2059 &self.cuda_grad_workspace,
2060 );
2061 }
2062
2063 unsafe {
2067 self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
2068 }
2069 let n_lm = self.lm_head_weight_gpu.len() as u32;
2070 let _ = adamw_step_cuda(
2071 &mut self.lm_head_weight_gpu,
2072 &self.lm_head_grad_gpu,
2073 &mut self.lm_head_m,
2074 &mut self.lm_head_v,
2075 lr,
2076 beta1,
2077 beta2,
2078 1e-8,
2079 weight_decay,
2080 step,
2081 n_lm,
2082 stream,
2083 );
2084
2085 unsafe {
2088 self.gpu_training
2089 .grad_final_norm_weight
2090 .copy_from_host_async(&accum.final_norm_grad, stream)
2091 .ok()?;
2092 }
2093 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
2094 let _ = adamw_step_cuda(
2095 &mut self.gpu_training.final_norm_weight,
2096 &self.gpu_training.grad_final_norm_weight,
2097 &mut self.final_norm_m,
2098 &mut self.final_norm_v,
2099 lr,
2100 beta1,
2101 beta2,
2102 1e-8,
2103 weight_decay,
2104 step,
2105 n_norm,
2106 stream,
2107 );
2108
2109 stream.synchronize().ok()?;
2110
2111 accum.zero_all();
2113 Some(())
2114 }
2115
2116 fn compute_clip_scale_with_norm(
2129 buf: &GpuBuffer<f32>,
2130 max_norm: f32,
2131 stream: &CudaStream,
2132 ) -> (f32, f32) {
2133 let n = buf.len() as u32;
2134 let grad_norm = match squared_sum_cuda(buf, n, stream) {
2136 Ok(norm) => norm,
2137 Err(_) => {
2138 let mut host = vec![0.0f32; buf.len()];
2140 if buf.copy_to_host_at(&mut host, 0).is_err() {
2141 return (1.0, 0.0);
2142 }
2143 let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2144 sq_sum.sqrt() as f32
2145 }
2146 };
2147 let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2148 (scale, grad_norm)
2149 }
2150
2151 #[allow(unsafe_code)]
2160 fn embed_backward(
2161 &mut self,
2162 input_ids: &[u32],
2163 _seq_len: usize,
2164 hidden_size: usize,
2165 vocab_size: usize,
2166 grad_output_is_a: bool,
2167 ) -> Option<()> {
2168 let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2170 let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2171 let embed_grad_buf = unsafe {
2172 if grad_output_is_a {
2173 &*grad_a_ptr
2174 } else {
2175 &*grad_b_ptr
2176 }
2177 };
2178 let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2179
2180 let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2188 {
2189 let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2190 let grad_norm = sq_sum.sqrt() as f32;
2191 self.last_embed_grad_norm = grad_norm; if grad_norm > embed_clip_norm {
2193 let scale = embed_clip_norm / grad_norm;
2194 for g in &mut embed_grad_data {
2195 *g *= scale;
2196 }
2197 }
2198 }
2199
2200 let embed_weight = &mut self.model.embed_tokens.weight;
2204 let grad_cell = embed_weight.grad_cell();
2205 let mut grad_ref = grad_cell.borrow_mut();
2206 if grad_ref.is_none() {
2207 *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2208 }
2209 if let Some(grad) = grad_ref.as_mut() {
2210 for (pos, &token_id) in input_ids.iter().enumerate() {
2211 let tid = token_id as usize;
2212 if tid < vocab_size {
2213 let src = pos * hidden_size;
2214 let dst = tid * hidden_size;
2215 for h in 0..hidden_size {
2216 grad[dst + h] += embed_grad_data[src + h];
2217 }
2218 }
2219 }
2220 }
2221 Some(())
2222 }
2223
2224 fn optimizer_step(&mut self) {
2230 self.grad_scaler.update(true);
2234
2235 self.embed_optimizer.set_lr(self.current_lr());
2237 let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2239 self.embed_optimizer.step_refs(&mut embed_params);
2240
2241 self.step += 1;
2242 self.metrics.losses.push(self.accumulated_loss);
2243 self.metrics.increment_step();
2244
2245 self.accumulated_loss = 0.0;
2246 self.accumulated_batches = 0;
2247 }
2248
2249 pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2261 if batch.batch_size == 0 {
2262 return 0.0;
2263 }
2264
2265 let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2266
2267 if self.accumulated_batches == 0 {
2268 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2270 }
2271
2272 let mut total_loss = 0.0;
2273 let mut valid_count = 0;
2274
2275 for i in 0..batch.batch_size {
2276 let Some(input_ids) = batch.get_input(i) else {
2277 continue;
2278 };
2279 let Some(target_ids) = batch.get_target(i) else {
2280 continue;
2281 };
2282
2283 if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2287 total_loss += loss;
2288 valid_count += 1;
2289 if accumulating {
2290 if let Some(accum) = &mut self.gpu_grad_accum {
2291 accum.accumulated_count += 1;
2292 } else if let Some(accum) = &mut self.grad_accum {
2293 accum.accumulated_count += 1;
2294 }
2295 }
2296 }
2297 }
2298
2299 let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2300
2301 if avg_loss == 0.0 && valid_count > 0 {
2303 eprintln!(
2304 "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2305 valid_count, total_loss, batch.batch_size
2306 );
2307 }
2308
2309 self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2310 self.accumulated_batches += 1;
2311
2312 if self.accumulated_batches >= self.config.accumulation_steps {
2313 if accumulating {
2314 if self.gpu_grad_accum.is_some() {
2316 self.gpu_optimizer_from_gpu_accum();
2317 } else {
2318 self.gpu_optimizer_from_accum();
2319 }
2320 }
2321 self.optimizer_step();
2322 }
2323
2324 avg_loss
2325 }
2326
2327 pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2331 let hidden_size = self.config.model_config.hidden_size;
2332 let vocab_size = self.config.model_config.vocab_size;
2333 let max_sl = self.config.max_seq_len;
2334 let mut total_loss = 0.0;
2335 let mut valid_count = 0;
2336 for i in 0..batch.batch_size {
2337 if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2338 {
2339 total_loss += loss;
2340 valid_count += 1;
2341 }
2342 }
2343 if valid_count > 0 {
2344 total_loss / valid_count as f32
2345 } else {
2346 0.0
2347 }
2348 }
2349
2350 fn eval_single_sequence(
2352 &mut self,
2353 batch: &LMBatch,
2354 i: usize,
2355 max_sl: usize,
2356 hidden_size: usize,
2357 vocab_size: usize,
2358 ) -> Option<f32> {
2359 let input_ids = batch.get_input(i)?;
2360 let target_ids = batch.get_target(i)?;
2361 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2363 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2364 let seq_len = input_ids.len();
2365 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2366 let stream = self.cuda_trainer.stream();
2367 let scale = 1.0 / seq_len as f32;
2368 let loss = fused_cross_entropy_cuda(
2369 &mut self.gpu_training.logits_buf,
2370 target_ids,
2371 seq_len as u32,
2372 vocab_size as u32,
2373 scale,
2374 stream,
2375 )
2376 .ok()?;
2377 if loss.is_finite() {
2378 Some(loss)
2379 } else {
2380 None
2381 }
2382 }
2383
2384 pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2386 self.train_epoch_with_callback(batches, |_, _, _| {})
2387 }
2388
2389 pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2393 where
2394 F: FnMut(usize, f32, &Self),
2395 {
2396 if batches.is_empty() {
2397 return 0.0;
2398 }
2399
2400 let mut total_loss = 0.0;
2401 let mut batches_processed = 0;
2402
2403 for (i, batch) in batches.iter().enumerate() {
2404 if let Some(max) = self.config.max_steps {
2405 if self.step >= max {
2406 break;
2407 }
2408 }
2409
2410 let batch_loss = self.train_batch(batch);
2411 total_loss += batch_loss;
2412 batches_processed += 1;
2413 on_batch(i, batch_loss, self);
2414 }
2415
2416 if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2418 self.profiler.print_report();
2419 }
2420
2421 total_loss / batches_processed.max(1) as f32
2422 }
2423
2424 pub(crate) fn ensure_grad_accum(&mut self) {
2431 if self.grad_accum.is_some() {
2432 return;
2433 }
2434 let mc = &self.config.model_config;
2435 let hidden_size = mc.hidden_size;
2436 let kv_hidden = mc.num_kv_heads * mc.head_dim();
2437 let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2438 hidden_size,
2439 kv_hidden,
2440 mc.intermediate_size,
2441 );
2442 self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2443 self.cuda_blocks.len(),
2444 block_sizes,
2445 mc.vocab_size,
2446 hidden_size,
2447 ));
2448 }
2449
2450 pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2455 if batch.batch_size == 0 {
2456 return 0.0;
2457 }
2458
2459 if self.accumulated_batches == 0 {
2460 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2461 }
2462
2463 let mut total_loss = 0.0;
2464 let mut valid_count = 0;
2465
2466 for i in 0..batch.batch_size {
2467 let Some(input_ids) = batch.get_input(i) else { continue };
2468 let Some(target_ids) = batch.get_target(i) else { continue };
2469
2470 if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2472 total_loss += loss;
2473 valid_count += 1;
2474 if let Some(accum) = &mut self.grad_accum {
2475 accum.accumulated_count += 1;
2476 }
2477 }
2478 }
2479
2480 if valid_count > 0 {
2481 total_loss / valid_count as f32
2482 } else {
2483 0.0
2484 }
2485 }
2486
2487 pub(crate) fn apply_ddp_gradients(&mut self) {
2493 self.accumulated_loss = 0.0;
2494 self.accumulated_batches = 0;
2495 self.gpu_optimizer_from_accum();
2496 self.optimizer_step();
2497 }
2498
2499 pub(crate) fn grad_accum_ref(
2501 &self,
2502 ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2503 self.grad_accum.as_ref()
2504 }
2505
2506 pub(crate) fn grad_accum_mut(
2508 &mut self,
2509 ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2510 self.grad_accum.as_mut()
2511 }
2512
2513 pub(crate) fn config(&self) -> &TransformerTrainConfig {
2515 &self.config
2516 }
2517
2518 pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2520 self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2521 }
2522
2523 pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2525 self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2526 }
2527
2528 pub fn reached_max_steps(&self) -> bool {
2530 self.config.max_steps.is_some_and(|max| self.step >= max)
2531 }
2532
2533 pub fn step(&self) -> usize {
2535 self.step
2536 }
2537
2538 pub fn set_initial_step(&mut self, step: usize) {
2544 self.step = step;
2545 self.gpu_training.step = step as u32;
2546 }
2547
2548 pub fn set_max_steps(&mut self, max_steps: usize) {
2554 self.config.max_steps = Some(max_steps);
2555 }
2556
2557 pub fn current_lr(&self) -> f32 {
2563 let base_lr = self.config.lr;
2564 if self.step < self.config.warmup_steps {
2565 base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2567 } else if let Some(max_steps) = self.config.max_steps {
2568 let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2570 if decay_steps == 0 {
2571 return base_lr;
2572 }
2573 let decay_step = self.step - self.config.warmup_steps;
2574 let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2575 0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2576 } else {
2577 base_lr
2579 }
2580 }
2581
2582 pub fn enable_profiler(&mut self, interval: usize) {
2593 self.profiler = StepProfiler::new(true, interval);
2594 }
2595
2596 pub fn print_profiler_report(&self) {
2598 self.profiler.print_report();
2599 }
2600
2601 pub fn last_grad_norm(&self) -> f32 {
2603 self.last_grad_norm
2604 }
2605
2606 pub fn param_grad_norms(&self) -> (f32, f32) {
2609 (self.last_grad_norm, self.last_embed_grad_norm)
2610 }
2611
2612 pub fn num_params(&self) -> usize {
2614 self.model.parameters().iter().map(|t| t.len()).sum()
2615 }
2616
2617 pub fn gpu_memory_mb(&self) -> (u64, u64) {
2619 match self.cuda_trainer.context().memory_info() {
2620 Ok((free, total)) => {
2621 let total_mb = (total / (1024 * 1024)) as u64;
2622 let used_mb = ((total - free) / (1024 * 1024)) as u64;
2623 (used_mb, total_mb)
2624 }
2625 Err(_) => (0, 0),
2626 }
2627 }
2628
2629 pub fn sync_weights_to_cpu(&mut self) {
2635 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2636
2637 if use_nf4 {
2638 } else {
2644 for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2645 if let Ok(weights) = block.download_weights() {
2646 let layer = &mut self.model.layers[layer_idx];
2647
2648 layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2649 layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2650 layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2651 layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2652
2653 layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2654 layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2655 layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2656
2657 layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2658 layer.post_attn_norm.weight =
2659 Tensor::from_vec(weights.post_attn_norm_weight, false);
2660 }
2661 }
2662 }
2663
2664 if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2666 self.model.norm.weight = Tensor::from_vec(norm_data, false);
2667 }
2668
2669 if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2676 self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2677 }
2678 }
2679
2680 pub fn model(&self) -> &Transformer {
2682 &self.model
2683 }
2684
2685 pub fn model_mut(&mut self) -> &mut Transformer {
2687 &mut self.model
2688 }
2689
2690 pub fn is_mixed_precision(&self) -> bool {
2692 self.config.precision_config.is_mixed()
2693 }
2694
2695 pub fn grad_scaler(&self) -> &GradScaler {
2697 &self.grad_scaler
2698 }
2699
2700 pub fn is_checkpointing(&self) -> bool {
2702 self.config.checkpoint_config.enabled
2703 }
2704
2705 pub fn save(
2707 &mut self,
2708 path: impl AsRef<std::path::Path>,
2709 name: &str,
2710 architecture: &str,
2711 ) -> crate::Result<()> {
2712 self.sync_weights_to_cpu();
2713
2714 let params: Vec<(String, Tensor)> = self
2716 .model
2717 .named_parameters()
2718 .into_iter()
2719 .map(|(name, tensor)| (name, tensor.clone()))
2720 .collect();
2721
2722 let metadata = ModelMetadata::new(name, architecture);
2723 let model = Model::new(metadata, params);
2724 let config = SaveConfig::new(ModelFormat::SafeTensors);
2725
2726 save_model(&model, path, &config)
2727 }
2728
2729 pub fn prepare_async_save(
2733 &mut self,
2734 name: &str,
2735 architecture: &str,
2736 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2737 self.sync_weights_to_cpu();
2738
2739 let param_data: Vec<(String, Vec<f32>)> = self
2741 .model
2742 .named_parameters()
2743 .into_iter()
2744 .map(|(n, t)| (n, t.data().to_vec()))
2745 .collect();
2746
2747 let name = name.to_string();
2748 let architecture = architecture.to_string();
2749
2750 Box::new(move |path: &std::path::Path| {
2751 let params: Vec<(String, Tensor)> =
2752 param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2753 let metadata = ModelMetadata::new(&name, &architecture);
2754 let model = Model::new(metadata, params);
2755 let config = SaveConfig::new(ModelFormat::SafeTensors);
2756 save_model(&model, path, &config)
2757 })
2758 }
2759
2760 pub fn save_apr(
2765 &mut self,
2766 path: impl AsRef<std::path::Path>,
2767 name: &str,
2768 architecture: &str,
2769 ) -> crate::Result<()> {
2770 self.save_apr_with_tokenizer(path, name, architecture, None)
2771 }
2772
2773 pub fn save_apr_with_tokenizer(
2782 &mut self,
2783 path: impl AsRef<std::path::Path>,
2784 name: &str,
2785 architecture: &str,
2786 tokenizer_dir: Option<&std::path::Path>,
2787 ) -> crate::Result<()> {
2788 self.sync_weights_to_cpu();
2789
2790 let params: Vec<(String, Tensor)> = self
2791 .model
2792 .named_parameters()
2793 .into_iter()
2794 .map(|(name, tensor)| (name, tensor.clone()))
2795 .collect();
2796
2797 use crate::io::save::infer_all_tensor_shapes;
2802 use aprender::serialization::apr::AprWriter;
2803 use serde_json::Value as Jv;
2804
2805 let mc = &self.config.model_config;
2806 let mut writer = AprWriter::new();
2807
2808 writer.set_metadata("model_name", Jv::String(name.to_string()));
2810 writer.set_metadata("architecture", Jv::String(architecture.to_string()));
2811 writer.set_metadata("version", Jv::String("0.1.0".into()));
2812 writer.set_metadata("format", Jv::String("entrenar-checkpoint".into()));
2813
2814 writer.set_metadata(
2817 "hidden_size",
2818 Jv::Number(serde_json::Number::from(mc.hidden_size as u64)),
2819 );
2820 writer.set_metadata(
2821 "num_hidden_layers",
2822 Jv::Number(serde_json::Number::from(mc.num_hidden_layers as u64)),
2823 );
2824 writer.set_metadata(
2825 "num_attention_heads",
2826 Jv::Number(serde_json::Number::from(mc.num_attention_heads as u64)),
2827 );
2828 writer.set_metadata(
2829 "num_kv_heads",
2830 Jv::Number(serde_json::Number::from(mc.num_kv_heads as u64)),
2831 );
2832 writer.set_metadata(
2833 "intermediate_size",
2834 Jv::Number(serde_json::Number::from(mc.intermediate_size as u64)),
2835 );
2836 writer
2837 .set_metadata("vocab_size", Jv::Number(serde_json::Number::from(mc.vocab_size as u64)));
2838 writer.set_metadata(
2839 "max_position_embeddings",
2840 Jv::Number(serde_json::Number::from(mc.max_position_embeddings as u64)),
2841 );
2842 if let Some(rope) = serde_json::Number::from_f64(mc.rope_theta as f64) {
2843 writer.set_metadata("rope_theta", Jv::Number(rope));
2844 }
2845 if let Some(eps) = serde_json::Number::from_f64(mc.rms_norm_eps as f64) {
2846 writer.set_metadata("rms_norm_eps", Jv::Number(eps));
2847 }
2848
2849 if let Some(dir) = tokenizer_dir {
2855 let tok_path = dir.join("tokenizer.json");
2856 if let Ok(json_bytes) = std::fs::read(&tok_path) {
2857 if let Ok(tok) = serde_json::from_slice::<Jv>(&json_bytes) {
2858 if let Some(model) = tok.get("model") {
2859 if let Some(vocab_obj) = model.get("vocab").and_then(|v| v.as_object()) {
2860 let mut vocab_pairs: Vec<(String, u64)> = vocab_obj
2861 .iter()
2862 .filter_map(|(k, v)| Some((k.clone(), v.as_u64()?)))
2863 .collect();
2864 vocab_pairs.sort_by_key(|(_, id)| *id);
2865 let vocab: Vec<Jv> =
2866 vocab_pairs.into_iter().map(|(k, _)| Jv::String(k)).collect();
2867 writer.set_metadata("tokenizer.vocabulary", Jv::Array(vocab));
2868 }
2869 if let Some(merges_arr) = model.get("merges").and_then(|m| m.as_array()) {
2870 let merges: Vec<Jv> = merges_arr
2871 .iter()
2872 .filter_map(|v| v.as_str().map(|s| Jv::String(s.to_string())))
2873 .collect();
2874 writer.set_metadata("tokenizer.merges", Jv::Array(merges));
2875 }
2876 }
2877 if let Some(added) = tok.get("added_tokens").and_then(|a| a.as_array()) {
2879 for entry in added {
2880 let content =
2881 entry.get("content").and_then(|c| c.as_str()).unwrap_or("");
2882 let id = entry.get("id").and_then(|i| i.as_u64());
2883 if let Some(id) = id {
2884 match content {
2885 "<s>" | "<|im_start|>" | "<|begin_of_text|>" => {
2886 writer.set_metadata(
2887 "tokenizer.bos_token_id",
2888 Jv::Number(serde_json::Number::from(id)),
2889 );
2890 }
2891 "</s>" | "<|im_end|>" | "<|end_of_text|>" | "<|endoftext|>" => {
2892 writer.set_metadata(
2893 "tokenizer.eos_token_id",
2894 Jv::Number(serde_json::Number::from(id)),
2895 );
2896 }
2897 _ => {}
2898 }
2899 }
2900 }
2901 }
2902 }
2903 }
2904 }
2905
2906 let shapes = infer_all_tensor_shapes(¶ms);
2908 for (tname, tensor) in ¶ms {
2909 let data = tensor.data();
2910 let slice = data.as_slice().expect("tensor data must be contiguous");
2911 let shape = shapes.get(tname).cloned().unwrap_or_else(|| vec![tensor.len()]);
2912 writer.add_tensor_f32(tname, shape, slice);
2913 }
2914
2915 writer
2916 .write(path)
2917 .map_err(|e| crate::error::Error::Serialization(format!("APR write failed: {e}")))
2918 }
2919
2920 fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2927 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2928 if use_nf4 {
2929 let frozen_suffixes = [
2930 "q_proj.weight",
2931 "k_proj.weight",
2932 "v_proj.weight",
2933 "o_proj.weight",
2934 "gate_proj.weight",
2935 "up_proj.weight",
2936 "down_proj.weight",
2937 ];
2938 self.model
2939 .named_parameters()
2940 .into_iter()
2941 .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2942 .map(|(n, t)| (n, t.data().to_vec()))
2943 .collect()
2944 } else {
2945 self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2946 }
2947 }
2948
2949 fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2950 if self.config.quantize_nf4 && self.config.is_lora() {
2951 self.cuda_blocks
2952 .iter()
2953 .enumerate()
2954 .filter_map(|(i, block)| {
2955 block
2956 .download_lora_weights()
2957 .ok()
2958 .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2959 })
2960 .collect()
2961 } else {
2962 Vec::new()
2963 }
2964 }
2965
2966 pub fn prepare_async_apr_save(
2967 &mut self,
2968 name: &str,
2969 architecture: &str,
2970 step: usize,
2971 loss: f64,
2972 lr: f64,
2973 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2974 self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2975 }
2976
2977 pub fn prepare_async_apr_save_with_tokenizer(
2983 &mut self,
2984 name: &str,
2985 architecture: &str,
2986 step: usize,
2987 loss: f64,
2988 lr: f64,
2989 tokenizer_path: Option<&std::path::Path>,
2990 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2991 self.sync_weights_to_cpu();
2992
2993 let param_data = self.snapshot_param_data();
2994 let lora_data = self.snapshot_lora_data();
2995
2996 let embed_m: Vec<Vec<f32>> = self
2998 .embed_optimizer
2999 .first_moments()
3000 .iter()
3001 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3002 .collect();
3003 let embed_v: Vec<Vec<f32>> = self
3004 .embed_optimizer
3005 .second_moments()
3006 .iter()
3007 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3008 .collect();
3009 let embed_step = self.embed_optimizer.step_count();
3010
3011 let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
3016 .gpu_training
3017 .optimizer_states
3018 .iter()
3019 .map(|state| state.download_to_host().unwrap_or_default())
3020 .collect();
3021
3022 let lm_head_m_host = {
3024 let mut buf = vec![0.0f32; self.lm_head_m.len()];
3025 let _ = self.lm_head_m.copy_to_host(&mut buf);
3026 buf
3027 };
3028 let lm_head_v_host = {
3029 let mut buf = vec![0.0f32; self.lm_head_v.len()];
3030 let _ = self.lm_head_v.copy_to_host(&mut buf);
3031 buf
3032 };
3033 let final_norm_m_host = {
3034 let mut buf = vec![0.0f32; self.final_norm_m.len()];
3035 let _ = self.final_norm_m.copy_to_host(&mut buf);
3036 buf
3037 };
3038 let final_norm_v_host = {
3039 let mut buf = vec![0.0f32; self.final_norm_v.len()];
3040 let _ = self.final_norm_v.copy_to_host(&mut buf);
3041 buf
3042 };
3043
3044 let name = name.to_string();
3045 let architecture = architecture.to_string();
3046 let model_config_json = serde_json::to_string(&self.config.model_config).ok();
3047 let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
3048
3049 let arch_hidden_size = self.config.model_config.hidden_size;
3056 let arch_num_layers = self.config.model_config.num_hidden_layers;
3057 let arch_num_heads = self.config.model_config.num_attention_heads;
3058 let arch_num_kv_heads = self.config.model_config.num_kv_heads;
3059 let arch_intermediate_size = self.config.model_config.intermediate_size;
3060 let arch_vocab_size = self.config.model_config.vocab_size;
3061 let arch_max_position_embeddings = self.config.model_config.max_position_embeddings;
3062 let arch_rope_theta = self.config.model_config.rope_theta;
3063 let arch_rms_norm_eps = self.config.model_config.rms_norm_eps;
3064
3065 let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
3068 tokenizer_path.and_then(|p| {
3069 let json_bytes = std::fs::read(p).ok()?;
3070 let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
3071 let model = tok.get("model")?;
3072 let vocab_obj = model.get("vocab")?.as_object()?;
3073 let mut vocab_pairs: Vec<(String, u64)> =
3075 vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
3076 vocab_pairs.sort_by_key(|(_, id)| *id);
3077 let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
3078 let merges: Vec<String> = model
3080 .get("merges")?
3081 .as_array()?
3082 .iter()
3083 .filter_map(|v| v.as_str().map(String::from))
3084 .collect();
3085 let added = tok.get("added_tokens").and_then(|a| a.as_array());
3087 let bos_id = added.and_then(|arr| {
3088 arr.iter()
3089 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
3090 .and_then(|t| t.get("id")?.as_u64())
3091 });
3092 let eos_id = added.and_then(|arr| {
3093 arr.iter()
3094 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
3095 .and_then(|t| t.get("id")?.as_u64())
3096 });
3097 if vocab.is_empty() {
3098 return None;
3099 }
3100 println!(
3101 " [ALB-130] Embedding tokenizer: {} vocab, {} merges",
3102 vocab.len(),
3103 merges.len()
3104 );
3105 Some((vocab, merges, bos_id, eos_id))
3106 });
3107
3108 Box::new(move |path: &std::path::Path| {
3109 use aprender::serialization::apr::AprWriter;
3110 use serde_json::Value as Jv;
3111
3112 let mut writer = AprWriter::new();
3113
3114 writer.set_metadata("model_name", Jv::String(name));
3116 writer.set_metadata("architecture", Jv::String(architecture));
3117 writer.set_metadata(
3118 "format",
3119 Jv::String(if is_delta_checkpoint {
3120 "entrenar-delta-checkpoint".into()
3121 } else {
3122 "entrenar-checkpoint".into()
3123 }),
3124 );
3125 writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
3126 writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
3127 writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
3128 writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
3129 if let Some(cfg) = model_config_json {
3130 writer.set_metadata("model_config", Jv::String(cfg));
3131 }
3132
3133 writer.set_metadata(
3137 "hidden_size",
3138 Jv::Number(serde_json::Number::from(arch_hidden_size as u64)),
3139 );
3140 writer.set_metadata(
3141 "num_hidden_layers",
3142 Jv::Number(serde_json::Number::from(arch_num_layers as u64)),
3143 );
3144 writer.set_metadata(
3145 "num_attention_heads",
3146 Jv::Number(serde_json::Number::from(arch_num_heads as u64)),
3147 );
3148 writer.set_metadata(
3149 "num_kv_heads",
3150 Jv::Number(serde_json::Number::from(arch_num_kv_heads as u64)),
3151 );
3152 writer.set_metadata(
3153 "intermediate_size",
3154 Jv::Number(serde_json::Number::from(arch_intermediate_size as u64)),
3155 );
3156 writer.set_metadata(
3157 "vocab_size",
3158 Jv::Number(serde_json::Number::from(arch_vocab_size as u64)),
3159 );
3160 writer.set_metadata(
3161 "max_position_embeddings",
3162 Jv::Number(serde_json::Number::from(arch_max_position_embeddings as u64)),
3163 );
3164 if let Some(rope) = serde_json::Number::from_f64(arch_rope_theta as f64) {
3165 writer.set_metadata("rope_theta", Jv::Number(rope));
3166 }
3167 if let Some(eps) = serde_json::Number::from_f64(arch_rms_norm_eps as f64) {
3168 writer.set_metadata("rms_norm_eps", Jv::Number(eps));
3169 }
3170
3171 if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
3173 writer.set_metadata(
3174 "tokenizer.vocabulary",
3175 Jv::Array(vocab.into_iter().map(Jv::String).collect()),
3176 );
3177 writer.set_metadata(
3178 "tokenizer.merges",
3179 Jv::Array(merges.into_iter().map(Jv::String).collect()),
3180 );
3181 if let Some(bos) = bos_id {
3182 writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
3183 }
3184 if let Some(eos) = eos_id {
3185 writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
3186 }
3187 }
3188
3189 let hidden_size = param_data
3191 .iter()
3192 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
3193 .map_or(0, |(_, d)| d.len());
3194
3195 for (tensor_name, data) in ¶m_data {
3197 let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
3198 writer.add_tensor_f32(tensor_name.clone(), shape, data);
3199 }
3200
3201 for (i, m_data) in embed_m.iter().enumerate() {
3203 let len = m_data.len();
3204 writer.add_tensor_f32(
3205 format!("__training__.embed_optimizer.m.{i}"),
3206 vec![len],
3207 m_data,
3208 );
3209 }
3210 for (i, v_data) in embed_v.iter().enumerate() {
3211 let len = v_data.len();
3212 writer.add_tensor_f32(
3213 format!("__training__.embed_optimizer.v.{i}"),
3214 vec![len],
3215 v_data,
3216 );
3217 }
3218
3219 for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
3221 for (suffix, data) in buffers {
3222 let len = data.len();
3223 writer.add_tensor_f32(
3224 format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
3225 vec![len],
3226 data,
3227 );
3228 }
3229 }
3230
3231 if !lm_head_m_host.is_empty() {
3233 let len = lm_head_m_host.len();
3234 writer.add_tensor_f32(
3235 "__training__.lm_head_optimizer.m".to_string(),
3236 vec![len],
3237 &lm_head_m_host,
3238 );
3239 let len = lm_head_v_host.len();
3240 writer.add_tensor_f32(
3241 "__training__.lm_head_optimizer.v".to_string(),
3242 vec![len],
3243 &lm_head_v_host,
3244 );
3245 }
3246 if !final_norm_m_host.is_empty() {
3247 let len = final_norm_m_host.len();
3248 writer.add_tensor_f32(
3249 "__training__.final_norm_optimizer.m".to_string(),
3250 vec![len],
3251 &final_norm_m_host,
3252 );
3253 let len = final_norm_v_host.len();
3254 writer.add_tensor_f32(
3255 "__training__.final_norm_optimizer.v".to_string(),
3256 vec![len],
3257 &final_norm_v_host,
3258 );
3259 }
3260
3261 for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
3263 if !a_q.is_empty() {
3264 writer.add_tensor_f32(
3265 format!("lora.{layer_idx}.q_proj.lora_a"),
3266 vec![a_q.len()],
3267 a_q,
3268 );
3269 writer.add_tensor_f32(
3270 format!("lora.{layer_idx}.q_proj.lora_b"),
3271 vec![b_q.len()],
3272 b_q,
3273 );
3274 }
3275 if !a_v.is_empty() {
3276 writer.add_tensor_f32(
3277 format!("lora.{layer_idx}.v_proj.lora_a"),
3278 vec![a_v.len()],
3279 a_v,
3280 );
3281 writer.add_tensor_f32(
3282 format!("lora.{layer_idx}.v_proj.lora_b"),
3283 vec![b_v.len()],
3284 b_v,
3285 );
3286 }
3287 }
3288
3289 writer
3291 .write(path)
3292 .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3293
3294 Ok(())
3295 })
3296 }
3297
3298 pub fn gpu_name(&self) -> String {
3300 self.cuda_trainer.device_name()
3301 }
3302
3303 pub fn save_cuda_lora_adapter(
3313 &self,
3314 output_dir: &std::path::Path,
3315 base_model_name: Option<&str>,
3316 ) -> crate::Result<()> {
3317 if !self.config.quantize_nf4 || !self.config.is_lora() {
3318 return Ok(()); }
3320
3321 let lora_rank = self.config.lora_rank.unwrap_or(16);
3322 let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3323 let lora_scale = lora_alpha / lora_rank as f32;
3324 let hidden_size = self.config.model_config.hidden_size;
3325 let head_dim = self.config.model_config.head_dim();
3326 let q_dim = self.config.model_config.num_attention_heads * head_dim;
3327 let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3328
3329 let lora_config =
3330 crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3331
3332 let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3333
3334 for (i, block) in self.cuda_blocks.iter().enumerate() {
3335 let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3336 Ok(weights) => weights,
3337 Err(_) => continue, };
3339
3340 if a_q.is_empty() && a_v.is_empty() {
3341 continue;
3342 }
3343
3344 if !a_q.is_empty() {
3346 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3348 for r in 0..hidden_size {
3349 for c in 0..lora_rank {
3350 a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3351 }
3352 }
3353
3354 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3357 let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3358 for r in 0..lora_rank {
3359 for c in 0..q_dim {
3360 b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3361 }
3362 }
3363
3364 let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3365 let mut layer = crate::lora::LoRALayer::new(
3366 base_weight,
3367 q_dim,
3368 hidden_size,
3369 lora_rank,
3370 lora_alpha,
3371 );
3372 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3374 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3375
3376 adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3377 }
3378
3379 if !a_v.is_empty() {
3381 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3382 for r in 0..hidden_size {
3383 for c in 0..lora_rank {
3384 a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3385 }
3386 }
3387
3388 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3389 let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3390 for r in 0..lora_rank {
3391 for c in 0..kv_hidden {
3392 b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3393 }
3394 }
3395
3396 let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3397 let mut layer = crate::lora::LoRALayer::new(
3398 base_weight,
3399 kv_hidden,
3400 hidden_size,
3401 lora_rank,
3402 lora_alpha,
3403 );
3404 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3405 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3406
3407 adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3408 }
3409 }
3410
3411 if adapters.is_empty() {
3412 println!(" [WARN] No LoRA adapters found to save");
3413 return Ok(());
3414 }
3415
3416 let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3417 adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3418
3419 std::fs::create_dir_all(output_dir).ok();
3420 crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3421 .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3422
3423 let adapter_path = output_dir.join("adapter_model.safetensors");
3424 let size_mb =
3425 std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3426 println!(
3427 "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3428 adapters.len(),
3429 size_mb,
3430 output_dir.display()
3431 );
3432
3433 Ok(())
3434 }
3435
3436 pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3441 let path = dir.join("optimizer_state.json");
3442 let m_data: Vec<Option<Vec<f32>>> = self
3443 .embed_optimizer
3444 .first_moments()
3445 .iter()
3446 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3447 .collect();
3448 let v_data: Vec<Option<Vec<f32>>> = self
3449 .embed_optimizer
3450 .second_moments()
3451 .iter()
3452 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3453 .collect();
3454 let state = serde_json::json!({
3455 "type": "adamw_cpu_embed",
3456 "step": self.embed_optimizer.step_count(),
3457 "m": m_data,
3458 "v": v_data,
3459 });
3460 let json_str = serde_json::to_string(&state).map_err(|e| {
3461 crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3462 })?;
3463 std::fs::write(&path, json_str)
3464 .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3465 Ok(())
3466 }
3467
3468 pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3474 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3475 Ok(r) => r,
3476 Err(_) => return (0, self.cuda_blocks.len()),
3477 };
3478
3479 let mut restored = 0usize;
3480 for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3481 let a_q =
3482 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3483 let b_q =
3484 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3485 let a_v =
3486 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3487 let b_v =
3488 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3489
3490 if a_q.is_empty() {
3491 continue; }
3493
3494 if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3495 eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3496 continue;
3497 }
3498 restored += 1;
3499 }
3500
3501 (restored, self.cuda_blocks.len())
3502 }
3503
3504 pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3509 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3510 Ok(r) => r,
3511 Err(_) => return false,
3512 };
3513
3514 if let Some(step_val) = reader.get_metadata("optimizer_step") {
3516 if let Some(step_str) = step_val.as_str() {
3517 if let Ok(step) = step_str.parse::<u64>() {
3518 self.embed_optimizer.set_step_count(step);
3519 }
3520 }
3521 }
3522
3523 for i in 0..128 {
3525 let name = format!("__training__.embed_optimizer.m.{i}");
3526 match reader.read_tensor_f32(&name) {
3527 Ok(data) if !data.is_empty() => {
3528 self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3529 }
3530 _ => break,
3531 }
3532 }
3533
3534 for i in 0..128 {
3536 let name = format!("__training__.embed_optimizer.v.{i}");
3537 match reader.read_tensor_f32(&name) {
3538 Ok(data) if !data.is_empty() => {
3539 self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3540 }
3541 _ => break,
3542 }
3543 }
3544
3545 let suffixes = [
3547 "m.w_q",
3548 "v.w_q",
3549 "m.w_k",
3550 "v.w_k",
3551 "m.w_v",
3552 "v.w_v",
3553 "m.w_o",
3554 "v.w_o",
3555 "m.w_gate",
3556 "v.w_gate",
3557 "m.w_up",
3558 "v.w_up",
3559 "m.w_down",
3560 "v.w_down",
3561 "m.input_norm",
3562 "v.input_norm",
3563 "m.post_attn_norm",
3564 "v.post_attn_norm",
3565 ];
3566 let mut blocks_restored = 0usize;
3567 for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3568 let mut data = std::collections::HashMap::new();
3569 for suffix in &suffixes {
3570 let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3571 if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3572 if !tensor_data.is_empty() {
3573 data.insert(suffix.to_string(), tensor_data);
3574 }
3575 }
3576 }
3577 if !data.is_empty() {
3578 let _ = state.restore_from_host(&data);
3579 blocks_restored += 1;
3580 }
3581 }
3582
3583 if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3585 if m_data.len() == self.lm_head_m.len() {
3586 let _ = self.lm_head_m.copy_from_host(&m_data);
3587 }
3588 }
3589 if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3590 if v_data.len() == self.lm_head_v.len() {
3591 let _ = self.lm_head_v.copy_from_host(&v_data);
3592 }
3593 }
3594
3595 if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3597 if m_data.len() == self.final_norm_m.len() {
3598 let _ = self.final_norm_m.copy_from_host(&m_data);
3599 }
3600 }
3601 if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3602 if v_data.len() == self.final_norm_v.len() {
3603 let _ = self.final_norm_v.copy_from_host(&v_data);
3604 }
3605 }
3606
3607 if blocks_restored > 0 {
3609 println!(
3610 " ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3611 self.gpu_training.optimizer_states.len()
3612 );
3613 } else if !self.gpu_training.optimizer_states.is_empty() {
3614 println!(
3615 " [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3616 self.gpu_training.optimizer_states.len()
3617 );
3618 }
3619
3620 true
3621 }
3622
3623 pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3627 let path = dir.join("optimizer_state.json");
3628 let data = match std::fs::read_to_string(&path) {
3629 Ok(d) => d,
3630 Err(_) => return false,
3631 };
3632 let state: serde_json::Value = match serde_json::from_str(&data) {
3633 Ok(v) => v,
3634 Err(_) => return false,
3635 };
3636 if let Some(step) = state["step"].as_u64() {
3637 self.embed_optimizer.set_step_count(step);
3638 }
3639 restore_moment_buffers(&state["m"], |idx, arr| {
3640 self.embed_optimizer.set_first_moment(idx, arr);
3641 });
3642 restore_moment_buffers(&state["v"], |idx, arr| {
3643 self.embed_optimizer.set_second_moment(idx, arr);
3644 });
3645 true
3646 }
3647}
3648
3649#[cfg(feature = "cuda")]
3653fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3654 if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3655 vec![numel]
3656 } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3657 let other_dim = numel / hidden_size;
3658 if name.ends_with("down_proj.weight") {
3659 vec![hidden_size, other_dim]
3660 } else {
3661 vec![other_dim, hidden_size]
3662 }
3663 } else {
3664 vec![numel]
3665 }
3666}
3667
3668#[cfg(feature = "cuda")]
3670fn restore_moment_buffers(
3671 json_arr: &serde_json::Value,
3672 mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3673) {
3674 let Some(arr) = json_arr.as_array() else { return };
3675 for (idx, val) in arr.iter().enumerate() {
3676 let Some(inner) = val.as_array() else { continue };
3677 let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3678 if !floats.is_empty() {
3679 set_fn(idx, ndarray::Array1::from_vec(floats));
3680 }
3681 }
3682}
3683
3684#[cfg(not(feature = "cuda"))]
3687pub struct CudaTransformerTrainer;
3688
3689#[cfg(not(feature = "cuda"))]
3690impl CudaTransformerTrainer {
3691 pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3692 Err(crate::error::Error::ConfigError(
3693 "CUDA not available (compiled without cuda feature)".into(),
3694 ))
3695 }
3696
3697 pub fn with_model(
3698 _model: crate::transformer::Transformer,
3699 _config: super::config::TransformerTrainConfig,
3700 ) -> crate::Result<Self> {
3701 Err(crate::error::Error::ConfigError(
3702 "CUDA not available (compiled without cuda feature)".into(),
3703 ))
3704 }
3705
3706 pub fn gpu_name(&self) -> String {
3707 unreachable!("CudaTransformerTrainer stub should never be instantiated")
3708 }
3709}
3710
3711#[cfg(test)]
3712mod tests {
3713 #[test]
3714 #[cfg(not(feature = "cuda"))]
3715 fn test_cuda_trainer_stub_returns_error() {
3716 use super::super::config::TransformerTrainConfig;
3717 use crate::transformer::TransformerConfig;
3718
3719 let mc = TransformerConfig::tiny();
3720 let config = TransformerTrainConfig::new(mc);
3721 let result = super::CudaTransformerTrainer::new(config);
3722 assert!(result.is_err());
3723 }
3724}