Skip to main content

executorch/
scalar.rs

1//! Custom scalar types that can be used in tensors.
2//!
3//! Half precision floating point types are provided by the `half` crate if its feature is enabled,
4//! otherwise a simple wrappers around `u16` are provided without any arithmetic operations.
5//!
6//! Complex numbers are provided by the `num-complex` crate if its feature is enabled,
7//! otherwise a simple struct with real and imaginary parts is provided without any arithmetic operations.
8
9macro_rules! scalar_type {
10    ($(#[$outer:meta])* $name:ident, $repr:ty) => {
11        #[derive(Copy, Clone, Debug, Default)]
12        #[repr(transparent)]
13        $(#[$outer])*
14        pub struct $name($repr);
15        impl $name {
16            #[doc = concat!("Creates a new `", stringify!($name), "` from its raw bit representation.")]
17            pub const fn from_bits(bits: $repr) -> Self {
18                Self(bits)
19            }
20            #[doc = concat!("Get the raw bit representation of the `", stringify!($name), "`.")]
21            pub const fn to_bits(&self) -> $repr {
22                self.0
23            }
24        }
25    };
26}
27
28cfg_if::cfg_if! { if #[cfg(feature = "half")] {
29    pub use half::f16;
30    pub use half::bf16;
31} else {
32    scalar_type!(
33        /// A 16-bit floating point type implementing the IEEE 754-2008 standard [`binary16`] a.k.a "half"
34        /// format.
35        ///
36        /// Doesn't provide any arithmetic operations, but can be converted to/from `u16`.
37        /// Enable the `half` feature to get a fully functional `f16` type.
38        #[allow(non_camel_case_types)]
39        f16, u16
40    );
41
42    scalar_type!(
43        /// A 16-bit floating point type implementing the [`bfloat16`] format.
44        ///
45        /// Doesn't provide any arithmetic operations, but can be converted to/from `u16`.
46        /// Enable the `half` feature to get a fully functional `bf16` type.
47        #[allow(non_camel_case_types)]
48        bf16, u16
49    );
50} }
51
52cfg_if::cfg_if! { if #[cfg(feature = "num-complex")] {
53    pub use num_complex::Complex;
54} else {
55    /// A complex number in Cartesian form.
56    ///
57    /// Doesn't provide any arithmetic operations, but expose the real and imaginary parts.
58    /// Enable the `num-complex` feature to get a fully functional `Complex` type.
59    #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
60    #[repr(C)]
61    pub struct Complex<T> {
62        /// Real portion of the complex number
63        pub re: T,
64        /// Imaginary portion of the complex number
65        pub im: T,
66    }
67} }
68
69scalar_type!(
70    /// 8-bit quantized integer.
71    ///
72    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
73    QInt8, u8
74);
75
76scalar_type!(
77    /// 8-bit unsigned quantized integer.
78    ///
79    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
80    QUInt8, u8
81);
82
83scalar_type!(
84    /// 32-bit quantized integer.
85    ///
86    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
87    QInt32, u32
88);
89
90scalar_type!(
91    /// Two 4-bit unsigned quantized integers packed into a byte.
92    ///
93    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
94    QUInt4x2, u8
95);
96
97scalar_type!(
98    /// Four 2-bit unsigned quantized integers packed into a byte.
99    ///
100    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
101    QUInt2x4, u8
102);
103
104scalar_type!(
105    /// Eight 1-bit values packed into a byte.
106    ///
107    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
108    Bits1x8, u8
109);
110
111scalar_type!(
112    /// Four 2-bit values packed into a byte.
113    ///
114    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
115    Bits2x4, u8
116);
117
118scalar_type!(
119    /// Two 4-bit values packed into a byte.
120    ///
121    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
122    Bits4x2, u8
123);
124
125scalar_type!(
126    /// 8-bit bitfield (1 byte).
127    ///
128    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
129    Bits8, u8
130);
131
132scalar_type!(
133    /// 16-bit bitfield (2 bytes).
134    ///
135    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
136    Bits16, u16
137);
138
139scalar_type!(
140    /// 8-bit floating-point with 1 bit for the sign, 5 bits for the exponents, 2 bits for the mantissa.
141    ///
142    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
143    #[allow(non_camel_case_types)]
144    Float8_e5m2, u8
145);
146
147scalar_type!(
148    /// 8-bit floating-point with 1 bit for the sign, 4 bits for the exponents, 3 bits for the mantissa,
149    /// only nan values and no infinite values (FN).
150    ///
151    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
152    #[allow(non_camel_case_types)]
153    Float8_e4m3fn, u8
154);
155
156scalar_type!(
157    /// 8-bit floating-point with 1 bit for the sign, 5 bits for the exponents, 2 bits for the mantissa,
158    /// only nan values and no infinite values (FN), no negative zero (UZ).
159    ///
160    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
161    #[allow(non_camel_case_types)]
162    Float8_e5m2fnuz, u8
163);
164
165scalar_type!(
166    /// 8-bit floating-point with 1 bit for the sign, 4 bits for the exponents, 3 bits for the mantissa,
167    /// only nan values and no infinite values (FN), no negative zero (UZ).
168    ///
169    /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
170    #[allow(non_camel_case_types)]
171    Float8_e4m3fnuz, u8
172);