Skip to main content

entrenar/autograd/
cuda_training.rs

1//! CUDA-accelerated training utilities
2//!
3//! This module provides high-level training primitives that use CUDA kernels
4//! when available, with automatic CPU fallback.
5//!
6//! # Architecture (SPEC-FT-001 v3.2.0)
7//!
8//! ```text
9//! CudaTrainer
10//!   ├── device: CudaDevice
11//!   ├── forward: gemm_forward kernel
12//!   ├── backward: gemm_backward_a/b kernels
13//!   └── optimizer: adamw_step_cuda kernel
14//! ```
15//!
16//! # Example
17//!
18//! ```ignore
19//! use entrenar::autograd::cuda_training::CudaTrainer;
20//!
21//! let trainer = CudaTrainer::new()?;
22//! let logits = trainer.matmul_forward(&hidden, &weights, m, k, n)?;
23//! trainer.adamw_step(&mut weights, &grads, lr, step)?;
24//! ```
25
26#[cfg(feature = "cuda")]
27use std::sync::Arc;
28
29#[cfg(feature = "cuda")]
30use trueno_gpu::driver::{cuda_available, CudaContext, CudaStream, GpuBuffer};
31
32use super::cuda_tensor::{CudaTensorError, Result};
33#[cfg(feature = "cuda")]
34use provable_contracts_macros::requires;
35
36#[cfg(feature = "cuda")]
37use super::cuda_backward::{gemm_backward_a, gemm_backward_b, init_kernel_cache};
38#[cfg(feature = "cuda")]
39use super::cuda_forward::{gemm_forward, init_forward_kernel_cache};
40#[cfg(feature = "cuda")]
41use super::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, init_optim_kernel_cache};
42
43/// CUDA-accelerated training context
44///
45/// Manages GPU resources and provides high-level training operations.
46#[cfg(feature = "cuda")]
47pub struct CudaTrainer {
48    ctx: Arc<CudaContext>,
49    stream: CudaStream,
50    step: u32,
51}
52
53#[cfg(feature = "cuda")]
54impl CudaTrainer {
55    /// Create a new CUDA trainer on the default GPU
56    pub fn new() -> Result<Self> {
57        Self::with_device(0)
58    }
59
60    /// Create a new CUDA trainer on the specified GPU
61    pub fn with_device(device_id: i32) -> Result<Self> {
62        if !cuda_available() {
63            return Err(CudaTensorError::CudaNotAvailable("No CUDA driver found".into()));
64        }
65
66        let ctx = Arc::new(
67            CudaContext::new(device_id)
68                .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?,
69        );
70        let stream = CudaStream::new(&ctx)
71            .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
72
73        // Initialize all kernel caches
74        init_forward_kernel_cache(ctx.clone())?;
75        init_kernel_cache(ctx.clone())?;
76        init_optim_kernel_cache(ctx.clone())?;
77
78        Ok(Self { ctx, stream, step: 0 })
79    }
80
81    /// Get the CUDA context
82    pub fn context(&self) -> &Arc<CudaContext> {
83        &self.ctx
84    }
85
86    /// Get the CUDA stream
87    pub fn stream(&self) -> &CudaStream {
88        &self.stream
89    }
90
91    /// Synchronize the stream (wait for all operations to complete)
92    pub fn synchronize(&self) -> Result<()> {
93        self.stream.synchronize().map_err(|e| CudaTensorError::KernelError(format!("{e:?}")))
94    }
95
96    /// Allocate a GPU buffer from host data
97    pub fn upload(&self, data: &[f32]) -> Result<GpuBuffer<f32>> {
98        let mut buf = GpuBuffer::from_host(&self.ctx, data)
99            .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
100        // PMAT-420: Set context for thread-safe transfers
101        buf.set_context(&self.ctx);
102        Ok(buf)
103    }
104
105    /// Allocate a zero-initialized GPU buffer
106    pub fn zeros(&self, len: usize) -> Result<GpuBuffer<f32>> {
107        let data = vec![0.0f32; len];
108        self.upload(&data)
109    }
110
111    /// Query free VRAM in MB (via cuMemGetInfo).
112    /// Returns None if query fails.
113    pub fn free_memory_mb(&self) -> Option<u64> {
114        self.ctx.memory_info().map(|(free, _total)| (free / (1024 * 1024)) as u64).ok()
115    }
116
117    /// Download GPU buffer to host
118    pub fn download(&self, buffer: &GpuBuffer<f32>) -> Result<Vec<f32>> {
119        let mut result = vec![0.0f32; buffer.len()];
120        buffer
121            .copy_to_host(&mut result)
122            .map_err(|e| CudaTensorError::TransferFailed(format!("{e:?}")))?;
123        Ok(result)
124    }
125
126    /// Matrix multiply forward pass: C = A @ B
127    ///
128    /// # Arguments
129    /// - `a`: Input matrix (m × k)
130    /// - `b`: Weight matrix (k × n)
131    /// - `c`: Output matrix (m × n)
132    /// - `m`, `k`, `n`: Matrix dimensions
133    pub fn matmul_forward(
134        &self,
135        a: &GpuBuffer<f32>,
136        b: &GpuBuffer<f32>,
137        c: &mut GpuBuffer<f32>,
138        m: u32,
139        k: u32,
140        n: u32,
141    ) -> Result<()> {
142        gemm_forward(a, b, c, m, k, n, &self.stream)
143    }
144
145    /// Matrix multiply backward pass for weight gradients
146    ///
147    /// Given C = A @ B, computes:
148    /// - grad_A = grad_C @ B^T
149    /// - grad_B = A^T @ grad_C
150    // Contract: backward-pass-v1 / matmul_backward
151    #[requires(m > 0 && k > 0 && n > 0)]
152    pub fn matmul_backward(
153        &self,
154        a: &GpuBuffer<f32>,
155        b: &GpuBuffer<f32>,
156        grad_c: &GpuBuffer<f32>,
157        grad_a: &mut GpuBuffer<f32>,
158        grad_b: &mut GpuBuffer<f32>,
159        m: u32,
160        k: u32,
161        n: u32,
162    ) -> Result<()> {
163        gemm_backward_a(grad_c, b, grad_a, m, k, n, &self.stream)?;
164        gemm_backward_b(a, grad_c, grad_b, m, k, n, &self.stream)?;
165        Ok(())
166    }
167
168    /// AdamW optimizer step on GPU
169    ///
170    /// Updates weights in-place using the AdamW algorithm.
171    pub fn adamw_step(
172        &mut self,
173        params: &mut GpuBuffer<f32>,
174        grads: &GpuBuffer<f32>,
175        m_state: &mut GpuBuffer<f32>,
176        v_state: &mut GpuBuffer<f32>,
177        lr: f32,
178        beta1: f32,
179        beta2: f32,
180        eps: f32,
181        weight_decay: f32,
182    ) -> Result<()> {
183        self.step += 1;
184        let n = params.len() as u32;
185        adamw_step_cuda(
186            params,
187            grads,
188            m_state,
189            v_state,
190            lr,
191            beta1,
192            beta2,
193            eps,
194            weight_decay,
195            self.step,
196            n,
197            &self.stream,
198        )
199    }
200
201    /// Apply gradient clipping
202    pub fn clip_gradients(&self, grads: &mut GpuBuffer<f32>, max_norm: f32) -> Result<()> {
203        // Compute gradient norm on CPU (requires download)
204        let grad_data = self.download(grads)?;
205        let grad_norm: f32 = grad_data.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207        // Compute scale factor
208        let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
209
210        // Apply clipping on GPU
211        gradient_clip_cuda(grads, scale, grads.len() as u32, &self.stream)
212    }
213
214    /// Get current optimizer step count
215    pub fn step_count(&self) -> u32 {
216        self.step
217    }
218
219    /// Reset optimizer step count (for new training run)
220    pub fn reset_step(&mut self) {
221        self.step = 0;
222    }
223
224    /// Get device name
225    pub fn device_name(&self) -> String {
226        self.ctx.device_name().unwrap_or_else(|_err| "Unknown GPU".to_string())
227    }
228
229    /// Get total GPU memory in bytes
230    pub fn total_memory(&self) -> usize {
231        self.ctx.total_memory().unwrap_or(0)
232    }
233}
234
235#[cfg(feature = "cuda")]
236#[allow(clippy::missing_fields_in_debug)]
237impl std::fmt::Debug for CudaTrainer {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.debug_struct("CudaTrainer")
240            .field("device", &self.device_name())
241            .field("memory_gb", &(self.total_memory() as f64 / 1e9))
242            .field("step", &self.step)
243            .finish()
244    }
245}
246
247// CPU fallback when CUDA is not available
248#[cfg(not(feature = "cuda"))]
249pub struct CudaTrainer;
250
251#[cfg(not(feature = "cuda"))]
252impl CudaTrainer {
253    pub fn new() -> Result<Self> {
254        Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
255    }
256}
257
258/// Check if CUDA training is available
259pub fn cuda_training_available() -> bool {
260    #[cfg(feature = "cuda")]
261    {
262        trueno_gpu::driver::cuda_available()
263    }
264    #[cfg(not(feature = "cuda"))]
265    {
266        false
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_cuda_training_available() {
276        // Just verify the function compiles and runs
277        let _ = cuda_training_available();
278    }
279
280    #[test]
281    #[cfg(feature = "cuda")]
282    fn test_cuda_trainer_creation() {
283        if !cuda_training_available() {
284            return;
285        }
286
287        let trainer = CudaTrainer::new();
288        assert!(trainer.is_ok());
289
290        let trainer = trainer.expect("operation should succeed");
291        assert!(!trainer.device_name().is_empty());
292        assert!(trainer.total_memory() > 0);
293    }
294
295    #[test]
296    #[cfg(feature = "cuda")]
297    fn test_cuda_trainer_upload_download() {
298        if !cuda_training_available() {
299            return;
300        }
301
302        let trainer = CudaTrainer::new().expect("operation should succeed");
303        let data = vec![1.0, 2.0, 3.0, 4.0];
304
305        let gpu_buffer = trainer.upload(&data).expect("load should succeed");
306        let result = trainer.download(&gpu_buffer).expect("load should succeed");
307
308        assert_eq!(data, result);
309    }
310
311    #[test]
312    #[cfg(feature = "cuda")]
313    fn test_cuda_trainer_zeros() {
314        if !cuda_training_available() {
315            return;
316        }
317
318        let trainer = CudaTrainer::new().expect("operation should succeed");
319        let gpu_buffer = trainer.zeros(100).expect("operation should succeed");
320        let result = trainer.download(&gpu_buffer).expect("load should succeed");
321
322        assert_eq!(result.len(), 100);
323        assert!(result.iter().all(|&x| x == 0.0));
324    }
325
326    #[test]
327    #[cfg(feature = "cuda")]
328    fn test_cuda_trainer_synchronize() {
329        if !cuda_training_available() {
330            return;
331        }
332
333        let trainer = CudaTrainer::new().expect("operation should succeed");
334        // Synchronize should succeed
335        assert!(trainer.synchronize().is_ok());
336    }
337
338    #[test]
339    #[cfg(feature = "cuda")]
340    fn test_cuda_trainer_context_and_stream() {
341        if !cuda_training_available() {
342            return;
343        }
344
345        let trainer = CudaTrainer::new().expect("operation should succeed");
346        // Accessing context and stream should not panic
347        let _ctx = trainer.context();
348        let _stream = trainer.stream();
349    }
350
351    #[test]
352    #[cfg(feature = "cuda")]
353    fn test_cuda_trainer_step_count() {
354        if !cuda_training_available() {
355            return;
356        }
357
358        let mut trainer = CudaTrainer::new().expect("operation should succeed");
359        assert_eq!(trainer.step_count(), 0);
360
361        // Simulate an optimizer step by calling adamw_step
362        let mut params = trainer.upload(&[1.0, 2.0, 3.0]).expect("load should succeed");
363        let grads = trainer.upload(&[0.1, 0.1, 0.1]).expect("load should succeed");
364        let mut m_state = trainer.zeros(3).expect("operation should succeed");
365        let mut v_state = trainer.zeros(3).expect("operation should succeed");
366
367        trainer
368            .adamw_step(
369                &mut params,
370                &grads,
371                &mut m_state,
372                &mut v_state,
373                0.001,
374                0.9,
375                0.999,
376                1e-8,
377                0.0,
378            )
379            .expect("operation should succeed");
380
381        assert_eq!(trainer.step_count(), 1);
382
383        trainer.reset_step();
384        assert_eq!(trainer.step_count(), 0);
385    }
386
387    #[test]
388    #[cfg(feature = "cuda")]
389    fn test_cuda_trainer_matmul_forward() {
390        if !cuda_training_available() {
391            return;
392        }
393
394        let trainer = CudaTrainer::new().expect("operation should succeed");
395
396        // 2x3 @ 3x2 = 2x2
397        let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
398        let b_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 3x2
399        let c_data: Vec<f32> = vec![0.0; 4]; // 2x2
400
401        let a = trainer.upload(&a_data).expect("load should succeed");
402        let b = trainer.upload(&b_data).expect("load should succeed");
403        let mut c = trainer.upload(&c_data).expect("load should succeed");
404
405        trainer.matmul_forward(&a, &b, &mut c, 2, 3, 2).expect("operation should succeed");
406        trainer.synchronize().expect("operation should succeed");
407
408        let result = trainer.download(&c).expect("load should succeed");
409        // Verify result is not all zeros (matmul should produce non-zero output)
410        assert!(!result.iter().all(|&x| x == 0.0));
411    }
412
413    #[test]
414    #[cfg(feature = "cuda")]
415    fn test_cuda_trainer_clip_gradients() {
416        if !cuda_training_available() {
417            return;
418        }
419
420        let trainer = CudaTrainer::new().expect("operation should succeed");
421
422        // Create large gradients that should be clipped
423        let grad_data: Vec<f32> = vec![10.0, 10.0, 10.0, 10.0]; // norm = 20
424        let mut grads = trainer.upload(&grad_data).expect("load should succeed");
425
426        // Clip to max_norm = 1.0
427        trainer.clip_gradients(&mut grads, 1.0).expect("operation should succeed");
428        trainer.synchronize().expect("operation should succeed");
429
430        let result = trainer.download(&grads).expect("load should succeed");
431        // Gradients should be scaled down
432        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
433        assert!(norm <= 1.1, "Gradient norm should be clipped to ~1.0, got {norm}");
434    }
435
436    #[test]
437    #[cfg(feature = "cuda")]
438    fn test_cuda_trainer_debug_impl() {
439        if !cuda_training_available() {
440            return;
441        }
442
443        let trainer = CudaTrainer::new().expect("operation should succeed");
444        let debug_str = format!("{trainer:?}");
445        assert!(debug_str.contains("CudaTrainer"));
446        assert!(debug_str.contains("device"));
447        assert!(debug_str.contains("step"));
448    }
449}