Skip to main content

ferrotorch_gpu/
tensor_bridge.rs

1//! Bridge between ferrotorch-core `Tensor<T>` and GPU operations.
2//!
3//! Because ferrotorch-core cannot depend on ferrotorch-gpu (that would be a
4//! circular dependency), all GPU integration lives here. This module provides:
5//!
6//! - [`GpuTensor<T>`] — a wrapper combining a [`CudaBuffer<T>`] with shape
7//!   metadata and the originating [`GpuDevice`].
8//! - Transfer functions to move data between CPU [`Tensor`] and GPU.
9//! - Elementwise arithmetic on [`GpuTensor`] backed by PTX kernels.
10//!
11//! # f32-only kernels
12//!
13//! The PTX kernels are currently f32-only. Operations on `GpuTensor<f64>` fall
14//! back to a CPU round-trip (copy to host, compute, copy back). The type
15//! parameter is kept for API consistency — once f64 PTX kernels are added,
16//! the fallback disappears transparently.
17
18use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
19
20use crate::blas::{gpu_matmul_f32, gpu_matmul_f64};
21use crate::buffer::CudaBuffer;
22use crate::conv::gpu_conv2d_f32;
23use crate::device::GpuDevice;
24use crate::error::{GpuError, GpuResult};
25use crate::kernels::{gpu_add, gpu_mul, gpu_neg, gpu_relu, gpu_sub};
26use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
27
28// ---------------------------------------------------------------------------
29// GpuFloat — Float + DeviceRepr (when cuda is enabled)
30// ---------------------------------------------------------------------------
31//
32// cudarc's transfer functions require `T: DeviceRepr`. Both f32 and f64
33// implement it, but we can't unconditionally name the trait when the `cuda`
34// feature is off (cudarc isn't compiled). This helper trait bridges the gap.
35
36/// Trait alias: `Float` types that can be transferred to/from GPU.
37///
38/// When the `cuda` feature is enabled this adds `cudarc::driver::DeviceRepr`.
39/// When disabled, it is identical to [`Float`].
40#[cfg(feature = "cuda")]
41pub trait GpuFloat: Float + cudarc::driver::DeviceRepr {}
42
43#[cfg(feature = "cuda")]
44impl GpuFloat for f32 {}
45#[cfg(feature = "cuda")]
46impl GpuFloat for f64 {}
47
48#[cfg(not(feature = "cuda"))]
49pub trait GpuFloat: Float {}
50
51#[cfg(not(feature = "cuda"))]
52impl GpuFloat for f32 {}
53#[cfg(not(feature = "cuda"))]
54impl GpuFloat for f64 {}
55
56// ---------------------------------------------------------------------------
57// GpuTensor
58// ---------------------------------------------------------------------------
59
60/// A tensor residing on a CUDA GPU.
61///
62/// Wraps a [`CudaBuffer<T>`] with shape metadata and a reference to the
63/// [`GpuDevice`] that owns the memory. Created by [`tensor_to_gpu`] or
64/// the convenience functions [`cuda`] / [`cuda_default`].
65///
66/// Convert back to a CPU [`Tensor`] with [`GpuTensor::cpu`] or the free
67/// function [`tensor_to_cpu`].
68pub struct GpuTensor<T: GpuFloat> {
69    buffer: CudaBuffer<T>,
70    shape: Vec<usize>,
71    device: GpuDevice,
72}
73
74impl<T: GpuFloat> GpuTensor<T> {
75    /// The shape of this tensor.
76    #[inline]
77    pub fn shape(&self) -> &[usize] {
78        &self.shape
79    }
80
81    /// Total number of elements.
82    #[inline]
83    pub fn numel(&self) -> usize {
84        self.shape.iter().product()
85    }
86
87    /// The GPU device that holds this tensor's data.
88    #[inline]
89    pub fn device(&self) -> &GpuDevice {
90        &self.device
91    }
92
93    /// Borrow the underlying [`CudaBuffer`].
94    #[inline]
95    pub fn buffer(&self) -> &CudaBuffer<T> {
96        &self.buffer
97    }
98
99    /// Number of dimensions.
100    #[inline]
101    pub fn ndim(&self) -> usize {
102        self.shape.len()
103    }
104
105    /// Copy this tensor back to CPU, returning a [`Tensor<T>`].
106    ///
107    /// This is a convenience wrapper around [`tensor_to_cpu`].
108    pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
109        tensor_to_cpu(self)
110    }
111}
112
113impl<T: GpuFloat> std::fmt::Debug for GpuTensor<T> {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("GpuTensor")
116            .field("shape", &self.shape)
117            .field("numel", &self.numel())
118            .field("device_ordinal", &self.device.ordinal())
119            .finish_non_exhaustive()
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Arithmetic operations
125// ---------------------------------------------------------------------------
126
127/// Returns `true` if `T` is `f32` (the type our PTX kernels support).
128#[inline]
129fn is_f32<T: GpuFloat>() -> bool {
130    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
131}
132
133/// Shape-validation helper for binary operations.
134fn validate_shapes<T: GpuFloat>(a: &GpuTensor<T>, b: &GpuTensor<T>) -> GpuResult<()> {
135    if a.shape() != b.shape() {
136        return Err(GpuError::LengthMismatch {
137            a: a.numel(),
138            b: b.numel(),
139        });
140    }
141    if a.device.ordinal() != b.device.ordinal() {
142        return Err(GpuError::DeviceMismatch {
143            expected: a.device.ordinal(),
144            got: b.device.ordinal(),
145        });
146    }
147    Ok(())
148}
149
150impl<T: GpuFloat> GpuTensor<T> {
151    /// Elementwise addition: `out[i] = self[i] + other[i]`.
152    ///
153    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
154    ///
155    /// # Errors
156    ///
157    /// - [`GpuError::LengthMismatch`] if shapes differ.
158    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
159    /// - [`GpuError::Driver`] on CUDA runtime errors.
160    pub fn add(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
161        validate_shapes(self, other)?;
162        if is_f32::<T>() {
163            // SAFETY: We have verified T is f32 via TypeId, so T is f32-layout-compatible.
164            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
165            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
166            let out_buf = gpu_add(a_buf, b_buf, &self.device)?;
167            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
168            Ok(GpuTensor {
169                buffer: out_buf,
170                shape: self.shape.clone(),
171                device: self.device.clone(),
172            })
173        } else {
174            binary_cpu_fallback(self, other, |a, b| a + b)
175        }
176    }
177
178    /// Elementwise subtraction: `out[i] = self[i] - other[i]`.
179    ///
180    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
181    pub fn sub(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
182        validate_shapes(self, other)?;
183        if is_f32::<T>() {
184            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
185            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
186            let out_buf = gpu_sub(a_buf, b_buf, &self.device)?;
187            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
188            Ok(GpuTensor {
189                buffer: out_buf,
190                shape: self.shape.clone(),
191                device: self.device.clone(),
192            })
193        } else {
194            binary_cpu_fallback(self, other, |a, b| a - b)
195        }
196    }
197
198    /// Elementwise multiplication: `out[i] = self[i] * other[i]`.
199    ///
200    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
201    pub fn mul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
202        validate_shapes(self, other)?;
203        if is_f32::<T>() {
204            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
205            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
206            let out_buf = gpu_mul(a_buf, b_buf, &self.device)?;
207            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
208            Ok(GpuTensor {
209                buffer: out_buf,
210                shape: self.shape.clone(),
211                device: self.device.clone(),
212            })
213        } else {
214            binary_cpu_fallback(self, other, |a, b| a * b)
215        }
216    }
217
218    /// Elementwise negation: `out[i] = -self[i]`.
219    ///
220    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
221    pub fn neg(&self) -> GpuResult<GpuTensor<T>> {
222        if is_f32::<T>() {
223            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
224            let out_buf = gpu_neg(a_buf, &self.device)?;
225            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
226            Ok(GpuTensor {
227                buffer: out_buf,
228                shape: self.shape.clone(),
229                device: self.device.clone(),
230            })
231        } else {
232            unary_cpu_fallback(self, |x| -x)
233        }
234    }
235
236    /// Elementwise ReLU: `out[i] = max(self[i], 0)`.
237    ///
238    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
239    pub fn relu(&self) -> GpuResult<GpuTensor<T>> {
240        if is_f32::<T>() {
241            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
242            let out_buf = gpu_relu(a_buf, &self.device)?;
243            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
244            Ok(GpuTensor {
245                buffer: out_buf,
246                shape: self.shape.clone(),
247                device: self.device.clone(),
248            })
249        } else {
250            unary_cpu_fallback(self, |x| {
251                let z = <T as num_traits::Zero>::zero();
252                if x > z { x } else { z }
253            })
254        }
255    }
256
257    /// Matrix multiplication: `C = self @ other`.
258    ///
259    /// Both tensors must be 2-D. `self` has shape `[m, k]` and `other` has
260    /// shape `[k, n]`. The result has shape `[m, n]`.
261    ///
262    /// Uses cuBLAS SGEMM for `f32` and DGEMM for `f64`.
263    ///
264    /// # Errors
265    ///
266    /// - [`GpuError::ShapeMismatch`] if either tensor is not 2-D or if the
267    ///   inner dimensions do not match (`self.shape[1] != other.shape[0]`).
268    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
269    /// - [`GpuError::Blas`] on cuBLAS runtime errors.
270    pub fn matmul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
271        // Validate 2-D shapes.
272        if self.ndim() != 2 {
273            return Err(GpuError::ShapeMismatch {
274                op: "matmul",
275                expected: vec![0, 0], // placeholder: "expected 2-D"
276                got: self.shape.clone(),
277            });
278        }
279        if other.ndim() != 2 {
280            return Err(GpuError::ShapeMismatch {
281                op: "matmul",
282                expected: vec![0, 0],
283                got: other.shape.clone(),
284            });
285        }
286
287        let m = self.shape[0];
288        let k = self.shape[1];
289        let k2 = other.shape[0];
290        let n = other.shape[1];
291
292        if k != k2 {
293            return Err(GpuError::ShapeMismatch {
294                op: "matmul",
295                expected: vec![k, n],
296                got: vec![k2, n],
297            });
298        }
299
300        if self.device.ordinal() != other.device.ordinal() {
301            return Err(GpuError::DeviceMismatch {
302                expected: self.device.ordinal(),
303                got: other.device.ordinal(),
304            });
305        }
306
307        if is_f32::<T>() {
308            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
309            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
310            let out_buf = gpu_matmul_f32(a_buf, b_buf, m, k, n, &self.device)?;
311            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
312            Ok(GpuTensor {
313                buffer: out_buf,
314                shape: vec![m, n],
315                device: self.device.clone(),
316            })
317        } else {
318            // f64 path
319            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
320            let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
321            let out_buf = gpu_matmul_f64(a_buf, b_buf, m, k, n, &self.device)?;
322            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
323            Ok(GpuTensor {
324                buffer: out_buf,
325                shape: vec![m, n],
326                device: self.device.clone(),
327            })
328        }
329    }
330
331    /// 2-D convolution: `output = conv2d(self, weight, bias)`.
332    ///
333    /// Uses im2col (CPU) + cuBLAS GEMM (GPU) — no cuDNN required.
334    ///
335    /// `self` must have shape `[B, C_in, H, W]` and `weight` must have
336    /// shape `[C_out, C_in, kH, kW]`. `bias`, if provided, must have
337    /// shape `[C_out]`. The result has shape `[B, C_out, H_out, W_out]`.
338    ///
339    /// Currently only supports `f32`. For `f64` tensors, returns
340    /// [`GpuError::ShapeMismatch`] (f64 conv path not yet implemented).
341    ///
342    /// # Errors
343    ///
344    /// - [`GpuError::ShapeMismatch`] if tensor dimensions are wrong, channel
345    ///   counts don't match, or if `T` is not `f32`.
346    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
347    /// - [`GpuError::Blas`] on cuBLAS runtime errors.
348    pub fn conv2d(
349        &self,
350        weight: &GpuTensor<T>,
351        bias: Option<&GpuTensor<T>>,
352        stride: (usize, usize),
353        padding: (usize, usize),
354    ) -> GpuResult<GpuTensor<T>> {
355        // Validate 4-D input.
356        if self.ndim() != 4 {
357            return Err(GpuError::ShapeMismatch {
358                op: "conv2d",
359                expected: vec![0, 0, 0, 0],
360                got: self.shape.clone(),
361            });
362        }
363        // Validate 4-D weight.
364        if weight.ndim() != 4 {
365            return Err(GpuError::ShapeMismatch {
366                op: "conv2d",
367                expected: vec![0, 0, 0, 0],
368                got: weight.shape.clone(),
369            });
370        }
371        // Validate 1-D bias.
372        if let Some(b) = bias {
373            if b.ndim() != 1 {
374                return Err(GpuError::ShapeMismatch {
375                    op: "conv2d",
376                    expected: vec![weight.shape[0]],
377                    got: b.shape.clone(),
378                });
379            }
380        }
381        // Device consistency.
382        if self.device.ordinal() != weight.device.ordinal() {
383            return Err(GpuError::DeviceMismatch {
384                expected: self.device.ordinal(),
385                got: weight.device.ordinal(),
386            });
387        }
388        if let Some(b) = bias {
389            if self.device.ordinal() != b.device.ordinal() {
390                return Err(GpuError::DeviceMismatch {
391                    expected: self.device.ordinal(),
392                    got: b.device.ordinal(),
393                });
394            }
395        }
396
397        if !is_f32::<T>() {
398            return Err(GpuError::ShapeMismatch {
399                op: "conv2d",
400                expected: vec![],
401                got: vec![],
402            });
403        }
404
405        let input_shape: [usize; 4] = [self.shape[0], self.shape[1], self.shape[2], self.shape[3]];
406        let weight_shape: [usize; 4] = [
407            weight.shape[0],
408            weight.shape[1],
409            weight.shape[2],
410            weight.shape[3],
411        ];
412
413        let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
414        let w_buf = unsafe { transmute_buffer_ref::<T, f32>(&weight.buffer) };
415        let b_buf = bias.map(|b| unsafe { transmute_buffer_ref::<T, f32>(&b.buffer) });
416
417        let (out_buf, out_shape) = gpu_conv2d_f32(
418            a_buf,
419            w_buf,
420            b_buf,
421            input_shape,
422            weight_shape,
423            stride,
424            padding,
425            &self.device,
426        )?;
427
428        let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
429        Ok(GpuTensor {
430            buffer: out_buf,
431            shape: out_shape.to_vec(),
432            device: self.device.clone(),
433        })
434    }
435}
436
437// ---------------------------------------------------------------------------
438// Transmute helpers for CudaBuffer<T> <-> CudaBuffer<f32>
439// ---------------------------------------------------------------------------
440//
441// The PTX kernel functions take `&CudaBuffer<f32>`. When T == f32 the
442// CudaBuffer<T> has an identical layout, so we can safely reinterpret.
443// These helpers are only called after an `is_f32::<T>()` guard.
444
445/// Reinterpret a `&CudaBuffer<T>` as `&CudaBuffer<U>`.
446///
447/// # Safety
448///
449/// Caller must have verified `size_of::<T>() == size_of::<U>()` and
450/// `align_of::<T>() == align_of::<U>()`.
451#[cfg(feature = "cuda")]
452unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
453    debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
454    debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
455    // CudaBuffer<T> and CudaBuffer<U> have identical layout when T and U
456    // are the same size — CudaSlice is a device pointer + length, both
457    // size-independent.
458    unsafe { &*(buf as *const CudaBuffer<T> as *const CudaBuffer<U>) }
459}
460
461/// Reinterpret an owned `CudaBuffer<U>` as `CudaBuffer<T>`.
462///
463/// # Safety
464///
465/// Caller must have verified `size_of::<U>() == size_of::<T>()` and
466/// `align_of::<U>() == align_of::<T>()`.
467#[cfg(feature = "cuda")]
468unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
469    debug_assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
470    debug_assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
471    // Move the buffer without running U's drop — T's drop will handle it.
472    let result = unsafe { std::ptr::read(&buf as *const CudaBuffer<U> as *const CudaBuffer<T>) };
473    std::mem::forget(buf);
474    result
475}
476
477// Stubs when cuda is not enabled — the transmute helpers are never called
478// because all kernel functions return NoCudaFeature, but we still need them
479// to exist so the module compiles.
480
481#[cfg(not(feature = "cuda"))]
482unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
483    let _ = buf;
484    unreachable!("transmute_buffer_ref called without cuda feature")
485}
486
487#[cfg(not(feature = "cuda"))]
488unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
489    let _ = buf;
490    unreachable!("transmute_buffer called without cuda feature")
491}
492
493// ---------------------------------------------------------------------------
494// CPU fallback helpers for non-f32 types
495// ---------------------------------------------------------------------------
496
497/// Binary operation fallback: copy both operands to CPU, apply `op`, copy back.
498fn binary_cpu_fallback<T: GpuFloat>(
499    a: &GpuTensor<T>,
500    b: &GpuTensor<T>,
501    op: fn(T, T) -> T,
502) -> GpuResult<GpuTensor<T>> {
503    let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
504    let b_cpu = gpu_to_cpu(&b.buffer, &b.device)?;
505    let result: Vec<T> = a_cpu
506        .iter()
507        .zip(b_cpu.iter())
508        .map(|(&x, &y)| op(x, y))
509        .collect();
510    let out_buf = cpu_to_gpu(&result, &a.device)?;
511    Ok(GpuTensor {
512        buffer: out_buf,
513        shape: a.shape.clone(),
514        device: a.device.clone(),
515    })
516}
517
518/// Unary operation fallback: copy operand to CPU, apply `op`, copy back.
519fn unary_cpu_fallback<T: GpuFloat>(a: &GpuTensor<T>, op: fn(T) -> T) -> GpuResult<GpuTensor<T>> {
520    let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
521    let result: Vec<T> = a_cpu.iter().map(|&x| op(x)).collect();
522    let out_buf = cpu_to_gpu(&result, &a.device)?;
523    Ok(GpuTensor {
524        buffer: out_buf,
525        shape: a.shape.clone(),
526        device: a.device.clone(),
527    })
528}
529
530// ---------------------------------------------------------------------------
531// Transfer functions: Tensor <-> GpuTensor
532// ---------------------------------------------------------------------------
533
534/// Move a CPU [`Tensor<T>`] to a GPU, returning a [`GpuTensor<T>`].
535///
536/// The tensor must be contiguous and reside on `Device::Cpu`. Its flat data
537/// is copied to the given [`GpuDevice`] via a host-to-device transfer.
538///
539/// # Errors
540///
541/// - Returns [`GpuError::Driver`] on CUDA allocation or copy failures.
542/// - Returns [`GpuError::LengthMismatch`] if the tensor is not contiguous.
543pub fn tensor_to_gpu<T: GpuFloat>(
544    tensor: &Tensor<T>,
545    device: &GpuDevice,
546) -> GpuResult<GpuTensor<T>> {
547    // Ensure the tensor is contiguous so data() gives a proper flat slice.
548    if !tensor.is_contiguous() {
549        return Err(GpuError::LengthMismatch {
550            a: tensor.numel(),
551            b: tensor.data().map_or(0, |d| d.len()),
552        });
553    }
554
555    // Extract the flat data from the CPU tensor.
556    let data = tensor.data().map_err(|_e| GpuError::InvalidDevice {
557        ordinal: device.ordinal(),
558        count: 0,
559    })?;
560
561    let buffer = cpu_to_gpu(data, device)?;
562    Ok(GpuTensor {
563        buffer,
564        shape: tensor.shape().to_vec(),
565        device: device.clone(),
566    })
567}
568
569/// Move a [`GpuTensor<T>`] back to CPU, returning a [`Tensor<T>`].
570///
571/// Performs a device-to-host copy and wraps the resulting `Vec<T>` in a
572/// new leaf [`Tensor`] with `requires_grad = false`.
573///
574/// # Errors
575///
576/// - Returns an error if the device-to-host copy fails.
577pub fn tensor_to_cpu<T: GpuFloat>(gpu_tensor: &GpuTensor<T>) -> FerrotorchResult<Tensor<T>> {
578    let host_data = gpu_to_cpu(&gpu_tensor.buffer, &gpu_tensor.device).map_err(|e| {
579        FerrotorchError::InvalidArgument {
580            message: format!("GPU-to-CPU transfer failed: {e}"),
581        }
582    })?;
583
584    let storage = TensorStorage::cpu(host_data);
585    Tensor::from_storage(storage, gpu_tensor.shape.clone(), false)
586}
587
588// ---------------------------------------------------------------------------
589// Convenience free functions
590// ---------------------------------------------------------------------------
591
592/// Move a CPU [`Tensor<T>`] to CUDA device with the given ordinal.
593///
594/// Shorthand for creating a [`GpuDevice`] and calling [`tensor_to_gpu`].
595///
596/// # Errors
597///
598/// Returns a [`GpuError`] if the device cannot be initialized or the
599/// transfer fails.
600pub fn cuda<T: GpuFloat>(tensor: &Tensor<T>, ordinal: usize) -> GpuResult<GpuTensor<T>> {
601    let device = GpuDevice::new(ordinal)?;
602    tensor_to_gpu(tensor, &device)
603}
604
605/// Move a CPU [`Tensor<T>`] to CUDA device 0.
606///
607/// Equivalent to `cuda(tensor, 0)`.
608pub fn cuda_default<T: GpuFloat>(tensor: &Tensor<T>) -> GpuResult<GpuTensor<T>> {
609    cuda(tensor, 0)
610}
611
612// ---------------------------------------------------------------------------
613// Tests
614// ---------------------------------------------------------------------------
615
616#[cfg(test)]
617#[cfg(feature = "cuda")]
618mod tests {
619    use super::*;
620    use ferrotorch_core::{Tensor, TensorStorage};
621
622    /// Helper: create a CPU tensor from a flat vec with the given shape.
623    fn cpu_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
624        let storage = TensorStorage::cpu(data);
625        Tensor::from_storage(storage, shape, false).expect("cpu_tensor")
626    }
627
628    // -- round-trip -----------------------------------------------------------
629
630    #[test]
631    fn tensor_to_gpu_round_trip() {
632        let t = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
633        let gpu = cuda_default(&t).expect("cuda_default");
634        let back = gpu.cpu().expect("cpu()");
635
636        assert_eq!(back.shape(), &[2, 3]);
637        assert_eq!(back.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
638    }
639
640    // -- shape preservation ---------------------------------------------------
641
642    #[test]
643    fn gpu_tensor_shape_preserved() {
644        let t = cpu_tensor(vec![1.0; 24], vec![2, 3, 4]);
645        let gpu = cuda_default(&t).expect("cuda_default");
646
647        assert_eq!(gpu.shape(), &[2, 3, 4]);
648        assert_eq!(gpu.numel(), 24);
649        assert_eq!(gpu.ndim(), 3);
650    }
651
652    // -- add ------------------------------------------------------------------
653
654    #[test]
655    fn gpu_tensor_add() {
656        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
657        let b = cpu_tensor(vec![10.0, 20.0, 30.0, 40.0], vec![4]);
658
659        let device = GpuDevice::new(0).expect("CUDA device 0");
660        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
661        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
662
663        let gc = ga.add(&gb).expect("gpu add");
664        let result = gc.cpu().expect("cpu");
665
666        assert_eq!(result.shape(), &[4]);
667        let data = result.data().unwrap();
668        assert!((data[0] - 11.0).abs() < 1e-6);
669        assert!((data[1] - 22.0).abs() < 1e-6);
670        assert!((data[2] - 33.0).abs() < 1e-6);
671        assert!((data[3] - 44.0).abs() < 1e-6);
672    }
673
674    // -- relu -----------------------------------------------------------------
675
676    #[test]
677    fn gpu_tensor_relu() {
678        let t = cpu_tensor(vec![-3.0, -1.0, 0.0, 1.0, 3.0], vec![5]);
679        let gpu = cuda_default(&t).expect("cuda_default");
680        let out = gpu.relu().expect("relu");
681        let result = out.cpu().expect("cpu");
682
683        let data = result.data().unwrap();
684        assert!((data[0] - 0.0).abs() < 1e-6);
685        assert!((data[1] - 0.0).abs() < 1e-6);
686        assert!((data[2] - 0.0).abs() < 1e-6);
687        assert!((data[3] - 1.0).abs() < 1e-6);
688        assert!((data[4] - 3.0).abs() < 1e-6);
689    }
690
691    // -- tensor_to_cpu values -------------------------------------------------
692
693    #[test]
694    fn tensor_to_cpu_correct_values() {
695        let original = vec![0.5, -1.5, 2.25, 0.0, 100.0, -0.001];
696        let t = cpu_tensor(original.clone(), vec![2, 3]);
697        let gpu = cuda_default(&t).expect("cuda_default");
698        let back = tensor_to_cpu(&gpu).expect("tensor_to_cpu");
699
700        let data = back.data().unwrap();
701        for (i, (&got, &expected)) in data.iter().zip(original.iter()).enumerate() {
702            assert!(
703                (got - expected).abs() < 1e-6,
704                "element {i}: got {got}, expected {expected}",
705            );
706        }
707    }
708
709    // -- sub ------------------------------------------------------------------
710
711    #[test]
712    fn gpu_tensor_sub() {
713        let a = cpu_tensor(vec![10.0, 20.0, 30.0], vec![3]);
714        let b = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
715
716        let device = GpuDevice::new(0).expect("CUDA device 0");
717        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
718        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
719
720        let gc = ga.sub(&gb).expect("gpu sub");
721        let result = gc.cpu().expect("cpu");
722        let data = result.data().unwrap();
723        assert!((data[0] - 9.0).abs() < 1e-6);
724        assert!((data[1] - 18.0).abs() < 1e-6);
725        assert!((data[2] - 27.0).abs() < 1e-6);
726    }
727
728    // -- mul ------------------------------------------------------------------
729
730    #[test]
731    fn gpu_tensor_mul() {
732        let a = cpu_tensor(vec![2.0, 3.0, 4.0], vec![3]);
733        let b = cpu_tensor(vec![10.0, 10.0, 10.0], vec![3]);
734
735        let device = GpuDevice::new(0).expect("CUDA device 0");
736        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
737        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
738
739        let gc = ga.mul(&gb).expect("gpu mul");
740        let result = gc.cpu().expect("cpu");
741        let data = result.data().unwrap();
742        assert!((data[0] - 20.0).abs() < 1e-6);
743        assert!((data[1] - 30.0).abs() < 1e-6);
744        assert!((data[2] - 40.0).abs() < 1e-6);
745    }
746
747    // -- neg ------------------------------------------------------------------
748
749    #[test]
750    fn gpu_tensor_neg() {
751        let t = cpu_tensor(vec![1.0, -2.0, 0.0, 3.5], vec![4]);
752        let gpu = cuda_default(&t).expect("cuda_default");
753        let out = gpu.neg().expect("neg");
754        let result = out.cpu().expect("cpu");
755        let data = result.data().unwrap();
756        assert!((data[0] - (-1.0)).abs() < 1e-6);
757        assert!((data[1] - 2.0).abs() < 1e-6);
758        assert!((data[2] - 0.0).abs() < 1e-6);
759        assert!((data[3] - (-3.5)).abs() < 1e-6);
760    }
761
762    // -- matmul ---------------------------------------------------------------
763
764    #[test]
765    fn gpu_tensor_matmul_basic() {
766        // A = [[1, 2, 3], [4, 5, 6]]  (2x3)
767        // B = [[7, 8], [9, 10], [11, 12]]  (3x2)
768        // C = [[58, 64], [139, 154]]  (2x2)
769        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
770        let b = cpu_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
771
772        let device = GpuDevice::new(0).expect("CUDA device 0");
773        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
774        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
775
776        let gc = ga.matmul(&gb).expect("gpu matmul");
777        assert_eq!(gc.shape(), &[2, 2]);
778
779        let result = gc.cpu().expect("cpu");
780        let data = result.data().unwrap();
781        assert!((data[0] - 58.0).abs() < 1e-4);
782        assert!((data[1] - 64.0).abs() < 1e-4);
783        assert!((data[2] - 139.0).abs() < 1e-4);
784        assert!((data[3] - 154.0).abs() < 1e-4);
785    }
786
787    #[test]
788    fn gpu_tensor_matmul_identity() {
789        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
790        let i = cpu_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
791
792        let device = GpuDevice::new(0).expect("CUDA device 0");
793        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
794        let gi = tensor_to_gpu(&i, &device).expect("i to gpu");
795
796        let gc = ga.matmul(&gi).expect("gpu matmul identity");
797        let result = gc.cpu().expect("cpu");
798        let data = result.data().unwrap();
799        assert!((data[0] - 1.0).abs() < 1e-6);
800        assert!((data[1] - 2.0).abs() < 1e-6);
801        assert!((data[2] - 3.0).abs() < 1e-6);
802        assert!((data[3] - 4.0).abs() < 1e-6);
803    }
804
805    #[test]
806    fn gpu_tensor_matmul_inner_dim_mismatch() {
807        // A is [2, 3], B is [2, 2] -- inner dims 3 != 2
808        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
809        let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
810
811        let device = GpuDevice::new(0).expect("CUDA device 0");
812        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
813        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
814
815        let err = ga.matmul(&gb).unwrap_err();
816        match err {
817            GpuError::ShapeMismatch { op: "matmul", .. } => {}
818            other => panic!("unexpected error: {other}"),
819        }
820    }
821
822    #[test]
823    fn gpu_tensor_matmul_not_2d() {
824        // A is [6] (1-D), should fail
825        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
826        let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
827
828        let device = GpuDevice::new(0).expect("CUDA device 0");
829        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
830        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
831
832        let err = ga.matmul(&gb).unwrap_err();
833        match err {
834            GpuError::ShapeMismatch { op: "matmul", .. } => {}
835            other => panic!("unexpected error: {other}"),
836        }
837    }
838
839    // -- shape mismatch error -------------------------------------------------
840
841    #[test]
842    fn gpu_tensor_add_shape_mismatch() {
843        let a = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
844        let b = cpu_tensor(vec![1.0, 2.0], vec![2]);
845
846        let device = GpuDevice::new(0).expect("CUDA device 0");
847        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
848        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
849
850        let err = ga.add(&gb).unwrap_err();
851        match err {
852            GpuError::LengthMismatch { .. } => {}
853            other => panic!("unexpected error: {other}"),
854        }
855    }
856
857    // -- empty tensor ---------------------------------------------------------
858
859    #[test]
860    fn gpu_tensor_empty_round_trip() {
861        let t = cpu_tensor(vec![], vec![0]);
862        let gpu = cuda_default(&t).expect("cuda_default");
863        assert_eq!(gpu.numel(), 0);
864        assert_eq!(gpu.shape(), &[0]);
865
866        let back = gpu.cpu().expect("cpu");
867        assert_eq!(back.shape(), &[0]);
868        assert_eq!(back.data().unwrap().len(), 0);
869    }
870
871    // -- scalar tensor --------------------------------------------------------
872
873    #[test]
874    fn gpu_tensor_scalar_round_trip() {
875        let storage = TensorStorage::cpu(vec![42.0f32]);
876        let t = Tensor::from_storage(storage, vec![], false).expect("scalar");
877        let gpu = cuda_default(&t).expect("cuda_default");
878        assert_eq!(gpu.shape(), &[] as &[usize]);
879        assert_eq!(gpu.numel(), 1);
880
881        let back = gpu.cpu().expect("cpu");
882        assert!(back.is_scalar());
883        assert!((back.item().unwrap() - 42.0).abs() < 1e-6);
884    }
885}