ferrotorch-core 0.1.3

Core tensor and autograd engine for ferrotorch — PyTorch in Rust
Documentation
use crate::device::Device;
use crate::dtype::Element;
use crate::gpu_dispatch::GpuBufferHandle;

/// The underlying data buffer for a tensor, tagged with its device.
///
/// Owns the data directly (`Vec<T>` for CPU, `GpuBufferHandle` for GPU).
/// The GPU handle is type-erased -- ferrotorch-gpu provides the concrete
/// implementation via the `GpuBackend` trait.
#[derive(Debug)]
pub struct TensorStorage<T: Element> {
    pub(crate) data: StorageBuffer<T>,
    pub(crate) device: Device,
}

/// Device-specific data buffer.
pub enum StorageBuffer<T: Element> {
    /// CPU heap-allocated data.
    Cpu(Vec<T>),
    /// GPU device memory, accessed via the registered `GpuBackend`.
    Gpu(GpuBufferHandle),
}

impl<T: Element> TensorStorage<T> {
    /// Create a new CPU storage from a `Vec<T>`.
    pub fn cpu(data: Vec<T>) -> Self {
        Self {
            data: StorageBuffer::Cpu(data),
            device: Device::Cpu,
        }
    }

    /// Create a new GPU storage from a handle.
    pub fn gpu(handle: GpuBufferHandle) -> Self {
        let device = Device::Cuda(handle.device_ordinal());
        Self {
            data: StorageBuffer::Gpu(handle),
            device,
        }
    }

    /// The device this storage resides on.
    #[inline]
    pub fn device(&self) -> Device {
        self.device
    }

    /// Total number of elements in the buffer.
    pub fn len(&self) -> usize {
        match &self.data {
            StorageBuffer::Cpu(v) => v.len(),
            StorageBuffer::Gpu(h) => h.len(),
        }
    }

    /// Whether the buffer is empty.
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Borrow the data as a slice. Only available for CPU storage.
    ///
    /// # Panics
    /// Panics if the tensor is on a GPU device. Call `.cpu()` first.
    pub fn as_slice(&self) -> &[T] {
        match &self.data {
            StorageBuffer::Cpu(v) => v.as_slice(),
            StorageBuffer::Gpu(_) => panic!("cannot access GPU tensor as CPU slice -- call .cpu() first"),
        }
    }

    /// Borrow the data as a mutable slice. Only available for CPU storage.
    pub fn as_mut_slice(&mut self) -> &mut [T] {
        match &mut self.data {
            StorageBuffer::Cpu(v) => v.as_mut_slice(),
            StorageBuffer::Gpu(_) => panic!("cannot mutate GPU tensor as CPU slice -- call .cpu() first"),
        }
    }

    /// Returns `true` if this storage is on CPU.
    #[inline]
    pub fn is_cpu(&self) -> bool {
        matches!(&self.data, StorageBuffer::Cpu(_))
    }

    /// Returns `true` if this storage is on a GPU.
    #[inline]
    pub fn is_gpu(&self) -> bool {
        matches!(&self.data, StorageBuffer::Gpu(_))
    }

    /// Get the GPU buffer handle. Returns `None` for CPU storage.
    pub fn gpu_handle(&self) -> Option<&GpuBufferHandle> {
        match &self.data {
            StorageBuffer::Gpu(h) => Some(h),
            StorageBuffer::Cpu(_) => None,
        }
    }
}

impl<T: Element> Clone for TensorStorage<T> {
    fn clone(&self) -> Self {
        match &self.data {
            StorageBuffer::Cpu(v) => Self {
                data: StorageBuffer::Cpu(v.clone()),
                device: self.device,
            },
            StorageBuffer::Gpu(h) => {
                // Clone GPU buffer via the registered backend
                if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
                    match backend.clone_buffer(h) {
                        Ok(cloned) => Self {
                            data: StorageBuffer::Gpu(cloned),
                            device: self.device,
                        },
                        Err(_) => panic!("failed to clone GPU buffer"),
                    }
                } else {
                    panic!("no GPU backend registered -- cannot clone GPU tensor")
                }
            }
        }
    }
}

impl<T: Element> std::fmt::Debug for StorageBuffer<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            StorageBuffer::Cpu(v) => write!(f, "Cpu({} elements)", v.len()),
            StorageBuffer::Gpu(h) => write!(f, "Gpu({h:?})"),
        }
    }
}