Skip to main content

pyany_serde/common/
python_type.rs

1use numpy::PyArrayDyn;
2use pyo3::exceptions::asyncio::InvalidStateError;
3use pyo3::prelude::*;
4use pyo3::types::{
5    PyBool, PyBytes, PyComplex, PyDict, PyFloat, PyInt, PyList, PySet, PyString, PyTuple,
6};
7use pyo3::Bound;
8
9use super::numpy_dtype_enum::NumpyDtype;
10
11// This enum is used to store first-level information about Python types.
12#[derive(Debug, PartialEq)]
13pub enum PythonType {
14    BOOL,
15    BYTES,
16    COMPLEX,
17    DICT,
18    FLOAT,
19    INT,
20    LIST,
21    NUMPY { dtype: NumpyDtype },
22    OTHER,
23    SET,
24    STRING,
25    TUPLE,
26}
27
28pub fn get_python_type_byte(python_type: &PythonType) -> u8 {
29    match python_type {
30        PythonType::BOOL => 0,
31        PythonType::BYTES => 1,
32        PythonType::COMPLEX => 2,
33        PythonType::DICT => 3,
34        PythonType::FLOAT => 4,
35        PythonType::INT => 5,
36        PythonType::LIST => 6,
37        PythonType::NUMPY { dtype } => match dtype {
38            NumpyDtype::INT8 => 7,
39            NumpyDtype::INT16 => 8,
40            NumpyDtype::INT32 => 9,
41            NumpyDtype::INT64 => 10,
42            NumpyDtype::UINT8 => 11,
43            NumpyDtype::UINT16 => 12,
44            NumpyDtype::UINT32 => 13,
45            NumpyDtype::UINT64 => 14,
46            NumpyDtype::FLOAT32 => 15,
47            NumpyDtype::FLOAT64 => 16,
48        },
49        PythonType::OTHER => 17,
50        PythonType::SET => 18,
51        PythonType::STRING => 19,
52        PythonType::TUPLE => 20,
53    }
54}
55
56pub fn retrieve_python_type(bytes: &[u8], offset: usize) -> PyResult<(PythonType, usize)> {
57    let python_type = match bytes[offset] {
58        0 => Ok(PythonType::BOOL),
59        1 => Ok(PythonType::BYTES),
60        2 => Ok(PythonType::COMPLEX),
61        3 => Ok(PythonType::DICT),
62        4 => Ok(PythonType::FLOAT),
63        5 => Ok(PythonType::INT),
64        6 => Ok(PythonType::LIST),
65        7 => Ok(PythonType::NUMPY {
66            dtype: NumpyDtype::INT8,
67        }),
68        8 => Ok(PythonType::NUMPY {
69            dtype: NumpyDtype::INT16,
70        }),
71        9 => Ok(PythonType::NUMPY {
72            dtype: NumpyDtype::INT32,
73        }),
74        10 => Ok(PythonType::NUMPY {
75            dtype: NumpyDtype::INT64,
76        }),
77        11 => Ok(PythonType::NUMPY {
78            dtype: NumpyDtype::UINT8,
79        }),
80        12 => Ok(PythonType::NUMPY {
81            dtype: NumpyDtype::UINT16,
82        }),
83        13 => Ok(PythonType::NUMPY {
84            dtype: NumpyDtype::UINT32,
85        }),
86        14 => Ok(PythonType::NUMPY {
87            dtype: NumpyDtype::UINT64,
88        }),
89        15 => Ok(PythonType::NUMPY {
90            dtype: NumpyDtype::FLOAT32,
91        }),
92        16 => Ok(PythonType::NUMPY {
93            dtype: NumpyDtype::FLOAT64,
94        }),
95        17 => Ok(PythonType::OTHER),
96        18 => Ok(PythonType::SET),
97        19 => Ok(PythonType::STRING),
98        20 => Ok(PythonType::TUPLE),
99        v => Err(InvalidStateError::new_err(format!(
100            "tried to deserialize PythonType but got value {v}"
101        ))),
102    }?;
103    Ok((python_type, offset + 1))
104}
105
106macro_rules! check_numpy {
107    ($v: ident, $dtype: ident) => {
108        $v.cast::<PyArrayDyn<$dtype>>().is_ok()
109    };
110}
111
112pub fn detect_python_type<'py>(v: &Bound<'py, PyAny>) -> PyResult<PythonType> {
113    if v.is_exact_instance_of::<PyBool>() {
114        return Ok(PythonType::BOOL);
115    }
116    if v.is_exact_instance_of::<PyInt>() {
117        return Ok(PythonType::INT);
118    }
119    if v.is_exact_instance_of::<PyFloat>() {
120        return Ok(PythonType::FLOAT);
121    }
122    if v.is_exact_instance_of::<PyComplex>() {
123        return Ok(PythonType::COMPLEX);
124    }
125    if v.is_exact_instance_of::<PyString>() {
126        return Ok(PythonType::STRING);
127    }
128    if v.is_exact_instance_of::<PyBytes>() {
129        return Ok(PythonType::BYTES);
130    }
131    if check_numpy!(v, i8) {
132        return Ok(PythonType::NUMPY {
133            dtype: NumpyDtype::INT8,
134        });
135    }
136    if check_numpy!(v, i16) {
137        return Ok(PythonType::NUMPY {
138            dtype: NumpyDtype::INT16,
139        });
140    }
141    if check_numpy!(v, i32) {
142        return Ok(PythonType::NUMPY {
143            dtype: NumpyDtype::INT32,
144        });
145    }
146    if check_numpy!(v, i64) {
147        return Ok(PythonType::NUMPY {
148            dtype: NumpyDtype::INT64,
149        });
150    }
151    if check_numpy!(v, u8) {
152        return Ok(PythonType::NUMPY {
153            dtype: NumpyDtype::UINT8,
154        });
155    }
156    if check_numpy!(v, u16) {
157        return Ok(PythonType::NUMPY {
158            dtype: NumpyDtype::UINT16,
159        });
160    }
161    if check_numpy!(v, u32) {
162        return Ok(PythonType::NUMPY {
163            dtype: NumpyDtype::UINT32,
164        });
165    }
166    if check_numpy!(v, u64) {
167        return Ok(PythonType::NUMPY {
168            dtype: NumpyDtype::UINT64,
169        });
170    }
171    if check_numpy!(v, f32) {
172        return Ok(PythonType::NUMPY {
173            dtype: NumpyDtype::FLOAT32,
174        });
175    }
176    if check_numpy!(v, f64) {
177        return Ok(PythonType::NUMPY {
178            dtype: NumpyDtype::FLOAT64,
179        });
180    }
181    if v.is_exact_instance_of::<PyList>() {
182        return Ok(PythonType::LIST);
183    }
184    if v.is_exact_instance_of::<PySet>() {
185        return Ok(PythonType::SET);
186    }
187    if v.is_exact_instance_of::<PyTuple>() {
188        return Ok(PythonType::TUPLE);
189    }
190    if v.is_exact_instance_of::<PyDict>() {
191        return Ok(PythonType::DICT);
192    }
193    return Ok(PythonType::OTHER);
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use pyo3::{ffi::c_str, PyResult, Python};
200
201    #[test]
202    fn python_test_detect_python_type_numpy() -> PyResult<()> {
203        Python::initialize();
204        Python::attach(|py| {
205            let locals = PyDict::new(py);
206            py.run(
207                c_str!(
208                    r#"
209import numpy as np
210arr_i8 = np.array([1,2], dtype=np.int8)
211arr_u8 = np.array([1,2], dtype=np.uint8)
212arr_i16 = np.array([1,2], dtype=np.int16)
213arr_f32 = np.array([1,2], dtype=np.float32)
214arr_f64 = np.array([1,2], dtype=np.float64)
215"#
216                ),
217                None,
218                Some(&locals),
219            )?;
220            assert_eq!(
221                PythonType::NUMPY {
222                    dtype: NumpyDtype::INT8
223                },
224                detect_python_type(&locals.get_item("arr_i8")?.unwrap())?
225            );
226            assert_eq!(
227                PythonType::NUMPY {
228                    dtype: NumpyDtype::UINT8
229                },
230                detect_python_type(&locals.get_item("arr_u8")?.unwrap())?
231            );
232            assert_eq!(
233                PythonType::NUMPY {
234                    dtype: NumpyDtype::INT16
235                },
236                detect_python_type(&locals.get_item("arr_i16")?.unwrap())?
237            );
238            assert_eq!(
239                PythonType::NUMPY {
240                    dtype: NumpyDtype::FLOAT32
241                },
242                detect_python_type(&locals.get_item("arr_f32")?.unwrap())?
243            );
244            assert_eq!(
245                PythonType::NUMPY {
246                    dtype: NumpyDtype::FLOAT64
247                },
248                detect_python_type(&locals.get_item("arr_f64")?.unwrap())?
249            );
250            Ok(())
251        })
252    }
253}