pyany_serde/pyany_serde_impl/
numpy_serde.rs

1use std::env;
2
3use bytemuck::{cast_slice, AnyBitPattern, NoUninit};
4use numpy::ndarray::ArrayD;
5use numpy::{Element, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
6use numpy::{IntoPyArray, PyArray};
7use pyo3::exceptions::asyncio::InvalidStateError;
8use pyo3::exceptions::PyValueError;
9use pyo3::sync::GILOnceCell;
10use pyo3::types::{PyBytes, PyCFunction, PyDict, PyList, PyTuple, PyType};
11use pyo3::{intern, prelude::*, PyTypeInfo};
12use strum_macros::Display;
13
14use crate::communication::{
15    append_bool_vec, append_bytes_vec, append_usize, append_usize_vec, retrieve_bool,
16    retrieve_usize,
17};
18use crate::{
19    common::{get_bytes_to_alignment, NumpyDtype},
20    communication::{append_bytes, retrieve_bytes},
21    PyAnySerde,
22};
23
24fn append_usize_option_vec(v: &mut Vec<u8>, val_option: &Option<usize>) {
25    if let Some(val) = val_option {
26        append_bool_vec(v, true);
27        append_usize_vec(v, *val);
28    } else {
29        append_bool_vec(v, false);
30    }
31}
32
33fn retrieve_usize_option(buf: &[u8], mut offset: usize) -> PyResult<(Option<usize>, usize)> {
34    let has_val;
35    (has_val, offset) = retrieve_bool(buf, offset)?;
36    if has_val {
37        let val;
38        (val, offset) = retrieve_usize(buf, offset)?;
39        Ok((Some(val), offset))
40    } else {
41        Ok((None, offset))
42    }
43}
44
45fn append_python_pkl_option_vec(v: &mut Vec<u8>, obj_option: &Option<PyObject>) -> PyResult<()> {
46    if let Some(obj) = obj_option {
47        append_bool_vec(v, true);
48        Python::with_gil::<_, PyResult<_>>(|py| {
49            let preprocessor_fn_py_bytes = py
50                .import("pickle")?
51                .getattr("dumps")?
52                .call1((obj,))?
53                .downcast_into::<PyBytes>()?;
54            append_bytes_vec(v, preprocessor_fn_py_bytes.as_bytes());
55            Ok(())
56        })?;
57    } else {
58        append_bool_vec(v, false);
59    }
60    Ok(())
61}
62
63fn retrieve_python_pkl_option(
64    buf: &[u8],
65    mut offset: usize,
66) -> PyResult<(Option<PyObject>, usize)> {
67    let has_obj;
68    (has_obj, offset) = retrieve_bool(buf, offset)?;
69    if has_obj {
70        Python::with_gil::<_, PyResult<_>>(|py| {
71            let obj_bytes;
72            (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
73            Ok((
74                Some(
75                    py.import("pickle")?
76                        .getattr("loads")?
77                        .call1((PyBytes::new(py, obj_bytes).into_pyobject(py)?,))?
78                        .unbind(),
79                ),
80                offset,
81            ))
82        })
83    } else {
84        Ok((None, offset))
85    }
86}
87
88#[pyclass]
89#[derive(Clone)]
90pub struct PickleableNumpySerdeConfig(pub Option<NumpySerdeConfig>);
91
92#[pymethods]
93impl PickleableNumpySerdeConfig {
94    #[new]
95    #[pyo3(signature = (*args))]
96    fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
97        let vec_args = args.iter().collect::<Vec<_>>();
98        if vec_args.len() > 1 {
99            return Err(PyValueError::new_err(format!(
100                "PickleableNumpySerdeConfig constructor takes 0 or 1 parameters, received {}",
101                args.as_any().repr()?.to_str()?
102            )));
103        }
104        if vec_args.len() == 1 {
105            Ok(PickleableNumpySerdeConfig(
106                vec_args[0].extract::<Option<NumpySerdeConfig>>()?,
107            ))
108        } else {
109            Ok(PickleableNumpySerdeConfig(None))
110        }
111    }
112    pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
113        Ok(match self.0.as_ref().unwrap() {
114            NumpySerdeConfig::DYNAMIC {
115                preprocessor_fn,
116                postprocessor_fn,
117            } => {
118                let mut bytes = vec![0];
119                append_python_pkl_option_vec(&mut bytes, preprocessor_fn)?;
120                append_python_pkl_option_vec(&mut bytes, postprocessor_fn)?;
121                bytes
122            }
123            NumpySerdeConfig::STATIC {
124                preprocessor_fn,
125                postprocessor_fn,
126                shape,
127                allocation_pool_min_size,
128                allocation_pool_max_size,
129                allocation_pool_warning_size,
130            } => {
131                let mut bytes = vec![1];
132                append_python_pkl_option_vec(&mut bytes, preprocessor_fn)?;
133                append_python_pkl_option_vec(&mut bytes, postprocessor_fn)?;
134                append_usize_vec(&mut bytes, shape.len());
135                for &dim in shape.iter() {
136                    append_usize_vec(&mut bytes, dim);
137                }
138                append_usize_vec(&mut bytes, *allocation_pool_min_size);
139                append_usize_option_vec(&mut bytes, allocation_pool_max_size);
140                append_usize_option_vec(&mut bytes, allocation_pool_warning_size);
141                bytes
142            }
143        })
144    }
145    pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
146        let buf = &state[..];
147        let type_byte = buf[0];
148        let mut offset = 1;
149        self.0 = Some(match type_byte {
150            0 => {
151                let preprocessor_fn;
152                (preprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
153                let postprocessor_fn;
154                (postprocessor_fn, _) = retrieve_python_pkl_option(buf, offset)?;
155                NumpySerdeConfig::DYNAMIC {
156                    preprocessor_fn,
157                    postprocessor_fn,
158                }
159            }
160            1 => {
161                let preprocessor_fn;
162                (preprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
163                let postprocessor_fn;
164                (postprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
165                let shape_len;
166                (shape_len, offset) = retrieve_usize(buf, offset)?;
167                let mut shape = Vec::with_capacity(shape_len);
168                for _ in 0..shape_len {
169                    let dim;
170                    (dim, offset) = retrieve_usize(buf, offset)?;
171                    shape.push(dim);
172                }
173                let allocation_pool_min_size;
174                (allocation_pool_min_size, offset) = retrieve_usize(buf, offset)?;
175                let allocation_pool_max_size;
176                (allocation_pool_max_size, _) = retrieve_usize_option(buf, offset)?;
177                let allocation_pool_warning_size;
178                (allocation_pool_warning_size, _) = retrieve_usize_option(buf, offset)?;
179                NumpySerdeConfig::STATIC {
180                    preprocessor_fn,
181                    postprocessor_fn,
182                    shape,
183                    allocation_pool_min_size,
184                    allocation_pool_max_size,
185                    allocation_pool_warning_size,
186                }
187            }
188            v => Err(InvalidStateError::new_err(format!(
189                "Got invalid type byte for NumpySerdeConfig: {v}"
190            )))?,
191        });
192        Ok(())
193    }
194}
195
196// TODO: remove preprocessor and postprocessor fns
197#[pyclass]
198#[derive(Debug, Clone, Display)]
199pub enum NumpySerdeConfig {
200    #[pyo3(constructor = (preprocessor_fn = None, postprocessor_fn = None))]
201    DYNAMIC {
202        preprocessor_fn: Option<PyObject>,
203        postprocessor_fn: Option<PyObject>,
204    },
205    #[pyo3(constructor = (shape, preprocessor_fn = None, postprocessor_fn = None, allocation_pool_min_size = 0, allocation_pool_max_size = None, allocation_pool_warning_size = Some(10000)))]
206    STATIC {
207        shape: Vec<usize>,
208        preprocessor_fn: Option<PyObject>,
209        postprocessor_fn: Option<PyObject>,
210        allocation_pool_min_size: usize,
211        allocation_pool_max_size: Option<usize>,
212        allocation_pool_warning_size: Option<usize>,
213    },
214}
215
216macro_rules! create_union {
217    ($handler:expr, $py:expr, $($type:ident),+) => {{
218        let mut union_list = Vec::new();
219        $(
220            union_list.push(
221                $handler.call_method1(
222                    "generate_schema",
223                    (paste::paste! { [<NumpySerdeConfig_ $type>]::type_object($py) },)
224                )?
225            );
226        )+
227        Ok::<_, PyErr>(union_list)
228    }};
229}
230
231pub fn check_for_unpickling<'py>(data: &Bound<'py, PyAny>) -> PyResult<bool> {
232    let preprocessor_fn_hex_option = data
233        .get_item("preprocessor_fn_pkl")?
234        .extract::<Option<String>>()?;
235    let postprocessor_fn_hex_option = data
236        .get_item("postprocessor_fn_pkl")?
237        .extract::<Option<String>>()?;
238    Ok(preprocessor_fn_hex_option.is_some() || postprocessor_fn_hex_option.is_some())
239}
240
241fn get_enum_subclass_before_validator_fn<'py>(
242    cls: &Bound<'py, PyType>,
243) -> PyResult<Bound<'py, PyCFunction>> {
244    let _py = cls.py();
245    let py_cls = cls.clone().unbind();
246    let func = move |args: &Bound<'_, PyTuple>,
247                     _kwargs: Option<&Bound<'_, PyDict>>|
248          -> PyResult<PyObject> {
249        let py = args.py();
250        let data = args.get_item(0)?;
251        let cls = py_cls.bind(py);
252        let preprocessor_fn_hex_option = data
253            .get_item("preprocessor_fn_pkl")?
254            .extract::<Option<String>>()?;
255        let preprocessor_fn_option = preprocessor_fn_hex_option
256            .map(|preprocessor_fn_hex| {
257                Ok::<_, PyErr>(
258                    py.import("pickle")?
259                        .getattr("loads")?
260                        .call1((PyBytes::new(
261                            py,
262                            &hex::decode(preprocessor_fn_hex.as_str()).map_err(|err| {
263                                PyValueError::new_err(format!(
264                                    "python_serde_pkl could not be decoded from hex into bytes: {}",
265                                    err.to_string()
266                                ))
267                            })?,
268                        ),))?
269                        .unbind(),
270                )
271            })
272            .transpose()?;
273        let postprocessor_fn_hex_option = data
274            .get_item("postprocessor_fn_pkl")?
275            .extract::<Option<String>>()?;
276        let postprocessor_fn_option = postprocessor_fn_hex_option
277            .map(|postprocessor_fn_hex| {
278                Ok::<_, PyErr>(
279                    py.import("pickle")?
280                        .getattr("loads")?
281                        .call1((PyBytes::new(
282                            py,
283                            &hex::decode(postprocessor_fn_hex.as_str()).map_err(|err| {
284                                PyValueError::new_err(format!(
285                                    "python_serde_pkl could not be decoded from hex into bytes: {}",
286                                    err.to_string()
287                                ))
288                            })?,
289                        ),))?
290                        .unbind(),
291                )
292            })
293            .transpose()?;
294        if cls.eq(NumpySerdeConfig_DYNAMIC::type_object(py))? {
295            Ok(NumpySerdeConfig::DYNAMIC {
296                preprocessor_fn: preprocessor_fn_option,
297                postprocessor_fn: postprocessor_fn_option,
298            }
299            .into_pyobject(py)?
300            .into_any()
301            .unbind())
302        } else if cls.eq(NumpySerdeConfig_STATIC::type_object(py))? {
303            let shape = data.get_item("shape")?.extract::<Vec<usize>>()?;
304            let allocation_pool_min_size = data
305                .get_item("allocation_pool_min_size")?
306                .extract::<usize>()?;
307            let allocation_pool_max_size = data
308                .get_item("allocation_pool_max_size")?
309                .extract::<Option<usize>>()?;
310            let allocation_pool_warning_size = data
311                .get_item("allocation_pool_warning_size")?
312                .extract::<Option<usize>>()?;
313            if allocation_pool_max_size.is_some()
314                && allocation_pool_min_size > allocation_pool_max_size.unwrap()
315            {
316                Err(PyValueError::new_err(format!(
317                    "Validation error: allocation_pool_min_size ({}) cannot be greater than allocation_pool_max_size ({})", allocation_pool_min_size, allocation_pool_max_size.unwrap()
318                )))?
319            }
320            Ok(NumpySerdeConfig::STATIC {
321                preprocessor_fn: preprocessor_fn_option,
322                postprocessor_fn: postprocessor_fn_option,
323                shape,
324                allocation_pool_min_size,
325                allocation_pool_max_size,
326                allocation_pool_warning_size,
327            }
328            .into_pyobject(py)?
329            .into_any()
330            .unbind())
331        } else {
332            Err(PyValueError::new_err(format!(
333                "Unexpected class: {}",
334                cls.repr()?.to_str()?
335            )))
336        }
337    };
338    PyCFunction::new_closure(_py, None, None, func)
339}
340
341fn get_enum_subclass_typed_dict_schema<'py>(
342    cls: &Bound<'py, PyType>,
343    core_schema: &Bound<'py, PyAny>,
344) -> PyResult<Bound<'py, PyAny>> {
345    let py = cls.py();
346    let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
347    let typed_dict_field = core_schema.getattr("typed_dict_field")?;
348    let int_schema = core_schema.getattr("int_schema")?;
349    let str_schema = core_schema.getattr("str_schema")?;
350    let list_schema = core_schema.getattr("list_schema")?;
351    let nullable_schema = core_schema.getattr("nullable_schema")?;
352    let cls_name = cls.name()?.to_string();
353    let (_, enum_subclass) = cls_name.split_once("_").unwrap();
354    let typed_dict_fields = PyDict::new(py);
355    typed_dict_fields.set_item(
356        "type",
357        typed_dict_field.call1((str_schema.call(
358            (),
359            Some(&PyDict::from_sequence(
360                &vec![(
361                    "pattern",
362                    vec![
363                        "^".to_owned(),
364                        enum_subclass.to_ascii_lowercase(),
365                        "$".to_owned(),
366                    ]
367                    .join("")
368                    .into_pyobject(py)?
369                    .into_any(),
370                )]
371                .into_pyobject(py)?,
372            )?),
373        )?,))?,
374    )?;
375    typed_dict_fields.set_item(
376        "preprocessor_fn_pkl",
377        typed_dict_field.call1((nullable_schema.call1((str_schema.call0()?,))?,))?,
378    )?;
379    typed_dict_fields.set_item(
380        "postprocessor_fn_pkl",
381        typed_dict_field.call1((nullable_schema.call1((str_schema.call0()?,))?,))?,
382    )?;
383    if cls.eq(NumpySerdeConfig_STATIC::type_object(py))? {
384        typed_dict_fields.set_item(
385            "shape",
386            typed_dict_field.call1((list_schema.call1((int_schema.call(
387                (),
388                Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
389            )?,))?,))?,
390        )?;
391        typed_dict_fields.set_item(
392            "allocation_pool_min_size",
393            typed_dict_field.call1((int_schema.call(
394                (),
395                Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
396            )?,))?,
397        )?;
398        typed_dict_fields.set_item(
399            "allocation_pool_max_size",
400            typed_dict_field.call1((nullable_schema.call1((int_schema.call(
401                (),
402                Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
403            )?,))?,))?,
404        )?;
405        typed_dict_fields.set_item(
406            "allocation_pool_warning_size",
407            typed_dict_field.call1((nullable_schema.call1((int_schema.call(
408                (),
409                Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
410            )?,))?,))?,
411        )?;
412    }
413    typed_dict_schema.call1((typed_dict_fields,))
414}
415
416#[pymethods]
417impl NumpySerdeConfig {
418    // pydantic methods
419    #[classmethod]
420    fn __get_pydantic_core_schema__<'py>(
421        cls: &Bound<'py, PyType>,
422        _source_type: Bound<'py, PyAny>,
423        handler: Bound<'py, PyAny>,
424    ) -> PyResult<Bound<'py, PyAny>> {
425        let py = cls.py();
426        let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
427        if cls.eq(NumpySerdeConfig::type_object(py))? {
428            let union_list = create_union!(handler, py, DYNAMIC, STATIC)?;
429            return core_schema.call_method1("union_schema", (union_list,));
430        }
431        let python_schema = core_schema.getattr("is_instance_schema")?.call1((cls,))?;
432        core_schema.getattr("json_or_python_schema")?.call1((
433            core_schema.getattr("chain_schema")?.call1((vec![
434                get_enum_subclass_typed_dict_schema(cls, &core_schema)?,
435                core_schema
436                    .getattr("no_info_before_validator_function")?
437                    .call1((get_enum_subclass_before_validator_fn(cls)?, &python_schema))?,
438            ],))?,
439            python_schema,
440        ))
441    }
442
443    pub fn to_json(&self) -> PyResult<PyObject> {
444        Python::with_gil(|py| {
445            let data = PyDict::new(py);
446            data.set_item("type", self.to_string().to_ascii_lowercase())?;
447            match self {
448                NumpySerdeConfig::DYNAMIC {
449                    preprocessor_fn,
450                    postprocessor_fn,
451                } => {
452                    let preprocessor_fn_pkl = preprocessor_fn
453                        .as_ref()
454                        .map(|preprocessor_fn| {
455                            Ok::<_, PyErr>(
456                                py.import("pickle")?
457                                    .getattr("dumps")?
458                                    .call1((preprocessor_fn,))?
459                                    .call_method0("hex")?,
460                            )
461                        })
462                        .transpose()?;
463                    data.set_item("preprocessor_fn_pkl", preprocessor_fn_pkl)?;
464                    let postprocessor_fn_pkl = postprocessor_fn
465                        .as_ref()
466                        .map(|postprocessor_fn| {
467                            Ok::<_, PyErr>(
468                                py.import("pickle")?
469                                    .getattr("dumps")?
470                                    .call1((postprocessor_fn,))?
471                                    .call_method0("hex")?,
472                            )
473                        })
474                        .transpose()?;
475                    data.set_item("postprocessor_fn_pkl", postprocessor_fn_pkl)?;
476                }
477                NumpySerdeConfig::STATIC {
478                    preprocessor_fn,
479                    postprocessor_fn,
480                    shape,
481                    allocation_pool_min_size,
482                    allocation_pool_max_size,
483                    allocation_pool_warning_size,
484                } => {
485                    let preprocessor_fn_pkl = preprocessor_fn
486                        .as_ref()
487                        .map(|preprocessor_fn| {
488                            Ok::<_, PyErr>(
489                                py.import("pickle")?
490                                    .getattr("dumps")?
491                                    .call1((preprocessor_fn,))?
492                                    .call_method0("hex")?,
493                            )
494                        })
495                        .transpose()?;
496                    data.set_item("preprocessor_fn_pkl", preprocessor_fn_pkl)?;
497                    let postprocessor_fn_pkl = postprocessor_fn
498                        .as_ref()
499                        .map(|postprocessor_fn| {
500                            Ok::<_, PyErr>(
501                                py.import("pickle")?
502                                    .getattr("dumps")?
503                                    .call1((postprocessor_fn,))?
504                                    .call_method0("hex")?,
505                            )
506                        })
507                        .transpose()?;
508                    data.set_item("postprocessor_fn_pkl", postprocessor_fn_pkl)?;
509                    data.set_item("shape", shape)?;
510                    data.set_item("allocation_pool_min_size", allocation_pool_min_size)?;
511                    data.set_item("allocation_pool_max_size", allocation_pool_max_size)?;
512                    data.set_item("allocation_pool_warning_size", allocation_pool_warning_size)?;
513                }
514            }
515            Ok(data.into_any().unbind())
516        })
517    }
518}
519
520#[derive(Clone)]
521pub struct NumpySerde<T: Element> {
522    pub config: NumpySerdeConfig,
523    pub allocation_pool: Vec<Py<PyArrayDyn<T>>>,
524}
525
526impl<T: Element + AnyBitPattern + NoUninit> NumpySerde<T> {
527    pub fn append_inner<'py>(
528        &mut self,
529        buf: &mut [u8],
530        mut offset: usize,
531        array: &Bound<'py, PyArrayDyn<T>>,
532    ) -> PyResult<usize> {
533        match &self.config {
534            NumpySerdeConfig::DYNAMIC { .. } => {
535                let shape = array.shape();
536                offset = append_usize(buf, offset, shape.len());
537                for &dim in shape.iter() {
538                    offset = append_usize(buf, offset, dim);
539                }
540                let obj_vec = array.to_vec()?;
541                offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
542                offset = append_bytes(buf, offset, cast_slice::<T, u8>(&obj_vec));
543            }
544            NumpySerdeConfig::STATIC { .. } => {
545                let obj_vec = array.to_vec()?;
546                offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
547                offset = append_bytes(buf, offset, cast_slice::<T, u8>(&obj_vec));
548            }
549        }
550        Ok(offset)
551    }
552
553    fn append_inner_vec<'py>(
554        &mut self,
555        v: &mut Vec<u8>,
556        start_addr: Option<usize>,
557        array: &Bound<'py, PyArrayDyn<T>>,
558    ) -> PyResult<()> {
559        let Some(start_addr) = start_addr else {
560            Err(InvalidStateError::new_err("Tried to serialize numpy data, but there was no start_addr provided so there's no way to know how to align the data. (was this called from inside a preprocessor function?)"))?
561        };
562        match &self.config {
563            NumpySerdeConfig::DYNAMIC { .. } => {
564                let shape = array.shape();
565                append_usize_vec(v, shape.len());
566                for &dim in shape.iter() {
567                    append_usize_vec(v, dim);
568                }
569                let obj_vec = array.to_vec()?;
570                v.append(&mut vec![
571                    0;
572                    get_bytes_to_alignment::<T>(start_addr + v.len())
573                ]);
574                append_bytes_vec(v, cast_slice::<T, u8>(&obj_vec));
575            }
576            NumpySerdeConfig::STATIC { .. } => {
577                let obj_vec = array.to_vec()?;
578                v.append(&mut vec![
579                    0;
580                    get_bytes_to_alignment::<T>(start_addr + v.len())
581                ]);
582                append_bytes_vec(v, cast_slice::<T, u8>(&obj_vec));
583            }
584        }
585        Ok(())
586    }
587
588    pub fn retrieve_inner<'py>(
589        &mut self,
590        py: Python<'py>,
591        buf: &[u8],
592        mut offset: usize,
593    ) -> PyResult<(Bound<'py, PyArrayDyn<T>>, usize)> {
594        let py_array = match &self.config {
595            NumpySerdeConfig::DYNAMIC { .. } => {
596                let shape_len;
597                (shape_len, offset) = retrieve_usize(buf, offset)?;
598                let mut shape = Vec::with_capacity(shape_len);
599                for _ in 0..shape_len {
600                    let dim;
601                    (dim, offset) = retrieve_usize(buf, offset)?;
602                    shape.push(dim);
603                }
604                offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
605                let obj_bytes;
606                (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
607                let array_vec = cast_slice::<u8, T>(obj_bytes).to_vec();
608                ArrayD::from_shape_vec(shape, array_vec)
609                    .map_err(|err| {
610                        InvalidStateError::new_err(format!(
611                            "Failed create Numpy array of T from shape and Vec<T>: {}",
612                            err
613                        ))
614                    })?
615                    .into_pyarray(py)
616            }
617            NumpySerdeConfig::STATIC {
618                shape,
619                allocation_pool_min_size,
620                allocation_pool_max_size,
621                allocation_pool_warning_size,
622                ..
623            } => {
624                offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
625                let obj_bytes;
626                (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
627                let array_vec = cast_slice::<u8, T>(obj_bytes).to_vec();
628                let py_array;
629                if allocation_pool_max_size.is_none() || allocation_pool_max_size.unwrap() > 0 {
630                    // Take two random elements from the pool
631                    let pool_size = self.allocation_pool.len();
632                    let idx1 = fastrand::usize(..pool_size);
633                    let idx2 = fastrand::usize(..pool_size);
634                    let e1 = &self.allocation_pool[idx1];
635                    let e2 = &self.allocation_pool[idx2];
636                    let e1_free = e1.get_refcnt(py) == 1;
637                    let e2_free = e2.get_refcnt(py) == 1;
638                    if e1_free && e2_free {
639                        py_array = e1.clone_ref(py).into_bound(py);
640                        if self.allocation_pool.len() > *allocation_pool_min_size {
641                            self.allocation_pool.swap_remove(idx2);
642                        }
643                    } else if e1_free {
644                        py_array = e1.clone_ref(py).into_bound(py);
645                    } else if e2_free {
646                        py_array = e2.clone_ref(py).into_bound(py);
647                    } else {
648                        let arr: Bound<'_, PyArray<T, _>> =
649                            unsafe { PyArrayDyn::new(py, &shape[..], false) };
650                        if allocation_pool_max_size.is_none()
651                            || self.allocation_pool.len() < allocation_pool_max_size.unwrap()
652                        {
653                            self.allocation_pool.push(arr.clone().unbind());
654                        }
655                        py_array = arr;
656                        if let Some(allocation_pool_warning_size) = allocation_pool_warning_size {
657                            if pool_size > *allocation_pool_warning_size {
658                                if pool_size % 100 == 0 {
659                                    let recursion_depth = env::var(
660                                        "PYANY_SERDE_NUMPY_ALLOCATION_WARNING_RECUSION_DEPTH",
661                                    )
662                                    .map(|v| v.parse::<usize>().unwrap_or(5))
663                                    .unwrap_or(5);
664                                    println!("Warning: the allocation pool for this Numpy PyAny serde instance is currently {pool_size}, which is larger than the warning limit set ({allocation_pool_warning_size}). Here is a random element from the allocation pool and a dict of the types of its referrers (and the referrers of those referrers, etc, up to the recursion depth set by PYANY_SERDE_NUMPY_ALLOCATION_WARNING_RECUSION_DEPTH (5 by default)):");
665                                    let mut total_in_use = 0;
666                                    for item in self.allocation_pool.iter() {
667                                        if item.get_refcnt(py) > 1 {
668                                            total_in_use += 1;
669                                        }
670                                    }
671                                    println!("Number of elements in allocation pool which are currently in use: {total_in_use}");
672                                    let idx = fastrand::usize(..pool_size);
673                                    let e = &self.allocation_pool[idx];
674                                    println!(
675                                        "{}\n\n",
676                                        get_ref_types(e.bind(py), recursion_depth)?
677                                            .repr()?
678                                            .to_string()
679                                    );
680                                }
681                            }
682                        }
683                    }
684                    unsafe { py_array.as_slice_mut().unwrap().copy_from_slice(&array_vec) };
685                } else {
686                    py_array = ArrayD::from_shape_vec(&shape[..], array_vec)
687                        .map_err(|err| {
688                            InvalidStateError::new_err(format!(
689                                "Failed create Numpy array of T from shape and Vec<T>: {}",
690                                err
691                            ))
692                        })?
693                        .into_pyarray(py);
694                }
695                py_array
696            }
697        };
698
699        Ok((py_array, offset))
700    }
701}
702
703#[macro_export]
704macro_rules! create_numpy_pyany_serde {
705    ($ty: ty, $config: expr) => {{
706        let mut allocation_pool = Vec::new();
707        let new_config;
708        if let NumpySerdeConfig::STATIC {
709            shape,
710            preprocessor_fn,
711            postprocessor_fn,
712            allocation_pool_min_size,
713            allocation_pool_max_size,
714            allocation_pool_warning_size,
715        } = $config
716        {
717            let allocation_pool_min_size = allocation_pool_min_size.max(2);
718            if allocation_pool_max_size.map(|v| v > 0).unwrap_or(true) {
719                let starting_pool_size = allocation_pool_min_size
720                    .min(allocation_pool_max_size.unwrap_or(allocation_pool_min_size));
721                Python::with_gil(|py| {
722                    for _ in 0..starting_pool_size {
723                        let arr: Bound<'_, numpy::PyArray<$ty, _>> =
724                            unsafe { numpy::PyArrayDyn::new(py, &shape[..], false) };
725                        allocation_pool.push(arr.unbind());
726                    }
727                });
728            }
729            new_config = NumpySerdeConfig::STATIC {
730                shape,
731                preprocessor_fn,
732                postprocessor_fn,
733                allocation_pool_min_size,
734                allocation_pool_max_size,
735                allocation_pool_warning_size,
736            };
737        } else {
738            new_config = $config;
739        }
740
741        Box::new(NumpySerde::<$ty> {
742            config: new_config,
743            allocation_pool,
744        })
745    }};
746}
747
748pub fn get_numpy_serde(dtype: NumpyDtype, config: NumpySerdeConfig) -> Box<dyn PyAnySerde> {
749    match dtype {
750        NumpyDtype::INT8 => {
751            create_numpy_pyany_serde!(i8, config)
752        }
753        NumpyDtype::INT16 => {
754            create_numpy_pyany_serde!(i16, config)
755        }
756        NumpyDtype::INT32 => {
757            create_numpy_pyany_serde!(i32, config)
758        }
759        NumpyDtype::INT64 => {
760            create_numpy_pyany_serde!(i64, config)
761        }
762        NumpyDtype::UINT8 => {
763            create_numpy_pyany_serde!(u8, config)
764        }
765        NumpyDtype::UINT16 => {
766            create_numpy_pyany_serde!(u16, config)
767        }
768        NumpyDtype::UINT32 => {
769            create_numpy_pyany_serde!(u32, config)
770        }
771        NumpyDtype::UINT64 => {
772            create_numpy_pyany_serde!(u64, config)
773        }
774        NumpyDtype::FLOAT32 => {
775            create_numpy_pyany_serde!(f32, config)
776        }
777        NumpyDtype::FLOAT64 => {
778            create_numpy_pyany_serde!(f64, config)
779        }
780    }
781}
782
783impl<T: Element + AnyBitPattern + NoUninit> PyAnySerde for NumpySerde<T> {
784    fn append<'py>(
785        &mut self,
786        buf: &mut [u8],
787        offset: usize,
788        obj: &Bound<'py, PyAny>,
789    ) -> PyResult<usize> {
790        let preprocessor_fn_option = match &self.config {
791            NumpySerdeConfig::DYNAMIC {
792                preprocessor_fn, ..
793            } => preprocessor_fn,
794            NumpySerdeConfig::STATIC {
795                preprocessor_fn, ..
796            } => preprocessor_fn,
797        };
798        match preprocessor_fn_option {
799            Some(preprocessor_fn) => self.append_inner(
800                buf,
801                offset,
802                preprocessor_fn
803                    .bind(obj.py())
804                    .call1((obj,))?
805                    .downcast::<PyArrayDyn<T>>()?,
806            ),
807            None => self.append_inner(buf, offset, obj.downcast::<PyArrayDyn<T>>()?),
808        }
809    }
810
811    fn append_vec<'py>(
812        &mut self,
813        v: &mut Vec<u8>,
814        start_addr: Option<usize>,
815        obj: &Bound<'py, PyAny>,
816    ) -> PyResult<()> {
817        let preprocessor_fn_option = match &self.config {
818            NumpySerdeConfig::DYNAMIC {
819                preprocessor_fn, ..
820            } => preprocessor_fn,
821            NumpySerdeConfig::STATIC {
822                preprocessor_fn, ..
823            } => preprocessor_fn,
824        };
825        match preprocessor_fn_option {
826            Some(preprocessor_fn) => self.append_inner_vec(
827                v,
828                start_addr,
829                preprocessor_fn
830                    .bind(obj.py())
831                    .call1((obj,))?
832                    .downcast::<PyArrayDyn<T>>()?,
833            ),
834            None => self.append_inner_vec(v, start_addr, obj.downcast::<PyArrayDyn<T>>()?),
835        }
836    }
837
838    fn retrieve<'py>(
839        &mut self,
840        py: Python<'py>,
841        buf: &[u8],
842        offset: usize,
843    ) -> PyResult<(Bound<'py, PyAny>, usize)> {
844        let (array, offset) = self.retrieve_inner(py, buf, offset)?;
845
846        let postprocessor_fn_option = match &self.config {
847            NumpySerdeConfig::DYNAMIC {
848                postprocessor_fn, ..
849            } => postprocessor_fn,
850            NumpySerdeConfig::STATIC {
851                postprocessor_fn, ..
852            } => postprocessor_fn,
853        };
854
855        Ok(match postprocessor_fn_option {
856            Some(postprocessor_fn) => (postprocessor_fn.bind(py).call1((array, offset))?, offset),
857            None => (array.into_any(), offset),
858        })
859    }
860}
861
862static GC: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
863fn get_ref_types<'py>(o: &Bound<'py, PyAny>, recursion: usize) -> PyResult<Bound<'py, PyAny>> {
864    let py = o.py();
865    let gc = GC
866        .get_or_try_init(py, || Ok::<_, PyErr>(py.import("gc")?.unbind()))?
867        .bind(o.py());
868    let referrers = gc
869        .call_method1(intern!(py, "get_referrers"), (o,))?
870        .downcast_into::<PyList>()?;
871    if recursion > 0 {
872        Ok(PyDict::from_sequence(
873            &referrers
874                .iter()
875                .map(|referrer| {
876                    Ok::<_, PyErr>((
877                        referrer.get_type().repr()?.to_string(),
878                        get_ref_types(&referrer, recursion - 1)?,
879                    ))
880                })
881                .collect::<PyResult<Vec<_>>>()?
882                .into_pyobject(py)?,
883        )?
884        .into_any())
885    } else {
886        referrers
887            .iter()
888            .map(|referrer| Ok::<_, PyErr>(referrer.get_type().repr()?.to_string()))
889            .collect::<PyResult<Vec<_>>>()?
890            .into_pyobject(py)
891    }
892}