1use std::env;
2
3use bytemuck::{cast_slice, AnyBitPattern, NoUninit};
4use numpy::ndarray::ArrayD;
5use numpy::{Element, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
6use numpy::{IntoPyArray, PyArray};
7use pyo3::exceptions::asyncio::InvalidStateError;
8use pyo3::exceptions::PyValueError;
9use pyo3::sync::GILOnceCell;
10use pyo3::types::{PyBytes, PyCFunction, PyDict, PyList, PyTuple, PyType};
11use pyo3::{intern, prelude::*, PyTypeInfo};
12use strum_macros::Display;
13
14use crate::communication::{
15 append_bool_vec, append_bytes_vec, append_usize, append_usize_vec, retrieve_bool,
16 retrieve_usize,
17};
18use crate::{
19 common::{get_bytes_to_alignment, NumpyDtype},
20 communication::{append_bytes, retrieve_bytes},
21 PyAnySerde,
22};
23
24fn append_usize_option_vec(v: &mut Vec<u8>, val_option: &Option<usize>) {
25 if let Some(val) = val_option {
26 append_bool_vec(v, true);
27 append_usize_vec(v, *val);
28 } else {
29 append_bool_vec(v, false);
30 }
31}
32
33fn retrieve_usize_option(buf: &[u8], mut offset: usize) -> PyResult<(Option<usize>, usize)> {
34 let has_val;
35 (has_val, offset) = retrieve_bool(buf, offset)?;
36 if has_val {
37 let val;
38 (val, offset) = retrieve_usize(buf, offset)?;
39 Ok((Some(val), offset))
40 } else {
41 Ok((None, offset))
42 }
43}
44
45fn append_python_pkl_option_vec(v: &mut Vec<u8>, obj_option: &Option<PyObject>) -> PyResult<()> {
46 if let Some(obj) = obj_option {
47 append_bool_vec(v, true);
48 Python::with_gil::<_, PyResult<_>>(|py| {
49 let preprocessor_fn_py_bytes = py
50 .import("pickle")?
51 .getattr("dumps")?
52 .call1((obj,))?
53 .downcast_into::<PyBytes>()?;
54 append_bytes_vec(v, preprocessor_fn_py_bytes.as_bytes());
55 Ok(())
56 })?;
57 } else {
58 append_bool_vec(v, false);
59 }
60 Ok(())
61}
62
63fn retrieve_python_pkl_option(
64 buf: &[u8],
65 mut offset: usize,
66) -> PyResult<(Option<PyObject>, usize)> {
67 let has_obj;
68 (has_obj, offset) = retrieve_bool(buf, offset)?;
69 if has_obj {
70 Python::with_gil::<_, PyResult<_>>(|py| {
71 let obj_bytes;
72 (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
73 Ok((
74 Some(
75 py.import("pickle")?
76 .getattr("loads")?
77 .call1((PyBytes::new(py, obj_bytes).into_pyobject(py)?,))?
78 .unbind(),
79 ),
80 offset,
81 ))
82 })
83 } else {
84 Ok((None, offset))
85 }
86}
87
88#[pyclass]
89#[derive(Clone)]
90pub struct PickleableNumpySerdeConfig(pub Option<NumpySerdeConfig>);
91
92#[pymethods]
93impl PickleableNumpySerdeConfig {
94 #[new]
95 #[pyo3(signature = (*args))]
96 fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
97 let vec_args = args.iter().collect::<Vec<_>>();
98 if vec_args.len() > 1 {
99 return Err(PyValueError::new_err(format!(
100 "PickleableNumpySerdeConfig constructor takes 0 or 1 parameters, received {}",
101 args.as_any().repr()?.to_str()?
102 )));
103 }
104 if vec_args.len() == 1 {
105 Ok(PickleableNumpySerdeConfig(
106 vec_args[0].extract::<Option<NumpySerdeConfig>>()?,
107 ))
108 } else {
109 Ok(PickleableNumpySerdeConfig(None))
110 }
111 }
112 pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
113 Ok(match self.0.as_ref().unwrap() {
114 NumpySerdeConfig::DYNAMIC {
115 preprocessor_fn,
116 postprocessor_fn,
117 } => {
118 let mut bytes = vec![0];
119 append_python_pkl_option_vec(&mut bytes, preprocessor_fn)?;
120 append_python_pkl_option_vec(&mut bytes, postprocessor_fn)?;
121 bytes
122 }
123 NumpySerdeConfig::STATIC {
124 preprocessor_fn,
125 postprocessor_fn,
126 shape,
127 allocation_pool_min_size,
128 allocation_pool_max_size,
129 allocation_pool_warning_size,
130 } => {
131 let mut bytes = vec![1];
132 append_python_pkl_option_vec(&mut bytes, preprocessor_fn)?;
133 append_python_pkl_option_vec(&mut bytes, postprocessor_fn)?;
134 append_usize_vec(&mut bytes, shape.len());
135 for &dim in shape.iter() {
136 append_usize_vec(&mut bytes, dim);
137 }
138 append_usize_vec(&mut bytes, *allocation_pool_min_size);
139 append_usize_option_vec(&mut bytes, allocation_pool_max_size);
140 append_usize_option_vec(&mut bytes, allocation_pool_warning_size);
141 bytes
142 }
143 })
144 }
145 pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
146 let buf = &state[..];
147 let type_byte = buf[0];
148 let mut offset = 1;
149 self.0 = Some(match type_byte {
150 0 => {
151 let preprocessor_fn;
152 (preprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
153 let postprocessor_fn;
154 (postprocessor_fn, _) = retrieve_python_pkl_option(buf, offset)?;
155 NumpySerdeConfig::DYNAMIC {
156 preprocessor_fn,
157 postprocessor_fn,
158 }
159 }
160 1 => {
161 let preprocessor_fn;
162 (preprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
163 let postprocessor_fn;
164 (postprocessor_fn, offset) = retrieve_python_pkl_option(buf, offset)?;
165 let shape_len;
166 (shape_len, offset) = retrieve_usize(buf, offset)?;
167 let mut shape = Vec::with_capacity(shape_len);
168 for _ in 0..shape_len {
169 let dim;
170 (dim, offset) = retrieve_usize(buf, offset)?;
171 shape.push(dim);
172 }
173 let allocation_pool_min_size;
174 (allocation_pool_min_size, offset) = retrieve_usize(buf, offset)?;
175 let allocation_pool_max_size;
176 (allocation_pool_max_size, _) = retrieve_usize_option(buf, offset)?;
177 let allocation_pool_warning_size;
178 (allocation_pool_warning_size, _) = retrieve_usize_option(buf, offset)?;
179 NumpySerdeConfig::STATIC {
180 preprocessor_fn,
181 postprocessor_fn,
182 shape,
183 allocation_pool_min_size,
184 allocation_pool_max_size,
185 allocation_pool_warning_size,
186 }
187 }
188 v => Err(InvalidStateError::new_err(format!(
189 "Got invalid type byte for NumpySerdeConfig: {v}"
190 )))?,
191 });
192 Ok(())
193 }
194}
195
196#[pyclass]
198#[derive(Debug, Clone, Display)]
199pub enum NumpySerdeConfig {
200 #[pyo3(constructor = (preprocessor_fn = None, postprocessor_fn = None))]
201 DYNAMIC {
202 preprocessor_fn: Option<PyObject>,
203 postprocessor_fn: Option<PyObject>,
204 },
205 #[pyo3(constructor = (shape, preprocessor_fn = None, postprocessor_fn = None, allocation_pool_min_size = 0, allocation_pool_max_size = None, allocation_pool_warning_size = Some(10000)))]
206 STATIC {
207 shape: Vec<usize>,
208 preprocessor_fn: Option<PyObject>,
209 postprocessor_fn: Option<PyObject>,
210 allocation_pool_min_size: usize,
211 allocation_pool_max_size: Option<usize>,
212 allocation_pool_warning_size: Option<usize>,
213 },
214}
215
216macro_rules! create_union {
217 ($handler:expr, $py:expr, $($type:ident),+) => {{
218 let mut union_list = Vec::new();
219 $(
220 union_list.push(
221 $handler.call_method1(
222 "generate_schema",
223 (paste::paste! { [<NumpySerdeConfig_ $type>]::type_object($py) },)
224 )?
225 );
226 )+
227 Ok::<_, PyErr>(union_list)
228 }};
229}
230
231pub fn check_for_unpickling<'py>(data: &Bound<'py, PyAny>) -> PyResult<bool> {
232 let preprocessor_fn_hex_option = data
233 .get_item("preprocessor_fn_pkl")?
234 .extract::<Option<String>>()?;
235 let postprocessor_fn_hex_option = data
236 .get_item("postprocessor_fn_pkl")?
237 .extract::<Option<String>>()?;
238 Ok(preprocessor_fn_hex_option.is_some() || postprocessor_fn_hex_option.is_some())
239}
240
241fn get_enum_subclass_before_validator_fn<'py>(
242 cls: &Bound<'py, PyType>,
243) -> PyResult<Bound<'py, PyCFunction>> {
244 let _py = cls.py();
245 let py_cls = cls.clone().unbind();
246 let func = move |args: &Bound<'_, PyTuple>,
247 _kwargs: Option<&Bound<'_, PyDict>>|
248 -> PyResult<PyObject> {
249 let py = args.py();
250 let data = args.get_item(0)?;
251 let cls = py_cls.bind(py);
252 let preprocessor_fn_hex_option = data
253 .get_item("preprocessor_fn_pkl")?
254 .extract::<Option<String>>()?;
255 let preprocessor_fn_option = preprocessor_fn_hex_option
256 .map(|preprocessor_fn_hex| {
257 Ok::<_, PyErr>(
258 py.import("pickle")?
259 .getattr("loads")?
260 .call1((PyBytes::new(
261 py,
262 &hex::decode(preprocessor_fn_hex.as_str()).map_err(|err| {
263 PyValueError::new_err(format!(
264 "python_serde_pkl could not be decoded from hex into bytes: {}",
265 err.to_string()
266 ))
267 })?,
268 ),))?
269 .unbind(),
270 )
271 })
272 .transpose()?;
273 let postprocessor_fn_hex_option = data
274 .get_item("postprocessor_fn_pkl")?
275 .extract::<Option<String>>()?;
276 let postprocessor_fn_option = postprocessor_fn_hex_option
277 .map(|postprocessor_fn_hex| {
278 Ok::<_, PyErr>(
279 py.import("pickle")?
280 .getattr("loads")?
281 .call1((PyBytes::new(
282 py,
283 &hex::decode(postprocessor_fn_hex.as_str()).map_err(|err| {
284 PyValueError::new_err(format!(
285 "python_serde_pkl could not be decoded from hex into bytes: {}",
286 err.to_string()
287 ))
288 })?,
289 ),))?
290 .unbind(),
291 )
292 })
293 .transpose()?;
294 if cls.eq(NumpySerdeConfig_DYNAMIC::type_object(py))? {
295 Ok(NumpySerdeConfig::DYNAMIC {
296 preprocessor_fn: preprocessor_fn_option,
297 postprocessor_fn: postprocessor_fn_option,
298 }
299 .into_pyobject(py)?
300 .into_any()
301 .unbind())
302 } else if cls.eq(NumpySerdeConfig_STATIC::type_object(py))? {
303 let shape = data.get_item("shape")?.extract::<Vec<usize>>()?;
304 let allocation_pool_min_size = data
305 .get_item("allocation_pool_min_size")?
306 .extract::<usize>()?;
307 let allocation_pool_max_size = data
308 .get_item("allocation_pool_max_size")?
309 .extract::<Option<usize>>()?;
310 let allocation_pool_warning_size = data
311 .get_item("allocation_pool_warning_size")?
312 .extract::<Option<usize>>()?;
313 if allocation_pool_max_size.is_some()
314 && allocation_pool_min_size > allocation_pool_max_size.unwrap()
315 {
316 Err(PyValueError::new_err(format!(
317 "Validation error: allocation_pool_min_size ({}) cannot be greater than allocation_pool_max_size ({})", allocation_pool_min_size, allocation_pool_max_size.unwrap()
318 )))?
319 }
320 Ok(NumpySerdeConfig::STATIC {
321 preprocessor_fn: preprocessor_fn_option,
322 postprocessor_fn: postprocessor_fn_option,
323 shape,
324 allocation_pool_min_size,
325 allocation_pool_max_size,
326 allocation_pool_warning_size,
327 }
328 .into_pyobject(py)?
329 .into_any()
330 .unbind())
331 } else {
332 Err(PyValueError::new_err(format!(
333 "Unexpected class: {}",
334 cls.repr()?.to_str()?
335 )))
336 }
337 };
338 PyCFunction::new_closure(_py, None, None, func)
339}
340
341fn get_enum_subclass_typed_dict_schema<'py>(
342 cls: &Bound<'py, PyType>,
343 core_schema: &Bound<'py, PyAny>,
344) -> PyResult<Bound<'py, PyAny>> {
345 let py = cls.py();
346 let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
347 let typed_dict_field = core_schema.getattr("typed_dict_field")?;
348 let int_schema = core_schema.getattr("int_schema")?;
349 let str_schema = core_schema.getattr("str_schema")?;
350 let list_schema = core_schema.getattr("list_schema")?;
351 let nullable_schema = core_schema.getattr("nullable_schema")?;
352 let cls_name = cls.name()?.to_string();
353 let (_, enum_subclass) = cls_name.split_once("_").unwrap();
354 let typed_dict_fields = PyDict::new(py);
355 typed_dict_fields.set_item(
356 "type",
357 typed_dict_field.call1((str_schema.call(
358 (),
359 Some(&PyDict::from_sequence(
360 &vec![(
361 "pattern",
362 vec![
363 "^".to_owned(),
364 enum_subclass.to_ascii_lowercase(),
365 "$".to_owned(),
366 ]
367 .join("")
368 .into_pyobject(py)?
369 .into_any(),
370 )]
371 .into_pyobject(py)?,
372 )?),
373 )?,))?,
374 )?;
375 typed_dict_fields.set_item(
376 "preprocessor_fn_pkl",
377 typed_dict_field.call1((nullable_schema.call1((str_schema.call0()?,))?,))?,
378 )?;
379 typed_dict_fields.set_item(
380 "postprocessor_fn_pkl",
381 typed_dict_field.call1((nullable_schema.call1((str_schema.call0()?,))?,))?,
382 )?;
383 if cls.eq(NumpySerdeConfig_STATIC::type_object(py))? {
384 typed_dict_fields.set_item(
385 "shape",
386 typed_dict_field.call1((list_schema.call1((int_schema.call(
387 (),
388 Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
389 )?,))?,))?,
390 )?;
391 typed_dict_fields.set_item(
392 "allocation_pool_min_size",
393 typed_dict_field.call1((int_schema.call(
394 (),
395 Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
396 )?,))?,
397 )?;
398 typed_dict_fields.set_item(
399 "allocation_pool_max_size",
400 typed_dict_field.call1((nullable_schema.call1((int_schema.call(
401 (),
402 Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
403 )?,))?,))?,
404 )?;
405 typed_dict_fields.set_item(
406 "allocation_pool_warning_size",
407 typed_dict_field.call1((nullable_schema.call1((int_schema.call(
408 (),
409 Some(&PyDict::from_sequence(&vec![("ge", 0)].into_pyobject(py)?)?),
410 )?,))?,))?,
411 )?;
412 }
413 typed_dict_schema.call1((typed_dict_fields,))
414}
415
416#[pymethods]
417impl NumpySerdeConfig {
418 #[classmethod]
420 fn __get_pydantic_core_schema__<'py>(
421 cls: &Bound<'py, PyType>,
422 _source_type: Bound<'py, PyAny>,
423 handler: Bound<'py, PyAny>,
424 ) -> PyResult<Bound<'py, PyAny>> {
425 let py = cls.py();
426 let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
427 if cls.eq(NumpySerdeConfig::type_object(py))? {
428 let union_list = create_union!(handler, py, DYNAMIC, STATIC)?;
429 return core_schema.call_method1("union_schema", (union_list,));
430 }
431 let python_schema = core_schema.getattr("is_instance_schema")?.call1((cls,))?;
432 core_schema.getattr("json_or_python_schema")?.call1((
433 core_schema.getattr("chain_schema")?.call1((vec![
434 get_enum_subclass_typed_dict_schema(cls, &core_schema)?,
435 core_schema
436 .getattr("no_info_before_validator_function")?
437 .call1((get_enum_subclass_before_validator_fn(cls)?, &python_schema))?,
438 ],))?,
439 python_schema,
440 ))
441 }
442
443 pub fn to_json(&self) -> PyResult<PyObject> {
444 Python::with_gil(|py| {
445 let data = PyDict::new(py);
446 data.set_item("type", self.to_string().to_ascii_lowercase())?;
447 match self {
448 NumpySerdeConfig::DYNAMIC {
449 preprocessor_fn,
450 postprocessor_fn,
451 } => {
452 let preprocessor_fn_pkl = preprocessor_fn
453 .as_ref()
454 .map(|preprocessor_fn| {
455 Ok::<_, PyErr>(
456 py.import("pickle")?
457 .getattr("dumps")?
458 .call1((preprocessor_fn,))?
459 .call_method0("hex")?,
460 )
461 })
462 .transpose()?;
463 data.set_item("preprocessor_fn_pkl", preprocessor_fn_pkl)?;
464 let postprocessor_fn_pkl = postprocessor_fn
465 .as_ref()
466 .map(|postprocessor_fn| {
467 Ok::<_, PyErr>(
468 py.import("pickle")?
469 .getattr("dumps")?
470 .call1((postprocessor_fn,))?
471 .call_method0("hex")?,
472 )
473 })
474 .transpose()?;
475 data.set_item("postprocessor_fn_pkl", postprocessor_fn_pkl)?;
476 }
477 NumpySerdeConfig::STATIC {
478 preprocessor_fn,
479 postprocessor_fn,
480 shape,
481 allocation_pool_min_size,
482 allocation_pool_max_size,
483 allocation_pool_warning_size,
484 } => {
485 let preprocessor_fn_pkl = preprocessor_fn
486 .as_ref()
487 .map(|preprocessor_fn| {
488 Ok::<_, PyErr>(
489 py.import("pickle")?
490 .getattr("dumps")?
491 .call1((preprocessor_fn,))?
492 .call_method0("hex")?,
493 )
494 })
495 .transpose()?;
496 data.set_item("preprocessor_fn_pkl", preprocessor_fn_pkl)?;
497 let postprocessor_fn_pkl = postprocessor_fn
498 .as_ref()
499 .map(|postprocessor_fn| {
500 Ok::<_, PyErr>(
501 py.import("pickle")?
502 .getattr("dumps")?
503 .call1((postprocessor_fn,))?
504 .call_method0("hex")?,
505 )
506 })
507 .transpose()?;
508 data.set_item("postprocessor_fn_pkl", postprocessor_fn_pkl)?;
509 data.set_item("shape", shape)?;
510 data.set_item("allocation_pool_min_size", allocation_pool_min_size)?;
511 data.set_item("allocation_pool_max_size", allocation_pool_max_size)?;
512 data.set_item("allocation_pool_warning_size", allocation_pool_warning_size)?;
513 }
514 }
515 Ok(data.into_any().unbind())
516 })
517 }
518}
519
520#[derive(Clone)]
521pub struct NumpySerde<T: Element> {
522 pub config: NumpySerdeConfig,
523 pub allocation_pool: Vec<Py<PyArrayDyn<T>>>,
524}
525
526impl<T: Element + AnyBitPattern + NoUninit> NumpySerde<T> {
527 pub fn append_inner<'py>(
528 &mut self,
529 buf: &mut [u8],
530 mut offset: usize,
531 array: &Bound<'py, PyArrayDyn<T>>,
532 ) -> PyResult<usize> {
533 match &self.config {
534 NumpySerdeConfig::DYNAMIC { .. } => {
535 let shape = array.shape();
536 offset = append_usize(buf, offset, shape.len());
537 for &dim in shape.iter() {
538 offset = append_usize(buf, offset, dim);
539 }
540 let obj_vec = array.to_vec()?;
541 offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
542 offset = append_bytes(buf, offset, cast_slice::<T, u8>(&obj_vec));
543 }
544 NumpySerdeConfig::STATIC { .. } => {
545 let obj_vec = array.to_vec()?;
546 offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
547 offset = append_bytes(buf, offset, cast_slice::<T, u8>(&obj_vec));
548 }
549 }
550 Ok(offset)
551 }
552
553 fn append_inner_vec<'py>(
554 &mut self,
555 v: &mut Vec<u8>,
556 start_addr: Option<usize>,
557 array: &Bound<'py, PyArrayDyn<T>>,
558 ) -> PyResult<()> {
559 let Some(start_addr) = start_addr else {
560 Err(InvalidStateError::new_err("Tried to serialize numpy data, but there was no start_addr provided so there's no way to know how to align the data. (was this called from inside a preprocessor function?)"))?
561 };
562 match &self.config {
563 NumpySerdeConfig::DYNAMIC { .. } => {
564 let shape = array.shape();
565 append_usize_vec(v, shape.len());
566 for &dim in shape.iter() {
567 append_usize_vec(v, dim);
568 }
569 let obj_vec = array.to_vec()?;
570 v.append(&mut vec![
571 0;
572 get_bytes_to_alignment::<T>(start_addr + v.len())
573 ]);
574 append_bytes_vec(v, cast_slice::<T, u8>(&obj_vec));
575 }
576 NumpySerdeConfig::STATIC { .. } => {
577 let obj_vec = array.to_vec()?;
578 v.append(&mut vec![
579 0;
580 get_bytes_to_alignment::<T>(start_addr + v.len())
581 ]);
582 append_bytes_vec(v, cast_slice::<T, u8>(&obj_vec));
583 }
584 }
585 Ok(())
586 }
587
588 pub fn retrieve_inner<'py>(
589 &mut self,
590 py: Python<'py>,
591 buf: &[u8],
592 mut offset: usize,
593 ) -> PyResult<(Bound<'py, PyArrayDyn<T>>, usize)> {
594 let py_array = match &self.config {
595 NumpySerdeConfig::DYNAMIC { .. } => {
596 let shape_len;
597 (shape_len, offset) = retrieve_usize(buf, offset)?;
598 let mut shape = Vec::with_capacity(shape_len);
599 for _ in 0..shape_len {
600 let dim;
601 (dim, offset) = retrieve_usize(buf, offset)?;
602 shape.push(dim);
603 }
604 offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
605 let obj_bytes;
606 (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
607 let array_vec = cast_slice::<u8, T>(obj_bytes).to_vec();
608 ArrayD::from_shape_vec(shape, array_vec)
609 .map_err(|err| {
610 InvalidStateError::new_err(format!(
611 "Failed create Numpy array of T from shape and Vec<T>: {}",
612 err
613 ))
614 })?
615 .into_pyarray(py)
616 }
617 NumpySerdeConfig::STATIC {
618 shape,
619 allocation_pool_min_size,
620 allocation_pool_max_size,
621 allocation_pool_warning_size,
622 ..
623 } => {
624 offset = offset + get_bytes_to_alignment::<T>(buf.as_ptr() as usize + offset);
625 let obj_bytes;
626 (obj_bytes, offset) = retrieve_bytes(buf, offset)?;
627 let array_vec = cast_slice::<u8, T>(obj_bytes).to_vec();
628 let py_array;
629 if allocation_pool_max_size.is_none() || allocation_pool_max_size.unwrap() > 0 {
630 let pool_size = self.allocation_pool.len();
632 let idx1 = fastrand::usize(..pool_size);
633 let idx2 = fastrand::usize(..pool_size);
634 let e1 = &self.allocation_pool[idx1];
635 let e2 = &self.allocation_pool[idx2];
636 let e1_free = e1.get_refcnt(py) == 1;
637 let e2_free = e2.get_refcnt(py) == 1;
638 if e1_free && e2_free {
639 py_array = e1.clone_ref(py).into_bound(py);
640 if self.allocation_pool.len() > *allocation_pool_min_size {
641 self.allocation_pool.swap_remove(idx2);
642 }
643 } else if e1_free {
644 py_array = e1.clone_ref(py).into_bound(py);
645 } else if e2_free {
646 py_array = e2.clone_ref(py).into_bound(py);
647 } else {
648 let arr: Bound<'_, PyArray<T, _>> =
649 unsafe { PyArrayDyn::new(py, &shape[..], false) };
650 if allocation_pool_max_size.is_none()
651 || self.allocation_pool.len() < allocation_pool_max_size.unwrap()
652 {
653 self.allocation_pool.push(arr.clone().unbind());
654 }
655 py_array = arr;
656 if let Some(allocation_pool_warning_size) = allocation_pool_warning_size {
657 if pool_size > *allocation_pool_warning_size {
658 if pool_size % 100 == 0 {
659 let recursion_depth = env::var(
660 "PYANY_SERDE_NUMPY_ALLOCATION_WARNING_RECUSION_DEPTH",
661 )
662 .map(|v| v.parse::<usize>().unwrap_or(5))
663 .unwrap_or(5);
664 println!("Warning: the allocation pool for this Numpy PyAny serde instance is currently {pool_size}, which is larger than the warning limit set ({allocation_pool_warning_size}). Here is a random element from the allocation pool and a dict of the types of its referrers (and the referrers of those referrers, etc, up to the recursion depth set by PYANY_SERDE_NUMPY_ALLOCATION_WARNING_RECUSION_DEPTH (5 by default)):");
665 let mut total_in_use = 0;
666 for item in self.allocation_pool.iter() {
667 if item.get_refcnt(py) > 1 {
668 total_in_use += 1;
669 }
670 }
671 println!("Number of elements in allocation pool which are currently in use: {total_in_use}");
672 let idx = fastrand::usize(..pool_size);
673 let e = &self.allocation_pool[idx];
674 println!(
675 "{}\n\n",
676 get_ref_types(e.bind(py), recursion_depth)?
677 .repr()?
678 .to_string()
679 );
680 }
681 }
682 }
683 }
684 unsafe { py_array.as_slice_mut().unwrap().copy_from_slice(&array_vec) };
685 } else {
686 py_array = ArrayD::from_shape_vec(&shape[..], array_vec)
687 .map_err(|err| {
688 InvalidStateError::new_err(format!(
689 "Failed create Numpy array of T from shape and Vec<T>: {}",
690 err
691 ))
692 })?
693 .into_pyarray(py);
694 }
695 py_array
696 }
697 };
698
699 Ok((py_array, offset))
700 }
701}
702
703#[macro_export]
704macro_rules! create_numpy_pyany_serde {
705 ($ty: ty, $config: expr) => {{
706 let mut allocation_pool = Vec::new();
707 let new_config;
708 if let NumpySerdeConfig::STATIC {
709 shape,
710 preprocessor_fn,
711 postprocessor_fn,
712 allocation_pool_min_size,
713 allocation_pool_max_size,
714 allocation_pool_warning_size,
715 } = $config
716 {
717 let allocation_pool_min_size = allocation_pool_min_size.max(2);
718 if allocation_pool_max_size.map(|v| v > 0).unwrap_or(true) {
719 let starting_pool_size = allocation_pool_min_size
720 .min(allocation_pool_max_size.unwrap_or(allocation_pool_min_size));
721 Python::with_gil(|py| {
722 for _ in 0..starting_pool_size {
723 let arr: Bound<'_, numpy::PyArray<$ty, _>> =
724 unsafe { numpy::PyArrayDyn::new(py, &shape[..], false) };
725 allocation_pool.push(arr.unbind());
726 }
727 });
728 }
729 new_config = NumpySerdeConfig::STATIC {
730 shape,
731 preprocessor_fn,
732 postprocessor_fn,
733 allocation_pool_min_size,
734 allocation_pool_max_size,
735 allocation_pool_warning_size,
736 };
737 } else {
738 new_config = $config;
739 }
740
741 Box::new(NumpySerde::<$ty> {
742 config: new_config,
743 allocation_pool,
744 })
745 }};
746}
747
748pub fn get_numpy_serde(dtype: NumpyDtype, config: NumpySerdeConfig) -> Box<dyn PyAnySerde> {
749 match dtype {
750 NumpyDtype::INT8 => {
751 create_numpy_pyany_serde!(i8, config)
752 }
753 NumpyDtype::INT16 => {
754 create_numpy_pyany_serde!(i16, config)
755 }
756 NumpyDtype::INT32 => {
757 create_numpy_pyany_serde!(i32, config)
758 }
759 NumpyDtype::INT64 => {
760 create_numpy_pyany_serde!(i64, config)
761 }
762 NumpyDtype::UINT8 => {
763 create_numpy_pyany_serde!(u8, config)
764 }
765 NumpyDtype::UINT16 => {
766 create_numpy_pyany_serde!(u16, config)
767 }
768 NumpyDtype::UINT32 => {
769 create_numpy_pyany_serde!(u32, config)
770 }
771 NumpyDtype::UINT64 => {
772 create_numpy_pyany_serde!(u64, config)
773 }
774 NumpyDtype::FLOAT32 => {
775 create_numpy_pyany_serde!(f32, config)
776 }
777 NumpyDtype::FLOAT64 => {
778 create_numpy_pyany_serde!(f64, config)
779 }
780 }
781}
782
783impl<T: Element + AnyBitPattern + NoUninit> PyAnySerde for NumpySerde<T> {
784 fn append<'py>(
785 &mut self,
786 buf: &mut [u8],
787 offset: usize,
788 obj: &Bound<'py, PyAny>,
789 ) -> PyResult<usize> {
790 let preprocessor_fn_option = match &self.config {
791 NumpySerdeConfig::DYNAMIC {
792 preprocessor_fn, ..
793 } => preprocessor_fn,
794 NumpySerdeConfig::STATIC {
795 preprocessor_fn, ..
796 } => preprocessor_fn,
797 };
798 match preprocessor_fn_option {
799 Some(preprocessor_fn) => self.append_inner(
800 buf,
801 offset,
802 preprocessor_fn
803 .bind(obj.py())
804 .call1((obj,))?
805 .downcast::<PyArrayDyn<T>>()?,
806 ),
807 None => self.append_inner(buf, offset, obj.downcast::<PyArrayDyn<T>>()?),
808 }
809 }
810
811 fn append_vec<'py>(
812 &mut self,
813 v: &mut Vec<u8>,
814 start_addr: Option<usize>,
815 obj: &Bound<'py, PyAny>,
816 ) -> PyResult<()> {
817 let preprocessor_fn_option = match &self.config {
818 NumpySerdeConfig::DYNAMIC {
819 preprocessor_fn, ..
820 } => preprocessor_fn,
821 NumpySerdeConfig::STATIC {
822 preprocessor_fn, ..
823 } => preprocessor_fn,
824 };
825 match preprocessor_fn_option {
826 Some(preprocessor_fn) => self.append_inner_vec(
827 v,
828 start_addr,
829 preprocessor_fn
830 .bind(obj.py())
831 .call1((obj,))?
832 .downcast::<PyArrayDyn<T>>()?,
833 ),
834 None => self.append_inner_vec(v, start_addr, obj.downcast::<PyArrayDyn<T>>()?),
835 }
836 }
837
838 fn retrieve<'py>(
839 &mut self,
840 py: Python<'py>,
841 buf: &[u8],
842 offset: usize,
843 ) -> PyResult<(Bound<'py, PyAny>, usize)> {
844 let (array, offset) = self.retrieve_inner(py, buf, offset)?;
845
846 let postprocessor_fn_option = match &self.config {
847 NumpySerdeConfig::DYNAMIC {
848 postprocessor_fn, ..
849 } => postprocessor_fn,
850 NumpySerdeConfig::STATIC {
851 postprocessor_fn, ..
852 } => postprocessor_fn,
853 };
854
855 Ok(match postprocessor_fn_option {
856 Some(postprocessor_fn) => (postprocessor_fn.bind(py).call1((array, offset))?, offset),
857 None => (array.into_any(), offset),
858 })
859 }
860}
861
862static GC: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
863fn get_ref_types<'py>(o: &Bound<'py, PyAny>, recursion: usize) -> PyResult<Bound<'py, PyAny>> {
864 let py = o.py();
865 let gc = GC
866 .get_or_try_init(py, || Ok::<_, PyErr>(py.import("gc")?.unbind()))?
867 .bind(o.py());
868 let referrers = gc
869 .call_method1(intern!(py, "get_referrers"), (o,))?
870 .downcast_into::<PyList>()?;
871 if recursion > 0 {
872 Ok(PyDict::from_sequence(
873 &referrers
874 .iter()
875 .map(|referrer| {
876 Ok::<_, PyErr>((
877 referrer.get_type().repr()?.to_string(),
878 get_ref_types(&referrer, recursion - 1)?,
879 ))
880 })
881 .collect::<PyResult<Vec<_>>>()?
882 .into_pyobject(py)?,
883 )?
884 .into_any())
885 } else {
886 referrers
887 .iter()
888 .map(|referrer| Ok::<_, PyErr>(referrer.get_type().repr()?.to_string()))
889 .collect::<PyResult<Vec<_>>>()?
890 .into_pyobject(py)
891 }
892}