Skip to main content

entrenar/autograd/
cuda_optim.rs

1//! CUDA-accelerated optimizer kernels for autograd
2//!
3//! This module wraps trueno-gpu optimizer kernels for GPU-resident weight updates.
4//! Eliminates CPU↔GPU synchronization by keeping all optimizer state on GPU.
5//!
6//! # Architecture (SPEC-FT-001 v3.1.0)
7//!
8//! ```text
9//! entrenar autograd
10//!     └── cuda_optim (this module)
11//!             └── trueno-gpu/kernels/optimizer
12//!                     └── AdamWStepKernel, AdamStepKernel, GradientClipKernel
13//! ```
14//!
15//! # Available Functions
16//!
17//! - `adamw_step_cuda` - Fused AdamW with weight decay
18//! - `adam_step_cuda` - Vanilla Adam without weight decay
19//! - `gradient_clip_cuda` - Apply gradient clipping scale
20//! - `squared_sum_cuda` - GPU-side sum-of-squares for L2 norm (KAIZEN-049)
21
22#![allow(unsafe_code)]
23#![allow(trivial_casts)]
24#![allow(clippy::borrow_as_ptr)]
25#![allow(clippy::ref_as_ptr)]
26
27#[cfg(feature = "cuda")]
28use std::collections::HashMap;
29#[cfg(feature = "cuda")]
30use std::sync::{Mutex, OnceLock};
31
32#[cfg(feature = "cuda")]
33use trueno_gpu::driver::{CudaContext, CudaModule, CudaStream, GpuBuffer, LaunchConfig};
34#[cfg(feature = "cuda")]
35use trueno_gpu::kernels::backward::{FusedCausalCrossEntropyKernel, FusedCrossEntropyKernel};
36use trueno_gpu::kernels::{
37    AdamStepKernel, AdamWStepKernel, ClipScaleReduceKernel, GradientClipGpuScaleKernel,
38    GradientClipKernel, Kernel, SquaredSumKernel,
39};
40
41use super::cuda_tensor::{CudaTensorError, Result};
42
43/// Cached compiled CUDA modules for optimizer kernels
44#[cfg(feature = "cuda")]
45static OPTIM_KERNEL_CACHE: OnceLock<Mutex<OptimKernelCache>> = OnceLock::new();
46
47/// Cache for compiled optimizer kernel modules
48#[cfg(feature = "cuda")]
49struct OptimKernelCache {
50    ctx: std::sync::Arc<CudaContext>,
51    modules: HashMap<String, CudaModule>,
52    sm_target: String,
53}
54
55#[cfg(feature = "cuda")]
56impl OptimKernelCache {
57    fn new(ctx: std::sync::Arc<CudaContext>) -> Self {
58        let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
59        Self { ctx, modules: HashMap::new(), sm_target }
60    }
61
62    fn sm_target(&self) -> &str {
63        &self.sm_target
64    }
65
66    /// Look up a previously compiled module by key (KAIZEN-058).
67    fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
68        self.modules.get_mut(name)
69    }
70
71    fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
72        if !self.modules.contains_key(name) {
73            let module = CudaModule::from_ptx(&self.ctx, ptx).map_err(|e| {
74                CudaTensorError::KernelError(format!("Failed to compile {name}: {e:?}"))
75            })?;
76            self.modules.insert(name.to_string(), module);
77        }
78        Ok(self.modules.get_mut(name).expect("module was just inserted above"))
79    }
80}
81
82/// Initialize optimizer kernel cache with CUDA context
83#[cfg(feature = "cuda")]
84pub fn init_optim_kernel_cache(ctx: std::sync::Arc<CudaContext>) -> Result<()> {
85    OPTIM_KERNEL_CACHE.get_or_init(|| Mutex::new(OptimKernelCache::new(ctx)));
86    Ok(())
87}
88
89/// Pre-warm AdamW optimizer kernels for all trainable parameter sizes (ENT-153).
90///
91/// The kernel key is `adamw_step_{n}` where `n` is parameter count.
92///
93/// ## LoRA mode (NF4 QLoRA)
94/// - `hidden * rank` for A_q, A_v
95/// - `rank * q_dim` for B_q
96/// - `rank * kv_hidden` for B_v
97/// - `hidden_size` for norm weights
98///
99/// ## Full fp32 mode (non-NF4)
100/// - `hidden * hidden` for w_q, w_o
101/// - `hidden * kv_hidden` for w_k, w_v
102/// - `hidden * intermediate` for w_gate, w_up, w_down
103/// - `hidden_size` for norm weights
104///
105/// ## Classifier head (both modes)
106/// - `num_classes * hidden_size` for classifier weight
107/// - `num_classes` for classifier bias
108///
109/// Must JIT-compile before block upload fills VRAM (C-PREWARM-001).
110#[cfg(feature = "cuda")]
111pub fn pre_warm_lora_adamw_kernels(
112    hidden_size: usize,
113    q_dim: usize,
114    kv_hidden_size: usize,
115    lora_rank: usize,
116    num_classes: usize,
117    intermediate_size: usize,
118    quantize_nf4: bool,
119) -> Result<()> {
120    if lora_rank == 0 {
121        return Ok(());
122    }
123
124    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
125    let mut cache = cache.lock().map_err(|_err| {
126        CudaTensorError::KernelError("Failed to acquire optim kernel cache lock".to_string())
127    })?;
128
129    let target = cache.sm_target().to_string();
130
131    let mut sizes: Vec<u32> = vec![
132        (hidden_size * lora_rank) as u32,    // A_q, A_v
133        (lora_rank * q_dim) as u32,          // B_q
134        (lora_rank * kv_hidden_size) as u32, // B_v
135        hidden_size as u32,                  // norm weights
136    ];
137
138    // Full fp32 weight sizes (non-NF4 mode: optimizer runs on all block weights)
139    if !quantize_nf4 {
140        sizes.push((hidden_size * hidden_size) as u32); // w_q, w_o
141        sizes.push((hidden_size * kv_hidden_size) as u32); // w_k, w_v
142        sizes.push((hidden_size * intermediate_size) as u32); // w_gate, w_up, w_down
143    }
144
145    // Classifier head sizes
146    if num_classes > 0 {
147        sizes.push((num_classes * hidden_size) as u32);
148        sizes.push(num_classes as u32);
149    }
150
151    sizes.sort_unstable();
152    sizes.dedup();
153
154    for n in sizes {
155        let kernel = AdamWStepKernel::new(n);
156        let ptx = kernel.emit_ptx_for_target(&target);
157        let key = format!("adamw_step_{n}");
158        cache.get_or_compile(&key, &ptx)?;
159    }
160
161    Ok(())
162}
163
164/// Fused AdamW optimizer step on GPU
165///
166/// Performs in-place weight update with momentum, adaptive learning rate, and weight decay.
167///
168/// # Arguments
169/// - `params`: weight tensor (updated in-place)
170/// - `grads`: gradient tensor
171/// - `m`: first moment state (updated in-place)
172/// - `v`: second moment state (updated in-place)
173/// - `lr`: learning rate
174/// - `beta1`: first moment decay (typically 0.9)
175/// - `beta2`: second moment decay (typically 0.999)
176/// - `eps`: numerical stability (typically 1e-8)
177/// - `weight_decay`: L2 penalty coefficient
178/// - `step`: current step number (for bias adjust)
179/// - `n`: number of parameters
180/// - `stream`: CUDA stream
181#[cfg(feature = "cuda")]
182pub fn adamw_step_cuda(
183    params: &mut GpuBuffer<f32>,
184    grads: &GpuBuffer<f32>,
185    m: &mut GpuBuffer<f32>,
186    v: &mut GpuBuffer<f32>,
187    lr: f32,
188    beta1: f32,
189    beta2: f32,
190    eps: f32,
191    weight_decay: f32,
192    step: u32,
193    n: u32,
194    stream: &CudaStream,
195) -> Result<()> {
196    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
197    let mut cache = cache.lock().map_err(|_err| {
198        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
199    })?;
200
201    let key = format!("adamw_step_{n}");
202    let module = match cache.get_cached(&key) {
203        Some(m) => m,
204        None => {
205            let kernel = AdamWStepKernel::new(n);
206            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
207            cache.get_or_compile(&key, &ptx)?
208        }
209    };
210
211    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
212
213    // Pre-compute bias adjust factors
214    let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
215    let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
216
217    let params_ptr = params.as_ptr();
218    let grads_ptr = grads.as_ptr();
219    let m_ptr = m.as_ptr();
220    let v_ptr = v.as_ptr();
221
222    let mut args: [*mut std::ffi::c_void; 12] = [
223        &params_ptr as *const _ as *mut _,
224        &grads_ptr as *const _ as *mut _,
225        &m_ptr as *const _ as *mut _,
226        &v_ptr as *const _ as *mut _,
227        &lr as *const _ as *mut _,
228        &beta1 as *const _ as *mut _,
229        &beta2 as *const _ as *mut _,
230        &eps as *const _ as *mut _,
231        &weight_decay as *const _ as *mut _,
232        &bias_adjust1 as *const _ as *mut _,
233        &bias_adjust2 as *const _ as *mut _,
234        &n as *const _ as *mut _,
235    ];
236
237    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
238    // matching sizes, and the kernel parameters match the expected PTX signature.
239    unsafe {
240        stream.launch_kernel(module, "adamw_step", &config, &mut args).map_err(|e| {
241            CudaTensorError::KernelError(format!("AdamW step launch failed: {e:?}"))
242        })?;
243    }
244
245    Ok(())
246}
247
248/// Fused Adam optimizer step on GPU (no weight decay)
249///
250/// Same as `adamw_step_cuda` but without the decoupled weight decay term.
251#[cfg(feature = "cuda")]
252pub fn adam_step_cuda(
253    params: &mut GpuBuffer<f32>,
254    grads: &GpuBuffer<f32>,
255    m: &mut GpuBuffer<f32>,
256    v: &mut GpuBuffer<f32>,
257    lr: f32,
258    beta1: f32,
259    beta2: f32,
260    eps: f32,
261    step: u32,
262    n: u32,
263    stream: &CudaStream,
264) -> Result<()> {
265    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
266    let mut cache = cache.lock().map_err(|_err| {
267        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
268    })?;
269
270    let key = format!("adam_step_{n}");
271    let module = match cache.get_cached(&key) {
272        Some(m) => m,
273        None => {
274            let kernel = AdamStepKernel::new(n);
275            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
276            cache.get_or_compile(&key, &ptx)?
277        }
278    };
279
280    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
281
282    // Pre-compute bias adjust factors
283    let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
284    let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
285
286    let params_ptr = params.as_ptr();
287    let grads_ptr = grads.as_ptr();
288    let m_ptr = m.as_ptr();
289    let v_ptr = v.as_ptr();
290
291    let mut args: [*mut std::ffi::c_void; 11] = [
292        &params_ptr as *const _ as *mut _,
293        &grads_ptr as *const _ as *mut _,
294        &m_ptr as *const _ as *mut _,
295        &v_ptr as *const _ as *mut _,
296        &lr as *const _ as *mut _,
297        &beta1 as *const _ as *mut _,
298        &beta2 as *const _ as *mut _,
299        &eps as *const _ as *mut _,
300        &bias_adjust1 as *const _ as *mut _,
301        &bias_adjust2 as *const _ as *mut _,
302        &n as *const _ as *mut _,
303    ];
304
305    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
306    // matching sizes, and the kernel parameters match the expected PTX signature.
307    unsafe {
308        stream
309            .launch_kernel(module, "adam_step", &config, &mut args)
310            .map_err(|e| CudaTensorError::KernelError(format!("Adam step launch failed: {e:?}")))?;
311    }
312
313    Ok(())
314}
315
316/// Apply gradient clipping on GPU
317///
318/// Scales gradients by a pre-computed factor to enforce maximum norm.
319///
320/// # Arguments
321/// - `grads`: gradient tensor (updated in-place)
322/// - `scale`: clipping scale factor (pre-computed as `min(1.0, max_norm / grad_norm)`)
323/// - `n`: number of gradient elements
324/// - `stream`: CUDA stream
325///
326/// # Usage
327/// ```ignore
328/// // Compute gradient norm on host
329/// let grad_norm = compute_l2_norm(&grads);
330/// let scale = (max_norm / grad_norm).min(1.0);
331///
332/// // Apply clipping on GPU
333/// gradient_clip_cuda(&mut grads, scale, n, &stream)?;
334/// ```
335#[cfg(feature = "cuda")]
336pub fn gradient_clip_cuda(
337    grads: &mut GpuBuffer<f32>,
338    scale: f32,
339    n: u32,
340    stream: &CudaStream,
341) -> Result<()> {
342    // Skip kernel launch if no clipping needed
343    if (scale - 1.0).abs() < 1e-7 {
344        return Ok(());
345    }
346
347    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
348    let mut cache = cache.lock().map_err(|_err| {
349        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
350    })?;
351
352    let key = format!("gradient_clip_{n}");
353    let module = match cache.get_cached(&key) {
354        Some(m) => m,
355        None => {
356            let kernel = GradientClipKernel::new(n);
357            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
358            cache.get_or_compile(&key, &ptx)?
359        }
360    };
361
362    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
363
364    let grads_ptr = grads.as_ptr();
365
366    let mut args: [*mut std::ffi::c_void; 3] =
367        [&grads_ptr as *const _ as *mut _, &scale as *const _ as *mut _, &n as *const _ as *mut _];
368
369    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
370    // matching sizes, and the kernel parameters match the expected PTX signature.
371    unsafe {
372        stream.launch_kernel(module, "gradient_clip", &config, &mut args).map_err(|e| {
373            CudaTensorError::KernelError(format!("Gradient clip launch failed: {e:?}"))
374        })?;
375    }
376
377    Ok(())
378}
379
380/// GPU-side sum-of-squares reduction (KAIZEN-049).
381///
382/// Computes `sum(input[i]^2)` entirely on GPU, returning only `num_blocks` partial sums
383/// (~1KB) to host. Host finishes with f64 summation and sqrt for the L2 norm.
384///
385/// # Contract (C-SQSUM-002)
386///
387/// - **Precondition**: `n > 0`, `input` has at least `n` elements
388/// - **Postcondition**: returned f32 = sqrt(sum(input[i]^2)) to within O(n × eps_f32)
389/// - **Transfer**: ~1KB D2H instead of n×4 bytes (128MB for 32M elements)
390///
391/// # Errors
392///
393/// Returns `Err` if kernel cache not initialized, kernel compilation fails, or GPU transfer fails.
394#[cfg(feature = "cuda")]
395pub fn squared_sum_cuda(input: &GpuBuffer<f32>, n: u32, stream: &CudaStream) -> Result<f32> {
396    let pending = squared_sum_launch_cuda(input, n, stream)?;
397    stream
398        .synchronize()
399        .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
400    squared_sum_collect(&pending)
401}
402
403/// Launched but not-yet-collected squared sum reduction.
404///
405/// Holds the GPU buffer of partial sums until `squared_sum_collect` downloads them.
406#[cfg(feature = "cuda")]
407pub struct PendingSquaredSum {
408    output: GpuBuffer<f32>,
409    num_blocks: u32,
410}
411
412/// Launch a squared sum reduction kernel without synchronizing (KAIZEN-055).
413///
414/// Returns a `PendingSquaredSum` handle. The caller MUST call `stream.synchronize()`
415/// before calling `squared_sum_collect()` on the handle.
416///
417/// This allows batching multiple reductions with a single sync point:
418/// ```ignore
419/// let p1 = squared_sum_launch_cuda(&buf1, n1, stream)?;
420/// let p2 = squared_sum_launch_cuda(&buf2, n2, stream)?;
421/// stream.synchronize()?;  // single sync for both
422/// let norm1 = squared_sum_collect(&p1)?;
423/// let norm2 = squared_sum_collect(&p2)?;
424/// ```
425#[cfg(feature = "cuda")]
426pub fn squared_sum_launch_cuda(
427    input: &GpuBuffer<f32>,
428    n: u32,
429    stream: &CudaStream,
430) -> Result<PendingSquaredSum> {
431    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
432    let mut cache = cache.lock().map_err(|_err| {
433        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
434    })?;
435
436    let kernel = SquaredSumKernel::new(n);
437    let num_blocks = kernel.num_blocks();
438
439    // Clone ctx before mutable borrow via get_or_compile/get_cached
440    let ctx = std::sync::Arc::clone(&cache.ctx);
441
442    let key = format!("squared_sum_{n}");
443    let module = match cache.get_cached(&key) {
444        Some(m) => m,
445        None => {
446            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
447            cache.get_or_compile(&key, &ptx)?
448        }
449    };
450
451    // Allocate output buffer for block partial sums (num_blocks × 4 bytes, typically ≤1KB)
452    let output = GpuBuffer::<f32>::new(&ctx, num_blocks as usize).map_err(|e| {
453        CudaTensorError::KernelError(format!("Failed to allocate squared_sum output: {e:?}"))
454    })?;
455
456    let config = LaunchConfig {
457        grid: (num_blocks, 1, 1),
458        block: (kernel.block_size(), 1, 1),
459        shared_mem: 8 * 4, // 8 warp partials × 4 bytes
460    };
461
462    let input_ptr = input.as_ptr();
463    let output_ptr = output.as_ptr();
464
465    let mut args: [*mut std::ffi::c_void; 3] = [
466        &input_ptr as *const _ as *mut _,
467        &output_ptr as *const _ as *mut _,
468        &n as *const _ as *mut _,
469    ];
470
471    // SAFETY: Kernel launch requires FFI. input has n elements, output has num_blocks elements,
472    // parameters match PTX signature (u64 input_ptr, u64 output_ptr, u32 n).
473    unsafe {
474        stream.launch_kernel(module, "squared_sum_reduce", &config, &mut args).map_err(|e| {
475            CudaTensorError::KernelError(format!("Squared sum launch failed: {e:?}"))
476        })?;
477    }
478
479    Ok(PendingSquaredSum { output, num_blocks })
480}
481
482/// Collect the result of a previously launched squared sum reduction (KAIZEN-055).
483///
484/// **Precondition**: `stream.synchronize()` must have been called after the launch.
485#[cfg(feature = "cuda")]
486pub fn squared_sum_collect(pending: &PendingSquaredSum) -> Result<f32> {
487    let mut partials = vec![0.0f32; pending.num_blocks as usize];
488    pending.output.copy_to_host(&mut partials).map_err(|e| {
489        CudaTensorError::KernelError(format!("Failed to download partial sums: {e:?}"))
490    })?;
491
492    // Sum partials in f64 for precision, then sqrt for L2 norm
493    let total: f64 = partials.iter().map(|&x| f64::from(x)).sum();
494    Ok(total.sqrt() as f32)
495}
496
497/// Launch a squared sum reduction writing into a pre-allocated output buffer at offset (ALB-078).
498///
499/// Unlike `squared_sum_launch_cuda` which allocates a fresh output buffer, this writes
500/// partial sums to `output_ptr + output_offset_elements * 4` in an existing buffer.
501/// Used by the fused clip pipeline to collect all 9 groups' partials into one contiguous buffer.
502///
503/// # Returns
504///
505/// Number of blocks written (= number of f32 partial sums at the output offset).
506#[cfg(feature = "cuda")]
507pub fn squared_sum_launch_into(
508    input: &GpuBuffer<f32>,
509    n: u32,
510    output_ptr: u64, // raw device pointer with offset already applied
511    stream: &CudaStream,
512) -> Result<u32> {
513    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
514    let mut cache = cache.lock().map_err(|_err| {
515        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
516    })?;
517
518    let kernel = SquaredSumKernel::new(n);
519    let num_blocks = kernel.num_blocks();
520
521    let key = format!("squared_sum_{n}");
522    let module = match cache.get_cached(&key) {
523        Some(m) => m,
524        None => {
525            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
526            cache.get_or_compile(&key, &ptx)?
527        }
528    };
529
530    let config = LaunchConfig {
531        grid: (num_blocks, 1, 1),
532        block: (kernel.block_size(), 1, 1),
533        shared_mem: 8 * 4, // 8 warp partials × 4 bytes
534    };
535
536    let input_ptr = input.as_ptr();
537
538    let mut args: [*mut std::ffi::c_void; 3] = [
539        &input_ptr as *const _ as *mut _,
540        &output_ptr as *const _ as *mut _,
541        &n as *const _ as *mut _,
542    ];
543
544    // SAFETY: Kernel launch requires FFI. input has n elements, output region has num_blocks elements,
545    // parameters match PTX signature (u64 input_ptr, u64 output_ptr, u32 n).
546    unsafe {
547        stream.launch_kernel(module, "squared_sum_reduce", &config, &mut args).map_err(|e| {
548            CudaTensorError::KernelError(format!("Squared sum launch_into failed: {e:?}"))
549        })?;
550    }
551
552    Ok(num_blocks)
553}
554
555/// Launch the clip scale reduction kernel on GPU (ALB-078).
556///
557/// Reads a contiguous buffer of squared-sum partial results, computes the global
558/// L2 norm and clip scale entirely on GPU. Writes `[scale, norm]` to `output`.
559///
560/// # Contract (C-FUSEDCLIP-001)
561///
562/// - output[0] = min(1.0, max_norm / sqrt(sum(partials[0..total_n])))
563/// - output[1] = sqrt(sum(partials[0..total_n]))
564#[cfg(feature = "cuda")]
565pub fn clip_scale_reduce_cuda(
566    partials: &GpuBuffer<f32>,
567    total_n: u32,
568    max_norm: f32,
569    output: &GpuBuffer<f32>,
570    stream: &CudaStream,
571) -> Result<()> {
572    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
573    let mut cache = cache.lock().map_err(|_err| {
574        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
575    })?;
576
577    let key = "clip_scale_reduce".to_string();
578    let module = match cache.get_cached(&key) {
579        Some(m) => m,
580        None => {
581            let kernel = ClipScaleReduceKernel;
582            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
583            cache.get_or_compile(&key, &ptx)?
584        }
585    };
586
587    // Single CTA, single thread — partials buffer is tiny (~1800 f32 for 350M)
588    let config = LaunchConfig { grid: (1, 1, 1), block: (1, 1, 1), shared_mem: 0 };
589
590    let partials_ptr = partials.as_ptr();
591    let output_ptr = output.as_ptr();
592
593    let mut args: [*mut std::ffi::c_void; 4] = [
594        &partials_ptr as *const _ as *mut _,
595        &total_n as *const _ as *mut _,
596        &max_norm as *const _ as *mut _,
597        &output_ptr as *const _ as *mut _,
598    ];
599
600    // SAFETY: partials has total_n elements, output has 2 elements (scale + norm).
601    unsafe {
602        stream.launch_kernel(module, "clip_scale_reduce", &config, &mut args).map_err(|e| {
603            CudaTensorError::KernelError(format!("Clip scale reduce launch failed: {e:?}"))
604        })?;
605    }
606
607    Ok(())
608}
609
610/// Apply gradient clipping with scale read from GPU memory (ALB-078).
611///
612/// Like `gradient_clip_cuda` but reads the scale from a GPU pointer instead of
613/// a host f32. This avoids D2H transfer of the clip scale.
614///
615/// The kernel exits early if scale ≈ 1.0, avoiding unnecessary write bandwidth.
616#[cfg(feature = "cuda")]
617pub fn gradient_clip_gpu_scale_cuda(
618    grads: &mut GpuBuffer<f32>,
619    scale_ptr: u64, // device pointer to f32 scale value
620    n: u32,
621    stream: &CudaStream,
622) -> Result<()> {
623    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
624    let mut cache = cache.lock().map_err(|_err| {
625        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
626    })?;
627
628    let key = format!("gradient_clip_gpu_scale_{n}");
629    let module = match cache.get_cached(&key) {
630        Some(m) => m,
631        None => {
632            let kernel = GradientClipGpuScaleKernel::new(n);
633            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
634            cache.get_or_compile(&key, &ptx)?
635        }
636    };
637
638    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
639
640    let grads_ptr = grads.as_ptr();
641
642    let mut args: [*mut std::ffi::c_void; 3] = [
643        &grads_ptr as *const _ as *mut _,
644        &scale_ptr as *const _ as *mut _,
645        &n as *const _ as *mut _,
646    ];
647
648    // SAFETY: grads has n elements, scale_ptr points to valid f32 on GPU.
649    unsafe {
650        stream.launch_kernel(module, "gradient_clip_gpu_scale", &config, &mut args).map_err(
651            |e| {
652                CudaTensorError::KernelError(format!(
653                    "Gradient clip GPU scale launch failed: {e:?}"
654                ))
655            },
656        )?;
657    }
658
659    Ok(())
660}
661
662/// Pre-allocated state for fused gradient clipping pipeline (ALB-078).
663///
664/// Holds GPU buffers for the contiguous partial-sum collection and clip scale output.
665/// Initialized once at trainer creation, reused every step.
666#[cfg(feature = "cuda")]
667pub struct FusedClipState {
668    /// Contiguous buffer for all squared-sum partial results across 9 gradient groups
669    pub partials_buf: GpuBuffer<f32>,
670    /// Output buffer: [clip_scale, grad_norm] (2 × f32)
671    pub scale_buf: GpuBuffer<f32>,
672    /// Byte offset into partials_buf for each of the 9 gradient groups
673    pub offsets: [u32; 9],
674    /// Number of partial-sum blocks for each group
675    pub num_blocks: [u32; 9],
676    /// Total number of partial sums across all groups
677    pub total_partials: u32,
678}
679
680#[cfg(feature = "cuda")]
681impl FusedClipState {
682    /// Create a new fused clip state for the given gradient buffer sizes.
683    ///
684    /// # Arguments
685    ///
686    /// * `ctx` - CUDA context
687    /// * `grad_sizes` - sizes (in elements) of the 9 gradient buffers:
688    ///   [w_q, w_k, w_v, w_o, gate, up, down, input_norm, post_attn_norm]
689    pub fn new(ctx: &std::sync::Arc<CudaContext>, grad_sizes: &[u32; 9]) -> Result<Self> {
690        let mut offsets = [0u32; 9];
691        let mut num_blocks_arr = [0u32; 9];
692        let mut total = 0u32;
693
694        for (i, &n) in grad_sizes.iter().enumerate() {
695            offsets[i] = total;
696            let kernel = SquaredSumKernel::new(n);
697            let nb = kernel.num_blocks();
698            num_blocks_arr[i] = nb;
699            total += nb;
700        }
701
702        let partials_buf = GpuBuffer::<f32>::new(ctx, total as usize).map_err(|e| {
703            CudaTensorError::KernelError(format!("Failed to allocate partials buffer: {e:?}"))
704        })?;
705
706        let scale_buf = GpuBuffer::<f32>::new(ctx, 2).map_err(|e| {
707            CudaTensorError::KernelError(format!("Failed to allocate scale buffer: {e:?}"))
708        })?;
709
710        Ok(Self {
711            partials_buf,
712            scale_buf,
713            offsets,
714            num_blocks: num_blocks_arr,
715            total_partials: total,
716        })
717    }
718}
719
720/// Fused GPU cross-entropy loss + softmax backward, in-place (KAIZEN-050 + KAIZEN-052).
721///
722/// Computes cross-entropy loss and writes gradient **in-place** to the logits buffer,
723/// eliminating both the logits D2H (77.8MB) + CPU softmax (40ms) + gradient H2D (77.8MB)
724/// AND the separate gradient buffer allocation (77.8MB for Qwen3-4B).
725///
726/// # Returns
727///
728/// Scalar loss (averaged over seq_len). Gradient is written in-place to `logits_buf`.
729///
730/// # Contract (C-XENT-002, updated KAIZEN-052)
731///
732/// - **Precondition**: `logits_buf` has `seq_len * vocab_size` elements, targets in `[0, vocab_size)`
733/// - **Postcondition**: `logits_buf[i] = (softmax - one_hot) * scale` (gradient, in-place)
734/// - **Postcondition**: `loss = mean(-log(softmax[target]))`
735/// - **Transfer**: H2D targets (seq_len×4 bytes) + D2H loss_partials (seq_len×4 bytes)
736/// - **Allocation**: 0 bytes grad buffer (was 77.8MB before KAIZEN-052)
737#[cfg(feature = "cuda")]
738pub fn fused_cross_entropy_cuda(
739    logits_buf: &mut GpuBuffer<f32>,
740    target_ids: &[u32],
741    seq_len: u32,
742    vocab_size: u32,
743    scale: f32,
744    stream: &CudaStream,
745) -> Result<f32> {
746    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
747    let mut cache = cache.lock().map_err(|_err| {
748        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
749    })?;
750
751    let kernel = FusedCrossEntropyKernel::new(vocab_size);
752
753    // Clone ctx before mutable borrow via get_or_compile/get_cached
754    let ctx = std::sync::Arc::clone(&cache.ctx);
755
756    let key = format!("fused_xent_{vocab_size}");
757    let module = match cache.get_cached(&key) {
758        Some(m) => m,
759        None => {
760            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
761            cache.get_or_compile(&key, &ptx)?
762        }
763    };
764
765    // Upload targets to GPU (seq_len × u32 = ~512 bytes for seq_len=128)
766    let targets_u32: Vec<u32> = target_ids[..seq_len as usize].to_vec();
767    let targets_gpu = GpuBuffer::<u32>::from_host(&ctx, &targets_u32)
768        .map_err(|e| CudaTensorError::KernelError(format!("Failed to upload targets: {e:?}")))?;
769
770    // KAIZEN-052: No grad_gpu allocation — gradient written in-place to logits_buf.
771
772    // Allocate loss partials buffer (seq_len × f32 — downloaded for scalar average)
773    let loss_gpu = GpuBuffer::<f32>::new(&ctx, seq_len as usize).map_err(|e| {
774        CudaTensorError::KernelError(format!("Failed to allocate loss partials: {e:?}"))
775    })?;
776
777    // Shared memory: 72 bytes (8 warp maxes + global max + 8 warp sums + global sum)
778    let config =
779        LaunchConfig { grid: (seq_len, 1, 1), block: (kernel.block_size(), 1, 1), shared_mem: 72 };
780
781    let logits_grad_ptr = logits_buf.as_ptr();
782    let targets_ptr = targets_gpu.as_ptr();
783    let loss_ptr = loss_gpu.as_ptr();
784
785    let mut args: [*mut std::ffi::c_void; 5] = [
786        &logits_grad_ptr as *const _ as *mut _,
787        &targets_ptr as *const _ as *mut _,
788        &loss_ptr as *const _ as *mut _,
789        &vocab_size as *const _ as *mut _,
790        &scale as *const _ as *mut _,
791    ];
792
793    // SAFETY: Kernel launch requires FFI. logits_buf has seq_len*vocab_size elements
794    // (read as logits, overwritten with gradients in-place). targets_gpu has seq_len u32
795    // elements, loss_gpu has seq_len f32 elements. Parameters match PTX signature
796    // (u64 logits_grad_ptr, u64 targets_ptr, u64 loss_ptr, u32 vocab_size, f32 scale).
797    unsafe {
798        stream.launch_kernel(module, "fused_cross_entropy", &config, &mut args).map_err(|e| {
799            CudaTensorError::KernelError(format!("Fused cross-entropy launch failed: {e:?}"))
800        })?;
801    }
802
803    // Synchronize and download loss partials (~512 bytes)
804    stream
805        .synchronize()
806        .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
807
808    let mut loss_partials = vec![0.0f32; seq_len as usize];
809    loss_gpu.copy_to_host(&mut loss_partials).map_err(|e| {
810        CudaTensorError::KernelError(format!("Failed to download loss partials: {e:?}"))
811    })?;
812
813    // Average loss across sequence positions (f64 for precision)
814    let total_loss: f64 = loss_partials.iter().map(|&x| f64::from(x)).sum();
815    let avg_loss = (total_loss / f64::from(seq_len)) as f32;
816
817    Ok(avg_loss)
818}
819
820/// Fused GPU causal cross-entropy loss + softmax backward, in-place (KAIZEN-064).
821///
822/// Like `fused_cross_entropy_cuda` but with causal LM masking: only positions
823/// `[loss_start, loss_end)` contribute to loss and gradient. Positions outside
824/// this range get zero gradient. Eliminates:
825/// - Logits D2H download (~296MB for Qwen3-4B)
826/// - CPU softmax computation
827/// - Gradient H2D upload (~296MB for Qwen3-4B)
828///
829/// # Contract (C-CAUSALXENT-001)
830///
831/// - **Precondition**: logits_buf has `seq_len * vocab_size` valid elements
832/// - **Precondition**: `loss_start < loss_end <= seq_len`
833/// - **Postcondition**: logits_buf overwritten with gradient (in-place)
834/// - **Postcondition**: Returns average loss over `[loss_start, loss_end)` only
835#[cfg(feature = "cuda")]
836pub fn fused_causal_cross_entropy_cuda(
837    logits_buf: &mut GpuBuffer<f32>,
838    target_ids: &[u32],
839    seq_len: u32,
840    vocab_size: u32,
841    loss_start: u32,
842    loss_end: u32,
843    scale: f32,
844    stream: &CudaStream,
845) -> Result<f32> {
846    let cache = OPTIM_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
847    let mut cache = cache.lock().map_err(|_err| {
848        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
849    })?;
850
851    let kernel = FusedCausalCrossEntropyKernel::new(vocab_size);
852
853    let ctx = std::sync::Arc::clone(&cache.ctx);
854
855    let key = format!("fused_causal_xent_{vocab_size}");
856    let module = match cache.get_cached(&key) {
857        Some(m) => m,
858        None => {
859            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
860            cache.get_or_compile(&key, &ptx)?
861        }
862    };
863
864    // Upload targets to GPU (seq_len × u32 = ~2KB for seq_len=512)
865    let targets_u32: Vec<u32> = target_ids[..seq_len as usize].to_vec();
866    let targets_gpu = GpuBuffer::<u32>::from_host(&ctx, &targets_u32)
867        .map_err(|e| CudaTensorError::KernelError(format!("Failed to upload targets: {e:?}")))?;
868
869    // Allocate loss partials buffer (seq_len × f32)
870    let loss_gpu = GpuBuffer::<f32>::new(&ctx, seq_len as usize).map_err(|e| {
871        CudaTensorError::KernelError(format!("Failed to allocate loss partials: {e:?}"))
872    })?;
873
874    // Shared memory: 72 bytes (same as base kernel)
875    let config =
876        LaunchConfig { grid: (seq_len, 1, 1), block: (kernel.block_size(), 1, 1), shared_mem: 72 };
877
878    let logits_grad_ptr = logits_buf.as_ptr();
879    let targets_ptr = targets_gpu.as_ptr();
880    let loss_ptr = loss_gpu.as_ptr();
881
882    let mut args: [*mut std::ffi::c_void; 7] = [
883        &logits_grad_ptr as *const _ as *mut _,
884        &targets_ptr as *const _ as *mut _,
885        &loss_ptr as *const _ as *mut _,
886        &vocab_size as *const _ as *mut _,
887        &scale as *const _ as *mut _,
888        &loss_start as *const _ as *mut _,
889        &loss_end as *const _ as *mut _,
890    ];
891
892    // SAFETY: Kernel launch requires FFI. logits_buf has seq_len*vocab_size elements
893    // (read as logits, overwritten with gradients in-place). targets_gpu has seq_len u32
894    // elements. loss_gpu has seq_len f32 elements. Parameters match PTX signature.
895    // Positions outside [loss_start, loss_end) get zero gradient and zero loss.
896    unsafe {
897        stream.launch_kernel(module, "fused_causal_cross_entropy", &config, &mut args).map_err(
898            |e| {
899                CudaTensorError::KernelError(format!(
900                    "Fused causal cross-entropy launch failed: {e:?}"
901                ))
902            },
903        )?;
904    }
905
906    // Synchronize and download loss partials (~2KB)
907    stream
908        .synchronize()
909        .map_err(|e| CudaTensorError::KernelError(format!("Stream sync failed: {e:?}")))?;
910
911    let mut loss_partials = vec![0.0f32; seq_len as usize];
912    loss_gpu.copy_to_host(&mut loss_partials).map_err(|e| {
913        CudaTensorError::KernelError(format!("Failed to download loss partials: {e:?}"))
914    })?;
915
916    // Average loss across loss positions only (f64 for precision)
917    let num_loss_tokens = loss_end.saturating_sub(loss_start) as usize;
918    if num_loss_tokens == 0 {
919        return Ok(0.0);
920    }
921    let total_loss: f64 =
922        loss_partials[loss_start as usize..loss_end as usize].iter().map(|&x| f64::from(x)).sum();
923    let avg_loss = (total_loss / num_loss_tokens as f64) as f32;
924
925    Ok(avg_loss)
926}
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931
932    #[test]
933    fn test_cuda_optim_module_compiles() {
934        // This test verifies the module compiles correctly
935        // Actual CUDA tests require GPU hardware
936        assert!(true);
937    }
938
939    #[test]
940    #[cfg(feature = "cuda")]
941    fn test_optim_kernel_cache_initialization() {
942        use trueno_gpu::driver::cuda_available;
943
944        if !cuda_available() {
945            return;
946        }
947
948        let ctx = CudaContext::new(0).expect("operation should succeed");
949        let ctx = std::sync::Arc::new(ctx);
950        let result = init_optim_kernel_cache(ctx);
951        assert!(result.is_ok());
952    }
953
954    /// Create a fresh GPU context for a test
955    /// Note: Using fresh contexts per-test avoids CUDA driver state issues
956    /// when running multiple tests sequentially
957    #[cfg(feature = "cuda")]
958    fn get_test_gpu_context() -> Option<std::sync::Arc<CudaContext>> {
959        use trueno_gpu::driver::cuda_available;
960
961        if cuda_available() {
962            CudaContext::new(0).ok().map(std::sync::Arc::new)
963        } else {
964            None
965        }
966    }
967
968    /// CPU reference implementation for AdamW step
969    fn adamw_step_cpu(
970        params: &mut [f32],
971        grads: &[f32],
972        m: &mut [f32],
973        v: &mut [f32],
974        lr: f32,
975        beta1: f32,
976        beta2: f32,
977        eps: f32,
978        weight_decay: f32,
979        step: u32,
980    ) {
981        let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
982        let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
983
984        for i in 0..params.len() {
985            // Update biased first moment estimate
986            m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
987            // Update biased second moment estimate
988            v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
989
990            // Compute bias-corrected estimates
991            let m_hat = m[i] * bias_adjust1;
992            let v_hat = v[i] * bias_adjust2;
993
994            // AdamW update: weight decay is applied directly to params
995            params[i] = params[i] * (1.0 - lr * weight_decay) - lr * m_hat / (v_hat.sqrt() + eps);
996        }
997    }
998
999    /// CPU reference implementation for Adam step (no weight decay)
1000    fn adam_step_cpu(
1001        params: &mut [f32],
1002        grads: &[f32],
1003        m: &mut [f32],
1004        v: &mut [f32],
1005        lr: f32,
1006        beta1: f32,
1007        beta2: f32,
1008        eps: f32,
1009        step: u32,
1010    ) {
1011        let bias_adjust1 = 1.0 / (1.0 - beta1.powi(step as i32));
1012        let bias_adjust2 = 1.0 / (1.0 - beta2.powi(step as i32));
1013
1014        for i in 0..params.len() {
1015            // Update biased first moment estimate
1016            m[i] = beta1 * m[i] + (1.0 - beta1) * grads[i];
1017            // Update biased second moment estimate
1018            v[i] = beta2 * v[i] + (1.0 - beta2) * grads[i] * grads[i];
1019
1020            // Compute bias-corrected estimates
1021            let m_hat = m[i] * bias_adjust1;
1022            let v_hat = v[i] * bias_adjust2;
1023
1024            // Adam update (no weight decay)
1025            params[i] -= lr * m_hat / (v_hat.sqrt() + eps);
1026        }
1027    }
1028
1029    /// CPU reference implementation for gradient clipping
1030    fn gradient_clip_cpu(grads: &mut [f32], scale: f32) {
1031        for g in grads.iter_mut() {
1032            *g *= scale;
1033        }
1034    }
1035
1036    #[test]
1037    #[cfg(feature = "cuda")]
1038    fn test_adamw_step_basic() {
1039        let ctx = match get_test_gpu_context() {
1040            Some(c) => c,
1041            None => return,
1042        };
1043        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1044        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1045
1046        let n = 4u32;
1047        let lr = 0.001f32;
1048        let beta1 = 0.9f32;
1049        let beta2 = 0.999f32;
1050        let eps = 1e-8f32;
1051        let weight_decay = 0.01f32;
1052        let step = 1u32;
1053
1054        // Initial values
1055        let mut params_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1056        let grads_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
1057        let mut m_data: Vec<f32> = vec![0.0; n as usize];
1058        let mut v_data: Vec<f32> = vec![0.0; n as usize];
1059
1060        // CPU reference
1061        let mut cpu_params = params_data.clone();
1062        let mut cpu_m = m_data.clone();
1063        let mut cpu_v = v_data.clone();
1064        adamw_step_cpu(
1065            &mut cpu_params,
1066            &grads_data,
1067            &mut cpu_m,
1068            &mut cpu_v,
1069            lr,
1070            beta1,
1071            beta2,
1072            eps,
1073            weight_decay,
1074            step,
1075        );
1076
1077        // GPU execution
1078        let mut params =
1079            GpuBuffer::from_host(&ctx, &params_data).expect("operation should succeed");
1080        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1081        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1082        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1083
1084        adamw_step_cuda(
1085            &mut params,
1086            &grads,
1087            &mut m,
1088            &mut v,
1089            lr,
1090            beta1,
1091            beta2,
1092            eps,
1093            weight_decay,
1094            step,
1095            n,
1096            &stream,
1097        )
1098        .expect("operation should succeed");
1099        stream.synchronize().expect("operation should succeed");
1100
1101        params.copy_to_host(&mut params_data).expect("operation should succeed");
1102        m.copy_to_host(&mut m_data).expect("operation should succeed");
1103        v.copy_to_host(&mut v_data).expect("operation should succeed");
1104
1105        // Compare GPU vs CPU results
1106        for i in 0..n as usize {
1107            assert!(
1108                (params_data[i] - cpu_params[i]).abs() < 1e-4,
1109                "AdamW params mismatch at {i}: GPU={}, CPU={}",
1110                params_data[i],
1111                cpu_params[i]
1112            );
1113            assert!(
1114                (m_data[i] - cpu_m[i]).abs() < 1e-5,
1115                "AdamW m mismatch at {i}: GPU={}, CPU={}",
1116                m_data[i],
1117                cpu_m[i]
1118            );
1119            assert!(
1120                (v_data[i] - cpu_v[i]).abs() < 1e-5,
1121                "AdamW v mismatch at {i}: GPU={}, CPU={}",
1122                v_data[i],
1123                cpu_v[i]
1124            );
1125        }
1126    }
1127
1128    #[test]
1129    #[cfg(feature = "cuda")]
1130    fn test_adamw_step_not_hardcoded() {
1131        // Mutation-killing test: verify params actually change
1132        let ctx = match get_test_gpu_context() {
1133            Some(c) => c,
1134            None => return,
1135        };
1136        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1137        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1138
1139        let n = 4u32;
1140        let initial_params: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1141        let grads_data: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5]; // Non-zero gradients
1142        let m_data: Vec<f32> = vec![0.0; n as usize];
1143        let v_data: Vec<f32> = vec![0.0; n as usize];
1144
1145        let mut params =
1146            GpuBuffer::from_host(&ctx, &initial_params).expect("operation should succeed");
1147        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1148        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1149        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1150
1151        adamw_step_cuda(
1152            &mut params,
1153            &grads,
1154            &mut m,
1155            &mut v,
1156            0.01, // Larger LR to see effect
1157            0.9,
1158            0.999,
1159            1e-8,
1160            0.01,
1161            1,
1162            n,
1163            &stream,
1164        )
1165        .expect("operation should succeed");
1166        stream.synchronize().expect("operation should succeed");
1167
1168        let mut result_params = vec![0.0f32; n as usize];
1169        params.copy_to_host(&mut result_params).expect("operation should succeed");
1170
1171        // Kill mutant: params should have changed
1172        assert_ne!(result_params, initial_params, "mutant: AdamW params unchanged after step");
1173        // Verify params decreased (negative gradient update)
1174        for (i, (&new, &old)) in result_params.iter().zip(initial_params.iter()).enumerate() {
1175            assert!(new < old, "AdamW params[{i}] should decrease with positive gradients");
1176        }
1177    }
1178
1179    #[test]
1180    #[cfg(feature = "cuda")]
1181    fn test_adamw_weight_decay() {
1182        // Test that weight decay is actually applied
1183        let ctx = match get_test_gpu_context() {
1184            Some(c) => c,
1185            None => return,
1186        };
1187        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1188        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1189
1190        let n = 4u32;
1191        let params_data: Vec<f32> = vec![10.0, 10.0, 10.0, 10.0]; // Large weights
1192        let grads_data: Vec<f32> = vec![0.0, 0.0, 0.0, 0.0]; // Zero gradients
1193        let m_data: Vec<f32> = vec![0.0; n as usize];
1194        let v_data: Vec<f32> = vec![0.0; n as usize];
1195
1196        let mut params =
1197            GpuBuffer::from_host(&ctx, &params_data).expect("operation should succeed");
1198        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1199        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1200        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1201
1202        // With zero gradients, only weight decay should affect params
1203        adamw_step_cuda(
1204            &mut params,
1205            &grads,
1206            &mut m,
1207            &mut v,
1208            0.01, // LR
1209            0.9,
1210            0.999,
1211            1e-8,
1212            0.1, // High weight decay
1213            1,
1214            n,
1215            &stream,
1216        )
1217        .expect("operation should succeed");
1218        stream.synchronize().expect("operation should succeed");
1219
1220        let mut result = vec![0.0f32; n as usize];
1221        params.copy_to_host(&mut result).expect("operation should succeed");
1222
1223        // With zero gradients, params should decay: p = p * (1 - lr * wd)
1224        let expected = 10.0 * (1.0 - 0.01 * 0.1);
1225        for (i, &p) in result.iter().enumerate() {
1226            assert!(
1227                (p - expected).abs() < 1e-3,
1228                "Weight decay not applied correctly at {i}: got {p}, expected {expected}"
1229            );
1230        }
1231    }
1232
1233    #[test]
1234    #[cfg(feature = "cuda")]
1235    fn test_adam_step_basic() {
1236        let ctx = match get_test_gpu_context() {
1237            Some(c) => c,
1238            None => return,
1239        };
1240        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1241        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1242
1243        let n = 4u32;
1244        let lr = 0.001f32;
1245        let beta1 = 0.9f32;
1246        let beta2 = 0.999f32;
1247        let eps = 1e-8f32;
1248        let step = 1u32;
1249
1250        // Initial values
1251        let mut params_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1252        let grads_data: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
1253        let mut m_data: Vec<f32> = vec![0.0; n as usize];
1254        let mut v_data: Vec<f32> = vec![0.0; n as usize];
1255
1256        // CPU reference
1257        let mut cpu_params = params_data.clone();
1258        let mut cpu_m = m_data.clone();
1259        let mut cpu_v = v_data.clone();
1260        adam_step_cpu(
1261            &mut cpu_params,
1262            &grads_data,
1263            &mut cpu_m,
1264            &mut cpu_v,
1265            lr,
1266            beta1,
1267            beta2,
1268            eps,
1269            step,
1270        );
1271
1272        // GPU execution
1273        let mut params =
1274            GpuBuffer::from_host(&ctx, &params_data).expect("operation should succeed");
1275        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1276        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1277        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1278
1279        adam_step_cuda(
1280            &mut params,
1281            &grads,
1282            &mut m,
1283            &mut v,
1284            lr,
1285            beta1,
1286            beta2,
1287            eps,
1288            step,
1289            n,
1290            &stream,
1291        )
1292        .expect("operation should succeed");
1293        stream.synchronize().expect("operation should succeed");
1294
1295        params.copy_to_host(&mut params_data).expect("operation should succeed");
1296        m.copy_to_host(&mut m_data).expect("operation should succeed");
1297        v.copy_to_host(&mut v_data).expect("operation should succeed");
1298
1299        // Compare GPU vs CPU results
1300        for i in 0..n as usize {
1301            assert!(
1302                (params_data[i] - cpu_params[i]).abs() < 1e-4,
1303                "Adam params mismatch at {i}: GPU={}, CPU={}",
1304                params_data[i],
1305                cpu_params[i]
1306            );
1307        }
1308    }
1309
1310    #[test]
1311    #[cfg(feature = "cuda")]
1312    fn test_adam_step_multiple_iterations() {
1313        let ctx = match get_test_gpu_context() {
1314            Some(c) => c,
1315            None => return,
1316        };
1317        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1318        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1319
1320        let n = 4u32;
1321        let lr = 0.01f32;
1322        let beta1 = 0.9f32;
1323        let beta2 = 0.999f32;
1324        let eps = 1e-8f32;
1325
1326        let mut params_data: Vec<f32> = vec![1.0, 1.0, 1.0, 1.0];
1327        let grads_data: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5];
1328        let m_data: Vec<f32> = vec![0.0; n as usize];
1329        let v_data: Vec<f32> = vec![0.0; n as usize];
1330
1331        let mut params =
1332            GpuBuffer::from_host(&ctx, &params_data).expect("operation should succeed");
1333        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1334        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1335        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1336
1337        // Run 10 steps
1338        for step in 1..=10 {
1339            adam_step_cuda(
1340                &mut params,
1341                &grads,
1342                &mut m,
1343                &mut v,
1344                lr,
1345                beta1,
1346                beta2,
1347                eps,
1348                step,
1349                n,
1350                &stream,
1351            )
1352            .expect("operation should succeed");
1353        }
1354        stream.synchronize().expect("operation should succeed");
1355
1356        params.copy_to_host(&mut params_data).expect("operation should succeed");
1357
1358        // Params should have decreased significantly after 10 steps
1359        for &p in &params_data {
1360            assert!(p < 1.0, "Params should decrease after multiple Adam steps");
1361            assert!(p > 0.0, "Params should remain positive");
1362        }
1363    }
1364
1365    #[test]
1366    #[cfg(feature = "cuda")]
1367    fn test_gradient_clip_basic() {
1368        let ctx = match get_test_gpu_context() {
1369            Some(c) => c,
1370            None => return,
1371        };
1372        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1373        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1374
1375        let n = 4u32;
1376        let grads_data: Vec<f32> = vec![2.0, 4.0, 6.0, 8.0];
1377        let scale = 0.5f32; // Scale down by half
1378
1379        let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1380
1381        gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1382        stream.synchronize().expect("operation should succeed");
1383
1384        let mut result = vec![0.0f32; n as usize];
1385        grads.copy_to_host(&mut result).expect("operation should succeed");
1386
1387        // CPU reference
1388        let mut expected = grads_data.clone();
1389        gradient_clip_cpu(&mut expected, scale);
1390
1391        for (i, (&got, &exp)) in result.iter().zip(expected.iter()).enumerate() {
1392            assert!(
1393                (got - exp).abs() < 1e-5,
1394                "Gradient clip mismatch at {i}: got {got}, expected {exp}"
1395            );
1396        }
1397    }
1398
1399    #[test]
1400    #[cfg(feature = "cuda")]
1401    fn test_gradient_clip_no_op() {
1402        // Test that scale=1.0 is a no-op
1403        let ctx = match get_test_gpu_context() {
1404            Some(c) => c,
1405            None => return,
1406        };
1407        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1408        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1409
1410        let n = 4u32;
1411        let grads_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1412        let scale = 1.0f32; // No scaling
1413
1414        let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1415
1416        // This should be a no-op (kernel not even launched)
1417        gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1418        stream.synchronize().expect("operation should succeed");
1419
1420        let mut result = vec![0.0f32; n as usize];
1421        grads.copy_to_host(&mut result).expect("operation should succeed");
1422
1423        // Gradients should be unchanged
1424        for (i, (&got, &exp)) in result.iter().zip(grads_data.iter()).enumerate() {
1425            assert!(
1426                (got - exp).abs() < 1e-6,
1427                "Gradient clip with scale=1 should not modify values at {i}"
1428            );
1429        }
1430    }
1431
1432    #[test]
1433    #[cfg(feature = "cuda")]
1434    fn test_gradient_clip_not_hardcoded() {
1435        // Mutation-killing test
1436        let ctx = match get_test_gpu_context() {
1437            Some(c) => c,
1438            None => return,
1439        };
1440        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1441        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1442
1443        let n = 4u32;
1444        let grads_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1445        let scale = 0.1f32;
1446
1447        let mut grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1448
1449        gradient_clip_cuda(&mut grads, scale, n, &stream).expect("operation should succeed");
1450        stream.synchronize().expect("operation should succeed");
1451
1452        let mut result = vec![0.0f32; n as usize];
1453        grads.copy_to_host(&mut result).expect("operation should succeed");
1454
1455        // Kill mutant: result should NOT equal original
1456        assert_ne!(result, grads_data, "mutant: gradient clip had no effect");
1457
1458        // Verify scaled values
1459        assert!((result[0] - 1.0).abs() < 1e-5);
1460        assert!((result[1] - 2.0).abs() < 1e-5);
1461        assert!((result[2] - 3.0).abs() < 1e-5);
1462        assert!((result[3] - 4.0).abs() < 1e-5);
1463    }
1464
1465    #[test]
1466    #[cfg(feature = "cuda")]
1467    fn test_optimizer_large_scale() {
1468        // Test with larger parameter count
1469        let ctx = match get_test_gpu_context() {
1470            Some(c) => c,
1471            None => return,
1472        };
1473        init_optim_kernel_cache(ctx.clone()).expect("operation should succeed");
1474        let stream = CudaStream::new(&ctx).expect("operation should succeed");
1475
1476        let n = 1024u32;
1477        let params_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
1478        let grads_data: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.01).sin()).collect();
1479        let m_data: Vec<f32> = vec![0.0; n as usize];
1480        let v_data: Vec<f32> = vec![0.0; n as usize];
1481
1482        let mut params =
1483            GpuBuffer::from_host(&ctx, &params_data).expect("operation should succeed");
1484        let grads = GpuBuffer::from_host(&ctx, &grads_data).expect("operation should succeed");
1485        let mut m = GpuBuffer::from_host(&ctx, &m_data).expect("operation should succeed");
1486        let mut v = GpuBuffer::from_host(&ctx, &v_data).expect("operation should succeed");
1487
1488        adamw_step_cuda(
1489            &mut params,
1490            &grads,
1491            &mut m,
1492            &mut v,
1493            0.001,
1494            0.9,
1495            0.999,
1496            1e-8,
1497            0.01,
1498            1,
1499            n,
1500            &stream,
1501        )
1502        .expect("operation should succeed");
1503        stream.synchronize().expect("operation should succeed");
1504
1505        let mut result = vec![0.0f32; n as usize];
1506        params.copy_to_host(&mut result).expect("operation should succeed");
1507
1508        // Verify no NaN or Inf
1509        assert!(
1510            !result.iter().any(|x| x.is_nan() || x.is_infinite()),
1511            "Large-scale optimizer should not produce NaN/Inf"
1512        );
1513    }
1514}