Skip to main content

rlmesh_spaces/tensor/
dlpack.rs

1use crate::dtype::DType;
2
3/// DLPack data type codes (`DLDataTypeCode`).
4mod code {
5    pub const INT: u8 = 0;
6    pub const UINT: u8 = 1;
7    pub const FLOAT: u8 = 2;
8    pub const BFLOAT: u8 = 4;
9    pub const BOOL: u8 = 6;
10}
11
12/// A DLPack `DLDataType` triple describing a tensor element type.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct DLPackType {
15    /// `DLDataTypeCode` (0 = int, 1 = uint, 2 = float, 4 = bfloat, 6 = bool).
16    pub code: u8,
17    /// Element width in bits.
18    pub bits: u8,
19    /// Vector lanes; always 1 for RLMesh tensors.
20    pub lanes: u16,
21}
22
23/// Map a dtype to its DLPack data type. `Unspecified` has no DLPack form.
24pub fn dlpack_type(dtype: DType) -> Option<DLPackType> {
25    let (code, bits) = match dtype {
26        DType::Unspecified => return None,
27        DType::Bool => (code::BOOL, 8),
28        DType::Uint8 => (code::UINT, 8),
29        DType::Uint16 => (code::UINT, 16),
30        DType::Uint32 => (code::UINT, 32),
31        DType::Uint64 => (code::UINT, 64),
32        DType::Int8 => (code::INT, 8),
33        DType::Int16 => (code::INT, 16),
34        DType::Int32 => (code::INT, 32),
35        DType::Int64 => (code::INT, 64),
36        DType::Float16 => (code::FLOAT, 16),
37        DType::Float32 => (code::FLOAT, 32),
38        DType::Float64 => (code::FLOAT, 64),
39        DType::Bfloat16 => (code::BFLOAT, 16),
40    };
41    Some(DLPackType {
42        code,
43        bits,
44        lanes: 1,
45    })
46}
47
48/// Map a DLPack data type back to a dtype. Returns `None` for unsupported
49/// codes or widths and for vectorized types (`lanes != 1`).
50pub fn dtype_from_dlpack(ty: DLPackType) -> Option<DType> {
51    if ty.lanes != 1 {
52        return None;
53    }
54    match (ty.code, ty.bits) {
55        (code::BOOL, 8) => Some(DType::Bool),
56        (code::UINT, 8) => Some(DType::Uint8),
57        (code::UINT, 16) => Some(DType::Uint16),
58        (code::UINT, 32) => Some(DType::Uint32),
59        (code::UINT, 64) => Some(DType::Uint64),
60        (code::INT, 8) => Some(DType::Int8),
61        (code::INT, 16) => Some(DType::Int16),
62        (code::INT, 32) => Some(DType::Int32),
63        (code::INT, 64) => Some(DType::Int64),
64        (code::FLOAT, 16) => Some(DType::Float16),
65        (code::FLOAT, 32) => Some(DType::Float32),
66        (code::FLOAT, 64) => Some(DType::Float64),
67        (code::BFLOAT, 16) => Some(DType::Bfloat16),
68        _ => None,
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use crate::dtype::dtype_size;
76
77    #[test]
78    fn test_dlpack_type_table() {
79        let expected = [
80            (DType::Bool, 6, 8),
81            (DType::Uint8, 1, 8),
82            (DType::Uint16, 1, 16),
83            (DType::Uint32, 1, 32),
84            (DType::Uint64, 1, 64),
85            (DType::Int8, 0, 8),
86            (DType::Int16, 0, 16),
87            (DType::Int32, 0, 32),
88            (DType::Int64, 0, 64),
89            (DType::Float16, 2, 16),
90            (DType::Float32, 2, 32),
91            (DType::Float64, 2, 64),
92            (DType::Bfloat16, 4, 16),
93        ];
94        for (dtype, code, bits) in expected {
95            let ty = dlpack_type(dtype).expect("supported dtype");
96            assert_eq!((ty.code, ty.bits, ty.lanes), (code, bits, 1), "{dtype:?}");
97        }
98        assert_eq!(dlpack_type(DType::Unspecified), None);
99    }
100
101    #[test]
102    fn test_dlpack_type_bits_match_dtype_size() {
103        for dtype in DType::ALL.into_iter().filter(|&d| d != DType::Unspecified) {
104            let ty = dlpack_type(dtype).expect("supported dtype");
105            assert_eq!(ty.bits as usize, dtype_size(dtype) * 8, "{dtype:?}");
106        }
107    }
108
109    #[test]
110    fn test_dlpack_type_roundtrip() {
111        for dtype in DType::ALL.into_iter().filter(|&d| d != DType::Unspecified) {
112            let ty = dlpack_type(dtype).expect("supported dtype");
113            assert_eq!(dtype_from_dlpack(ty), Some(dtype));
114        }
115    }
116
117    #[test]
118    fn test_dtype_from_dlpack_rejects_unsupported() {
119        // Vectorized types.
120        assert_eq!(
121            dtype_from_dlpack(DLPackType {
122                code: 2,
123                bits: 32,
124                lanes: 4
125            }),
126            None
127        );
128        // Unknown width.
129        assert_eq!(
130            dtype_from_dlpack(DLPackType {
131                code: 0,
132                bits: 128,
133                lanes: 1
134            }),
135            None
136        );
137        // Unknown code (e.g. complex = 5).
138        assert_eq!(
139            dtype_from_dlpack(DLPackType {
140                code: 5,
141                bits: 64,
142                lanes: 1
143            }),
144            None
145        );
146    }
147}