serde_pyobject/de.rs
1use crate::{
2 dataclass::dataclass_as_dict,
3 error::{Error, Result},
4 pydantic::pydantic_model_as_dict,
5};
6use pyo3::{types::*, Bound};
7use serde::{
8 de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
9 forward_to_deserialize_any, Deserialize, Deserializer,
10};
11
12/// Deserialize a Python object into Rust type `T: Deserialize`.
13///
14/// # Examples
15///
16/// ## primitive
17///
18/// ```
19/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
20/// use serde_pyobject::from_pyobject;
21///
22/// Python::attach(|py| {
23/// // integer
24/// let any: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
25/// let i: i32 = from_pyobject(any.into_bound(py)).unwrap();
26/// assert_eq!(i, 42);
27///
28/// // float
29/// let any: Py<PyAny> = 0.1_f32.into_bound_py_any(py).unwrap().unbind();
30/// let x: f32 = from_pyobject(any.into_bound(py)).unwrap();
31/// assert_eq!(x, 0.1);
32///
33/// // bool
34/// let any: Py<PyAny> = true.into_bound_py_any(py).unwrap().unbind();
35/// let x: bool = from_pyobject(any.into_bound(py)).unwrap();
36/// assert_eq!(x, true);
37/// });
38/// ```
39///
40/// ## option
41///
42/// ```
43/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
44/// use serde_pyobject::from_pyobject;
45///
46/// Python::attach(|py| {
47/// let none = py.None();
48/// let option: Option<i32> = from_pyobject(none.into_bound(py)).unwrap();
49/// assert_eq!(option, None);
50///
51/// let py_int: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
52/// let i: Option<i32> = from_pyobject(py_int.into_bound(py)).unwrap();
53/// assert_eq!(i, Some(42));
54/// })
55/// ```
56///
57/// ## unit
58///
59/// ```
60/// use pyo3::{Python, types::PyTuple};
61/// use serde_pyobject::from_pyobject;
62///
63/// Python::attach(|py| {
64/// let py_unit = PyTuple::empty(py);
65/// let unit: () = from_pyobject(py_unit).unwrap();
66/// assert_eq!(unit, ());
67/// })
68/// ```
69///
70/// ## unit_struct
71///
72/// ```
73/// use serde::Deserialize;
74/// use pyo3::{Python, types::PyTuple};
75/// use serde_pyobject::from_pyobject;
76///
77/// #[derive(Debug, PartialEq, Deserialize)]
78/// struct UnitStruct;
79///
80/// Python::attach(|py| {
81/// let py_unit = PyTuple::empty(py);
82/// let unit: UnitStruct = from_pyobject(py_unit).unwrap();
83/// assert_eq!(unit, UnitStruct);
84/// })
85/// ```
86///
87/// ## unit variant
88///
89/// ```
90/// use serde::Deserialize;
91/// use pyo3::{Python, types::PyString};
92/// use serde_pyobject::from_pyobject;
93///
94/// #[derive(Debug, PartialEq, Deserialize)]
95/// enum E {
96/// A,
97/// B,
98/// }
99///
100/// Python::attach(|py| {
101/// let any = PyString::new(py, "A");
102/// let out: E = from_pyobject(any).unwrap();
103/// assert_eq!(out, E::A);
104/// })
105/// ```
106///
107/// ## newtype struct
108///
109/// ```
110/// use serde::Deserialize;
111/// use pyo3::{Python, Bound, PyAny, IntoPyObject};
112/// use serde_pyobject::from_pyobject;
113///
114/// #[derive(Debug, PartialEq, Deserialize)]
115/// struct NewTypeStruct(u8);
116///
117/// Python::attach(|py| {
118/// let any: Bound<PyAny> = 1_u32.into_pyobject(py).unwrap().into_any();
119/// let obj: NewTypeStruct = from_pyobject(any).unwrap();
120/// assert_eq!(obj, NewTypeStruct(1));
121/// });
122/// ```
123///
124/// ## newtype variant
125///
126/// ```
127/// use serde::Deserialize;
128/// use pyo3::Python;
129/// use serde_pyobject::{from_pyobject, pydict};
130///
131/// #[derive(Debug, PartialEq, Deserialize)]
132/// enum NewTypeVariant {
133/// N(u8),
134/// }
135///
136/// Python::attach(|py| {
137/// let dict = pydict! { py, "N" => 41 }.unwrap();
138/// let obj: NewTypeVariant = from_pyobject(dict).unwrap();
139/// assert_eq!(obj, NewTypeVariant::N(41));
140/// });
141/// ```
142///
143/// ## seq
144///
145/// ```
146/// use pyo3::Python;
147/// use serde_pyobject::{from_pyobject, pylist};
148///
149/// Python::attach(|py| {
150/// let list = pylist![py; 1, 2, 3].unwrap();
151/// let seq: Vec<i32> = from_pyobject(list).unwrap();
152/// assert_eq!(seq, vec![1, 2, 3]);
153/// });
154/// ```
155///
156/// ## tuple
157///
158/// ```
159/// use pyo3::{Python, types::PyTuple};
160/// use serde_pyobject::from_pyobject;
161///
162/// Python::attach(|py| {
163/// let tuple = PyTuple::new(py, &[1, 2, 3]).unwrap();
164/// let tuple: (i32, i32, i32) = from_pyobject(tuple).unwrap();
165/// assert_eq!(tuple, (1, 2, 3));
166/// });
167/// ```
168///
169/// ## tuple struct
170///
171/// ```
172/// use serde::Deserialize;
173/// use pyo3::{Python, IntoPyObject, types::PyTuple};
174/// use serde_pyobject::from_pyobject;
175///
176/// #[derive(Debug, PartialEq, Deserialize)]
177/// struct T(u8, String);
178///
179/// Python::attach(|py| {
180/// let tuple = PyTuple::new(py, &[1_u32.into_pyobject(py).unwrap().into_any(), "test".into_pyobject(py).unwrap().into_any()]).unwrap();
181/// let obj: T = from_pyobject(tuple).unwrap();
182/// assert_eq!(obj, T(1, "test".to_string()));
183/// });
184/// ```
185///
186/// ## tuple variant
187///
188/// ```
189/// use serde::Deserialize;
190/// use pyo3::Python;
191/// use serde_pyobject::{from_pyobject, pydict};
192///
193/// #[derive(Debug, PartialEq, Deserialize)]
194/// enum TupleVariant {
195/// T(u8, u8),
196/// }
197///
198/// Python::attach(|py| {
199/// let dict = pydict! { py, "T" => (1, 2) }.unwrap();
200/// let obj: TupleVariant = from_pyobject(dict).unwrap();
201/// assert_eq!(obj, TupleVariant::T(1, 2));
202/// });
203/// ```
204///
205/// ## map
206///
207/// ```
208/// use pyo3::Python;
209/// use serde_pyobject::{from_pyobject, pydict};
210/// use std::collections::BTreeMap;
211///
212/// Python::attach(|py| {
213/// let dict = pydict! { py,
214/// "a" => "hom",
215/// "b" => "test"
216/// }
217/// .unwrap();
218/// let map: BTreeMap<String, String> = from_pyobject(dict).unwrap();
219/// assert_eq!(map.get("a"), Some(&"hom".to_string()));
220/// assert_eq!(map.get("b"), Some(&"test".to_string()));
221/// });
222/// ```
223///
224/// ## struct
225///
226/// ```
227/// use serde::Deserialize;
228/// use pyo3::Python;
229/// use serde_pyobject::{from_pyobject, pydict};
230///
231/// #[derive(Debug, PartialEq, Deserialize)]
232/// struct A {
233/// a: i32,
234/// b: String,
235/// }
236///
237/// Python::attach(|py| {
238/// let dict = pydict! {
239/// "a" => 1,
240/// "b" => "test"
241/// }
242/// .unwrap();
243/// let a: A = from_pyobject(dict.into_bound(py)).unwrap();
244/// assert_eq!(
245/// a,
246/// A {
247/// a: 1,
248/// b: "test".to_string()
249/// }
250/// );
251/// });
252///
253/// Python::attach(|py| {
254/// let dict = pydict! {
255/// "A" => pydict! {
256/// "a" => 1,
257/// "b" => "test"
258/// }
259/// .unwrap()
260/// }
261/// .unwrap();
262/// let a: A = from_pyobject(dict.into_bound(py)).unwrap();
263/// assert_eq!(
264/// a,
265/// A {
266/// a: 1,
267/// b: "test".to_string()
268/// }
269/// );
270/// });
271/// ```
272///
273/// ## struct variant
274///
275/// ```
276/// use serde::Deserialize;
277/// use pyo3::Python;
278/// use serde_pyobject::{from_pyobject, pydict};
279///
280/// #[derive(Debug, PartialEq, Deserialize)]
281/// enum StructVariant {
282/// S { r: u8, g: u8, b: u8 },
283/// }
284///
285/// Python::attach(|py| {
286/// let dict = pydict! {
287/// py,
288/// "S" => pydict! {
289/// "r" => 1,
290/// "g" => 2,
291/// "b" => 3
292/// }.unwrap()
293/// }
294/// .unwrap();
295/// let obj: StructVariant = from_pyobject(dict).unwrap();
296/// assert_eq!(obj, StructVariant::S { r: 1, g: 2, b: 3 });
297/// });
298/// ```
299pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
300 let any = any.into_any();
301 T::deserialize(PyAnyDeserializer(any))
302}
303
304struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
305
306impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
307 type Error = Error;
308
309 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
310 where
311 V: Visitor<'de>,
312 {
313 if self.0.is_instance_of::<PyDict>() {
314 return visitor.visit_map(MapDeserializer::new(self.0.cast()?));
315 }
316 if self.0.is_instance_of::<PyList>() {
317 return visitor.visit_seq(SeqDeserializer::from_list(self.0.cast()?));
318 }
319 if self.0.is_instance_of::<PyTuple>() {
320 return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.cast()?));
321 }
322 if self.0.is_instance_of::<PyString>() {
323 return visitor.visit_str(&self.0.extract::<String>()?);
324 }
325 if self.0.is_instance_of::<PyBool>() {
326 // must be match before PyLong
327 return visitor.visit_bool(self.0.extract()?);
328 }
329 if self.0.is_instance_of::<PyInt>() {
330 return visitor.visit_i64(self.0.extract()?);
331 }
332 if self.0.is_instance_of::<PyFloat>() {
333 return visitor.visit_f64(self.0.extract()?);
334 }
335 if let Some(dict) = dataclass_as_dict(self.0.py(), &self.0)? {
336 return visitor.visit_map(MapDeserializer::new(&dict));
337 }
338 if let Some(dict) = pydantic_model_as_dict(self.0.py(), &self.0)? {
339 return visitor.visit_map(MapDeserializer::new(&dict));
340 }
341 if self.0.hasattr("__dict__")? {
342 return visitor.visit_map(MapDeserializer::new(self.0.getattr("__dict__")?.cast()?));
343 }
344 if self.0.hasattr("__slots__")? {
345 // __slots__ and __dict__ are mutually exclusive, see
346 // https://docs.python.org/3/reference/datamodel.html#slots
347 return visitor.visit_map(MapDeserializer::from_slots(&self.0)?);
348 }
349 if self.0.is_none() {
350 return visitor.visit_none();
351 }
352
353 unreachable!("Unsupported type: {}", self.0.get_type());
354 }
355
356 fn deserialize_struct<V: de::Visitor<'de>>(
357 self,
358 name: &'static str,
359 _fields: &'static [&'static str],
360 visitor: V,
361 ) -> Result<V::Value> {
362 // Nested dict `{ "A": { "a": 1, "b": 2 } }` is deserialized as `A { a: 1, b: 2 }`
363 if self.0.is_instance_of::<PyDict>() {
364 let dict: &Bound<PyDict> = self.0.cast()?;
365 if let Some(inner) = dict.get_item(name)? {
366 if let Ok(inner) = inner.cast() {
367 return visitor.visit_map(MapDeserializer::new(inner));
368 }
369 }
370 }
371 // Default to `any` case
372 self.deserialize_any(visitor)
373 }
374
375 fn deserialize_newtype_struct<V: de::Visitor<'de>>(
376 self,
377 _name: &'static str,
378 visitor: V,
379 ) -> Result<V::Value> {
380 visitor.visit_seq(SeqDeserializer {
381 seq_reversed: vec![self.0],
382 })
383 }
384
385 fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
386 if self.0.is_none() {
387 visitor.visit_none()
388 } else {
389 visitor.visit_some(self)
390 }
391 }
392
393 fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
394 if self.0.is(PyTuple::empty(self.0.py())) {
395 visitor.visit_unit()
396 } else {
397 self.deserialize_any(visitor)
398 }
399 }
400
401 fn deserialize_unit_struct<V: de::Visitor<'de>>(
402 self,
403 _name: &'static str,
404 visitor: V,
405 ) -> Result<V::Value> {
406 if self.0.is(PyTuple::empty(self.0.py())) {
407 visitor.visit_unit()
408 } else {
409 self.deserialize_any(visitor)
410 }
411 }
412
413 fn deserialize_enum<V: de::Visitor<'de>>(
414 self,
415 _name: &'static str,
416 _variants: &'static [&'static str],
417 visitor: V,
418 ) -> Result<V::Value> {
419 if self.0.is_instance_of::<PyString>() {
420 let variant: String = self.0.extract()?;
421 let py = self.0.py();
422 let none = py.None().into_bound(py);
423 return visitor.visit_enum(EnumDeserializer {
424 variant: &variant,
425 inner: none,
426 });
427 }
428 if self.0.is_instance_of::<PyDict>() {
429 let dict: &Bound<PyDict> = self.0.cast()?;
430 if dict.len() == 1 {
431 let key = dict.keys().get_item(0).unwrap();
432 let value = dict.values().get_item(0).unwrap();
433 if key.is_instance_of::<PyString>() {
434 let variant: String = key.extract()?;
435 return visitor.visit_enum(EnumDeserializer {
436 variant: &variant,
437 inner: value,
438 });
439 }
440 }
441 }
442 self.deserialize_any(visitor)
443 }
444
445 fn deserialize_tuple_struct<V: de::Visitor<'de>>(
446 self,
447 name: &'static str,
448 _len: usize,
449 visitor: V,
450 ) -> Result<V::Value> {
451 if self.0.is_instance_of::<PyDict>() {
452 let dict: &Bound<PyDict> = self.0.cast()?;
453 if let Some(value) = dict.get_item(name)? {
454 if value.is_instance_of::<PyTuple>() {
455 let tuple: &Bound<PyTuple> = value.cast()?;
456 return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
457 }
458 }
459 }
460 self.deserialize_any(visitor)
461 }
462
463 forward_to_deserialize_any! {
464 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
465 bytes byte_buf seq tuple
466 map identifier ignored_any
467 }
468}
469
470struct SeqDeserializer<'py> {
471 seq_reversed: Vec<Bound<'py, PyAny>>,
472}
473
474impl<'py> SeqDeserializer<'py> {
475 fn from_list(list: &Bound<'py, PyList>) -> Self {
476 let mut seq_reversed = Vec::new();
477 for item in list.iter().rev() {
478 seq_reversed.push(item);
479 }
480 Self { seq_reversed }
481 }
482
483 fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
484 let mut seq_reversed = Vec::new();
485 for item in tuple.iter().rev() {
486 seq_reversed.push(item);
487 }
488 Self { seq_reversed }
489 }
490}
491
492impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
493 type Error = Error;
494 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
495 where
496 T: de::DeserializeSeed<'de>,
497 {
498 self.seq_reversed.pop().map_or(Ok(None), |value| {
499 let value = seed.deserialize(PyAnyDeserializer(value))?;
500 Ok(Some(value))
501 })
502 }
503}
504
505struct MapDeserializer<'py> {
506 keys: Vec<Bound<'py, PyAny>>,
507 values: Vec<Bound<'py, PyAny>>,
508}
509
510impl<'py> MapDeserializer<'py> {
511 fn new(dict: &Bound<'py, PyDict>) -> Self {
512 let mut keys = Vec::new();
513 let mut values = Vec::new();
514 for (key, value) in dict.iter() {
515 keys.push(key);
516 values.push(value);
517 }
518 Self { keys, values }
519 }
520
521 fn from_slots(obj: &Bound<'py, PyAny>) -> Result<Self> {
522 let mut keys = vec![];
523 let mut values = vec![];
524 for key in obj.getattr("__slots__")?.try_iter()? {
525 let key = key?;
526 keys.push(key.clone());
527 let v = obj.getattr(key.str()?)?;
528 values.push(v);
529 }
530 Ok(Self { keys, values })
531 }
532}
533
534impl<'de> MapAccess<'de> for MapDeserializer<'_> {
535 type Error = Error;
536
537 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
538 where
539 K: de::DeserializeSeed<'de>,
540 {
541 if let Some(key) = self.keys.pop() {
542 let key = seed.deserialize(PyAnyDeserializer(key))?;
543 Ok(Some(key))
544 } else {
545 Ok(None)
546 }
547 }
548
549 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
550 where
551 V: de::DeserializeSeed<'de>,
552 {
553 if let Some(value) = self.values.pop() {
554 let value = seed.deserialize(PyAnyDeserializer(value))?;
555 Ok(value)
556 } else {
557 unreachable!()
558 }
559 }
560}
561
562// this lifetime is technically no longer 'py
563struct EnumDeserializer<'py> {
564 variant: &'py str,
565 inner: Bound<'py, PyAny>,
566}
567
568impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
569 type Error = Error;
570 type Variant = Self;
571
572 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
573 where
574 V: de::DeserializeSeed<'de>,
575 {
576 Ok((
577 seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
578 self,
579 ))
580 }
581}
582
583impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
584 type Error = Error;
585
586 fn unit_variant(self) -> Result<()> {
587 Ok(())
588 }
589
590 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
591 where
592 T: de::DeserializeSeed<'de>,
593 {
594 seed.deserialize(PyAnyDeserializer(self.inner))
595 }
596
597 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
598 where
599 V: Visitor<'de>,
600 {
601 PyAnyDeserializer(self.inner).deserialize_seq(visitor)
602 }
603
604 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
605 where
606 V: Visitor<'de>,
607 {
608 PyAnyDeserializer(self.inner).deserialize_map(visitor)
609 }
610}