Skip to main content

entrenar/autograd/
cuda_tensor.rs

1//! CUDA-accelerated tensor type for GPU training
2//!
3//! This module provides GPU-resident tensors using trueno-gpu's CUDA backend.
4//! It replaces ndarray::Array1<f32> with CUDA-backed storage for 100x speedup.
5//!
6//! # Architecture (SPEC-FT-001 v3.0.0)
7//!
8//! ```text
9//! CudaTensor
10//!   ├── data: GpuBuffer<f32>     (GPU memory)
11//!   ├── grad: Option<GpuBuffer<f32>>  (gradient on GPU)
12//!   └── ctx: Arc<CudaContext>    (shared device context)
13//! ```
14//!
15//! # Example
16//!
17//! ```ignore
18//! use entrenar::autograd::CudaTensor;
19//!
20//! // Create tensor on GPU
21//! let t = CudaTensor::from_vec(vec![1.0, 2.0, 3.0], true)?;
22//!
23//! // Transfer back to CPU
24//! let data = t.to_vec()?;
25//! ```
26
27#[cfg(feature = "cuda")]
28use std::sync::Arc;
29
30#[cfg(feature = "cuda")]
31use trueno_gpu::driver::{cuda_available, CudaContext, CudaStream, GpuBuffer};
32#[cfg(feature = "cuda")]
33use trueno_gpu::GpuError;
34
35/// Error type for CUDA tensor operations
36#[derive(Debug, thiserror::Error)]
37pub enum CudaTensorError {
38    /// CUDA is not available on this system
39    #[error("CUDA not available: {0}")]
40    CudaNotAvailable(String),
41
42    /// GPU memory allocation failed
43    #[error("GPU allocation failed: {0}")]
44    AllocationFailed(String),
45
46    /// Data transfer failed
47    #[error("Data transfer failed: {0}")]
48    TransferFailed(String),
49
50    /// Shape mismatch
51    #[error("Shape mismatch: expected {expected}, got {actual}")]
52    ShapeMismatch { expected: usize, actual: usize },
53
54    /// Kernel launch failed
55    #[error("Kernel launch failed: {0}")]
56    KernelError(String),
57
58    /// Device not initialized
59    #[error("CUDA device not initialized")]
60    DeviceNotInitialized,
61}
62
63#[cfg(feature = "cuda")]
64impl From<GpuError> for CudaTensorError {
65    fn from(e: GpuError) -> Self {
66        match e {
67            GpuError::OutOfMemory { requested, available } => CudaTensorError::AllocationFailed(
68                format!("Out of GPU memory: requested {requested} bytes, {available} available"),
69            ),
70            GpuError::Transfer(msg) => CudaTensorError::TransferFailed(msg),
71            GpuError::CudaNotAvailable(msg) => CudaTensorError::CudaNotAvailable(msg),
72            other => CudaTensorError::KernelError(format!("{other:?}")),
73        }
74    }
75}
76
77/// Result type for CUDA tensor operations
78pub type Result<T> = std::result::Result<T, CudaTensorError>;
79
80/// CUDA device handle with lazy initialization
81#[cfg(feature = "cuda")]
82pub struct CudaDevice {
83    ctx: Arc<CudaContext>,
84    stream: CudaStream,
85}
86
87#[cfg(feature = "cuda")]
88impl CudaDevice {
89    /// Create a new CUDA device handle for the given device ID
90    pub fn new(device_id: i32) -> Result<Self> {
91        if !cuda_available() {
92            return Err(CudaTensorError::CudaNotAvailable("No CUDA driver found".into()));
93        }
94
95        let ctx = CudaContext::new(device_id)
96            .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?;
97        let stream = CudaStream::new(&ctx)
98            .map_err(|e| CudaTensorError::AllocationFailed(format!("{e:?}")))?;
99
100        Ok(Self { ctx: Arc::new(ctx), stream })
101    }
102
103    /// Create device handle for default GPU (device 0)
104    pub fn default_device() -> Result<Self> {
105        Self::new(0)
106    }
107
108    /// Get the CUDA context
109    pub fn context(&self) -> &Arc<CudaContext> {
110        &self.ctx
111    }
112
113    /// Get the CUDA stream
114    pub fn stream(&self) -> &CudaStream {
115        &self.stream
116    }
117
118    /// Synchronize the stream (wait for all operations to complete)
119    pub fn synchronize(&self) -> Result<()> {
120        self.stream.synchronize().map_err(|e| CudaTensorError::KernelError(format!("{e:?}")))
121    }
122}
123
124/// GPU-resident tensor with gradient support
125///
126/// This is the CUDA-accelerated replacement for `Tensor` when the `cuda` feature is enabled.
127#[cfg(feature = "cuda")]
128pub struct CudaTensor {
129    /// Data stored on GPU
130    data: GpuBuffer<f32>,
131    /// Gradient stored on GPU (if requires_grad)
132    grad: Option<GpuBuffer<f32>>,
133    /// Shared device context
134    device: Arc<CudaContext>,
135    /// Whether this tensor requires gradient computation
136    requires_grad: bool,
137    /// Number of elements
138    len: usize,
139}
140
141#[cfg(feature = "cuda")]
142impl CudaTensor {
143    /// Create a new tensor on GPU from host data
144    pub fn from_vec(device: &CudaDevice, data: Vec<f32>, requires_grad: bool) -> Result<Self> {
145        let len = data.len();
146        let gpu_data = GpuBuffer::from_host(&device.ctx, &data)?;
147
148        let grad = if requires_grad {
149            // Initialize gradient to zeros
150            let zeros = vec![0.0f32; len];
151            Some(GpuBuffer::from_host(&device.ctx, &zeros)?)
152        } else {
153            None
154        };
155
156        Ok(Self { data: gpu_data, grad, device: device.ctx.clone(), requires_grad, len })
157    }
158
159    /// Create a tensor filled with zeros
160    pub fn zeros(device: &CudaDevice, len: usize, requires_grad: bool) -> Result<Self> {
161        let data = vec![0.0f32; len];
162        Self::from_vec(device, data, requires_grad)
163    }
164
165    /// Create a tensor filled with ones
166    pub fn ones(device: &CudaDevice, len: usize, requires_grad: bool) -> Result<Self> {
167        let data = vec![1.0f32; len];
168        Self::from_vec(device, data, requires_grad)
169    }
170
171    /// Copy tensor data back to CPU
172    pub fn to_vec(&self) -> Result<Vec<f32>> {
173        let mut result = vec![0.0f32; self.len];
174        self.data.copy_to_host(&mut result)?;
175        Ok(result)
176    }
177
178    /// Get gradient as CPU vector (if computed)
179    pub fn grad_to_vec(&self) -> Result<Option<Vec<f32>>> {
180        match &self.grad {
181            Some(grad_buf) => {
182                let mut result = vec![0.0f32; self.len];
183                grad_buf.copy_to_host(&mut result)?;
184                Ok(Some(result))
185            }
186            None => Ok(None),
187        }
188    }
189
190    /// Update data from CPU vector
191    pub fn copy_from_vec(&mut self, data: &[f32]) -> Result<()> {
192        if data.len() != self.len {
193            return Err(CudaTensorError::ShapeMismatch { expected: self.len, actual: data.len() });
194        }
195        self.data.copy_from_host(data)?;
196        Ok(())
197    }
198
199    /// Set gradient from CPU vector
200    pub fn set_grad_from_vec(&mut self, grad: &[f32]) -> Result<()> {
201        if grad.len() != self.len {
202            return Err(CudaTensorError::ShapeMismatch { expected: self.len, actual: grad.len() });
203        }
204
205        match &mut self.grad {
206            Some(grad_buf) => {
207                grad_buf.copy_from_host(grad)?;
208            }
209            None => {
210                self.grad = Some(GpuBuffer::from_host(
211                    // Need to get context somehow - this is a design issue
212                    // For now, we'll create a new buffer
213                    &CudaContext::new(0)
214                        .map_err(|e| CudaTensorError::CudaNotAvailable(format!("{e:?}")))?,
215                    grad,
216                )?);
217            }
218        }
219        Ok(())
220    }
221
222    /// Zero out gradient
223    pub fn zero_grad(&mut self) -> Result<()> {
224        if let Some(ref mut grad_buf) = self.grad {
225            let zeros = vec![0.0f32; self.len];
226            grad_buf.copy_from_host(&zeros)?;
227        }
228        Ok(())
229    }
230
231    /// Check if requires gradient
232    pub fn requires_grad(&self) -> bool {
233        self.requires_grad
234    }
235
236    /// Get number of elements
237    pub fn len(&self) -> usize {
238        self.len
239    }
240
241    /// Check if empty
242    pub fn is_empty(&self) -> bool {
243        self.len == 0
244    }
245
246    /// Get raw GPU buffer for data (for kernel operations)
247    pub fn data_buffer(&self) -> &GpuBuffer<f32> {
248        &self.data
249    }
250
251    /// Get mutable raw GPU buffer for data
252    pub fn data_buffer_mut(&mut self) -> &mut GpuBuffer<f32> {
253        &mut self.data
254    }
255
256    /// Get raw GPU buffer for gradient (for kernel operations)
257    pub fn grad_buffer(&self) -> Option<&GpuBuffer<f32>> {
258        self.grad.as_ref()
259    }
260
261    /// Get mutable raw GPU buffer for gradient
262    pub fn grad_buffer_mut(&mut self) -> Option<&mut GpuBuffer<f32>> {
263        self.grad.as_mut()
264    }
265
266    /// Get device context
267    pub fn device(&self) -> &Arc<CudaContext> {
268        &self.device
269    }
270}
271
272#[cfg(feature = "cuda")]
273impl std::fmt::Debug for CudaTensor {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_struct("CudaTensor")
276            .field("len", &self.len)
277            .field("requires_grad", &self.requires_grad)
278            .field("has_grad", &self.grad.is_some())
279            .finish_non_exhaustive()
280    }
281}
282
283// CPU fallback when CUDA is not available
284#[cfg(not(feature = "cuda"))]
285pub struct CudaDevice;
286
287#[cfg(not(feature = "cuda"))]
288impl CudaDevice {
289    pub fn new(_device_id: i32) -> Result<Self> {
290        Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
291    }
292
293    pub fn default_device() -> Result<Self> {
294        Err(CudaTensorError::CudaNotAvailable("Compiled without CUDA support".into()))
295    }
296}
297
298#[cfg(not(feature = "cuda"))]
299pub struct CudaTensor;
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_cuda_tensor_error_display() {
307        let err = CudaTensorError::CudaNotAvailable("test".into());
308        assert!(err.to_string().contains("CUDA not available"));
309
310        let err = CudaTensorError::ShapeMismatch { expected: 10, actual: 5 };
311        assert!(err.to_string().contains("10"));
312        assert!(err.to_string().contains('5'));
313    }
314
315    #[test]
316    #[cfg(feature = "cuda")]
317    fn test_cuda_device_creation() {
318        // This test only runs if CUDA is actually available
319        if !cuda_available() {
320            return;
321        }
322
323        let device = CudaDevice::default_device();
324        assert!(device.is_ok());
325    }
326
327    #[test]
328    #[cfg(feature = "cuda")]
329    fn test_cuda_tensor_from_vec() {
330        if !cuda_available() {
331            return;
332        }
333
334        let device = CudaDevice::default_device().expect("operation should succeed");
335        let data = vec![1.0, 2.0, 3.0, 4.0];
336        let tensor =
337            CudaTensor::from_vec(&device, data.clone(), true).expect("operation should succeed");
338
339        assert_eq!(tensor.len(), 4);
340        assert!(tensor.requires_grad());
341
342        // Verify round-trip
343        let result = tensor.to_vec().expect("operation should succeed");
344        assert_eq!(result, data);
345    }
346
347    #[test]
348    #[cfg(feature = "cuda")]
349    fn test_cuda_tensor_zeros() {
350        if !cuda_available() {
351            return;
352        }
353
354        let device = CudaDevice::default_device().expect("operation should succeed");
355        let tensor = CudaTensor::zeros(&device, 100, false).expect("operation should succeed");
356
357        assert_eq!(tensor.len(), 100);
358        assert!(!tensor.requires_grad());
359
360        let data = tensor.to_vec().expect("operation should succeed");
361        assert!(data.iter().all(|&x| x == 0.0));
362    }
363
364    #[test]
365    #[cfg(feature = "cuda")]
366    fn test_cuda_tensor_ones() {
367        if !cuda_available() {
368            return;
369        }
370
371        let device = CudaDevice::default_device().expect("operation should succeed");
372        let tensor = CudaTensor::ones(&device, 50, true).expect("operation should succeed");
373
374        assert_eq!(tensor.len(), 50);
375        let data = tensor.to_vec().expect("operation should succeed");
376        assert!(data.iter().all(|&x| x == 1.0));
377    }
378
379    #[test]
380    #[cfg(feature = "cuda")]
381    fn test_cuda_tensor_gradient() {
382        if !cuda_available() {
383            return;
384        }
385
386        let device = CudaDevice::default_device().expect("operation should succeed");
387        let mut tensor = CudaTensor::from_vec(&device, vec![1.0, 2.0, 3.0], true)
388            .expect("operation should succeed");
389
390        // Initially gradient should be zeros
391        let grad = tensor
392            .grad_to_vec()
393            .expect("operation should succeed")
394            .expect("operation should succeed");
395        assert!(grad.iter().all(|&x| x == 0.0));
396
397        // Set gradient
398        tensor.set_grad_from_vec(&[0.1, 0.2, 0.3]).expect("operation should succeed");
399        let grad = tensor
400            .grad_to_vec()
401            .expect("operation should succeed")
402            .expect("operation should succeed");
403        assert!((grad[0] - 0.1).abs() < 1e-6);
404        assert!((grad[1] - 0.2).abs() < 1e-6);
405        assert!((grad[2] - 0.3).abs() < 1e-6);
406
407        // Zero gradient
408        tensor.zero_grad().expect("gradient should be available");
409        let grad = tensor
410            .grad_to_vec()
411            .expect("operation should succeed")
412            .expect("operation should succeed");
413        assert!(grad.iter().all(|&x| x == 0.0));
414    }
415
416    #[test]
417    #[cfg(feature = "cuda")]
418    fn test_cuda_tensor_copy_from_vec() {
419        if !cuda_available() {
420            return;
421        }
422
423        let device = CudaDevice::default_device().expect("operation should succeed");
424        let mut tensor = CudaTensor::zeros(&device, 4, false).expect("operation should succeed");
425
426        tensor.copy_from_vec(&[5.0, 6.0, 7.0, 8.0]).expect("operation should succeed");
427        let data = tensor.to_vec().expect("operation should succeed");
428        assert_eq!(data, vec![5.0, 6.0, 7.0, 8.0]);
429    }
430
431    #[test]
432    #[cfg(feature = "cuda")]
433    fn test_cuda_tensor_shape_mismatch() {
434        if !cuda_available() {
435            return;
436        }
437
438        let device = CudaDevice::default_device().expect("operation should succeed");
439        let mut tensor = CudaTensor::zeros(&device, 4, false).expect("operation should succeed");
440
441        let result = tensor.copy_from_vec(&[1.0, 2.0]); // Wrong size
442        assert!(result.is_err());
443        assert!(matches!(result, Err(CudaTensorError::ShapeMismatch { .. })));
444    }
445
446    #[test]
447    #[cfg(not(feature = "cuda"))]
448    fn test_cuda_not_available_fallback() {
449        let result = CudaDevice::default_device();
450        assert!(result.is_err());
451        assert!(matches!(result, Err(CudaTensorError::CudaNotAvailable(_))));
452    }
453}