burn_backend/element/
scalar.rs1use burn_std::{BoolStore, DType, bf16, f16};
2use num_traits::ToPrimitive;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use crate::{Element, ElementConversion};
9
10#[derive(Clone, Copy, Debug)]
12#[allow(missing_docs)]
13pub enum Scalar {
14 Float(f64),
15 Int(i64),
16 UInt(u64),
17 Bool(bool),
18}
19
20impl Scalar {
21 pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
26 if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
27 Self::Float(value.elem())
28 } else if dtype.is_int() {
29 Self::Int(value.elem())
30 } else if dtype.is_uint() {
31 Self::UInt(value.elem())
32 } else if dtype.is_bool() {
33 match dtype {
34 DType::Bool(BoolStore::Native) => Self::Bool(value.elem()),
35 DType::Bool(BoolStore::U8) | DType::Bool(BoolStore::U32) => {
36 Self::UInt(value.elem())
37 }
38 _ => unreachable!(),
39 }
40 } else {
41 unimplemented!("Scalar not supported for {dtype:?}")
42 }
43 }
44
45 pub fn elem<E: Element>(self) -> E {
47 match self {
48 Self::Float(x) => x.elem(),
49 Self::Int(x) => x.elem(),
50 Self::UInt(x) => x.elem(),
51 Self::Bool(x) => x.elem(),
52 }
53 }
54
55 pub fn try_as_integer(&self) -> Option<Self> {
57 match self {
58 Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
59 Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
60 Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
61 }
62 }
63}
64
65macro_rules! impl_from_scalar {
66 ($($ty:ty => $variant:ident),+ $(,)?) => {
67 $(
68 impl From<$ty> for Scalar {
69 fn from(value: $ty) -> Self {
70 Scalar::$variant(value.elem())
71 }
72 }
73 )+
74 };
75}
76
77impl_from_scalar! {
78 f64 => Float, f32 => Float, f16 => Float, bf16 => Float,
79 i64 => Int, i32 => Int, i16 => Int, i8 => Int,
80 u64 => UInt, u32 => UInt, u16 => UInt, u8 => UInt, bool => Bool,
81}
82
83impl ToPrimitive for Scalar {
85 fn to_i64(&self) -> Option<i64> {
86 match self {
87 Scalar::Float(x) => x.to_i64(),
88 Scalar::UInt(x) => x.to_i64(),
89 Scalar::Int(x) => Some(*x),
90 Scalar::Bool(x) => Some(*x as i64),
91 }
92 }
93
94 fn to_u64(&self) -> Option<u64> {
95 match self {
96 Scalar::Float(x) => x.to_u64(),
97 Scalar::UInt(x) => Some(*x),
98 Scalar::Int(x) => x.to_u64(),
99 Scalar::Bool(x) => Some(*x as u64),
100 }
101 }
102
103 fn to_f64(&self) -> Option<f64> {
104 match self {
105 Scalar::Float(x) => Some(*x),
106 Scalar::UInt(x) => x.to_f64(),
107 Scalar::Int(x) => x.to_f64(),
108 Scalar::Bool(x) => (*x as u8).to_f64(),
109 }
110 }
111}