use std::collections::HashSet;
use pyo3::exceptions::asyncio::InvalidStateError;
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyCFunction, PyDict, PyString, PyTuple, PyType};
use pyo3::{prelude::*, PyTypeInfo};
use strum_macros::Display;
use crate::communication::{append_string_vec, retrieve_string, retrieve_usize};
use crate::PyAnySerde;
#[derive(Clone)]
pub struct DataclassSerde {
class: PyObject,
init_strategy: InternalInitStrategy,
field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
}
#[pyclass]
#[derive(Clone)]
pub struct PickleableInitStrategy(pub Option<InitStrategy>);
#[pymethods]
impl PickleableInitStrategy {
#[new]
#[pyo3(signature = (*args))]
fn new<'py>(args: Bound<'py, PyTuple>) -> PyResult<Self> {
let vec_args = args.iter().collect::<Vec<_>>();
if vec_args.len() > 1 {
return Err(PyValueError::new_err(format!(
"PickleableInitStrategy constructor takes 0 or 1 parameters, received {}",
args.as_any().repr()?.to_str()?
)));
}
if vec_args.len() == 1 {
Ok(PickleableInitStrategy(
vec_args[0].extract::<Option<InitStrategy>>()?,
))
} else {
Ok(PickleableInitStrategy(None))
}
}
pub fn __getstate__(&self) -> Vec<u8> {
match self.0.as_ref().unwrap() {
InitStrategy::ALL {} => vec![0],
InitStrategy::SOME { kwargs } => {
let mut bytes = vec![1];
bytes.extend_from_slice(&kwargs.len().to_ne_bytes());
for kwarg in kwargs.iter() {
append_string_vec(&mut bytes, kwarg);
}
bytes
}
InitStrategy::NONE {} => vec![2],
}
}
pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
let buf = &state[..];
let type_byte = buf[0];
let mut offset = 1;
self.0 = Some(match type_byte {
0 => InitStrategy::ALL {},
1 => {
let n_kwargs;
(n_kwargs, offset) = retrieve_usize(buf, offset)?;
let mut kwargs = Vec::with_capacity(n_kwargs);
for _ in 0..n_kwargs {
let kwarg;
(kwarg, offset) = retrieve_string(buf, offset)?;
kwargs.push(kwarg)
}
InitStrategy::SOME { kwargs }
}
2 => InitStrategy::NONE {},
v => Err(InvalidStateError::new_err(format!(
"Got invalid type byte for InitStrategy: {v}"
)))?,
});
Ok(())
}
}
#[pyclass]
#[derive(Clone, Debug, PartialEq, Display)]
pub enum InitStrategy {
ALL {},
SOME { kwargs: Vec<String> },
NONE {},
}
macro_rules! create_union {
($handler:expr, $py:expr, $($type:ident),+) => {{
let mut union_list = Vec::new();
$(
union_list.push(
$handler.call_method1(
"generate_schema",
(paste::paste! { [<InitStrategy_ $type>]::type_object($py) },)
)?
);
)+
Ok::<_, PyErr>(union_list)
}};
}
fn get_enum_subclass_before_validator_fn<'py>(
cls: &Bound<'py, PyType>,
) -> PyResult<Bound<'py, PyCFunction>> {
let _py = cls.py();
let py_cls = cls.clone().unbind();
let func = move |args: &Bound<'_, PyTuple>,
_kwargs: Option<&Bound<'_, PyDict>>|
-> PyResult<PyObject> {
let py = args.py();
let data = args.get_item(0)?;
let cls = py_cls.bind(py);
if cls.eq(InitStrategy_ALL::type_object(py))? {
Ok(InitStrategy::ALL {}.into_pyobject(py)?.into_any().unbind())
} else if cls.eq(InitStrategy_SOME::type_object(py))? {
let kwargs = data.get_item("kwargs")?.extract::<Vec<String>>()?;
Ok(InitStrategy::SOME { kwargs }
.into_pyobject(py)?
.into_any()
.unbind())
} else if cls.eq(InitStrategy_NONE::type_object(py))? {
Ok(InitStrategy::NONE {}.into_pyobject(py)?.into_any().unbind())
} else {
Err(PyValueError::new_err(format!(
"Unexpected class: {}",
cls.repr()?.to_str()?
)))
}
};
PyCFunction::new_closure(_py, None, None, func)
}
fn get_enum_subclass_typed_dict_schema<'py>(
cls: &Bound<'py, PyType>,
core_schema: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
let py = cls.py();
let typed_dict_schema = core_schema.getattr("typed_dict_schema")?;
let typed_dict_field = core_schema.getattr("typed_dict_field")?;
let str_schema = core_schema.getattr("str_schema")?;
let list_schema = core_schema.getattr("list_schema")?;
let cls_name = cls.name()?.to_string();
let (_, enum_subclass) = cls_name.split_once("_").unwrap();
let typed_dict_fields = PyDict::new(py);
typed_dict_fields.set_item(
"type",
typed_dict_field.call1((str_schema.call(
(),
Some(&PyDict::from_sequence(
&vec![(
"pattern",
vec![
"^".to_owned(),
enum_subclass.to_ascii_lowercase(),
"$".to_owned(),
]
.join("")
.into_pyobject(py)?
.into_any(),
)]
.into_pyobject(py)?,
)?),
)?,))?,
)?;
if cls.eq(InitStrategy_SOME::type_object(py))? {
typed_dict_fields.set_item(
"kwargs",
typed_dict_field.call1((list_schema.call1((str_schema.call0()?,))?,))?,
)?;
}
typed_dict_schema.call1((typed_dict_fields,))
}
#[pymethods]
impl InitStrategy {
#[classmethod]
fn __get_pydantic_core_schema__<'py>(
cls: &Bound<'py, PyType>,
_source_type: Bound<'py, PyAny>,
handler: Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyAny>> {
let py = cls.py();
let core_schema = py.import("pydantic_core")?.getattr("core_schema")?;
if cls.eq(InitStrategy::type_object(py))? {
let union_list = create_union!(handler, py, ALL, SOME, NONE)?;
return core_schema.call_method1("union_schema", (union_list,));
}
let python_schema = core_schema.getattr("is_instance_schema")?.call1((cls,))?;
core_schema.getattr("json_or_python_schema")?.call1((
core_schema.getattr("chain_schema")?.call1((vec![
get_enum_subclass_typed_dict_schema(cls, &core_schema)?,
core_schema
.getattr("no_info_before_validator_function")?
.call1((get_enum_subclass_before_validator_fn(cls)?, &python_schema))?,
],))?,
python_schema,
))
}
pub fn to_json(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
let data = PyDict::new(py);
data.set_item("type", self.to_string().to_ascii_lowercase())?;
if let InitStrategy::SOME { kwargs } = self {
data.set_item("kwargs", kwargs)?;
}
Ok(data.into_any().unbind())
})
}
}
#[derive(Clone, Debug)]
pub enum InternalInitStrategy {
ALL(Py<PyDict>),
SOME(Py<PyDict>, HashSet<usize>),
NONE,
}
impl DataclassSerde {
pub fn new(
class: PyObject,
init_strategy: InitStrategy,
field_serde_kv_list: Vec<(Py<PyString>, Box<dyn PyAnySerde>)>,
) -> PyResult<Self> {
let internal_init_strategy = match &init_strategy {
InitStrategy::ALL {} => Python::with_gil::<_, PyResult<_>>(|py| {
let kwargs_kv_list = field_serde_kv_list
.iter()
.map(|(field, _)| (field, None::<PyObject>))
.collect::<Vec<_>>();
let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
Ok(InternalInitStrategy::ALL(kwargs))
})?,
InitStrategy::SOME { kwargs } => Python::with_gil::<_, PyResult<_>>(|py| {
let init_field_idxs = kwargs.iter().map(|init_field| field_serde_kv_list.iter().position(|(field, _)| field.to_string() == *init_field).ok_or_else(|| PyValueError::new_err(format!("field name {} provided in InitStrategy_SOME not contained in field_serde_kv_list", init_field)))).collect::<PyResult<HashSet<_>>>()?;
let kwargs_kv_list = field_serde_kv_list
.iter()
.enumerate()
.filter(|(idx, _)| init_field_idxs.contains(idx))
.map(|(_, (field, _))| (field, None::<PyObject>))
.collect::<Vec<_>>();
let kwargs = PyDict::from_sequence(&kwargs_kv_list.into_pyobject(py)?)?.unbind();
Ok(InternalInitStrategy::SOME(kwargs, init_field_idxs))
})?,
InitStrategy::NONE {} => InternalInitStrategy::NONE,
};
Ok(DataclassSerde {
class,
init_strategy: internal_init_strategy,
field_serde_kv_list,
})
}
}
impl PyAnySerde for DataclassSerde {
fn append<'py>(
&mut self,
buf: &mut [u8],
mut offset: usize,
obj: &Bound<'py, PyAny>,
) -> PyResult<usize> {
for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
offset = pyany_serde.append(buf, offset, &obj.getattr(&*field)?)?;
}
Ok(offset)
}
fn append_vec<'py>(
&mut self,
v: &mut Vec<u8>,
start_addr: Option<usize>,
obj: &Bound<'py, PyAny>,
) -> PyResult<()> {
for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
pyany_serde.append_vec(v, start_addr, &obj.getattr(&*field)?)?;
}
Ok(())
}
fn retrieve<'py>(
&mut self,
py: Python<'py>,
buf: &[u8],
mut offset: usize,
) -> PyResult<(Bound<'py, PyAny>, usize)> {
let mut kv_list = Vec::with_capacity(self.field_serde_kv_list.len());
for (field, pyany_serde) in self.field_serde_kv_list.iter_mut() {
let field_value;
(field_value, offset) = pyany_serde.retrieve(py, buf, offset)?;
kv_list.push((field.clone_ref(py).into_bound(py), field_value));
}
let class = self.class.bind(py);
let obj = match &self.init_strategy {
InternalInitStrategy::ALL(py_kwargs) => {
let kwargs = py_kwargs.bind(py);
for (field, field_value) in kv_list.iter() {
kwargs.set_item(field, field_value)?;
}
class.call((), Some(kwargs))?
}
InternalInitStrategy::SOME(py_kwargs, init_field_idxs) => {
let kwargs = py_kwargs.bind(py);
let (init_kv_list, other_kv_list) = kv_list
.into_iter()
.enumerate()
.partition::<Vec<_>, _>(|(idx, _)| init_field_idxs.contains(idx));
for (_, (field, field_value)) in init_kv_list.iter() {
kwargs.set_item(field, field_value)?;
}
let obj = class.call((), Some(kwargs))?;
for (_, (field, field_value)) in other_kv_list.iter() {
obj.setattr(field, field_value)?;
}
obj
}
InternalInitStrategy::NONE => {
let obj = class.call0()?;
for (field, field_value) in kv_list.iter() {
obj.setattr(field, field_value)?;
}
obj
}
};
Ok((obj, offset))
}
}