1#[cfg(feature = "cuda")]
36use trueno_gpu::driver::{CudaStream, GpuBuffer};
37
38#[cfg(feature = "cuda")]
39use crate::autograd::cuda_backward::{gemm_backward_a, gemm_backward_b, rms_norm_backward};
40#[cfg(feature = "cuda")]
41use crate::autograd::cuda_forward::{gemm_forward, pre_warm_forward_kernels, rms_norm_forward};
42#[cfg(feature = "cuda")]
43use crate::autograd::cuda_optim::{
44 adamw_step_cuda, clip_scale_reduce_cuda, fused_cross_entropy_cuda, gradient_clip_cuda,
45 gradient_clip_gpu_scale_cuda, squared_sum_collect, squared_sum_cuda, squared_sum_launch_cuda,
46 squared_sum_launch_into, FusedClipState,
47};
48#[cfg(feature = "cuda")]
49use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
50#[cfg(feature = "cuda")]
51use crate::autograd::precision::GradScaler;
52#[cfg(feature = "cuda")]
53use crate::autograd::Tensor;
54#[cfg(feature = "cuda")]
55use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
56#[cfg(feature = "cuda")]
57use crate::optim::{AdamW, Optimizer};
58#[cfg(feature = "cuda")]
59use crate::train::MetricsTracker;
60#[cfg(feature = "cuda")]
61use crate::transformer::{
62 CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
63 GpuBlockOptimizerState, GpuLoraOptimizerState, Transformer,
64};
65
66#[cfg(feature = "cuda")]
67use super::batch::LMBatch;
68#[cfg(feature = "cuda")]
69use super::config::TransformerTrainConfig;
70#[cfg(feature = "cuda")]
71use super::step_profiler::StepProfiler;
72
73#[cfg(feature = "cuda")]
80fn compute_workspace_clip_scale_gpu(
81 ws: &CudaGradWorkspace,
82 max_norm: f32,
83 stream: &CudaStream,
84) -> (f32, f32) {
85 use crate::autograd::cuda_optim::PendingSquaredSum;
86
87 let all_bufs: [&GpuBuffer<f32>; 9] = [
88 &ws.grad_w_q,
89 &ws.grad_w_k,
90 &ws.grad_w_v,
91 &ws.grad_w_o,
92 &ws.grad_gate,
93 &ws.grad_up,
94 &ws.grad_down,
95 &ws.grad_input_norm,
96 &ws.grad_post_attn_norm,
97 ];
98
99 let mut pending: Vec<PendingSquaredSum> = Vec::with_capacity(9);
102 for buf in &all_bufs {
103 let n = buf.len() as u32;
104 if n == 0 {
105 continue;
106 }
107 if let Ok(p) = squared_sum_launch_cuda(buf, n, stream) {
108 pending.push(p);
109 }
110 }
111
112 if stream.synchronize().is_err() {
114 return (1.0, 0.0);
115 }
116
117 let mut total_sq = 0.0f64;
121 for p in &pending {
122 if let Ok(sq_norm) = squared_sum_collect(p) {
123 total_sq += f64::from(sq_norm); }
125 }
126
127 let grad_norm = total_sq.sqrt() as f32; let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
129 (scale, grad_norm)
130}
131
132#[cfg(feature = "cuda")]
136fn clip_workspace_gradients(ws: &mut CudaGradWorkspace, max_norm: f32, stream: &CudaStream) -> f32 {
137 let (scale, grad_norm) = compute_workspace_clip_scale_gpu(ws, max_norm, stream);
138 if (scale - 1.0).abs() < 1e-7 {
139 return grad_norm;
140 }
141
142 let n_wq = ws.grad_w_q.len() as u32;
143 let n_wk = ws.grad_w_k.len() as u32;
144 let n_wv = ws.grad_w_v.len() as u32;
145 let n_wo = ws.grad_w_o.len() as u32;
146 let n_gate = ws.grad_gate.len() as u32;
147 let n_up = ws.grad_up.len() as u32;
148 let n_down = ws.grad_down.len() as u32;
149 let n_inorm = ws.grad_input_norm.len() as u32;
150 let n_panorm = ws.grad_post_attn_norm.len() as u32;
151
152 let _ = gradient_clip_cuda(&mut ws.grad_w_q, scale, n_wq, stream);
153 let _ = gradient_clip_cuda(&mut ws.grad_w_k, scale, n_wk, stream);
154 let _ = gradient_clip_cuda(&mut ws.grad_w_v, scale, n_wv, stream);
155 let _ = gradient_clip_cuda(&mut ws.grad_w_o, scale, n_wo, stream);
156 let _ = gradient_clip_cuda(&mut ws.grad_gate, scale, n_gate, stream);
157 let _ = gradient_clip_cuda(&mut ws.grad_up, scale, n_up, stream);
158 let _ = gradient_clip_cuda(&mut ws.grad_down, scale, n_down, stream);
159 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, scale, n_inorm, stream);
160 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, scale, n_panorm, stream);
161 grad_norm
162}
163
164#[cfg(feature = "cuda")]
175fn fused_clip_workspace_gradients(
176 ws: &mut CudaGradWorkspace,
177 max_norm: f32,
178 state: &FusedClipState,
179 stream: &CudaStream,
180) {
181 let all_bufs: [&GpuBuffer<f32>; 9] = [
182 &ws.grad_w_q,
183 &ws.grad_w_k,
184 &ws.grad_w_v,
185 &ws.grad_w_o,
186 &ws.grad_gate,
187 &ws.grad_up,
188 &ws.grad_down,
189 &ws.grad_input_norm,
190 &ws.grad_post_attn_norm,
191 ];
192
193 for (i, buf) in all_bufs.iter().enumerate() {
196 let n = buf.len() as u32;
197 if n == 0 {
198 continue;
199 }
200 let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
201 let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
202 }
203
204 let _ = clip_scale_reduce_cuda(
207 &state.partials_buf,
208 state.total_partials,
209 max_norm,
210 &state.scale_buf,
211 stream,
212 );
213
214 let scale_ptr = state.scale_buf.as_ptr(); let mut all_bufs_mut: [&mut GpuBuffer<f32>; 9] = [
218 &mut ws.grad_w_q,
219 &mut ws.grad_w_k,
220 &mut ws.grad_w_v,
221 &mut ws.grad_w_o,
222 &mut ws.grad_gate,
223 &mut ws.grad_up,
224 &mut ws.grad_down,
225 &mut ws.grad_input_norm,
226 &mut ws.grad_post_attn_norm,
227 ];
228 for buf in &mut all_bufs_mut {
229 let n = buf.len() as u32;
230 if n == 0 {
231 continue;
232 }
233 let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
234 }
235}
236
237#[cfg(feature = "cuda")]
241#[allow(dead_code)]
242fn compute_workspace_grad_norm(ws: &CudaGradWorkspace, stream: &CudaStream) -> f32 {
243 let (_, norm) = compute_workspace_clip_scale_gpu(ws, f32::MAX, stream);
244 norm
245}
246
247#[cfg(feature = "cuda")]
257#[allow(dead_code)]
258fn unscale_workspace_gradients(ws: &mut CudaGradWorkspace, inv_scale: f32, stream: &CudaStream) {
259 if (inv_scale - 1.0).abs() < 1e-7 {
260 return;
261 }
262
263 let n_wq = ws.grad_w_q.len() as u32;
264 let n_wk = ws.grad_w_k.len() as u32;
265 let n_wv = ws.grad_w_v.len() as u32;
266 let n_wo = ws.grad_w_o.len() as u32;
267 let n_gate = ws.grad_gate.len() as u32;
268 let n_up = ws.grad_up.len() as u32;
269 let n_down = ws.grad_down.len() as u32;
270 let n_inorm = ws.grad_input_norm.len() as u32;
271 let n_panorm = ws.grad_post_attn_norm.len() as u32;
272
273 let _ = gradient_clip_cuda(&mut ws.grad_w_q, inv_scale, n_wq, stream);
274 let _ = gradient_clip_cuda(&mut ws.grad_w_k, inv_scale, n_wk, stream);
275 let _ = gradient_clip_cuda(&mut ws.grad_w_v, inv_scale, n_wv, stream);
276 let _ = gradient_clip_cuda(&mut ws.grad_w_o, inv_scale, n_wo, stream);
277 let _ = gradient_clip_cuda(&mut ws.grad_gate, inv_scale, n_gate, stream);
278 let _ = gradient_clip_cuda(&mut ws.grad_up, inv_scale, n_up, stream);
279 let _ = gradient_clip_cuda(&mut ws.grad_down, inv_scale, n_down, stream);
280 let _ = gradient_clip_cuda(&mut ws.grad_input_norm, inv_scale, n_inorm, stream);
281 let _ = gradient_clip_cuda(&mut ws.grad_post_attn_norm, inv_scale, n_panorm, stream);
282}
283
284#[cfg(feature = "cuda")]
292struct GpuPretrainState {
293 layer_inputs: Vec<GpuBuffer<f32>>,
295 saved_layer_mask: Vec<bool>,
299 recompute_buf: Option<GpuBuffer<f32>>,
303 final_norm_weight: GpuBuffer<f32>,
305 blocks_output: GpuBuffer<f32>,
307 grad_buf_a: GpuBuffer<f32>,
309 grad_buf_b: GpuBuffer<f32>,
311 grad_final_norm_weight: GpuBuffer<f32>,
313 norm_output: GpuBuffer<f32>,
315 logits_buf: GpuBuffer<f32>,
317 lm_head_grad_hidden: GpuBuffer<f32>,
319 optimizer_states: Vec<GpuBlockOptimizerState>,
321 step: u32,
323}
324
325#[cfg(feature = "cuda")]
336pub struct CudaTransformerTrainer {
337 model: Transformer,
339 cuda_trainer: CudaTrainer,
341 cuda_blocks: Vec<CudaBlock>,
343 cuda_grad_workspace: CudaGradWorkspace,
345 nf4_shared_scratch: Option<CudaBlockScratch>,
347 nf4_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
349 nf4_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
351 gpu_training: GpuPretrainState,
353 lm_head_weight_gpu: GpuBuffer<f32>,
355 lm_head_grad_gpu: GpuBuffer<f32>,
357 lm_head_m: GpuBuffer<f32>,
359 lm_head_v: GpuBuffer<f32>,
361 final_norm_m: GpuBuffer<f32>,
363 final_norm_v: GpuBuffer<f32>,
365 embed_optimizer: AdamW,
367 config: TransformerTrainConfig,
369 pub metrics: MetricsTracker,
371 step: usize,
373 accumulated_loss: f32,
375 accumulated_batches: usize,
377 last_grad_norm: f32,
379 last_embed_grad_norm: f32,
381 grad_accum: Option<super::grad_accumulator::PerBlockGradientAccumulator>,
384 gpu_grad_accum: Option<super::gpu_grad_accumulator::GpuGradientAccumulator>,
387 grad_scaler: GradScaler,
391 profiler: StepProfiler,
394 fwd_scratch_a: GpuBuffer<f32>,
397 fwd_scratch_b: GpuBuffer<f32>,
398 h2d_staging: Vec<f32>,
401 d2h_staging: Vec<f32>,
406 fused_clip: Option<FusedClipState>,
409 final_norm_zero_buf: Vec<f32>,
413}
414
415#[cfg(feature = "cuda")]
416impl CudaTransformerTrainer {
417 pub fn new(config: TransformerTrainConfig) -> crate::Result<Self> {
424 let model = Transformer::new(&config.model_config);
425 Self::with_model(model, config)
426 }
427
428 pub fn for_inference(
442 checkpoint_dir: impl AsRef<std::path::Path>,
443 model_config: crate::transformer::TransformerConfig,
444 ) -> crate::Result<Self> {
445 let dir = checkpoint_dir.as_ref();
446
447 let model = if let Some((Some(m), _step)) =
449 crate::config::try_load_apr_for_inference(dir, &model_config)
450 {
451 m
452 } else {
453 Transformer::from_safetensors(dir, &model_config)?
454 };
455
456 let mut config = TransformerTrainConfig::new(model_config);
457 config.max_seq_len = config.model_config.max_position_embeddings;
458 Self::with_model(model, config)
459 }
460
461 pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> crate::Result<Self> {
467 if !cuda_training_available() {
468 return Err(crate::error::Error::ConfigError("CUDA not available".into()));
469 }
470
471 let mc = &config.model_config;
472 let max_seq_len = config.max_seq_len;
473 let hidden_size = mc.hidden_size;
474 let vocab_size = mc.vocab_size;
475 let num_layers = mc.num_hidden_layers;
476
477 let cuda_trainer = CudaTrainer::new().map_err(|e| {
479 crate::error::Error::ConfigError(format!("CUDA trainer init failed: {e:?}"))
480 })?;
481
482 println!(
483 " GPU: {} ({:.1} GB)",
484 cuda_trainer.device_name(),
485 cuda_trainer.total_memory() as f64 / 1e9
486 );
487
488 let ctx = cuda_trainer.context().clone();
489 let stream = cuda_trainer.stream();
490
491 pre_warm_forward_kernels(
494 hidden_size,
495 mc.intermediate_size,
496 mc.num_attention_heads,
497 mc.num_kv_heads,
498 mc.head_dim(),
499 max_seq_len,
500 )
501 .map_err(|e| crate::error::Error::ConfigError(format!("Kernel pre-warm failed: {e:?}")))?;
502
503 {
507 use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels;
508 let head_dim = mc.head_dim();
509 pre_warm_lora_backward_kernels(
510 hidden_size,
511 mc.num_attention_heads * head_dim,
512 mc.num_kv_heads * head_dim,
513 max_seq_len,
514 config.lora_rank.unwrap_or(0),
515 mc.intermediate_size,
516 mc.num_attention_heads,
517 config.quantize_nf4 && config.is_lora(),
518 )
519 .map_err(|e| {
520 crate::error::Error::ConfigError(format!("Backward kernel pre-warm failed: {e:?}"))
521 })?;
522 eprintln!(" ✓ Backward kernels pre-warmed (silu_backward, rms_norm_backward, etc.)");
523 }
524
525 if let Err(e) = crate::autograd::cuda_forward::set_forward_cublas_stream(stream) {
528 println!("[WARN] cuBLAS forward stream bind failed: {e:?} — falling back to PTX");
529 }
530 if let Err(e) = crate::autograd::cuda_backward::set_backward_cublas_stream(stream) {
531 println!("[WARN] cuBLAS backward stream bind failed: {e:?} — falling back to PTX");
532 }
533
534 let use_nf4 = config.quantize_nf4 && config.is_lora();
536 let cuda_blocks = Self::upload_blocks(
537 &model,
538 mc,
539 &config,
540 &ctx,
541 use_nf4,
542 num_layers,
543 hidden_size,
544 max_seq_len,
545 )?;
546
547 let cuda_grad_workspace = CudaGradWorkspace::new(&ctx, mc).map_err(|e| {
549 crate::error::Error::ConfigError(format!("Grad workspace alloc failed: {e:?}"))
550 })?;
551
552 let buf_size = max_seq_len * hidden_size;
554 let logits_size = max_seq_len * vocab_size;
555
556 let checkpointing = config.checkpoint_config.enabled;
560 let segment_size = if checkpointing {
561 let ns = config.checkpoint_config.num_segments.max(1);
562 num_layers.div_ceil(ns)
563 } else {
564 1 };
566 let saved_layer_mask: Vec<bool> =
567 (0..num_layers).map(|i| !checkpointing || i % segment_size == 0).collect();
568
569 let mut layer_inputs = Vec::with_capacity(num_layers);
570 for _ in 0..num_layers {
571 layer_inputs.push(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
572 crate::error::Error::ConfigError(format!("Layer input alloc failed: {e:?}"))
573 })?);
574 }
575
576 let recompute_buf = if checkpointing {
578 Some(GpuBuffer::new(&ctx, buf_size).map_err(|e| {
579 crate::error::Error::ConfigError(format!("Recompute buf alloc failed: {e:?}"))
580 })?)
581 } else {
582 None
583 };
584
585 if checkpointing {
586 let saved_count = saved_layer_mask.iter().filter(|&&x| x).count();
587 println!(
588 " ✓ Activation checkpointing: {} segments, saving {}/{} layer inputs",
589 config.checkpoint_config.num_segments, saved_count, num_layers
590 );
591 }
592
593 let norm_slice = model.norm.weight.data().as_slice().expect("contiguous");
595 let final_norm_weight = GpuBuffer::from_host(&ctx, norm_slice).map_err(|e| {
596 crate::error::Error::ConfigError(format!("Norm weight upload failed: {e:?}"))
597 })?;
598
599 let blocks_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
600 crate::error::Error::ConfigError(format!("Blocks output alloc failed: {e:?}"))
601 })?;
602 let grad_buf_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
603 crate::error::Error::ConfigError(format!("Grad buf A alloc failed: {e:?}"))
604 })?;
605 let grad_buf_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
606 crate::error::Error::ConfigError(format!("Grad buf B alloc failed: {e:?}"))
607 })?;
608 let grad_final_norm_weight = GpuBuffer::new(&ctx, hidden_size).map_err(|e| {
609 crate::error::Error::ConfigError(format!("Grad norm alloc failed: {e:?}"))
610 })?;
611 let norm_output = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
612 crate::error::Error::ConfigError(format!("Norm output alloc failed: {e:?}"))
613 })?;
614 let logits_buf = GpuBuffer::new(&ctx, logits_size).map_err(|e| {
615 crate::error::Error::ConfigError(format!("Logits buf alloc failed: {e:?}"))
616 })?;
617 let lm_head_grad_hidden = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
618 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
619 })?;
620
621 let mut optimizer_states = Vec::new();
623 if !use_nf4 {
624 optimizer_states.reserve(num_layers);
625 for (i, block) in cuda_blocks.iter().enumerate() {
626 optimizer_states.push(block.init_optimizer_state().map_err(|e| {
627 crate::error::Error::ConfigError(format!("Block {i} opt state failed: {e:?}"))
628 })?);
629 }
630 }
631
632 let gpu_training = GpuPretrainState {
633 layer_inputs,
634 saved_layer_mask,
635 recompute_buf,
636 final_norm_weight,
637 blocks_output,
638 grad_buf_a,
639 grad_buf_b,
640 grad_final_norm_weight,
641 norm_output,
642 logits_buf,
643 lm_head_grad_hidden,
644 optimizer_states,
645 step: 0,
646 };
647
648 let lm_head_data = model.lm_head.as_ref().unwrap_or(&model.embed_tokens.weight).data();
651 let lm_head_slice = lm_head_data.as_slice().expect("contiguous");
652 let lm_head_weight_gpu = GpuBuffer::from_host(&ctx, lm_head_slice).map_err(|e| {
653 crate::error::Error::ConfigError(format!("LM head upload failed: {e:?}"))
654 })?;
655 let lm_head_grad_gpu = GpuBuffer::new(&ctx, vocab_size * hidden_size).map_err(|e| {
656 crate::error::Error::ConfigError(format!("LM head grad alloc failed: {e:?}"))
657 })?;
658 let lm_head_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
661 .map_err(|e| {
662 crate::error::Error::ConfigError(format!("LM head m alloc failed: {e:?}"))
663 })?;
664 let lm_head_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; vocab_size * hidden_size])
665 .map_err(|e| {
666 crate::error::Error::ConfigError(format!("LM head v alloc failed: {e:?}"))
667 })?;
668
669 let final_norm_m = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
671 crate::error::Error::ConfigError(format!("Final norm m alloc failed: {e:?}"))
672 })?;
673 let final_norm_v = GpuBuffer::from_host(&ctx, &vec![0.0f32; hidden_size]).map_err(|e| {
674 crate::error::Error::ConfigError(format!("Final norm v alloc failed: {e:?}"))
675 })?;
676
677 let buf_size = max_seq_len * hidden_size;
679 let fwd_scratch_a = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
680 crate::error::Error::ConfigError(format!("Fwd scratch A alloc failed: {e:?}"))
681 })?;
682 let fwd_scratch_b = GpuBuffer::new(&ctx, buf_size).map_err(|e| {
683 crate::error::Error::ConfigError(format!("Fwd scratch B alloc failed: {e:?}"))
684 })?;
685
686 stream
688 .synchronize()
689 .map_err(|e| crate::error::Error::ConfigError(format!("Stream sync failed: {e:?}")))?;
690
691 println!(
692 " ✓ GPU training state allocated (LM head: {:.1} MB)",
693 (vocab_size * hidden_size * 4) as f64 / 1e6
694 );
695
696 let (nf4_shared_scratch, nf4_lora_grad_workspace, nf4_lora_optimizer_states) = if use_nf4 {
698 let lora_rank = config.lora_rank.unwrap_or(16);
699
700 let scratch = CudaBlockScratch::new(mc, max_seq_len, &ctx, lora_rank).map_err(|e| {
702 crate::error::Error::ConfigError(format!("NF4 shared scratch alloc failed: {e:?}"))
703 })?;
704
705 let grad_ws = CudaLoraGradWorkspace::new(&ctx, mc, lora_rank).map_err(|e| {
707 crate::error::Error::ConfigError(format!(
708 "NF4 LoRA grad workspace alloc failed: {e:?}"
709 ))
710 })?;
711
712 let mut lora_opt_states = Vec::with_capacity(num_layers);
714 for (i, block) in cuda_blocks.iter().enumerate() {
715 lora_opt_states.push(block.init_lora_optimizer_state().map_err(|e| {
716 crate::error::Error::ConfigError(format!(
717 "Block {i} LoRA opt state failed: {e:?}"
718 ))
719 })?);
720 }
721
722 println!(
723 " ✓ NF4 training infrastructure allocated (shared scratch + LoRA optimizer × {num_layers})"
724 );
725 (Some(scratch), Some(grad_ws), Some(lora_opt_states))
726 } else {
727 (None, None, None)
728 };
729
730 let embed_optimizer =
733 AdamW::new(config.lr, config.beta1, config.beta2, 1e-8, config.weight_decay);
734
735 let grad_accum = if config.accumulation_steps > 1 {
738 let kv_hidden = mc.num_kv_heads * mc.head_dim();
739 let block_sizes =
740 super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
741 hidden_size,
742 kv_hidden,
743 mc.intermediate_size,
744 );
745 let accum = super::grad_accumulator::PerBlockGradientAccumulator::new(
746 num_layers,
747 block_sizes,
748 vocab_size,
749 hidden_size,
750 );
751 println!(
752 " ✓ Gradient accumulation: {} steps, CPU buffers ({:.1} MB)",
753 config.accumulation_steps,
754 (accum
755 .block_grads
756 .iter()
757 .map(super::grad_accumulator::BlockGradientSet::total_elements)
758 .sum::<usize>()
759 + accum.lm_head_grad.len()
760 + accum.final_norm_grad.len()
761 + accum.embedding_grad.len()) as f64
762 * 4.0
763 / 1e6,
764 );
765 Some(accum)
766 } else {
767 None
768 };
769
770 let gpu_grad_accum = if config.accumulation_steps > 1 {
773 match super::gpu_grad_accumulator::GpuGradientAccumulator::new(&ctx, mc) {
774 Ok(accum) => {
775 println!(" ✓ GPU gradient accumulation enabled (ALB-091)");
776 Some(accum)
777 }
778 Err(e) => {
779 eprintln!(
780 " [WARN] GPU gradient accumulation failed ({e}), using CPU fallback"
781 );
782 None
783 }
784 }
785 } else {
786 None
787 };
788
789 let d2h_staging = if config.accumulation_steps > 1 && gpu_grad_accum.is_none() {
792 let ws_max = hidden_size * mc.intermediate_size;
793 let lm_max = vocab_size * hidden_size;
794 vec![0.0f32; ws_max.max(lm_max)]
795 } else {
796 Vec::new()
797 };
798
799 let kv_hidden = mc.num_kv_heads * mc.head_dim();
802 let fused_clip = Self::init_fused_clip(&ctx, &config, hidden_size, kv_hidden, mc);
803
804 let grad_scaler = GradScaler::from_config(&config.precision_config);
806 if config.precision_config.is_mixed() {
807 println!(
808 " ✓ Mixed precision: {} (loss scale={}, dynamic={})",
809 config.precision_config.compute_precision,
810 grad_scaler.scale(),
811 grad_scaler.is_dynamic(),
812 );
813 }
814
815 Ok(Self {
816 model,
817 cuda_trainer,
818 cuda_blocks,
819 cuda_grad_workspace,
820 nf4_shared_scratch,
821 nf4_lora_grad_workspace,
822 nf4_lora_optimizer_states,
823 gpu_training,
824 lm_head_weight_gpu,
825 lm_head_grad_gpu,
826 lm_head_m,
827 lm_head_v,
828 final_norm_m,
829 final_norm_v,
830 embed_optimizer,
831 profiler: if config.profile_interval > 0 {
833 StepProfiler::new(true, config.profile_interval)
834 } else {
835 StepProfiler::disabled()
836 },
837 config,
838 metrics: MetricsTracker::new(),
839 step: 0,
840 accumulated_loss: 0.0,
841 accumulated_batches: 0,
842 last_grad_norm: 0.0,
843 last_embed_grad_norm: 0.0,
844 grad_accum,
845 gpu_grad_accum,
846 grad_scaler,
847 fwd_scratch_a,
848 fwd_scratch_b,
849 h2d_staging: vec![0.0f32; max_seq_len * hidden_size],
850 d2h_staging,
851 fused_clip,
852 final_norm_zero_buf: vec![0.0f32; hidden_size],
853 })
854 }
855
856 #[allow(clippy::too_many_arguments)]
858 fn upload_blocks(
859 model: &Transformer,
860 mc: &crate::transformer::TransformerConfig,
861 config: &TransformerTrainConfig,
862 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
863 use_nf4: bool,
864 num_layers: usize,
865 hidden_size: usize,
866 max_seq_len: usize,
867 ) -> crate::Result<Vec<CudaBlock>> {
868 let mut cuda_blocks: Vec<CudaBlock> = Vec::with_capacity(num_layers);
869
870 if use_nf4 {
871 let lora_rank = config.lora_rank.unwrap_or(16);
872 let lora_alpha = config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
873 let lora_scale = lora_alpha / lora_rank as f32;
874 let head_dim = mc.head_dim();
875 let q_dim = mc.num_attention_heads * head_dim;
876 let kv_hidden = mc.num_kv_heads * head_dim;
877
878 for (i, layer) in model.layers.iter().enumerate() {
879 let lora_a_q: Vec<f32> = (0..hidden_size * lora_rank)
880 .map(|j| ((j as f32 + i as f32 * 1000.0) * 0.1).sin() * 0.01)
881 .collect();
882 let lora_b_q = vec![0.0f32; lora_rank * q_dim];
883 let lora_a_v: Vec<f32> = (0..hidden_size * lora_rank)
884 .map(|j| ((j as f32 + i as f32 * 2000.0 + 500.0) * 0.1).sin() * 0.01)
885 .collect();
886 let lora_b_v = vec![0.0f32; lora_rank * kv_hidden];
887
888 let q_norm_data = layer
889 .self_attn
890 .q_norm
891 .as_ref()
892 .map(|t| t.data().as_slice().expect("contiguous q_norm").to_vec());
893 let k_norm_data = layer
894 .self_attn
895 .k_norm
896 .as_ref()
897 .map(|t| t.data().as_slice().expect("contiguous k_norm").to_vec());
898
899 let block = crate::transformer::CudaNf4TransformerBlock::new(
900 mc,
901 i,
902 ctx.clone(),
903 layer.input_norm.weight.data().as_slice().expect("contiguous"),
904 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
905 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
906 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
907 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
908 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
909 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
910 layer.ffn.w_up.data().as_slice().expect("contiguous"),
911 layer.ffn.w_down.data().as_slice().expect("contiguous"),
912 max_seq_len,
913 Some((&lora_a_q, &lora_b_q)),
914 Some((&lora_a_v, &lora_b_v)),
915 lora_scale,
916 lora_rank,
917 q_norm_data.as_deref(),
918 k_norm_data.as_deref(),
919 )
920 .map_err(|e| {
921 crate::error::Error::ConfigError(format!("NF4 block {i} upload failed: {e:?}"))
922 })?;
923 cuda_blocks.push(CudaBlock::Nf4(block));
924 }
925 println!(" ✓ {num_layers} NF4 transformer blocks uploaded (LoRA rank={lora_rank}, alpha={lora_alpha})");
926 } else {
927 for (i, layer) in model.layers.iter().enumerate() {
928 let block = CudaTransformerBlock::new(
929 mc,
930 i,
931 ctx.clone(),
932 layer.input_norm.weight.data().as_slice().expect("contiguous"),
933 layer.post_attn_norm.weight.data().as_slice().expect("contiguous"),
934 layer.self_attn.w_q.data().as_slice().expect("contiguous"),
935 layer.self_attn.w_k.data().as_slice().expect("contiguous"),
936 layer.self_attn.w_v.data().as_slice().expect("contiguous"),
937 layer.self_attn.w_o.data().as_slice().expect("contiguous"),
938 layer.ffn.w_gate.data().as_slice().expect("contiguous"),
939 layer.ffn.w_up.data().as_slice().expect("contiguous"),
940 layer.ffn.w_down.data().as_slice().expect("contiguous"),
941 max_seq_len,
942 )
943 .map_err(|e| {
944 crate::error::Error::ConfigError(format!("Block {i} upload failed: {e:?}"))
945 })?;
946 cuda_blocks.push(CudaBlock::Fp32(block));
947 }
948 println!(" ✓ {num_layers} transformer blocks uploaded to GPU");
949 }
950
951 Ok(cuda_blocks)
952 }
953
954 fn init_fused_clip(
956 ctx: &std::sync::Arc<trueno_gpu::driver::CudaContext>,
957 config: &TransformerTrainConfig,
958 hidden_size: usize,
959 kv_hidden: usize,
960 mc: &crate::transformer::TransformerConfig,
961 ) -> Option<FusedClipState> {
962 config.base.max_grad_norm?;
963 let grad_sizes: [u32; 9] = [
964 (hidden_size * hidden_size) as u32,
965 (hidden_size * kv_hidden) as u32,
966 (hidden_size * kv_hidden) as u32,
967 (hidden_size * hidden_size) as u32,
968 (hidden_size * mc.intermediate_size) as u32,
969 (hidden_size * mc.intermediate_size) as u32,
970 (mc.intermediate_size * hidden_size) as u32,
971 hidden_size as u32,
972 hidden_size as u32,
973 ];
974 match FusedClipState::new(ctx, &grad_sizes) {
975 Ok(state) => {
976 println!(
977 " ✓ Fused gradient clipping: {} partials ({:.1} KB)",
978 state.total_partials,
979 f64::from(state.total_partials) * 4.0 / 1024.0,
980 );
981 Some(state)
982 }
983 Err(e) => {
984 println!(" ⚠ Fused clip alloc failed ({e:?}), using sync fallback");
985 None
986 }
987 }
988 }
989
990 fn train_step_single(
999 &mut self,
1000 input_ids: &[u32],
1001 target_ids: &[u32],
1002 accumulate_only: bool,
1003 ) -> Option<f32> {
1004 self.profiler.begin_step();
1005 let result = self.train_step_inner(input_ids, target_ids, accumulate_only);
1006 self.profiler.finish_step();
1007 result
1008 }
1009
1010 fn train_step_inner(
1012 &mut self,
1013 input_ids: &[u32],
1014 target_ids: &[u32],
1015 accumulate_only: bool,
1016 ) -> Option<f32> {
1017 let hidden_size = self.config.model_config.hidden_size;
1018 let vocab_size = self.config.model_config.vocab_size;
1019
1020 let max_sl = self.config.max_seq_len;
1022 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
1023 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
1024 let seq_len = input_ids.len();
1025
1026 if self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size).is_none() {
1029 eprintln!(
1030 "[train_step_inner] gpu_forward returned None (seq_len={seq_len}, \
1031 hidden={hidden_size}, vocab={vocab_size}) — CUDA context likely poisoned"
1032 );
1033 return None;
1034 }
1035
1036 self.profiler.begin(StepProfiler::LOSS);
1039 let stream = self.cuda_trainer.stream();
1040
1041 let mut loss_scale = 1.0 / seq_len as f32;
1050 if self.config.accumulation_steps > 1 {
1051 loss_scale /= self.config.accumulation_steps as f32;
1052 }
1053
1054 let loss_val = fused_cross_entropy_cuda(
1056 &mut self.gpu_training.logits_buf,
1057 target_ids,
1058 seq_len as u32,
1059 vocab_size as u32,
1060 loss_scale,
1061 stream,
1062 )
1063 .ok()?;
1064
1065 if !loss_val.is_finite() {
1067 return None;
1068 }
1069 self.profiler.end(StepProfiler::LOSS);
1070
1071 if let Some(grad_output_is_a) =
1080 self.gpu_backward(seq_len, hidden_size, vocab_size, accumulate_only)
1081 {
1082 self.profiler.begin(StepProfiler::EMBED_BWD);
1084 self.embed_backward(input_ids, seq_len, hidden_size, vocab_size, grad_output_is_a);
1085
1086 self.profiler.end(StepProfiler::EMBED_BWD);
1087 }
1088
1089 Some(loss_val)
1090 }
1091
1092 #[allow(unsafe_code)]
1097 fn gpu_forward(
1098 &mut self,
1099 input_ids: &[u32],
1100 seq_len: usize,
1101 hidden_size: usize,
1102 vocab_size: usize,
1103 ) -> Option<()> {
1104 contract_pre_gpu_forward!();
1105 let stream = self.cuda_trainer.stream();
1106
1107 self.profiler.begin(StepProfiler::EMBED);
1109 let hidden = self.model.embed_tokens.forward(input_ids);
1110 let hidden_slice = hidden.data().as_slice()?;
1111 self.profiler.end(StepProfiler::EMBED);
1112
1113 self.profiler.begin(StepProfiler::H2D);
1118 self.h2d_staging[..hidden_slice.len()].copy_from_slice(hidden_slice);
1119 self.h2d_staging[hidden_slice.len()..].fill(0.0);
1120 if let Err(e) = self.fwd_scratch_a.copy_from_host(&self.h2d_staging) {
1121 eprintln!("[gpu_forward] H2D copy failed: {e:?} — CUDA context may be poisoned");
1122 return None;
1123 }
1124 self.profiler.end(StepProfiler::H2D);
1125
1126 self.profiler.begin(StepProfiler::FORWARD);
1130 let mut input_is_a = true; for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
1132 let (input_ptr, output_ptr): (*const GpuBuffer<f32>, *mut GpuBuffer<f32>) =
1135 if input_is_a {
1136 (
1137 std::ptr::from_ref(&self.fwd_scratch_a),
1138 std::ptr::from_mut(&mut self.fwd_scratch_b),
1139 )
1140 } else {
1141 (
1142 std::ptr::from_ref(&self.fwd_scratch_b),
1143 std::ptr::from_mut(&mut self.fwd_scratch_a),
1144 )
1145 };
1146 if self.gpu_training.saved_layer_mask[i] {
1147 unsafe {
1150 self.gpu_training.layer_inputs[i]
1151 .copy_from_buffer_async(&*input_ptr, stream)
1152 .ok()?;
1153 }
1154 }
1155 self.profiler.begin_layer();
1158 unsafe {
1159 block
1160 .forward(
1161 &*input_ptr,
1162 &mut *output_ptr,
1163 seq_len,
1164 stream,
1165 self.nf4_shared_scratch.as_mut(),
1166 )
1167 .ok()?;
1168 }
1169 self.profiler.end_layer_fwd(i);
1170 input_is_a = !input_is_a;
1171 }
1172 self.profiler.end(StepProfiler::FORWARD);
1173
1174 let final_output: &GpuBuffer<f32> =
1176 if input_is_a { &self.fwd_scratch_a } else { &self.fwd_scratch_b };
1177
1178 self.profiler.begin(StepProfiler::NORM_LM);
1181 unsafe {
1182 self.gpu_training.blocks_output.copy_from_buffer_async(final_output, stream).ok()?;
1183 }
1184
1185 rms_norm_forward(
1187 final_output,
1188 &self.gpu_training.final_norm_weight,
1189 &mut self.gpu_training.norm_output,
1190 seq_len as u32,
1191 hidden_size as u32,
1192 stream,
1193 )
1194 .ok()?;
1195
1196 gemm_forward(
1200 &self.gpu_training.norm_output,
1201 &self.lm_head_weight_gpu,
1202 &mut self.gpu_training.logits_buf,
1203 seq_len as u32,
1204 hidden_size as u32,
1205 vocab_size as u32,
1206 stream,
1207 )
1208 .ok()?;
1209
1210 self.profiler.end(StepProfiler::NORM_LM);
1213
1214 Some(())
1215 }
1216
1217 pub fn forward_logits(&mut self, input_ids: &[u32]) -> Option<Vec<f32>> {
1229 let seq_len = input_ids.len();
1230 let hidden_size = self.config.model_config.hidden_size;
1231 let vocab_size = self.config.model_config.vocab_size;
1232
1233 if seq_len == 0 || seq_len > self.config.max_seq_len {
1234 return None;
1235 }
1236
1237 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
1239
1240 let stream = self.cuda_trainer.stream();
1242 stream.synchronize().ok()?;
1243
1244 let offset = (seq_len - 1) * vocab_size;
1246 let mut logits = vec![0.0f32; vocab_size];
1247 self.gpu_training.logits_buf.copy_to_host_at(&mut logits, offset).ok()?;
1248
1249 Some(logits)
1250 }
1251
1252 #[allow(unsafe_code)]
1269 fn recompute_segment(
1270 gpu_training: &mut GpuPretrainState,
1271 cuda_blocks: &mut [CudaBlock],
1272 nf4_shared_scratch: &mut Option<CudaBlockScratch>,
1273 target_layer: usize,
1274 seq_len: usize,
1275 stream: &CudaStream,
1276 ) -> Option<()> {
1277 let seg_start = (0..=target_layer).rev().find(|&i| gpu_training.saved_layer_mask[i])?;
1279
1280 if seg_start == target_layer {
1281 return Some(()); }
1283
1284 let recompute_buf = gpu_training.recompute_buf.as_mut()?;
1287 unsafe {
1288 recompute_buf
1289 .copy_from_buffer_async(&gpu_training.layer_inputs[seg_start], stream)
1290 .ok()?;
1291 }
1292
1293 for i in seg_start..target_layer {
1304 if i == seg_start {
1305 let recompute_ptr: *const GpuBuffer<f32> = recompute_buf;
1307 let li = &mut gpu_training.layer_inputs;
1308 unsafe {
1309 cuda_blocks[i]
1310 .forward(
1311 &*recompute_ptr,
1312 &mut li[i + 1],
1313 seq_len,
1314 stream,
1315 nf4_shared_scratch.as_mut(),
1316 )
1317 .ok()?;
1318 }
1319 } else {
1320 let li = &mut gpu_training.layer_inputs;
1322 let (left, right) = li.split_at_mut(i + 1);
1323 cuda_blocks[i]
1324 .forward(&left[i], &mut right[0], seq_len, stream, nf4_shared_scratch.as_mut())
1325 .ok()?;
1326 }
1327 }
1328
1329 Some(())
1330 }
1331
1332 #[allow(unsafe_code)]
1345 fn gpu_backward(
1346 &mut self,
1347 seq_len: usize,
1348 hidden_size: usize,
1349 vocab_size: usize,
1350 accumulate_only: bool,
1351 ) -> Option<bool> {
1352 let stream = self.cuda_trainer.stream();
1353 let max_grad_norm = self.config.base.max_grad_norm;
1354 let lr = self.current_lr();
1355 let beta1 = self.config.beta1;
1357 let beta2 = self.config.beta2;
1358 let weight_decay = self.config.weight_decay;
1359
1360 self.profiler.begin(StepProfiler::LM_BWD);
1365 gemm_backward_a(
1366 &self.gpu_training.logits_buf,
1367 &self.lm_head_weight_gpu,
1368 &mut self.gpu_training.lm_head_grad_hidden,
1369 seq_len as u32,
1370 hidden_size as u32,
1371 vocab_size as u32,
1372 stream,
1373 )
1374 .ok()?;
1375
1376 gemm_backward_b(
1377 &self.gpu_training.norm_output,
1378 &self.gpu_training.logits_buf,
1379 &mut self.lm_head_grad_gpu,
1380 seq_len as u32,
1381 hidden_size as u32,
1382 vocab_size as u32,
1383 stream,
1384 )
1385 .ok()?;
1386
1387 let lm_sq_norm =
1393 squared_sum_cuda(&self.lm_head_grad_gpu, self.lm_head_grad_gpu.len() as u32, stream)
1394 .unwrap_or(0.0);
1395 let lm_norm = lm_sq_norm.sqrt(); self.last_grad_norm = lm_norm; if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1399 eprintln!("[grad-trace] lm_head gnorm={lm_norm:.6}");
1400 let gh_sq = squared_sum_cuda(
1402 &self.gpu_training.lm_head_grad_hidden,
1403 self.gpu_training.lm_head_grad_hidden.len() as u32,
1404 stream,
1405 )
1406 .unwrap_or(0.0);
1407 eprintln!("[grad-trace] lm_head_grad_hidden gnorm={:.6}", gh_sq.sqrt());
1408 }
1409 if let Some(max_norm) = max_grad_norm {
1410 let clip_scale = if lm_norm > max_norm { max_norm / lm_norm } else { 1.0 };
1411 let n = self.lm_head_grad_gpu.len() as u32;
1412 let _ = gradient_clip_cuda(&mut self.lm_head_grad_gpu, clip_scale, n, stream);
1413 }
1414 self.profiler.end(StepProfiler::LM_BWD);
1415
1416 self.profiler.begin(StepProfiler::NORM_BWD);
1418 self.gpu_training.grad_final_norm_weight.copy_from_host(&self.final_norm_zero_buf).ok()?;
1420 rms_norm_backward(
1421 &self.gpu_training.blocks_output,
1422 &self.gpu_training.final_norm_weight,
1423 &self.gpu_training.lm_head_grad_hidden,
1424 &mut self.gpu_training.grad_buf_a,
1425 &mut self.gpu_training.grad_final_norm_weight,
1426 seq_len as u32,
1427 hidden_size as u32,
1428 1e-5_f32,
1429 stream,
1430 )
1431 .ok()?;
1432
1433 if let Some(max_norm) = max_grad_norm {
1436 let (scale, _) = Self::compute_clip_scale_with_norm(
1437 &self.gpu_training.grad_final_norm_weight,
1438 max_norm,
1439 stream,
1440 );
1441 let n = self.gpu_training.grad_final_norm_weight.len() as u32;
1442 let _ =
1443 gradient_clip_cuda(&mut self.gpu_training.grad_final_norm_weight, scale, n, stream);
1444 }
1445 self.profiler.end(StepProfiler::NORM_BWD);
1446
1447 if accumulate_only {
1449 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1451 let _ = gpu_accum.accumulate_nonblock(
1452 &self.lm_head_grad_gpu,
1453 &self.gpu_training.grad_final_norm_weight,
1454 stream,
1455 );
1456 } else {
1457 stream.synchronize().ok()?;
1458 Self::download_nonblock_grads_to_accum(
1459 &self.lm_head_grad_gpu,
1460 &self.gpu_training.grad_final_norm_weight,
1461 &mut self.grad_accum,
1462 &mut self.d2h_staging,
1463 )?;
1464 }
1465 } else {
1466 Self::run_nonblock_optimizer_step(
1467 &mut self.gpu_training,
1468 Some(&mut self.lm_head_weight_gpu),
1469 &self.lm_head_grad_gpu,
1470 &mut self.lm_head_m,
1471 &mut self.lm_head_v,
1472 &mut self.final_norm_m,
1473 &mut self.final_norm_v,
1474 lr,
1475 beta1,
1476 beta2,
1477 weight_decay,
1478 stream,
1479 );
1480 }
1481
1482 self.profiler.begin(StepProfiler::BLK_BWD);
1488 let grad_a_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_a;
1489 let grad_b_ptr: *mut GpuBuffer<f32> = &raw mut self.gpu_training.grad_buf_b;
1490 let mut grad_output_is_a = true;
1491 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
1492
1493 for layer_idx in (0..self.cuda_blocks.len()).rev() {
1494 if !self.gpu_training.saved_layer_mask[layer_idx] {
1497 Self::recompute_segment(
1498 &mut self.gpu_training,
1499 &mut self.cuda_blocks,
1500 &mut self.nf4_shared_scratch,
1501 layer_idx,
1502 seq_len,
1503 stream,
1504 )?;
1505 }
1506
1507 let (grad_output, grad_input) = unsafe {
1508 if grad_output_is_a {
1509 (&*grad_a_ptr, &mut *grad_b_ptr)
1510 } else {
1511 (&*grad_b_ptr, &mut *grad_a_ptr)
1512 }
1513 };
1514
1515 self.profiler.begin_layer();
1516 if use_nf4 {
1517 let _output_scratch_ptr: *mut GpuBuffer<f32> = if grad_output_is_a {
1521 grad_b_ptr } else {
1523 grad_a_ptr
1524 };
1525 match self.cuda_blocks[layer_idx].backward_nf4(
1528 &self.gpu_training.layer_inputs[layer_idx],
1529 grad_output,
1530 grad_input,
1531 &mut self.gpu_training.blocks_output, seq_len,
1533 stream,
1534 self.nf4_shared_scratch.as_mut().expect("NF4 requires shared scratch"),
1535 self.nf4_lora_grad_workspace
1536 .as_mut()
1537 .expect("NF4 requires LoRA grad workspace"),
1538 ) {
1539 Ok(()) => {}
1540 Err(e) => {
1541 eprintln!(
1542 "[backward_nf4] Layer {} FAILED: {:?} (seq_len={}, hidden={})",
1543 layer_idx, e, seq_len, self.config.model_config.hidden_size
1544 );
1545 return None;
1546 }
1547 }
1548
1549 if let Some(max_norm) = max_grad_norm {
1553 self.nf4_lora_grad_workspace
1554 .as_mut()
1555 .expect("NF4 requires LoRA grad ws")
1556 .clip_gradients(max_norm, stream);
1557 }
1558
1559 {
1565 let step = self.gpu_training.step;
1566 let effective_lr = if accumulate_only {
1567 lr / self.config.accumulation_steps as f32
1568 } else {
1569 lr
1570 };
1571 if let Some(ref mut opt_states) = self.nf4_lora_optimizer_states {
1572 let _ = self.cuda_blocks[layer_idx].lora_optimizer_step(
1573 &mut opt_states[layer_idx],
1574 step,
1575 effective_lr,
1576 beta1,
1577 beta2,
1578 1e-8,
1579 weight_decay,
1580 stream,
1581 self.nf4_lora_grad_workspace
1582 .as_ref()
1583 .expect("NF4 requires LoRA grad ws"),
1584 );
1585 }
1586 }
1587 } else {
1588 self.cuda_blocks[layer_idx]
1590 .backward(
1591 &self.gpu_training.layer_inputs[layer_idx],
1592 grad_output,
1593 grad_input,
1594 seq_len,
1595 stream,
1596 &mut self.cuda_grad_workspace,
1597 )
1598 .ok()?;
1599
1600 if std::env::var("ENTRENAR_TRACE_GRADIENTS").is_ok() {
1606 let (_, block_gnorm) = compute_workspace_clip_scale_gpu(
1607 &self.cuda_grad_workspace,
1608 f32::MAX,
1609 stream,
1610 );
1611 let act_sq = squared_sum_cuda(grad_input, grad_input.len() as u32, stream)
1613 .unwrap_or(0.0);
1614 let act_gnorm = act_sq.sqrt();
1615 eprintln!(
1616 "[grad-trace] block={layer_idx} weight_gnorm={block_gnorm:.6} act_gnorm={act_gnorm:.6}"
1617 );
1618 }
1619
1620 if accumulate_only {
1622 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1624 let _ = gpu_accum.accumulate_block(
1625 &self.cuda_grad_workspace,
1626 layer_idx,
1627 stream,
1628 );
1629 } else {
1630 stream.synchronize().ok()?;
1632 if let Some(accum) = &mut self.grad_accum {
1633 Self::download_workspace_to_accum(
1634 &self.cuda_grad_workspace,
1635 accum,
1636 layer_idx,
1637 &mut self.d2h_staging,
1638 )?;
1639 }
1640 }
1641 } else {
1642 let step = self.gpu_training.step;
1644 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1645 &mut self.gpu_training.optimizer_states[layer_idx],
1646 step,
1647 lr,
1648 beta1,
1649 beta2,
1650 1e-8,
1651 weight_decay,
1652 stream,
1653 &self.cuda_grad_workspace,
1654 );
1655 }
1656 }
1657
1658 self.profiler.end_layer_bwd(layer_idx);
1659 grad_output_is_a = !grad_output_is_a;
1660 }
1661
1662 stream.synchronize().ok()?;
1663 self.profiler.end(StepProfiler::BLK_BWD);
1664
1665 Some(grad_output_is_a)
1666 }
1667
1668 fn download_nonblock_grads_to_accum(
1674 lm_head_grad: &GpuBuffer<f32>,
1675 final_norm_grad: &GpuBuffer<f32>,
1676 grad_accum: &mut Option<super::grad_accumulator::PerBlockGradientAccumulator>,
1677 host: &mut [f32],
1678 ) -> Option<()> {
1679 let accum = grad_accum.as_mut()?;
1680
1681 let lm_slice = &mut host[..lm_head_grad.len()];
1682 lm_head_grad.copy_to_host_at(lm_slice, 0).ok()?;
1683 for (d, s) in accum.lm_head_grad.iter_mut().zip(lm_slice.iter()) {
1684 *d += s;
1685 }
1686
1687 let norm_slice = &mut host[..final_norm_grad.len()];
1688 final_norm_grad.copy_to_host_at(norm_slice, 0).ok()?;
1689 for (d, s) in accum.final_norm_grad.iter_mut().zip(norm_slice.iter()) {
1690 *d += s;
1691 }
1692 Some(())
1693 }
1694
1695 #[allow(clippy::too_many_arguments)]
1698 fn run_nonblock_optimizer_step(
1699 gpu_training: &mut GpuPretrainState,
1700 lm_head_weight_gpu: Option<&mut GpuBuffer<f32>>,
1701 lm_head_grad_gpu: &GpuBuffer<f32>,
1702 lm_head_m: &mut GpuBuffer<f32>,
1703 lm_head_v: &mut GpuBuffer<f32>,
1704 final_norm_m: &mut GpuBuffer<f32>,
1705 final_norm_v: &mut GpuBuffer<f32>,
1706 lr: f32,
1707 beta1: f32,
1708 beta2: f32,
1709 weight_decay: f32,
1710 stream: &CudaStream,
1711 ) {
1712 gpu_training.step += 1;
1713 let step = gpu_training.step;
1714
1715 if let Some(lm_head_weight) = lm_head_weight_gpu {
1716 let n_lm = lm_head_weight.len() as u32;
1717 let _ = adamw_step_cuda(
1718 lm_head_weight,
1719 lm_head_grad_gpu,
1720 lm_head_m,
1721 lm_head_v,
1722 lr,
1723 beta1,
1724 beta2,
1725 1e-8,
1726 weight_decay,
1727 step,
1728 n_lm,
1729 stream,
1730 );
1731 }
1732
1733 let n_norm = gpu_training.final_norm_weight.len() as u32;
1734 let _ = adamw_step_cuda(
1735 &mut gpu_training.final_norm_weight,
1736 &gpu_training.grad_final_norm_weight,
1737 final_norm_m,
1738 final_norm_v,
1739 lr,
1740 beta1,
1741 beta2,
1742 1e-8,
1743 weight_decay,
1744 step,
1745 n_norm,
1746 stream,
1747 );
1748 }
1749
1750 fn download_workspace_to_accum(
1758 ws: &CudaGradWorkspace,
1759 accum: &mut super::grad_accumulator::PerBlockGradientAccumulator,
1760 layer_idx: usize,
1761 host: &mut [f32],
1762 ) -> Option<()> {
1763 let bg = &mut accum.block_grads[layer_idx];
1764
1765 use super::grad_accumulator::component;
1766 let bufs_and_components: [(&GpuBuffer<f32>, usize); 9] = [
1767 (&ws.grad_w_q, component::W_Q),
1768 (&ws.grad_w_k, component::W_K),
1769 (&ws.grad_w_v, component::W_V),
1770 (&ws.grad_w_o, component::W_O),
1771 (&ws.grad_gate, component::GATE),
1772 (&ws.grad_up, component::UP),
1773 (&ws.grad_down, component::DOWN),
1774 (&ws.grad_input_norm, component::INPUT_NORM),
1775 (&ws.grad_post_attn_norm, component::POST_ATTN_NORM),
1776 ];
1777
1778 for (gpu_buf, comp_idx) in &bufs_and_components {
1779 let slice = &mut host[..gpu_buf.len()];
1780 gpu_buf.copy_to_host_at(slice, 0).ok()?;
1781 for (d, s) in bg.components[*comp_idx].iter_mut().zip(slice.iter()) {
1782 *d += s;
1783 }
1784 }
1785 Some(())
1786 }
1787
1788 fn gpu_optimizer_from_gpu_accum(&mut self) -> Option<()> {
1795 let stream = self.cuda_trainer.stream();
1796 let lr = self.current_lr();
1797 let beta1 = self.config.beta1;
1798 let beta2 = self.config.beta2;
1799 let weight_decay = self.config.weight_decay;
1800
1801 stream.synchronize().ok()?;
1803
1804 self.gpu_training.step += 1;
1805 let step = self.gpu_training.step;
1806
1807 let gpu_accum = self.gpu_grad_accum.as_ref()?;
1809 for layer_idx in 0..self.cuda_blocks.len() {
1810 gpu_accum.upload_to_workspace(&mut self.cuda_grad_workspace, layer_idx).ok()?;
1811
1812 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1813 &mut self.gpu_training.optimizer_states[layer_idx],
1814 step,
1815 lr,
1816 beta1,
1817 beta2,
1818 1e-8,
1819 weight_decay,
1820 stream,
1821 &self.cuda_grad_workspace,
1822 );
1823 }
1824
1825 gpu_accum
1827 .upload_nonblock(
1828 &mut self.lm_head_grad_gpu,
1829 &mut self.gpu_training.grad_final_norm_weight,
1830 )
1831 .ok()?;
1832
1833 let n_lm = self.lm_head_weight_gpu.len() as u32;
1834 let _ = adamw_step_cuda(
1835 &mut self.lm_head_weight_gpu,
1836 &self.lm_head_grad_gpu,
1837 &mut self.lm_head_m,
1838 &mut self.lm_head_v,
1839 lr,
1840 beta1,
1841 beta2,
1842 1e-8,
1843 weight_decay,
1844 step,
1845 n_lm,
1846 stream,
1847 );
1848
1849 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1851 let _ = adamw_step_cuda(
1852 &mut self.gpu_training.final_norm_weight,
1853 &self.gpu_training.grad_final_norm_weight,
1854 &mut self.final_norm_m,
1855 &mut self.final_norm_v,
1856 lr,
1857 beta1,
1858 beta2,
1859 1e-8,
1860 weight_decay,
1861 step,
1862 n_norm,
1863 stream,
1864 );
1865
1866 stream.synchronize().ok()?;
1867
1868 if let Some(ref mut gpu_accum) = self.gpu_grad_accum {
1870 let _ = gpu_accum.zero_all();
1871 }
1872
1873 Some(())
1874 }
1875
1876 #[allow(unsafe_code)]
1877 fn gpu_optimizer_from_accum(&mut self) -> Option<()> {
1878 let stream = self.cuda_trainer.stream();
1879 let lr = self.current_lr();
1880 let beta1 = self.config.beta1;
1881 let beta2 = self.config.beta2;
1882 let weight_decay = self.config.weight_decay;
1883
1884 let accum = self.grad_accum.as_mut()?;
1886 accum.average();
1887
1888 if accum.has_non_finite() {
1890 println!("[WARN] R-038: NaN/Inf in accumulated gradients, skipping optimizer step");
1891 accum.zero_all();
1892 return Some(());
1893 }
1894
1895 self.gpu_training.step += 1;
1896 let step = self.gpu_training.step;
1897
1898 use super::grad_accumulator::component;
1900 for layer_idx in 0..self.cuda_blocks.len() {
1901 let bg = &accum.block_grads[layer_idx];
1902
1903 unsafe {
1907 self.cuda_grad_workspace
1908 .grad_w_q
1909 .copy_from_host_async(&bg.components[component::W_Q], stream)
1910 .ok()?;
1911 self.cuda_grad_workspace
1912 .grad_w_k
1913 .copy_from_host_async(&bg.components[component::W_K], stream)
1914 .ok()?;
1915 self.cuda_grad_workspace
1916 .grad_w_v
1917 .copy_from_host_async(&bg.components[component::W_V], stream)
1918 .ok()?;
1919 self.cuda_grad_workspace
1920 .grad_w_o
1921 .copy_from_host_async(&bg.components[component::W_O], stream)
1922 .ok()?;
1923 self.cuda_grad_workspace
1924 .grad_gate
1925 .copy_from_host_async(&bg.components[component::GATE], stream)
1926 .ok()?;
1927 self.cuda_grad_workspace
1928 .grad_up
1929 .copy_from_host_async(&bg.components[component::UP], stream)
1930 .ok()?;
1931 self.cuda_grad_workspace
1932 .grad_down
1933 .copy_from_host_async(&bg.components[component::DOWN], stream)
1934 .ok()?;
1935 self.cuda_grad_workspace
1936 .grad_input_norm
1937 .copy_from_host_async(&bg.components[component::INPUT_NORM], stream)
1938 .ok()?;
1939 self.cuda_grad_workspace
1940 .grad_post_attn_norm
1941 .copy_from_host_async(&bg.components[component::POST_ATTN_NORM], stream)
1942 .ok()?;
1943 }
1944
1945 let _ = self.cuda_blocks[layer_idx].optimizer_step(
1947 &mut self.gpu_training.optimizer_states[layer_idx],
1948 step,
1949 lr,
1950 beta1,
1951 beta2,
1952 1e-8,
1953 weight_decay,
1954 stream,
1955 &self.cuda_grad_workspace,
1956 );
1957 }
1958
1959 unsafe {
1963 self.lm_head_grad_gpu.copy_from_host_async(&accum.lm_head_grad, stream).ok()?;
1964 }
1965 let n_lm = self.lm_head_weight_gpu.len() as u32;
1966 let _ = adamw_step_cuda(
1967 &mut self.lm_head_weight_gpu,
1968 &self.lm_head_grad_gpu,
1969 &mut self.lm_head_m,
1970 &mut self.lm_head_v,
1971 lr,
1972 beta1,
1973 beta2,
1974 1e-8,
1975 weight_decay,
1976 step,
1977 n_lm,
1978 stream,
1979 );
1980
1981 unsafe {
1984 self.gpu_training
1985 .grad_final_norm_weight
1986 .copy_from_host_async(&accum.final_norm_grad, stream)
1987 .ok()?;
1988 }
1989 let n_norm = self.gpu_training.final_norm_weight.len() as u32;
1990 let _ = adamw_step_cuda(
1991 &mut self.gpu_training.final_norm_weight,
1992 &self.gpu_training.grad_final_norm_weight,
1993 &mut self.final_norm_m,
1994 &mut self.final_norm_v,
1995 lr,
1996 beta1,
1997 beta2,
1998 1e-8,
1999 weight_decay,
2000 step,
2001 n_norm,
2002 stream,
2003 );
2004
2005 stream.synchronize().ok()?;
2006
2007 accum.zero_all();
2009 Some(())
2010 }
2011
2012 fn compute_clip_scale_with_norm(
2025 buf: &GpuBuffer<f32>,
2026 max_norm: f32,
2027 stream: &CudaStream,
2028 ) -> (f32, f32) {
2029 let n = buf.len() as u32;
2030 let grad_norm = match squared_sum_cuda(buf, n, stream) {
2032 Ok(norm) => norm,
2033 Err(_) => {
2034 let mut host = vec![0.0f32; buf.len()];
2036 if buf.copy_to_host_at(&mut host, 0).is_err() {
2037 return (1.0, 0.0);
2038 }
2039 let sq_sum: f64 = host.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2040 sq_sum.sqrt() as f32
2041 }
2042 };
2043 let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
2044 (scale, grad_norm)
2045 }
2046
2047 #[allow(unsafe_code)]
2056 fn embed_backward(
2057 &mut self,
2058 input_ids: &[u32],
2059 _seq_len: usize,
2060 hidden_size: usize,
2061 vocab_size: usize,
2062 grad_output_is_a: bool,
2063 ) -> Option<()> {
2064 let grad_a_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_a;
2066 let grad_b_ptr: *const GpuBuffer<f32> = &raw const self.gpu_training.grad_buf_b;
2067 let embed_grad_buf = unsafe {
2068 if grad_output_is_a {
2069 &*grad_a_ptr
2070 } else {
2071 &*grad_b_ptr
2072 }
2073 };
2074 let mut embed_grad_data = self.cuda_trainer.download(embed_grad_buf).ok()?;
2075
2076 let embed_clip_norm = self.config.base.max_grad_norm.unwrap_or(1.0);
2084 {
2085 let sq_sum: f64 = embed_grad_data.iter().map(|&x| f64::from(x) * f64::from(x)).sum();
2086 let grad_norm = sq_sum.sqrt() as f32;
2087 self.last_embed_grad_norm = grad_norm; if grad_norm > embed_clip_norm {
2089 let scale = embed_clip_norm / grad_norm;
2090 for g in &mut embed_grad_data {
2091 *g *= scale;
2092 }
2093 }
2094 }
2095
2096 let embed_weight = &mut self.model.embed_tokens.weight;
2100 let grad_cell = embed_weight.grad_cell();
2101 let mut grad_ref = grad_cell.borrow_mut();
2102 if grad_ref.is_none() {
2103 *grad_ref = Some(ndarray::Array1::zeros(embed_weight.len()));
2104 }
2105 if let Some(grad) = grad_ref.as_mut() {
2106 for (pos, &token_id) in input_ids.iter().enumerate() {
2107 let tid = token_id as usize;
2108 if tid < vocab_size {
2109 let src = pos * hidden_size;
2110 let dst = tid * hidden_size;
2111 for h in 0..hidden_size {
2112 grad[dst + h] += embed_grad_data[src + h];
2113 }
2114 }
2115 }
2116 }
2117 Some(())
2118 }
2119
2120 fn optimizer_step(&mut self) {
2126 self.grad_scaler.update(true);
2130
2131 self.embed_optimizer.set_lr(self.current_lr());
2133 let mut embed_params = vec![&mut self.model.embed_tokens.weight];
2135 self.embed_optimizer.step_refs(&mut embed_params);
2136
2137 self.step += 1;
2138 self.metrics.losses.push(self.accumulated_loss);
2139 self.metrics.increment_step();
2140
2141 self.accumulated_loss = 0.0;
2142 self.accumulated_batches = 0;
2143 }
2144
2145 pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
2157 if batch.batch_size == 0 {
2158 return 0.0;
2159 }
2160
2161 let accumulating = self.grad_accum.is_some() || self.gpu_grad_accum.is_some();
2162
2163 if self.accumulated_batches == 0 {
2164 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2166 }
2167
2168 let mut total_loss = 0.0;
2169 let mut valid_count = 0;
2170
2171 for i in 0..batch.batch_size {
2172 let Some(input_ids) = batch.get_input(i) else {
2173 continue;
2174 };
2175 let Some(target_ids) = batch.get_target(i) else {
2176 continue;
2177 };
2178
2179 if let Some(loss) = self.train_step_single(input_ids, target_ids, accumulating) {
2183 total_loss += loss;
2184 valid_count += 1;
2185 if accumulating {
2186 if let Some(accum) = &mut self.gpu_grad_accum {
2187 accum.accumulated_count += 1;
2188 } else if let Some(accum) = &mut self.grad_accum {
2189 accum.accumulated_count += 1;
2190 }
2191 }
2192 }
2193 }
2194
2195 let avg_loss = if valid_count > 0 { total_loss / valid_count as f32 } else { 0.0 };
2196
2197 if avg_loss == 0.0 && valid_count > 0 {
2199 eprintln!(
2200 "[train_batch DEBUG] avg_loss=0.0 but valid_count={}, total_loss={}, batch_size={}",
2201 valid_count, total_loss, batch.batch_size
2202 );
2203 }
2204
2205 self.accumulated_loss += avg_loss / self.config.accumulation_steps as f32;
2206 self.accumulated_batches += 1;
2207
2208 if self.accumulated_batches >= self.config.accumulation_steps {
2209 if accumulating {
2210 if self.gpu_grad_accum.is_some() {
2212 self.gpu_optimizer_from_gpu_accum();
2213 } else {
2214 self.gpu_optimizer_from_accum();
2215 }
2216 }
2217 self.optimizer_step();
2218 }
2219
2220 avg_loss
2221 }
2222
2223 pub fn eval_batch(&mut self, batch: &LMBatch) -> f32 {
2227 let hidden_size = self.config.model_config.hidden_size;
2228 let vocab_size = self.config.model_config.vocab_size;
2229 let max_sl = self.config.max_seq_len;
2230 let mut total_loss = 0.0;
2231 let mut valid_count = 0;
2232 for i in 0..batch.batch_size {
2233 if let Some(loss) = self.eval_single_sequence(batch, i, max_sl, hidden_size, vocab_size)
2234 {
2235 total_loss += loss;
2236 valid_count += 1;
2237 }
2238 }
2239 if valid_count > 0 {
2240 total_loss / valid_count as f32
2241 } else {
2242 0.0
2243 }
2244 }
2245
2246 fn eval_single_sequence(
2248 &mut self,
2249 batch: &LMBatch,
2250 i: usize,
2251 max_sl: usize,
2252 hidden_size: usize,
2253 vocab_size: usize,
2254 ) -> Option<f32> {
2255 let input_ids = batch.get_input(i)?;
2256 let target_ids = batch.get_target(i)?;
2257 let input_ids = if input_ids.len() > max_sl { &input_ids[..max_sl] } else { input_ids };
2259 let target_ids = if target_ids.len() > max_sl { &target_ids[..max_sl] } else { target_ids };
2260 let seq_len = input_ids.len();
2261 self.gpu_forward(input_ids, seq_len, hidden_size, vocab_size)?;
2262 let stream = self.cuda_trainer.stream();
2263 let scale = 1.0 / seq_len as f32;
2264 let loss = fused_cross_entropy_cuda(
2265 &mut self.gpu_training.logits_buf,
2266 target_ids,
2267 seq_len as u32,
2268 vocab_size as u32,
2269 scale,
2270 stream,
2271 )
2272 .ok()?;
2273 if loss.is_finite() {
2274 Some(loss)
2275 } else {
2276 None
2277 }
2278 }
2279
2280 pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
2282 self.train_epoch_with_callback(batches, |_, _, _| {})
2283 }
2284
2285 pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
2289 where
2290 F: FnMut(usize, f32, &Self),
2291 {
2292 if batches.is_empty() {
2293 return 0.0;
2294 }
2295
2296 let mut total_loss = 0.0;
2297 let mut batches_processed = 0;
2298
2299 for (i, batch) in batches.iter().enumerate() {
2300 if let Some(max) = self.config.max_steps {
2301 if self.step >= max {
2302 break;
2303 }
2304 }
2305
2306 let batch_loss = self.train_batch(batch);
2307 total_loss += batch_loss;
2308 batches_processed += 1;
2309 on_batch(i, batch_loss, self);
2310 }
2311
2312 if self.profiler.is_enabled() && self.profiler.step_count() > 0 {
2314 self.profiler.print_report();
2315 }
2316
2317 total_loss / batches_processed.max(1) as f32
2318 }
2319
2320 pub(crate) fn ensure_grad_accum(&mut self) {
2327 if self.grad_accum.is_some() {
2328 return;
2329 }
2330 let mc = &self.config.model_config;
2331 let hidden_size = mc.hidden_size;
2332 let kv_hidden = mc.num_kv_heads * mc.head_dim();
2333 let block_sizes = super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(
2334 hidden_size,
2335 kv_hidden,
2336 mc.intermediate_size,
2337 );
2338 self.grad_accum = Some(super::grad_accumulator::PerBlockGradientAccumulator::new(
2339 self.cuda_blocks.len(),
2340 block_sizes,
2341 mc.vocab_size,
2342 hidden_size,
2343 ));
2344 }
2345
2346 pub(crate) fn forward_backward_batch(&mut self, batch: &LMBatch) -> f32 {
2351 if batch.batch_size == 0 {
2352 return 0.0;
2353 }
2354
2355 if self.accumulated_batches == 0 {
2356 self.embed_optimizer.zero_grad_refs(&mut vec![&mut self.model.embed_tokens.weight]);
2357 }
2358
2359 let mut total_loss = 0.0;
2360 let mut valid_count = 0;
2361
2362 for i in 0..batch.batch_size {
2363 let Some(input_ids) = batch.get_input(i) else { continue };
2364 let Some(target_ids) = batch.get_target(i) else { continue };
2365
2366 if let Some(loss) = self.train_step_single(input_ids, target_ids, true) {
2368 total_loss += loss;
2369 valid_count += 1;
2370 if let Some(accum) = &mut self.grad_accum {
2371 accum.accumulated_count += 1;
2372 }
2373 }
2374 }
2375
2376 if valid_count > 0 {
2377 total_loss / valid_count as f32
2378 } else {
2379 0.0
2380 }
2381 }
2382
2383 pub(crate) fn apply_ddp_gradients(&mut self) {
2389 self.accumulated_loss = 0.0;
2390 self.accumulated_batches = 0;
2391 self.gpu_optimizer_from_accum();
2392 self.optimizer_step();
2393 }
2394
2395 pub(crate) fn grad_accum_ref(
2397 &self,
2398 ) -> Option<&super::grad_accumulator::PerBlockGradientAccumulator> {
2399 self.grad_accum.as_ref()
2400 }
2401
2402 pub(crate) fn grad_accum_mut(
2404 &mut self,
2405 ) -> Option<&mut super::grad_accumulator::PerBlockGradientAccumulator> {
2406 self.grad_accum.as_mut()
2407 }
2408
2409 pub(crate) fn config(&self) -> &TransformerTrainConfig {
2411 &self.config
2412 }
2413
2414 pub(crate) fn embed_grad_vec(&self) -> Option<Vec<f32>> {
2416 self.model.embed_tokens.weight.grad().map(|g| g.to_vec())
2417 }
2418
2419 pub(crate) fn set_embed_grad(&mut self, grad: Vec<f32>) {
2421 self.model.embed_tokens.weight.set_grad(ndarray::Array1::from(grad));
2422 }
2423
2424 pub fn reached_max_steps(&self) -> bool {
2426 self.config.max_steps.is_some_and(|max| self.step >= max)
2427 }
2428
2429 pub fn step(&self) -> usize {
2431 self.step
2432 }
2433
2434 pub fn set_initial_step(&mut self, step: usize) {
2440 self.step = step;
2441 self.gpu_training.step = step as u32;
2442 }
2443
2444 pub fn set_max_steps(&mut self, max_steps: usize) {
2450 self.config.max_steps = Some(max_steps);
2451 }
2452
2453 pub fn current_lr(&self) -> f32 {
2459 let base_lr = self.config.lr;
2460 if self.step < self.config.warmup_steps {
2461 base_lr * (self.step as f32 / self.config.warmup_steps.max(1) as f32)
2463 } else if let Some(max_steps) = self.config.max_steps {
2464 let decay_steps = max_steps.saturating_sub(self.config.warmup_steps);
2466 if decay_steps == 0 {
2467 return base_lr;
2468 }
2469 let decay_step = self.step - self.config.warmup_steps;
2470 let progress = (decay_step as f32 / decay_steps as f32).min(1.0);
2471 0.5 * base_lr * (1.0 + (std::f32::consts::PI * progress).cos())
2472 } else {
2473 base_lr
2475 }
2476 }
2477
2478 pub fn enable_profiler(&mut self, interval: usize) {
2489 self.profiler = StepProfiler::new(true, interval);
2490 }
2491
2492 pub fn print_profiler_report(&self) {
2494 self.profiler.print_report();
2495 }
2496
2497 pub fn last_grad_norm(&self) -> f32 {
2499 self.last_grad_norm
2500 }
2501
2502 pub fn param_grad_norms(&self) -> (f32, f32) {
2505 (self.last_grad_norm, self.last_embed_grad_norm)
2506 }
2507
2508 pub fn num_params(&self) -> usize {
2510 self.model.parameters().iter().map(|t| t.len()).sum()
2511 }
2512
2513 pub fn gpu_memory_mb(&self) -> (u64, u64) {
2515 match self.cuda_trainer.context().memory_info() {
2516 Ok((free, total)) => {
2517 let total_mb = (total / (1024 * 1024)) as u64;
2518 let used_mb = ((total - free) / (1024 * 1024)) as u64;
2519 (used_mb, total_mb)
2520 }
2521 Err(_) => (0, 0),
2522 }
2523 }
2524
2525 pub fn sync_weights_to_cpu(&mut self) {
2531 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2532
2533 if use_nf4 {
2534 } else {
2540 for (layer_idx, block) in self.cuda_blocks.iter().enumerate() {
2541 if let Ok(weights) = block.download_weights() {
2542 let layer = &mut self.model.layers[layer_idx];
2543
2544 layer.self_attn.w_q = Tensor::from_vec(weights.w_q, false);
2545 layer.self_attn.w_k = Tensor::from_vec(weights.w_k, false);
2546 layer.self_attn.w_v = Tensor::from_vec(weights.w_v, false);
2547 layer.self_attn.w_o = Tensor::from_vec(weights.w_o, false);
2548
2549 layer.ffn.w_gate = Tensor::from_vec(weights.w_gate, false);
2550 layer.ffn.w_up = Tensor::from_vec(weights.w_up, false);
2551 layer.ffn.w_down = Tensor::from_vec(weights.w_down, false);
2552
2553 layer.input_norm.weight = Tensor::from_vec(weights.input_norm_weight, false);
2554 layer.post_attn_norm.weight =
2555 Tensor::from_vec(weights.post_attn_norm_weight, false);
2556 }
2557 }
2558 }
2559
2560 if let Ok(norm_data) = self.cuda_trainer.download(&self.gpu_training.final_norm_weight) {
2562 self.model.norm.weight = Tensor::from_vec(norm_data, false);
2563 }
2564
2565 if let Ok(lm_data) = self.cuda_trainer.download(&self.lm_head_weight_gpu) {
2572 self.model.lm_head = Some(Tensor::from_vec(lm_data, false));
2573 }
2574 }
2575
2576 pub fn model(&self) -> &Transformer {
2578 &self.model
2579 }
2580
2581 pub fn model_mut(&mut self) -> &mut Transformer {
2583 &mut self.model
2584 }
2585
2586 pub fn is_mixed_precision(&self) -> bool {
2588 self.config.precision_config.is_mixed()
2589 }
2590
2591 pub fn grad_scaler(&self) -> &GradScaler {
2593 &self.grad_scaler
2594 }
2595
2596 pub fn is_checkpointing(&self) -> bool {
2598 self.config.checkpoint_config.enabled
2599 }
2600
2601 pub fn save(
2603 &mut self,
2604 path: impl AsRef<std::path::Path>,
2605 name: &str,
2606 architecture: &str,
2607 ) -> crate::Result<()> {
2608 self.sync_weights_to_cpu();
2609
2610 let params: Vec<(String, Tensor)> = self
2612 .model
2613 .named_parameters()
2614 .into_iter()
2615 .map(|(name, tensor)| (name, tensor.clone()))
2616 .collect();
2617
2618 let metadata = ModelMetadata::new(name, architecture);
2619 let model = Model::new(metadata, params);
2620 let config = SaveConfig::new(ModelFormat::SafeTensors);
2621
2622 save_model(&model, path, &config)
2623 }
2624
2625 pub fn prepare_async_save(
2629 &mut self,
2630 name: &str,
2631 architecture: &str,
2632 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2633 self.sync_weights_to_cpu();
2634
2635 let param_data: Vec<(String, Vec<f32>)> = self
2637 .model
2638 .named_parameters()
2639 .into_iter()
2640 .map(|(n, t)| (n, t.data().to_vec()))
2641 .collect();
2642
2643 let name = name.to_string();
2644 let architecture = architecture.to_string();
2645
2646 Box::new(move |path: &std::path::Path| {
2647 let params: Vec<(String, Tensor)> =
2648 param_data.into_iter().map(|(n, d)| (n, Tensor::from_vec(d, false))).collect();
2649 let metadata = ModelMetadata::new(&name, &architecture);
2650 let model = Model::new(metadata, params);
2651 let config = SaveConfig::new(ModelFormat::SafeTensors);
2652 save_model(&model, path, &config)
2653 })
2654 }
2655
2656 pub fn save_apr(
2661 &mut self,
2662 path: impl AsRef<std::path::Path>,
2663 name: &str,
2664 architecture: &str,
2665 ) -> crate::Result<()> {
2666 self.sync_weights_to_cpu();
2667
2668 let params: Vec<(String, Tensor)> = self
2669 .model
2670 .named_parameters()
2671 .into_iter()
2672 .map(|(name, tensor)| (name, tensor.clone()))
2673 .collect();
2674
2675 let metadata = ModelMetadata::new(name, architecture);
2676 let model = Model::new(metadata, params);
2677 let config = SaveConfig::new(ModelFormat::Apr);
2678
2679 save_model(&model, path, &config)
2680 }
2681
2682 fn snapshot_param_data(&self) -> Vec<(String, Vec<f32>)> {
2689 let use_nf4 = self.config.quantize_nf4 && self.config.is_lora();
2690 if use_nf4 {
2691 let frozen_suffixes = [
2692 "q_proj.weight",
2693 "k_proj.weight",
2694 "v_proj.weight",
2695 "o_proj.weight",
2696 "gate_proj.weight",
2697 "up_proj.weight",
2698 "down_proj.weight",
2699 ];
2700 self.model
2701 .named_parameters()
2702 .into_iter()
2703 .filter(|(n, _)| !frozen_suffixes.iter().any(|s| n.ends_with(s)))
2704 .map(|(n, t)| (n, t.data().to_vec()))
2705 .collect()
2706 } else {
2707 self.model.named_parameters().into_iter().map(|(n, t)| (n, t.data().to_vec())).collect()
2708 }
2709 }
2710
2711 fn snapshot_lora_data(&self) -> Vec<(usize, Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2712 if self.config.quantize_nf4 && self.config.is_lora() {
2713 self.cuda_blocks
2714 .iter()
2715 .enumerate()
2716 .filter_map(|(i, block)| {
2717 block
2718 .download_lora_weights()
2719 .ok()
2720 .map(|(a_q, b_q, a_v, b_v)| (i, a_q, b_q, a_v, b_v))
2721 })
2722 .collect()
2723 } else {
2724 Vec::new()
2725 }
2726 }
2727
2728 pub fn prepare_async_apr_save(
2729 &mut self,
2730 name: &str,
2731 architecture: &str,
2732 step: usize,
2733 loss: f64,
2734 lr: f64,
2735 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2736 self.prepare_async_apr_save_with_tokenizer(name, architecture, step, loss, lr, None)
2737 }
2738
2739 pub fn prepare_async_apr_save_with_tokenizer(
2745 &mut self,
2746 name: &str,
2747 architecture: &str,
2748 step: usize,
2749 loss: f64,
2750 lr: f64,
2751 tokenizer_path: Option<&std::path::Path>,
2752 ) -> Box<dyn FnOnce(&std::path::Path) -> crate::Result<()> + Send> {
2753 self.sync_weights_to_cpu();
2754
2755 let param_data = self.snapshot_param_data();
2756 let lora_data = self.snapshot_lora_data();
2757
2758 let embed_m: Vec<Vec<f32>> = self
2760 .embed_optimizer
2761 .first_moments()
2762 .iter()
2763 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2764 .collect();
2765 let embed_v: Vec<Vec<f32>> = self
2766 .embed_optimizer
2767 .second_moments()
2768 .iter()
2769 .filter_map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
2770 .collect();
2771 let embed_step = self.embed_optimizer.step_count();
2772
2773 let block_optim_data: Vec<Vec<(String, Vec<f32>)>> = self
2778 .gpu_training
2779 .optimizer_states
2780 .iter()
2781 .map(|state| state.download_to_host().unwrap_or_default())
2782 .collect();
2783
2784 let lm_head_m_host = {
2786 let mut buf = vec![0.0f32; self.lm_head_m.len()];
2787 let _ = self.lm_head_m.copy_to_host(&mut buf);
2788 buf
2789 };
2790 let lm_head_v_host = {
2791 let mut buf = vec![0.0f32; self.lm_head_v.len()];
2792 let _ = self.lm_head_v.copy_to_host(&mut buf);
2793 buf
2794 };
2795 let final_norm_m_host = {
2796 let mut buf = vec![0.0f32; self.final_norm_m.len()];
2797 let _ = self.final_norm_m.copy_to_host(&mut buf);
2798 buf
2799 };
2800 let final_norm_v_host = {
2801 let mut buf = vec![0.0f32; self.final_norm_v.len()];
2802 let _ = self.final_norm_v.copy_to_host(&mut buf);
2803 buf
2804 };
2805
2806 let name = name.to_string();
2807 let architecture = architecture.to_string();
2808 let model_config_json = serde_json::to_string(&self.config.model_config).ok();
2809 let is_delta_checkpoint = self.config.quantize_nf4 && self.config.is_lora();
2810
2811 let tokenizer_data: Option<(Vec<String>, Vec<String>, Option<u64>, Option<u64>)> =
2814 tokenizer_path.and_then(|p| {
2815 let json_bytes = std::fs::read(p).ok()?;
2816 let tok: serde_json::Value = serde_json::from_slice(&json_bytes).ok()?;
2817 let model = tok.get("model")?;
2818 let vocab_obj = model.get("vocab")?.as_object()?;
2819 let mut vocab_pairs: Vec<(String, u64)> =
2821 vocab_obj.iter().filter_map(|(k, v)| Some((k.clone(), v.as_u64()?))).collect();
2822 vocab_pairs.sort_by_key(|(_, id)| *id);
2823 let vocab: Vec<String> = vocab_pairs.into_iter().map(|(k, _)| k).collect();
2824 let merges: Vec<String> = model
2826 .get("merges")?
2827 .as_array()?
2828 .iter()
2829 .filter_map(|v| v.as_str().map(String::from))
2830 .collect();
2831 let added = tok.get("added_tokens").and_then(|a| a.as_array());
2833 let bos_id = added.and_then(|arr| {
2834 arr.iter()
2835 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("<s>"))
2836 .and_then(|t| t.get("id")?.as_u64())
2837 });
2838 let eos_id = added.and_then(|arr| {
2839 arr.iter()
2840 .find(|t| t.get("content").and_then(|c| c.as_str()) == Some("</s>"))
2841 .and_then(|t| t.get("id")?.as_u64())
2842 });
2843 if vocab.is_empty() {
2844 return None;
2845 }
2846 println!(
2847 " [ALB-130] Embedding tokenizer: {} vocab, {} merges",
2848 vocab.len(),
2849 merges.len()
2850 );
2851 Some((vocab, merges, bos_id, eos_id))
2852 });
2853
2854 Box::new(move |path: &std::path::Path| {
2855 use aprender::serialization::apr::AprWriter;
2856 use serde_json::Value as Jv;
2857
2858 let mut writer = AprWriter::new();
2859
2860 writer.set_metadata("model_name", Jv::String(name));
2862 writer.set_metadata("architecture", Jv::String(architecture));
2863 writer.set_metadata(
2864 "format",
2865 Jv::String(if is_delta_checkpoint {
2866 "entrenar-delta-checkpoint".into()
2867 } else {
2868 "entrenar-checkpoint".into()
2869 }),
2870 );
2871 writer.set_metadata("checkpoint_step", Jv::String(step.to_string()));
2872 writer.set_metadata("loss", Jv::String(format!("{loss:.6}")));
2873 writer.set_metadata("learning_rate", Jv::String(format!("{lr:.6e}")));
2874 writer.set_metadata("optimizer_step", Jv::String(embed_step.to_string()));
2875 if let Some(cfg) = model_config_json {
2876 writer.set_metadata("model_config", Jv::String(cfg));
2877 }
2878
2879 if let Some((vocab, merges, bos_id, eos_id)) = tokenizer_data {
2881 writer.set_metadata(
2882 "tokenizer.vocabulary",
2883 Jv::Array(vocab.into_iter().map(Jv::String).collect()),
2884 );
2885 writer.set_metadata(
2886 "tokenizer.merges",
2887 Jv::Array(merges.into_iter().map(Jv::String).collect()),
2888 );
2889 if let Some(bos) = bos_id {
2890 writer.set_metadata("tokenizer.bos_token_id", Jv::Number(bos.into()));
2891 }
2892 if let Some(eos) = eos_id {
2893 writer.set_metadata("tokenizer.eos_token_id", Jv::Number(eos.into()));
2894 }
2895 }
2896
2897 let hidden_size = param_data
2899 .iter()
2900 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
2901 .map_or(0, |(_, d)| d.len());
2902
2903 for (tensor_name, data) in ¶m_data {
2905 let shape = infer_tensor_shape(tensor_name, data.len(), hidden_size);
2906 writer.add_tensor_f32(tensor_name.clone(), shape, data);
2907 }
2908
2909 for (i, m_data) in embed_m.iter().enumerate() {
2911 let len = m_data.len();
2912 writer.add_tensor_f32(
2913 format!("__training__.embed_optimizer.m.{i}"),
2914 vec![len],
2915 m_data,
2916 );
2917 }
2918 for (i, v_data) in embed_v.iter().enumerate() {
2919 let len = v_data.len();
2920 writer.add_tensor_f32(
2921 format!("__training__.embed_optimizer.v.{i}"),
2922 vec![len],
2923 v_data,
2924 );
2925 }
2926
2927 for (layer_idx, buffers) in block_optim_data.iter().enumerate() {
2929 for (suffix, data) in buffers {
2930 let len = data.len();
2931 writer.add_tensor_f32(
2932 format!("__training__.block_optimizer.{layer_idx}.{suffix}"),
2933 vec![len],
2934 data,
2935 );
2936 }
2937 }
2938
2939 if !lm_head_m_host.is_empty() {
2941 let len = lm_head_m_host.len();
2942 writer.add_tensor_f32(
2943 "__training__.lm_head_optimizer.m".to_string(),
2944 vec![len],
2945 &lm_head_m_host,
2946 );
2947 let len = lm_head_v_host.len();
2948 writer.add_tensor_f32(
2949 "__training__.lm_head_optimizer.v".to_string(),
2950 vec![len],
2951 &lm_head_v_host,
2952 );
2953 }
2954 if !final_norm_m_host.is_empty() {
2955 let len = final_norm_m_host.len();
2956 writer.add_tensor_f32(
2957 "__training__.final_norm_optimizer.m".to_string(),
2958 vec![len],
2959 &final_norm_m_host,
2960 );
2961 let len = final_norm_v_host.len();
2962 writer.add_tensor_f32(
2963 "__training__.final_norm_optimizer.v".to_string(),
2964 vec![len],
2965 &final_norm_v_host,
2966 );
2967 }
2968
2969 for (layer_idx, a_q, b_q, a_v, b_v) in &lora_data {
2971 if !a_q.is_empty() {
2972 writer.add_tensor_f32(
2973 format!("lora.{layer_idx}.q_proj.lora_a"),
2974 vec![a_q.len()],
2975 a_q,
2976 );
2977 writer.add_tensor_f32(
2978 format!("lora.{layer_idx}.q_proj.lora_b"),
2979 vec![b_q.len()],
2980 b_q,
2981 );
2982 }
2983 if !a_v.is_empty() {
2984 writer.add_tensor_f32(
2985 format!("lora.{layer_idx}.v_proj.lora_a"),
2986 vec![a_v.len()],
2987 a_v,
2988 );
2989 writer.add_tensor_f32(
2990 format!("lora.{layer_idx}.v_proj.lora_b"),
2991 vec![b_v.len()],
2992 b_v,
2993 );
2994 }
2995 }
2996
2997 writer
2999 .write(path)
3000 .map_err(|e| crate::error::Error::Serialization(format!("APR save failed: {e}")))?;
3001
3002 Ok(())
3003 })
3004 }
3005
3006 pub fn gpu_name(&self) -> String {
3008 self.cuda_trainer.device_name()
3009 }
3010
3011 pub fn save_cuda_lora_adapter(
3021 &self,
3022 output_dir: &std::path::Path,
3023 base_model_name: Option<&str>,
3024 ) -> crate::Result<()> {
3025 if !self.config.quantize_nf4 || !self.config.is_lora() {
3026 return Ok(()); }
3028
3029 let lora_rank = self.config.lora_rank.unwrap_or(16);
3030 let lora_alpha = self.config.lora_alpha.unwrap_or(2.0 * lora_rank as f32);
3031 let lora_scale = lora_alpha / lora_rank as f32;
3032 let hidden_size = self.config.model_config.hidden_size;
3033 let head_dim = self.config.model_config.head_dim();
3034 let q_dim = self.config.model_config.num_attention_heads * head_dim;
3035 let kv_hidden = self.config.model_config.num_kv_heads * head_dim;
3036
3037 let lora_config =
3038 crate::lora::LoRAConfig::new(lora_rank, lora_alpha).target_qv_projections();
3039
3040 let mut adapters: Vec<(String, crate::lora::LoRALayer)> = Vec::new();
3041
3042 for (i, block) in self.cuda_blocks.iter().enumerate() {
3043 let (a_q, b_q_scaled, a_v, b_v_scaled) = match block.download_lora_weights() {
3044 Ok(weights) => weights,
3045 Err(_) => continue, };
3047
3048 if a_q.is_empty() && a_v.is_empty() {
3049 continue;
3050 }
3051
3052 if !a_q.is_empty() {
3054 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3056 for r in 0..hidden_size {
3057 for c in 0..lora_rank {
3058 a_transposed[c * hidden_size + r] = a_q[r * lora_rank + c];
3059 }
3060 }
3061
3062 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3065 let mut b_transposed = vec![0.0f32; q_dim * lora_rank];
3066 for r in 0..lora_rank {
3067 for c in 0..q_dim {
3068 b_transposed[c * lora_rank + r] = b_q_scaled[r * q_dim + c] * inv_scale;
3069 }
3070 }
3071
3072 let base_weight = crate::autograd::Tensor::zeros(q_dim * hidden_size, false);
3073 let mut layer = crate::lora::LoRALayer::new(
3074 base_weight,
3075 q_dim,
3076 hidden_size,
3077 lora_rank,
3078 lora_alpha,
3079 );
3080 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3082 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3083
3084 adapters.push((format!("model.layers.{i}.self_attn.q_proj"), layer));
3085 }
3086
3087 if !a_v.is_empty() {
3089 let mut a_transposed = vec![0.0f32; lora_rank * hidden_size];
3090 for r in 0..hidden_size {
3091 for c in 0..lora_rank {
3092 a_transposed[c * hidden_size + r] = a_v[r * lora_rank + c];
3093 }
3094 }
3095
3096 let inv_scale = if lora_scale.abs() > 1e-10 { 1.0 / lora_scale } else { 1.0 };
3097 let mut b_transposed = vec![0.0f32; kv_hidden * lora_rank];
3098 for r in 0..lora_rank {
3099 for c in 0..kv_hidden {
3100 b_transposed[c * lora_rank + r] = b_v_scaled[r * kv_hidden + c] * inv_scale;
3101 }
3102 }
3103
3104 let base_weight = crate::autograd::Tensor::zeros(kv_hidden * hidden_size, false);
3105 let mut layer = crate::lora::LoRALayer::new(
3106 base_weight,
3107 kv_hidden,
3108 hidden_size,
3109 lora_rank,
3110 lora_alpha,
3111 );
3112 layer.lora_a_mut().data_mut().assign(&ndarray::Array1::from(a_transposed));
3113 layer.lora_b_mut().data_mut().assign(&ndarray::Array1::from(b_transposed));
3114
3115 adapters.push((format!("model.layers.{i}.self_attn.v_proj"), layer));
3116 }
3117 }
3118
3119 if adapters.is_empty() {
3120 println!(" [WARN] No LoRA adapters found to save");
3121 return Ok(());
3122 }
3123
3124 let adapter_refs: Vec<(&str, &crate::lora::LoRALayer)> =
3125 adapters.iter().map(|(name, layer)| (name.as_str(), layer)).collect();
3126
3127 std::fs::create_dir_all(output_dir).ok();
3128 crate::lora::save_adapter_peft(&adapter_refs, &lora_config, base_model_name, output_dir)
3129 .map_err(|e| crate::error::Error::Io(format!("Failed to save PEFT adapter: {e}")))?;
3130
3131 let adapter_path = output_dir.join("adapter_model.safetensors");
3132 let size_mb =
3133 std::fs::metadata(&adapter_path).map(|m| m.len()).unwrap_or(0) / (1024 * 1024);
3134 println!(
3135 "✓ LoRA adapter saved ({} layers, {} MB) to {}",
3136 adapters.len(),
3137 size_mb,
3138 output_dir.display()
3139 );
3140
3141 Ok(())
3142 }
3143
3144 pub fn save_optimizer_state(&self, dir: &std::path::Path) -> crate::Result<()> {
3149 let path = dir.join("optimizer_state.json");
3150 let m_data: Vec<Option<Vec<f32>>> = self
3151 .embed_optimizer
3152 .first_moments()
3153 .iter()
3154 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3155 .collect();
3156 let v_data: Vec<Option<Vec<f32>>> = self
3157 .embed_optimizer
3158 .second_moments()
3159 .iter()
3160 .map(|opt| opt.as_ref().map(ndarray::ArrayBase::to_vec))
3161 .collect();
3162 let state = serde_json::json!({
3163 "type": "adamw_cpu_embed",
3164 "step": self.embed_optimizer.step_count(),
3165 "m": m_data,
3166 "v": v_data,
3167 });
3168 let json_str = serde_json::to_string(&state).map_err(|e| {
3169 crate::error::Error::ConfigError(format!("serialize optimizer state: {e}"))
3170 })?;
3171 std::fs::write(&path, json_str)
3172 .map_err(|e| crate::error::Error::ConfigError(format!("write optimizer state: {e}")))?;
3173 Ok(())
3174 }
3175
3176 pub fn restore_lora_from_apr(&mut self, apr_path: &std::path::Path) -> (usize, usize) {
3182 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3183 Ok(r) => r,
3184 Err(_) => return (0, self.cuda_blocks.len()),
3185 };
3186
3187 let mut restored = 0usize;
3188 for (i, block) in self.cuda_blocks.iter_mut().enumerate() {
3189 let a_q =
3190 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_a")).unwrap_or_default();
3191 let b_q =
3192 reader.read_tensor_f32(&format!("lora.{i}.q_proj.lora_b")).unwrap_or_default();
3193 let a_v =
3194 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_a")).unwrap_or_default();
3195 let b_v =
3196 reader.read_tensor_f32(&format!("lora.{i}.v_proj.lora_b")).unwrap_or_default();
3197
3198 if a_q.is_empty() {
3199 continue; }
3201
3202 if let Err(e) = block.upload_lora_weights(&a_q, &b_q, &a_v, &b_v) {
3203 eprintln!("Warning: failed to restore LoRA for layer {i}: {e}");
3204 continue;
3205 }
3206 restored += 1;
3207 }
3208
3209 (restored, self.cuda_blocks.len())
3210 }
3211
3212 pub fn load_optimizer_state_apr(&mut self, apr_path: &std::path::Path) -> bool {
3217 let reader = match aprender::serialization::apr::AprReader::open(apr_path) {
3218 Ok(r) => r,
3219 Err(_) => return false,
3220 };
3221
3222 if let Some(step_val) = reader.get_metadata("optimizer_step") {
3224 if let Some(step_str) = step_val.as_str() {
3225 if let Ok(step) = step_str.parse::<u64>() {
3226 self.embed_optimizer.set_step_count(step);
3227 }
3228 }
3229 }
3230
3231 for i in 0..128 {
3233 let name = format!("__training__.embed_optimizer.m.{i}");
3234 match reader.read_tensor_f32(&name) {
3235 Ok(data) if !data.is_empty() => {
3236 self.embed_optimizer.set_first_moment(i, ndarray::Array1::from_vec(data));
3237 }
3238 _ => break,
3239 }
3240 }
3241
3242 for i in 0..128 {
3244 let name = format!("__training__.embed_optimizer.v.{i}");
3245 match reader.read_tensor_f32(&name) {
3246 Ok(data) if !data.is_empty() => {
3247 self.embed_optimizer.set_second_moment(i, ndarray::Array1::from_vec(data));
3248 }
3249 _ => break,
3250 }
3251 }
3252
3253 let suffixes = [
3255 "m.w_q",
3256 "v.w_q",
3257 "m.w_k",
3258 "v.w_k",
3259 "m.w_v",
3260 "v.w_v",
3261 "m.w_o",
3262 "v.w_o",
3263 "m.w_gate",
3264 "v.w_gate",
3265 "m.w_up",
3266 "v.w_up",
3267 "m.w_down",
3268 "v.w_down",
3269 "m.input_norm",
3270 "v.input_norm",
3271 "m.post_attn_norm",
3272 "v.post_attn_norm",
3273 ];
3274 let mut blocks_restored = 0usize;
3275 for (layer_idx, state) in self.gpu_training.optimizer_states.iter_mut().enumerate() {
3276 let mut data = std::collections::HashMap::new();
3277 for suffix in &suffixes {
3278 let name = format!("__training__.block_optimizer.{layer_idx}.{suffix}");
3279 if let Ok(tensor_data) = reader.read_tensor_f32(&name) {
3280 if !tensor_data.is_empty() {
3281 data.insert(suffix.to_string(), tensor_data);
3282 }
3283 }
3284 }
3285 if !data.is_empty() {
3286 let _ = state.restore_from_host(&data);
3287 blocks_restored += 1;
3288 }
3289 }
3290
3291 if let Ok(m_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.m") {
3293 if m_data.len() == self.lm_head_m.len() {
3294 let _ = self.lm_head_m.copy_from_host(&m_data);
3295 }
3296 }
3297 if let Ok(v_data) = reader.read_tensor_f32("__training__.lm_head_optimizer.v") {
3298 if v_data.len() == self.lm_head_v.len() {
3299 let _ = self.lm_head_v.copy_from_host(&v_data);
3300 }
3301 }
3302
3303 if let Ok(m_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.m") {
3305 if m_data.len() == self.final_norm_m.len() {
3306 let _ = self.final_norm_m.copy_from_host(&m_data);
3307 }
3308 }
3309 if let Ok(v_data) = reader.read_tensor_f32("__training__.final_norm_optimizer.v") {
3310 if v_data.len() == self.final_norm_v.len() {
3311 let _ = self.final_norm_v.copy_from_host(&v_data);
3312 }
3313 }
3314
3315 if blocks_restored > 0 {
3317 println!(
3318 " ✓ GPU block optimizer states restored ({blocks_restored}/{} blocks)",
3319 self.gpu_training.optimizer_states.len()
3320 );
3321 } else if !self.gpu_training.optimizer_states.is_empty() {
3322 println!(
3323 " [WARN] GPU block optimizer states NOT restored (0/{} blocks — zeroed m/v)",
3324 self.gpu_training.optimizer_states.len()
3325 );
3326 }
3327
3328 true
3329 }
3330
3331 pub fn load_optimizer_state(&mut self, dir: &std::path::Path) -> bool {
3335 let path = dir.join("optimizer_state.json");
3336 let data = match std::fs::read_to_string(&path) {
3337 Ok(d) => d,
3338 Err(_) => return false,
3339 };
3340 let state: serde_json::Value = match serde_json::from_str(&data) {
3341 Ok(v) => v,
3342 Err(_) => return false,
3343 };
3344 if let Some(step) = state["step"].as_u64() {
3345 self.embed_optimizer.set_step_count(step);
3346 }
3347 restore_moment_buffers(&state["m"], |idx, arr| {
3348 self.embed_optimizer.set_first_moment(idx, arr);
3349 });
3350 restore_moment_buffers(&state["v"], |idx, arr| {
3351 self.embed_optimizer.set_second_moment(idx, arr);
3352 });
3353 true
3354 }
3355}
3356
3357#[cfg(feature = "cuda")]
3361fn infer_tensor_shape(name: &str, numel: usize, hidden_size: usize) -> Vec<usize> {
3362 if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
3363 vec![numel]
3364 } else if hidden_size > 0 && numel.is_multiple_of(hidden_size) {
3365 let other_dim = numel / hidden_size;
3366 if name.ends_with("down_proj.weight") {
3367 vec![hidden_size, other_dim]
3368 } else {
3369 vec![other_dim, hidden_size]
3370 }
3371 } else {
3372 vec![numel]
3373 }
3374}
3375
3376#[cfg(feature = "cuda")]
3378fn restore_moment_buffers(
3379 json_arr: &serde_json::Value,
3380 mut set_fn: impl FnMut(usize, ndarray::Array1<f32>),
3381) {
3382 let Some(arr) = json_arr.as_array() else { return };
3383 for (idx, val) in arr.iter().enumerate() {
3384 let Some(inner) = val.as_array() else { continue };
3385 let floats: Vec<f32> = inner.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
3386 if !floats.is_empty() {
3387 set_fn(idx, ndarray::Array1::from_vec(floats));
3388 }
3389 }
3390}
3391
3392#[cfg(not(feature = "cuda"))]
3395pub struct CudaTransformerTrainer;
3396
3397#[cfg(not(feature = "cuda"))]
3398impl CudaTransformerTrainer {
3399 pub fn new(_config: super::config::TransformerTrainConfig) -> crate::Result<Self> {
3400 Err(crate::error::Error::ConfigError(
3401 "CUDA not available (compiled without cuda feature)".into(),
3402 ))
3403 }
3404
3405 pub fn with_model(
3406 _model: crate::transformer::Transformer,
3407 _config: super::config::TransformerTrainConfig,
3408 ) -> crate::Result<Self> {
3409 Err(crate::error::Error::ConfigError(
3410 "CUDA not available (compiled without cuda feature)".into(),
3411 ))
3412 }
3413
3414 pub fn gpu_name(&self) -> String {
3415 unreachable!("CudaTransformerTrainer stub should never be instantiated")
3416 }
3417}
3418
3419#[cfg(test)]
3420mod tests {
3421 #[test]
3422 #[cfg(not(feature = "cuda"))]
3423 fn test_cuda_trainer_stub_returns_error() {
3424 use super::super::config::TransformerTrainConfig;
3425 use crate::transformer::TransformerConfig;
3426
3427 let mc = TransformerConfig::tiny();
3428 let config = TransformerTrainConfig::new(mc);
3429 let result = super::CudaTransformerTrainer::new(config);
3430 assert!(result.is_err());
3431 }
3432}