use crate::{
dataclass::dataclass_as_dict,
error::{Error, Result},
pydantic::pydantic_model_as_dict,
};
use pyo3::{types::*, Bound};
use serde::{
de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
forward_to_deserialize_any, Deserialize, Deserializer,
};
pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
let any = any.into_any();
T::deserialize(PyAnyDeserializer(any))
}
struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.0.is_instance_of::<PyDict>() {
return visitor.visit_map(MapDeserializer::new(self.0.cast()?));
}
if self.0.is_instance_of::<PyList>() {
return visitor.visit_seq(SeqDeserializer::from_list(self.0.cast()?));
}
if self.0.is_instance_of::<PyTuple>() {
return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.cast()?));
}
if self.0.is_instance_of::<PyString>() {
return visitor.visit_str(&self.0.extract::<String>()?);
}
if self.0.is_instance_of::<PyBool>() {
return visitor.visit_bool(self.0.extract()?);
}
if self.0.is_instance_of::<PyInt>() {
return visitor.visit_i64(self.0.extract()?);
}
if self.0.is_instance_of::<PyFloat>() {
return visitor.visit_f64(self.0.extract()?);
}
if let Some(dict) = dataclass_as_dict(self.0.py(), &self.0)? {
return visitor.visit_map(MapDeserializer::new(&dict));
}
if let Some(dict) = pydantic_model_as_dict(self.0.py(), &self.0)? {
return visitor.visit_map(MapDeserializer::new(&dict));
}
if self.0.hasattr("__dict__")? {
return visitor.visit_map(MapDeserializer::new(self.0.getattr("__dict__")?.cast()?));
}
if self.0.hasattr("__slots__")? {
return visitor.visit_map(MapDeserializer::from_slots(&self.0)?);
}
if self.0.is_none() {
return visitor.visit_none();
}
unreachable!("Unsupported type: {}", self.0.get_type());
}
fn deserialize_struct<V: de::Visitor<'de>>(
self,
name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &Bound<PyDict> = self.0.cast()?;
if let Some(inner) = dict.get_item(name)? {
if let Ok(inner) = inner.cast() {
return visitor.visit_map(MapDeserializer::new(inner));
}
}
}
self.deserialize_any(visitor)
}
fn deserialize_newtype_struct<V: de::Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value> {
visitor.visit_seq(SeqDeserializer {
seq_reversed: vec![self.0],
})
}
fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
if self.0.is_none() {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
if self.0.is(PyTuple::empty(self.0.py())) {
visitor.visit_unit()
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_unit_struct<V: de::Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value> {
if self.0.is(PyTuple::empty(self.0.py())) {
visitor.visit_unit()
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_enum<V: de::Visitor<'de>>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyString>() {
let variant: String = self.0.extract()?;
let py = self.0.py();
let none = py.None().into_bound(py);
return visitor.visit_enum(EnumDeserializer {
variant: &variant,
inner: none,
});
}
if self.0.is_instance_of::<PyDict>() {
let dict: &Bound<PyDict> = self.0.cast()?;
if dict.len() == 1 {
let key = dict.keys().get_item(0).unwrap();
let value = dict.values().get_item(0).unwrap();
if key.is_instance_of::<PyString>() {
let variant: String = key.extract()?;
return visitor.visit_enum(EnumDeserializer {
variant: &variant,
inner: value,
});
}
}
}
self.deserialize_any(visitor)
}
fn deserialize_tuple_struct<V: de::Visitor<'de>>(
self,
name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &Bound<PyDict> = self.0.cast()?;
if let Some(value) = dict.get_item(name)? {
if value.is_instance_of::<PyTuple>() {
let tuple: &Bound<PyTuple> = value.cast()?;
return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
}
}
}
self.deserialize_any(visitor)
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf seq tuple
map identifier ignored_any
}
}
struct SeqDeserializer<'py> {
seq_reversed: Vec<Bound<'py, PyAny>>,
}
impl<'py> SeqDeserializer<'py> {
fn from_list(list: &Bound<'py, PyList>) -> Self {
let mut seq_reversed = Vec::new();
for item in list.iter().rev() {
seq_reversed.push(item);
}
Self { seq_reversed }
}
fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
let mut seq_reversed = Vec::new();
for item in tuple.iter().rev() {
seq_reversed.push(item);
}
Self { seq_reversed }
}
}
impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
self.seq_reversed.pop().map_or(Ok(None), |value| {
let value = seed.deserialize(PyAnyDeserializer(value))?;
Ok(Some(value))
})
}
}
struct MapDeserializer<'py> {
keys: Vec<Bound<'py, PyAny>>,
values: Vec<Bound<'py, PyAny>>,
}
impl<'py> MapDeserializer<'py> {
fn new(dict: &Bound<'py, PyDict>) -> Self {
let mut keys = Vec::new();
let mut values = Vec::new();
for (key, value) in dict.iter() {
keys.push(key);
values.push(value);
}
Self { keys, values }
}
fn from_slots(obj: &Bound<'py, PyAny>) -> Result<Self> {
let mut keys = vec![];
let mut values = vec![];
for key in obj.getattr("__slots__")?.try_iter()? {
let key = key?;
keys.push(key.clone());
let v = obj.getattr(key.str()?)?;
values.push(v);
}
Ok(Self { keys, values })
}
}
impl<'de> MapAccess<'de> for MapDeserializer<'_> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
if let Some(key) = self.keys.pop() {
let key = seed.deserialize(PyAnyDeserializer(key))?;
Ok(Some(key))
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
if let Some(value) = self.values.pop() {
let value = seed.deserialize(PyAnyDeserializer(value))?;
Ok(value)
} else {
unreachable!()
}
}
}
struct EnumDeserializer<'py> {
variant: &'py str,
inner: Bound<'py, PyAny>,
}
impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: de::DeserializeSeed<'de>,
{
Ok((
seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
self,
))
}
}
impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: de::DeserializeSeed<'de>,
{
seed.deserialize(PyAnyDeserializer(self.inner))
}
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
PyAnyDeserializer(self.inner).deserialize_seq(visitor)
}
fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
PyAnyDeserializer(self.inner).deserialize_map(visitor)
}
}