#[cfg(feature = "half")]
mod scalar_bf16;
mod scalar_bool;
#[cfg(feature = "complex")]
mod scalar_cf32;
#[cfg(feature = "complex")]
mod scalar_cf64;
#[cfg(feature = "half")]
mod scalar_f16;
mod scalar_f32;
mod scalar_f64;
mod scalar_i16;
mod scalar_i32;
mod scalar_i64;
mod scalar_i8;
mod scalar_u8;
#[cfg(feature = "complex")]
use num_complex::Complex;
#[cfg(feature = "half")]
use half::{bf16, f16};
use crate::dtype::DType;
pub trait Scalar: Copy + Clone + Sized + core::fmt::Debug + 'static + PartialEq {
#[cfg(feature = "half")]
fn from_bf16(t: bf16) -> Self;
#[cfg(feature = "half")]
fn from_f16(t: f16) -> Self;
fn from_f32(t: f32) -> Self;
fn from_f64(t: f64) -> Self;
#[cfg(feature = "complex")]
fn from_cf32(t: Complex<f32>) -> Self;
#[cfg(feature = "complex")]
fn from_cf64(t: Complex<f64>) -> Self;
fn from_u8(t: u8) -> Self;
fn from_i8(t: i8) -> Self;
fn from_i16(t: i16) -> Self;
fn from_i32(t: i32) -> Self;
fn from_i64(t: i64) -> Self;
fn from_bool(t: bool) -> Self;
fn from_le_bytes(bytes: &[u8]) -> Self;
fn dtype() -> DType;
fn zero() -> Self;
fn one() -> Self;
fn byte_size() -> usize;
fn abs(self) -> Self;
fn reciprocal(self) -> Self;
fn floor(self) -> Self;
fn neg(self) -> Self;
fn relu(self) -> Self;
fn sin(self) -> Self;
fn cos(self) -> Self;
fn exp2(self) -> Self;
fn log2(self) -> Self;
fn inv(self) -> Self;
fn not(self) -> Self;
fn nonzero(self) -> Self;
fn sqrt(self) -> Self;
fn add(self, rhs: Self) -> Self;
fn sub(self, rhs: Self) -> Self;
fn mul(self, rhs: Self) -> Self;
fn div(self, rhs: Self) -> Self;
fn pow(self, rhs: Self) -> Self;
fn cmplt(self, rhs: Self) -> Self;
fn cmpgt(self, rhs: Self) -> Self;
fn or(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
fn max_value() -> Self;
fn min_value() -> Self;
fn epsilon() -> Self;
fn is_equal(self, rhs: Self) -> bool;
fn cast<T: Scalar>(self) -> T {
use core::mem::transmute_copy as t;
return unsafe {
match Self::dtype() {
#[cfg(feature = "half")]
DType::BF16 => T::from_bf16(t(&self)),
#[cfg(feature = "half")]
DType::F16 => T::from_f16(t(&self)),
DType::F32 => T::from_f32(t(&self)),
DType::F64 => T::from_f64(t(&self)),
#[cfg(feature = "complex")]
DType::CF32 => T::from_cf32(t(&self)),
#[cfg(feature = "complex")]
DType::CF64 => T::from_cf64(t(&self)),
DType::U8 => T::from_u8(t(&self)),
DType::I8 => T::from_i8(t(&self)),
DType::I16 => T::from_i16(t(&self)),
DType::I32 => T::from_i32(t(&self)),
DType::I64 => T::from_i64(t(&self)),
DType::Bool => T::from_bool(t(&self)),
}
};
}
}
pub trait Float: Scalar {}