Skip to main content

burn_backend/element/
scalar.rs

1use burn_std::{DType, bf16, f16};
2use num_traits::ToPrimitive;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use crate::{Element, ElementConversion};
9
10/// A scalar element.
11#[derive(Clone, Copy, Debug)]
12#[allow(missing_docs)]
13pub enum Scalar {
14    Float(f64),
15    Int(i64),
16    UInt(u64),
17    Bool(bool),
18}
19
20impl Scalar {
21    /// Creates a scalar with the specified data type.
22    ///
23    /// # Note
24    /// [`QFloat`](DType::QFloat) scalars are represented as float for element-wise operations.
25    pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {
26        if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {
27            Self::Float(value.elem())
28        } else if dtype.is_int() {
29            Self::Int(value.elem())
30        } else if dtype.is_uint() {
31            Self::UInt(value.elem())
32        } else if dtype.is_bool() {
33            Self::Bool(value.elem())
34        } else {
35            unimplemented!("Scalar not supported for {dtype:?}")
36        }
37    }
38
39    /// Converts and returns the converted element.
40    pub fn elem<E: Element>(self) -> E {
41        match self {
42            Self::Float(x) => x.elem(),
43            Self::Int(x) => x.elem(),
44            Self::UInt(x) => x.elem(),
45            Self::Bool(x) => x.elem(),
46        }
47    }
48
49    /// Returns the exact integer value, if valid.
50    pub fn try_as_integer(&self) -> Option<Self> {
51        match self {
52            Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),
53            Scalar::Int(_) | Scalar::UInt(_) => Some(*self),
54            Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),
55        }
56    }
57}
58
59macro_rules! impl_from_scalar {
60    ($($ty:ty => $variant:ident),+ $(,)?) => {
61        $(
62            impl From<$ty> for Scalar {
63                fn from(value: $ty) -> Self {
64                    Scalar::$variant(value.elem())
65                }
66            }
67        )+
68    };
69}
70
71impl_from_scalar! {
72    f64  => Float, f32  => Float, f16  => Float, bf16 => Float,
73    i64  => Int, i32  => Int, i16  => Int, i8 => Int,
74    u64  => UInt, u32  => UInt, u16  => UInt, u8 => UInt, bool => Bool,
75}
76
77// CubeCL requirement
78impl ToPrimitive for Scalar {
79    fn to_i64(&self) -> Option<i64> {
80        match self {
81            Scalar::Float(x) => x.to_i64(),
82            Scalar::UInt(x) => x.to_i64(),
83            Scalar::Int(x) => Some(*x),
84            Scalar::Bool(x) => Some(*x as i64),
85        }
86    }
87
88    fn to_u64(&self) -> Option<u64> {
89        match self {
90            Scalar::Float(x) => x.to_u64(),
91            Scalar::UInt(x) => Some(*x),
92            Scalar::Int(x) => x.to_u64(),
93            Scalar::Bool(x) => Some(*x as u64),
94        }
95    }
96
97    fn to_f64(&self) -> Option<f64> {
98        match self {
99            Scalar::Float(x) => Some(*x),
100            Scalar::UInt(x) => x.to_f64(),
101            Scalar::Int(x) => x.to_f64(),
102            Scalar::Bool(x) => (*x as u8).to_f64(),
103        }
104    }
105}