executorch/scalar.rs
1//! Custom scalar types that can be used in tensors.
2//!
3//! Half precision floating point types are provided by the `half` crate if its feature is enabled,
4//! otherwise a simple wrappers around `u16` are provided without any arithmetic operations.
5//!
6//! Complex numbers are provided by the `num-complex` crate if its feature is enabled,
7//! otherwise a simple struct with real and imaginary parts is provided without any arithmetic operations.
8
9macro_rules! scalar_type {
10 ($(#[$outer:meta])* $name:ident, $repr:ty) => {
11 #[derive(Copy, Clone, Debug, Default)]
12 #[repr(transparent)]
13 $(#[$outer])*
14 pub struct $name($repr);
15 impl $name {
16 #[doc = concat!("Creates a new `", stringify!($name), "` from its raw bit representation.")]
17 pub const fn from_bits(bits: $repr) -> Self {
18 Self(bits)
19 }
20 #[doc = concat!("Get the raw bit representation of the `", stringify!($name), "`.")]
21 pub const fn to_bits(&self) -> $repr {
22 self.0
23 }
24 }
25 };
26}
27
28cfg_if::cfg_if! { if #[cfg(feature = "half")] {
29 pub use half::f16;
30 pub use half::bf16;
31} else {
32 scalar_type!(
33 /// A 16-bit floating point type implementing the IEEE 754-2008 standard [`binary16`] a.k.a "half"
34 /// format.
35 ///
36 /// Doesn't provide any arithmetic operations, but can be converted to/from `u16`.
37 /// Enable the `half` feature to get a fully functional `f16` type.
38 #[allow(non_camel_case_types)]
39 f16, u16
40 );
41
42 scalar_type!(
43 /// A 16-bit floating point type implementing the [`bfloat16`] format.
44 ///
45 /// Doesn't provide any arithmetic operations, but can be converted to/from `u16`.
46 /// Enable the `half` feature to get a fully functional `bf16` type.
47 #[allow(non_camel_case_types)]
48 bf16, u16
49 );
50} }
51
52cfg_if::cfg_if! { if #[cfg(feature = "num-complex")] {
53 pub use num_complex::Complex;
54} else {
55 /// A complex number in Cartesian form.
56 ///
57 /// Doesn't provide any arithmetic operations, but expose the real and imaginary parts.
58 /// Enable the `num-complex` feature to get a fully functional `Complex` type.
59 #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
60 #[repr(C)]
61 pub struct Complex<T> {
62 /// Real portion of the complex number
63 pub re: T,
64 /// Imaginary portion of the complex number
65 pub im: T,
66 }
67} }
68
69scalar_type!(
70 /// 8-bit quantized integer.
71 ///
72 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
73 QInt8, u8
74);
75
76scalar_type!(
77 /// 8-bit unsigned quantized integer.
78 ///
79 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
80 QUInt8, u8
81);
82
83scalar_type!(
84 /// 32-bit quantized integer.
85 ///
86 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
87 QInt32, u32
88);
89
90scalar_type!(
91 /// Two 4-bit unsigned quantized integers packed into a byte.
92 ///
93 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
94 QUInt4x2, u8
95);
96
97scalar_type!(
98 /// Four 2-bit unsigned quantized integers packed into a byte.
99 ///
100 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
101 QUInt2x4, u8
102);
103
104scalar_type!(
105 /// Eight 1-bit values packed into a byte.
106 ///
107 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
108 Bits1x8, u8
109);
110
111scalar_type!(
112 /// Four 2-bit values packed into a byte.
113 ///
114 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
115 Bits2x4, u8
116);
117
118scalar_type!(
119 /// Two 4-bit values packed into a byte.
120 ///
121 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
122 Bits4x2, u8
123);
124
125scalar_type!(
126 /// 8-bit bitfield (1 byte).
127 ///
128 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
129 Bits8, u8
130);
131
132scalar_type!(
133 /// 16-bit bitfield (2 bytes).
134 ///
135 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
136 Bits16, u16
137);
138
139scalar_type!(
140 /// 8-bit floating-point with 1 bit for the sign, 5 bits for the exponents, 2 bits for the mantissa.
141 ///
142 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
143 #[allow(non_camel_case_types)]
144 Float8_e5m2, u8
145);
146
147scalar_type!(
148 /// 8-bit floating-point with 1 bit for the sign, 4 bits for the exponents, 3 bits for the mantissa,
149 /// only nan values and no infinite values (FN).
150 ///
151 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
152 #[allow(non_camel_case_types)]
153 Float8_e4m3fn, u8
154);
155
156scalar_type!(
157 /// 8-bit floating-point with 1 bit for the sign, 5 bits for the exponents, 2 bits for the mantissa,
158 /// only nan values and no infinite values (FN), no negative zero (UZ).
159 ///
160 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
161 #[allow(non_camel_case_types)]
162 Float8_e5m2fnuz, u8
163);
164
165scalar_type!(
166 /// 8-bit floating-point with 1 bit for the sign, 4 bits for the exponents, 3 bits for the mantissa,
167 /// only nan values and no infinite values (FN), no negative zero (UZ).
168 ///
169 /// Does not provide any arithmetic operations, but can be converted to/from bits representation.
170 #[allow(non_camel_case_types)]
171 Float8_e4m3fnuz, u8
172);