1use std::collections::HashSet;
2
3use pyo3::exceptions::asyncio::InvalidStateError;
4use pyo3::exceptions::PyValueError;
5use pyo3::types::{PyCFunction, PyDict, PyString, PyTuple, PyType};
6use pyo3::{prelude::*, PyTypeInfo};
7use strum_macros::Display;
8
9use crate::communication::{append_string_vec, retrieve_string, retrieve_usize};
10use crate::PyAnySerde;
11
12#[derive(Clone)]
13pub struct DataclassSerde {
14 class: PyObject,
15 init_strategy: InternalInitStrategy,
16 field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
17}
18
19#[pyclass]
20#[derive(Clone)]
21pub struct PickleableInitStrategy(pub Option<InitStrategy>);
22
23#[pymethods]
24impl PickleableInitStrategy {
25 #[new]
26 #[pyo3(signature = (*args))]
27 fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
28 let vec_args = args.iter().collect::<Vec<_>>();
29 if vec_args.len() > 1 {
30 return Err(PyValueError::new_err(format!(
31 "PickleableInitStrategy constructor takes 0 or 1 parameters, received {}",
32 args.as_any().repr()?.to_str()?
33 )));
34 }
35 if vec_args.len() == 1 {
36 Ok(PickleableInitStrategy(
37 vec_args[0].extract::<Option<InitStrategy>>()?,
38 ))
39 } else {
40 Ok(PickleableInitStrategy(None))
41 }
42 }
43 pub fn __getstate__(&self) -> Vec<u8> {
44 match self.0.as_ref().unwrap() {
45 InitStrategy::ALL {} => vec![0],
46 InitStrategy::SOME { kwargs } => {
47 let mut bytes = vec![1];
48 bytes.extend_from_slice(&kwargs.len().to_ne_bytes());
49 for kwarg in kwargs.iter() {
50 append_string_vec(&mut bytes, kwarg);
51 }
52 bytes
53 }
54 InitStrategy::NONE {} => vec![2],
55 }
56 }
57 pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
58 let buf = &state[..];
59 let type_byte = buf[0];
60 let mut offset = 1;
61 self.0 = Some(match type_byte {
62 0 => InitStrategy::ALL {},
63 1 => {
64 let n_kwargs;
65 (n_kwargs, offset) = retrieve_usize(buf, offset)?;
66 let mut kwargs = Vec::with_capacity(n_kwargs);
67 for _ in 0..n_kwargs {
68 let kwarg;
69 (kwarg, offset) = retrieve_string(buf, offset)?;
70 kwargs.push(kwarg)
71 }
72 InitStrategy::SOME { kwargs }
73 }
74 2 => InitStrategy::NONE {},
75 v => Err(InvalidStateError::new_err(format!(
76 "Got invalid type byte for InitStrategy: {v}"
77 )))?,
78 });
79 Ok(())
80 }
81}
82
83#[pyclass]
84#[derive(Clone, Debug, PartialEq, Display)]
85pub enum InitStrategy {
86 ALL {},
87 SOME { kwargs: Vec<String> },
88 NONE {},
89}
90
91macro_rules! create_union {
92 ($handler:expr, $py:expr, $($type:ident),+) => {{
93 let mut union_list = Vec::new();
94 $(
95 union_list.push(
96 $handler.call_method1(
97 "generate_schema",
98 (paste::paste! { [<InitStrategy_ $type>]::type_object($py) },)
99 )?
100 );
101 )+
102 Ok::<_, PyErr>(union_list)
103 }};
104}
105
106fn get_enum_subclass_before_validator_fn<'py>(
107 cls: &Bound<'py, PyType>,
108) -> PyResult<Bound<'py, PyCFunction>> {
109 let _py = cls.py();
110 let py_cls = cls.clone().unbind();
111 let func = move |args: &Bound<'_, PyTuple>,
112 _kwargs: Option<&Bound<'_, PyDict>>|
113 -> PyResult<PyObject> {
114 let py = args.py();
115 let data = args.get_item(0)?;
116 let cls = py_cls.bind(py);
117 if cls.eq(InitStrategy_ALL::type_object(py))? {
118 Ok(InitStrategy::ALL {}.into_pyobject(py)?.into_any().unbind())
119 } else if cls.eq(InitStrategy_SOME::type_object(py))? {
120 let kwargs = data.get_item("kwargs")?.extract::<Vec<String>>()?;
121 Ok(InitStrategy::SOME { kwargs }
122 .into_pyobject(py)?
123 .into_any()
124 .unbind())
125 } else if cls.eq(InitStrategy_NONE::type_object(py))? {
126 Ok(InitStrategy::NONE {}.into_pyobject(py)?.into_any().unbind())
127 } else {
128 Err(PyValueError::new_err(format!(
129 "Unexpected class: {}",
130 cls.repr()?.to_str()?
131 )))
132 }
133 };
134 PyCFunction::new_closure(_py, None, None, func)
135}
136
137fn get_enum_subclass_typed_dict_schema<'py>(
138 cls: &Bound<'py, PyType>,
139 core_schema: &Bound<'py, PyAny>,
140) -> PyResult<Bound<'py, PyAny>> {
141 let py = cls.py();
142 let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
143 let typed_dict_field = core_schema.getattr("typed_dict_field")?;
144 let str_schema = core_schema.getattr("str_schema")?;
145 let list_schema = core_schema.getattr("list_schema")?;
146 let cls_name = cls.name()?.to_string();
147 let (_, enum_subclass) = cls_name.split_once("_").unwrap();
148 let typed_dict_fields = PyDict::new(py);
149 typed_dict_fields.set_item(
150 "type",
151 typed_dict_field.call1((str_schema.call(
152 (),
153 Some(&PyDict::from_sequence(
154 &vec![(
155 "pattern",
156 vec![
157 "^".to_owned(),
158 enum_subclass.to_ascii_lowercase(),
159 "$".to_owned(),
160 ]
161 .join("")
162 .into_pyobject(py)?
163 .into_any(),
164 )]
165 .into_pyobject(py)?,
166 )?),
167 )?,))?,
168 )?;
169 if cls.eq(InitStrategy_SOME::type_object(py))? {
170 typed_dict_fields.set_item(
171 "kwargs",
172 typed_dict_field.call1((list_schema.call1((str_schema.call0()?,))?,))?,
173 )?;
174 }
175 typed_dict_schema.call1((typed_dict_fields,))
176}
177
178#[pymethods]
179impl InitStrategy {
180 #[classmethod]
182 fn __get_pydantic_core_schema__<'py>(
183 cls: &Bound<'py, PyType>,
184 _source_type: Bound<'py, PyAny>,
185 handler: Bound<'py, PyAny>,
186 ) -> PyResult<Bound<'py, PyAny>> {
187 let py = cls.py();
188 let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
189 if cls.eq(InitStrategy::type_object(py))? {
190 let union_list = create_union!(handler, py, ALL, SOME, NONE)?;
191 return core_schema.call_method1("union_schema", (union_list,));
192 }
193 let python_schema = core_schema.getattr("is_instance_schema")?.call1((cls,))?;
194 core_schema.getattr("json_or_python_schema")?.call1((
195 core_schema.getattr("chain_schema")?.call1((vec![
196 get_enum_subclass_typed_dict_schema(cls, &core_schema)?,
197 core_schema
198 .getattr("no_info_before_validator_function")?
199 .call1((get_enum_subclass_before_validator_fn(cls)?, &python_schema))?,
200 ],))?,
201 python_schema,
202 ))
203 }
204
205 pub fn to_json(&self) -> PyResult<PyObject> {
206 Python::with_gil(|py| {
207 let data = PyDict::new(py);
208 data.set_item("type", self.to_string().to_ascii_lowercase())?;
209 if let InitStrategy::SOME { kwargs } = self {
210 data.set_item("kwargs", kwargs)?;
211 }
212 Ok(data.into_any().unbind())
213 })
214 }
215}
216
217#[derive(Clone, Debug)]
218pub enum InternalInitStrategy {
219 ALL(Py<PyDict>),
220 SOME(Py<PyDict>, HashSet<usize>),
221 NONE,
222}
223
224impl DataclassSerde {
225 pub fn new(
226 class: PyObject,
227 init_strategy: InitStrategy,
228 field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
229 ) -> PyResult<Self> {
230 let internal_init_strategy = match &init_strategy {
231 InitStrategy::ALL {} => Python::with_gil::<_, PyResult<_>>(|py| {
232 let kwargs_kv_list = field_serde_kv_list
233 .iter()
234 .map(|(field, _)| (field, None::<PyObject>))
235 .collect::<Vec<_>>();
236 let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
237 Ok(InternalInitStrategy::ALL(kwargs))
238 })?,
239 InitStrategy::SOME { kwargs } => Python::with_gil::<_, PyResult<_>>(|py| {
240 let init_field_idxs = kwargs.iter().map(|init_field| field_serde_kv_list.iter().position(|(field, _)| field.to_string() == *init_field).ok_or_else(|| PyValueError::new_err(format!("field name {} provided in InitStrategy_SOME not contained in field_serde_kv_list", init_field)))).collect::<PyResult<HashSet<_>>>()?;
241 let kwargs_kv_list = field_serde_kv_list
242 .iter()
243 .enumerate()
244 .filter(|(idx, _)| init_field_idxs.contains(idx))
245 .map(|(_, (field, _))| (field, None::<PyObject>))
246 .collect::<Vec<_>>();
247 let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
248 Ok(InternalInitStrategy::SOME(kwargs, init_field_idxs))
249 })?,
250 InitStrategy::NONE {} => InternalInitStrategy::NONE,
251 };
252 Ok(DataclassSerde {
253 class,
254 init_strategy: internal_init_strategy,
255 field_serde_kv_list,
256 })
257 }
258}
259
260impl PyAnySerde for DataclassSerde {
261 fn append<'py>(
262 &mut self,
263 buf: &mut [u8],
264 mut offset: usize,
265 obj: &Bound<'py, PyAny>,
266 ) -> PyResult<usize> {
267 for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
268 offset = pyany_serde.append(buf, offset, &obj.getattr(&*field)?)?;
269 }
270 Ok(offset)
271 }
272
273 fn append_vec<'py>(
274 &mut self,
275 v: &mut Vec<u8>,
276 start_addr: Option<usize>,
277 obj: &Bound<'py, PyAny>,
278 ) -> PyResult<()> {
279 for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
280 pyany_serde.append_vec(v, start_addr, &obj.getattr(&*field)?)?;
281 }
282 Ok(())
283 }
284
285 fn retrieve<'py>(
286 &mut self,
287 py: Python<'py>,
288 buf: &[u8],
289 mut offset: usize,
290 ) -> PyResult<(Bound<'py, PyAny>, usize)> {
291 let mut kv_list = Vec::with_capacity(self.field_serde_kv_list.len());
292 for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
293 let field_value;
294 (field_value, offset) = pyany_serde.retrieve(py, buf, offset)?;
295 kv_list.push((field.clone_ref(py).into_bound(py), field_value));
296 }
297 let class = self.class.bind(py);
298 let obj = match &self.init_strategy {
299 InternalInitStrategy::ALL(py_kwargs) => {
300 let kwargs = py_kwargs.bind(py);
301 for (field, field_value) in kv_list.iter() {
302 kwargs.set_item(field, field_value)?;
303 }
304 class.call((), Some(kwargs))?
305 }
306 InternalInitStrategy::SOME(py_kwargs, init_field_idxs) => {
307 let kwargs = py_kwargs.bind(py);
308 let (init_kv_list, other_kv_list) = kv_list
309 .into_iter()
310 .enumerate()
311 .partition::<Vec<_>, _>(|(idx, _)| init_field_idxs.contains(idx));
312 for (_, (field, field_value)) in init_kv_list.iter() {
313 kwargs.set_item(field, field_value)?;
314 }
315 let obj = class.call((), Some(kwargs))?;
316 for (_, (field, field_value)) in other_kv_list.iter() {
317 obj.setattr(field, field_value)?;
318 }
319 obj
320 }
321 InternalInitStrategy::NONE => {
322 let obj = class.call0()?;
323 for (field, field_value) in kv_list.iter() {
324 obj.setattr(field, field_value)?;
325 }
326 obj
327 }
328 };
329 Ok((obj, offset))
330 }
331}