1#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum DType {
22 F32,
23 F16,
24 BF16,
25 F64,
26 I8,
27 I16,
28 I32,
29 I64,
30 U8,
31 U32,
32 Bool,
33 C64,
44}
45
46impl DType {
47 pub const fn size_bytes(self) -> usize {
49 match self {
50 Self::Bool | Self::I8 | Self::U8 => 1,
51 Self::F16 | Self::BF16 | Self::I16 => 2,
52 Self::F32 | Self::I32 | Self::U32 => 4,
53 Self::F64 | Self::I64 | Self::C64 => 8,
54 }
55 }
56
57 pub const fn is_float(self) -> bool {
58 matches!(self, Self::F32 | Self::F16 | Self::BF16 | Self::F64)
59 }
60
61 pub const fn is_complex(self) -> bool {
65 matches!(self, Self::C64)
66 }
67
68 pub const fn is_int(self) -> bool {
69 matches!(
70 self,
71 Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 | Self::U32
72 )
73 }
74
75 pub const fn promotion_rank(self) -> u8 {
88 match self {
89 Self::Bool => 0,
90 Self::U8 | Self::I8 => 1,
91 Self::I16 | Self::BF16 => 2,
92 Self::F16 => 3,
93 Self::U32 | Self::I32 => 4,
94 Self::I64 => 5,
95 Self::F32 => 6,
96 Self::F64 => 7,
97 Self::C64 => 8,
98 }
99 }
100
101 pub fn promote(self, other: Self) -> Self {
105 if self == other {
106 return self;
107 }
108 if matches!(
111 (self, other),
112 (Self::F16, Self::BF16) | (Self::BF16, Self::F16)
113 ) {
114 return Self::F32;
115 }
116 let promote_int_to_float = |int: Self, float: Self| -> Self {
118 match (int, float) {
119 (_, Self::F64) => Self::F64,
120 (Self::I64, _) => Self::F64, (_, Self::F32) => Self::F32,
122 (_, Self::F16) | (_, Self::BF16) => Self::F32, _ => float,
124 }
125 };
126 match (
127 self.is_int(),
128 other.is_int(),
129 self.is_float(),
130 other.is_float(),
131 ) {
132 (true, false, false, true) => promote_int_to_float(self, other),
133 (false, true, true, false) => promote_int_to_float(other, self),
134 _ => {
135 if self.promotion_rank() >= other.promotion_rank() {
136 self
137 } else {
138 other
139 }
140 }
141 }
142 }
143}
144
145fn integral_scalar(value: f64, name: &str) -> Result<i64, String> {
146 if !value.is_finite() {
147 return Err(format!(
148 "constant value {value} is not finite for dtype {name}"
149 ));
150 }
151 if value.fract() != 0.0 {
152 return Err(format!(
153 "constant value {value} must be integral for dtype {name}"
154 ));
155 }
156 Ok(value as i64)
157}
158
159pub fn scalar_constant_bytes(value: f64, dtype: DType) -> Result<Vec<u8>, String> {
161 let out_of_range =
162 |name: &str| format!("constant value {value} is out of range for dtype {name}");
163 match dtype {
164 DType::F32 => Ok((value as f32).to_le_bytes().to_vec()),
165 DType::F64 => Ok(value.to_le_bytes().to_vec()),
166 DType::I8 => {
167 let v = integral_scalar(value, "i8")?;
168 if !(i8::MIN as i64..=i8::MAX as i64).contains(&v) {
169 return Err(out_of_range("i8"));
170 }
171 Ok((v as i8).to_le_bytes().to_vec())
172 }
173 DType::I16 => {
174 let v = integral_scalar(value, "i16")?;
175 if !(i16::MIN as i64..=i16::MAX as i64).contains(&v) {
176 return Err(out_of_range("i16"));
177 }
178 Ok((v as i16).to_le_bytes().to_vec())
179 }
180 DType::I32 => {
181 let v = integral_scalar(value, "i32")?;
182 if !(i32::MIN as i64..=i32::MAX as i64).contains(&v) {
183 return Err(out_of_range("i32"));
184 }
185 Ok((v as i32).to_le_bytes().to_vec())
186 }
187 DType::I64 => {
188 if !value.is_finite() {
189 return Err(format!(
190 "constant value {value} is not finite for dtype i64"
191 ));
192 }
193 if value.fract() != 0.0 {
194 return Err(format!(
195 "constant value {value} must be integral for dtype i64"
196 ));
197 }
198 if value >= 9.223372036854776e18 || value < -9.223372036854776e18 {
200 return Err(out_of_range("i64"));
201 }
202 Ok((value as i64).to_le_bytes().to_vec())
203 }
204 DType::U8 => {
205 let v = integral_scalar(value, "u8")?;
206 if !(0..=u8::MAX as i64).contains(&v) {
207 return Err(out_of_range("u8"));
208 }
209 Ok((v as u8).to_le_bytes().to_vec())
210 }
211 DType::U32 => {
212 let v = integral_scalar(value, "u32")?;
213 if v < 0 || v > u32::MAX as i64 {
214 return Err(out_of_range("u32"));
215 }
216 Ok((v as u32).to_le_bytes().to_vec())
217 }
218 DType::Bool => Ok(vec![u8::from(value != 0.0)]),
219 DType::F16 | DType::BF16 | DType::C64 => Err(format!(
220 "scalar literal dtype '{dtype:?}' is built via f32 constant + cast"
221 )),
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
236pub struct Element {
237 pub dtype: DType,
238 pub subtype: ElementSubtype,
241 pub saturating: bool,
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
247pub enum ElementSubtype {
248 Standard,
249 Fp8E4m3,
252 Fp8E5m2,
255}
256
257impl Element {
258 pub const fn new(dtype: DType) -> Self {
259 Self {
260 dtype,
261 subtype: ElementSubtype::Standard,
262 saturating: false,
263 }
264 }
265 pub const fn fp8_e4m3() -> Self {
266 Self {
267 dtype: DType::U8,
268 subtype: ElementSubtype::Fp8E4m3,
269 saturating: true,
270 }
271 }
272 pub const fn fp8_e5m2() -> Self {
273 Self {
274 dtype: DType::U8,
275 subtype: ElementSubtype::Fp8E5m2,
276 saturating: true,
277 }
278 }
279 pub const fn saturating(self) -> Self {
280 Self {
281 saturating: true,
282 ..self
283 }
284 }
285}
286
287impl std::fmt::Display for DType {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 Self::F32 => write!(f, "f32"),
291 Self::F16 => write!(f, "f16"),
292 Self::BF16 => write!(f, "bf16"),
293 Self::F64 => write!(f, "f64"),
294 Self::I8 => write!(f, "i8"),
295 Self::I16 => write!(f, "i16"),
296 Self::I32 => write!(f, "i32"),
297 Self::I64 => write!(f, "i64"),
298 Self::U8 => write!(f, "u8"),
299 Self::U32 => write!(f, "u32"),
300 Self::Bool => write!(f, "bool"),
301 Self::C64 => write!(f, "c64"),
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn element_constructors() {
312 let f = Element::new(DType::F32);
313 assert_eq!(f.dtype, DType::F32);
314 assert_eq!(f.subtype, ElementSubtype::Standard);
315 assert!(!f.saturating);
316
317 let e4 = Element::fp8_e4m3();
318 assert_eq!(e4.subtype, ElementSubtype::Fp8E4m3);
319 assert!(e4.saturating);
320 assert_eq!(e4.dtype, DType::U8);
321
322 let s = Element::new(DType::I32).saturating();
323 assert!(s.saturating);
324 assert_eq!(s.dtype, DType::I32);
325 }
326
327 #[test]
328 fn promote_same() {
329 assert_eq!(DType::F32.promote(DType::F32), DType::F32);
330 assert_eq!(DType::I8.promote(DType::I8), DType::I8);
331 }
332
333 #[test]
334 fn promote_int_widening() {
335 assert_eq!(DType::I8.promote(DType::I16), DType::I16);
336 assert_eq!(DType::I32.promote(DType::I64), DType::I64);
337 }
338
339 #[test]
340 fn promote_int_to_float() {
341 assert_eq!(DType::I32.promote(DType::F32), DType::F32);
342 assert_eq!(DType::I64.promote(DType::F32), DType::F64);
343 assert_eq!(DType::I8.promote(DType::F16), DType::F32);
344 }
345
346 #[test]
347 fn promote_f16_bf16_goes_to_f32() {
348 assert_eq!(DType::F16.promote(DType::BF16), DType::F32);
349 assert_eq!(DType::BF16.promote(DType::F16), DType::F32);
350 }
351
352 #[test]
353 fn promote_is_commutative_for_well_defined_pairs() {
354 let pairs = [
355 (DType::F32, DType::F16),
356 (DType::I32, DType::F64),
357 (DType::Bool, DType::I8),
358 ];
359 for (a, b) in pairs {
360 assert_eq!(
361 a.promote(b),
362 b.promote(a),
363 "promote({a},{b}) should equal promote({b},{a})"
364 );
365 }
366 }
367
368 #[test]
369 fn scalar_constant_bytes_round_trips() {
370 assert_eq!(
371 scalar_constant_bytes(2.5, DType::F32).unwrap(),
372 2.5f32.to_le_bytes().to_vec()
373 );
374 assert_eq!(
375 scalar_constant_bytes(-1.0, DType::F64).unwrap(),
376 (-1.0f64).to_le_bytes().to_vec()
377 );
378 assert_eq!(
379 scalar_constant_bytes(7.0, DType::I32).unwrap(),
380 7i32.to_le_bytes()
381 );
382 assert_eq!(scalar_constant_bytes(0.0, DType::Bool).unwrap(), vec![0]);
383 assert_eq!(scalar_constant_bytes(1.0, DType::Bool).unwrap(), vec![1]);
384 }
385
386 #[test]
387 fn scalar_constant_bytes_rejects_out_of_range() {
388 assert!(scalar_constant_bytes(128.0, DType::I8).is_err());
389 assert!(scalar_constant_bytes(-1.0, DType::U32).is_err());
390 assert!(scalar_constant_bytes(9.223372036854776e18, DType::I64).is_err());
391 assert!(scalar_constant_bytes(2.5, DType::I32).is_err());
392 }
393
394 #[test]
395 fn scalar_constant_bytes_rejects_low_precision_direct() {
396 assert!(scalar_constant_bytes(1.0, DType::F16).is_err());
397 assert!(scalar_constant_bytes(1.0, DType::BF16).is_err());
398 assert!(scalar_constant_bytes(1.0, DType::C64).is_err());
399 }
400}