mod activations;
mod complex;
pub mod constructors;
mod conversions;
mod expression;
mod math_ops;
mod sparse;
pub mod transformations;
mod utils;
#[cfg(test)]
mod complex_tests;
#[cfg(test)]
mod constructors_tests;
#[cfg(test)]
mod property_tests;
use crate::errors::Result;
use scirs2_core::ndarray::{ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
use scirs2_core::Complex;
use scirs2_core::{Complex32, Complex64};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DType {
F32,
F16,
BF16,
F64,
C32,
C64,
CF16,
CBF16,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
Bool,
}
impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
DType::F32 => 4,
DType::F16 => 2,
DType::BF16 => 2,
DType::F64 => 8,
DType::C32 => 8, DType::C64 => 16, DType::CF16 => 4, DType::CBF16 => 4, DType::U8 => 1,
DType::U16 => 2,
DType::U32 => 4,
DType::U64 => 8,
DType::I8 => 1,
DType::I16 => 2,
DType::I32 => 4,
DType::I64 => 8,
DType::Bool => 1,
}
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[derive(Debug)]
pub struct MetalTensorData {
pub buffer_id: crate::gpu_ops::metal::BufferId,
pub shape: Vec<usize>,
pub dtype: DType,
}
#[cfg(all(target_os = "macos", feature = "metal"))]
impl Clone for MetalTensorData {
fn clone(&self) -> Self {
Self {
buffer_id: self.buffer_id,
shape: self.shape.clone(),
dtype: self.dtype,
}
}
}
#[cfg(feature = "cuda")]
#[derive(Debug)]
pub struct CudaTensorData {
pub buffer_id: crate::gpu_ops::cuda::BufferId,
pub shape: Vec<usize>,
pub dtype: DType,
}
#[cfg(feature = "cuda")]
impl Clone for CudaTensorData {
fn clone(&self) -> Self {
Self {
buffer_id: self.buffer_id,
shape: self.shape.clone(),
dtype: self.dtype,
}
}
}
pub enum Tensor {
F32(ArrayD<f32>),
F64(ArrayD<f64>),
F16(ArrayD<half::f16>),
BF16(ArrayD<half::bf16>),
I64(ArrayD<i64>),
C32(ArrayD<Complex32>),
C64(ArrayD<Complex64>),
CF16(ArrayD<Complex<half::f16>>),
CBF16(ArrayD<Complex<half::bf16>>),
Sparse(crate::sparse_tensor::SparseTensor),
#[cfg(feature = "torch")]
Torch(tch::Tensor),
#[cfg(feature = "candle")]
Candle(candle_core::Tensor),
#[cfg(all(target_os = "macos", feature = "metal"))]
Metal(MetalTensorData),
#[cfg(feature = "cuda")]
CUDA(CudaTensorData),
}
impl Clone for Tensor {
fn clone(&self) -> Self {
match self {
Tensor::F32(arr) => Tensor::F32(arr.clone()),
Tensor::F64(arr) => Tensor::F64(arr.clone()),
Tensor::F16(arr) => Tensor::F16(arr.clone()),
Tensor::BF16(arr) => Tensor::BF16(arr.clone()),
Tensor::I64(arr) => Tensor::I64(arr.clone()),
Tensor::C32(arr) => Tensor::C32(arr.clone()),
Tensor::C64(arr) => Tensor::C64(arr.clone()),
Tensor::CF16(arr) => Tensor::CF16(arr.clone()),
Tensor::CBF16(arr) => Tensor::CBF16(arr.clone()),
Tensor::Sparse(s) => Tensor::Sparse(s.clone()),
#[cfg(feature = "torch")]
Tensor::Torch(t) => Tensor::Torch(t.shallow_clone()),
#[cfg(feature = "candle")]
Tensor::Candle(t) => Tensor::Candle(t.clone()),
#[cfg(all(target_os = "macos", feature = "metal"))]
Tensor::Metal(data) => Tensor::Metal(data.clone()),
#[cfg(feature = "cuda")]
Tensor::CUDA(data) => Tensor::CUDA(data.clone()),
}
}
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Tensor::F32(_) => write!(f, "Tensor::F32(shape: {:?}, dtype: F32)", self.shape()),
Tensor::F64(_) => write!(f, "Tensor::F64(shape: {:?}, dtype: F64)", self.shape()),
Tensor::F16(_) => write!(f, "Tensor::F16(shape: {:?}, dtype: F16)", self.shape()),
Tensor::BF16(_) => write!(f, "Tensor::BF16(shape: {:?}, dtype: BF16)", self.shape()),
Tensor::I64(_) => write!(f, "Tensor::I64(shape: {:?}, dtype: I64)", self.shape()),
Tensor::C32(_) => write!(f, "Tensor::C32(shape: {:?}, dtype: C32)", self.shape()),
Tensor::C64(_) => write!(f, "Tensor::C64(shape: {:?}, dtype: C64)", self.shape()),
Tensor::CF16(_) => write!(f, "Tensor::CF16(shape: {:?}, dtype: CF16)", self.shape()),
Tensor::CBF16(_) => write!(f, "Tensor::CBF16(shape: {:?}, dtype: CBF16)", self.shape()),
Tensor::Sparse(s) => write!(f, "Tensor::Sparse({:?})", s),
#[cfg(feature = "torch")]
Tensor::Torch(_) => write!(f, "Tensor::Torch(shape: {:?})", self.shape()),
#[cfg(feature = "candle")]
Tensor::Candle(_) => write!(f, "Tensor::Candle(shape: {:?})", self.shape()),
#[cfg(all(target_os = "macos", feature = "metal"))]
Tensor::Metal(data) => write!(
f,
"Tensor::Metal(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
data.shape, data.dtype, data.buffer_id
),
#[cfg(feature = "cuda")]
Tensor::CUDA(data) => write!(
f,
"Tensor::CUDA(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
data.shape, data.dtype, data.buffer_id
),
}
}
}
#[cfg(any(feature = "torch", feature = "candle"))]
unsafe impl Sync for Tensor {}
impl From<ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>> for Tensor {
fn from(arr: ArrayD<f32>) -> Self {
Tensor::F32(arr)
}
}
impl From<ArrayBase<OwnedRepr<f64>, Dim<IxDynImpl>>> for Tensor {
fn from(arr: ArrayD<f64>) -> Self {
Tensor::F64(arr)
}
}
impl std::ops::Add for Tensor {
type Output = Result<Tensor>;
fn add(self, other: Tensor) -> Self::Output {
Tensor::add(&self, &other)
}
}
impl std::ops::Add for &Tensor {
type Output = Result<Tensor>;
fn add(self, other: &Tensor) -> Self::Output {
Tensor::add(self, other)
}
}
impl std::ops::Add<&&Tensor> for &Tensor {
type Output = Result<Tensor>;
fn add(self, other: &&Tensor) -> Self::Output {
Tensor::add(self, other)
}
}
impl std::ops::Add<&Tensor> for &&Tensor {
type Output = Result<Tensor>;
fn add(self, other: &Tensor) -> Self::Output {
Tensor::add(self, other)
}
}
impl std::ops::Sub for Tensor {
type Output = Result<Tensor>;
fn sub(self, other: Tensor) -> Self::Output {
Tensor::sub(&self, &other)
}
}
impl std::ops::Mul<f32> for Tensor {
type Output = Result<Tensor>;
fn mul(self, scalar: f32) -> Self::Output {
self.scalar_mul(scalar)
}
}
impl std::ops::Mul<f32> for &Tensor {
type Output = Result<Tensor>;
fn mul(self, scalar: f32) -> Self::Output {
self.scalar_mul(scalar)
}
}
impl std::ops::Mul<f64> for Tensor {
type Output = Result<Tensor>;
fn mul(self, scalar: f64) -> Self::Output {
self.scalar_mul(scalar as f32)
}
}
impl std::ops::Mul<f64> for &Tensor {
type Output = Result<Tensor>;
fn mul(self, scalar: f64) -> Self::Output {
self.scalar_mul(scalar as f32)
}
}
impl std::ops::Mul<&Tensor> for &Tensor {
type Output = Result<Tensor>;
fn mul(self, other: &Tensor) -> Self::Output {
Tensor::mul(self, other)
}
}
impl std::ops::Mul<Tensor> for &Tensor {
type Output = Result<Tensor>;
fn mul(self, other: Tensor) -> Self::Output {
Tensor::mul(self, &other)
}
}
impl std::ops::Mul<&Tensor> for Tensor {
type Output = Result<Tensor>;
fn mul(self, other: &Tensor) -> Self::Output {
Tensor::mul(&self, other)
}
}
impl std::ops::Div<f32> for Tensor {
type Output = Result<Tensor>;
fn div(self, scalar: f32) -> Self::Output {
self.scalar_div(scalar)
}
}
impl std::ops::Div<f32> for &Tensor {
type Output = Result<Tensor>;
fn div(self, scalar: f32) -> Self::Output {
self.scalar_div(scalar)
}
}
impl std::ops::Div<f64> for Tensor {
type Output = Result<Tensor>;
fn div(self, scalar: f64) -> Self::Output {
self.scalar_div(scalar as f32)
}
}
impl std::ops::Div<f64> for &Tensor {
type Output = Result<Tensor>;
fn div(self, scalar: f64) -> Self::Output {
self.scalar_div(scalar as f32)
}
}
impl std::ops::Div<f64> for &&Tensor {
type Output = Result<Tensor>;
fn div(self, scalar: f64) -> Self::Output {
(*self).scalar_div(scalar as f32)
}
}
impl std::ops::Sub for &Tensor {
type Output = Result<Tensor>;
fn sub(self, other: &Tensor) -> Self::Output {
Tensor::sub(self, other)
}
}
pub type TensorType = DType;
pub use expression::{EvalContext, ExprNode, OpType, OptimizationHints, TensorExpr};
pub use utils::{clear_gradients, disable_grad, enable_grad, is_grad_enabled};