1use std::collections::BTreeMap;
2use std::env;
3use std::io;
4use std::io::Write;
5use std::str::FromStr;
6
7use num_traits::{FromPrimitive, ToPrimitive};
8use pyo3::exceptions::asyncio::InvalidStateError;
9use pyo3::exceptions::PyValueError;
10use pyo3::types::{PyBytes, PyCFunction, PyDict, PyFunction, PyTuple, PyType};
11use pyo3::{prelude::*, PyTypeInfo};
12use strum::{IntoEnumIterator, VariantNames};
13use strum_macros::Display;
14
15use crate::common::NumpyDtype;
16use crate::communication::{
17 append_bytes_vec, append_string_vec, append_usize_vec, retrieve_bytes, retrieve_string,
18 retrieve_usize,
19};
20use crate::pyany_serde_impl::{
21 numpy_check_for_unpickling, InitStrategy, NumpySerdeConfig, PickleableInitStrategy,
22 PickleableNumpySerdeConfig,
23};
24
25#[pyclass]
27#[derive(Clone)]
28pub struct PickleablePyAnySerdeType(pub Option<Option<PyAnySerdeType>>);
29
30#[pymethods]
31impl PickleablePyAnySerdeType {
32 #[new]
34 #[pyo3(signature = (*args))]
35 fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
36 let vec_args = args.iter().collect::<Vec<_>>();
37 if vec_args.len() > 1 {
38 return Err(PyValueError::new_err(format!(
39 "PickleablePyAnySerde constructor takes 0 or 1 parameters, received {}",
40 args.as_any().repr()?.to_str()?
41 )));
42 }
43 if vec_args.len() == 1 {
44 Ok(PickleablePyAnySerdeType(Some(
45 vec_args[0].extract::<Option<PyAnySerdeType>>()?,
46 )))
47 } else {
48 Ok(PickleablePyAnySerdeType(None))
49 }
50 }
51
52 pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
54 let pyany_serde_type_option = self.0.as_ref().unwrap();
55 Ok(match pyany_serde_type_option {
56 Some(pyany_serde_type) => {
57 let mut option_bytes = vec![1];
58 let mut pyany_serde_type_bytes = match pyany_serde_type {
59 PyAnySerdeType::BOOL {} => vec![0],
60 PyAnySerdeType::BYTES {} => vec![1],
61 PyAnySerdeType::COMPLEX {} => vec![2],
62 PyAnySerdeType::DATACLASS {
63 clazz,
64 init_strategy,
65 field_serde_type_dict,
66 } => {
67 let mut bytes = vec![3];
68 append_bytes_vec(
69 &mut bytes,
70 &PickleableInitStrategy(Some(init_strategy.clone())).__getstate__()[..],
71 );
72 append_usize_vec(&mut bytes, field_serde_type_dict.len());
73 for (field, serde_type) in field_serde_type_dict.iter() {
74 append_string_vec(&mut bytes, field);
75 append_bytes_vec(
76 &mut bytes,
77 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
78 .__getstate__()?[..],
79 );
80 }
81 Python::with_gil::<_, PyResult<_>>(|py| {
82 let clazz_py_bytes = py
83 .import("pickle")?
84 .getattr("dumps")?
85 .call1((clazz,))?
86 .downcast_into::<PyBytes>()?;
87 append_bytes_vec(&mut bytes, clazz_py_bytes.as_bytes());
88 Ok(bytes)
89 })?
90 }
91 PyAnySerdeType::DICT {
92 keys_serde_type,
93 values_serde_type,
94 } => {
95 let mut bytes = vec![4];
96 Python::with_gil::<_, PyResult<_>>(|py| {
97 for py_serde_type in
98 vec![keys_serde_type, values_serde_type].into_iter()
99 {
100 let serde_type = py_serde_type.extract::<PyAnySerdeType>(py)?;
101 append_bytes_vec(
102 &mut bytes,
103 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
104 .__getstate__()?[..],
105 );
106 }
107 Ok(bytes)
108 })?
109 }
110 PyAnySerdeType::DYNAMIC {} => vec![5],
111 PyAnySerdeType::FLOAT {} => vec![6],
112 PyAnySerdeType::INT {} => vec![7],
113 PyAnySerdeType::LIST { items_serde_type } => {
114 let mut bytes = vec![8];
115 Python::with_gil::<_, PyResult<_>>(|py| {
116 let serde_type = items_serde_type.extract::<PyAnySerdeType>(py)?;
117 append_bytes_vec(
118 &mut bytes,
119 &PickleablePyAnySerdeType(Some(Some(serde_type))).__getstate__()?[..],
120 );
121 Ok(bytes)
122 })?
123 }
124 PyAnySerdeType::NUMPY { dtype, config } => {
125 let mut bytes = vec![9, dtype.to_u8().unwrap()];
126 append_bytes_vec(
127 &mut bytes,
128 &PickleableNumpySerdeConfig(Some(config.clone())).__getstate__()?[..],
129 );
130 bytes
131 }
132 PyAnySerdeType::OPTION { value_serde_type } => {
133 let mut bytes = vec![10];
134 Python::with_gil::<_, PyResult<_>>(|py| {
135 let serde_type = value_serde_type.extract::<PyAnySerdeType>(py)?;
136 append_bytes_vec(
137 &mut bytes,
138 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
139 .__getstate__()?[..],
140 );
141 Ok(bytes)
142 })?
143 }
144 PyAnySerdeType::PICKLE {} => vec![11],
145 PyAnySerdeType::PYTHONSERDE { python_serde } => {
146 let mut bytes = vec![12];
147 Python::with_gil::<_, PyResult<_>>(|py| {
148 let python_serde_py_bytes = py
149 .import("pickle")?
150 .getattr("dumps")?
151 .call1((python_serde,))?
152 .downcast_into::<PyBytes>()?;
153 append_bytes_vec(&mut bytes, python_serde_py_bytes.as_bytes());
154 Ok(bytes)
155 })?
156 }
157 PyAnySerdeType::SET { items_serde_type } => {
158 let mut bytes = vec![13];
159 Python::with_gil::<_, PyResult<_>>(|py| {
160 let serde_type = items_serde_type.extract::<PyAnySerdeType>(py)?;
161 append_bytes_vec(
162 &mut bytes,
163 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
164 .__getstate__()?[..],
165 );
166 Ok(bytes)
167 })?
168 }
169 PyAnySerdeType::STRING {} => vec![14],
170 PyAnySerdeType::TUPLE { item_serde_types } => {
171 let mut bytes = vec![15];
172 bytes.extend_from_slice(&item_serde_types.len().to_ne_bytes());
173 for serde_type in item_serde_types.iter() {
174 append_bytes_vec(
175 &mut bytes,
176 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
177 .__getstate__()?[..],
178 );
179 }
180 bytes
181 }
182 PyAnySerdeType::TYPEDDICT {
183 key_serde_type_dict,
184 } => {
185 let mut bytes = vec![16];
186 bytes.extend_from_slice(&key_serde_type_dict.len().to_ne_bytes());
187 for (key, serde_type) in key_serde_type_dict.iter() {
188 append_string_vec(&mut bytes, key);
189 append_bytes_vec(
190 &mut bytes,
191 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
192 .__getstate__()?[..],
193 );
194 }
195 bytes
196 }
197 PyAnySerdeType::UNION {
198 option_serde_types,
199 option_choice_fn,
200 } => {
201 let mut bytes = vec![17];
202 bytes.extend_from_slice(&option_serde_types.len().to_ne_bytes());
203 for serde_type in option_serde_types.iter() {
204 append_bytes_vec(
205 &mut bytes,
206 &PickleablePyAnySerdeType(Some(Some(serde_type.clone())))
207 .__getstate__()?[..],
208 );
209 }
210 Python::with_gil::<_, PyResult<_>>(|py| {
211 let option_choice_fn_py_bytes = py
212 .import("pickle")?
213 .getattr("dumps")?
214 .call1((option_choice_fn,))?
215 .downcast_into::<PyBytes>()?;
216 append_bytes_vec(&mut bytes, option_choice_fn_py_bytes.as_bytes());
217 Ok(bytes)
218 })?
219 }
220 };
221 option_bytes.append(&mut pyany_serde_type_bytes);
222 option_bytes
223 }
224 None => vec![0],
225 })
226 }
227
228 pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
229 let buf = &state[..];
230 let option_byte = state[0];
231 self.0 = Some(match option_byte {
232 0 => None,
233 1 => {
234 let type_byte = state[1];
235 let mut offset = 2;
236 Some(match type_byte {
237 0 => PyAnySerdeType::BOOL {},
238 1 => PyAnySerdeType::BYTES {},
239 2 => PyAnySerdeType::COMPLEX {},
240 3 => {
241 let init_strategy_bytes;
242 (init_strategy_bytes, offset) = retrieve_bytes(buf, offset)?;
243 let mut pickleable_init_strategy = PickleableInitStrategy(None);
244 pickleable_init_strategy.__setstate__(init_strategy_bytes.to_vec())?;
245 let n_fields;
246 (n_fields, offset) = retrieve_usize(buf, offset)?;
247 let mut field_serde_type_dict = BTreeMap::new();
248 for _ in 0..n_fields {
249 let field;
250 (field, offset) = retrieve_string(buf, offset)?;
251 let serde_type_bytes;
252 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
253 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
254 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
255 field_serde_type_dict
256 .insert(field, pickleable_serde_type.0.unwrap().unwrap());
257 }
258 Python::with_gil::<_, PyResult<_>>(|py| {
259 let clazz_bytes;
260 (clazz_bytes, offset) = retrieve_bytes(buf, offset)?;
261 let clazz = py
262 .import("pickle")?
263 .getattr("loads")?
264 .call1((PyBytes::new(py, clazz_bytes).into_pyobject(py)?,))?
265 .unbind();
266 Ok(PyAnySerdeType::DATACLASS {
267 clazz,
268 init_strategy: pickleable_init_strategy.0.unwrap(),
269 field_serde_type_dict,
270 })
271 })?
272 }
273 4 => Python::with_gil::<_, PyResult<_>>(|py| {
274 let keys_serde_type_bytes;
275 (keys_serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
276 let mut pickleable_keys_serde_type = PickleablePyAnySerdeType(None);
277 pickleable_keys_serde_type.__setstate__(keys_serde_type_bytes.to_vec())?;
278 let values_serde_type_bytes;
279 (values_serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
280 let mut pickleable_values_serde_type = PickleablePyAnySerdeType(None);
281 pickleable_values_serde_type
282 .__setstate__(values_serde_type_bytes.to_vec())?;
283 Ok(PyAnySerdeType::DICT {
284 keys_serde_type: Py::new(
285 py,
286 pickleable_keys_serde_type.0.unwrap().unwrap(),
287 )?,
288 values_serde_type: Py::new(
289 py,
290 pickleable_values_serde_type.0.unwrap().unwrap(),
291 )?,
292 })
293 })?,
294 5 => PyAnySerdeType::DYNAMIC {},
295 6 => PyAnySerdeType::FLOAT {},
296 7 => PyAnySerdeType::INT {},
297 8 => Python::with_gil::<_, PyResult<_>>(|py| {
298 let serde_type_bytes;
299 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
300 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
301 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
302 Ok(PyAnySerdeType::LIST {
303 items_serde_type: Py::new(
304 py,
305 pickleable_serde_type.0.unwrap().unwrap(),
306 )?,
307 })
308 })?,
309 9 => {
310 let dtype = NumpyDtype::from_u8(buf[offset]).unwrap();
311 offset += 1;
312 let numpy_serde_config_bytes;
313 (numpy_serde_config_bytes, _) = retrieve_bytes(buf, offset)?;
314 let mut pickleable_numpy_serde_config = PickleableNumpySerdeConfig(None);
315 pickleable_numpy_serde_config
316 .__setstate__(numpy_serde_config_bytes.to_vec())?;
317 PyAnySerdeType::NUMPY {
318 dtype,
319 config: pickleable_numpy_serde_config.0.unwrap(),
320 }
321 }
322 10 => Python::with_gil::<_, PyResult<_>>(|py| {
323 let serde_type_bytes;
324 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
325 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
326 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
327 Ok(PyAnySerdeType::OPTION {
328 value_serde_type: Py::new(
329 py,
330 pickleable_serde_type.0.unwrap().unwrap(),
331 )?,
332 })
333 })?,
334 11 => PyAnySerdeType::PICKLE {},
335 12 => Python::with_gil::<_, PyResult<_>>(|py| {
336 let python_serde_bytes;
337 (python_serde_bytes, offset) = retrieve_bytes(buf, offset)?;
338 let python_serde = py
339 .import("pickle")?
340 .getattr("loads")?
341 .call1((PyBytes::new(py, python_serde_bytes).into_pyobject(py)?,))?
342 .unbind();
343 Ok(PyAnySerdeType::PYTHONSERDE { python_serde })
344 })?,
345 13 => Python::with_gil::<_, PyResult<_>>(|py| {
346 let serde_type_bytes;
347 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
348 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
349 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
350 Ok(PyAnySerdeType::SET {
351 items_serde_type: Py::new(
352 py,
353 pickleable_serde_type.0.unwrap().unwrap(),
354 )?,
355 })
356 })?,
357 14 => PyAnySerdeType::STRING {},
358 15 => {
359 let n_items;
360 (n_items, offset) = retrieve_usize(buf, offset)?;
361 let mut item_serde_types = Vec::with_capacity(n_items);
362 for _ in 0..n_items {
363 let serde_type_bytes;
364 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
365 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
366 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
367 item_serde_types.push(pickleable_serde_type.0.unwrap().unwrap())
368 }
369 PyAnySerdeType::TUPLE { item_serde_types }
370 }
371 16 => {
372 let n_keys;
373 (n_keys, offset) = retrieve_usize(buf, offset)?;
374 let mut key_serde_type_dict = BTreeMap::new();
375 for _ in 0..n_keys {
376 let key;
377 (key, offset) = retrieve_string(buf, offset)?;
378 let serde_type_bytes;
379 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
380 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
381 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
382 key_serde_type_dict
383 .insert(key, pickleable_serde_type.0.unwrap().unwrap());
384 }
385 PyAnySerdeType::TYPEDDICT {
386 key_serde_type_dict,
387 }
388 }
389 17 => {
390 let n_options;
391 (n_options, offset) = retrieve_usize(buf, offset)?;
392 let mut option_serde_types = Vec::with_capacity(n_options);
393 for _ in 0..n_options {
394 let serde_type_bytes;
395 (serde_type_bytes, offset) = retrieve_bytes(buf, offset)?;
396 let mut pickleable_serde_type = PickleablePyAnySerdeType(None);
397 pickleable_serde_type.__setstate__(serde_type_bytes.to_vec())?;
398 option_serde_types.push(pickleable_serde_type.0.unwrap().unwrap())
399 }
400 Python::with_gil::<_, PyResult<_>>(|py| {
401 let option_choice_fn_bytes;
402 (option_choice_fn_bytes, offset) = retrieve_bytes(buf, offset)?;
403 let option_choice_fn = py.import("pickle")?.getattr("loads")?.call1(
404 (PyBytes::new(py, option_choice_fn_bytes).into_pyobject(py)?,),
405 )?;
406 Ok(PyAnySerdeType::UNION {
407 option_serde_types,
408 option_choice_fn: option_choice_fn
409 .downcast_into::<PyFunction>()?
410 .unbind(),
411 })
412 })?
413 }
414 v => Err(InvalidStateError::new_err(format!(
415 "Got invalid type byte for PyAnySerde: {v}"
416 )))?,
417 })
418 }
419 v => Err(InvalidStateError::new_err(format!(
420 "Got invalid option byte for PyAnySerdeType: {v}"
421 )))?,
422 });
423
424 Ok(())
425 }
426}
427
428#[pyclass]
429#[derive(Debug, Clone, Display, strum_macros::VariantNames)]
430pub enum PyAnySerdeType {
431 BOOL {},
432 BYTES {},
433 COMPLEX {},
434 DATACLASS {
435 clazz: PyObject,
436 init_strategy: InitStrategy,
437 field_serde_type_dict: BTreeMap<String, PyAnySerdeType>,
438 },
439 DICT {
440 keys_serde_type: Py<PyAnySerdeType>,
441 values_serde_type: Py<PyAnySerdeType>,
442 },
443 DYNAMIC {},
444 FLOAT {},
445 INT {},
446 LIST {
447 items_serde_type: Py<PyAnySerdeType>,
448 },
449 #[pyo3(constructor = (dtype, config = NumpySerdeConfig::DYNAMIC { preprocessor_fn: None, postprocessor_fn: None }))]
450 NUMPY {
451 dtype: NumpyDtype,
452 config: NumpySerdeConfig,
453 },
454 OPTION {
455 value_serde_type: Py<PyAnySerdeType>,
456 },
457 PICKLE {},
458 PYTHONSERDE {
459 python_serde: PyObject,
460 },
461 SET {
462 items_serde_type: Py<PyAnySerdeType>,
463 },
464 STRING {},
465 TUPLE {
466 item_serde_types: Vec<PyAnySerdeType>,
467 },
468 TYPEDDICT {
469 key_serde_type_dict: BTreeMap<String, PyAnySerdeType>,
470 },
471 UNION {
472 option_serde_types: Vec<PyAnySerdeType>,
473 option_choice_fn: Py<PyFunction>,
474 },
475}
476
477fn check_for_unpickling_aux<'py>(data: &Bound<'py, PyAny>) -> PyResult<bool> {
478 let pyany_serde_type_field = data
479 .get_item("type")?
480 .extract::<String>()?
481 .to_ascii_lowercase();
482 Ok(match pyany_serde_type_field.as_str() {
483 "dataclass" => true,
484 "dict" => {
485 check_for_unpickling_aux(&data.get_item("keys_serde_type")?)?
486 || check_for_unpickling_aux(&data.get_item("values_serde_type")?)?
487 }
488 "list" => check_for_unpickling_aux(&data.get_item("items_serde_type")?)?,
489 "numpy" => numpy_check_for_unpickling(&data.get_item("config")?)?,
490 "option" => check_for_unpickling_aux(&data.get_item("value_serde_type")?)?,
491 "pythonserde" => true,
492 "set" => check_for_unpickling_aux(&data.get_item("items_serde_type")?)?,
493 "tuple" => {
494 let mut has_unpickling = false;
495 for item_serde_type_data in data
496 .get_item("item_serde_types")?
497 .extract::<Vec<Bound<'_, PyAny>>>()?
498 .iter()
499 {
500 has_unpickling |= check_for_unpickling_aux(&item_serde_type_data)?;
501 }
502 has_unpickling
503 }
504 "typeddict" => {
505 let mut has_unpickling = false;
506 for (_, serde_type_data) in data
507 .get_item("key_serde_type_dict")?
508 .downcast_into::<PyDict>()?
509 .iter()
510 {
511 has_unpickling |= check_for_unpickling_aux(&serde_type_data)?;
512 }
513 has_unpickling
514 }
515 "union" => true,
516 _ => false,
517 })
518}
519
520#[pyfunction]
521fn check_for_unpickling<'py, 'a>(data: &'a Bound<'py, PyAny>) -> PyResult<&'a Bound<'py, PyAny>> {
522 let silent_mode = env::var("PYANY_SERDE_UNPICKLE_WITHOUT_PROMPT")
523 .map(|v| v.eq("1"))
524 .unwrap_or(false);
525 if !silent_mode && check_for_unpickling_aux(&data)? {
526 println!("WARNING: About to call unpickle on the hexadecimal-encoded binary contents of some config fields. If you do not trust the origins of this json, or you cannot otherwise verify the safety of this field's contents, you should not proceed.");
527 print!("Proceed? (y/N)\t");
528 io::stdout().flush()?;
529 let mut response = String::new();
530 io::stdin().read_line(&mut response).unwrap();
531 if !response.trim().eq_ignore_ascii_case("y") {
532 Err(PyValueError::new_err("Operation cancelled by user due to unpickling required to build config model from json"))?
533 } else {
534 println!("Continuing with execution. If you would like to ignore this warning in the future, set the environment variable PYANY_SERDE_UNPICKLE_WITHOUT_PROMPT to \"1\".")
535 }
536 }
537 Ok(data)
538}
539
540fn get_before_validator_fn<'py>(
541 _handler: &Bound<'py, PyAny>,
542 _schema_validator: &Bound<'py, PyAny>,
543) -> PyResult<Bound<'py, PyCFunction>> {
544 let _py = _handler.py();
545 let py_handler = _handler.clone().unbind();
546 let py_schema_validator = _schema_validator.clone().unbind();
547 let func = move |args: &Bound<'_, PyTuple>,
548 _kwargs: Option<&Bound<'_, PyDict>>|
549 -> PyResult<PyObject> {
550 let py = args.py();
552 let data = args.get_item(0)?;
553 let handler = py_handler.bind(py);
554 let schema_validator = py_schema_validator.bind(py);
555
556 let pyany_serde_type_field = data
558 .get_item("type")?
559 .extract::<String>()?
560 .to_ascii_lowercase();
561 let pyany_serde_type = match pyany_serde_type_field.as_str() {
562 "bool" => PyAnySerdeType::BOOL {},
563 "bytes" => PyAnySerdeType::BYTES {},
564 "complex" => PyAnySerdeType::COMPLEX {},
565 "dataclass" => {
566 let clazz_bytes_hex = data.get_item("dataclass_pkl")?.extract::<String>()?;
567 let clazz = py
568 .import("pickle")?
569 .getattr("loads")?
570 .call1((PyBytes::new(
571 py,
572 &hex::decode(clazz_bytes_hex.as_str()).map_err(|err| {
573 PyValueError::new_err(format!(
574 "dataclass_pkl could not be decoded from hex into bytes: {}",
575 err.to_string()
576 ))
577 })?,
578 ),))?
579 .unbind();
580 let init_strategy = schema_validator
581 .call1((handler
582 .call_method1("generate_schema", (InitStrategy::type_object(py),))?,))?
583 .call_method1("validate_python", (data.get_item("init_strategy")?,))?
584 .extract::<InitStrategy>()?;
585 let mut field_serde_type_dict = BTreeMap::new();
586 for (key, serde_type_data) in data
587 .get_item("field_serde_type_dict")?
588 .downcast_into::<PyDict>()?
589 .into_iter()
590 {
591 let key = key.extract::<String>()?;
592 let value = get_before_validator_fn(handler, schema_validator)?
593 .call1((serde_type_data,))?
594 .extract::<PyAnySerdeType>()?;
595 field_serde_type_dict.insert(key, value);
596 }
597 PyAnySerdeType::DATACLASS {
598 clazz,
599 init_strategy,
600 field_serde_type_dict,
601 }
602 }
603 "dict" => {
604 let keys_serde_type_data = data.get_item("keys_serde_type")?;
605 let keys_serde_type = get_before_validator_fn(handler, schema_validator)?
606 .call1((keys_serde_type_data,))?
607 .extract::<PyAnySerdeType>()?;
608 let values_serde_type_data = data.get_item("values_serde_type")?;
609 let values_serde_type = get_before_validator_fn(handler, schema_validator)?
610 .call1((values_serde_type_data,))?
611 .extract::<PyAnySerdeType>()?;
612 PyAnySerdeType::DICT {
613 keys_serde_type: Py::new(py, keys_serde_type)?,
614 values_serde_type: Py::new(py, values_serde_type)?,
615 }
616 }
617 "dynamic" => PyAnySerdeType::DYNAMIC {},
618 "float" => PyAnySerdeType::FLOAT {},
619 "int" => PyAnySerdeType::INT {},
620 "list" => {
621 let items_serde_type_data = data.get_item("items_serde_type")?;
622 let items_serde_type = get_before_validator_fn(handler, schema_validator)?
623 .call1((items_serde_type_data,))?
624 .extract::<PyAnySerdeType>()?;
625 PyAnySerdeType::LIST {
626 items_serde_type: Py::new(py, items_serde_type)?,
627 }
628 }
629 "numpy" => {
630 let dtype_string = data.get_item("dtype")?.extract::<String>()?;
631 let dtype = NumpyDtype::from_str(dtype_string.as_str()).map_err(|_| {
632 PyValueError::new_err(format!(
633 "dtype was provided as {dtype_string} which is not a valid dtype"
634 ))
635 })?;
636 let numpy_serde_config = schema_validator
637 .call1((handler
638 .call_method1("generate_schema", (NumpySerdeConfig::type_object(py),))?,))?
639 .call_method1("validate_python", (data.get_item("config")?,))?
640 .extract::<NumpySerdeConfig>()?;
641 PyAnySerdeType::NUMPY {
642 dtype,
643 config: numpy_serde_config,
644 }
645 }
646 "option" => {
647 let value_serde_type_data = data.get_item("value_serde_type")?;
648 let value_serde_type = get_before_validator_fn(handler, schema_validator)?
649 .call1((value_serde_type_data,))?
650 .extract::<PyAnySerdeType>()?;
651 PyAnySerdeType::OPTION {
652 value_serde_type: Py::new(py, value_serde_type)?,
653 }
654 }
655 "pickle" => PyAnySerdeType::PICKLE {},
656 "pythonserde" => {
657 let python_serde_bytes_hex =
658 data.get_item("python_serde_pkl")?.extract::<String>()?;
659 let python_serde = py
660 .import("pickle")?
661 .getattr("loads")?
662 .call1((PyBytes::new(
663 py,
664 &hex::decode(python_serde_bytes_hex.as_str()).map_err(|err| {
665 PyValueError::new_err(format!(
666 "python_serde_pkl could not be decoded from hex into bytes: {}",
667 err.to_string()
668 ))
669 })?,
670 ),))?
671 .unbind();
672 PyAnySerdeType::PYTHONSERDE { python_serde }
673 }
674 "set" => {
675 let items_serde_type_data = data.get_item("items_serde_type")?;
676 let items_serde_type = get_before_validator_fn(handler, schema_validator)?
677 .call1((items_serde_type_data,))?
678 .extract::<PyAnySerdeType>()?;
679 PyAnySerdeType::SET {
680 items_serde_type: Py::new(py, items_serde_type)?,
681 }
682 }
683 "string" => PyAnySerdeType::STRING {},
684 "tuple" => {
685 let item_serde_types_data = data
686 .get_item("item_serde_types")?
687 .extract::<Vec<Bound<'_, PyAny>>>()?;
688 let item_serde_types = item_serde_types_data
689 .iter()
690 .map(|item_serde_type_data| {
691 Ok(get_before_validator_fn(handler, schema_validator)?
692 .call1((item_serde_type_data,))?
693 .extract::<PyAnySerdeType>()?)
694 })
695 .collect::<PyResult<Vec<_>>>()?;
696 PyAnySerdeType::TUPLE { item_serde_types }
697 }
698 "typeddict" => {
699 let mut key_serde_type_dict = BTreeMap::new();
700 for (key, serde_type_data) in data
701 .get_item("key_serde_type_dict")?
702 .downcast_into::<PyDict>()?
703 .into_iter()
704 {
705 let key = key.extract::<String>()?;
706 let value = get_before_validator_fn(handler, schema_validator)?
707 .call1((serde_type_data,))?
708 .extract::<PyAnySerdeType>()?;
709 key_serde_type_dict.insert(key, value);
710 }
711 PyAnySerdeType::TYPEDDICT {
712 key_serde_type_dict,
713 }
714 }
715 "union" => {
716 let option_serde_types_data = data
717 .get_item("option_serde_types")?
718 .extract::<Vec<Bound<'_, PyAny>>>()?;
719 let option_serde_types = option_serde_types_data
720 .iter()
721 .map(|option_serde_type_data| {
722 Ok(get_before_validator_fn(handler, schema_validator)?
723 .call1((option_serde_type_data,))?
724 .extract::<PyAnySerdeType>()?)
725 })
726 .collect::<PyResult<Vec<_>>>()?;
727 let option_choice_fn_bytes_hex =
728 data.get_item("option_choice_fn_pkl")?.extract::<String>()?;
729 let option_choice_fn = py
730 .import("pickle")?
731 .getattr("loads")?
732 .call1((PyBytes::new(
733 py,
734 &hex::decode(option_choice_fn_bytes_hex.as_str()).map_err(|err| {
735 PyValueError::new_err(format!(
736 "option_choice_fn_pkl could not be decoded from hex into bytes: {}",
737 err.to_string()
738 ))
739 })?,
740 ),))?
741 .downcast_into::<PyFunction>()?
742 .unbind();
743 PyAnySerdeType::UNION {
744 option_serde_types,
745 option_choice_fn,
746 }
747 }
748 v => Err(PyValueError::new_err(format!("Unexpected type: {v}")))?,
749 };
750
751 Ok(pyany_serde_type.into_pyobject(py)?.into_any().unbind())
752 };
753 PyCFunction::new_closure(_py, None, None, func)
754}
755
756#[pymethods]
757impl PyAnySerdeType {
758 fn as_pickleable<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
759 Ok(PickleablePyAnySerdeType(Some(Some(self.clone())))
760 .into_pyobject(py)?
761 .into_any())
762 }
763
764 #[classmethod]
766 fn __get_pydantic_core_schema__<'py>(
767 cls: &Bound<'py, PyType>,
768 _source_type: Bound<'py, PyAny>,
769 handler: Bound<'py, PyAny>,
770 ) -> PyResult<Bound<'py, PyAny>> {
771 let py = cls.py();
772 let generate_schema = handler.getattr("generate_schema")?;
773 let pydantic_core = py.import("pydantic_core")?;
774 let schema_validator = pydantic_core.getattr("SchemaValidator")?;
775 let core_schema = pydantic_core.getattr("core_schema")?;
776
777 let str_schema = core_schema.getattr("str_schema")?;
778 let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
779 let list_schema = core_schema.getattr("list_schema")?;
780 let dict_schema = core_schema.getattr("dict_schema")?;
781 let any_schema = core_schema.getattr("any_schema")?;
782 let typed_dict_field = core_schema.getattr("typed_dict_field")?;
783
784 let pyany_serde_type_reference_schema = core_schema
785 .call_method1("definition_reference_schema", ("pyany_serde_type_schema",))?;
786 let pyany_serde_type_reference_schema_field =
787 typed_dict_field.call1((&pyany_serde_type_reference_schema,))?;
788
789 let union_list = PyAnySerdeType::VARIANTS
790 .iter()
791 .map(|pyany_serde_type_variant| {
792 let pyany_serde_type_field = pyany_serde_type_variant.to_ascii_lowercase();
793 let typed_dict_fields = PyDict::new(py);
794 typed_dict_fields.set_item(
795 "type",
796 typed_dict_field.call1((str_schema.call(
797 (),
798 Some(&PyDict::from_sequence(
799 &vec![(
800 "pattern",
801 vec![
802 "^".to_owned(),
803 pyany_serde_type_field.clone(),
804 "$".to_owned(),
805 ]
806 .join("")
807 .into_pyobject(py)?
808 .into_any(),
809 )]
810 .into_pyobject(py)?,
811 )?),
812 )?,))?,
813 )?;
814 match pyany_serde_type_field.as_str() {
815 "dataclass" => {
816 typed_dict_fields.set_item(
817 "dataclass_pkl",
818 typed_dict_field.call1((str_schema.call0()?,))?,
819 )?;
820 typed_dict_fields.set_item(
821 "init_strategy",
822 typed_dict_field.call1((
823 generate_schema.call1((InitStrategy::type_object(py),))?,
824 ))?,
825 )?;
826 typed_dict_fields.set_item(
827 "field_serde_type_dict",
828 typed_dict_field.call1((dict_schema.call1((
829 str_schema.call0()?,
830 &pyany_serde_type_reference_schema,
831 ))?,))?,
832 )?;
833 }
834 "dict" => {
835 typed_dict_fields.set_item(
836 "keys_serde_type",
837 &pyany_serde_type_reference_schema_field,
838 )?;
839 typed_dict_fields.set_item(
840 "values_serde_type",
841 &pyany_serde_type_reference_schema_field,
842 )?;
843 }
844 "list" => {
845 typed_dict_fields.set_item(
846 "items_serde_type",
847 &pyany_serde_type_reference_schema_field,
848 )?;
849 }
850 "numpy" => {
851 typed_dict_fields.set_item(
852 "dtype",
853 typed_dict_field.call1((str_schema.call(
854 (),
855 Some(&PyDict::from_sequence(
856 &vec![(
857 "pattern",
858 vec![
859 "^(".to_owned(),
860 NumpyDtype::iter()
861 .map(|dtype_str| dtype_str.to_string())
862 .collect::<Vec<_>>()
863 .join("|"),
864 ")$".to_owned(),
865 ]
866 .join(""),
867 )]
868 .into_pyobject(py)?,
869 )?),
870 )?,))?,
871 )?;
872 typed_dict_fields.set_item(
873 "config",
874 typed_dict_field.call1((
875 generate_schema.call1((NumpySerdeConfig::type_object(py),))?,
876 ))?,
877 )?;
878 }
879 "option" => {
880 typed_dict_fields.set_item(
881 "value_serde_type",
882 &pyany_serde_type_reference_schema_field,
883 )?;
884 }
885 "pythonserde" => {
886 typed_dict_fields.set_item(
887 "python_serde_pkl",
888 typed_dict_field.call1((str_schema.call0()?,))?,
889 )?;
890 }
891 "set" => {
892 typed_dict_fields.set_item(
893 "items_serde_type",
894 &pyany_serde_type_reference_schema_field,
895 )?;
896 }
897 "tuple" => {
898 typed_dict_fields.set_item(
899 "item_serde_types",
900 typed_dict_field.call1((
901 list_schema.call1((&pyany_serde_type_reference_schema,))?,
902 ))?,
903 )?;
904 }
905 "typeddict" => {
906 typed_dict_fields.set_item(
907 "key_serde_type_dict",
908 typed_dict_field.call1((dict_schema.call1((
909 str_schema.call0()?,
910 &pyany_serde_type_reference_schema,
911 ))?,))?,
912 )?;
913 }
914 "union" => {
915 typed_dict_fields.set_item(
916 "option_serde_types",
917 typed_dict_field.call1((
918 list_schema.call1((&pyany_serde_type_reference_schema,))?,
919 ))?,
920 )?;
921 typed_dict_fields.set_item(
922 "option_choice_fn_pkl",
923 typed_dict_field.call1((str_schema.call0()?,))?,
924 )?;
925 }
926 _ => (),
927 };
928 Ok(typed_dict_schema.call1((typed_dict_fields,))?)
929 })
930 .collect::<PyResult<Vec<_>>>()?;
931 let pyany_serde_type_union_schema = core_schema.call_method(
932 "union_schema",
933 (union_list,),
934 Some(&PyDict::from_sequence(
935 &vec![("ref", "pyany_serde_type_schema")].into_pyobject(py)?,
936 )?),
937 )?;
938
939 let pyany_serde_type_python_schema =
940 core_schema.call_method1("is_instance_schema", (PyAnySerdeType::type_object(py),))?;
941 let pyany_serde_type_json_or_python_schema = core_schema.call_method1(
942 "json_or_python_schema",
943 (
944 core_schema.call_method1(
945 "chain_schema",
946 (vec![
947 core_schema.call_method1(
948 "no_info_before_validator_function",
949 (
950 wrap_pyfunction!(check_for_unpickling, py)?,
951 any_schema.call0()?,
952 ),
953 )?,
954 pyany_serde_type_union_schema.clone(),
955 core_schema.call_method1(
956 "no_info_before_validator_function",
957 (
958 get_before_validator_fn(&handler, &schema_validator)?,
959 &pyany_serde_type_python_schema,
960 ),
961 )?,
962 ],),
963 )?,
964 pyany_serde_type_python_schema,
965 ),
966 )?;
967 core_schema.call_method(
968 "definitions_schema",
969 (&pyany_serde_type_json_or_python_schema,),
970 Some(&PyDict::from_sequence(
971 &vec![("definitions", vec![&pyany_serde_type_union_schema])].into_pyobject(py)?,
972 )?),
973 )
974 }
975
976 fn to_json(&self) -> PyResult<PyObject> {
977 Python::with_gil(|py| {
978 let data = PyDict::new(py);
979 data.set_item("type", self.to_string().to_ascii_lowercase())?;
980 if let PyAnySerdeType::DATACLASS {
981 clazz,
982 init_strategy,
983 field_serde_type_dict,
984 } = self
985 {
986 data.set_item(
987 "dataclass_pkl",
988 py.import("pickle")?
989 .getattr("dumps")?
990 .call1((clazz,))?
991 .call_method0("hex")?,
992 )?;
993 data.set_item("init_strategy", init_strategy.to_json()?)?;
994 data.set_item(
995 "field_serde_type_dict",
996 field_serde_type_dict
997 .iter()
998 .map(|(key, field_serde_type)| Ok((key, field_serde_type.to_json()?)))
999 .collect::<PyResult<BTreeMap<_, _>>>()?,
1000 )?;
1001 } else if let PyAnySerdeType::DICT {
1002 keys_serde_type,
1003 values_serde_type,
1004 } = self
1005 {
1006 data.set_item(
1007 "keys_serde_type",
1008 keys_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1009 )?;
1010 data.set_item(
1011 "values_serde_type",
1012 values_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1013 )?;
1014 } else if let PyAnySerdeType::LIST { items_serde_type } = self {
1015 data.set_item(
1016 "items_serde_type",
1017 items_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1018 )?;
1019 } else if let PyAnySerdeType::NUMPY { dtype, config } = self {
1020 data.set_item("dtype", dtype.to_string())?;
1021 data.set_item("config", config.to_json()?)?;
1022 } else if let PyAnySerdeType::OPTION { value_serde_type } = self {
1023 data.set_item(
1024 "value_serde_type",
1025 value_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1026 )?;
1027 } else if let PyAnySerdeType::PYTHONSERDE { python_serde } = self {
1028 data.set_item(
1029 "python_serde_pkl",
1030 py.import("pickle")?
1031 .getattr("dumps")?
1032 .call1((python_serde,))?
1033 .call_method0("hex")?,
1034 )?;
1035 } else if let PyAnySerdeType::SET { items_serde_type } = self {
1036 data.set_item(
1037 "items_serde_type",
1038 items_serde_type.extract::<PyAnySerdeType>(py)?.to_json()?,
1039 )?;
1040 } else if let PyAnySerdeType::TUPLE { item_serde_types } = self {
1041 data.set_item(
1042 "item_serde_types",
1043 item_serde_types
1044 .iter()
1045 .map(|item_serde_type| item_serde_type.to_json())
1046 .collect::<PyResult<Vec<_>>>()?,
1047 )?;
1048 } else if let PyAnySerdeType::TYPEDDICT {
1049 key_serde_type_dict,
1050 } = self
1051 {
1052 data.set_item(
1053 "key_serde_type_dict",
1054 key_serde_type_dict
1055 .iter()
1056 .map(|(key, field_serde_type)| Ok((key, field_serde_type.to_json()?)))
1057 .collect::<PyResult<BTreeMap<_, _>>>()?,
1058 )?;
1059 } else if let PyAnySerdeType::UNION {
1060 option_serde_types,
1061 option_choice_fn,
1062 } = self
1063 {
1064 data.set_item(
1065 "option_serde_types",
1066 option_serde_types
1067 .iter()
1068 .map(|item_serde_type| item_serde_type.to_json())
1069 .collect::<PyResult<Vec<_>>>()?,
1070 )?;
1071 data.set_item(
1072 "option_choice_fn_pkl",
1073 py.import("pickle")?
1074 .getattr("dumps")?
1075 .call1((option_choice_fn,))?
1076 .call_method0("hex")?,
1077 )?;
1078 }
1079 Ok(data.into_any().unbind())
1080 })
1081 }
1082}