1use numpy::PyArrayDyn;
2use pyo3::exceptions::asyncio::InvalidStateError;
3use pyo3::prelude::*;
4use pyo3::types::{
5 PyBool, PyBytes, PyComplex, PyDict, PyFloat, PyInt, PyList, PySet, PyString, PyTuple,
6};
7use pyo3::Bound;
8
9use super::numpy_dtype_enum::NumpyDtype;
10
11#[derive(Debug, PartialEq)]
13pub enum PythonType {
14 BOOL,
15 BYTES,
16 COMPLEX,
17 DICT,
18 FLOAT,
19 INT,
20 LIST,
21 NUMPY { dtype: NumpyDtype },
22 OTHER,
23 SET,
24 STRING,
25 TUPLE,
26}
27
28pub fn get_python_type_byte(python_type: &PythonType) -> u8 {
29 match python_type {
30 PythonType::BOOL => 0,
31 PythonType::BYTES => 1,
32 PythonType::COMPLEX => 2,
33 PythonType::DICT => 3,
34 PythonType::FLOAT => 4,
35 PythonType::INT => 5,
36 PythonType::LIST => 6,
37 PythonType::NUMPY { dtype } => match dtype {
38 NumpyDtype::INT8 => 7,
39 NumpyDtype::INT16 => 8,
40 NumpyDtype::INT32 => 9,
41 NumpyDtype::INT64 => 10,
42 NumpyDtype::UINT8 => 11,
43 NumpyDtype::UINT16 => 12,
44 NumpyDtype::UINT32 => 13,
45 NumpyDtype::UINT64 => 14,
46 NumpyDtype::FLOAT32 => 15,
47 NumpyDtype::FLOAT64 => 16,
48 },
49 PythonType::OTHER => 17,
50 PythonType::SET => 18,
51 PythonType::STRING => 19,
52 PythonType::TUPLE => 20,
53 }
54}
55
56pub fn retrieve_python_type(bytes: &[u8], offset: usize) -> PyResult<(PythonType, usize)> {
57 let python_type = match bytes[offset] {
58 0 => Ok(PythonType::BOOL),
59 1 => Ok(PythonType::BYTES),
60 2 => Ok(PythonType::COMPLEX),
61 3 => Ok(PythonType::DICT),
62 4 => Ok(PythonType::FLOAT),
63 5 => Ok(PythonType::INT),
64 6 => Ok(PythonType::LIST),
65 7 => Ok(PythonType::NUMPY {
66 dtype: NumpyDtype::INT8,
67 }),
68 8 => Ok(PythonType::NUMPY {
69 dtype: NumpyDtype::INT16,
70 }),
71 9 => Ok(PythonType::NUMPY {
72 dtype: NumpyDtype::INT32,
73 }),
74 10 => Ok(PythonType::NUMPY {
75 dtype: NumpyDtype::INT64,
76 }),
77 11 => Ok(PythonType::NUMPY {
78 dtype: NumpyDtype::UINT8,
79 }),
80 12 => Ok(PythonType::NUMPY {
81 dtype: NumpyDtype::UINT16,
82 }),
83 13 => Ok(PythonType::NUMPY {
84 dtype: NumpyDtype::UINT32,
85 }),
86 14 => Ok(PythonType::NUMPY {
87 dtype: NumpyDtype::UINT64,
88 }),
89 15 => Ok(PythonType::NUMPY {
90 dtype: NumpyDtype::FLOAT32,
91 }),
92 16 => Ok(PythonType::NUMPY {
93 dtype: NumpyDtype::FLOAT64,
94 }),
95 17 => Ok(PythonType::OTHER),
96 18 => Ok(PythonType::SET),
97 19 => Ok(PythonType::STRING),
98 20 => Ok(PythonType::TUPLE),
99 v => Err(InvalidStateError::new_err(format!(
100 "tried to deserialize PythonType but got value {v}"
101 ))),
102 }?;
103 Ok((python_type, offset + 1))
104}
105
106macro_rules! check_numpy {
107 ($v: ident, $dtype: ident) => {
108 $v.cast::<PyArrayDyn<$dtype>>().is_ok()
109 };
110}
111
112pub fn detect_python_type<'py>(v: &Bound<'py, PyAny>) -> PyResult<PythonType> {
113 if v.is_exact_instance_of::<PyBool>() {
114 return Ok(PythonType::BOOL);
115 }
116 if v.is_exact_instance_of::<PyInt>() {
117 return Ok(PythonType::INT);
118 }
119 if v.is_exact_instance_of::<PyFloat>() {
120 return Ok(PythonType::FLOAT);
121 }
122 if v.is_exact_instance_of::<PyComplex>() {
123 return Ok(PythonType::COMPLEX);
124 }
125 if v.is_exact_instance_of::<PyString>() {
126 return Ok(PythonType::STRING);
127 }
128 if v.is_exact_instance_of::<PyBytes>() {
129 return Ok(PythonType::BYTES);
130 }
131 if check_numpy!(v, i8) {
132 return Ok(PythonType::NUMPY {
133 dtype: NumpyDtype::INT8,
134 });
135 }
136 if check_numpy!(v, i16) {
137 return Ok(PythonType::NUMPY {
138 dtype: NumpyDtype::INT16,
139 });
140 }
141 if check_numpy!(v, i32) {
142 return Ok(PythonType::NUMPY {
143 dtype: NumpyDtype::INT32,
144 });
145 }
146 if check_numpy!(v, i64) {
147 return Ok(PythonType::NUMPY {
148 dtype: NumpyDtype::INT64,
149 });
150 }
151 if check_numpy!(v, u8) {
152 return Ok(PythonType::NUMPY {
153 dtype: NumpyDtype::UINT8,
154 });
155 }
156 if check_numpy!(v, u16) {
157 return Ok(PythonType::NUMPY {
158 dtype: NumpyDtype::UINT16,
159 });
160 }
161 if check_numpy!(v, u32) {
162 return Ok(PythonType::NUMPY {
163 dtype: NumpyDtype::UINT32,
164 });
165 }
166 if check_numpy!(v, u64) {
167 return Ok(PythonType::NUMPY {
168 dtype: NumpyDtype::UINT64,
169 });
170 }
171 if check_numpy!(v, f32) {
172 return Ok(PythonType::NUMPY {
173 dtype: NumpyDtype::FLOAT32,
174 });
175 }
176 if check_numpy!(v, f64) {
177 return Ok(PythonType::NUMPY {
178 dtype: NumpyDtype::FLOAT64,
179 });
180 }
181 if v.is_exact_instance_of::<PyList>() {
182 return Ok(PythonType::LIST);
183 }
184 if v.is_exact_instance_of::<PySet>() {
185 return Ok(PythonType::SET);
186 }
187 if v.is_exact_instance_of::<PyTuple>() {
188 return Ok(PythonType::TUPLE);
189 }
190 if v.is_exact_instance_of::<PyDict>() {
191 return Ok(PythonType::DICT);
192 }
193 return Ok(PythonType::OTHER);
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use pyo3::{ffi::c_str, PyResult, Python};
200
201 #[test]
202 fn python_test_detect_python_type_numpy() -> PyResult<()> {
203 Python::initialize();
204 Python::attach(|py| {
205 let locals = PyDict::new(py);
206 py.run(
207 c_str!(
208 r#"
209import numpy as np
210arr_i8 = np.array([1,2], dtype=np.int8)
211arr_u8 = np.array([1,2], dtype=np.uint8)
212arr_i16 = np.array([1,2], dtype=np.int16)
213arr_f32 = np.array([1,2], dtype=np.float32)
214arr_f64 = np.array([1,2], dtype=np.float64)
215"#
216 ),
217 None,
218 Some(&locals),
219 )?;
220 assert_eq!(
221 PythonType::NUMPY {
222 dtype: NumpyDtype::INT8
223 },
224 detect_python_type(&locals.get_item("arr_i8")?.unwrap())?
225 );
226 assert_eq!(
227 PythonType::NUMPY {
228 dtype: NumpyDtype::UINT8
229 },
230 detect_python_type(&locals.get_item("arr_u8")?.unwrap())?
231 );
232 assert_eq!(
233 PythonType::NUMPY {
234 dtype: NumpyDtype::INT16
235 },
236 detect_python_type(&locals.get_item("arr_i16")?.unwrap())?
237 );
238 assert_eq!(
239 PythonType::NUMPY {
240 dtype: NumpyDtype::FLOAT32
241 },
242 detect_python_type(&locals.get_item("arr_f32")?.unwrap())?
243 );
244 assert_eq!(
245 PythonType::NUMPY {
246 dtype: NumpyDtype::FLOAT64
247 },
248 detect_python_type(&locals.get_item("arr_f64")?.unwrap())?
249 );
250 Ok(())
251 })
252 }
253}