1use std::fmt;
12
13pub mod conversions;
15pub mod types;
16
17pub use types::{
19 BFloat8, BFloat16, Complex32, Complex64, Complex128, Float8E4M3Fn, Float8E5M2, Float16,
20 Float32, QuantizedI4, QuantizedU8,
21};
22
23#[repr(transparent)]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct DTypeId(pub u8);
30
31impl DTypeId {
32 pub const F16: DTypeId = DTypeId(1);
34 pub const F32: DTypeId = DTypeId(2);
36 pub const F64: DTypeId = DTypeId(3);
38 pub const BF16: DTypeId = DTypeId(4);
40 pub const BF8: DTypeId = DTypeId(5);
42 pub const F8E4M3FN: DTypeId = DTypeId(6);
44 pub const F8E5M2: DTypeId = DTypeId(7);
46
47 pub const COMPLEX32: DTypeId = DTypeId(50);
49 pub const COMPLEX64: DTypeId = DTypeId(51);
51 pub const COMPLEX128: DTypeId = DTypeId(52);
53
54 pub const I8: DTypeId = DTypeId(10);
56 pub const I16: DTypeId = DTypeId(11);
58 pub const I32: DTypeId = DTypeId(12);
60 pub const I64: DTypeId = DTypeId(13);
62
63 pub const U8: DTypeId = DTypeId(20);
65 pub const U16: DTypeId = DTypeId(21);
67 pub const U32: DTypeId = DTypeId(22);
69 pub const U64: DTypeId = DTypeId(23);
71
72 pub const BOOL: DTypeId = DTypeId(30);
74
75 pub const QI4: DTypeId = DTypeId(40);
77 pub const QU8: DTypeId = DTypeId(41);
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct DTypeInfo {
87 pub id: DTypeId,
89 pub name: &'static str,
91 pub byte_size: usize,
93 pub storage_bits: u16,
95 pub align: usize,
97 pub is_float: bool,
99 pub is_int: bool,
101 pub is_bool: bool,
103}
104
105pub unsafe trait DTypeLike: Copy {
112 const DTYPE: DType;
114}
115
116pub trait DTypeValue: Copy {
120 const DTYPE: DType;
122 fn write_bytes(self, out: &mut Vec<u8>);
124}
125
126pub trait DTypeElement: DTypeLike + DTypeValue + Copy + Default + Send + Sync + 'static {}
130
131impl<T> DTypeElement for T where T: DTypeLike + DTypeValue + Copy + Default + Send + Sync + 'static {}
132
133pub trait DTypeCandidate: Copy + Clone + PartialEq + Eq + std::hash::Hash {
137 fn size_bytes(&self) -> usize;
139
140 fn is_float(&self) -> bool;
142
143 fn is_int(&self) -> bool;
145
146 fn is_signed_int(&self) -> bool {
148 self.is_int() && self.is_signed()
149 }
150
151 fn is_unsigned_int(&self) -> bool {
153 self.is_int() && !self.is_signed()
154 }
155
156 fn is_signed(&self) -> bool;
158
159 fn is_bool(&self) -> bool;
161
162 fn type_name(&self) -> &'static str;
164
165 unsafe fn from_bytes(bytes: &[u8]) -> Self;
169
170 fn to_bytes(&self) -> Vec<u8>;
172}
173
174pub trait FloatDType: DTypeCandidate {
178 fn from_f32(value: f32) -> Self;
180 fn to_f32(self) -> f32;
182}
183
184#[repr(u8)]
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
189pub enum DType {
190 F16 = 1,
192 F32 = 2,
194 F64 = 3,
196 BF16 = 4,
198 BF8 = 5,
200 F8E4M3FN = 6,
202 F8E5M2 = 7,
204 Complex32 = 50,
206 Complex64 = 51,
208 Complex128 = 52,
210 I8 = 10,
212 I16 = 11,
214 I32 = 12,
216 I64 = 13,
218 U8 = 20,
220 U16 = 21,
222 U32 = 22,
224 U64 = 23,
226 Bool = 30,
228 QI4 = 40,
230 QU8 = 41,
232}
233
234impl DTypeCandidate for DType {
235 fn size_bytes(&self) -> usize {
236 self.dtype_size_bytes()
237 }
238
239 fn is_float(&self) -> bool {
240 self.is_float()
241 }
242
243 fn is_int(&self) -> bool {
244 self.is_int()
245 }
246
247 fn is_signed(&self) -> bool {
248 self.is_signed()
249 }
250
251 fn is_bool(&self) -> bool {
252 self.is_bool()
253 }
254
255 fn type_name(&self) -> &'static str {
256 self.type_name()
257 }
258
259 unsafe fn from_bytes(_bytes: &[u8]) -> Self {
260 panic!("Cannot convert bytes to DType enum directly - use concrete types instead")
261 }
262
263 fn to_bytes(&self) -> Vec<u8> {
264 vec![self.id().0]
265 }
266}
267
268impl DType {
270 pub fn dtype_size_bytes(&self) -> usize {
272 match self {
273 DType::F16 => 2,
274 DType::F32 => 4,
275 DType::F64 => 8,
276 DType::BF16 => 2,
277 DType::BF8 => 1,
278 DType::F8E4M3FN => 1,
279 DType::F8E5M2 => 1,
280 DType::Complex32 => 4,
281 DType::Complex64 => 8,
282 DType::Complex128 => 16,
283 DType::I8 => 1,
284 DType::I16 => 2,
285 DType::I32 => 4,
286 DType::I64 => 8,
287 DType::U8 => 1,
288 DType::U16 => 2,
289 DType::U32 => 4,
290 DType::U64 => 8,
291 DType::Bool => 1,
292 DType::QI4 => 1, DType::QU8 => 1,
294 }
295 }
296
297 pub fn storage_bits(&self) -> u16 {
299 match self {
300 DType::QI4 => 4,
301 _ => (self.dtype_size_bytes() * 8) as u16,
302 }
303 }
304
305 pub fn id(&self) -> DTypeId {
307 DTypeId(*self as u8)
308 }
309
310 pub fn from_id(id: DTypeId) -> Option<Self> {
312 match id.0 {
313 1 => Some(DType::F16),
314 2 => Some(DType::F32),
315 3 => Some(DType::F64),
316 4 => Some(DType::BF16),
317 5 => Some(DType::BF8),
318 6 => Some(DType::F8E4M3FN),
319 7 => Some(DType::F8E5M2),
320 50 => Some(DType::Complex32),
321 51 => Some(DType::Complex64),
322 52 => Some(DType::Complex128),
323 10 => Some(DType::I8),
324 11 => Some(DType::I16),
325 12 => Some(DType::I32),
326 13 => Some(DType::I64),
327 20 => Some(DType::U8),
328 21 => Some(DType::U16),
329 22 => Some(DType::U32),
330 23 => Some(DType::U64),
331 30 => Some(DType::Bool),
332 40 => Some(DType::QI4),
333 41 => Some(DType::QU8),
334 _ => None,
335 }
336 }
337
338 pub fn info(&self) -> DTypeInfo {
340 let (name, align) = match self {
341 DType::F16 => ("float16", 2),
342 DType::F32 => ("float32", 4),
343 DType::F64 => ("float64", 8),
344 DType::BF16 => ("bfloat16", 2),
345 DType::BF8 => ("bfloat8", 1),
346 DType::F8E4M3FN => ("float8_e4m3fn", 1),
347 DType::F8E5M2 => ("float8_e5m2", 1),
348 DType::Complex32 => ("complex32", 2),
349 DType::Complex64 => ("complex64", 4),
350 DType::Complex128 => ("complex128", 8),
351 DType::I8 => ("int8", 1),
352 DType::I16 => ("int16", 2),
353 DType::I32 => ("int32", 4),
354 DType::I64 => ("int64", 8),
355 DType::U8 => ("uint8", 1),
356 DType::U16 => ("uint16", 2),
357 DType::U32 => ("uint32", 4),
358 DType::U64 => ("uint64", 8),
359 DType::Bool => ("bool", 1),
360 DType::QI4 => ("quantized_i4", 1),
361 DType::QU8 => ("quantized_u8", 1),
362 };
363
364 DTypeInfo {
365 id: self.id(),
366 name,
367 byte_size: self.dtype_size_bytes(),
368 storage_bits: self.storage_bits(),
369 align,
370 is_float: self.is_float(),
371 is_int: self.is_int(),
372 is_bool: self.is_bool(),
373 }
374 }
375
376 pub fn is_float(&self) -> bool {
378 matches!(
379 self,
380 DType::F16
381 | DType::F32
382 | DType::F64
383 | DType::BF16
384 | DType::BF8
385 | DType::F8E4M3FN
386 | DType::F8E5M2
387 | DType::Complex32
388 | DType::Complex64
389 | DType::Complex128
390 )
391 }
392
393 pub fn is_int(&self) -> bool {
395 matches!(
396 self,
397 DType::I8
398 | DType::I16
399 | DType::I32
400 | DType::I64
401 | DType::U8
402 | DType::U16
403 | DType::U32
404 | DType::U64
405 | DType::QI4
406 | DType::QU8
407 )
408 }
409
410 pub fn is_signed_int(&self) -> bool {
412 matches!(
413 self,
414 DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::QI4
415 )
416 }
417
418 pub fn is_unsigned_int(&self) -> bool {
420 matches!(
421 self,
422 DType::U8 | DType::U16 | DType::U32 | DType::U64 | DType::QU8
423 )
424 }
425
426 pub fn is_signed(&self) -> bool {
428 self.is_signed_int()
429 }
430
431 pub fn is_bool(&self) -> bool {
433 matches!(self, DType::Bool)
434 }
435
436 pub fn type_name(&self) -> &'static str {
438 match self {
439 DType::F16 => "float16",
440 DType::F32 => "float32",
441 DType::F64 => "float64",
442 DType::BF16 => "bfloat16",
443 DType::BF8 => "bfloat8",
444 DType::F8E4M3FN => "float8_e4m3fn",
445 DType::F8E5M2 => "float8_e5m2",
446 DType::Complex32 => "complex32",
447 DType::Complex64 => "complex64",
448 DType::Complex128 => "complex128",
449 DType::I8 => "int8",
450 DType::I16 => "int16",
451 DType::I32 => "int32",
452 DType::I64 => "int64",
453 DType::U8 => "uint8",
454 DType::U16 => "uint16",
455 DType::U32 => "uint32",
456 DType::U64 => "uint64",
457 DType::Bool => "bool",
458 DType::QI4 => "quantized_i4",
459 DType::QU8 => "quantized_u8",
460 }
461 }
462}
463
464pub fn is_float_convertible(dtype: DType) -> bool {
467 matches!(
468 dtype,
469 DType::F16
470 | DType::F32
471 | DType::F64
472 | DType::BF16
473 | DType::BF8
474 | DType::F8E4M3FN
475 | DType::F8E5M2
476 )
477}
478
479fn decode_with<T: FloatDType>(bytes: &[u8]) -> Result<Vec<f32>, String> {
480 let element_size = std::mem::size_of::<T>();
481 if element_size == 0 || !bytes.len().is_multiple_of(element_size) {
482 return Err("invalid byte length for float dtype".to_string());
483 }
484
485 Ok(bytes
486 .chunks_exact(element_size)
487 .map(|chunk| unsafe { T::from_bytes(chunk) }.to_f32())
488 .collect())
489}
490
491fn encode_with<T: FloatDType>(values: &[f32]) -> Vec<u8> {
492 let element_size = std::mem::size_of::<T>();
493 let mut bytes = Vec::with_capacity(values.len() * element_size);
494 for value in values {
495 bytes.extend_from_slice(&T::from_f32(*value).to_bytes());
496 }
497 bytes
498}
499
500pub fn decode_float_bytes(dtype: DType, bytes: &[u8]) -> Result<Vec<f32>, String> {
507 match dtype {
508 DType::F32 => {
509 if !bytes.len().is_multiple_of(4) {
510 return Err("invalid f32 byte length".to_string());
511 }
512 Ok(bytes
513 .chunks_exact(4)
514 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
515 .collect())
516 }
517 DType::F64 => {
518 if !bytes.len().is_multiple_of(8) {
519 return Err("invalid f64 byte length".to_string());
520 }
521 Ok(bytes
522 .chunks_exact(8)
523 .map(|chunk| f64::from_le_bytes(chunk.try_into().unwrap()) as f32)
524 .collect())
525 }
526 DType::F16 => decode_with::<Float16>(bytes),
527 DType::BF16 => decode_with::<BFloat16>(bytes),
528 DType::BF8 => decode_with::<BFloat8>(bytes),
529 DType::F8E4M3FN => decode_with::<Float8E4M3Fn>(bytes),
530 DType::F8E5M2 => decode_with::<Float8E5M2>(bytes),
531 _ => Err(format!("dtype {} is not supported", dtype)),
532 }
533}
534
535pub fn encode_float_bytes(dtype: DType, values: &[f32]) -> Result<Vec<u8>, String> {
542 match dtype {
543 DType::F32 => {
544 let mut bytes = Vec::with_capacity(values.len() * 4);
545 for value in values {
546 bytes.extend_from_slice(&value.to_le_bytes());
547 }
548 Ok(bytes)
549 }
550 DType::F64 => {
551 let mut bytes = Vec::with_capacity(values.len() * 8);
552 for value in values {
553 bytes.extend_from_slice(&(*value as f64).to_le_bytes());
554 }
555 Ok(bytes)
556 }
557 DType::F16 => Ok(encode_with::<Float16>(values)),
558 DType::BF16 => Ok(encode_with::<BFloat16>(values)),
559 DType::BF8 => Ok(encode_with::<BFloat8>(values)),
560 DType::F8E4M3FN => Ok(encode_with::<Float8E4M3Fn>(values)),
561 DType::F8E5M2 => Ok(encode_with::<Float8E5M2>(values)),
562 _ => Err(format!("dtype {} is not supported", dtype)),
563 }
564}
565
566impl fmt::Display for DType {
567 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568 write!(f, "{}", self.type_name())
569 }
570}
571
572unsafe impl DTypeLike for f32 {
574 const DTYPE: DType = DType::F32;
575}
576
577unsafe impl DTypeLike for f64 {
578 const DTYPE: DType = DType::F64;
579}
580
581impl DTypeValue for f32 {
582 const DTYPE: DType = DType::F32;
583
584 fn write_bytes(self, out: &mut Vec<u8>) {
585 out.extend_from_slice(&self.to_le_bytes());
586 }
587}
588
589impl DTypeValue for f64 {
590 const DTYPE: DType = DType::F64;
591
592 fn write_bytes(self, out: &mut Vec<u8>) {
593 out.extend_from_slice(&self.to_le_bytes());
594 }
595}
596
597impl DTypeValue for i8 {
598 const DTYPE: DType = DType::I8;
599
600 fn write_bytes(self, out: &mut Vec<u8>) {
601 out.push(self as u8);
602 }
603}
604
605impl DTypeValue for i16 {
606 const DTYPE: DType = DType::I16;
607
608 fn write_bytes(self, out: &mut Vec<u8>) {
609 out.extend_from_slice(&self.to_le_bytes());
610 }
611}
612
613impl DTypeValue for i32 {
614 const DTYPE: DType = DType::I32;
615
616 fn write_bytes(self, out: &mut Vec<u8>) {
617 out.extend_from_slice(&self.to_le_bytes());
618 }
619}
620
621impl DTypeValue for i64 {
622 const DTYPE: DType = DType::I64;
623
624 fn write_bytes(self, out: &mut Vec<u8>) {
625 out.extend_from_slice(&self.to_le_bytes());
626 }
627}
628
629impl DTypeValue for u8 {
630 const DTYPE: DType = DType::U8;
631
632 fn write_bytes(self, out: &mut Vec<u8>) {
633 out.push(self);
634 }
635}
636
637impl DTypeValue for u16 {
638 const DTYPE: DType = DType::U16;
639
640 fn write_bytes(self, out: &mut Vec<u8>) {
641 out.extend_from_slice(&self.to_le_bytes());
642 }
643}
644
645impl DTypeValue for u32 {
646 const DTYPE: DType = DType::U32;
647
648 fn write_bytes(self, out: &mut Vec<u8>) {
649 out.extend_from_slice(&self.to_le_bytes());
650 }
651}
652
653impl DTypeValue for u64 {
654 const DTYPE: DType = DType::U64;
655
656 fn write_bytes(self, out: &mut Vec<u8>) {
657 out.extend_from_slice(&self.to_le_bytes());
658 }
659}
660
661impl DTypeValue for bool {
662 const DTYPE: DType = DType::Bool;
663
664 fn write_bytes(self, out: &mut Vec<u8>) {
665 out.push(u8::from(self));
666 }
667}
668
669impl<T> DTypeValue for T
670where
671 T: DTypeCandidate + DTypeLike,
672{
673 const DTYPE: DType = T::DTYPE;
674
675 fn write_bytes(self, out: &mut Vec<u8>) {
676 out.extend_from_slice(&self.to_bytes());
677 }
678}
679
680unsafe impl DTypeLike for i8 {
681 const DTYPE: DType = DType::I8;
682}
683
684unsafe impl DTypeLike for i16 {
685 const DTYPE: DType = DType::I16;
686}
687
688unsafe impl DTypeLike for i32 {
689 const DTYPE: DType = DType::I32;
690}
691
692unsafe impl DTypeLike for i64 {
693 const DTYPE: DType = DType::I64;
694}
695
696unsafe impl DTypeLike for u8 {
697 const DTYPE: DType = DType::U8;
698}
699
700unsafe impl DTypeLike for u16 {
701 const DTYPE: DType = DType::U16;
702}
703
704unsafe impl DTypeLike for u32 {
705 const DTYPE: DType = DType::U32;
706}
707
708unsafe impl DTypeLike for u64 {
709 const DTYPE: DType = DType::U64;
710}
711
712unsafe impl DTypeLike for bool {
713 const DTYPE: DType = DType::Bool;
714}
715
716unsafe impl DTypeLike for BFloat16 {
717 const DTYPE: DType = DType::BF16;
718}
719
720unsafe impl DTypeLike for BFloat8 {
721 const DTYPE: DType = DType::BF8;
722}
723
724unsafe impl DTypeLike for Float16 {
725 const DTYPE: DType = DType::F16;
726}
727
728unsafe impl DTypeLike for Float32 {
729 const DTYPE: DType = DType::F32;
730}
731
732unsafe impl DTypeLike for Float8E4M3Fn {
733 const DTYPE: DType = DType::F8E4M3FN;
734}
735
736unsafe impl DTypeLike for Float8E5M2 {
737 const DTYPE: DType = DType::F8E5M2;
738}
739
740unsafe impl DTypeLike for Complex32 {
741 const DTYPE: DType = DType::Complex32;
742}
743
744unsafe impl DTypeLike for Complex64 {
745 const DTYPE: DType = DType::Complex64;
746}
747
748unsafe impl DTypeLike for Complex128 {
749 const DTYPE: DType = DType::Complex128;
750}
751
752unsafe impl DTypeLike for QuantizedI4 {
753 const DTYPE: DType = DType::QI4;
754}
755
756unsafe impl DTypeLike for QuantizedU8 {
757 const DTYPE: DType = DType::QU8;
758}
759
760pub const F32: DType = DType::F32;
763pub const F64: DType = DType::F64;
765pub const F16: DType = DType::F16;
767pub const F8E4M3FN: DType = DType::F8E4M3FN;
769pub const F8E5M2: DType = DType::F8E5M2;
771pub const FLOAT16: DType = DType::F16;
773pub const FLOAT32: DType = DType::F32;
775pub const FLOAT64: DType = DType::F64;
777pub const FLOAT8_E4M3FN: DType = DType::F8E4M3FN;
779pub const FLOAT8_E5M2: DType = DType::F8E5M2;
781pub const I8: DType = DType::I8;
783pub const INT8: DType = DType::I8;
785pub const I16: DType = DType::I16;
787pub const INT16: DType = DType::I16;
789pub const I32: DType = DType::I32;
791pub const INT32: DType = DType::I32;
793pub const I64: DType = DType::I64;
795pub const INT64: DType = DType::I64;
797pub const U8: DType = DType::U8;
799pub const UINT8: DType = DType::U8;
801pub const U16: DType = DType::U16;
803pub const UINT16: DType = DType::U16;
805pub const U32: DType = DType::U32;
807pub const UINT32: DType = DType::U32;
809pub const U64: DType = DType::U64;
811pub const UINT64: DType = DType::U64;
813pub const BOOL: DType = DType::Bool;
815pub const BF16: DType = DType::BF16;
817pub const BFLOAT16: DType = DType::BF16;
819pub const BF8: DType = DType::BF8;
821pub const BFLOAT8: DType = DType::BF8;
823pub const COMPLEX32: DType = DType::Complex32;
825pub const COMPLEX64: DType = DType::Complex64;
827pub const COMPLEX128: DType = DType::Complex128;
829pub const QI4: DType = DType::QI4;
831pub const QU8: DType = DType::QU8;
833
834#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[test]
841 fn dtype_sizes() {
842 assert_eq!(F32.size_bytes(), 4);
843 assert_eq!(F64.size_bytes(), 8);
844 assert_eq!(I32.size_bytes(), 4);
845 assert_eq!(U8.size_bytes(), 1);
846 assert_eq!(BOOL.size_bytes(), 1);
847 }
848
849 #[test]
850 fn dtype_classification() {
851 assert!(F32.is_float());
852 assert!(!F32.is_int());
853
854 assert!(I32.is_int());
855 assert!(I32.is_signed_int());
856 assert!(!I32.is_unsigned_int());
857 assert!(!I32.is_float());
858
859 assert!(U32.is_int());
860 assert!(!U32.is_signed_int());
861 assert!(U32.is_unsigned_int());
862
863 assert!(BOOL.is_bool());
864 assert!(!BOOL.is_float());
865 assert!(!BOOL.is_int());
866 }
867
868 #[test]
869 fn dtype_display() {
870 assert_eq!(format!("{}", F32), "float32");
871 assert_eq!(format!("{}", I64), "int64");
872 assert_eq!(format!("{}", BOOL), "bool");
873 }
874
875 #[test]
876 fn dtype_info_table() {
877 let table = [
878 (DType::F16, 1, "float16", 2usize, 16u16, 2usize),
879 (DType::F32, 2, "float32", 4, 32, 4),
880 (DType::F64, 3, "float64", 8, 64, 8),
881 (DType::BF16, 4, "bfloat16", 2, 16, 2),
882 (DType::BF8, 5, "bfloat8", 1, 8, 1),
883 (DType::F8E4M3FN, 6, "float8_e4m3fn", 1, 8, 1),
884 (DType::F8E5M2, 7, "float8_e5m2", 1, 8, 1),
885 (DType::Complex32, 50, "complex32", 4, 32, 2),
886 (DType::Complex64, 51, "complex64", 8, 64, 4),
887 (DType::Complex128, 52, "complex128", 16, 128, 8),
888 (DType::I8, 10, "int8", 1, 8, 1),
889 (DType::I16, 11, "int16", 2, 16, 2),
890 (DType::I32, 12, "int32", 4, 32, 4),
891 (DType::I64, 13, "int64", 8, 64, 8),
892 (DType::U8, 20, "uint8", 1, 8, 1),
893 (DType::U16, 21, "uint16", 2, 16, 2),
894 (DType::U32, 22, "uint32", 4, 32, 4),
895 (DType::U64, 23, "uint64", 8, 64, 8),
896 (DType::Bool, 30, "bool", 1, 8, 1),
897 (DType::QI4, 40, "quantized_i4", 1, 4, 1),
898 (DType::QU8, 41, "quantized_u8", 1, 8, 1),
899 ];
900
901 for (dtype, id, name, bytes, bits, align) in table {
902 let info = dtype.info();
903 assert_eq!(info.id.0, id);
904 assert_eq!(info.name, name);
905 assert_eq!(info.byte_size, bytes);
906 assert_eq!(info.storage_bits, bits);
907 assert_eq!(info.align, align);
908 assert_eq!(DType::from_id(info.id), Some(dtype));
909 }
910 }
911}