Skip to main content

ferrotorch_gpu/
tensor_bridge.rs

1//! Bridge between ferrotorch-core `Tensor<T>` and GPU operations.
2//!
3//! Because ferrotorch-core cannot depend on ferrotorch-gpu (that would be a
4//! circular dependency), all GPU integration lives here. This module provides:
5//!
6//! - [`GpuTensor<T>`] — a wrapper combining a [`CudaBuffer<T>`] with shape
7//!   metadata and the originating [`GpuDevice`].
8//! - Transfer functions to move data between CPU [`Tensor`] and GPU.
9//! - Elementwise arithmetic on [`GpuTensor`] backed by PTX kernels.
10//!
11//! # f32-only kernels
12//!
13//! The PTX kernels are currently f32-only. Operations on `GpuTensor<f64>` fall
14//! back to a CPU round-trip (copy to host, compute, copy back). The type
15//! parameter is kept for API consistency — once f64 PTX kernels are added,
16//! the fallback disappears transparently.
17
18use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
19
20use crate::blas::{gpu_matmul_f32, gpu_matmul_f64};
21use crate::buffer::CudaBuffer;
22use crate::conv::gpu_conv2d_f32;
23use crate::device::GpuDevice;
24use crate::error::{GpuError, GpuResult};
25use crate::kernels::{
26    gpu_add, gpu_add_f64, gpu_mul, gpu_mul_f64, gpu_neg, gpu_neg_f64, gpu_relu, gpu_relu_f64,
27    gpu_sub, gpu_sub_f64,
28};
29use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
30
31// ---------------------------------------------------------------------------
32// GpuFloat — Float + DeviceRepr (when cuda is enabled)
33// ---------------------------------------------------------------------------
34//
35// cudarc's transfer functions require `T: DeviceRepr`. Both f32 and f64
36// implement it, but we can't unconditionally name the trait when the `cuda`
37// feature is off (cudarc isn't compiled). This helper trait bridges the gap.
38
39/// Trait alias: `Float` types that can be transferred to/from GPU.
40///
41/// When the `cuda` feature is enabled this adds `cudarc::driver::DeviceRepr`.
42/// When disabled, it is identical to [`Float`].
43#[cfg(feature = "cuda")]
44pub trait GpuFloat: Float + cudarc::driver::DeviceRepr {}
45
46#[cfg(feature = "cuda")]
47impl GpuFloat for f32 {}
48#[cfg(feature = "cuda")]
49impl GpuFloat for f64 {}
50
51#[cfg(not(feature = "cuda"))]
52pub trait GpuFloat: Float {}
53
54#[cfg(not(feature = "cuda"))]
55impl GpuFloat for f32 {}
56#[cfg(not(feature = "cuda"))]
57impl GpuFloat for f64 {}
58
59// ---------------------------------------------------------------------------
60// GpuTensor
61// ---------------------------------------------------------------------------
62
63/// A tensor residing on a CUDA GPU.
64///
65/// Wraps a [`CudaBuffer<T>`] with shape metadata and a reference to the
66/// [`GpuDevice`] that owns the memory. Created by [`tensor_to_gpu`] or
67/// the convenience functions [`cuda`] / [`cuda_default`].
68///
69/// Convert back to a CPU [`Tensor`] with [`GpuTensor::cpu`] or the free
70/// function [`tensor_to_cpu`].
71pub struct GpuTensor<T: GpuFloat> {
72    buffer: CudaBuffer<T>,
73    shape: Vec<usize>,
74    device: GpuDevice,
75}
76
77impl<T: GpuFloat> GpuTensor<T> {
78    /// The shape of this tensor.
79    #[inline]
80    pub fn shape(&self) -> &[usize] {
81        &self.shape
82    }
83
84    /// Total number of elements.
85    #[inline]
86    pub fn numel(&self) -> usize {
87        self.shape.iter().product()
88    }
89
90    /// The GPU device that holds this tensor's data.
91    #[inline]
92    pub fn device(&self) -> &GpuDevice {
93        &self.device
94    }
95
96    /// Borrow the underlying [`CudaBuffer`].
97    #[inline]
98    pub fn buffer(&self) -> &CudaBuffer<T> {
99        &self.buffer
100    }
101
102    /// Number of dimensions.
103    #[inline]
104    pub fn ndim(&self) -> usize {
105        self.shape.len()
106    }
107
108    /// Copy this tensor back to CPU, returning a [`Tensor<T>`].
109    ///
110    /// This is a convenience wrapper around [`tensor_to_cpu`].
111    pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
112        tensor_to_cpu(self)
113    }
114}
115
116impl<T: GpuFloat> std::fmt::Debug for GpuTensor<T> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("GpuTensor")
119            .field("shape", &self.shape)
120            .field("numel", &self.numel())
121            .field("device_ordinal", &self.device.ordinal())
122            .finish_non_exhaustive()
123    }
124}
125
126// ---------------------------------------------------------------------------
127// Arithmetic operations
128// ---------------------------------------------------------------------------
129
130/// Returns `true` if `T` is `f32` (the type our PTX kernels support).
131#[inline]
132fn is_f32<T: GpuFloat>() -> bool {
133    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
134}
135
136/// Returns `true` if `T` is `f64`.
137#[inline]
138fn is_f64<T: GpuFloat>() -> bool {
139    std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>()
140}
141
142/// Shape-validation helper for binary operations.
143fn validate_shapes<T: GpuFloat>(a: &GpuTensor<T>, b: &GpuTensor<T>) -> GpuResult<()> {
144    if a.shape() != b.shape() {
145        return Err(GpuError::LengthMismatch {
146            a: a.numel(),
147            b: b.numel(),
148        });
149    }
150    if a.device.ordinal() != b.device.ordinal() {
151        return Err(GpuError::DeviceMismatch {
152            expected: a.device.ordinal(),
153            got: b.device.ordinal(),
154        });
155    }
156    Ok(())
157}
158
159impl<T: GpuFloat> GpuTensor<T> {
160    /// Elementwise addition: `out[i] = self[i] + other[i]`.
161    ///
162    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
163    ///
164    /// # Errors
165    ///
166    /// - [`GpuError::LengthMismatch`] if shapes differ.
167    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
168    /// - [`GpuError::Driver`] on CUDA runtime errors.
169    pub fn add(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
170        validate_shapes(self, other)?;
171        if is_f32::<T>() {
172            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
173            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
174            let out_buf = gpu_add(a_buf, b_buf, &self.device)?;
175            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
176            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
177        } else if is_f64::<T>() {
178            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
179            let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
180            let out_buf = gpu_add_f64(a_buf, b_buf, &self.device)?;
181            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
182            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
183        } else {
184            binary_cpu_fallback(self, other, |a, b| a + b)
185        }
186    }
187
188    /// Elementwise subtraction: `out[i] = self[i] - other[i]`.
189    ///
190    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
191    pub fn sub(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
192        validate_shapes(self, other)?;
193        if is_f32::<T>() {
194            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
195            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
196            let out_buf = gpu_sub(a_buf, b_buf, &self.device)?;
197            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
198            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
199        } else if is_f64::<T>() {
200            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
201            let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
202            let out_buf = gpu_sub_f64(a_buf, b_buf, &self.device)?;
203            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
204            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
205        } else {
206            binary_cpu_fallback(self, other, |a, b| a - b)
207        }
208    }
209
210    /// Elementwise multiplication: `out[i] = self[i] * other[i]`.
211    ///
212    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
213    pub fn mul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
214        validate_shapes(self, other)?;
215        if is_f32::<T>() {
216            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
217            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
218            let out_buf = gpu_mul(a_buf, b_buf, &self.device)?;
219            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
220            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
221        } else if is_f64::<T>() {
222            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
223            let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
224            let out_buf = gpu_mul_f64(a_buf, b_buf, &self.device)?;
225            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
226            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
227        } else {
228            binary_cpu_fallback(self, other, |a, b| a * b)
229        }
230    }
231
232    /// Elementwise negation: `out[i] = -self[i]`.
233    ///
234    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
235    pub fn neg(&self) -> GpuResult<GpuTensor<T>> {
236        if is_f32::<T>() {
237            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
238            let out_buf = gpu_neg(a_buf, &self.device)?;
239            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
240            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
241        } else if is_f64::<T>() {
242            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
243            let out_buf = gpu_neg_f64(a_buf, &self.device)?;
244            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
245            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
246        } else {
247            unary_cpu_fallback(self, |x| -x)
248        }
249    }
250
251    /// Elementwise ReLU: `out[i] = max(self[i], 0)`.
252    ///
253    /// Uses a PTX kernel for `f32`; falls back to CPU round-trip for `f64`.
254    pub fn relu(&self) -> GpuResult<GpuTensor<T>> {
255        if is_f32::<T>() {
256            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
257            let out_buf = gpu_relu(a_buf, &self.device)?;
258            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
259            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
260        } else if is_f64::<T>() {
261            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
262            let out_buf = gpu_relu_f64(a_buf, &self.device)?;
263            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
264            Ok(GpuTensor { buffer: out_buf, shape: self.shape.clone(), device: self.device.clone() })
265        } else {
266            unary_cpu_fallback(self, |x| {
267                let z = <T as num_traits::Zero>::zero();
268                if x > z { x } else { z }
269            })
270        }
271    }
272
273    /// Matrix multiplication: `C = self @ other`.
274    ///
275    /// Both tensors must be 2-D. `self` has shape `[m, k]` and `other` has
276    /// shape `[k, n]`. The result has shape `[m, n]`.
277    ///
278    /// Uses cuBLAS SGEMM for `f32` and DGEMM for `f64`.
279    ///
280    /// # Errors
281    ///
282    /// - [`GpuError::ShapeMismatch`] if either tensor is not 2-D or if the
283    ///   inner dimensions do not match (`self.shape[1] != other.shape[0]`).
284    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
285    /// - [`GpuError::Blas`] on cuBLAS runtime errors.
286    pub fn matmul(&self, other: &GpuTensor<T>) -> GpuResult<GpuTensor<T>> {
287        // Validate 2-D shapes.
288        if self.ndim() != 2 {
289            return Err(GpuError::ShapeMismatch {
290                op: "matmul",
291                expected: vec![0, 0], // placeholder: "expected 2-D"
292                got: self.shape.clone(),
293            });
294        }
295        if other.ndim() != 2 {
296            return Err(GpuError::ShapeMismatch {
297                op: "matmul",
298                expected: vec![0, 0],
299                got: other.shape.clone(),
300            });
301        }
302
303        let m = self.shape[0];
304        let k = self.shape[1];
305        let k2 = other.shape[0];
306        let n = other.shape[1];
307
308        if k != k2 {
309            return Err(GpuError::ShapeMismatch {
310                op: "matmul",
311                expected: vec![k, n],
312                got: vec![k2, n],
313            });
314        }
315
316        if self.device.ordinal() != other.device.ordinal() {
317            return Err(GpuError::DeviceMismatch {
318                expected: self.device.ordinal(),
319                got: other.device.ordinal(),
320            });
321        }
322
323        if is_f32::<T>() {
324            let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
325            let b_buf = unsafe { transmute_buffer_ref::<T, f32>(&other.buffer) };
326            let out_buf = gpu_matmul_f32(a_buf, b_buf, m, k, n, &self.device)?;
327            let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
328            Ok(GpuTensor {
329                buffer: out_buf,
330                shape: vec![m, n],
331                device: self.device.clone(),
332            })
333        } else {
334            // f64 path
335            let a_buf = unsafe { transmute_buffer_ref::<T, f64>(&self.buffer) };
336            let b_buf = unsafe { transmute_buffer_ref::<T, f64>(&other.buffer) };
337            let out_buf = gpu_matmul_f64(a_buf, b_buf, m, k, n, &self.device)?;
338            let out_buf = unsafe { transmute_buffer::<f64, T>(out_buf) };
339            Ok(GpuTensor {
340                buffer: out_buf,
341                shape: vec![m, n],
342                device: self.device.clone(),
343            })
344        }
345    }
346
347    /// 2-D convolution: `output = conv2d(self, weight, bias)`.
348    ///
349    /// Uses im2col (CPU) + cuBLAS GEMM (GPU) — no cuDNN required.
350    ///
351    /// `self` must have shape `[B, C_in, H, W]` and `weight` must have
352    /// shape `[C_out, C_in, kH, kW]`. `bias`, if provided, must have
353    /// shape `[C_out]`. The result has shape `[B, C_out, H_out, W_out]`.
354    ///
355    /// Currently only supports `f32`. For `f64` tensors, returns
356    /// [`GpuError::ShapeMismatch`] (f64 conv path not yet implemented).
357    ///
358    /// # Errors
359    ///
360    /// - [`GpuError::ShapeMismatch`] if tensor dimensions are wrong, channel
361    ///   counts don't match, or if `T` is not `f32`.
362    /// - [`GpuError::DeviceMismatch`] if tensors are on different devices.
363    /// - [`GpuError::Blas`] on cuBLAS runtime errors.
364    pub fn conv2d(
365        &self,
366        weight: &GpuTensor<T>,
367        bias: Option<&GpuTensor<T>>,
368        stride: (usize, usize),
369        padding: (usize, usize),
370    ) -> GpuResult<GpuTensor<T>> {
371        // Validate 4-D input.
372        if self.ndim() != 4 {
373            return Err(GpuError::ShapeMismatch {
374                op: "conv2d",
375                expected: vec![0, 0, 0, 0],
376                got: self.shape.clone(),
377            });
378        }
379        // Validate 4-D weight.
380        if weight.ndim() != 4 {
381            return Err(GpuError::ShapeMismatch {
382                op: "conv2d",
383                expected: vec![0, 0, 0, 0],
384                got: weight.shape.clone(),
385            });
386        }
387        // Validate 1-D bias.
388        if let Some(b) = bias {
389            if b.ndim() != 1 {
390                return Err(GpuError::ShapeMismatch {
391                    op: "conv2d",
392                    expected: vec![weight.shape[0]],
393                    got: b.shape.clone(),
394                });
395            }
396        }
397        // Device consistency.
398        if self.device.ordinal() != weight.device.ordinal() {
399            return Err(GpuError::DeviceMismatch {
400                expected: self.device.ordinal(),
401                got: weight.device.ordinal(),
402            });
403        }
404        if let Some(b) = bias {
405            if self.device.ordinal() != b.device.ordinal() {
406                return Err(GpuError::DeviceMismatch {
407                    expected: self.device.ordinal(),
408                    got: b.device.ordinal(),
409                });
410            }
411        }
412
413        if !is_f32::<T>() {
414            return Err(GpuError::ShapeMismatch {
415                op: "conv2d",
416                expected: vec![],
417                got: vec![],
418            });
419        }
420
421        let input_shape: [usize; 4] = [self.shape[0], self.shape[1], self.shape[2], self.shape[3]];
422        let weight_shape: [usize; 4] = [
423            weight.shape[0],
424            weight.shape[1],
425            weight.shape[2],
426            weight.shape[3],
427        ];
428
429        let a_buf = unsafe { transmute_buffer_ref::<T, f32>(&self.buffer) };
430        let w_buf = unsafe { transmute_buffer_ref::<T, f32>(&weight.buffer) };
431        let b_buf = bias.map(|b| unsafe { transmute_buffer_ref::<T, f32>(&b.buffer) });
432
433        let (out_buf, out_shape) = gpu_conv2d_f32(
434            a_buf,
435            w_buf,
436            b_buf,
437            input_shape,
438            weight_shape,
439            stride,
440            padding,
441            &self.device,
442        )?;
443
444        let out_buf = unsafe { transmute_buffer::<f32, T>(out_buf) };
445        Ok(GpuTensor {
446            buffer: out_buf,
447            shape: out_shape.to_vec(),
448            device: self.device.clone(),
449        })
450    }
451}
452
453// ---------------------------------------------------------------------------
454// Transmute helpers for CudaBuffer<T> <-> CudaBuffer<f32>
455// ---------------------------------------------------------------------------
456//
457// The PTX kernel functions take `&CudaBuffer<f32>`. When T == f32 the
458// CudaBuffer<T> has an identical layout, so we can safely reinterpret.
459// These helpers are only called after an `is_f32::<T>()` guard.
460
461/// Reinterpret a `&CudaBuffer<T>` as `&CudaBuffer<U>`.
462///
463/// # Safety
464///
465/// Caller must have verified `size_of::<T>() == size_of::<U>()` and
466/// `align_of::<T>() == align_of::<U>()`.
467#[cfg(feature = "cuda")]
468unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
469    debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
470    debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
471    // CudaBuffer<T> and CudaBuffer<U> have identical layout when T and U
472    // are the same size — CudaSlice is a device pointer + length, both
473    // size-independent.
474    unsafe { &*(buf as *const CudaBuffer<T> as *const CudaBuffer<U>) }
475}
476
477/// Reinterpret an owned `CudaBuffer<U>` as `CudaBuffer<T>`.
478///
479/// # Safety
480///
481/// Caller must have verified `size_of::<U>() == size_of::<T>()` and
482/// `align_of::<U>() == align_of::<T>()`.
483#[cfg(feature = "cuda")]
484unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
485    debug_assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
486    debug_assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
487    // Move the buffer without running U's drop — T's drop will handle it.
488    let result = unsafe { std::ptr::read(&buf as *const CudaBuffer<U> as *const CudaBuffer<T>) };
489    std::mem::forget(buf);
490    result
491}
492
493// Stubs when cuda is not enabled — the transmute helpers are never called
494// because all kernel functions return NoCudaFeature, but we still need them
495// to exist so the module compiles.
496
497#[cfg(not(feature = "cuda"))]
498unsafe fn transmute_buffer_ref<T, U>(buf: &CudaBuffer<T>) -> &CudaBuffer<U> {
499    let _ = buf;
500    unreachable!("transmute_buffer_ref called without cuda feature")
501}
502
503#[cfg(not(feature = "cuda"))]
504unsafe fn transmute_buffer<U, T>(buf: CudaBuffer<U>) -> CudaBuffer<T> {
505    let _ = buf;
506    unreachable!("transmute_buffer called without cuda feature")
507}
508
509// ---------------------------------------------------------------------------
510// CPU fallback helpers for non-f32 types
511// ---------------------------------------------------------------------------
512
513/// Binary operation fallback: copy both operands to CPU, apply `op`, copy back.
514fn binary_cpu_fallback<T: GpuFloat>(
515    a: &GpuTensor<T>,
516    b: &GpuTensor<T>,
517    op: fn(T, T) -> T,
518) -> GpuResult<GpuTensor<T>> {
519    let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
520    let b_cpu = gpu_to_cpu(&b.buffer, &b.device)?;
521    let result: Vec<T> = a_cpu
522        .iter()
523        .zip(b_cpu.iter())
524        .map(|(&x, &y)| op(x, y))
525        .collect();
526    let out_buf = cpu_to_gpu(&result, &a.device)?;
527    Ok(GpuTensor {
528        buffer: out_buf,
529        shape: a.shape.clone(),
530        device: a.device.clone(),
531    })
532}
533
534/// Unary operation fallback: copy operand to CPU, apply `op`, copy back.
535fn unary_cpu_fallback<T: GpuFloat>(a: &GpuTensor<T>, op: fn(T) -> T) -> GpuResult<GpuTensor<T>> {
536    let a_cpu = gpu_to_cpu(&a.buffer, &a.device)?;
537    let result: Vec<T> = a_cpu.iter().map(|&x| op(x)).collect();
538    let out_buf = cpu_to_gpu(&result, &a.device)?;
539    Ok(GpuTensor {
540        buffer: out_buf,
541        shape: a.shape.clone(),
542        device: a.device.clone(),
543    })
544}
545
546// ---------------------------------------------------------------------------
547// Transfer functions: Tensor <-> GpuTensor
548// ---------------------------------------------------------------------------
549
550/// Move a CPU [`Tensor<T>`] to a GPU, returning a [`GpuTensor<T>`].
551///
552/// The tensor must be contiguous and reside on `Device::Cpu`. Its flat data
553/// is copied to the given [`GpuDevice`] via a host-to-device transfer.
554///
555/// # Errors
556///
557/// - Returns [`GpuError::Driver`] on CUDA allocation or copy failures.
558/// - Returns [`GpuError::LengthMismatch`] if the tensor is not contiguous.
559pub fn tensor_to_gpu<T: GpuFloat>(
560    tensor: &Tensor<T>,
561    device: &GpuDevice,
562) -> GpuResult<GpuTensor<T>> {
563    // Ensure the tensor is contiguous so data() gives a proper flat slice.
564    if !tensor.is_contiguous() {
565        return Err(GpuError::LengthMismatch {
566            a: tensor.numel(),
567            b: tensor.data().map_or(0, |d| d.len()),
568        });
569    }
570
571    // Extract the flat data from the CPU tensor.
572    let data = tensor.data().map_err(|_e| GpuError::InvalidDevice {
573        ordinal: device.ordinal(),
574        count: 0,
575    })?;
576
577    let buffer = cpu_to_gpu(data, device)?;
578    Ok(GpuTensor {
579        buffer,
580        shape: tensor.shape().to_vec(),
581        device: device.clone(),
582    })
583}
584
585/// Move a [`GpuTensor<T>`] back to CPU, returning a [`Tensor<T>`].
586///
587/// Performs a device-to-host copy and wraps the resulting `Vec<T>` in a
588/// new leaf [`Tensor`] with `requires_grad = false`.
589///
590/// # Errors
591///
592/// - Returns an error if the device-to-host copy fails.
593pub fn tensor_to_cpu<T: GpuFloat>(gpu_tensor: &GpuTensor<T>) -> FerrotorchResult<Tensor<T>> {
594    let host_data = gpu_to_cpu(&gpu_tensor.buffer, &gpu_tensor.device).map_err(|e| {
595        FerrotorchError::InvalidArgument {
596            message: format!("GPU-to-CPU transfer failed: {e}"),
597        }
598    })?;
599
600    let storage = TensorStorage::cpu(host_data);
601    Tensor::from_storage(storage, gpu_tensor.shape.clone(), false)
602}
603
604// ---------------------------------------------------------------------------
605// Convenience free functions
606// ---------------------------------------------------------------------------
607
608/// Move a CPU [`Tensor<T>`] to CUDA device with the given ordinal.
609///
610/// Shorthand for creating a [`GpuDevice`] and calling [`tensor_to_gpu`].
611///
612/// # Errors
613///
614/// Returns a [`GpuError`] if the device cannot be initialized or the
615/// transfer fails.
616pub fn cuda<T: GpuFloat>(tensor: &Tensor<T>, ordinal: usize) -> GpuResult<GpuTensor<T>> {
617    let device = GpuDevice::new(ordinal)?;
618    tensor_to_gpu(tensor, &device)
619}
620
621/// Move a CPU [`Tensor<T>`] to CUDA device 0.
622///
623/// Equivalent to `cuda(tensor, 0)`.
624pub fn cuda_default<T: GpuFloat>(tensor: &Tensor<T>) -> GpuResult<GpuTensor<T>> {
625    cuda(tensor, 0)
626}
627
628// ---------------------------------------------------------------------------
629// Tests
630// ---------------------------------------------------------------------------
631
632#[cfg(test)]
633#[cfg(feature = "cuda")]
634mod tests {
635    use super::*;
636    use ferrotorch_core::{Tensor, TensorStorage};
637
638    /// Helper: create a CPU tensor from a flat vec with the given shape.
639    fn cpu_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
640        let storage = TensorStorage::cpu(data);
641        Tensor::from_storage(storage, shape, false).expect("cpu_tensor")
642    }
643
644    // -- round-trip -----------------------------------------------------------
645
646    #[test]
647    fn tensor_to_gpu_round_trip() {
648        let t = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
649        let gpu = cuda_default(&t).expect("cuda_default");
650        let back = gpu.cpu().expect("cpu()");
651
652        assert_eq!(back.shape(), &[2, 3]);
653        assert_eq!(back.data().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
654    }
655
656    // -- shape preservation ---------------------------------------------------
657
658    #[test]
659    fn gpu_tensor_shape_preserved() {
660        let t = cpu_tensor(vec![1.0; 24], vec![2, 3, 4]);
661        let gpu = cuda_default(&t).expect("cuda_default");
662
663        assert_eq!(gpu.shape(), &[2, 3, 4]);
664        assert_eq!(gpu.numel(), 24);
665        assert_eq!(gpu.ndim(), 3);
666    }
667
668    // -- add ------------------------------------------------------------------
669
670    #[test]
671    fn gpu_tensor_add() {
672        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
673        let b = cpu_tensor(vec![10.0, 20.0, 30.0, 40.0], vec![4]);
674
675        let device = GpuDevice::new(0).expect("CUDA device 0");
676        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
677        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
678
679        let gc = ga.add(&gb).expect("gpu add");
680        let result = gc.cpu().expect("cpu");
681
682        assert_eq!(result.shape(), &[4]);
683        let data = result.data().unwrap();
684        assert!((data[0] - 11.0).abs() < 1e-6);
685        assert!((data[1] - 22.0).abs() < 1e-6);
686        assert!((data[2] - 33.0).abs() < 1e-6);
687        assert!((data[3] - 44.0).abs() < 1e-6);
688    }
689
690    // -- relu -----------------------------------------------------------------
691
692    #[test]
693    fn gpu_tensor_relu() {
694        let t = cpu_tensor(vec![-3.0, -1.0, 0.0, 1.0, 3.0], vec![5]);
695        let gpu = cuda_default(&t).expect("cuda_default");
696        let out = gpu.relu().expect("relu");
697        let result = out.cpu().expect("cpu");
698
699        let data = result.data().unwrap();
700        assert!((data[0] - 0.0).abs() < 1e-6);
701        assert!((data[1] - 0.0).abs() < 1e-6);
702        assert!((data[2] - 0.0).abs() < 1e-6);
703        assert!((data[3] - 1.0).abs() < 1e-6);
704        assert!((data[4] - 3.0).abs() < 1e-6);
705    }
706
707    // -- tensor_to_cpu values -------------------------------------------------
708
709    #[test]
710    fn tensor_to_cpu_correct_values() {
711        let original = vec![0.5, -1.5, 2.25, 0.0, 100.0, -0.001];
712        let t = cpu_tensor(original.clone(), vec![2, 3]);
713        let gpu = cuda_default(&t).expect("cuda_default");
714        let back = tensor_to_cpu(&gpu).expect("tensor_to_cpu");
715
716        let data = back.data().unwrap();
717        for (i, (&got, &expected)) in data.iter().zip(original.iter()).enumerate() {
718            assert!(
719                (got - expected).abs() < 1e-6,
720                "element {i}: got {got}, expected {expected}",
721            );
722        }
723    }
724
725    // -- sub ------------------------------------------------------------------
726
727    #[test]
728    fn gpu_tensor_sub() {
729        let a = cpu_tensor(vec![10.0, 20.0, 30.0], vec![3]);
730        let b = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
731
732        let device = GpuDevice::new(0).expect("CUDA device 0");
733        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
734        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
735
736        let gc = ga.sub(&gb).expect("gpu sub");
737        let result = gc.cpu().expect("cpu");
738        let data = result.data().unwrap();
739        assert!((data[0] - 9.0).abs() < 1e-6);
740        assert!((data[1] - 18.0).abs() < 1e-6);
741        assert!((data[2] - 27.0).abs() < 1e-6);
742    }
743
744    // -- mul ------------------------------------------------------------------
745
746    #[test]
747    fn gpu_tensor_mul() {
748        let a = cpu_tensor(vec![2.0, 3.0, 4.0], vec![3]);
749        let b = cpu_tensor(vec![10.0, 10.0, 10.0], vec![3]);
750
751        let device = GpuDevice::new(0).expect("CUDA device 0");
752        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
753        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
754
755        let gc = ga.mul(&gb).expect("gpu mul");
756        let result = gc.cpu().expect("cpu");
757        let data = result.data().unwrap();
758        assert!((data[0] - 20.0).abs() < 1e-6);
759        assert!((data[1] - 30.0).abs() < 1e-6);
760        assert!((data[2] - 40.0).abs() < 1e-6);
761    }
762
763    // -- neg ------------------------------------------------------------------
764
765    #[test]
766    fn gpu_tensor_neg() {
767        let t = cpu_tensor(vec![1.0, -2.0, 0.0, 3.5], vec![4]);
768        let gpu = cuda_default(&t).expect("cuda_default");
769        let out = gpu.neg().expect("neg");
770        let result = out.cpu().expect("cpu");
771        let data = result.data().unwrap();
772        assert!((data[0] - (-1.0)).abs() < 1e-6);
773        assert!((data[1] - 2.0).abs() < 1e-6);
774        assert!((data[2] - 0.0).abs() < 1e-6);
775        assert!((data[3] - (-3.5)).abs() < 1e-6);
776    }
777
778    // -- matmul ---------------------------------------------------------------
779
780    #[test]
781    fn gpu_tensor_matmul_basic() {
782        // A = [[1, 2, 3], [4, 5, 6]]  (2x3)
783        // B = [[7, 8], [9, 10], [11, 12]]  (3x2)
784        // C = [[58, 64], [139, 154]]  (2x2)
785        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
786        let b = cpu_tensor(vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0], vec![3, 2]);
787
788        let device = GpuDevice::new(0).expect("CUDA device 0");
789        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
790        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
791
792        let gc = ga.matmul(&gb).expect("gpu matmul");
793        assert_eq!(gc.shape(), &[2, 2]);
794
795        let result = gc.cpu().expect("cpu");
796        let data = result.data().unwrap();
797        assert!((data[0] - 58.0).abs() < 1e-4);
798        assert!((data[1] - 64.0).abs() < 1e-4);
799        assert!((data[2] - 139.0).abs() < 1e-4);
800        assert!((data[3] - 154.0).abs() < 1e-4);
801    }
802
803    #[test]
804    fn gpu_tensor_matmul_identity() {
805        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
806        let i = cpu_tensor(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
807
808        let device = GpuDevice::new(0).expect("CUDA device 0");
809        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
810        let gi = tensor_to_gpu(&i, &device).expect("i to gpu");
811
812        let gc = ga.matmul(&gi).expect("gpu matmul identity");
813        let result = gc.cpu().expect("cpu");
814        let data = result.data().unwrap();
815        assert!((data[0] - 1.0).abs() < 1e-6);
816        assert!((data[1] - 2.0).abs() < 1e-6);
817        assert!((data[2] - 3.0).abs() < 1e-6);
818        assert!((data[3] - 4.0).abs() < 1e-6);
819    }
820
821    #[test]
822    fn gpu_tensor_matmul_inner_dim_mismatch() {
823        // A is [2, 3], B is [2, 2] -- inner dims 3 != 2
824        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
825        let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
826
827        let device = GpuDevice::new(0).expect("CUDA device 0");
828        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
829        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
830
831        let err = ga.matmul(&gb).unwrap_err();
832        match err {
833            GpuError::ShapeMismatch { op: "matmul", .. } => {}
834            other => panic!("unexpected error: {other}"),
835        }
836    }
837
838    #[test]
839    fn gpu_tensor_matmul_not_2d() {
840        // A is [6] (1-D), should fail
841        let a = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
842        let b = cpu_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
843
844        let device = GpuDevice::new(0).expect("CUDA device 0");
845        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
846        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
847
848        let err = ga.matmul(&gb).unwrap_err();
849        match err {
850            GpuError::ShapeMismatch { op: "matmul", .. } => {}
851            other => panic!("unexpected error: {other}"),
852        }
853    }
854
855    // -- shape mismatch error -------------------------------------------------
856
857    #[test]
858    fn gpu_tensor_add_shape_mismatch() {
859        let a = cpu_tensor(vec![1.0, 2.0, 3.0], vec![3]);
860        let b = cpu_tensor(vec![1.0, 2.0], vec![2]);
861
862        let device = GpuDevice::new(0).expect("CUDA device 0");
863        let ga = tensor_to_gpu(&a, &device).expect("a to gpu");
864        let gb = tensor_to_gpu(&b, &device).expect("b to gpu");
865
866        let err = ga.add(&gb).unwrap_err();
867        match err {
868            GpuError::LengthMismatch { .. } => {}
869            other => panic!("unexpected error: {other}"),
870        }
871    }
872
873    // -- empty tensor ---------------------------------------------------------
874
875    #[test]
876    fn gpu_tensor_empty_round_trip() {
877        let t = cpu_tensor(vec![], vec![0]);
878        let gpu = cuda_default(&t).expect("cuda_default");
879        assert_eq!(gpu.numel(), 0);
880        assert_eq!(gpu.shape(), &[0]);
881
882        let back = gpu.cpu().expect("cpu");
883        assert_eq!(back.shape(), &[0]);
884        assert_eq!(back.data().unwrap().len(), 0);
885    }
886
887    // -- scalar tensor --------------------------------------------------------
888
889    #[test]
890    fn gpu_tensor_scalar_round_trip() {
891        let storage = TensorStorage::cpu(vec![42.0f32]);
892        let t = Tensor::from_storage(storage, vec![], false).expect("scalar");
893        let gpu = cuda_default(&t).expect("cuda_default");
894        assert_eq!(gpu.shape(), &[] as &[usize]);
895        assert_eq!(gpu.numel(), 1);
896
897        let back = gpu.cpu().expect("cpu");
898        assert!(back.is_scalar());
899        assert!((back.item().unwrap() - 42.0).abs() < 1e-6);
900    }
901}