serde_pyobject/
de.rs

1use crate::{
2    dataclass::dataclass_as_dict,
3    error::{Error, Result},
4    pydantic::pydantic_model_as_dict,
5};
6use pyo3::{types::*, Bound};
7use serde::{
8    de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
9    forward_to_deserialize_any, Deserialize, Deserializer,
10};
11
12/// Deserialize a Python object into Rust type `T: Deserialize`.
13///
14/// # Examples
15///
16/// ## primitive
17///
18/// ```
19/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
20/// use serde_pyobject::from_pyobject;
21///
22/// Python::attach(|py| {
23///     // integer
24///     let any: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
25///     let i: i32 = from_pyobject(any.into_bound(py)).unwrap();
26///     assert_eq!(i, 42);
27///
28///     // float
29///     let any: Py<PyAny> = 0.1_f32.into_bound_py_any(py).unwrap().unbind();
30///     let x: f32 = from_pyobject(any.into_bound(py)).unwrap();
31///     assert_eq!(x, 0.1);
32///
33///     // bool
34///     let any: Py<PyAny> = true.into_bound_py_any(py).unwrap().unbind();
35///     let x: bool = from_pyobject(any.into_bound(py)).unwrap();
36///     assert_eq!(x, true);
37/// });
38/// ```
39///
40/// ## option
41///
42/// ```
43/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
44/// use serde_pyobject::from_pyobject;
45///
46/// Python::attach(|py| {
47///     let none = py.None();
48///     let option: Option<i32> = from_pyobject(none.into_bound(py)).unwrap();
49///     assert_eq!(option, None);
50///
51///     let py_int: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
52///     let i: Option<i32> = from_pyobject(py_int.into_bound(py)).unwrap();
53///     assert_eq!(i, Some(42));
54/// })
55/// ```
56///
57/// ## unit
58///
59/// ```
60/// use pyo3::{Python, types::PyTuple};
61/// use serde_pyobject::from_pyobject;
62///
63/// Python::attach(|py| {
64///     let py_unit = PyTuple::empty(py);
65///     let unit: () = from_pyobject(py_unit).unwrap();
66///     assert_eq!(unit, ());
67/// })
68/// ```
69///
70/// ## unit_struct
71///
72/// ```
73/// use serde::Deserialize;
74/// use pyo3::{Python, types::PyTuple};
75/// use serde_pyobject::from_pyobject;
76///
77/// #[derive(Debug, PartialEq, Deserialize)]
78/// struct UnitStruct;
79///
80/// Python::attach(|py| {
81///     let py_unit = PyTuple::empty(py);
82///     let unit: UnitStruct = from_pyobject(py_unit).unwrap();
83///     assert_eq!(unit, UnitStruct);
84/// })
85/// ```
86///
87/// ## unit variant
88///
89/// ```
90/// use serde::Deserialize;
91/// use pyo3::{Python, types::PyString};
92/// use serde_pyobject::from_pyobject;
93///
94/// #[derive(Debug, PartialEq, Deserialize)]
95/// enum E {
96///     A,
97///     B,
98/// }
99///
100/// Python::attach(|py| {
101///     let any = PyString::new(py, "A");
102///     let out: E = from_pyobject(any).unwrap();
103///     assert_eq!(out, E::A);
104/// })
105/// ```
106///
107/// ## newtype struct
108///
109/// ```
110/// use serde::Deserialize;
111/// use pyo3::{Python, Bound, PyAny, IntoPyObject};
112/// use serde_pyobject::from_pyobject;
113///
114/// #[derive(Debug, PartialEq, Deserialize)]
115/// struct NewTypeStruct(u8);
116///
117/// Python::attach(|py| {
118///     let any: Bound<PyAny> = 1_u32.into_pyobject(py).unwrap().into_any();
119///     let obj: NewTypeStruct = from_pyobject(any).unwrap();
120///     assert_eq!(obj, NewTypeStruct(1));
121/// });
122/// ```
123///
124/// ## newtype variant
125///
126/// ```
127/// use serde::Deserialize;
128/// use pyo3::Python;
129/// use serde_pyobject::{from_pyobject, pydict};
130///
131/// #[derive(Debug, PartialEq, Deserialize)]
132/// enum NewTypeVariant {
133///     N(u8),
134/// }
135///
136/// Python::attach(|py| {
137///     let dict = pydict! { py, "N" => 41 }.unwrap();
138///     let obj: NewTypeVariant = from_pyobject(dict).unwrap();
139///     assert_eq!(obj, NewTypeVariant::N(41));
140/// });
141/// ```
142///
143/// ## seq
144///
145/// ```
146/// use pyo3::Python;
147/// use serde_pyobject::{from_pyobject, pylist};
148///
149/// Python::attach(|py| {
150///     let list = pylist![py; 1, 2, 3].unwrap();
151///     let seq: Vec<i32> = from_pyobject(list).unwrap();
152///     assert_eq!(seq, vec![1, 2, 3]);
153/// });
154/// ```
155///
156/// ## tuple
157///
158/// ```
159/// use pyo3::{Python, types::PyTuple};
160/// use serde_pyobject::from_pyobject;
161///
162/// Python::attach(|py| {
163///     let tuple = PyTuple::new(py, &[1, 2, 3]).unwrap();
164///     let tuple: (i32, i32, i32) = from_pyobject(tuple).unwrap();
165///     assert_eq!(tuple, (1, 2, 3));
166/// });
167/// ```
168///
169/// ## tuple struct
170///
171/// ```
172/// use serde::Deserialize;
173/// use pyo3::{Python, IntoPyObject, types::PyTuple};
174/// use serde_pyobject::from_pyobject;
175///
176/// #[derive(Debug, PartialEq, Deserialize)]
177/// struct T(u8, String);
178///
179/// Python::attach(|py| {
180///     let tuple = PyTuple::new(py, &[1_u32.into_pyobject(py).unwrap().into_any(), "test".into_pyobject(py).unwrap().into_any()]).unwrap();
181///     let obj: T = from_pyobject(tuple).unwrap();
182///     assert_eq!(obj, T(1, "test".to_string()));
183/// });
184/// ```
185///
186/// ## tuple variant
187///
188/// ```
189/// use serde::Deserialize;
190/// use pyo3::Python;
191/// use serde_pyobject::{from_pyobject, pydict};
192///
193/// #[derive(Debug, PartialEq, Deserialize)]
194/// enum TupleVariant {
195///     T(u8, u8),
196/// }
197///
198/// Python::attach(|py| {
199///     let dict = pydict! { py, "T" => (1, 2) }.unwrap();
200///     let obj: TupleVariant = from_pyobject(dict).unwrap();
201///     assert_eq!(obj, TupleVariant::T(1, 2));
202/// });
203/// ```
204///
205/// ## map
206///
207/// ```
208/// use pyo3::Python;
209/// use serde_pyobject::{from_pyobject, pydict};
210/// use std::collections::BTreeMap;
211///
212/// Python::attach(|py| {
213///     let dict = pydict! { py,
214///         "a" => "hom",
215///         "b" => "test"
216///     }
217///     .unwrap();
218///     let map: BTreeMap<String, String> = from_pyobject(dict).unwrap();
219///     assert_eq!(map.get("a"), Some(&"hom".to_string()));
220///     assert_eq!(map.get("b"), Some(&"test".to_string()));
221/// });
222/// ```
223///
224/// ## struct
225///
226/// ```
227/// use serde::Deserialize;
228/// use pyo3::Python;
229/// use serde_pyobject::{from_pyobject, pydict};
230///
231/// #[derive(Debug, PartialEq, Deserialize)]
232/// struct A {
233///     a: i32,
234///     b: String,
235/// }
236///
237/// Python::attach(|py| {
238///     let dict = pydict! {
239///         "a" => 1,
240///         "b" => "test"
241///     }
242///     .unwrap();
243///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
244///     assert_eq!(
245///         a,
246///         A {
247///             a: 1,
248///             b: "test".to_string()
249///         }
250///     );
251/// });
252///
253/// Python::attach(|py| {
254///     let dict = pydict! {
255///         "A" => pydict! {
256///             "a" => 1,
257///             "b" => "test"
258///         }
259///         .unwrap()
260///     }
261///     .unwrap();
262///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
263///     assert_eq!(
264///         a,
265///         A {
266///             a: 1,
267///             b: "test".to_string()
268///         }
269///     );
270/// });
271/// ```
272///
273/// ## struct variant
274///
275/// ```
276/// use serde::Deserialize;
277/// use pyo3::Python;
278/// use serde_pyobject::{from_pyobject, pydict};
279///
280/// #[derive(Debug, PartialEq, Deserialize)]
281/// enum StructVariant {
282///     S { r: u8, g: u8, b: u8 },
283/// }
284///
285/// Python::attach(|py| {
286///     let dict = pydict! {
287///         py,
288///         "S" => pydict! {
289///             "r" => 1,
290///             "g" => 2,
291///             "b" => 3
292///         }.unwrap()
293///     }
294///     .unwrap();
295///     let obj: StructVariant = from_pyobject(dict).unwrap();
296///     assert_eq!(obj, StructVariant::S { r: 1, g: 2, b: 3 });
297/// });
298/// ```
299pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
300    let any = any.into_any();
301    T::deserialize(PyAnyDeserializer(any))
302}
303
304struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
305
306impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
307    type Error = Error;
308
309    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
310    where
311        V: Visitor<'de>,
312    {
313        if self.0.is_instance_of::<PyDict>() {
314            return visitor.visit_map(MapDeserializer::new(self.0.cast()?));
315        }
316        if self.0.is_instance_of::<PyList>() {
317            return visitor.visit_seq(SeqDeserializer::from_list(self.0.cast()?));
318        }
319        if self.0.is_instance_of::<PyTuple>() {
320            return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.cast()?));
321        }
322        if self.0.is_instance_of::<PyString>() {
323            return visitor.visit_str(&self.0.extract::<String>()?);
324        }
325        if self.0.is_instance_of::<PyBool>() {
326            // must be match before PyLong
327            return visitor.visit_bool(self.0.extract()?);
328        }
329        if self.0.is_instance_of::<PyInt>() {
330            return visitor.visit_i64(self.0.extract()?);
331        }
332        if self.0.is_instance_of::<PyFloat>() {
333            return visitor.visit_f64(self.0.extract()?);
334        }
335        if let Some(dict) = dataclass_as_dict(self.0.py(), &self.0)? {
336            return visitor.visit_map(MapDeserializer::new(&dict));
337        }
338        if let Some(dict) = pydantic_model_as_dict(self.0.py(), &self.0)? {
339            return visitor.visit_map(MapDeserializer::new(&dict));
340        }
341        if self.0.hasattr("__dict__")? {
342            return visitor.visit_map(MapDeserializer::new(self.0.getattr("__dict__")?.cast()?));
343        }
344        if self.0.hasattr("__slots__")? {
345            // __slots__ and __dict__ are mutually exclusive, see
346            // https://docs.python.org/3/reference/datamodel.html#slots
347            return visitor.visit_map(MapDeserializer::from_slots(&self.0)?);
348        }
349        if self.0.is_none() {
350            return visitor.visit_none();
351        }
352
353        unreachable!("Unsupported type: {}", self.0.get_type());
354    }
355
356    fn deserialize_struct<V: de::Visitor<'de>>(
357        self,
358        name: &'static str,
359        _fields: &'static [&'static str],
360        visitor: V,
361    ) -> Result<V::Value> {
362        // Nested dict `{ "A": { "a": 1, "b": 2 } }` is deserialized as `A { a: 1, b: 2 }`
363        if self.0.is_instance_of::<PyDict>() {
364            let dict: &Bound<PyDict> = self.0.cast()?;
365            if let Some(inner) = dict.get_item(name)? {
366                if let Ok(inner) = inner.cast() {
367                    return visitor.visit_map(MapDeserializer::new(inner));
368                }
369            }
370        }
371        // Default to `any` case
372        self.deserialize_any(visitor)
373    }
374
375    fn deserialize_newtype_struct<V: de::Visitor<'de>>(
376        self,
377        _name: &'static str,
378        visitor: V,
379    ) -> Result<V::Value> {
380        visitor.visit_seq(SeqDeserializer {
381            seq_reversed: vec![self.0],
382        })
383    }
384
385    fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
386        if self.0.is_none() {
387            visitor.visit_none()
388        } else {
389            visitor.visit_some(self)
390        }
391    }
392
393    fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
394        if self.0.is(PyTuple::empty(self.0.py())) {
395            visitor.visit_unit()
396        } else {
397            self.deserialize_any(visitor)
398        }
399    }
400
401    fn deserialize_unit_struct<V: de::Visitor<'de>>(
402        self,
403        _name: &'static str,
404        visitor: V,
405    ) -> Result<V::Value> {
406        if self.0.is(PyTuple::empty(self.0.py())) {
407            visitor.visit_unit()
408        } else {
409            self.deserialize_any(visitor)
410        }
411    }
412
413    fn deserialize_enum<V: de::Visitor<'de>>(
414        self,
415        _name: &'static str,
416        _variants: &'static [&'static str],
417        visitor: V,
418    ) -> Result<V::Value> {
419        if self.0.is_instance_of::<PyString>() {
420            let variant: String = self.0.extract()?;
421            let py = self.0.py();
422            let none = py.None().into_bound(py);
423            return visitor.visit_enum(EnumDeserializer {
424                variant: &variant,
425                inner: none,
426            });
427        }
428        if self.0.is_instance_of::<PyDict>() {
429            let dict: &Bound<PyDict> = self.0.cast()?;
430            if dict.len() == 1 {
431                let key = dict.keys().get_item(0).unwrap();
432                let value = dict.values().get_item(0).unwrap();
433                if key.is_instance_of::<PyString>() {
434                    let variant: String = key.extract()?;
435                    return visitor.visit_enum(EnumDeserializer {
436                        variant: &variant,
437                        inner: value,
438                    });
439                }
440            }
441        }
442        self.deserialize_any(visitor)
443    }
444
445    fn deserialize_tuple_struct<V: de::Visitor<'de>>(
446        self,
447        name: &'static str,
448        _len: usize,
449        visitor: V,
450    ) -> Result<V::Value> {
451        if self.0.is_instance_of::<PyDict>() {
452            let dict: &Bound<PyDict> = self.0.cast()?;
453            if let Some(value) = dict.get_item(name)? {
454                if value.is_instance_of::<PyTuple>() {
455                    let tuple: &Bound<PyTuple> = value.cast()?;
456                    return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
457                }
458            }
459        }
460        self.deserialize_any(visitor)
461    }
462
463    forward_to_deserialize_any! {
464        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
465        bytes byte_buf seq tuple
466        map identifier ignored_any
467    }
468}
469
470struct SeqDeserializer<'py> {
471    seq_reversed: Vec<Bound<'py, PyAny>>,
472}
473
474impl<'py> SeqDeserializer<'py> {
475    fn from_list(list: &Bound<'py, PyList>) -> Self {
476        let mut seq_reversed = Vec::new();
477        for item in list.iter().rev() {
478            seq_reversed.push(item);
479        }
480        Self { seq_reversed }
481    }
482
483    fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
484        let mut seq_reversed = Vec::new();
485        for item in tuple.iter().rev() {
486            seq_reversed.push(item);
487        }
488        Self { seq_reversed }
489    }
490}
491
492impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
493    type Error = Error;
494    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
495    where
496        T: de::DeserializeSeed<'de>,
497    {
498        self.seq_reversed.pop().map_or(Ok(None), |value| {
499            let value = seed.deserialize(PyAnyDeserializer(value))?;
500            Ok(Some(value))
501        })
502    }
503}
504
505struct MapDeserializer<'py> {
506    keys: Vec<Bound<'py, PyAny>>,
507    values: Vec<Bound<'py, PyAny>>,
508}
509
510impl<'py> MapDeserializer<'py> {
511    fn new(dict: &Bound<'py, PyDict>) -> Self {
512        let mut keys = Vec::new();
513        let mut values = Vec::new();
514        for (key, value) in dict.iter() {
515            keys.push(key);
516            values.push(value);
517        }
518        Self { keys, values }
519    }
520
521    fn from_slots(obj: &Bound<'py, PyAny>) -> Result<Self> {
522        let mut keys = vec![];
523        let mut values = vec![];
524        for key in obj.getattr("__slots__")?.try_iter()? {
525            let key = key?;
526            keys.push(key.clone());
527            let v = obj.getattr(key.str()?)?;
528            values.push(v);
529        }
530        Ok(Self { keys, values })
531    }
532}
533
534impl<'de> MapAccess<'de> for MapDeserializer<'_> {
535    type Error = Error;
536
537    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
538    where
539        K: de::DeserializeSeed<'de>,
540    {
541        if let Some(key) = self.keys.pop() {
542            let key = seed.deserialize(PyAnyDeserializer(key))?;
543            Ok(Some(key))
544        } else {
545            Ok(None)
546        }
547    }
548
549    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
550    where
551        V: de::DeserializeSeed<'de>,
552    {
553        if let Some(value) = self.values.pop() {
554            let value = seed.deserialize(PyAnyDeserializer(value))?;
555            Ok(value)
556        } else {
557            unreachable!()
558        }
559    }
560}
561
562// this lifetime is technically no longer 'py
563struct EnumDeserializer<'py> {
564    variant: &'py str,
565    inner: Bound<'py, PyAny>,
566}
567
568impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
569    type Error = Error;
570    type Variant = Self;
571
572    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
573    where
574        V: de::DeserializeSeed<'de>,
575    {
576        Ok((
577            seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
578            self,
579        ))
580    }
581}
582
583impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
584    type Error = Error;
585
586    fn unit_variant(self) -> Result<()> {
587        Ok(())
588    }
589
590    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
591    where
592        T: de::DeserializeSeed<'de>,
593    {
594        seed.deserialize(PyAnyDeserializer(self.inner))
595    }
596
597    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
598    where
599        V: Visitor<'de>,
600    {
601        PyAnyDeserializer(self.inner).deserialize_seq(visitor)
602    }
603
604    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
605    where
606        V: Visitor<'de>,
607    {
608        PyAnyDeserializer(self.inner).deserialize_map(visitor)
609    }
610}