pyany_serde/pyany_serde_impl/
dataclass_serde.rs

1use std::collections::HashSet;
2
3use pyo3::exceptions::asyncio::InvalidStateError;
4use pyo3::exceptions::PyValueError;
5use pyo3::types::{PyCFunction, PyDict, PyString, PyTuple, PyType};
6use pyo3::{prelude::*, PyTypeInfo};
7use strum_macros::Display;
8
9use crate::communication::{append_string_vec, retrieve_string, retrieve_usize};
10use crate::PyAnySerde;
11
12#[derive(Clone)]
13pub struct DataclassSerde {
14    class: PyObject,
15    init_strategy: InternalInitStrategy,
16    field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
17}
18
19#[pyclass]
20#[derive(Clone)]
21pub struct PickleableInitStrategy(pub Option<InitStrategy>);
22
23#[pymethods]
24impl PickleableInitStrategy {
25    #[new]
26    #[pyo3(signature = (*args))]
27    fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
28        let vec_args = args.iter().collect::<Vec<_>>();
29        if vec_args.len() > 1 {
30            return Err(PyValueError::new_err(format!(
31                "PickleableInitStrategy constructor takes 0 or 1 parameters, received {}",
32                args.as_any().repr()?.to_str()?
33            )));
34        }
35        if vec_args.len() == 1 {
36            Ok(PickleableInitStrategy(
37                vec_args[0].extract::<Option<InitStrategy>>()?,
38            ))
39        } else {
40            Ok(PickleableInitStrategy(None))
41        }
42    }
43    pub fn __getstate__(&self) -> Vec<u8> {
44        match self.0.as_ref().unwrap() {
45            InitStrategy::ALL {} => vec![0],
46            InitStrategy::SOME { kwargs } => {
47                let mut bytes = vec![1];
48                bytes.extend_from_slice(&kwargs.len().to_ne_bytes());
49                for kwarg in kwargs.iter() {
50                    append_string_vec(&mut bytes, kwarg);
51                }
52                bytes
53            }
54            InitStrategy::NONE {} => vec![2],
55        }
56    }
57    pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
58        let buf = &state[..];
59        let type_byte = buf[0];
60        let mut offset = 1;
61        self.0 = Some(match type_byte {
62            0 => InitStrategy::ALL {},
63            1 => {
64                let n_kwargs;
65                (n_kwargs, offset) = retrieve_usize(buf, offset)?;
66                let mut kwargs = Vec::with_capacity(n_kwargs);
67                for _ in 0..n_kwargs {
68                    let kwarg;
69                    (kwarg, offset) = retrieve_string(buf, offset)?;
70                    kwargs.push(kwarg)
71                }
72                InitStrategy::SOME { kwargs }
73            }
74            2 => InitStrategy::NONE {},
75            v => Err(InvalidStateError::new_err(format!(
76                "Got invalid type byte for InitStrategy: {v}"
77            )))?,
78        });
79        Ok(())
80    }
81}
82
83#[pyclass]
84#[derive(Clone, Debug, PartialEq, Display)]
85pub enum InitStrategy {
86    ALL {},
87    SOME { kwargs: Vec<String> },
88    NONE {},
89}
90
91macro_rules! create_union {
92    ($handler:expr, $py:expr, $($type:ident),+) => {{
93        let mut union_list = Vec::new();
94        $(
95            union_list.push(
96                $handler.call_method1(
97                    "generate_schema",
98                    (paste::paste! { [<InitStrategy_ $type>]::type_object($py) },)
99                )?
100            );
101        )+
102        Ok::<_, PyErr>(union_list)
103    }};
104}
105
106fn get_enum_subclass_before_validator_fn<'py>(
107    cls: &Bound<'py, PyType>,
108) -> PyResult<Bound<'py, PyCFunction>> {
109    let _py = cls.py();
110    let py_cls = cls.clone().unbind();
111    let func = move |args: &Bound<'_, PyTuple>,
112                     _kwargs: Option<&Bound<'_, PyDict>>|
113          -> PyResult<PyObject> {
114        let py = args.py();
115        let data = args.get_item(0)?;
116        let cls = py_cls.bind(py);
117        if cls.eq(InitStrategy_ALL::type_object(py))? {
118            Ok(InitStrategy::ALL {}.into_pyobject(py)?.into_any().unbind())
119        } else if cls.eq(InitStrategy_SOME::type_object(py))? {
120            let kwargs = data.get_item("kwargs")?.extract::<Vec<String>>()?;
121            Ok(InitStrategy::SOME { kwargs }
122                .into_pyobject(py)?
123                .into_any()
124                .unbind())
125        } else if cls.eq(InitStrategy_NONE::type_object(py))? {
126            Ok(InitStrategy::NONE {}.into_pyobject(py)?.into_any().unbind())
127        } else {
128            Err(PyValueError::new_err(format!(
129                "Unexpected class: {}",
130                cls.repr()?.to_str()?
131            )))
132        }
133    };
134    PyCFunction::new_closure(_py, None, None, func)
135}
136
137fn get_enum_subclass_typed_dict_schema<'py>(
138    cls: &Bound<'py, PyType>,
139    core_schema: &Bound<'py, PyAny>,
140) -> PyResult<Bound<'py, PyAny>> {
141    let py = cls.py();
142    let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
143    let typed_dict_field = core_schema.getattr("typed_dict_field")?;
144    let str_schema = core_schema.getattr("str_schema")?;
145    let list_schema = core_schema.getattr("list_schema")?;
146    let cls_name = cls.name()?.to_string();
147    let (_, enum_subclass) = cls_name.split_once("_").unwrap();
148    let typed_dict_fields = PyDict::new(py);
149    typed_dict_fields.set_item(
150        "type",
151        typed_dict_field.call1((str_schema.call(
152            (),
153            Some(&PyDict::from_sequence(
154                &vec![(
155                    "pattern",
156                    vec![
157                        "^".to_owned(),
158                        enum_subclass.to_ascii_lowercase(),
159                        "$".to_owned(),
160                    ]
161                    .join("")
162                    .into_pyobject(py)?
163                    .into_any(),
164                )]
165                .into_pyobject(py)?,
166            )?),
167        )?,))?,
168    )?;
169    if cls.eq(InitStrategy_SOME::type_object(py))? {
170        typed_dict_fields.set_item(
171            "kwargs",
172            typed_dict_field.call1((list_schema.call1((str_schema.call0()?,))?,))?,
173        )?;
174    }
175    typed_dict_schema.call1((typed_dict_fields,))
176}
177
178#[pymethods]
179impl InitStrategy {
180    // pydantic methods
181    #[classmethod]
182    fn __get_pydantic_core_schema__<'py>(
183        cls: &Bound<'py, PyType>,
184        _source_type: Bound<'py, PyAny>,
185        handler: Bound<'py, PyAny>,
186    ) -> PyResult<Bound<'py, PyAny>> {
187        let py = cls.py();
188        let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
189        if cls.eq(InitStrategy::type_object(py))? {
190            let union_list = create_union!(handler, py, ALL, SOME, NONE)?;
191            return core_schema.call_method1("union_schema", (union_list,));
192        }
193        let python_schema = core_schema.getattr("is_instance_schema")?.call1((cls,))?;
194        core_schema.getattr("json_or_python_schema")?.call1((
195            core_schema.getattr("chain_schema")?.call1((vec![
196                get_enum_subclass_typed_dict_schema(cls, &core_schema)?,
197                core_schema
198                    .getattr("no_info_before_validator_function")?
199                    .call1((get_enum_subclass_before_validator_fn(cls)?, &python_schema))?,
200            ],))?,
201            python_schema,
202        ))
203    }
204
205    pub fn to_json(&self) -> PyResult<PyObject> {
206        Python::with_gil(|py| {
207            let data = PyDict::new(py);
208            data.set_item("type", self.to_string().to_ascii_lowercase())?;
209            if let InitStrategy::SOME { kwargs } = self {
210                data.set_item("kwargs", kwargs)?;
211            }
212            Ok(data.into_any().unbind())
213        })
214    }
215}
216
217#[derive(Clone, Debug)]
218pub enum InternalInitStrategy {
219    ALL(Py<PyDict>),
220    SOME(Py<PyDict>, HashSet<usize>),
221    NONE,
222}
223
224impl DataclassSerde {
225    pub fn new(
226        class: PyObject,
227        init_strategy: InitStrategy,
228        field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
229    ) -> PyResult<Self> {
230        let internal_init_strategy = match &init_strategy {
231            InitStrategy::ALL {} => Python::with_gil::<_, PyResult<_>>(|py| {
232                let kwargs_kv_list = field_serde_kv_list
233                    .iter()
234                    .map(|(field, _)| (field, None::<PyObject>))
235                    .collect::<Vec<_>>();
236                let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
237                Ok(InternalInitStrategy::ALL(kwargs))
238            })?,
239            InitStrategy::SOME { kwargs } => Python::with_gil::<_, PyResult<_>>(|py| {
240                let init_field_idxs = kwargs.iter().map(|init_field| field_serde_kv_list.iter().position(|(field, _)| field.to_string() == *init_field).ok_or_else(|| PyValueError::new_err(format!("field name {} provided in InitStrategy_SOME not contained in field_serde_kv_list", init_field)))).collect::<PyResult<HashSet<_>>>()?;
241                let kwargs_kv_list = field_serde_kv_list
242                    .iter()
243                    .enumerate()
244                    .filter(|(idx, _)| init_field_idxs.contains(idx))
245                    .map(|(_, (field, _))| (field, None::<PyObject>))
246                    .collect::<Vec<_>>();
247                let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
248                Ok(InternalInitStrategy::SOME(kwargs, init_field_idxs))
249            })?,
250            InitStrategy::NONE {} => InternalInitStrategy::NONE,
251        };
252        Ok(DataclassSerde {
253            class,
254            init_strategy: internal_init_strategy,
255            field_serde_kv_list,
256        })
257    }
258}
259
260impl PyAnySerde for DataclassSerde {
261    fn append<'py>(
262        &mut self,
263        buf: &mut [u8],
264        mut offset: usize,
265        obj: &Bound<'py, PyAny>,
266    ) -> PyResult<usize> {
267        for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
268            offset = pyany_serde.append(buf, offset, &obj.getattr(&*field)?)?;
269        }
270        Ok(offset)
271    }
272
273    fn append_vec<'py>(
274        &mut self,
275        v: &mut Vec<u8>,
276        start_addr: Option<usize>,
277        obj: &Bound<'py, PyAny>,
278    ) -> PyResult<()> {
279        for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
280            pyany_serde.append_vec(v, start_addr, &obj.getattr(&*field)?)?;
281        }
282        Ok(())
283    }
284
285    fn retrieve<'py>(
286        &mut self,
287        py: Python<'py>,
288        buf: &[u8],
289        mut offset: usize,
290    ) -> PyResult<(Bound<'py, PyAny>, usize)> {
291        let mut kv_list = Vec::with_capacity(self.field_serde_kv_list.len());
292        for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
293            let field_value;
294            (field_value, offset) = pyany_serde.retrieve(py, buf, offset)?;
295            kv_list.push((field.clone_ref(py).into_bound(py), field_value));
296        }
297        let class = self.class.bind(py);
298        let obj = match &self.init_strategy {
299            InternalInitStrategy::ALL(py_kwargs) => {
300                let kwargs = py_kwargs.bind(py);
301                for (field, field_value) in kv_list.iter() {
302                    kwargs.set_item(field, field_value)?;
303                }
304                class.call((), Some(kwargs))?
305            }
306            InternalInitStrategy::SOME(py_kwargs, init_field_idxs) => {
307                let kwargs = py_kwargs.bind(py);
308                let (init_kv_list, other_kv_list) = kv_list
309                    .into_iter()
310                    .enumerate()
311                    .partition::<Vec<_>, _>(|(idx, _)| init_field_idxs.contains(idx));
312                for (_, (field, field_value)) in init_kv_list.iter() {
313                    kwargs.set_item(field, field_value)?;
314                }
315                let obj = class.call((), Some(kwargs))?;
316                for (_, (field, field_value)) in other_kv_list.iter() {
317                    obj.setattr(field, field_value)?;
318                }
319                obj
320            }
321            InternalInitStrategy::NONE => {
322                let obj = class.call0()?;
323                for (field, field_value) in kv_list.iter() {
324                    obj.setattr(field, field_value)?;
325                }
326                obj
327            }
328        };
329        Ok((obj, offset))
330    }
331}