Skip to main content

numina/dtype/
mod.rs

1//! Data type definitions and implementations for arrays.
2//!
3//! Numina's dtype system has two layers:
4//! - [`DType`]/[`DTypeId`]: small, stable identifiers for dispatch and serialization
5//! - "concrete" types (e.g. [`Float16`], [`BFloat16`], [`Float8E4M3Fn`]) that define byte layout and
6//!   conversion behavior
7//!
8//! ## Serialization
9//! When converting values to/from bytes, Numina uses **little-endian** encoding.
10
11use std::fmt;
12
13// Core modules
14pub mod conversions;
15pub mod types;
16
17// Re-exports for convenience
18pub use types::{
19    BFloat8, BFloat16, Complex32, Complex64, Complex128, Float8E4M3Fn, Float8E5M2, Float16,
20    Float32, QuantizedI4, QuantizedU8,
21};
22
23/// Stable dtype identifier for Lamina/Laminax/Cetana serialization.
24///
25/// This is the value that should be stored in IR / runtime formats. The numeric values are part of
26/// Numina's compatibility contract and must not be renumbered.
27#[repr(transparent)]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct DTypeId(pub u8);
30
31impl DTypeId {
32    /// Stable ID for [`DType::F16`] (`float16`).
33    pub const F16: DTypeId = DTypeId(1);
34    /// Stable ID for [`DType::F32`] (`float32`).
35    pub const F32: DTypeId = DTypeId(2);
36    /// Stable ID for [`DType::F64`] (`float64`).
37    pub const F64: DTypeId = DTypeId(3);
38    /// Stable ID for [`DType::BF16`] (`bfloat16`).
39    pub const BF16: DTypeId = DTypeId(4);
40    /// Stable ID for [`DType::BF8`] (`bfloat8`).
41    pub const BF8: DTypeId = DTypeId(5);
42    /// Stable ID for [`DType::F8E4M3FN`] (`float8_e4m3fn`).
43    pub const F8E4M3FN: DTypeId = DTypeId(6);
44    /// Stable ID for [`DType::F8E5M2`] (`float8_e5m2`).
45    pub const F8E5M2: DTypeId = DTypeId(7);
46
47    /// Stable ID for [`DType::Complex32`] (`complex32`).
48    pub const COMPLEX32: DTypeId = DTypeId(50);
49    /// Stable ID for [`DType::Complex64`] (`complex64`).
50    pub const COMPLEX64: DTypeId = DTypeId(51);
51    /// Stable ID for [`DType::Complex128`] (`complex128`).
52    pub const COMPLEX128: DTypeId = DTypeId(52);
53
54    /// Stable ID for [`DType::I8`] (`int8`).
55    pub const I8: DTypeId = DTypeId(10);
56    /// Stable ID for [`DType::I16`] (`int16`).
57    pub const I16: DTypeId = DTypeId(11);
58    /// Stable ID for [`DType::I32`] (`int32`).
59    pub const I32: DTypeId = DTypeId(12);
60    /// Stable ID for [`DType::I64`] (`int64`).
61    pub const I64: DTypeId = DTypeId(13);
62
63    /// Stable ID for [`DType::U8`] (`uint8`).
64    pub const U8: DTypeId = DTypeId(20);
65    /// Stable ID for [`DType::U16`] (`uint16`).
66    pub const U16: DTypeId = DTypeId(21);
67    /// Stable ID for [`DType::U32`] (`uint32`).
68    pub const U32: DTypeId = DTypeId(22);
69    /// Stable ID for [`DType::U64`] (`uint64`).
70    pub const U64: DTypeId = DTypeId(23);
71
72    /// Stable ID for [`DType::Bool`] (`bool`).
73    pub const BOOL: DTypeId = DTypeId(30);
74
75    /// Stable ID for [`DType::QI4`] (`qi4`).
76    pub const QI4: DTypeId = DTypeId(40);
77    /// Stable ID for [`DType::QU8`] (`qu8`).
78    pub const QU8: DTypeId = DTypeId(41);
79}
80
81/// Static descriptor for dtype metadata.
82///
83/// Values are intended to be ABI-relevant (byte size, alignment, storage bits) and stable across
84/// Numina versions for a given [`DTypeId`].
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct DTypeInfo {
87    /// Stable dtype ID used for serialization.
88    pub id: DTypeId,
89    /// Human-readable name.
90    pub name: &'static str,
91    /// Logical size in bytes for one element.
92    pub byte_size: usize,
93    /// Number of storage bits used (e.g. `QI4` uses 4 bits per value but is stored in a byte).
94    pub storage_bits: u16,
95    /// Required alignment in bytes.
96    pub align: usize,
97    /// Whether the dtype should be treated as a "float-like" dtype for certain conversions.
98    pub is_float: bool,
99    /// Whether the dtype is an integer-like dtype.
100    pub is_int: bool,
101    /// Whether the dtype is boolean.
102    pub is_bool: bool,
103}
104
105/// Trait for mapping concrete Rust types to Numina dtypes.
106///
107/// # Safety
108/// Implementors must guarantee their in-memory representation matches the canonical layout for
109/// `Self::DTYPE` (size, alignment, and byte encoding), since Numina may reinterpret raw bytes as the
110/// corresponding primitive/type for that dtype.
111pub unsafe trait DTypeLike: Copy {
112    /// Static dtype descriptor for this Rust type
113    const DTYPE: DType;
114}
115
116/// Trait for dtype-backed value serialization.
117///
118/// Implementations must write values using little-endian encoding where applicable.
119pub trait DTypeValue: Copy {
120    /// Static dtype descriptor for this Rust type.
121    const DTYPE: DType;
122    /// Append the canonical little-endian encoding of this value to `out`.
123    fn write_bytes(self, out: &mut Vec<u8>);
124}
125
126/// Marker trait for types that can be used as `Array<T>` elements.
127///
128/// This is mainly a convenience bound for "Numina element" types used in generic APIs.
129pub trait DTypeElement: DTypeLike + DTypeValue + Copy + Default + Send + Sync + 'static {}
130
131impl<T> DTypeElement for T where T: DTypeLike + DTypeValue + Copy + Default + Send + Sync + 'static {}
132
133/// Trait for types that can be used as dtype candidates.
134///
135/// This trait is implemented by both the enum [`DType`] and the concrete dtype wrapper types.
136pub trait DTypeCandidate: Copy + Clone + PartialEq + Eq + std::hash::Hash {
137    /// Returns the size in bytes of this data type
138    fn size_bytes(&self) -> usize;
139
140    /// Returns true if this is a floating point type
141    fn is_float(&self) -> bool;
142
143    /// Returns true if this is an integer type
144    fn is_int(&self) -> bool;
145
146    /// Returns true if this is a signed integer type
147    fn is_signed_int(&self) -> bool {
148        self.is_int() && self.is_signed()
149    }
150
151    /// Returns true if this is an unsigned integer type
152    fn is_unsigned_int(&self) -> bool {
153        self.is_int() && !self.is_signed()
154    }
155
156    /// Returns true if this is a signed type (for integers)
157    fn is_signed(&self) -> bool;
158
159    /// Returns true if this is a boolean type
160    fn is_bool(&self) -> bool;
161
162    /// Returns a string representation of the type
163    fn type_name(&self) -> &'static str;
164
165    /// Convert from raw bytes (used internally)
166    /// # Safety
167    /// The caller must ensure the bytes are valid for this type
168    unsafe fn from_bytes(bytes: &[u8]) -> Self;
169
170    /// Convert to raw bytes (used internally)
171    fn to_bytes(&self) -> Vec<u8>;
172}
173
174/// Trait for float-like dtype conversions.
175///
176/// This is used to implement `encode_float_bytes` / `decode_float_bytes` for custom float formats.
177pub trait FloatDType: DTypeCandidate {
178    /// Convert from an `f32` (possibly lossy).
179    fn from_f32(value: f32) -> Self;
180    /// Convert to `f32` (possibly lossy).
181    fn to_f32(self) -> f32;
182}
183
184/// Data type enumeration for array elements.
185///
186/// Discriminants are explicit and match [`DTypeId`] values.
187#[repr(u8)]
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
189pub enum DType {
190    /// 16-bit floating point
191    F16 = 1,
192    /// 32-bit floating point
193    F32 = 2,
194    /// 64-bit floating point
195    F64 = 3,
196    /// Brain Float 16-bit
197    BF16 = 4,
198    /// Brain Float 8-bit
199    BF8 = 5,
200    /// Float8 E4M3FN
201    F8E4M3FN = 6,
202    /// Float8 E5M2
203    F8E5M2 = 7,
204    /// Complex with float16 components
205    Complex32 = 50,
206    /// Complex with float32 components
207    Complex64 = 51,
208    /// Complex with float64 components
209    Complex128 = 52,
210    /// 8-bit signed integer
211    I8 = 10,
212    /// 16-bit signed integer
213    I16 = 11,
214    /// 32-bit signed integer
215    I32 = 12,
216    /// 64-bit signed integer
217    I64 = 13,
218    /// 8-bit unsigned integer
219    U8 = 20,
220    /// 16-bit unsigned integer
221    U16 = 21,
222    /// 32-bit unsigned integer
223    U32 = 22,
224    /// 64-bit unsigned integer
225    U64 = 23,
226    /// Boolean
227    Bool = 30,
228    /// Quantized 4-bit signed integer
229    QI4 = 40,
230    /// Quantized 8-bit unsigned integer
231    QU8 = 41,
232}
233
234impl DTypeCandidate for DType {
235    fn size_bytes(&self) -> usize {
236        self.dtype_size_bytes()
237    }
238
239    fn is_float(&self) -> bool {
240        self.is_float()
241    }
242
243    fn is_int(&self) -> bool {
244        self.is_int()
245    }
246
247    fn is_signed(&self) -> bool {
248        self.is_signed()
249    }
250
251    fn is_bool(&self) -> bool {
252        self.is_bool()
253    }
254
255    fn type_name(&self) -> &'static str {
256        self.type_name()
257    }
258
259    unsafe fn from_bytes(_bytes: &[u8]) -> Self {
260        panic!("Cannot convert bytes to DType enum directly - use concrete types instead")
261    }
262
263    fn to_bytes(&self) -> Vec<u8> {
264        vec![self.id().0]
265    }
266}
267
268// Instance methods that delegate to the enum variants
269impl DType {
270    /// Returns the size in bytes of this data type
271    pub fn dtype_size_bytes(&self) -> usize {
272        match self {
273            DType::F16 => 2,
274            DType::F32 => 4,
275            DType::F64 => 8,
276            DType::BF16 => 2,
277            DType::BF8 => 1,
278            DType::F8E4M3FN => 1,
279            DType::F8E5M2 => 1,
280            DType::Complex32 => 4,
281            DType::Complex64 => 8,
282            DType::Complex128 => 16,
283            DType::I8 => 1,
284            DType::I16 => 2,
285            DType::I32 => 4,
286            DType::I64 => 8,
287            DType::U8 => 1,
288            DType::U16 => 2,
289            DType::U32 => 4,
290            DType::U64 => 8,
291            DType::Bool => 1,
292            DType::QI4 => 1, // 4 bits per value, but allocated per byte
293            DType::QU8 => 1,
294        }
295    }
296
297    /// Returns the storage size in bits for this dtype
298    pub fn storage_bits(&self) -> u16 {
299        match self {
300            DType::QI4 => 4,
301            _ => (self.dtype_size_bytes() * 8) as u16,
302        }
303    }
304
305    /// Returns the stable dtype id
306    pub fn id(&self) -> DTypeId {
307        DTypeId(*self as u8)
308    }
309
310    /// Convert from a stable dtype id
311    pub fn from_id(id: DTypeId) -> Option<Self> {
312        match id.0 {
313            1 => Some(DType::F16),
314            2 => Some(DType::F32),
315            3 => Some(DType::F64),
316            4 => Some(DType::BF16),
317            5 => Some(DType::BF8),
318            6 => Some(DType::F8E4M3FN),
319            7 => Some(DType::F8E5M2),
320            50 => Some(DType::Complex32),
321            51 => Some(DType::Complex64),
322            52 => Some(DType::Complex128),
323            10 => Some(DType::I8),
324            11 => Some(DType::I16),
325            12 => Some(DType::I32),
326            13 => Some(DType::I64),
327            20 => Some(DType::U8),
328            21 => Some(DType::U16),
329            22 => Some(DType::U32),
330            23 => Some(DType::U64),
331            30 => Some(DType::Bool),
332            40 => Some(DType::QI4),
333            41 => Some(DType::QU8),
334            _ => None,
335        }
336    }
337
338    /// Returns a static descriptor for this dtype
339    pub fn info(&self) -> DTypeInfo {
340        let (name, align) = match self {
341            DType::F16 => ("float16", 2),
342            DType::F32 => ("float32", 4),
343            DType::F64 => ("float64", 8),
344            DType::BF16 => ("bfloat16", 2),
345            DType::BF8 => ("bfloat8", 1),
346            DType::F8E4M3FN => ("float8_e4m3fn", 1),
347            DType::F8E5M2 => ("float8_e5m2", 1),
348            DType::Complex32 => ("complex32", 2),
349            DType::Complex64 => ("complex64", 4),
350            DType::Complex128 => ("complex128", 8),
351            DType::I8 => ("int8", 1),
352            DType::I16 => ("int16", 2),
353            DType::I32 => ("int32", 4),
354            DType::I64 => ("int64", 8),
355            DType::U8 => ("uint8", 1),
356            DType::U16 => ("uint16", 2),
357            DType::U32 => ("uint32", 4),
358            DType::U64 => ("uint64", 8),
359            DType::Bool => ("bool", 1),
360            DType::QI4 => ("quantized_i4", 1),
361            DType::QU8 => ("quantized_u8", 1),
362        };
363
364        DTypeInfo {
365            id: self.id(),
366            name,
367            byte_size: self.dtype_size_bytes(),
368            storage_bits: self.storage_bits(),
369            align,
370            is_float: self.is_float(),
371            is_int: self.is_int(),
372            is_bool: self.is_bool(),
373        }
374    }
375
376    /// Returns true if this is a floating point type
377    pub fn is_float(&self) -> bool {
378        matches!(
379            self,
380            DType::F16
381                | DType::F32
382                | DType::F64
383                | DType::BF16
384                | DType::BF8
385                | DType::F8E4M3FN
386                | DType::F8E5M2
387                | DType::Complex32
388                | DType::Complex64
389                | DType::Complex128
390        )
391    }
392
393    /// Returns true if this is an integer type
394    pub fn is_int(&self) -> bool {
395        matches!(
396            self,
397            DType::I8
398                | DType::I16
399                | DType::I32
400                | DType::I64
401                | DType::U8
402                | DType::U16
403                | DType::U32
404                | DType::U64
405                | DType::QI4
406                | DType::QU8
407        )
408    }
409
410    /// Returns true if this is a signed integer type
411    pub fn is_signed_int(&self) -> bool {
412        matches!(
413            self,
414            DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::QI4
415        )
416    }
417
418    /// Returns true if this is an unsigned integer type
419    pub fn is_unsigned_int(&self) -> bool {
420        matches!(
421            self,
422            DType::U8 | DType::U16 | DType::U32 | DType::U64 | DType::QU8
423        )
424    }
425
426    /// Returns true if this is a signed type (for integers)
427    pub fn is_signed(&self) -> bool {
428        self.is_signed_int()
429    }
430
431    /// Returns true if this is a boolean type
432    pub fn is_bool(&self) -> bool {
433        matches!(self, DType::Bool)
434    }
435
436    /// Returns a string representation of the type
437    pub fn type_name(&self) -> &'static str {
438        match self {
439            DType::F16 => "float16",
440            DType::F32 => "float32",
441            DType::F64 => "float64",
442            DType::BF16 => "bfloat16",
443            DType::BF8 => "bfloat8",
444            DType::F8E4M3FN => "float8_e4m3fn",
445            DType::F8E5M2 => "float8_e5m2",
446            DType::Complex32 => "complex32",
447            DType::Complex64 => "complex64",
448            DType::Complex128 => "complex128",
449            DType::I8 => "int8",
450            DType::I16 => "int16",
451            DType::I32 => "int32",
452            DType::I64 => "int64",
453            DType::U8 => "uint8",
454            DType::U16 => "uint16",
455            DType::U32 => "uint32",
456            DType::U64 => "uint64",
457            DType::Bool => "bool",
458            DType::QI4 => "quantized_i4",
459            DType::QU8 => "quantized_u8",
460        }
461    }
462}
463
464/// Returns `true` if `dtype` can be lossily converted to/from `f32` via
465/// [`encode_float_bytes`] / [`decode_float_bytes`].
466pub fn is_float_convertible(dtype: DType) -> bool {
467    matches!(
468        dtype,
469        DType::F16
470            | DType::F32
471            | DType::F64
472            | DType::BF16
473            | DType::BF8
474            | DType::F8E4M3FN
475            | DType::F8E5M2
476    )
477}
478
479fn decode_with<T: FloatDType>(bytes: &[u8]) -> Result<Vec<f32>, String> {
480    let element_size = std::mem::size_of::<T>();
481    if element_size == 0 || !bytes.len().is_multiple_of(element_size) {
482        return Err("invalid byte length for float dtype".to_string());
483    }
484
485    Ok(bytes
486        .chunks_exact(element_size)
487        .map(|chunk| unsafe { T::from_bytes(chunk) }.to_f32())
488        .collect())
489}
490
491fn encode_with<T: FloatDType>(values: &[f32]) -> Vec<u8> {
492    let element_size = std::mem::size_of::<T>();
493    let mut bytes = Vec::with_capacity(values.len() * element_size);
494    for value in values {
495        bytes.extend_from_slice(&T::from_f32(*value).to_bytes());
496    }
497    bytes
498}
499
500/// Decode a byte buffer of `dtype`-encoded floating point values into `Vec<f32>`.
501///
502/// Bytes are interpreted as little-endian.
503///
504/// # Errors
505/// Returns `Err` if the byte length is invalid for the dtype or if the dtype is unsupported.
506pub fn decode_float_bytes(dtype: DType, bytes: &[u8]) -> Result<Vec<f32>, String> {
507    match dtype {
508        DType::F32 => {
509            if !bytes.len().is_multiple_of(4) {
510                return Err("invalid f32 byte length".to_string());
511            }
512            Ok(bytes
513                .chunks_exact(4)
514                .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
515                .collect())
516        }
517        DType::F64 => {
518            if !bytes.len().is_multiple_of(8) {
519                return Err("invalid f64 byte length".to_string());
520            }
521            Ok(bytes
522                .chunks_exact(8)
523                .map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()) as f32)
524                .collect())
525        }
526        DType::F16 => decode_with::<Float16>(bytes),
527        DType::BF16 => decode_with::<BFloat16>(bytes),
528        DType::BF8 => decode_with::<BFloat8>(bytes),
529        DType::F8E4M3FN => decode_with::<Float8E4M3Fn>(bytes),
530        DType::F8E5M2 => decode_with::<Float8E5M2>(bytes),
531        _ => Err(format!("dtype {} is not supported", dtype)),
532    }
533}
534
535/// Encode `values` into a byte buffer for a given float-like dtype.
536///
537/// Bytes are emitted in little-endian order.
538///
539/// # Errors
540/// Returns `Err` if the dtype is unsupported.
541pub fn encode_float_bytes(dtype: DType, values: &[f32]) -> Result<Vec<u8>, String> {
542    match dtype {
543        DType::F32 => {
544            let mut bytes = Vec::with_capacity(values.len() * 4);
545            for value in values {
546                bytes.extend_from_slice(&value.to_le_bytes());
547            }
548            Ok(bytes)
549        }
550        DType::F64 => {
551            let mut bytes = Vec::with_capacity(values.len() * 8);
552            for value in values {
553                bytes.extend_from_slice(&(*value as f64).to_le_bytes());
554            }
555            Ok(bytes)
556        }
557        DType::F16 => Ok(encode_with::<Float16>(values)),
558        DType::BF16 => Ok(encode_with::<BFloat16>(values)),
559        DType::BF8 => Ok(encode_with::<BFloat8>(values)),
560        DType::F8E4M3FN => Ok(encode_with::<Float8E4M3Fn>(values)),
561        DType::F8E5M2 => Ok(encode_with::<Float8E5M2>(values)),
562        _ => Err(format!("dtype {} is not supported", dtype)),
563    }
564}
565
566impl fmt::Display for DType {
567    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568        write!(f, "{}", self.type_name())
569    }
570}
571
572// DTypeLike implementations for primitive Rust types
573unsafe impl DTypeLike for f32 {
574    const DTYPE: DType = DType::F32;
575}
576
577unsafe impl DTypeLike for f64 {
578    const DTYPE: DType = DType::F64;
579}
580
581impl DTypeValue for f32 {
582    const DTYPE: DType = DType::F32;
583
584    fn write_bytes(self, out: &mut Vec<u8>) {
585        out.extend_from_slice(&self.to_le_bytes());
586    }
587}
588
589impl DTypeValue for f64 {
590    const DTYPE: DType = DType::F64;
591
592    fn write_bytes(self, out: &mut Vec<u8>) {
593        out.extend_from_slice(&self.to_le_bytes());
594    }
595}
596
597impl DTypeValue for i8 {
598    const DTYPE: DType = DType::I8;
599
600    fn write_bytes(self, out: &mut Vec<u8>) {
601        out.push(self as u8);
602    }
603}
604
605impl DTypeValue for i16 {
606    const DTYPE: DType = DType::I16;
607
608    fn write_bytes(self, out: &mut Vec<u8>) {
609        out.extend_from_slice(&self.to_le_bytes());
610    }
611}
612
613impl DTypeValue for i32 {
614    const DTYPE: DType = DType::I32;
615
616    fn write_bytes(self, out: &mut Vec<u8>) {
617        out.extend_from_slice(&self.to_le_bytes());
618    }
619}
620
621impl DTypeValue for i64 {
622    const DTYPE: DType = DType::I64;
623
624    fn write_bytes(self, out: &mut Vec<u8>) {
625        out.extend_from_slice(&self.to_le_bytes());
626    }
627}
628
629impl DTypeValue for u8 {
630    const DTYPE: DType = DType::U8;
631
632    fn write_bytes(self, out: &mut Vec<u8>) {
633        out.push(self);
634    }
635}
636
637impl DTypeValue for u16 {
638    const DTYPE: DType = DType::U16;
639
640    fn write_bytes(self, out: &mut Vec<u8>) {
641        out.extend_from_slice(&self.to_le_bytes());
642    }
643}
644
645impl DTypeValue for u32 {
646    const DTYPE: DType = DType::U32;
647
648    fn write_bytes(self, out: &mut Vec<u8>) {
649        out.extend_from_slice(&self.to_le_bytes());
650    }
651}
652
653impl DTypeValue for u64 {
654    const DTYPE: DType = DType::U64;
655
656    fn write_bytes(self, out: &mut Vec<u8>) {
657        out.extend_from_slice(&self.to_le_bytes());
658    }
659}
660
661impl DTypeValue for bool {
662    const DTYPE: DType = DType::Bool;
663
664    fn write_bytes(self, out: &mut Vec<u8>) {
665        out.push(u8::from(self));
666    }
667}
668
669impl<T> DTypeValue for T
670where
671    T: DTypeCandidate + DTypeLike,
672{
673    const DTYPE: DType = T::DTYPE;
674
675    fn write_bytes(self, out: &mut Vec<u8>) {
676        out.extend_from_slice(&self.to_bytes());
677    }
678}
679
680unsafe impl DTypeLike for i8 {
681    const DTYPE: DType = DType::I8;
682}
683
684unsafe impl DTypeLike for i16 {
685    const DTYPE: DType = DType::I16;
686}
687
688unsafe impl DTypeLike for i32 {
689    const DTYPE: DType = DType::I32;
690}
691
692unsafe impl DTypeLike for i64 {
693    const DTYPE: DType = DType::I64;
694}
695
696unsafe impl DTypeLike for u8 {
697    const DTYPE: DType = DType::U8;
698}
699
700unsafe impl DTypeLike for u16 {
701    const DTYPE: DType = DType::U16;
702}
703
704unsafe impl DTypeLike for u32 {
705    const DTYPE: DType = DType::U32;
706}
707
708unsafe impl DTypeLike for u64 {
709    const DTYPE: DType = DType::U64;
710}
711
712unsafe impl DTypeLike for bool {
713    const DTYPE: DType = DType::Bool;
714}
715
716unsafe impl DTypeLike for BFloat16 {
717    const DTYPE: DType = DType::BF16;
718}
719
720unsafe impl DTypeLike for BFloat8 {
721    const DTYPE: DType = DType::BF8;
722}
723
724unsafe impl DTypeLike for Float16 {
725    const DTYPE: DType = DType::F16;
726}
727
728unsafe impl DTypeLike for Float32 {
729    const DTYPE: DType = DType::F32;
730}
731
732unsafe impl DTypeLike for Float8E4M3Fn {
733    const DTYPE: DType = DType::F8E4M3FN;
734}
735
736unsafe impl DTypeLike for Float8E5M2 {
737    const DTYPE: DType = DType::F8E5M2;
738}
739
740unsafe impl DTypeLike for Complex32 {
741    const DTYPE: DType = DType::Complex32;
742}
743
744unsafe impl DTypeLike for Complex64 {
745    const DTYPE: DType = DType::Complex64;
746}
747
748unsafe impl DTypeLike for Complex128 {
749    const DTYPE: DType = DType::Complex128;
750}
751
752unsafe impl DTypeLike for QuantizedI4 {
753    const DTYPE: DType = DType::QI4;
754}
755
756unsafe impl DTypeLike for QuantizedU8 {
757    const DTYPE: DType = DType::QU8;
758}
759
760// Convenience constants
761/// Alias for [`DType::F32`].
762pub const F32: DType = DType::F32;
763/// Alias for [`DType::F64`].
764pub const F64: DType = DType::F64;
765/// Alias for [`DType::F16`].
766pub const F16: DType = DType::F16;
767/// Alias for [`DType::F8E4M3FN`].
768pub const F8E4M3FN: DType = DType::F8E4M3FN;
769/// Alias for [`DType::F8E5M2`].
770pub const F8E5M2: DType = DType::F8E5M2;
771/// Alias for [`DType::F16`].
772pub const FLOAT16: DType = DType::F16;
773/// Alias for [`DType::F32`].
774pub const FLOAT32: DType = DType::F32;
775/// Alias for [`DType::F64`].
776pub const FLOAT64: DType = DType::F64;
777/// Alias for [`DType::F8E4M3FN`].
778pub const FLOAT8_E4M3FN: DType = DType::F8E4M3FN;
779/// Alias for [`DType::F8E5M2`].
780pub const FLOAT8_E5M2: DType = DType::F8E5M2;
781/// Alias for [`DType::I8`].
782pub const I8: DType = DType::I8;
783/// Alias for [`DType::I8`].
784pub const INT8: DType = DType::I8;
785/// Alias for [`DType::I16`].
786pub const I16: DType = DType::I16;
787/// Alias for [`DType::I16`].
788pub const INT16: DType = DType::I16;
789/// Alias for [`DType::I32`].
790pub const I32: DType = DType::I32;
791/// Alias for [`DType::I32`].
792pub const INT32: DType = DType::I32;
793/// Alias for [`DType::I64`].
794pub const I64: DType = DType::I64;
795/// Alias for [`DType::I64`].
796pub const INT64: DType = DType::I64;
797/// Alias for [`DType::U8`].
798pub const U8: DType = DType::U8;
799/// Alias for [`DType::U8`].
800pub const UINT8: DType = DType::U8;
801/// Alias for [`DType::U16`].
802pub const U16: DType = DType::U16;
803/// Alias for [`DType::U16`].
804pub const UINT16: DType = DType::U16;
805/// Alias for [`DType::U32`].
806pub const U32: DType = DType::U32;
807/// Alias for [`DType::U32`].
808pub const UINT32: DType = DType::U32;
809/// Alias for [`DType::U64`].
810pub const U64: DType = DType::U64;
811/// Alias for [`DType::U64`].
812pub const UINT64: DType = DType::U64;
813/// Alias for [`DType::Bool`].
814pub const BOOL: DType = DType::Bool;
815/// Alias for [`DType::BF16`].
816pub const BF16: DType = DType::BF16;
817/// Alias for [`DType::BF16`].
818pub const BFLOAT16: DType = DType::BF16;
819/// Alias for [`DType::BF8`].
820pub const BF8: DType = DType::BF8;
821/// Alias for [`DType::BF8`].
822pub const BFLOAT8: DType = DType::BF8;
823/// Alias for [`DType::Complex32`].
824pub const COMPLEX32: DType = DType::Complex32;
825/// Alias for [`DType::Complex64`].
826pub const COMPLEX64: DType = DType::Complex64;
827/// Alias for [`DType::Complex128`].
828pub const COMPLEX128: DType = DType::Complex128;
829/// Alias for [`DType::QI4`].
830pub const QI4: DType = DType::QI4;
831/// Alias for [`DType::QU8`].
832pub const QU8: DType = DType::QU8;
833
834// Constants are already defined above, no need to re-export
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    #[test]
841    fn dtype_sizes() {
842        assert_eq!(F32.size_bytes(), 4);
843        assert_eq!(F64.size_bytes(), 8);
844        assert_eq!(I32.size_bytes(), 4);
845        assert_eq!(U8.size_bytes(), 1);
846        assert_eq!(BOOL.size_bytes(), 1);
847    }
848
849    #[test]
850    fn dtype_classification() {
851        assert!(F32.is_float());
852        assert!(!F32.is_int());
853
854        assert!(I32.is_int());
855        assert!(I32.is_signed_int());
856        assert!(!I32.is_unsigned_int());
857        assert!(!I32.is_float());
858
859        assert!(U32.is_int());
860        assert!(!U32.is_signed_int());
861        assert!(U32.is_unsigned_int());
862
863        assert!(BOOL.is_bool());
864        assert!(!BOOL.is_float());
865        assert!(!BOOL.is_int());
866    }
867
868    #[test]
869    fn dtype_display() {
870        assert_eq!(format!("{}", F32), "float32");
871        assert_eq!(format!("{}", I64), "int64");
872        assert_eq!(format!("{}", BOOL), "bool");
873    }
874
875    #[test]
876    fn dtype_info_table() {
877        let table = [
878            (DType::F16, 1, "float16", 2usize, 16u16, 2usize),
879            (DType::F32, 2, "float32", 4, 32, 4),
880            (DType::F64, 3, "float64", 8, 64, 8),
881            (DType::BF16, 4, "bfloat16", 2, 16, 2),
882            (DType::BF8, 5, "bfloat8", 1, 8, 1),
883            (DType::F8E4M3FN, 6, "float8_e4m3fn", 1, 8, 1),
884            (DType::F8E5M2, 7, "float8_e5m2", 1, 8, 1),
885            (DType::Complex32, 50, "complex32", 4, 32, 2),
886            (DType::Complex64, 51, "complex64", 8, 64, 4),
887            (DType::Complex128, 52, "complex128", 16, 128, 8),
888            (DType::I8, 10, "int8", 1, 8, 1),
889            (DType::I16, 11, "int16", 2, 16, 2),
890            (DType::I32, 12, "int32", 4, 32, 4),
891            (DType::I64, 13, "int64", 8, 64, 8),
892            (DType::U8, 20, "uint8", 1, 8, 1),
893            (DType::U16, 21, "uint16", 2, 16, 2),
894            (DType::U32, 22, "uint32", 4, 32, 4),
895            (DType::U64, 23, "uint64", 8, 64, 8),
896            (DType::Bool, 30, "bool", 1, 8, 1),
897            (DType::QI4, 40, "quantized_i4", 1, 4, 1),
898            (DType::QU8, 41, "quantized_u8", 1, 8, 1),
899        ];
900
901        for (dtype, id, name, bytes, bits, align) in table {
902            let info = dtype.info();
903            assert_eq!(info.id.0, id);
904            assert_eq!(info.name, name);
905            assert_eq!(info.byte_size, bytes);
906            assert_eq!(info.storage_bits, bits);
907            assert_eq!(info.align, align);
908            assert_eq!(DType::from_id(info.id), Some(dtype));
909        }
910    }
911}