1use crate::dtype::DType;
2
3mod code {
5 pub const INT: u8 = 0;
6 pub const UINT: u8 = 1;
7 pub const FLOAT: u8 = 2;
8 pub const BFLOAT: u8 = 4;
9 pub const BOOL: u8 = 6;
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct DLPackType {
15 pub code: u8,
17 pub bits: u8,
19 pub lanes: u16,
21}
22
23pub fn dlpack_type(dtype: DType) -> Option<DLPackType> {
25 let (code, bits) = match dtype {
26 DType::Unspecified => return None,
27 DType::Bool => (code::BOOL, 8),
28 DType::Uint8 => (code::UINT, 8),
29 DType::Uint16 => (code::UINT, 16),
30 DType::Uint32 => (code::UINT, 32),
31 DType::Uint64 => (code::UINT, 64),
32 DType::Int8 => (code::INT, 8),
33 DType::Int16 => (code::INT, 16),
34 DType::Int32 => (code::INT, 32),
35 DType::Int64 => (code::INT, 64),
36 DType::Float16 => (code::FLOAT, 16),
37 DType::Float32 => (code::FLOAT, 32),
38 DType::Float64 => (code::FLOAT, 64),
39 DType::Bfloat16 => (code::BFLOAT, 16),
40 };
41 Some(DLPackType {
42 code,
43 bits,
44 lanes: 1,
45 })
46}
47
48pub fn dtype_from_dlpack(ty: DLPackType) -> Option<DType> {
51 if ty.lanes != 1 {
52 return None;
53 }
54 match (ty.code, ty.bits) {
55 (code::BOOL, 8) => Some(DType::Bool),
56 (code::UINT, 8) => Some(DType::Uint8),
57 (code::UINT, 16) => Some(DType::Uint16),
58 (code::UINT, 32) => Some(DType::Uint32),
59 (code::UINT, 64) => Some(DType::Uint64),
60 (code::INT, 8) => Some(DType::Int8),
61 (code::INT, 16) => Some(DType::Int16),
62 (code::INT, 32) => Some(DType::Int32),
63 (code::INT, 64) => Some(DType::Int64),
64 (code::FLOAT, 16) => Some(DType::Float16),
65 (code::FLOAT, 32) => Some(DType::Float32),
66 (code::FLOAT, 64) => Some(DType::Float64),
67 (code::BFLOAT, 16) => Some(DType::Bfloat16),
68 _ => None,
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::dtype::dtype_size;
76
77 #[test]
78 fn test_dlpack_type_table() {
79 let expected = [
80 (DType::Bool, 6, 8),
81 (DType::Uint8, 1, 8),
82 (DType::Uint16, 1, 16),
83 (DType::Uint32, 1, 32),
84 (DType::Uint64, 1, 64),
85 (DType::Int8, 0, 8),
86 (DType::Int16, 0, 16),
87 (DType::Int32, 0, 32),
88 (DType::Int64, 0, 64),
89 (DType::Float16, 2, 16),
90 (DType::Float32, 2, 32),
91 (DType::Float64, 2, 64),
92 (DType::Bfloat16, 4, 16),
93 ];
94 for (dtype, code, bits) in expected {
95 let ty = dlpack_type(dtype).expect("supported dtype");
96 assert_eq!((ty.code, ty.bits, ty.lanes), (code, bits, 1), "{dtype:?}");
97 }
98 assert_eq!(dlpack_type(DType::Unspecified), None);
99 }
100
101 #[test]
102 fn test_dlpack_type_bits_match_dtype_size() {
103 for dtype in DType::ALL.into_iter().filter(|&d| d != DType::Unspecified) {
104 let ty = dlpack_type(dtype).expect("supported dtype");
105 assert_eq!(ty.bits as usize, dtype_size(dtype) * 8, "{dtype:?}");
106 }
107 }
108
109 #[test]
110 fn test_dlpack_type_roundtrip() {
111 for dtype in DType::ALL.into_iter().filter(|&d| d != DType::Unspecified) {
112 let ty = dlpack_type(dtype).expect("supported dtype");
113 assert_eq!(dtype_from_dlpack(ty), Some(dtype));
114 }
115 }
116
117 #[test]
118 fn test_dtype_from_dlpack_rejects_unsupported() {
119 assert_eq!(
121 dtype_from_dlpack(DLPackType {
122 code: 2,
123 bits: 32,
124 lanes: 4
125 }),
126 None
127 );
128 assert_eq!(
130 dtype_from_dlpack(DLPackType {
131 code: 0,
132 bits: 128,
133 lanes: 1
134 }),
135 None
136 );
137 assert_eq!(
139 dtype_from_dlpack(DLPackType {
140 code: 5,
141 bits: 64,
142 lanes: 1
143 }),
144 None
145 );
146 }
147}