burn_backend/element/
scalar.rs1use burn_std::{DType, bf16, f16};
2use num_traits::ToPrimitive;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use crate::{Element, ElementConversion};
9
10#[derive(Clone, Copy, Debug)]
12#[allow(missing_docs)]
13pub enum Scalar {
14 Float(f64),
15 Int(i64),
16 UInt(u64),
17 Bool(bool),
18}
19
20impl Scalar {
21 pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
26 if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
27 Self::Float(value.elem())
28 } else if dtype.is_int() {
29 Self::Int(value.elem())
30 } else if dtype.is_uint() {
31 Self::UInt(value.elem())
32 } else if dtype.is_bool() {
33 Self::Bool(value.elem())
34 } else {
35 unimplemented!("Scalar not supported for {dtype:?}")
36 }
37 }
38
39 pub fn elem<E: Element>(self) -> E {
41 match self {
42 Self::Float(x) => x.elem(),
43 Self::Int(x) => x.elem(),
44 Self::UInt(x) => x.elem(),
45 Self::Bool(x) => x.elem(),
46 }
47 }
48
49 pub fn try_as_integer(&self) -> Option<Self> {
51 match self {
52 Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
53 Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
54 Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
55 }
56 }
57}
58
59macro_rules! impl_from_scalar {
60 ($($ty:ty => $variant:ident),+ $(,)?) => {
61 $(
62 impl From<$ty> for Scalar {
63 fn from(value: $ty) -> Self {
64 Scalar::$variant(value.elem())
65 }
66 }
67 )+
68 };
69}
70
71impl_from_scalar! {
72 f64 => Float, f32 => Float, f16 => Float, bf16 => Float,
73 i64 => Int, i32 => Int, i16 => Int, i8 => Int,
74 u64 => UInt, u32 => UInt, u16 => UInt, u8 => UInt, bool => Bool,
75}
76
77impl ToPrimitive for Scalar {
79 fn to_i64(&self) -> Option<i64> {
80 match self {
81 Scalar::Float(x) => x.to_i64(),
82 Scalar::UInt(x) => x.to_i64(),
83 Scalar::Int(x) => Some(*x),
84 Scalar::Bool(x) => Some(*x as i64),
85 }
86 }
87
88 fn to_u64(&self) -> Option<u64> {
89 match self {
90 Scalar::Float(x) => x.to_u64(),
91 Scalar::UInt(x) => Some(*x),
92 Scalar::Int(x) => x.to_u64(),
93 Scalar::Bool(x) => Some(*x as u64),
94 }
95 }
96
97 fn to_f64(&self) -> Option<f64> {
98 match self {
99 Scalar::Float(x) => Some(*x),
100 Scalar::UInt(x) => x.to_f64(),
101 Scalar::Int(x) => x.to_f64(),
102 Scalar::Bool(x) => (*x as u8).to_f64(),
103 }
104 }
105}