1#[cfg(feature = "cuda")]
27use std::sync::Arc;
28
29#[cfg(feature = "cuda")]
30use trueno_gpu::driver::{cuda_available, CudaContext, CudaStream, GpuBuffer};
31
32use super::cuda_tensor::{CudaTensorError, Result};
33#[cfg(feature = "cuda")]
34use provable_contracts_macros::requires;
35
36#[cfg(feature = "cuda")]
37use super::cuda_backward::{gemm_backward_a, gemm_backward_b, init_kernel_cache};
38#[cfg(feature = "cuda")]
39use super::cuda_forward::{gemm_forward, init_forward_kernel_cache};
40#[cfg(feature = "cuda")]
41use super::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, init_optim_kernel_cache};
42
43#[cfg(feature = "cuda")]
47pub struct CudaTrainer {
48 ctx: Arc<CudaContext>,
49 stream: CudaStream,
50 step: u32,
51}
52
53#[cfg(feature = "cuda")]
54impl CudaTrainer {
55 pub fn new() -> Result<Self> {
57 Self::with_device(0)
58 }
59
60 pub fn with_device(device_id: i32) -> Result<Self> {
62 if !cuda_available() {
63 return Err(CudaTensorError::CudaNotAvailable("No CUDA driver found".into()));
64 }
65
66 let ctx = Arc::new(
67 CudaContext::new(device_id)
68 .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?,
69 );
70 let stream = CudaStream::new(&ctx)
71 .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
72
73 init_forward_kernel_cache(ctx.clone())?;
75 init_kernel_cache(ctx.clone())?;
76 init_optim_kernel_cache(ctx.clone())?;
77
78 Ok(Self { ctx, stream, step: 0 })
79 }
80
81 pub fn context(&self) -> &Arc<CudaContext> {
83 &self.ctx
84 }
85
86 pub fn stream(&self) -> &CudaStream {
88 &self.stream
89 }
90
91 pub fn synchronize(&self) -> Result<()> {
93 self.stream.synchronize().map_err(|e| CudaTensorError::KernelError(format!("{e:?}")))
94 }
95
96 pub fn upload(&self, data: &[f32]) -> Result<GpuBuffer<f32>> {
98 let mut buf = GpuBuffer::from_host(&self.ctx, data)
99 .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
100 buf.set_context(&self.ctx);
102 Ok(buf)
103 }
104
105 pub fn zeros(&self, len: usize) -> Result<GpuBuffer<f32>> {
107 let data = vec![0.0f32; len];
108 self.upload(&data)
109 }
110
111 pub fn free_memory_mb(&self) -> Option<u64> {
114 self.ctx.memory_info().map(|(free, _total)| (free / (1024 * 1024)) as u64).ok()
115 }
116
117 pub fn download(&self, buffer: &GpuBuffer<f32>) -> Result<Vec<f32>> {
119 let mut result = vec![0.0f32; buffer.len()];
120 buffer
121 .copy_to_host(&mut result)
122 .map_err(|e| CudaTensorError::TransferFailed(format!("{e:?}")))?;
123 Ok(result)
124 }
125
126 pub fn matmul_forward(
134 &self,
135 a: &GpuBuffer<f32>,
136 b: &GpuBuffer<f32>,
137 c: &mut GpuBuffer<f32>,
138 m: u32,
139 k: u32,
140 n: u32,
141 ) -> Result<()> {
142 gemm_forward(a, b, c, m, k, n, &self.stream)
143 }
144
145 #[requires(m > 0 && k > 0 && n > 0)]
152 pub fn matmul_backward(
153 &self,
154 a: &GpuBuffer<f32>,
155 b: &GpuBuffer<f32>,
156 grad_c: &GpuBuffer<f32>,
157 grad_a: &mut GpuBuffer<f32>,
158 grad_b: &mut GpuBuffer<f32>,
159 m: u32,
160 k: u32,
161 n: u32,
162 ) -> Result<()> {
163 gemm_backward_a(grad_c, b, grad_a, m, k, n, &self.stream)?;
164 gemm_backward_b(a, grad_c, grad_b, m, k, n, &self.stream)?;
165 Ok(())
166 }
167
168 pub fn adamw_step(
172 &mut self,
173 params: &mut GpuBuffer<f32>,
174 grads: &GpuBuffer<f32>,
175 m_state: &mut GpuBuffer<f32>,
176 v_state: &mut GpuBuffer<f32>,
177 lr: f32,
178 beta1: f32,
179 beta2: f32,
180 eps: f32,
181 weight_decay: f32,
182 ) -> Result<()> {
183 self.step += 1;
184 let n = params.len() as u32;
185 adamw_step_cuda(
186 params,
187 grads,
188 m_state,
189 v_state,
190 lr,
191 beta1,
192 beta2,
193 eps,
194 weight_decay,
195 self.step,
196 n,
197 &self.stream,
198 )
199 }
200
201 pub fn clip_gradients(&self, grads: &mut GpuBuffer<f32>, max_norm: f32) -> Result<()> {
203 let grad_data = self.download(grads)?;
205 let grad_norm: f32 = grad_data.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207 let scale = if grad_norm > max_norm { max_norm / grad_norm } else { 1.0 };
209
210 gradient_clip_cuda(grads, scale, grads.len() as u32, &self.stream)
212 }
213
214 pub fn step_count(&self) -> u32 {
216 self.step
217 }
218
219 pub fn reset_step(&mut self) {
221 self.step = 0;
222 }
223
224 pub fn device_name(&self) -> String {
226 self.ctx.device_name().unwrap_or_else(|_err| "Unknown GPU".to_string())
227 }
228
229 pub fn total_memory(&self) -> usize {
231 self.ctx.total_memory().unwrap_or(0)
232 }
233}
234
235#[cfg(feature = "cuda")]
236#[allow(clippy::missing_fields_in_debug)]
237impl std::fmt::Debug for CudaTrainer {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("CudaTrainer")
240 .field("device", &self.device_name())
241 .field("memory_gb", &(self.total_memory() as f64 / 1e9))
242 .field("step", &self.step)
243 .finish()
244 }
245}
246
247#[cfg(not(feature = "cuda"))]
249pub struct CudaTrainer;
250
251#[cfg(not(feature = "cuda"))]
252impl CudaTrainer {
253 pub fn new() -> Result<Self> {
254 Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
255 }
256}
257
258pub fn cuda_training_available() -> bool {
260 #[cfg(feature = "cuda")]
261 {
262 trueno_gpu::driver::cuda_available()
263 }
264 #[cfg(not(feature = "cuda"))]
265 {
266 false
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_cuda_training_available() {
276 let _ = cuda_training_available();
278 }
279
280 #[test]
281 #[cfg(feature = "cuda")]
282 fn test_cuda_trainer_creation() {
283 if !cuda_training_available() {
284 return;
285 }
286
287 let trainer = CudaTrainer::new();
288 assert!(trainer.is_ok());
289
290 let trainer = trainer.expect("operation should succeed");
291 assert!(!trainer.device_name().is_empty());
292 assert!(trainer.total_memory() > 0);
293 }
294
295 #[test]
296 #[cfg(feature = "cuda")]
297 fn test_cuda_trainer_upload_download() {
298 if !cuda_training_available() {
299 return;
300 }
301
302 let trainer = CudaTrainer::new().expect("operation should succeed");
303 let data = vec![1.0, 2.0, 3.0, 4.0];
304
305 let gpu_buffer = trainer.upload(&data).expect("load should succeed");
306 let result = trainer.download(&gpu_buffer).expect("load should succeed");
307
308 assert_eq!(data, result);
309 }
310
311 #[test]
312 #[cfg(feature = "cuda")]
313 fn test_cuda_trainer_zeros() {
314 if !cuda_training_available() {
315 return;
316 }
317
318 let trainer = CudaTrainer::new().expect("operation should succeed");
319 let gpu_buffer = trainer.zeros(100).expect("operation should succeed");
320 let result = trainer.download(&gpu_buffer).expect("load should succeed");
321
322 assert_eq!(result.len(), 100);
323 assert!(result.iter().all(|&x| x == 0.0));
324 }
325
326 #[test]
327 #[cfg(feature = "cuda")]
328 fn test_cuda_trainer_synchronize() {
329 if !cuda_training_available() {
330 return;
331 }
332
333 let trainer = CudaTrainer::new().expect("operation should succeed");
334 assert!(trainer.synchronize().is_ok());
336 }
337
338 #[test]
339 #[cfg(feature = "cuda")]
340 fn test_cuda_trainer_context_and_stream() {
341 if !cuda_training_available() {
342 return;
343 }
344
345 let trainer = CudaTrainer::new().expect("operation should succeed");
346 let _ctx = trainer.context();
348 let _stream = trainer.stream();
349 }
350
351 #[test]
352 #[cfg(feature = "cuda")]
353 fn test_cuda_trainer_step_count() {
354 if !cuda_training_available() {
355 return;
356 }
357
358 let mut trainer = CudaTrainer::new().expect("operation should succeed");
359 assert_eq!(trainer.step_count(), 0);
360
361 let mut params = trainer.upload(&[1.0, 2.0, 3.0]).expect("load should succeed");
363 let grads = trainer.upload(&[0.1, 0.1, 0.1]).expect("load should succeed");
364 let mut m_state = trainer.zeros(3).expect("operation should succeed");
365 let mut v_state = trainer.zeros(3).expect("operation should succeed");
366
367 trainer
368 .adamw_step(
369 &mut params,
370 &grads,
371 &mut m_state,
372 &mut v_state,
373 0.001,
374 0.9,
375 0.999,
376 1e-8,
377 0.0,
378 )
379 .expect("operation should succeed");
380
381 assert_eq!(trainer.step_count(), 1);
382
383 trainer.reset_step();
384 assert_eq!(trainer.step_count(), 0);
385 }
386
387 #[test]
388 #[cfg(feature = "cuda")]
389 fn test_cuda_trainer_matmul_forward() {
390 if !cuda_training_available() {
391 return;
392 }
393
394 let trainer = CudaTrainer::new().expect("operation should succeed");
395
396 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let c_data: Vec<f32> = vec![0.0; 4]; let a = trainer.upload(&a_data).expect("load should succeed");
402 let b = trainer.upload(&b_data).expect("load should succeed");
403 let mut c = trainer.upload(&c_data).expect("load should succeed");
404
405 trainer.matmul_forward(&a, &b, &mut c, 2, 3, 2).expect("operation should succeed");
406 trainer.synchronize().expect("operation should succeed");
407
408 let result = trainer.download(&c).expect("load should succeed");
409 assert!(!result.iter().all(|&x| x == 0.0));
411 }
412
413 #[test]
414 #[cfg(feature = "cuda")]
415 fn test_cuda_trainer_clip_gradients() {
416 if !cuda_training_available() {
417 return;
418 }
419
420 let trainer = CudaTrainer::new().expect("operation should succeed");
421
422 let grad_data: Vec<f32> = vec![10.0, 10.0, 10.0, 10.0]; let mut grads = trainer.upload(&grad_data).expect("load should succeed");
425
426 trainer.clip_gradients(&mut grads, 1.0).expect("operation should succeed");
428 trainer.synchronize().expect("operation should succeed");
429
430 let result = trainer.download(&grads).expect("load should succeed");
431 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
433 assert!(norm <= 1.1, "Gradient norm should be clipped to ~1.0, got {norm}");
434 }
435
436 #[test]
437 #[cfg(feature = "cuda")]
438 fn test_cuda_trainer_debug_impl() {
439 if !cuda_training_available() {
440 return;
441 }
442
443 let trainer = CudaTrainer::new().expect("operation should succeed");
444 let debug_str = format!("{trainer:?}");
445 assert!(debug_str.contains("CudaTrainer"));
446 assert!(debug_str.contains("device"));
447 assert!(debug_str.contains("step"));
448 }
449}