1#[cfg(feature = "cuda")]
28use std::sync::Arc;
29
30#[cfg(feature = "cuda")]
31use trueno_gpu::driver::{cuda_available, CudaContext, CudaStream, GpuBuffer};
32#[cfg(feature = "cuda")]
33use trueno_gpu::GpuError;
34
35#[derive(Debug, thiserror::Error)]
37pub enum CudaTensorError {
38 #[error("CUDA not available: {0}")]
40 CudaNotAvailable(String),
41
42 #[error("GPU allocation failed: {0}")]
44 AllocationFailed(String),
45
46 #[error("Data transfer failed: {0}")]
48 TransferFailed(String),
49
50 #[error("Shape mismatch: expected {expected}, got {actual}")]
52 ShapeMismatch { expected: usize, actual: usize },
53
54 #[error("Kernel launch failed: {0}")]
56 KernelError(String),
57
58 #[error("CUDA device not initialized")]
60 DeviceNotInitialized,
61}
62
63#[cfg(feature = "cuda")]
64impl From<GpuError> for CudaTensorError {
65 fn from(e: GpuError) -> Self {
66 match e {
67 GpuError::OutOfMemory { requested, available } => CudaTensorError::AllocationFailed(
68 format!("Out of GPU memory: requested {requested} bytes, {available} available"),
69 ),
70 GpuError::Transfer(msg) => CudaTensorError::TransferFailed(msg),
71 GpuError::CudaNotAvailable(msg) => CudaTensorError::CudaNotAvailable(msg),
72 other => CudaTensorError::KernelError(format!("{other:?}")),
73 }
74 }
75}
76
77pub type Result<T> = std::result::Result<T, CudaTensorError>;
79
80#[cfg(feature = "cuda")]
82pub struct CudaDevice {
83 ctx: Arc<CudaContext>,
84 stream: CudaStream,
85}
86
87#[cfg(feature = "cuda")]
88impl CudaDevice {
89 pub fn new(device_id: i32) -> Result<Self> {
91 if !cuda_available() {
92 return Err(CudaTensorError::CudaNotAvailable("No CUDA driver found".into()));
93 }
94
95 let ctx = CudaContext::new(device_id)
96 .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?;
97 let stream = CudaStream::new(&ctx)
98 .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
99
100 Ok(Self { ctx: Arc::new(ctx), stream })
101 }
102
103 pub fn default_device() -> Result<Self> {
105 Self::new(0)
106 }
107
108 pub fn context(&self) -> &Arc<CudaContext> {
110 &self.ctx
111 }
112
113 pub fn stream(&self) -> &CudaStream {
115 &self.stream
116 }
117
118 pub fn synchronize(&self) -> Result<()> {
120 self.stream.synchronize().map_err(|e| CudaTensorError::KernelError(format!("{e:?}")))
121 }
122}
123
124#[cfg(feature = "cuda")]
128pub struct CudaTensor {
129 data: GpuBuffer<f32>,
131 grad: Option<GpuBuffer<f32>>,
133 device: Arc<CudaContext>,
135 requires_grad: bool,
137 len: usize,
139}
140
141#[cfg(feature = "cuda")]
142impl CudaTensor {
143 pub fn from_vec(device: &CudaDevice, data: Vec<f32>, requires_grad: bool) -> Result<Self> {
145 let len = data.len();
146 let gpu_data = GpuBuffer::from_host(&device.ctx, &data)?;
147
148 let grad = if requires_grad {
149 let zeros = vec![0.0f32; len];
151 Some(GpuBuffer::from_host(&device.ctx, &zeros)?)
152 } else {
153 None
154 };
155
156 Ok(Self { data: gpu_data, grad, device: device.ctx.clone(), requires_grad, len })
157 }
158
159 pub fn zeros(device: &CudaDevice, len: usize, requires_grad: bool) -> Result<Self> {
161 let data = vec![0.0f32; len];
162 Self::from_vec(device, data, requires_grad)
163 }
164
165 pub fn ones(device: &CudaDevice, len: usize, requires_grad: bool) -> Result<Self> {
167 let data = vec![1.0f32; len];
168 Self::from_vec(device, data, requires_grad)
169 }
170
171 pub fn to_vec(&self) -> Result<Vec<f32>> {
173 let mut result = vec![0.0f32; self.len];
174 self.data.copy_to_host(&mut result)?;
175 Ok(result)
176 }
177
178 pub fn grad_to_vec(&self) -> Result<Option<Vec<f32>>> {
180 match &self.grad {
181 Some(grad_buf) => {
182 let mut result = vec![0.0f32; self.len];
183 grad_buf.copy_to_host(&mut result)?;
184 Ok(Some(result))
185 }
186 None => Ok(None),
187 }
188 }
189
190 pub fn copy_from_vec(&mut self, data: &[f32]) -> Result<()> {
192 if data.len() != self.len {
193 return Err(CudaTensorError::ShapeMismatch { expected: self.len, actual: data.len() });
194 }
195 self.data.copy_from_host(data)?;
196 Ok(())
197 }
198
199 pub fn set_grad_from_vec(&mut self, grad: &[f32]) -> Result<()> {
201 if grad.len() != self.len {
202 return Err(CudaTensorError::ShapeMismatch { expected: self.len, actual: grad.len() });
203 }
204
205 match &mut self.grad {
206 Some(grad_buf) => {
207 grad_buf.copy_from_host(grad)?;
208 }
209 None => {
210 self.grad = Some(GpuBuffer::from_host(
211 &CudaContext::new(0)
214 .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?,
215 grad,
216 )?);
217 }
218 }
219 Ok(())
220 }
221
222 pub fn zero_grad(&mut self) -> Result<()> {
224 if let Some(ref mut grad_buf) = self.grad {
225 let zeros = vec![0.0f32; self.len];
226 grad_buf.copy_from_host(&zeros)?;
227 }
228 Ok(())
229 }
230
231 pub fn requires_grad(&self) -> bool {
233 self.requires_grad
234 }
235
236 pub fn len(&self) -> usize {
238 self.len
239 }
240
241 pub fn is_empty(&self) -> bool {
243 self.len == 0
244 }
245
246 pub fn data_buffer(&self) -> &GpuBuffer<f32> {
248 &self.data
249 }
250
251 pub fn data_buffer_mut(&mut self) -> &mut GpuBuffer<f32> {
253 &mut self.data
254 }
255
256 pub fn grad_buffer(&self) -> Option<&GpuBuffer<f32>> {
258 self.grad.as_ref()
259 }
260
261 pub fn grad_buffer_mut(&mut self) -> Option<&mut GpuBuffer<f32>> {
263 self.grad.as_mut()
264 }
265
266 pub fn device(&self) -> &Arc<CudaContext> {
268 &self.device
269 }
270}
271
272#[cfg(feature = "cuda")]
273impl std::fmt::Debug for CudaTensor {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("CudaTensor")
276 .field("len", &self.len)
277 .field("requires_grad", &self.requires_grad)
278 .field("has_grad", &self.grad.is_some())
279 .finish_non_exhaustive()
280 }
281}
282
283#[cfg(not(feature = "cuda"))]
285pub struct CudaDevice;
286
287#[cfg(not(feature = "cuda"))]
288impl CudaDevice {
289 pub fn new(_device_id: i32) -> Result<Self> {
290 Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
291 }
292
293 pub fn default_device() -> Result<Self> {
294 Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
295 }
296}
297
298#[cfg(not(feature = "cuda"))]
299pub struct CudaTensor;
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_cuda_tensor_error_display() {
307 let err = CudaTensorError::CudaNotAvailable("test".into());
308 assert!(err.to_string().contains("CUDA not available"));
309
310 let err = CudaTensorError::ShapeMismatch { expected: 10, actual: 5 };
311 assert!(err.to_string().contains("10"));
312 assert!(err.to_string().contains('5'));
313 }
314
315 #[test]
316 #[cfg(feature = "cuda")]
317 fn test_cuda_device_creation() {
318 if !cuda_available() {
320 return;
321 }
322
323 let device = CudaDevice::default_device();
324 assert!(device.is_ok());
325 }
326
327 #[test]
328 #[cfg(feature = "cuda")]
329 fn test_cuda_tensor_from_vec() {
330 if !cuda_available() {
331 return;
332 }
333
334 let device = CudaDevice::default_device().expect("operation should succeed");
335 let data = vec![1.0, 2.0, 3.0, 4.0];
336 let tensor =
337 CudaTensor::from_vec(&device, data.clone(), true).expect("operation should succeed");
338
339 assert_eq!(tensor.len(), 4);
340 assert!(tensor.requires_grad());
341
342 let result = tensor.to_vec().expect("operation should succeed");
344 assert_eq!(result, data);
345 }
346
347 #[test]
348 #[cfg(feature = "cuda")]
349 fn test_cuda_tensor_zeros() {
350 if !cuda_available() {
351 return;
352 }
353
354 let device = CudaDevice::default_device().expect("operation should succeed");
355 let tensor = CudaTensor::zeros(&device, 100, false).expect("operation should succeed");
356
357 assert_eq!(tensor.len(), 100);
358 assert!(!tensor.requires_grad());
359
360 let data = tensor.to_vec().expect("operation should succeed");
361 assert!(data.iter().all(|&x| x == 0.0));
362 }
363
364 #[test]
365 #[cfg(feature = "cuda")]
366 fn test_cuda_tensor_ones() {
367 if !cuda_available() {
368 return;
369 }
370
371 let device = CudaDevice::default_device().expect("operation should succeed");
372 let tensor = CudaTensor::ones(&device, 50, true).expect("operation should succeed");
373
374 assert_eq!(tensor.len(), 50);
375 let data = tensor.to_vec().expect("operation should succeed");
376 assert!(data.iter().all(|&x| x == 1.0));
377 }
378
379 #[test]
380 #[cfg(feature = "cuda")]
381 fn test_cuda_tensor_gradient() {
382 if !cuda_available() {
383 return;
384 }
385
386 let device = CudaDevice::default_device().expect("operation should succeed");
387 let mut tensor = CudaTensor::from_vec(&device, vec![1.0, 2.0, 3.0], true)
388 .expect("operation should succeed");
389
390 let grad = tensor
392 .grad_to_vec()
393 .expect("operation should succeed")
394 .expect("operation should succeed");
395 assert!(grad.iter().all(|&x| x == 0.0));
396
397 tensor.set_grad_from_vec(&[0.1, 0.2, 0.3]).expect("operation should succeed");
399 let grad = tensor
400 .grad_to_vec()
401 .expect("operation should succeed")
402 .expect("operation should succeed");
403 assert!((grad[0] - 0.1).abs() < 1e-6);
404 assert!((grad[1] - 0.2).abs() < 1e-6);
405 assert!((grad[2] - 0.3).abs() < 1e-6);
406
407 tensor.zero_grad().expect("gradient should be available");
409 let grad = tensor
410 .grad_to_vec()
411 .expect("operation should succeed")
412 .expect("operation should succeed");
413 assert!(grad.iter().all(|&x| x == 0.0));
414 }
415
416 #[test]
417 #[cfg(feature = "cuda")]
418 fn test_cuda_tensor_copy_from_vec() {
419 if !cuda_available() {
420 return;
421 }
422
423 let device = CudaDevice::default_device().expect("operation should succeed");
424 let mut tensor = CudaTensor::zeros(&device, 4, false).expect("operation should succeed");
425
426 tensor.copy_from_vec(&[5.0, 6.0, 7.0, 8.0]).expect("operation should succeed");
427 let data = tensor.to_vec().expect("operation should succeed");
428 assert_eq!(data, vec![5.0, 6.0, 7.0, 8.0]);
429 }
430
431 #[test]
432 #[cfg(feature = "cuda")]
433 fn test_cuda_tensor_shape_mismatch() {
434 if !cuda_available() {
435 return;
436 }
437
438 let device = CudaDevice::default_device().expect("operation should succeed");
439 let mut tensor = CudaTensor::zeros(&device, 4, false).expect("operation should succeed");
440
441 let result = tensor.copy_from_vec(&[1.0, 2.0]); assert!(result.is_err());
443 assert!(matches!(result, Err(CudaTensorError::ShapeMismatch { .. })));
444 }
445
446 #[test]
447 #[cfg(not(feature = "cuda"))]
448 fn test_cuda_not_available_fallback() {
449 let result = CudaDevice::default_device();
450 assert!(result.is_err());
451 assert!(matches!(result, Err(CudaTensorError::CudaNotAvailable(_))));
452 }
453}