use core::fmt::Debug;
use crate::tensor::traits::{DType, Device};
pub trait TensorStorage: Clone + Send + Sync + Debug {
fn dtype(&self) -> DType;
fn device(&self) -> Device;
fn nbytes(&self) -> usize;
fn is_contiguous(&self) -> bool;
fn alignment(&self) -> usize;
}
#[derive(Clone, Debug)]
#[cfg(feature = "tensor")]
pub struct NdArrayStorage {
data: Vec<f64>,
dtype: DType,
}
#[cfg(feature = "tensor")]
impl NdArrayStorage {
pub fn new(data: Vec<f64>, dtype: DType) -> Self {
Self { data, dtype }
}
pub fn data(&self) -> &[f64] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [f64] {
&mut self.data
}
}
#[cfg(feature = "tensor")]
impl TensorStorage for NdArrayStorage {
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> Device {
Device::Cpu
}
fn nbytes(&self) -> usize {
self.data.len() * self.dtype.size_bytes()
}
fn is_contiguous(&self) -> bool {
true
}
fn alignment(&self) -> usize {
64 }
}
#[cfg(feature = "tensor-gpu")]
#[derive(Clone, Debug)]
pub struct DfdxStorage {
inner: dfdx::tensor::Tensor1D<f64>,
}
#[cfg(feature = "tensor-gpu")]
impl DfdxStorage {
pub fn from_dfdx(tensor: dfdx::tensor::Tensor1D<f64>) -> Self {
Self { inner: tensor }
}
pub fn inner(&self) -> &dfdx::tensor::Tensor1D<f64> {
&self.inner
}
}
#[cfg(feature = "tensor-gpu")]
impl TensorStorage for DfdxStorage {
fn dtype(&self) -> DType {
DType::F64
}
fn device(&self) -> Device {
Device::Cuda(0) }
fn nbytes(&self) -> usize {
self.inner.shape().0 * 8
}
fn is_contiguous(&self) -> bool {
true }
fn alignment(&self) -> usize {
128 }
}
#[cfg(feature = "tensor-candle")]
#[derive(Clone, Debug)]
pub struct CandleStorage {
inner: candle_core::Tensor,
}
#[cfg(feature = "tensor-candle")]
impl CandleStorage {
pub fn from_candle(tensor: candle_core::Tensor) -> Self {
Self { inner: tensor }
}
pub fn inner(&self) -> &candle_core::Tensor {
&self.inner
}
}
#[cfg(feature = "tensor-candle")]
impl TensorStorage for CandleStorage {
fn dtype(&self) -> DType {
match self.inner.dtype() {
candle_core::DType::F32 => DType::F32,
candle_core::DType::F64 => DType::F64,
candle_core::DType::I32 => DType::I32,
candle_core::DType::I64 => DType::I64,
_ => DType::F64,
}
}
fn device(&self) -> Device {
match self.inner.device() {
candle_core::Device::Cpu => Device::Cpu,
candle_core::Device::Cuda(_) => Device::Cuda(0),
candle_core::Device::Metal(_) => Device::Cpu, }
}
fn nbytes(&self) -> usize {
self.inner.elem_count() * self.dtype().size_bytes()
}
fn is_contiguous(&self) -> bool {
self.inner.is_contiguous()
}
fn alignment(&self) -> usize {
64
}
}
#[derive(Clone)]
pub enum UnifiedStorage {
NdArray(NdArrayStorage),
#[cfg(feature = "tensor-gpu")]
Dfdx(DfdxStorage),
#[cfg(feature = "tensor-candle")]
Candle(CandleStorage),
}
#[cfg(feature = "tensor")]
impl UnifiedStorage {
pub fn ndarray(data: Vec<f64>, dtype: DType) -> Self {
UnifiedStorage::NdArray(NdArrayStorage::new(data, dtype))
}
}
impl TensorStorage for UnifiedStorage {
fn dtype(&self) -> DType {
match self {
UnifiedStorage::NdArray(s) => s.dtype(),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(s) => s.dtype(),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(s) => s.dtype(),
}
}
fn device(&self) -> Device {
match self {
UnifiedStorage::NdArray(s) => s.device(),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(s) => s.device(),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(s) => s.device(),
}
}
fn nbytes(&self) -> usize {
match self {
UnifiedStorage::NdArray(s) => s.nbytes(),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(s) => s.nbytes(),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(s) => s.nbytes(),
}
}
fn is_contiguous(&self) -> bool {
match self {
UnifiedStorage::NdArray(s) => s.is_contiguous(),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(s) => s.is_contiguous(),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(s) => s.is_contiguous(),
}
}
fn alignment(&self) -> usize {
match self {
UnifiedStorage::NdArray(s) => s.alignment(),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(s) => s.alignment(),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(s) => s.alignment(),
}
}
}
impl Debug for UnifiedStorage {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
UnifiedStorage::NdArray(_) => write!(f, "UnifiedStorage::NdArray"),
#[cfg(feature = "tensor-gpu")]
UnifiedStorage::Dfdx(_) => write!(f, "UnifiedStorage::Dfdx"),
#[cfg(feature = "tensor-candle")]
UnifiedStorage::Candle(_) => write!(f, "UnifiedStorage::Candle"),
}
}
}