pyany_serde/
pyany_serde_type.rs

1use std::collections::BTreeMap;
2use std::env;
3use std::io;
4use std::io::Write;
5use std::str::FromStr;
6
7use num_traits::{FromPrimitive, ToPrimitive};
8use pyo3::exceptions::asyncio::InvalidStateError;
9use pyo3::exceptions::PyValueError;
10use pyo3::types::{PyBytes, PyCFunction, PyDict, PyFunction, PyTuple, PyType};
11use pyo3::{prelude::*, PyTypeInfo};
12use strum::{IntoEnumIterator, VariantNames};
13use strum_macros::Display;
14
15use crate::common::NumpyDtype;
16use crate::communication::{
17    append_bytes_vec, append_string_vec, append_usize_vec, retrieve_bytes, retrieve_string,
18    retrieve_usize,
19};
20use crate::pyany_serde_impl::{
21    numpy_check_for_unpickling, InitStrategy, NumpySerdeConfig, PickleableInitStrategy,
22    PickleableNumpySerdeConfig,
23};
24
25// This enum is used to store information about a type which is sent between processes to dynamically recover a Box<dyn PyAnySerde>
26#[pyclass]
27#[derive(Clone)]
28pub struct PickleablePyAnySerdeType(pub Option<Option<PyAnySerdeType>>);
29
30#[pymethods]
31impl PickleablePyAnySerdeType {
32    // We need a zero-args constructor for compatibility with unpickling
33    #[new]
34    #[pyo3(signature = (*args))]
35    fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
36        let vec_args = args.iter().collect::<Vec<_>>();
37        if vec_args.len() > 1 {
38            return Err(PyValueError::new_err(format!(
39                "PickleablePyAnySerde constructor takes 0 or 1 parameters, received {}",
40                args.as_any().repr()?.to_str()?
41            )));
42        }
43        if vec_args.len() == 1 {
44            Ok(PickleablePyAnySerdeType(Some(
45                vec_args[0].extract::<Option<PyAnySerdeType>>()?,
46            )))
47        } else {
48            Ok(PickleablePyAnySerdeType(None))
49        }
50    }
51
52    // pickle methods
53    pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
54        let pyany_serde_type_option = self.0.as_ref().unwrap();
55        Ok(match pyany_serde_type_option {
56            Some(pyany_serde_type) => {
57                let mut option_bytes = vec![1];
58                let mut pyany_serde_type_bytes = match pyany_serde_type {
59                    PyAnySerdeType::BOOL {} => vec![0],
60                    PyAnySerdeType::BYTES {} => vec![1],
61                    PyAnySerdeType::COMPLEX {} => vec![2],
62                    PyAnySerdeType::DATACLASS {
63                        clazz,
64                        init_strategy,
65                        field_serde_type_dict,
66                    } => {
67                        let mut bytes = vec![3];
68                        append_bytes_vec(
69                            &mut bytes,
70                            &PickleableInitStrategy(Some(init_strategy.clone())).__getstate__()[..],
71                        );
72                        append_usize_vec(&mut bytes, field_serde_type_dict.len());
73                        for (field, serde_type) in field_serde_type_dict.iter() {
74                            append_string_vec(&mut bytes, field);
75                            append_bytes_vec(
76                                &mut bytes,
77                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
78                                    .__getstate__()?[..],
79                            );
80                        }
81                        Python::with_gil::<_, PyResult<_>>(|py| {
82                            let clazz_py_bytes = py
83                                .import("pickle")?
84                                .getattr("dumps")?
85                                .call1((clazz,))?
86                                .downcast_into::<PyBytes>()?;
87                            append_bytes_vec(&mut bytes, clazz_py_bytes.as_bytes());
88                            Ok(bytes)
89                        })?
90                    }
91                    PyAnySerdeType::DICT {
92                        keys_serde_type,
93                        values_serde_type,
94                    } => {
95                        let mut bytes = vec![4];
96                        Python::with_gil::<_, PyResult<_>>(|py| {
97                            for py_serde_type in
98                                vec![keys_serde_type, values_serde_type].into_iter()
99                            {
100                                let serde_type = py_serde_type.extract::<PyAnySerdeType>(py)?;
101                                append_bytes_vec(
102                                    &mut bytes,
103                                    &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
104                                        .__getstate__()?[..],
105                                );
106                            }
107                            Ok(bytes)
108                        })?
109                    }
110                    PyAnySerdeType::DYNAMIC {} => vec![5],
111                    PyAnySerdeType::FLOAT {} => vec![6],
112                    PyAnySerdeType::INT {} => vec![7],
113                    PyAnySerdeType::LIST { items_serde_type } => {
114                        let mut bytes = vec![8];
115                        Python::with_gil::<_, PyResult<_>>(|py| {
116                            let serde_type = items_serde_type.extract::<PyAnySerdeType>(py)?;
117                            append_bytes_vec(
118                                &mut bytes,
119                                &PickleablePyAnySerdeType(Some(Some(serde_type))).__getstate__()?[..],
120                            );
121                            Ok(bytes)
122                        })?
123                    }
124                    PyAnySerdeType::NUMPY { dtype, config } => {
125                        let mut bytes = vec![9, dtype.to_u8().unwrap()];
126                        append_bytes_vec(
127                            &mut bytes,
128                            &PickleableNumpySerdeConfig(Some(config.clone())).__getstate__()?[..],
129                        );
130                        bytes
131                    }
132                    PyAnySerdeType::OPTION { value_serde_type } => {
133                        let mut bytes = vec![10];
134                        Python::with_gil::<_, PyResult<_>>(|py| {
135                            let serde_type = value_serde_type.extract::<PyAnySerdeType>(py)?;
136                            append_bytes_vec(
137                                &mut bytes,
138                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
139                                    .__getstate__()?[..],
140                            );
141                            Ok(bytes)
142                        })?
143                    }
144                    PyAnySerdeType::PICKLE {} => vec![11],
145                    PyAnySerdeType::PYTHONSERDE { python_serde } => {
146                        let mut bytes = vec![12];
147                        Python::with_gil::<_, PyResult<_>>(|py| {
148                            let python_serde_py_bytes = py
149                                .import("pickle")?
150                                .getattr("dumps")?
151                                .call1((python_serde,))?
152                                .downcast_into::<PyBytes>()?;
153                            append_bytes_vec(&mut bytes, python_serde_py_bytes.as_bytes());
154                            Ok(bytes)
155                        })?
156                    }
157                    PyAnySerdeType::SET { items_serde_type } => {
158                        let mut bytes = vec![13];
159                        Python::with_gil::<_, PyResult<_>>(|py| {
160                            let serde_type = items_serde_type.extract::<PyAnySerdeType>(py)?;
161                            append_bytes_vec(
162                                &mut bytes,
163                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
164                                    .__getstate__()?[..],
165                            );
166                            Ok(bytes)
167                        })?
168                    }
169                    PyAnySerdeType::STRING {} => vec![14],
170                    PyAnySerdeType::TUPLE { item_serde_types } => {
171                        let mut bytes = vec![15];
172                        bytes.extend_from_slice(&item_serde_types.len().to_ne_bytes());
173                        for serde_type in item_serde_types.iter() {
174                            append_bytes_vec(
175                                &mut bytes,
176                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
177                                    .__getstate__()?[..],
178                            );
179                        }
180                        bytes
181                    }
182                    PyAnySerdeType::TYPEDDICT {
183                        key_serde_type_dict,
184                    } => {
185                        let mut bytes = vec![16];
186                        bytes.extend_from_slice(&key_serde_type_dict.len().to_ne_bytes());
187                        for (key, serde_type) in key_serde_type_dict.iter() {
188                            append_string_vec(&mut bytes, key);
189                            append_bytes_vec(
190                                &mut bytes,
191                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
192                                    .__getstate__()?[..],
193                            );
194                        }
195                        bytes
196                    }
197                    PyAnySerdeType::UNION {
198                        option_serde_types,
199                        option_choice_fn,
200                    } => {
201                        let mut bytes = vec![17];
202                        bytes.extend_from_slice(&option_serde_types.len().to_ne_bytes());
203                        for serde_type in option_serde_types.iter() {
204                            append_bytes_vec(
205                                &mut bytes,
206                                &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
207                                    .__getstate__()?[..],
208                            );
209                        }
210                        Python::with_gil::<_, PyResult<_>>(|py| {
211                            let option_choice_fn_py_bytes = py
212                                .import("pickle")?
213                                .getattr("dumps")?
214                                .call1((option_choice_fn,))?
215                                .downcast_into::<PyBytes>()?;
216                            append_bytes_vec(&mut bytes, option_choice_fn_py_bytes.as_bytes());
217                            Ok(bytes)
218                        })?
219                    }
220                };
221                option_bytes.append(&mut pyany_serde_type_bytes);
222                option_bytes
223            }
224            None => vec![0],
225        })
226    }
227
228    pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
229        let buf = &state[..];
230        let option_byte = state[0];
231        self.0 = Some(match option_byte {
232            0 => None,
233            1 => {
234                let type_byte = state[1];
235                let mut offset = 2;
236                Some(match type_byte {
237                    0 => PyAnySerdeType::BOOL {},
238                    1 => PyAnySerdeType::BYTES {},
239                    2 => PyAnySerdeType::COMPLEX {},
240                    3 => {
241                        let init_strategy_bytes;
242                        (init_strategy_bytes, offset) = retrieve_bytes(buf, offset)?;
243                        let mut pickleable_init_strategy = PickleableInitStrategy(None);
244                        pickleable_init_strategy.__setstate__(init_strategy_bytes.to_vec())?;
245                        let n_fields;
246                        (n_fields, offset) = retrieve_usize(buf, offset)?;
247                        let mut field_serde_type_dict = BTreeMap::new();
248                        for _ in 0..n_fields {
249                            let field;
250                            (field, offset) = retrieve_string(buf, offset)?;
251                            let serde_type_bytes;
252                            (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
253                            let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
254                            pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
255                            field_serde_type_dict
256                                .insert(field, pickleable_serde_type.0.unwrap().unwrap());
257                        }
258                        Python::with_gil::<_, PyResult<_>>(|py| {
259                            let clazz_bytes;
260                            (clazz_bytes, offset) = retrieve_bytes(buf, offset)?;
261                            let clazz = py
262                                .import("pickle")?
263                                .getattr("loads")?
264                                .call1((PyBytes::new(py, clazz_bytes).into_pyobject(py)?,))?
265                                .unbind();
266                            Ok(PyAnySerdeType::DATACLASS {
267                                clazz,
268                                init_strategy: pickleable_init_strategy.0.unwrap(),
269                                field_serde_type_dict,
270                            })
271                        })?
272                    }
273                    4 => Python::with_gil::<_, PyResult<_>>(|py| {
274                        let keys_serde_type_bytes;
275                        (keys_serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
276                        let mut pickleable_keys_serde_type = PickleablePyAnySerdeType(None);
277                        pickleable_keys_serde_type.__setstate__(keys_serde_type_bytes.to_vec())?;
278                        let values_serde_type_bytes;
279                        (values_serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
280                        let mut pickleable_values_serde_type = PickleablePyAnySerdeType(None);
281                        pickleable_values_serde_type
282                            .__setstate__(values_serde_type_bytes.to_vec())?;
283                        Ok(PyAnySerdeType::DICT {
284                            keys_serde_type: Py::new(
285                                py,
286                                pickleable_keys_serde_type.0.unwrap().unwrap(),
287                            )?,
288                            values_serde_type: Py::new(
289                                py,
290                                pickleable_values_serde_type.0.unwrap().unwrap(),
291                            )?,
292                        })
293                    })?,
294                    5 => PyAnySerdeType::DYNAMIC {},
295                    6 => PyAnySerdeType::FLOAT {},
296                    7 => PyAnySerdeType::INT {},
297                    8 => Python::with_gil::<_, PyResult<_>>(|py| {
298                        let serde_type_bytes;
299                        (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
300                        let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
301                        pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
302                        Ok(PyAnySerdeType::LIST {
303                            items_serde_type: Py::new(
304                                py,
305                                pickleable_serde_type.0.unwrap().unwrap(),
306                            )?,
307                        })
308                    })?,
309                    9 => {
310                        let dtype = NumpyDtype::from_u8(buf[offset]).unwrap();
311                        offset += 1;
312                        let numpy_serde_config_bytes;
313                        (numpy_serde_config_bytes, _) = retrieve_bytes(buf, offset)?;
314                        let mut pickleable_numpy_serde_config = PickleableNumpySerdeConfig(None);
315                        pickleable_numpy_serde_config
316                            .__setstate__(numpy_serde_config_bytes.to_vec())?;
317                        PyAnySerdeType::NUMPY {
318                            dtype,
319                            config: pickleable_numpy_serde_config.0.unwrap(),
320                        }
321                    }
322                    10 => Python::with_gil::<_, PyResult<_>>(|py| {
323                        let serde_type_bytes;
324                        (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
325                        let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
326                        pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
327                        Ok(PyAnySerdeType::OPTION {
328                            value_serde_type: Py::new(
329                                py,
330                                pickleable_serde_type.0.unwrap().unwrap(),
331                            )?,
332                        })
333                    })?,
334                    11 => PyAnySerdeType::PICKLE {},
335                    12 => Python::with_gil::<_, PyResult<_>>(|py| {
336                        let python_serde_bytes;
337                        (python_serde_bytes, offset) = retrieve_bytes(buf, offset)?;
338                        let python_serde = py
339                            .import("pickle")?
340                            .getattr("loads")?
341                            .call1((PyBytes::new(py, python_serde_bytes).into_pyobject(py)?,))?
342                            .unbind();
343                        Ok(PyAnySerdeType::PYTHONSERDE { python_serde })
344                    })?,
345                    13 => Python::with_gil::<_, PyResult<_>>(|py| {
346                        let serde_type_bytes;
347                        (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
348                        let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
349                        pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
350                        Ok(PyAnySerdeType::SET {
351                            items_serde_type: Py::new(
352                                py,
353                                pickleable_serde_type.0.unwrap().unwrap(),
354                            )?,
355                        })
356                    })?,
357                    14 => PyAnySerdeType::STRING {},
358                    15 => {
359                        let n_items;
360                        (n_items, offset) = retrieve_usize(buf, offset)?;
361                        let mut item_serde_types = Vec::with_capacity(n_items);
362                        for _ in 0..n_items {
363                            let serde_type_bytes;
364                            (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
365                            let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
366                            pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
367                            item_serde_types.push(pickleable_serde_type.0.unwrap().unwrap())
368                        }
369                        PyAnySerdeType::TUPLE { item_serde_types }
370                    }
371                    16 => {
372                        let n_keys;
373                        (n_keys, offset) = retrieve_usize(buf, offset)?;
374                        let mut key_serde_type_dict = BTreeMap::new();
375                        for _ in 0..n_keys {
376                            let key;
377                            (key, offset) = retrieve_string(buf, offset)?;
378                            let serde_type_bytes;
379                            (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
380                            let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
381                            pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
382                            key_serde_type_dict
383                                .insert(key, pickleable_serde_type.0.unwrap().unwrap());
384                        }
385                        PyAnySerdeType::TYPEDDICT {
386                            key_serde_type_dict,
387                        }
388                    }
389                    17 => {
390                        let n_options;
391                        (n_options, offset) = retrieve_usize(buf, offset)?;
392                        let mut option_serde_types = Vec::with_capacity(n_options);
393                        for _ in 0..n_options {
394                            let serde_type_bytes;
395                            (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
396                            let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
397                            pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
398                            option_serde_types.push(pickleable_serde_type.0.unwrap().unwrap())
399                        }
400                        Python::with_gil::<_, PyResult<_>>(|py| {
401                            let option_choice_fn_bytes;
402                            (option_choice_fn_bytes, offset) = retrieve_bytes(buf, offset)?;
403                            let option_choice_fn = py.import("pickle")?.getattr("loads")?.call1(
404                                (PyBytes::new(py, option_choice_fn_bytes).into_pyobject(py)?,),
405                            )?;
406                            Ok(PyAnySerdeType::UNION {
407                                option_serde_types,
408                                option_choice_fn: option_choice_fn
409                                    .downcast_into::<PyFunction>()?
410                                    .unbind(),
411                            })
412                        })?
413                    }
414                    v => Err(InvalidStateError::new_err(format!(
415                        "Got invalid type byte for PyAnySerde: {v}"
416                    )))?,
417                })
418            }
419            v => Err(InvalidStateError::new_err(format!(
420                "Got invalid option byte for PyAnySerdeType: {v}"
421            )))?,
422        });
423
424        Ok(())
425    }
426}
427
428#[pyclass]
429#[derive(Debug, Clone, Display, strum_macros::VariantNames)]
430pub enum PyAnySerdeType {
431    BOOL {},
432    BYTES {},
433    COMPLEX {},
434    DATACLASS {
435        clazz: PyObject,
436        init_strategy: InitStrategy,
437        field_serde_type_dict: BTreeMap<String, PyAnySerdeType>,
438    },
439    DICT {
440        keys_serde_type: Py<PyAnySerdeType>,
441        values_serde_type: Py<PyAnySerdeType>,
442    },
443    DYNAMIC {},
444    FLOAT {},
445    INT {},
446    LIST {
447        items_serde_type: Py<PyAnySerdeType>,
448    },
449    #[pyo3(constructor = (dtype, config = NumpySerdeConfig::DYNAMIC { preprocessor_fn: None, postprocessor_fn: None }))]
450    NUMPY {
451        dtype: NumpyDtype,
452        config: NumpySerdeConfig,
453    },
454    OPTION {
455        value_serde_type: Py<PyAnySerdeType>,
456    },
457    PICKLE {},
458    PYTHONSERDE {
459        python_serde: PyObject,
460    },
461    SET {
462        items_serde_type: Py<PyAnySerdeType>,
463    },
464    STRING {},
465    TUPLE {
466        item_serde_types: Vec<PyAnySerdeType>,
467    },
468    TYPEDDICT {
469        key_serde_type_dict: BTreeMap<String, PyAnySerdeType>,
470    },
471    UNION {
472        option_serde_types: Vec<PyAnySerdeType>,
473        option_choice_fn: Py<PyFunction>,
474    },
475}
476
477fn check_for_unpickling_aux<'py>(data: &Bound<'py, PyAny>) -> PyResult<bool> {
478    let pyany_serde_type_field = data
479        .get_item("type")?
480        .extract::<String>()?
481        .to_ascii_lowercase();
482    Ok(match pyany_serde_type_field.as_str() {
483        "dataclass" => true,
484        "dict" => {
485            check_for_unpickling_aux(&data.get_item("keys_serde_type")?)?
486                || check_for_unpickling_aux(&data.get_item("values_serde_type")?)?
487        }
488        "list" => check_for_unpickling_aux(&data.get_item("items_serde_type")?)?,
489        "numpy" => numpy_check_for_unpickling(&data.get_item("config")?)?,
490        "option" => check_for_unpickling_aux(&data.get_item("value_serde_type")?)?,
491        "pythonserde" => true,
492        "set" => check_for_unpickling_aux(&data.get_item("items_serde_type")?)?,
493        "tuple" => {
494            let mut has_unpickling = false;
495            for item_serde_type_data in data
496                .get_item("item_serde_types")?
497                .extract::<Vec<Bound<'_, PyAny>>>()?
498                .iter()
499            {
500                has_unpickling |= check_for_unpickling_aux(&item_serde_type_data)?;
501            }
502            has_unpickling
503        }
504        "typeddict" => {
505            let mut has_unpickling = false;
506            for (_, serde_type_data) in data
507                .get_item("key_serde_type_dict")?
508                .downcast_into::<PyDict>()?
509                .iter()
510            {
511                has_unpickling |= check_for_unpickling_aux(&serde_type_data)?;
512            }
513            has_unpickling
514        }
515        "union" => true,
516        _ => false,
517    })
518}
519
520#[pyfunction]
521fn check_for_unpickling<'py, 'a>(data: &'a Bound<'py, PyAny>) -> PyResult<&'a Bound<'py, PyAny>> {
522    let silent_mode = env::var("PYANY_SERDE_UNPICKLE_WITHOUT_PROMPT")
523        .map(|v| v.eq("1"))
524        .unwrap_or(false);
525    if !silent_mode && check_for_unpickling_aux(&data)? {
526        println!("WARNING: About to call unpickle on the hexadecimal-encoded binary contents of some config fields. If you do not trust the origins of this json, or you cannot otherwise verify the safety of this field's contents, you should not proceed.");
527        print!("Proceed? (y/N)\t");
528        io::stdout().flush()?;
529        let mut response = String::new();
530        io::stdin().read_line(&mut response).unwrap();
531        if !response.trim().eq_ignore_ascii_case("y") {
532            Err(PyValueError::new_err("Operation cancelled by user due to unpickling required to build config model from json"))?
533        } else {
534            println!("Continuing with execution. If you would like to ignore this warning in the future, set the environment variable PYANY_SERDE_UNPICKLE_WITHOUT_PROMPT to \"1\".")
535        }
536    }
537    Ok(data)
538}
539
540fn get_before_validator_fn<'py>(
541    _handler: &Bound<'py, PyAny>,
542    _schema_validator: &Bound<'py, PyAny>,
543) -> PyResult<Bound<'py, PyCFunction>> {
544    let _py = _handler.py();
545    let py_handler = _handler.clone().unbind();
546    let py_schema_validator = _schema_validator.clone().unbind();
547    let func = move |args: &Bound<'_, PyTuple>,
548                     _kwargs: Option<&Bound<'_, PyDict>>|
549          -> PyResult<PyObject> {
550        // initial setup
551        let py = args.py();
552        let data = args.get_item(0)?;
553        let handler = py_handler.bind(py);
554        let schema_validator = py_schema_validator.bind(py);
555
556        // processing of data
557        let pyany_serde_type_field = data
558            .get_item("type")?
559            .extract::<String>()?
560            .to_ascii_lowercase();
561        let pyany_serde_type = match pyany_serde_type_field.as_str() {
562            "bool" => PyAnySerdeType::BOOL {},
563            "bytes" => PyAnySerdeType::BYTES {},
564            "complex" => PyAnySerdeType::COMPLEX {},
565            "dataclass" => {
566                let clazz_bytes_hex = data.get_item("dataclass_pkl")?.extract::<String>()?;
567                let clazz = py
568                    .import("pickle")?
569                    .getattr("loads")?
570                    .call1((PyBytes::new(
571                        py,
572                        &hex::decode(clazz_bytes_hex.as_str()).map_err(|err| {
573                            PyValueError::new_err(format!(
574                                "dataclass_pkl could not be decoded from hex into bytes: {}",
575                                err.to_string()
576                            ))
577                        })?,
578                    ),))?
579                    .unbind();
580                let init_strategy = schema_validator
581                    .call1((handler
582                        .call_method1("generate_schema", (InitStrategy::type_object(py),))?,))?
583                    .call_method1("validate_python", (data.get_item("init_strategy")?,))?
584                    .extract::<InitStrategy>()?;
585                let mut field_serde_type_dict = BTreeMap::new();
586                for (key, serde_type_data) in data
587                    .get_item("field_serde_type_dict")?
588                    .downcast_into::<PyDict>()?
589                    .into_iter()
590                {
591                    let key = key.extract::<String>()?;
592                    let value = get_before_validator_fn(handler, schema_validator)?
593                        .call1((serde_type_data,))?
594                        .extract::<PyAnySerdeType>()?;
595                    field_serde_type_dict.insert(key, value);
596                }
597                PyAnySerdeType::DATACLASS {
598                    clazz,
599                    init_strategy,
600                    field_serde_type_dict,
601                }
602            }
603            "dict" => {
604                let keys_serde_type_data = data.get_item("keys_serde_type")?;
605                let keys_serde_type = get_before_validator_fn(handler, schema_validator)?
606                    .call1((keys_serde_type_data,))?
607                    .extract::<PyAnySerdeType>()?;
608                let values_serde_type_data = data.get_item("values_serde_type")?;
609                let values_serde_type = get_before_validator_fn(handler, schema_validator)?
610                    .call1((values_serde_type_data,))?
611                    .extract::<PyAnySerdeType>()?;
612                PyAnySerdeType::DICT {
613                    keys_serde_type: Py::new(py, keys_serde_type)?,
614                    values_serde_type: Py::new(py, values_serde_type)?,
615                }
616            }
617            "dynamic" => PyAnySerdeType::DYNAMIC {},
618            "float" => PyAnySerdeType::FLOAT {},
619            "int" => PyAnySerdeType::INT {},
620            "list" => {
621                let items_serde_type_data = data.get_item("items_serde_type")?;
622                let items_serde_type = get_before_validator_fn(handler, schema_validator)?
623                    .call1((items_serde_type_data,))?
624                    .extract::<PyAnySerdeType>()?;
625                PyAnySerdeType::LIST {
626                    items_serde_type: Py::new(py, items_serde_type)?,
627                }
628            }
629            "numpy" => {
630                let dtype_string = data.get_item("dtype")?.extract::<String>()?;
631                let dtype = NumpyDtype::from_str(dtype_string.as_str()).map_err(|_| {
632                    PyValueError::new_err(format!(
633                        "dtype was provided as {dtype_string} which is not a valid dtype"
634                    ))
635                })?;
636                let numpy_serde_config = schema_validator
637                    .call1((handler
638                        .call_method1("generate_schema", (NumpySerdeConfig::type_object(py),))?,))?
639                    .call_method1("validate_python", (data.get_item("config")?,))?
640                    .extract::<NumpySerdeConfig>()?;
641                PyAnySerdeType::NUMPY {
642                    dtype,
643                    config: numpy_serde_config,
644                }
645            }
646            "option" => {
647                let value_serde_type_data = data.get_item("value_serde_type")?;
648                let value_serde_type = get_before_validator_fn(handler, schema_validator)?
649                    .call1((value_serde_type_data,))?
650                    .extract::<PyAnySerdeType>()?;
651                PyAnySerdeType::OPTION {
652                    value_serde_type: Py::new(py, value_serde_type)?,
653                }
654            }
655            "pickle" => PyAnySerdeType::PICKLE {},
656            "pythonserde" => {
657                let python_serde_bytes_hex =
658                    data.get_item("python_serde_pkl")?.extract::<String>()?;
659                let python_serde = py
660                    .import("pickle")?
661                    .getattr("loads")?
662                    .call1((PyBytes::new(
663                        py,
664                        &hex::decode(python_serde_bytes_hex.as_str()).map_err(|err| {
665                            PyValueError::new_err(format!(
666                                "python_serde_pkl could not be decoded from hex into bytes: {}",
667                                err.to_string()
668                            ))
669                        })?,
670                    ),))?
671                    .unbind();
672                PyAnySerdeType::PYTHONSERDE { python_serde }
673            }
674            "set" => {
675                let items_serde_type_data = data.get_item("items_serde_type")?;
676                let items_serde_type = get_before_validator_fn(handler, schema_validator)?
677                    .call1((items_serde_type_data,))?
678                    .extract::<PyAnySerdeType>()?;
679                PyAnySerdeType::SET {
680                    items_serde_type: Py::new(py, items_serde_type)?,
681                }
682            }
683            "string" => PyAnySerdeType::STRING {},
684            "tuple" => {
685                let item_serde_types_data = data
686                    .get_item("item_serde_types")?
687                    .extract::<Vec<Bound<'_, PyAny>>>()?;
688                let item_serde_types = item_serde_types_data
689                    .iter()
690                    .map(|item_serde_type_data| {
691                        Ok(get_before_validator_fn(handler, schema_validator)?
692                            .call1((item_serde_type_data,))?
693                            .extract::<PyAnySerdeType>()?)
694                    })
695                    .collect::<PyResult<Vec<_>>>()?;
696                PyAnySerdeType::TUPLE { item_serde_types }
697            }
698            "typeddict" => {
699                let mut key_serde_type_dict = BTreeMap::new();
700                for (key, serde_type_data) in data
701                    .get_item("key_serde_type_dict")?
702                    .downcast_into::<PyDict>()?
703                    .into_iter()
704                {
705                    let key = key.extract::<String>()?;
706                    let value = get_before_validator_fn(handler, schema_validator)?
707                        .call1((serde_type_data,))?
708                        .extract::<PyAnySerdeType>()?;
709                    key_serde_type_dict.insert(key, value);
710                }
711                PyAnySerdeType::TYPEDDICT {
712                    key_serde_type_dict,
713                }
714            }
715            "union" => {
716                let option_serde_types_data = data
717                    .get_item("option_serde_types")?
718                    .extract::<Vec<Bound<'_, PyAny>>>()?;
719                let option_serde_types = option_serde_types_data
720                    .iter()
721                    .map(|option_serde_type_data| {
722                        Ok(get_before_validator_fn(handler, schema_validator)?
723                            .call1((option_serde_type_data,))?
724                            .extract::<PyAnySerdeType>()?)
725                    })
726                    .collect::<PyResult<Vec<_>>>()?;
727                let option_choice_fn_bytes_hex =
728                    data.get_item("option_choice_fn_pkl")?.extract::<String>()?;
729                let option_choice_fn = py
730                    .import("pickle")?
731                    .getattr("loads")?
732                    .call1((PyBytes::new(
733                        py,
734                        &hex::decode(option_choice_fn_bytes_hex.as_str()).map_err(|err| {
735                            PyValueError::new_err(format!(
736                                "option_choice_fn_pkl could not be decoded from hex into bytes: {}",
737                                err.to_string()
738                            ))
739                        })?,
740                    ),))?
741                    .downcast_into::<PyFunction>()?
742                    .unbind();
743                PyAnySerdeType::UNION {
744                    option_serde_types,
745                    option_choice_fn,
746                }
747            }
748            v => Err(PyValueError::new_err(format!("Unexpected type: {v}")))?,
749        };
750
751        Ok(pyany_serde_type.into_pyobject(py)?.into_any().unbind())
752    };
753    PyCFunction::new_closure(_py, None, None, func)
754}
755
756#[pymethods]
757impl PyAnySerdeType {
758    fn as_pickleable<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
759        Ok(PickleablePyAnySerdeType(Some(Some(self.clone())))
760            .into_pyobject(py)?
761            .into_any())
762    }
763
764    // pydantic methods
765    #[classmethod]
766    fn __get_pydantic_core_schema__<'py>(
767        cls: &Bound<'py, PyType>,
768        _source_type: Bound<'py, PyAny>,
769        handler: Bound<'py, PyAny>,
770    ) -> PyResult<Bound<'py, PyAny>> {
771        let py = cls.py();
772        let generate_schema = handler.getattr("generate_schema")?;
773        let pydantic_core = py.import("pydantic_core")?;
774        let schema_validator = pydantic_core.getattr("SchemaValidator")?;
775        let core_schema = pydantic_core.getattr("core_schema")?;
776
777        let str_schema = core_schema.getattr("str_schema")?;
778        let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
779        let list_schema = core_schema.getattr("list_schema")?;
780        let dict_schema = core_schema.getattr("dict_schema")?;
781        let any_schema = core_schema.getattr("any_schema")?;
782        let typed_dict_field = core_schema.getattr("typed_dict_field")?;
783
784        let pyany_serde_type_reference_schema = core_schema
785            .call_method1("definition_reference_schema", ("pyany_serde_type_schema",))?;
786        let pyany_serde_type_reference_schema_field =
787            typed_dict_field.call1((&pyany_serde_type_reference_schema,))?;
788
789        let union_list = PyAnySerdeType::VARIANTS
790            .iter()
791            .map(|pyany_serde_type_variant| {
792                let pyany_serde_type_field = pyany_serde_type_variant.to_ascii_lowercase();
793                let typed_dict_fields = PyDict::new(py);
794                typed_dict_fields.set_item(
795                    "type",
796                    typed_dict_field.call1((str_schema.call(
797                        (),
798                        Some(&PyDict::from_sequence(
799                            &vec![(
800                                "pattern",
801                                vec![
802                                    "^".to_owned(),
803                                    pyany_serde_type_field.clone(),
804                                    "$".to_owned(),
805                                ]
806                                .join("")
807                                .into_pyobject(py)?
808                                .into_any(),
809                            )]
810                            .into_pyobject(py)?,
811                        )?),
812                    )?,))?,
813                )?;
814                match pyany_serde_type_field.as_str() {
815                    "dataclass" => {
816                        typed_dict_fields.set_item(
817                            "dataclass_pkl",
818                            typed_dict_field.call1((str_schema.call0()?,))?,
819                        )?;
820                        typed_dict_fields.set_item(
821                            "init_strategy",
822                            typed_dict_field.call1((
823                                generate_schema.call1((InitStrategy::type_object(py),))?,
824                            ))?,
825                        )?;
826                        typed_dict_fields.set_item(
827                            "field_serde_type_dict",
828                            typed_dict_field.call1((dict_schema.call1((
829                                str_schema.call0()?,
830                                &pyany_serde_type_reference_schema,
831                            ))?,))?,
832                        )?;
833                    }
834                    "dict" => {
835                        typed_dict_fields.set_item(
836                            "keys_serde_type",
837                            &pyany_serde_type_reference_schema_field,
838                        )?;
839                        typed_dict_fields.set_item(
840                            "values_serde_type",
841                            &pyany_serde_type_reference_schema_field,
842                        )?;
843                    }
844                    "list" => {
845                        typed_dict_fields.set_item(
846                            "items_serde_type",
847                            &pyany_serde_type_reference_schema_field,
848                        )?;
849                    }
850                    "numpy" => {
851                        typed_dict_fields.set_item(
852                            "dtype",
853                            typed_dict_field.call1((str_schema.call(
854                                (),
855                                Some(&PyDict::from_sequence(
856                                    &vec![(
857                                        "pattern",
858                                        vec![
859                                            "^(".to_owned(),
860                                            NumpyDtype::iter()
861                                                .map(|dtype_str| dtype_str.to_string())
862                                                .collect::<Vec<_>>()
863                                                .join("|"),
864                                            ")$".to_owned(),
865                                        ]
866                                        .join(""),
867                                    )]
868                                    .into_pyobject(py)?,
869                                )?),
870                            )?,))?,
871                        )?;
872                        typed_dict_fields.set_item(
873                            "config",
874                            typed_dict_field.call1((
875                                generate_schema.call1((NumpySerdeConfig::type_object(py),))?,
876                            ))?,
877                        )?;
878                    }
879                    "option" => {
880                        typed_dict_fields.set_item(
881                            "value_serde_type",
882                            &pyany_serde_type_reference_schema_field,
883                        )?;
884                    }
885                    "pythonserde" => {
886                        typed_dict_fields.set_item(
887                            "python_serde_pkl",
888                            typed_dict_field.call1((str_schema.call0()?,))?,
889                        )?;
890                    }
891                    "set" => {
892                        typed_dict_fields.set_item(
893                            "items_serde_type",
894                            &pyany_serde_type_reference_schema_field,
895                        )?;
896                    }
897                    "tuple" => {
898                        typed_dict_fields.set_item(
899                            "item_serde_types",
900                            typed_dict_field.call1((
901                                list_schema.call1((&pyany_serde_type_reference_schema,))?,
902                            ))?,
903                        )?;
904                    }
905                    "typeddict" => {
906                        typed_dict_fields.set_item(
907                            "key_serde_type_dict",
908                            typed_dict_field.call1((dict_schema.call1((
909                                str_schema.call0()?,
910                                &pyany_serde_type_reference_schema,
911                            ))?,))?,
912                        )?;
913                    }
914                    "union" => {
915                        typed_dict_fields.set_item(
916                            "option_serde_types",
917                            typed_dict_field.call1((
918                                list_schema.call1((&pyany_serde_type_reference_schema,))?,
919                            ))?,
920                        )?;
921                        typed_dict_fields.set_item(
922                            "option_choice_fn_pkl",
923                            typed_dict_field.call1((str_schema.call0()?,))?,
924                        )?;
925                    }
926                    _ => (),
927                };
928                Ok(typed_dict_schema.call1((typed_dict_fields,))?)
929            })
930            .collect::<PyResult<Vec<_>>>()?;
931        let pyany_serde_type_union_schema = core_schema.call_method(
932            "union_schema",
933            (union_list,),
934            Some(&PyDict::from_sequence(
935                &vec![("ref", "pyany_serde_type_schema")].into_pyobject(py)?,
936            )?),
937        )?;
938
939        let pyany_serde_type_python_schema =
940            core_schema.call_method1("is_instance_schema", (PyAnySerdeType::type_object(py),))?;
941        let pyany_serde_type_json_or_python_schema = core_schema.call_method1(
942            "json_or_python_schema",
943            (
944                core_schema.call_method1(
945                    "chain_schema",
946                    (vec![
947                        core_schema.call_method1(
948                            "no_info_before_validator_function",
949                            (
950                                wrap_pyfunction!(check_for_unpickling, py)?,
951                                any_schema.call0()?,
952                            ),
953                        )?,
954                        pyany_serde_type_union_schema.clone(),
955                        core_schema.call_method1(
956                            "no_info_before_validator_function",
957                            (
958                                get_before_validator_fn(&handler, &schema_validator)?,
959                                &pyany_serde_type_python_schema,
960                            ),
961                        )?,
962                    ],),
963                )?,
964                pyany_serde_type_python_schema,
965            ),
966        )?;
967        core_schema.call_method(
968            "definitions_schema",
969            (&pyany_serde_type_json_or_python_schema,),
970            Some(&PyDict::from_sequence(
971                &vec![("definitions", vec![&pyany_serde_type_union_schema])].into_pyobject(py)?,
972            )?),
973        )
974    }
975
976    fn to_json(&self) -> PyResult<PyObject> {
977        Python::with_gil(|py| {
978            let data = PyDict::new(py);
979            data.set_item("type", self.to_string().to_ascii_lowercase())?;
980            if let PyAnySerdeType::DATACLASS {
981                clazz,
982                init_strategy,
983                field_serde_type_dict,
984            } = self
985            {
986                data.set_item(
987                    "dataclass_pkl",
988                    py.import("pickle")?
989                        .getattr("dumps")?
990                        .call1((clazz,))?
991                        .call_method0("hex")?,
992                )?;
993                data.set_item("init_strategy", init_strategy.to_json()?)?;
994                data.set_item(
995                    "field_serde_type_dict",
996                    field_serde_type_dict
997                        .iter()
998                        .map(|(key, field_serde_type)| Ok((key, field_serde_type.to_json()?)))
999                        .collect::<PyResult<BTreeMap<_, _>>>()?,
1000                )?;
1001            } else if let PyAnySerdeType::DICT {
1002                keys_serde_type,
1003                values_serde_type,
1004            } = self
1005            {
1006                data.set_item(
1007                    "keys_serde_type",
1008                    keys_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1009                )?;
1010                data.set_item(
1011                    "values_serde_type",
1012                    values_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1013                )?;
1014            } else if let PyAnySerdeType::LIST { items_serde_type } = self {
1015                data.set_item(
1016                    "items_serde_type",
1017                    items_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1018                )?;
1019            } else if let PyAnySerdeType::NUMPY { dtype, config } = self {
1020                data.set_item("dtype", dtype.to_string())?;
1021                data.set_item("config", config.to_json()?)?;
1022            } else if let PyAnySerdeType::OPTION { value_serde_type } = self {
1023                data.set_item(
1024                    "value_serde_type",
1025                    value_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1026                )?;
1027            } else if let PyAnySerdeType::PYTHONSERDE { python_serde } = self {
1028                data.set_item(
1029                    "python_serde_pkl",
1030                    py.import("pickle")?
1031                        .getattr("dumps")?
1032                        .call1((python_serde,))?
1033                        .call_method0("hex")?,
1034                )?;
1035            } else if let PyAnySerdeType::SET { items_serde_type } = self {
1036                data.set_item(
1037                    "items_serde_type",
1038                    items_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1039                )?;
1040            } else if let PyAnySerdeType::TUPLE { item_serde_types } = self {
1041                data.set_item(
1042                    "item_serde_types",
1043                    item_serde_types
1044                        .iter()
1045                        .map(|item_serde_type| item_serde_type.to_json())
1046                        .collect::<PyResult<Vec<_>>>()?,
1047                )?;
1048            } else if let PyAnySerdeType::TYPEDDICT {
1049                key_serde_type_dict,
1050            } = self
1051            {
1052                data.set_item(
1053                    "key_serde_type_dict",
1054                    key_serde_type_dict
1055                        .iter()
1056                        .map(|(key, field_serde_type)| Ok((key, field_serde_type.to_json()?)))
1057                        .collect::<PyResult<BTreeMap<_, _>>>()?,
1058                )?;
1059            } else if let PyAnySerdeType::UNION {
1060                option_serde_types,
1061                option_choice_fn,
1062            } = self
1063            {
1064                data.set_item(
1065                    "option_serde_types",
1066                    option_serde_types
1067                        .iter()
1068                        .map(|item_serde_type| item_serde_type.to_json())
1069                        .collect::<PyResult<Vec<_>>>()?,
1070                )?;
1071                data.set_item(
1072                    "option_choice_fn_pkl",
1073                    py.import("pickle")?
1074                        .getattr("dumps")?
1075                        .call1((option_choice_fn,))?
1076                        .call_method0("hex")?,
1077                )?;
1078            }
1079            Ok(data.into_any().unbind())
1080        })
1081    }
1082}