use burn_std::{BoolStore, DType, bf16, f16};
use num_traits::ToPrimitive;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::{Element, ElementConversion};
#[derive(Clone, Copy, Debug)]
#[allow(missing_docs)]
pub enum Scalar {
Float(f64),
Int(i64),
UInt(u64),
Bool(bool),
}
impl Scalar {
pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
Self::Float(value.elem())
} else if dtype.is_int() {
Self::Int(value.elem())
} else if dtype.is_uint() {
Self::UInt(value.elem())
} else if dtype.is_bool() {
match dtype {
DType::Bool(BoolStore::Native) => Self::Bool(value.elem()),
DType::Bool(BoolStore::U8) | DType::Bool(BoolStore::U32) => {
Self::UInt(value.elem())
}
_ => unreachable!(),
}
} else {
unimplemented!("Scalar not supported for {dtype:?}")
}
}
pub fn elem<E: Element>(self) -> E {
match self {
Self::Float(x) => x.elem(),
Self::Int(x) => x.elem(),
Self::UInt(x) => x.elem(),
Self::Bool(x) => x.elem(),
}
}
pub fn try_as_integer(&self) -> Option<Self> {
match self {
Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
}
}
}
macro_rules! impl_from_scalar {
($($ty:ty => $variant:ident),+ $(,)?) => {
$(
impl From<$ty> for Scalar {
fn from(value: $ty) -> Self {
Scalar::$variant(value.elem())
}
}
)+
};
}
impl_from_scalar! {
f64 => Float, f32 => Float, f16 => Float, bf16 => Float,
i64 => Int, i32 => Int, i16 => Int, i8 => Int,
u64 => UInt, u32 => UInt, u16 => UInt, u8 => UInt, bool => Bool,
}
impl ToPrimitive for Scalar {
fn to_i64(&self) -> Option<i64> {
match self {
Scalar::Float(x) => x.to_i64(),
Scalar::UInt(x) => x.to_i64(),
Scalar::Int(x) => Some(*x),
Scalar::Bool(x) => Some(*x as i64),
}
}
fn to_u64(&self) -> Option<u64> {
match self {
Scalar::Float(x) => x.to_u64(),
Scalar::UInt(x) => Some(*x),
Scalar::Int(x) => x.to_u64(),
Scalar::Bool(x) => Some(*x as u64),
}
}
fn to_f64(&self) -> Option<f64> {
match self {
Scalar::Float(x) => Some(*x),
Scalar::UInt(x) => x.to_f64(),
Scalar::Int(x) => x.to_f64(),
Scalar::Bool(x) => (*x as u8).to_f64(),
}
}
}