1use std::ffi::c_void;
7
8#[repr(u32)]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum DLDeviceType {
14 Cpu = 1,
16 Cuda = 2,
18 CudaHost = 3,
20 OpenCL = 4,
22 Vulkan = 7,
24 Metal = 8,
26 Vpi = 9,
28 Rocm = 10,
30 RocmHost = 11,
32 ExtDev = 12,
34 CudaManaged = 13,
36 OneApi = 14,
38 WebGpu = 15,
40 Hexagon = 16,
42 Maia = 17,
44}
45
46impl DLDeviceType {
47 pub fn from_raw(value: u32) -> Option<Self> {
51 match value {
52 1 => Some(Self::Cpu),
53 2 => Some(Self::Cuda),
54 3 => Some(Self::CudaHost),
55 4 => Some(Self::OpenCL),
56 7 => Some(Self::Vulkan),
57 8 => Some(Self::Metal),
58 9 => Some(Self::Vpi),
59 10 => Some(Self::Rocm),
60 11 => Some(Self::RocmHost),
61 12 => Some(Self::ExtDev),
62 13 => Some(Self::CudaManaged),
63 14 => Some(Self::OneApi),
64 15 => Some(Self::WebGpu),
65 16 => Some(Self::Hexagon),
66 17 => Some(Self::Maia),
67 _ => None,
68 }
69 }
70}
71
72#[repr(C)]
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77pub struct DLDevice {
78 pub device_type: u32,
80 pub device_id: i32,
82}
83
84impl DLDevice {
85 pub fn new(device_type: DLDeviceType, device_id: i32) -> Self {
87 Self {
88 device_type: device_type as u32,
89 device_id,
90 }
91 }
92
93 pub fn device_type_enum(&self) -> Option<DLDeviceType> {
97 DLDeviceType::from_raw(self.device_type)
98 }
99
100 pub fn is_cuda(&self) -> bool {
102 self.device_type == DLDeviceType::Cuda as u32
103 }
104
105 pub fn is_cpu(&self) -> bool {
107 self.device_type == DLDeviceType::Cpu as u32
108 }
109
110 pub fn is_cuda_host(&self) -> bool {
112 self.device_type == DLDeviceType::CudaHost as u32
113 }
114
115 pub fn is_rocm(&self) -> bool {
117 self.device_type == DLDeviceType::Rocm as u32
118 }
119
120 pub fn is_metal(&self) -> bool {
122 self.device_type == DLDeviceType::Metal as u32
123 }
124}
125
126#[repr(u8)]
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
131pub enum DLDataTypeCode {
132 Int = 0,
134 UInt = 1,
136 Float = 2,
138 OpaqueHandle = 3,
140 Bfloat = 4,
142 Complex = 5,
144 Bool = 6,
146}
147
148impl DLDataTypeCode {
149 pub fn from_raw(value: u8) -> Option<Self> {
153 match value {
154 0 => Some(Self::Int),
155 1 => Some(Self::UInt),
156 2 => Some(Self::Float),
157 3 => Some(Self::OpaqueHandle),
158 4 => Some(Self::Bfloat),
159 5 => Some(Self::Complex),
160 6 => Some(Self::Bool),
161 _ => None,
162 }
163 }
164}
165
166#[repr(C)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171pub struct DLDataType {
172 pub code: u8,
174 pub bits: u8,
176 pub lanes: u16,
178}
179
180impl DLDataType {
181 pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self {
183 Self {
184 code: code as u8,
185 bits,
186 lanes,
187 }
188 }
189
190 pub fn code_enum(&self) -> Option<DLDataTypeCode> {
194 DLDataTypeCode::from_raw(self.code)
195 }
196
197 pub fn is_f16(&self) -> bool {
199 self.code == DLDataTypeCode::Float as u8 && self.bits == 16 && self.lanes == 1
200 }
201
202 pub fn is_f32(&self) -> bool {
204 self.code == DLDataTypeCode::Float as u8 && self.bits == 32 && self.lanes == 1
205 }
206
207 pub fn is_f64(&self) -> bool {
209 self.code == DLDataTypeCode::Float as u8 && self.bits == 64 && self.lanes == 1
210 }
211
212 pub fn is_bf16(&self) -> bool {
214 self.code == DLDataTypeCode::Bfloat as u8 && self.bits == 16 && self.lanes == 1
215 }
216
217 pub fn is_i8(&self) -> bool {
219 self.code == DLDataTypeCode::Int as u8 && self.bits == 8 && self.lanes == 1
220 }
221
222 pub fn is_i16(&self) -> bool {
224 self.code == DLDataTypeCode::Int as u8 && self.bits == 16 && self.lanes == 1
225 }
226
227 pub fn is_i32(&self) -> bool {
229 self.code == DLDataTypeCode::Int as u8 && self.bits == 32 && self.lanes == 1
230 }
231
232 pub fn is_i64(&self) -> bool {
234 self.code == DLDataTypeCode::Int as u8 && self.bits == 64 && self.lanes == 1
235 }
236
237 pub fn is_u8(&self) -> bool {
239 self.code == DLDataTypeCode::UInt as u8 && self.bits == 8 && self.lanes == 1
240 }
241
242 pub fn is_u16(&self) -> bool {
244 self.code == DLDataTypeCode::UInt as u8 && self.bits == 16 && self.lanes == 1
245 }
246
247 pub fn is_u32(&self) -> bool {
249 self.code == DLDataTypeCode::UInt as u8 && self.bits == 32 && self.lanes == 1
250 }
251
252 pub fn is_u64(&self) -> bool {
254 self.code == DLDataTypeCode::UInt as u8 && self.bits == 64 && self.lanes == 1
255 }
256
257 pub fn is_bool(&self) -> bool {
259 self.code == DLDataTypeCode::Bool as u8 && self.bits == 8 && self.lanes == 1
260 }
261
262 pub fn itemsize(&self) -> usize {
264 ((self.bits as usize) * (self.lanes as usize)).div_ceil(8)
265 }
266}
267
268#[repr(C)]
272pub struct DLTensor {
273 pub data: *mut c_void,
276 pub device: DLDevice,
278 pub ndim: i32,
280 pub dtype: DLDataType,
282 pub shape: *mut i64,
285 pub strides: *mut i64,
288 pub byte_offset: u64,
290}
291
292pub type DLManagedTensorDeleter = unsafe extern "C" fn(*mut DLManagedTensor);
296
297#[repr(C)]
302pub struct DLManagedTensor {
303 pub dl_tensor: DLTensor,
305 pub manager_ctx: *mut c_void,
308 pub deleter: Option<DLManagedTensorDeleter>,
311}
312
313#[repr(C)]
317#[derive(Debug, Clone, Copy, PartialEq, Eq)]
318pub struct DLPackVersion {
319 pub major: u32,
321 pub minor: u32,
323}
324
325pub const DLPACK_MAJOR_VERSION: u32 = 1;
327pub const DLPACK_MINOR_VERSION: u32 = 0;
336
337pub const DLPACK_FLAG_BITMASK_READ_ONLY: u64 = 1 << 0;
339pub const DLPACK_FLAG_BITMASK_IS_COPIED: u64 = 1 << 1;
341pub const DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED: u64 = 1 << 2;
343
344pub type DLManagedTensorVersionedDeleter = unsafe extern "C" fn(*mut DLManagedTensorVersioned);
350
351#[repr(C)]
357pub struct DLManagedTensorVersioned {
358 pub version: DLPackVersion,
360 pub manager_ctx: *mut c_void,
362 pub deleter: Option<DLManagedTensorVersionedDeleter>,
365 pub flags: u64,
367 pub dl_tensor: DLTensor,
369}
370
371pub fn cuda_device(device_id: i32) -> DLDevice {
377 DLDevice::new(DLDeviceType::Cuda, device_id)
378}
379
380pub fn cpu_device() -> DLDevice {
382 DLDevice::new(DLDeviceType::Cpu, 0)
383}
384
385pub fn metal_device(device_id: i32) -> DLDevice {
387 DLDevice::new(DLDeviceType::Metal, device_id)
388}
389
390pub fn dtype_f32() -> DLDataType {
392 DLDataType::new(DLDataTypeCode::Float, 32, 1)
393}
394
395pub fn dtype_f64() -> DLDataType {
397 DLDataType::new(DLDataTypeCode::Float, 64, 1)
398}
399
400pub fn dtype_f16() -> DLDataType {
402 DLDataType::new(DLDataTypeCode::Float, 16, 1)
403}
404
405pub fn dtype_bf16() -> DLDataType {
407 DLDataType::new(DLDataTypeCode::Bfloat, 16, 1)
408}
409
410pub fn dtype_i8() -> DLDataType {
412 DLDataType::new(DLDataTypeCode::Int, 8, 1)
413}
414
415pub fn dtype_i16() -> DLDataType {
417 DLDataType::new(DLDataTypeCode::Int, 16, 1)
418}
419
420pub fn dtype_i32() -> DLDataType {
422 DLDataType::new(DLDataTypeCode::Int, 32, 1)
423}
424
425pub fn dtype_i64() -> DLDataType {
427 DLDataType::new(DLDataTypeCode::Int, 64, 1)
428}
429
430pub fn dtype_u8() -> DLDataType {
432 DLDataType::new(DLDataTypeCode::UInt, 8, 1)
433}
434
435pub fn dtype_u16() -> DLDataType {
437 DLDataType::new(DLDataTypeCode::UInt, 16, 1)
438}
439
440pub fn dtype_u32() -> DLDataType {
442 DLDataType::new(DLDataTypeCode::UInt, 32, 1)
443}
444
445pub fn dtype_u64() -> DLDataType {
447 DLDataType::new(DLDataTypeCode::UInt, 64, 1)
448}
449
450pub fn dtype_bool() -> DLDataType {
452 DLDataType::new(DLDataTypeCode::Bool, 8, 1)
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
464 fn test_device_type_from_raw_all_variants() {
465 assert_eq!(DLDeviceType::from_raw(1), Some(DLDeviceType::Cpu));
466 assert_eq!(DLDeviceType::from_raw(2), Some(DLDeviceType::Cuda));
467 assert_eq!(DLDeviceType::from_raw(3), Some(DLDeviceType::CudaHost));
468 assert_eq!(DLDeviceType::from_raw(4), Some(DLDeviceType::OpenCL));
469 assert_eq!(DLDeviceType::from_raw(7), Some(DLDeviceType::Vulkan));
470 assert_eq!(DLDeviceType::from_raw(8), Some(DLDeviceType::Metal));
471 assert_eq!(DLDeviceType::from_raw(9), Some(DLDeviceType::Vpi));
472 assert_eq!(DLDeviceType::from_raw(10), Some(DLDeviceType::Rocm));
473 assert_eq!(DLDeviceType::from_raw(11), Some(DLDeviceType::RocmHost));
474 assert_eq!(DLDeviceType::from_raw(12), Some(DLDeviceType::ExtDev));
475 assert_eq!(DLDeviceType::from_raw(13), Some(DLDeviceType::CudaManaged));
476 assert_eq!(DLDeviceType::from_raw(14), Some(DLDeviceType::OneApi));
477 assert_eq!(DLDeviceType::from_raw(15), Some(DLDeviceType::WebGpu));
478 assert_eq!(DLDeviceType::from_raw(16), Some(DLDeviceType::Hexagon));
479 assert_eq!(DLDeviceType::from_raw(17), Some(DLDeviceType::Maia));
480 }
481
482 #[test]
483 fn test_device_type_from_raw_unknown() {
484 assert_eq!(DLDeviceType::from_raw(0), None);
485 assert_eq!(DLDeviceType::from_raw(5), None);
486 assert_eq!(DLDeviceType::from_raw(6), None);
487 assert_eq!(DLDeviceType::from_raw(18), None);
488 assert_eq!(DLDeviceType::from_raw(100), None);
489 assert_eq!(DLDeviceType::from_raw(u32::MAX), None);
490 }
491
492 #[test]
493 fn test_device_type_debug() {
494 assert_eq!(format!("{:?}", DLDeviceType::Cpu), "Cpu");
495 assert_eq!(format!("{:?}", DLDeviceType::Cuda), "Cuda");
496 }
497
498 #[test]
499 fn test_device_type_clone_copy() {
500 let dt = DLDeviceType::Cuda;
501 let dt2 = dt;
502 let dt3 = dt;
503 assert_eq!(dt, dt2);
504 assert_eq!(dt, dt3);
505 }
506
507 #[test]
508 fn test_device_type_hash() {
509 use std::collections::HashSet;
510 let mut set = HashSet::new();
511 set.insert(DLDeviceType::Cpu);
512 set.insert(DLDeviceType::Cuda);
513 set.insert(DLDeviceType::Cpu);
514 assert_eq!(set.len(), 2);
515 }
516
517 #[test]
522 fn test_device_new() {
523 let dev = DLDevice::new(DLDeviceType::Cuda, 3);
524 assert_eq!(dev.device_type, 2);
525 assert_eq!(dev.device_id, 3);
526 }
527
528 #[test]
529 fn test_device_type_enum() {
530 let dev = DLDevice::new(DLDeviceType::Rocm, 1);
531 assert_eq!(dev.device_type_enum(), Some(DLDeviceType::Rocm));
532
533 let unknown = DLDevice {
534 device_type: 99,
535 device_id: 0,
536 };
537 assert_eq!(unknown.device_type_enum(), None);
538 }
539
540 #[test]
541 fn test_device_is_cuda() {
542 assert!(cuda_device(0).is_cuda());
543 assert!(!cpu_device().is_cuda());
544 assert!(!DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda());
545 }
546
547 #[test]
548 fn test_device_is_cpu() {
549 assert!(cpu_device().is_cpu());
550 assert!(!cuda_device(0).is_cpu());
551 }
552
553 #[test]
554 fn test_device_is_cuda_host() {
555 assert!(DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda_host());
556 assert!(!cpu_device().is_cuda_host());
557 assert!(!cuda_device(0).is_cuda_host());
558 }
559
560 #[test]
561 fn test_device_is_rocm() {
562 assert!(DLDevice::new(DLDeviceType::Rocm, 0).is_rocm());
563 assert!(!cpu_device().is_rocm());
564 assert!(!cuda_device(0).is_rocm());
565 }
566
567 #[test]
568 fn test_device_is_metal() {
569 assert!(DLDevice::new(DLDeviceType::Metal, 0).is_metal());
570 assert!(metal_device(0).is_metal());
571 assert!(!cpu_device().is_metal());
572 assert!(!cuda_device(0).is_metal());
573 }
574
575 #[test]
576 fn test_device_debug() {
577 let dev = cuda_device(2);
578 let debug = format!("{:?}", dev);
579 assert!(debug.contains("device_type"));
580 assert!(debug.contains("device_id"));
581 }
582
583 #[test]
584 fn test_device_clone_copy() {
585 let dev = cuda_device(1);
586 let dev2 = dev;
587 let dev3 = dev;
588 assert_eq!(dev, dev2);
589 assert_eq!(dev, dev3);
590 }
591
592 #[test]
593 fn test_device_hash() {
594 use std::collections::HashSet;
595 let mut set = HashSet::new();
596 set.insert(cpu_device());
597 set.insert(cuda_device(0));
598 set.insert(cuda_device(1));
599 set.insert(cpu_device());
600 assert_eq!(set.len(), 3);
601 }
602
603 #[test]
608 fn test_dtype_code_from_raw_all_variants() {
609 assert_eq!(DLDataTypeCode::from_raw(0), Some(DLDataTypeCode::Int));
610 assert_eq!(DLDataTypeCode::from_raw(1), Some(DLDataTypeCode::UInt));
611 assert_eq!(DLDataTypeCode::from_raw(2), Some(DLDataTypeCode::Float));
612 assert_eq!(
613 DLDataTypeCode::from_raw(3),
614 Some(DLDataTypeCode::OpaqueHandle)
615 );
616 assert_eq!(DLDataTypeCode::from_raw(4), Some(DLDataTypeCode::Bfloat));
617 assert_eq!(DLDataTypeCode::from_raw(5), Some(DLDataTypeCode::Complex));
618 assert_eq!(DLDataTypeCode::from_raw(6), Some(DLDataTypeCode::Bool));
619 }
620
621 #[test]
622 fn test_dtype_code_from_raw_unknown() {
623 assert_eq!(DLDataTypeCode::from_raw(7), None);
624 assert_eq!(DLDataTypeCode::from_raw(100), None);
625 assert_eq!(DLDataTypeCode::from_raw(u8::MAX), None);
626 }
627
628 #[test]
629 fn test_dtype_code_debug() {
630 assert_eq!(format!("{:?}", DLDataTypeCode::Float), "Float");
631 assert_eq!(format!("{:?}", DLDataTypeCode::Int), "Int");
632 }
633
634 #[test]
635 fn test_dtype_code_clone_copy() {
636 let code = DLDataTypeCode::Float;
637 let code2 = code;
638 let code3 = code;
639 assert_eq!(code, code2);
640 assert_eq!(code, code3);
641 }
642
643 #[test]
644 fn test_dtype_code_hash() {
645 use std::collections::HashSet;
646 let mut set = HashSet::new();
647 set.insert(DLDataTypeCode::Float);
648 set.insert(DLDataTypeCode::Int);
649 set.insert(DLDataTypeCode::Float);
650 assert_eq!(set.len(), 2);
651 }
652
653 #[test]
658 fn test_dtype_new() {
659 let dt = DLDataType::new(DLDataTypeCode::Float, 32, 1);
660 assert_eq!(dt.code, 2);
661 assert_eq!(dt.bits, 32);
662 assert_eq!(dt.lanes, 1);
663 }
664
665 #[test]
666 fn test_dtype_code_enum() {
667 let dt = dtype_f32();
668 assert_eq!(dt.code_enum(), Some(DLDataTypeCode::Float));
669
670 let unknown = DLDataType {
671 code: 99,
672 bits: 32,
673 lanes: 1,
674 };
675 assert_eq!(unknown.code_enum(), None);
676 }
677
678 #[test]
679 fn test_dtype_is_f16() {
680 assert!(dtype_f16().is_f16());
681 assert!(!dtype_f32().is_f16());
682 assert!(!dtype_bf16().is_f16());
683 let wrong = DLDataType::new(DLDataTypeCode::Float, 16, 2);
685 assert!(!wrong.is_f16());
686 }
687
688 #[test]
689 fn test_dtype_is_f32() {
690 assert!(dtype_f32().is_f32());
691 assert!(!dtype_f64().is_f32());
692 assert!(!dtype_f16().is_f32());
693 }
694
695 #[test]
696 fn test_dtype_is_f64() {
697 assert!(dtype_f64().is_f64());
698 assert!(!dtype_f32().is_f64());
699 }
700
701 #[test]
702 fn test_dtype_is_bf16() {
703 assert!(dtype_bf16().is_bf16());
704 assert!(!dtype_f16().is_bf16());
705 assert!(!dtype_f32().is_bf16());
706 }
707
708 #[test]
709 fn test_dtype_is_i8() {
710 assert!(dtype_i8().is_i8());
711 assert!(!dtype_i16().is_i8());
712 assert!(!dtype_u8().is_i8());
713 }
714
715 #[test]
716 fn test_dtype_is_i16() {
717 assert!(dtype_i16().is_i16());
718 assert!(!dtype_i8().is_i16());
719 assert!(!dtype_i32().is_i16());
720 }
721
722 #[test]
723 fn test_dtype_is_i32() {
724 assert!(dtype_i32().is_i32());
725 assert!(!dtype_i64().is_i32());
726 assert!(!dtype_u32().is_i32());
727 }
728
729 #[test]
730 fn test_dtype_is_i64() {
731 assert!(dtype_i64().is_i64());
732 assert!(!dtype_i32().is_i64());
733 }
734
735 #[test]
736 fn test_dtype_is_u8() {
737 assert!(dtype_u8().is_u8());
738 assert!(!dtype_i8().is_u8());
739 assert!(!dtype_u16().is_u8());
740 }
741
742 #[test]
743 fn test_dtype_is_u16() {
744 assert!(dtype_u16().is_u16());
745 assert!(!dtype_u8().is_u16());
746 }
747
748 #[test]
749 fn test_dtype_is_u32() {
750 assert!(dtype_u32().is_u32());
751 assert!(!dtype_i32().is_u32());
752 }
753
754 #[test]
755 fn test_dtype_is_u64() {
756 assert!(dtype_u64().is_u64());
757 assert!(!dtype_u32().is_u64());
758 }
759
760 #[test]
761 fn test_dtype_is_bool() {
762 assert!(dtype_bool().is_bool());
763 assert!(!dtype_u8().is_bool());
764 assert!(!dtype_i8().is_bool());
765 }
766
767 #[test]
768 fn test_dtype_itemsize() {
769 assert_eq!(dtype_f16().itemsize(), 2);
770 assert_eq!(dtype_f32().itemsize(), 4);
771 assert_eq!(dtype_f64().itemsize(), 8);
772 assert_eq!(dtype_bf16().itemsize(), 2);
773 assert_eq!(dtype_i8().itemsize(), 1);
774 assert_eq!(dtype_i16().itemsize(), 2);
775 assert_eq!(dtype_i32().itemsize(), 4);
776 assert_eq!(dtype_i64().itemsize(), 8);
777 assert_eq!(dtype_u8().itemsize(), 1);
778 assert_eq!(dtype_u16().itemsize(), 2);
779 assert_eq!(dtype_u32().itemsize(), 4);
780 assert_eq!(dtype_u64().itemsize(), 8);
781 assert_eq!(dtype_bool().itemsize(), 1);
782 }
783
784 #[test]
785 fn test_dtype_itemsize_vectorized() {
786 let vec_f32 = DLDataType::new(DLDataTypeCode::Float, 32, 4);
788 assert_eq!(vec_f32.itemsize(), 16); let vec_i16 = DLDataType::new(DLDataTypeCode::Int, 16, 8);
792 assert_eq!(vec_i16.itemsize(), 16); }
794
795 #[test]
796 fn test_dtype_itemsize_rounding() {
797 let one_bit = DLDataType {
799 code: 0,
800 bits: 1,
801 lanes: 1,
802 };
803 assert_eq!(one_bit.itemsize(), 1);
804
805 let seven_bits = DLDataType {
806 code: 0,
807 bits: 7,
808 lanes: 1,
809 };
810 assert_eq!(seven_bits.itemsize(), 1);
811
812 let nine_bits = DLDataType {
813 code: 0,
814 bits: 9,
815 lanes: 1,
816 };
817 assert_eq!(nine_bits.itemsize(), 2);
818 }
819
820 #[test]
821 fn test_dtype_debug() {
822 let dt = dtype_f32();
823 let debug = format!("{:?}", dt);
824 assert!(debug.contains("code"));
825 assert!(debug.contains("bits"));
826 assert!(debug.contains("lanes"));
827 }
828
829 #[test]
830 fn test_dtype_clone_copy() {
831 let dt = dtype_f32();
832 let dt2 = dt;
833 let dt3 = dt;
834 assert_eq!(dt, dt2);
835 assert_eq!(dt, dt3);
836 }
837
838 #[test]
839 fn test_dtype_hash() {
840 use std::collections::HashSet;
841 let mut set = HashSet::new();
842 set.insert(dtype_f32());
843 set.insert(dtype_f64());
844 set.insert(dtype_f32());
845 assert_eq!(set.len(), 2);
846 }
847
848 #[test]
853 fn test_cuda_device() {
854 let dev = cuda_device(0);
855 assert!(dev.is_cuda());
856 assert_eq!(dev.device_id, 0);
857
858 let dev1 = cuda_device(1);
859 assert!(dev1.is_cuda());
860 assert_eq!(dev1.device_id, 1);
861 }
862
863 #[test]
864 fn test_cpu_device() {
865 let dev = cpu_device();
866 assert!(dev.is_cpu());
867 assert_eq!(dev.device_id, 0);
868 }
869
870 #[test]
871 fn test_metal_device() {
872 let dev = metal_device(0);
873 assert!(dev.is_metal());
874 assert_eq!(dev.device_id, 0);
875
876 let dev1 = metal_device(1);
877 assert!(dev1.is_metal());
878 assert_eq!(dev1.device_id, 1);
879 }
880
881 #[test]
882 fn test_all_dtype_constructors() {
883 assert!(dtype_f16().is_f16());
885 assert!(dtype_f32().is_f32());
886 assert!(dtype_f64().is_f64());
887 assert!(dtype_bf16().is_bf16());
888
889 assert!(dtype_i8().is_i8());
891 assert!(dtype_i16().is_i16());
892 assert!(dtype_i32().is_i32());
893 assert!(dtype_i64().is_i64());
894
895 assert!(dtype_u8().is_u8());
897 assert!(dtype_u16().is_u16());
898 assert!(dtype_u32().is_u32());
899 assert!(dtype_u64().is_u64());
900
901 assert!(dtype_bool().is_bool());
903 }
904
905 #[test]
910 fn test_dl_tensor_size() {
911 let size = std::mem::size_of::<DLTensor>();
914 assert!(size > 0);
915 }
918
919 #[test]
920 fn test_dl_managed_tensor_size() {
921 let size = std::mem::size_of::<DLManagedTensor>();
922 assert!(size > 0);
923 }
925
926 #[test]
927 fn test_dl_pack_version_layout() {
928 assert_eq!(std::mem::size_of::<DLPackVersion>(), 8);
930 assert_eq!(std::mem::offset_of!(DLPackVersion, major), 0);
931 assert_eq!(std::mem::offset_of!(DLPackVersion, minor), 4);
932 }
933
934 #[test]
935 fn test_dl_managed_tensor_versioned_layout() {
936 use std::mem::offset_of;
940 assert_eq!(offset_of!(DLManagedTensorVersioned, version), 0);
941 assert!(
942 offset_of!(DLManagedTensorVersioned, version)
943 < offset_of!(DLManagedTensorVersioned, manager_ctx)
944 );
945 assert!(
946 offset_of!(DLManagedTensorVersioned, manager_ctx)
947 < offset_of!(DLManagedTensorVersioned, deleter)
948 );
949 assert!(
950 offset_of!(DLManagedTensorVersioned, deleter)
951 < offset_of!(DLManagedTensorVersioned, flags)
952 );
953 assert!(
954 offset_of!(DLManagedTensorVersioned, flags)
955 < offset_of!(DLManagedTensorVersioned, dl_tensor)
956 );
957 assert!(std::mem::size_of::<DLManagedTensorVersioned>() > std::mem::size_of::<DLTensor>());
959 }
960
961 #[test]
962 fn test_read_only_flag_value() {
963 assert_eq!(DLPACK_FLAG_BITMASK_READ_ONLY, 1);
964 assert_eq!(DLPACK_FLAG_BITMASK_IS_COPIED, 2);
965 assert_eq!(DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED, 4);
966 assert_eq!(DLPACK_MAJOR_VERSION, 1);
967 }
968
969 #[test]
970 fn test_dl_device_repr_c() {
971 assert_eq!(std::mem::align_of::<DLDevice>(), 4);
973 assert_eq!(std::mem::size_of::<DLDevice>(), 8);
974 }
975
976 #[test]
977 fn test_dl_data_type_repr_c() {
978 assert_eq!(std::mem::size_of::<DLDataType>(), 4);
980 }
981}