Skip to main content

burn_backend/element/
scalar.rs

1use burn_std::{BoolStore, DType, bf16, f16};
2use num_traits::ToPrimitive;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use crate::{Element, ElementConversion};
9
10/// A scalar element.
11#[derive(Clone, Copy, Debug)]
12#[allow(missing_docs)]
13pub enum Scalar {
14    Float(f64),
15    Int(i64),
16    UInt(u64),
17    Bool(bool),
18}
19
20impl Scalar {
21    /// Creates a scalar with the specified data type.
22    ///
23    /// # Note
24    /// [`QFloat`](DType::QFloat) scalars are represented as float for element-wise operations.
25    pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
26        if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
27            Self::Float(value.elem())
28        } else if dtype.is_int() {
29            Self::Int(value.elem())
30        } else if dtype.is_uint() {
31            Self::UInt(value.elem())
32        } else if dtype.is_bool() {
33            match dtype {
34                DType::Bool(BoolStore::Native) => Self::Bool(value.elem()),
35                DType::Bool(BoolStore::U8) | DType::Bool(BoolStore::U32) => {
36                    Self::UInt(value.elem())
37                }
38                _ => unreachable!(),
39            }
40        } else {
41            unimplemented!("Scalar not supported for {dtype:?}")
42        }
43    }
44
45    /// Converts and returns the converted element.
46    pub fn elem<E: Element>(self) -> E {
47        match self {
48            Self::Float(x) => x.elem(),
49            Self::Int(x) => x.elem(),
50            Self::UInt(x) => x.elem(),
51            Self::Bool(x) => x.elem(),
52        }
53    }
54
55    /// Returns the exact integer value, if valid.
56    pub fn try_as_integer(&self) -> Option<Self> {
57        match self {
58            Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
59            Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
60            Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
61        }
62    }
63}
64
65macro_rules! impl_from_scalar {
66    ($($ty:ty => $variant:ident),+ $(,)?) => {
67        $(
68            impl From<$ty> for Scalar {
69                fn from(value: $ty) -> Self {
70                    Scalar::$variant(value.elem())
71                }
72            }
73        )+
74    };
75}
76
77impl_from_scalar! {
78    f64  => Float, f32  => Float, f16  => Float, bf16 => Float,
79    i64  => Int, i32  => Int, i16  => Int, i8 => Int,
80    u64  => UInt, u32  => UInt, u16  => UInt, u8 => UInt, bool => Bool,
81}
82
83// CubeCL requirement
84impl ToPrimitive for Scalar {
85    fn to_i64(&self) -> Option<i64> {
86        match self {
87            Scalar::Float(x) => x.to_i64(),
88            Scalar::UInt(x) => x.to_i64(),
89            Scalar::Int(x) => Some(*x),
90            Scalar::Bool(x) => Some(*x as i64),
91        }
92    }
93
94    fn to_u64(&self) -> Option<u64> {
95        match self {
96            Scalar::Float(x) => x.to_u64(),
97            Scalar::UInt(x) => Some(*x),
98            Scalar::Int(x) => x.to_u64(),
99            Scalar::Bool(x) => Some(*x as u64),
100        }
101    }
102
103    fn to_f64(&self) -> Option<f64> {
104        match self {
105            Scalar::Float(x) => Some(*x),
106            Scalar::UInt(x) => x.to_f64(),
107            Scalar::Int(x) => x.to_f64(),
108            Scalar::Bool(x) => (*x as u8).to_f64(),
109        }
110    }
111}