1#![allow(unsafe_code)]
23#![allow(trivial_casts)]
24#![allow(clippy::borrow_as_ptr)]
25#![allow(clippy::ref_as_ptr)]
26
27#[cfg(feature = "cuda")]
28use std::collections::HashMap;
29#[cfg(feature = "cuda")]
30use std::sync::{Mutex, OnceLock};
31
32#[cfg(feature = "cuda")]
33use trueno_gpu::driver::{CudaContext, CudaModule, CudaStream, GpuBuffer, LaunchConfig};
34#[cfg(feature = "cuda")]
35use trueno_gpu::kernels::backward::{FusedCausalCrossEntropyKernel, FusedCrossEntropyKernel};
36use trueno_gpu::kernels::{
37 AdamStepKernel, AdamWStepKernel, ClipScaleReduceKernel, GradientClipGpuScaleKernel,
38 GradientClipKernel, Kernel, SquaredSumKernel,
39};
40
41use super::cuda_tensor::{CudaTensorError, Result};
42
43#[cfg(feature = "cuda")]
45static OPTIM_KERNEL_CACHE: OnceLock<Mutex<OptimKernelCache>> = OnceLock::new();
46
47#[cfg(feature = "cuda")]
49struct OptimKernelCache {
50 ctx: std::sync::Arc<CudaContext>,
51 modules: HashMap<String, CudaModule>,
52 sm_target: String,
53}
54
55#[cfg(feature = "cuda")]
56impl OptimKernelCache {
57 fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
58 let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
59 Self { ctx, modules: HashMap::new(), sm_target }
60 }
61
62 fn sm_target(&self) -> &str {
63 &self.sm_target
64 }
65
66 fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
68 self.modules.get_mut(name)
69 }
70
71 fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
72 if !self.modules.contains_key(name) {
73 let module = CudaModule::from_ptx(&self.ctx, ptx).map_err(|e| {
74 CudaTensorError::KernelError(format!("Failed to compile {name}: {e:?}"))
75 })?;
76 self.modules.insert(name.to_string(), module);
77 }
78 Ok(self.modules.get_mut(name).expect("module was just inserted above"))
79 }
80}
81
82#[cfg(feature = "cuda")]
84pub fn init_optim_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
85 OPTIM_KERNEL_CACHE.get_or_init(|| Mutex::new(OptimKernelCache::new(ctx)));
86 Ok(())
87}
88
89#[cfg(feature = "cuda")]
111pub fn pre_warm_lora_adamw_kernels(
112 hidden_size: usize,
113 q_dim: usize,
114 kv_hidden_size: usize,
115 lora_rank: usize,
116 num_classes: usize,
117 intermediate_size: usize,
118 quantize_nf4: bool,
119) -> Result<()> {
120 if lora_rank == 0 {
121 return Ok(());
122 }
123
124 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
125 let mut cache = cache.lock().map_err(|_err| {
126 CudaTensorError::KernelError("Failed to acquire optim kernel cache lock".to_string())
127 })?;
128
129 let target = cache.sm_target().to_string();
130
131 let mut sizes: Vec<u32> = vec![
132 (hidden_size * lora_rank) as u32, (lora_rank * q_dim) as u32, (lora_rank * kv_hidden_size) as u32, hidden_size as u32, ];
137
138 if !quantize_nf4 {
140 sizes.push((hidden_size * hidden_size) as u32); sizes.push((hidden_size * kv_hidden_size) as u32); sizes.push((hidden_size * intermediate_size) as u32); }
144
145 if num_classes > 0 {
147 sizes.push((num_classes * hidden_size) as u32);
148 sizes.push(num_classes as u32);
149 }
150
151 sizes.sort_unstable();
152 sizes.dedup();
153
154 for n in sizes {
155 let kernel = AdamWStepKernel::new(n);
156 let ptx = kernel.emit_ptx_for_target(&target);
157 let key = format!("adamw_step_{n}");
158 cache.get_or_compile(&key, &ptx)?;
159 }
160
161 Ok(())
162}
163
164#[cfg(feature = "cuda")]
182pub fn adamw_step_cuda(
183 params: &mut GpuBuffer<f32>,
184 grads: &GpuBuffer<f32>,
185 m: &mut GpuBuffer<f32>,
186 v: &mut GpuBuffer<f32>,
187 lr: f32,
188 beta1: f32,
189 beta2: f32,
190 eps: f32,
191 weight_decay: f32,
192 step: u32,
193 n: u32,
194 stream: &CudaStream,
195) -> Result<()> {
196 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
197 let mut cache = cache.lock().map_err(|_err| {
198 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
199 })?;
200
201 let key = format!("adamw_step_{n}");
202 let module = match cache.get_cached(&key) {
203 Some(m) => m,
204 None => {
205 let kernel = AdamWStepKernel::new(n);
206 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
207 cache.get_or_compile(&key, &ptx)?
208 }
209 };
210
211 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
212
213 let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
215 let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
216
217 let params_ptr = params.as_ptr();
218 let grads_ptr = grads.as_ptr();
219 let m_ptr = m.as_ptr();
220 let v_ptr = v.as_ptr();
221
222 let mut args: [*mut std::ffi::c_void; 12] = [
223 ¶ms_ptr as *const _ as *mut _,
224 &grads_ptr as *const _ as *mut _,
225 &m_ptr as *const _ as *mut _,
226 &v_ptr as *const _ as *mut _,
227 &lr as *const _ as *mut _,
228 &beta1 as *const _ as *mut _,
229 &beta2 as *const _ as *mut _,
230 &eps as *const _ as *mut _,
231 &weight_decay as *const _ as *mut _,
232 &bias_adjust1 as *const _ as *mut _,
233 &bias_adjust2 as *const _ as *mut _,
234 &n as *const _ as *mut _,
235 ];
236
237 unsafe {
240 stream.launch_kernel(module, "adamw_step", &config, &mut args).map_err(|e| {
241 CudaTensorError::KernelError(format!("AdamW step launch failed: {e:?}"))
242 })?;
243 }
244
245 Ok(())
246}
247
248#[cfg(feature = "cuda")]
252pub fn adam_step_cuda(
253 params: &mut GpuBuffer<f32>,
254 grads: &GpuBuffer<f32>,
255 m: &mut GpuBuffer<f32>,
256 v: &mut GpuBuffer<f32>,
257 lr: f32,
258 beta1: f32,
259 beta2: f32,
260 eps: f32,
261 step: u32,
262 n: u32,
263 stream: &CudaStream,
264) -> Result<()> {
265 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
266 let mut cache = cache.lock().map_err(|_err| {
267 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
268 })?;
269
270 let key = format!("adam_step_{n}");
271 let module = match cache.get_cached(&key) {
272 Some(m) => m,
273 None => {
274 let kernel = AdamStepKernel::new(n);
275 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
276 cache.get_or_compile(&key, &ptx)?
277 }
278 };
279
280 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
281
282 let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
284 let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
285
286 let params_ptr = params.as_ptr();
287 let grads_ptr = grads.as_ptr();
288 let m_ptr = m.as_ptr();
289 let v_ptr = v.as_ptr();
290
291 let mut args: [*mut std::ffi::c_void; 11] = [
292 ¶ms_ptr as *const _ as *mut _,
293 &grads_ptr as *const _ as *mut _,
294 &m_ptr as *const _ as *mut _,
295 &v_ptr as *const _ as *mut _,
296 &lr as *const _ as *mut _,
297 &beta1 as *const _ as *mut _,
298 &beta2 as *const _ as *mut _,
299 &eps as *const _ as *mut _,
300 &bias_adjust1 as *const _ as *mut _,
301 &bias_adjust2 as *const _ as *mut _,
302 &n as *const _ as *mut _,
303 ];
304
305 unsafe {
308 stream
309 .launch_kernel(module, "adam_step", &config, &mut args)
310 .map_err(|e| CudaTensorError::KernelError(format!("Adam step launch failed: {e:?}")))?;
311 }
312
313 Ok(())
314}
315
316#[cfg(feature = "cuda")]
336pub fn gradient_clip_cuda(
337 grads: &mut GpuBuffer<f32>,
338 scale: f32,
339 n: u32,
340 stream: &CudaStream,
341) -> Result<()> {
342 if (scale - 1.0).abs() < 1e-7 {
344 return Ok(());
345 }
346
347 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
348 let mut cache = cache.lock().map_err(|_err| {
349 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
350 })?;
351
352 let key = format!("gradient_clip_{n}");
353 let module = match cache.get_cached(&key) {
354 Some(m) => m,
355 None => {
356 let kernel = GradientClipKernel::new(n);
357 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
358 cache.get_or_compile(&key, &ptx)?
359 }
360 };
361
362 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
363
364 let grads_ptr = grads.as_ptr();
365
366 let mut args: [*mut std::ffi::c_void; 3] =
367 [&grads_ptr as *const _ as *mut _, &scale as *const _ as *mut _, &n as *const _ as *mut _];
368
369 unsafe {
372 stream.launch_kernel(module, "gradient_clip", &config, &mut args).map_err(|e| {
373 CudaTensorError::KernelError(format!("Gradient clip launch failed: {e:?}"))
374 })?;
375 }
376
377 Ok(())
378}
379
380#[cfg(feature = "cuda")]
395pub fn squared_sum_cuda(input: &GpuBuffer<f32>, n: u32, stream: &CudaStream) -> Result<f32> {
396 let pending = squared_sum_launch_cuda(input, n, stream)?;
397 stream
398 .synchronize()
399 .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
400 squared_sum_collect(&pending)
401}
402
403#[cfg(feature = "cuda")]
407pub struct PendingSquaredSum {
408 output: GpuBuffer<f32>,
409 num_blocks: u32,
410}
411
412#[cfg(feature = "cuda")]
426pub fn squared_sum_launch_cuda(
427 input: &GpuBuffer<f32>,
428 n: u32,
429 stream: &CudaStream,
430) -> Result<PendingSquaredSum> {
431 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
432 let mut cache = cache.lock().map_err(|_err| {
433 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
434 })?;
435
436 let kernel = SquaredSumKernel::new(n);
437 let num_blocks = kernel.num_blocks();
438
439 let ctx = std::sync::Arc::clone(&cache.ctx);
441
442 let key = format!("squared_sum_{n}");
443 let module = match cache.get_cached(&key) {
444 Some(m) => m,
445 None => {
446 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
447 cache.get_or_compile(&key, &ptx)?
448 }
449 };
450
451 let output = GpuBuffer::<f32>::new(&ctx, num_blocks as usize).map_err(|e| {
453 CudaTensorError::KernelError(format!("Failed to allocate squared_sum output: {e:?}"))
454 })?;
455
456 let config = LaunchConfig {
457 grid: (num_blocks, 1, 1),
458 block: (kernel.block_size(), 1, 1),
459 shared_mem: 8 * 4, };
461
462 let input_ptr = input.as_ptr();
463 let output_ptr = output.as_ptr();
464
465 let mut args: [*mut std::ffi::c_void; 3] = [
466 &input_ptr as *const _ as *mut _,
467 &output_ptr as *const _ as *mut _,
468 &n as *const _ as *mut _,
469 ];
470
471 unsafe {
474 stream.launch_kernel(module, "squared_sum_reduce", &config, &mut args).map_err(|e| {
475 CudaTensorError::KernelError(format!("Squared sum launch failed: {e:?}"))
476 })?;
477 }
478
479 Ok(PendingSquaredSum { output, num_blocks })
480}
481
482#[cfg(feature = "cuda")]
486pub fn squared_sum_collect(pending: &PendingSquaredSum) -> Result<f32> {
487 let mut partials = vec![0.0f32; pending.num_blocks as usize];
488 pending.output.copy_to_host(&mut partials).map_err(|e| {
489 CudaTensorError::KernelError(format!("Failed to download partial sums: {e:?}"))
490 })?;
491
492 let total: f64 = partials.iter().map(|&x| f64::from(x)).sum();
494 Ok(total.sqrt() as f32)
495}
496
497#[cfg(feature = "cuda")]
507pub fn squared_sum_launch_into(
508 input: &GpuBuffer<f32>,
509 n: u32,
510 output_ptr: u64, stream: &CudaStream,
512) -> Result<u32> {
513 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
514 let mut cache = cache.lock().map_err(|_err| {
515 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
516 })?;
517
518 let kernel = SquaredSumKernel::new(n);
519 let num_blocks = kernel.num_blocks();
520
521 let key = format!("squared_sum_{n}");
522 let module = match cache.get_cached(&key) {
523 Some(m) => m,
524 None => {
525 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
526 cache.get_or_compile(&key, &ptx)?
527 }
528 };
529
530 let config = LaunchConfig {
531 grid: (num_blocks, 1, 1),
532 block: (kernel.block_size(), 1, 1),
533 shared_mem: 8 * 4, };
535
536 let input_ptr = input.as_ptr();
537
538 let mut args: [*mut std::ffi::c_void; 3] = [
539 &input_ptr as *const _ as *mut _,
540 &output_ptr as *const _ as *mut _,
541 &n as *const _ as *mut _,
542 ];
543
544 unsafe {
547 stream.launch_kernel(module, "squared_sum_reduce", &config, &mut args).map_err(|e| {
548 CudaTensorError::KernelError(format!("Squared sum launch_into failed: {e:?}"))
549 })?;
550 }
551
552 Ok(num_blocks)
553}
554
555#[cfg(feature = "cuda")]
565pub fn clip_scale_reduce_cuda(
566 partials: &GpuBuffer<f32>,
567 total_n: u32,
568 max_norm: f32,
569 output: &GpuBuffer<f32>,
570 stream: &CudaStream,
571) -> Result<()> {
572 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
573 let mut cache = cache.lock().map_err(|_err| {
574 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
575 })?;
576
577 let key = "clip_scale_reduce".to_string();
578 let module = match cache.get_cached(&key) {
579 Some(m) => m,
580 None => {
581 let kernel = ClipScaleReduceKernel;
582 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
583 cache.get_or_compile(&key, &ptx)?
584 }
585 };
586
587 let config = LaunchConfig { grid: (1, 1, 1), block: (1, 1, 1), shared_mem: 0 };
589
590 let partials_ptr = partials.as_ptr();
591 let output_ptr = output.as_ptr();
592
593 let mut args: [*mut std::ffi::c_void; 4] = [
594 &partials_ptr as *const _ as *mut _,
595 &total_n as *const _ as *mut _,
596 &max_norm as *const _ as *mut _,
597 &output_ptr as *const _ as *mut _,
598 ];
599
600 unsafe {
602 stream.launch_kernel(module, "clip_scale_reduce", &config, &mut args).map_err(|e| {
603 CudaTensorError::KernelError(format!("Clip scale reduce launch failed: {e:?}"))
604 })?;
605 }
606
607 Ok(())
608}
609
610#[cfg(feature = "cuda")]
617pub fn gradient_clip_gpu_scale_cuda(
618 grads: &mut GpuBuffer<f32>,
619 scale_ptr: u64, n: u32,
621 stream: &CudaStream,
622) -> Result<()> {
623 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
624 let mut cache = cache.lock().map_err(|_err| {
625 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
626 })?;
627
628 let key = format!("gradient_clip_gpu_scale_{n}");
629 let module = match cache.get_cached(&key) {
630 Some(m) => m,
631 None => {
632 let kernel = GradientClipGpuScaleKernel::new(n);
633 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
634 cache.get_or_compile(&key, &ptx)?
635 }
636 };
637
638 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
639
640 let grads_ptr = grads.as_ptr();
641
642 let mut args: [*mut std::ffi::c_void; 3] = [
643 &grads_ptr as *const _ as *mut _,
644 &scale_ptr as *const _ as *mut _,
645 &n as *const _ as *mut _,
646 ];
647
648 unsafe {
650 stream.launch_kernel(module, "gradient_clip_gpu_scale", &config, &mut args).map_err(
651 |e| {
652 CudaTensorError::KernelError(format!(
653 "Gradient clip GPU scale launch failed: {e:?}"
654 ))
655 },
656 )?;
657 }
658
659 Ok(())
660}
661
662#[cfg(feature = "cuda")]
667pub struct FusedClipState {
668 pub partials_buf: GpuBuffer<f32>,
670 pub scale_buf: GpuBuffer<f32>,
672 pub offsets: [u32; 9],
674 pub num_blocks: [u32; 9],
676 pub total_partials: u32,
678}
679
680#[cfg(feature = "cuda")]
681impl FusedClipState {
682 pub fn new(ctx: &std::sync::Arc<CudaContext>, grad_sizes: &[u32; 9]) -> Result<Self> {
690 let mut offsets = [0u32; 9];
691 let mut num_blocks_arr = [0u32; 9];
692 let mut total = 0u32;
693
694 for (i, &n) in grad_sizes.iter().enumerate() {
695 offsets[i] = total;
696 let kernel = SquaredSumKernel::new(n);
697 let nb = kernel.num_blocks();
698 num_blocks_arr[i] = nb;
699 total += nb;
700 }
701
702 let partials_buf = GpuBuffer::<f32>::new(ctx, total as usize).map_err(|e| {
703 CudaTensorError::KernelError(format!("Failed to allocate partials buffer: {e:?}"))
704 })?;
705
706 let scale_buf = GpuBuffer::<f32>::new(ctx, 2).map_err(|e| {
707 CudaTensorError::KernelError(format!("Failed to allocate scale buffer: {e:?}"))
708 })?;
709
710 Ok(Self {
711 partials_buf,
712 scale_buf,
713 offsets,
714 num_blocks: num_blocks_arr,
715 total_partials: total,
716 })
717 }
718}
719
720#[cfg(feature = "cuda")]
738pub fn fused_cross_entropy_cuda(
739 logits_buf: &mut GpuBuffer<f32>,
740 target_ids: &[u32],
741 seq_len: u32,
742 vocab_size: u32,
743 scale: f32,
744 stream: &CudaStream,
745) -> Result<f32> {
746 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
747 let mut cache = cache.lock().map_err(|_err| {
748 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
749 })?;
750
751 let kernel = FusedCrossEntropyKernel::new(vocab_size);
752
753 let ctx = std::sync::Arc::clone(&cache.ctx);
755
756 let key = format!("fused_xent_{vocab_size}");
757 let module = match cache.get_cached(&key) {
758 Some(m) => m,
759 None => {
760 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
761 cache.get_or_compile(&key, &ptx)?
762 }
763 };
764
765 let targets_u32: Vec<u32> = target_ids[..seq_len as usize].to_vec();
767 let targets_gpu = GpuBuffer::<u32>::from_host(&ctx, &targets_u32)
768 .map_err(|e| CudaTensorError::KernelError(format!("Failed to upload targets: {e:?}")))?;
769
770 let loss_gpu = GpuBuffer::<f32>::new(&ctx, seq_len as usize).map_err(|e| {
774 CudaTensorError::KernelError(format!("Failed to allocate loss partials: {e:?}"))
775 })?;
776
777 let config =
779 LaunchConfig { grid: (seq_len, 1, 1), block: (kernel.block_size(), 1, 1), shared_mem: 72 };
780
781 let logits_grad_ptr = logits_buf.as_ptr();
782 let targets_ptr = targets_gpu.as_ptr();
783 let loss_ptr = loss_gpu.as_ptr();
784
785 let mut args: [*mut std::ffi::c_void; 5] = [
786 &logits_grad_ptr as *const _ as *mut _,
787 &targets_ptr as *const _ as *mut _,
788 &loss_ptr as *const _ as *mut _,
789 &vocab_size as *const _ as *mut _,
790 &scale as *const _ as *mut _,
791 ];
792
793 unsafe {
798 stream.launch_kernel(module, "fused_cross_entropy", &config, &mut args).map_err(|e| {
799 CudaTensorError::KernelError(format!("Fused cross-entropy launch failed: {e:?}"))
800 })?;
801 }
802
803 stream
805 .synchronize()
806 .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
807
808 let mut loss_partials = vec![0.0f32; seq_len as usize];
809 loss_gpu.copy_to_host(&mut loss_partials).map_err(|e| {
810 CudaTensorError::KernelError(format!("Failed to download loss partials: {e:?}"))
811 })?;
812
813 let total_loss: f64 = loss_partials.iter().map(|&x| f64::from(x)).sum();
815 let avg_loss = (total_loss / f64::from(seq_len)) as f32;
816
817 Ok(avg_loss)
818}
819
820#[cfg(feature = "cuda")]
836pub fn fused_causal_cross_entropy_cuda(
837 logits_buf: &mut GpuBuffer<f32>,
838 target_ids: &[u32],
839 seq_len: u32,
840 vocab_size: u32,
841 loss_start: u32,
842 loss_end: u32,
843 scale: f32,
844 stream: &CudaStream,
845) -> Result<f32> {
846 let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
847 let mut cache = cache.lock().map_err(|_err| {
848 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
849 })?;
850
851 let kernel = FusedCausalCrossEntropyKernel::new(vocab_size);
852
853 let ctx = std::sync::Arc::clone(&cache.ctx);
854
855 let key = format!("fused_causal_xent_{vocab_size}");
856 let module = match cache.get_cached(&key) {
857 Some(m) => m,
858 None => {
859 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
860 cache.get_or_compile(&key, &ptx)?
861 }
862 };
863
864 let targets_u32: Vec<u32> = target_ids[..seq_len as usize].to_vec();
866 let targets_gpu = GpuBuffer::<u32>::from_host(&ctx, &targets_u32)
867 .map_err(|e| CudaTensorError::KernelError(format!("Failed to upload targets: {e:?}")))?;
868
869 let loss_gpu = GpuBuffer::<f32>::new(&ctx, seq_len as usize).map_err(|e| {
871 CudaTensorError::KernelError(format!("Failed to allocate loss partials: {e:?}"))
872 })?;
873
874 let config =
876 LaunchConfig { grid: (seq_len, 1, 1), block: (kernel.block_size(), 1, 1), shared_mem: 72 };
877
878 let logits_grad_ptr = logits_buf.as_ptr();
879 let targets_ptr = targets_gpu.as_ptr();
880 let loss_ptr = loss_gpu.as_ptr();
881
882 let mut args: [*mut std::ffi::c_void; 7] = [
883 &logits_grad_ptr as *const _ as *mut _,
884 &targets_ptr as *const _ as *mut _,
885 &loss_ptr as *const _ as *mut _,
886 &vocab_size as *const _ as *mut _,
887 &scale as *const _ as *mut _,
888 &loss_start as *const _ as *mut _,
889 &loss_end as *const _ as *mut _,
890 ];
891
892 unsafe {
897 stream.launch_kernel(module, "fused_causal_cross_entropy", &config, &mut args).map_err(
898 |e| {
899 CudaTensorError::KernelError(format!(
900 "Fused causal cross-entropy launch failed: {e:?}"
901 ))
902 },
903 )?;
904 }
905
906 stream
908 .synchronize()
909 .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
910
911 let mut loss_partials = vec![0.0f32; seq_len as usize];
912 loss_gpu.copy_to_host(&mut loss_partials).map_err(|e| {
913 CudaTensorError::KernelError(format!("Failed to download loss partials: {e:?}"))
914 })?;
915
916 let num_loss_tokens = loss_end.saturating_sub(loss_start) as usize;
918 if num_loss_tokens == 0 {
919 return Ok(0.0);
920 }
921 let total_loss: f64 =
922 loss_partials[loss_start as usize..loss_end as usize].iter().map(|&x| f64::from(x)).sum();
923 let avg_loss = (total_loss / num_loss_tokens as f64) as f32;
924
925 Ok(avg_loss)
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931
932 #[test]
933 fn test_cuda_optim_module_compiles() {
934 assert!(true);
937 }
938
939 #[test]
940 #[cfg(feature = "cuda")]
941 fn test_optim_kernel_cache_initialization() {
942 use trueno_gpu::driver::cuda_available;
943
944 if !cuda_available() {
945 return;
946 }
947
948 let ctx = CudaContext::new(0).expect("operation should succeed");
949 let ctx = std::sync::Arc::new(ctx);
950 let result = init_optim_kernel_cache(ctx);
951 assert!(result.is_ok());
952 }
953
954 #[cfg(feature = "cuda")]
958 fn get_test_gpu_context() -> Option<std::sync::Arc<CudaContext>> {
959 use trueno_gpu::driver::cuda_available;
960
961 if cuda_available() {
962 CudaContext::new(0).ok().map(std::sync::Arc::new)
963 } else {
964 None
965 }
966 }
967
968 fn adamw_step_cpu(
970 params: &mut [f32],
971 grads: &[f32],
972 m: &mut [f32],
973 v: &mut [f32],
974 lr: f32,
975 beta1: f32,
976 beta2: f32,
977 eps: f32,
978 weight_decay: f32,
979 step: u32,
980 ) {
981 let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
982 let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
983
984 for i in 0..params.len() {
985 m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
987 v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
989
990 let m_hat = m[i] * bias_adjust1;
992 let v_hat = v[i] * bias_adjust2;
993
994 params[i] = params[i] * (1.0 - lr * weight_decay) - lr * m_hat / (v_hat.sqrt() + eps);
996 }
997 }
998
999 fn adam_step_cpu(
1001 params: &mut [f32],
1002 grads: &[f32],
1003 m: &mut [f32],
1004 v: &mut [f32],
1005 lr: f32,
1006 beta1: f32,
1007 beta2: f32,
1008 eps: f32,
1009 step: u32,
1010 ) {
1011 let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
1012 let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
1013
1014 for i in 0..params.len() {
1015 m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
1017 v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
1019
1020 let m_hat = m[i] * bias_adjust1;
1022 let v_hat = v[i] * bias_adjust2;
1023
1024 params[i] -= lr * m_hat / (v_hat.sqrt() + eps);
1026 }
1027 }
1028
1029 fn gradient_clip_cpu(grads: &mut [f32], scale: f32) {
1031 for g in grads.iter_mut() {
1032 *g *= scale;
1033 }
1034 }
1035
1036 #[test]
1037 #[cfg(feature = "cuda")]
1038 fn test_adamw_step_basic() {
1039 let ctx = match get_test_gpu_context() {
1040 Some(c) => c,
1041 None => return,
1042 };
1043 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1044 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1045
1046 let n = 4u32;
1047 let lr = 0.001f32;
1048 let beta1 = 0.9f32;
1049 let beta2 = 0.999f32;
1050 let eps = 1e-8f32;
1051 let weight_decay = 0.01f32;
1052 let step = 1u32;
1053
1054 let mut params_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1056 let grads_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
1057 let mut m_data: Vec<f32> = vec![0.0; n as usize];
1058 let mut v_data: Vec<f32> = vec![0.0; n as usize];
1059
1060 let mut cpu_params = params_data.clone();
1062 let mut cpu_m = m_data.clone();
1063 let mut cpu_v = v_data.clone();
1064 adamw_step_cpu(
1065 &mut cpu_params,
1066 &grads_data,
1067 &mut cpu_m,
1068 &mut cpu_v,
1069 lr,
1070 beta1,
1071 beta2,
1072 eps,
1073 weight_decay,
1074 step,
1075 );
1076
1077 let mut params =
1079 GpuBuffer::from_host(&ctx, ¶ms_data).expect("operation should succeed");
1080 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1081 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1082 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1083
1084 adamw_step_cuda(
1085 &mut params,
1086 &grads,
1087 &mut m,
1088 &mut v,
1089 lr,
1090 beta1,
1091 beta2,
1092 eps,
1093 weight_decay,
1094 step,
1095 n,
1096 &stream,
1097 )
1098 .expect("operation should succeed");
1099 stream.synchronize().expect("operation should succeed");
1100
1101 params.copy_to_host(&mut params_data).expect("operation should succeed");
1102 m.copy_to_host(&mut m_data).expect("operation should succeed");
1103 v.copy_to_host(&mut v_data).expect("operation should succeed");
1104
1105 for i in 0..n as usize {
1107 assert!(
1108 (params_data[i] - cpu_params[i]).abs() < 1e-4,
1109 "AdamW params mismatch at {i}: GPU={}, CPU={}",
1110 params_data[i],
1111 cpu_params[i]
1112 );
1113 assert!(
1114 (m_data[i] - cpu_m[i]).abs() < 1e-5,
1115 "AdamW m mismatch at {i}: GPU={}, CPU={}",
1116 m_data[i],
1117 cpu_m[i]
1118 );
1119 assert!(
1120 (v_data[i] - cpu_v[i]).abs() < 1e-5,
1121 "AdamW v mismatch at {i}: GPU={}, CPU={}",
1122 v_data[i],
1123 cpu_v[i]
1124 );
1125 }
1126 }
1127
1128 #[test]
1129 #[cfg(feature = "cuda")]
1130 fn test_adamw_step_not_hardcoded() {
1131 let ctx = match get_test_gpu_context() {
1133 Some(c) => c,
1134 None => return,
1135 };
1136 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1137 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1138
1139 let n = 4u32;
1140 let initial_params: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1141 let grads_data: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5]; let m_data: Vec<f32> = vec![0.0; n as usize];
1143 let v_data: Vec<f32> = vec![0.0; n as usize];
1144
1145 let mut params =
1146 GpuBuffer::from_host(&ctx, &initial_params).expect("operation should succeed");
1147 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1148 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1149 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1150
1151 adamw_step_cuda(
1152 &mut params,
1153 &grads,
1154 &mut m,
1155 &mut v,
1156 0.01, 0.9,
1158 0.999,
1159 1e-8,
1160 0.01,
1161 1,
1162 n,
1163 &stream,
1164 )
1165 .expect("operation should succeed");
1166 stream.synchronize().expect("operation should succeed");
1167
1168 let mut result_params = vec![0.0f32; n as usize];
1169 params.copy_to_host(&mut result_params).expect("operation should succeed");
1170
1171 assert_ne!(result_params, initial_params, "mutant: AdamW params unchanged after step");
1173 for (i, (&new, &old)) in result_params.iter().zip(initial_params.iter()).enumerate() {
1175 assert!(new < old, "AdamW params[{i}] should decrease with positive gradients");
1176 }
1177 }
1178
1179 #[test]
1180 #[cfg(feature = "cuda")]
1181 fn test_adamw_weight_decay() {
1182 let ctx = match get_test_gpu_context() {
1184 Some(c) => c,
1185 None => return,
1186 };
1187 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1188 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1189
1190 let n = 4u32;
1191 let params_data: Vec<f32> = vec![10.0, 10.0, 10.0, 10.0]; let grads_data: Vec<f32> = vec![0.0, 0.0, 0.0, 0.0]; let m_data: Vec<f32> = vec![0.0; n as usize];
1194 let v_data: Vec<f32> = vec![0.0; n as usize];
1195
1196 let mut params =
1197 GpuBuffer::from_host(&ctx, ¶ms_data).expect("operation should succeed");
1198 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1199 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1200 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1201
1202 adamw_step_cuda(
1204 &mut params,
1205 &grads,
1206 &mut m,
1207 &mut v,
1208 0.01, 0.9,
1210 0.999,
1211 1e-8,
1212 0.1, 1,
1214 n,
1215 &stream,
1216 )
1217 .expect("operation should succeed");
1218 stream.synchronize().expect("operation should succeed");
1219
1220 let mut result = vec![0.0f32; n as usize];
1221 params.copy_to_host(&mut result).expect("operation should succeed");
1222
1223 let expected = 10.0 * (1.0 - 0.01 * 0.1);
1225 for (i, &p) in result.iter().enumerate() {
1226 assert!(
1227 (p - expected).abs() < 1e-3,
1228 "Weight decay not applied correctly at {i}: got {p}, expected {expected}"
1229 );
1230 }
1231 }
1232
1233 #[test]
1234 #[cfg(feature = "cuda")]
1235 fn test_adam_step_basic() {
1236 let ctx = match get_test_gpu_context() {
1237 Some(c) => c,
1238 None => return,
1239 };
1240 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1241 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1242
1243 let n = 4u32;
1244 let lr = 0.001f32;
1245 let beta1 = 0.9f32;
1246 let beta2 = 0.999f32;
1247 let eps = 1e-8f32;
1248 let step = 1u32;
1249
1250 let mut params_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1252 let grads_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
1253 let mut m_data: Vec<f32> = vec![0.0; n as usize];
1254 let mut v_data: Vec<f32> = vec![0.0; n as usize];
1255
1256 let mut cpu_params = params_data.clone();
1258 let mut cpu_m = m_data.clone();
1259 let mut cpu_v = v_data.clone();
1260 adam_step_cpu(
1261 &mut cpu_params,
1262 &grads_data,
1263 &mut cpu_m,
1264 &mut cpu_v,
1265 lr,
1266 beta1,
1267 beta2,
1268 eps,
1269 step,
1270 );
1271
1272 let mut params =
1274 GpuBuffer::from_host(&ctx, ¶ms_data).expect("operation should succeed");
1275 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1276 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1277 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1278
1279 adam_step_cuda(
1280 &mut params,
1281 &grads,
1282 &mut m,
1283 &mut v,
1284 lr,
1285 beta1,
1286 beta2,
1287 eps,
1288 step,
1289 n,
1290 &stream,
1291 )
1292 .expect("operation should succeed");
1293 stream.synchronize().expect("operation should succeed");
1294
1295 params.copy_to_host(&mut params_data).expect("operation should succeed");
1296 m.copy_to_host(&mut m_data).expect("operation should succeed");
1297 v.copy_to_host(&mut v_data).expect("operation should succeed");
1298
1299 for i in 0..n as usize {
1301 assert!(
1302 (params_data[i] - cpu_params[i]).abs() < 1e-4,
1303 "Adam params mismatch at {i}: GPU={}, CPU={}",
1304 params_data[i],
1305 cpu_params[i]
1306 );
1307 }
1308 }
1309
1310 #[test]
1311 #[cfg(feature = "cuda")]
1312 fn test_adam_step_multiple_iterations() {
1313 let ctx = match get_test_gpu_context() {
1314 Some(c) => c,
1315 None => return,
1316 };
1317 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1318 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1319
1320 let n = 4u32;
1321 let lr = 0.01f32;
1322 let beta1 = 0.9f32;
1323 let beta2 = 0.999f32;
1324 let eps = 1e-8f32;
1325
1326 let mut params_data: Vec<f32> = vec![1.0, 1.0, 1.0, 1.0];
1327 let grads_data: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5];
1328 let m_data: Vec<f32> = vec![0.0; n as usize];
1329 let v_data: Vec<f32> = vec![0.0; n as usize];
1330
1331 let mut params =
1332 GpuBuffer::from_host(&ctx, ¶ms_data).expect("operation should succeed");
1333 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1334 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1335 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1336
1337 for step in 1..=10 {
1339 adam_step_cuda(
1340 &mut params,
1341 &grads,
1342 &mut m,
1343 &mut v,
1344 lr,
1345 beta1,
1346 beta2,
1347 eps,
1348 step,
1349 n,
1350 &stream,
1351 )
1352 .expect("operation should succeed");
1353 }
1354 stream.synchronize().expect("operation should succeed");
1355
1356 params.copy_to_host(&mut params_data).expect("operation should succeed");
1357
1358 for &p in ¶ms_data {
1360 assert!(p < 1.0, "Params should decrease after multiple Adam steps");
1361 assert!(p > 0.0, "Params should remain positive");
1362 }
1363 }
1364
1365 #[test]
1366 #[cfg(feature = "cuda")]
1367 fn test_gradient_clip_basic() {
1368 let ctx = match get_test_gpu_context() {
1369 Some(c) => c,
1370 None => return,
1371 };
1372 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1373 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1374
1375 let n = 4u32;
1376 let grads_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0];
1377 let scale = 0.5f32; let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1380
1381 gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1382 stream.synchronize().expect("operation should succeed");
1383
1384 let mut result = vec![0.0f32; n as usize];
1385 grads.copy_to_host(&mut result).expect("operation should succeed");
1386
1387 let mut expected = grads_data.clone();
1389 gradient_clip_cpu(&mut expected, scale);
1390
1391 for (i, (&got, &exp)) in result.iter().zip(expected.iter()).enumerate() {
1392 assert!(
1393 (got - exp).abs() < 1e-5,
1394 "Gradient clip mismatch at {i}: got {got}, expected {exp}"
1395 );
1396 }
1397 }
1398
1399 #[test]
1400 #[cfg(feature = "cuda")]
1401 fn test_gradient_clip_no_op() {
1402 let ctx = match get_test_gpu_context() {
1404 Some(c) => c,
1405 None => return,
1406 };
1407 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1408 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1409
1410 let n = 4u32;
1411 let grads_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1412 let scale = 1.0f32; let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1415
1416 gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1418 stream.synchronize().expect("operation should succeed");
1419
1420 let mut result = vec![0.0f32; n as usize];
1421 grads.copy_to_host(&mut result).expect("operation should succeed");
1422
1423 for (i, (&got, &exp)) in result.iter().zip(grads_data.iter()).enumerate() {
1425 assert!(
1426 (got - exp).abs() < 1e-6,
1427 "Gradient clip with scale=1 should not modify values at {i}"
1428 );
1429 }
1430 }
1431
1432 #[test]
1433 #[cfg(feature = "cuda")]
1434 fn test_gradient_clip_not_hardcoded() {
1435 let ctx = match get_test_gpu_context() {
1437 Some(c) => c,
1438 None => return,
1439 };
1440 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1441 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1442
1443 let n = 4u32;
1444 let grads_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1445 let scale = 0.1f32;
1446
1447 let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1448
1449 gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1450 stream.synchronize().expect("operation should succeed");
1451
1452 let mut result = vec![0.0f32; n as usize];
1453 grads.copy_to_host(&mut result).expect("operation should succeed");
1454
1455 assert_ne!(result, grads_data, "mutant: gradient clip had no effect");
1457
1458 assert!((result[0] - 1.0).abs() < 1e-5);
1460 assert!((result[1] - 2.0).abs() < 1e-5);
1461 assert!((result[2] - 3.0).abs() < 1e-5);
1462 assert!((result[3] - 4.0).abs() < 1e-5);
1463 }
1464
1465 #[test]
1466 #[cfg(feature = "cuda")]
1467 fn test_optimizer_large_scale() {
1468 let ctx = match get_test_gpu_context() {
1470 Some(c) => c,
1471 None => return,
1472 };
1473 init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1474 let stream = CudaStream::new(&ctx).expect("operation should succeed");
1475
1476 let n = 1024u32;
1477 let params_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
1478 let grads_data: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.01).sin()).collect();
1479 let m_data: Vec<f32> = vec![0.0; n as usize];
1480 let v_data: Vec<f32> = vec![0.0; n as usize];
1481
1482 let mut params =
1483 GpuBuffer::from_host(&ctx, ¶ms_data).expect("operation should succeed");
1484 let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1485 let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1486 let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1487
1488 adamw_step_cuda(
1489 &mut params,
1490 &grads,
1491 &mut m,
1492 &mut v,
1493 0.001,
1494 0.9,
1495 0.999,
1496 1e-8,
1497 0.01,
1498 1,
1499 n,
1500 &stream,
1501 )
1502 .expect("operation should succeed");
1503 stream.synchronize().expect("operation should succeed");
1504
1505 let mut result = vec![0.0f32; n as usize];
1506 params.copy_to_host(&mut result).expect("operation should succeed");
1507
1508 assert!(
1510 !result.iter().any(|x| x.is_nan() || x.is_infinite()),
1511 "Large-scale optimizer should not produce NaN/Inf"
1512 );
1513 }
1514}