pyo3_utils/
from_py_dict.rs

1//! See: <https://github.com/PyO3/pyo3/issues/5163>
2
3use std::borrow::{Cow, ToOwned};
4
5use pyo3::{
6    conversion::{FromPyObjectBound, IntoPyObjectExt as _},
7    exceptions::PyTypeError,
8    prelude::*,
9    types::{PyDict, PyString},
10};
11
12/// Inspired by [`typing.NotRequired`](https://docs.python.org/3/library/typing.html#typing.NotRequired)
13///
14/// See also: [derive_from_py_dict].
15#[derive(Debug, Clone, Copy)]
16pub struct NotRequired<T>(pub Option<T>);
17
18// DO NOT use `#[derive(Default)]`, it requires `T: Default`.
19impl<T> Default for NotRequired<T> {
20    fn default() -> Self {
21        NotRequired(None)
22    }
23}
24
25impl<'py, T> FromPyObject<'py> for NotRequired<T>
26where
27    for<'a, 'py_a> T: FromPyObjectBound<'a, 'py_a>,
28{
29    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
30        let value = ob.extract::<T>()?;
31        Ok(NotRequired(Some(value)))
32    }
33}
34
35impl<'py, T> NotRequired<T>
36where
37    for<'a> &'a T: IntoPyObject<'py>,
38    T: IntoPyObject<'py>,
39    // TODO, FIXME: We could have avoided this constraint,
40    // but it is imposed on us by the `Cow` used in pyo3.
41    // We should create an issue for pyo3 about this.
42    Self: ToOwned<Owned = Self>,
43    // 👆
44{
45    /// See: <https://pyo3.rs/v0.25.1/conversions/traits.html#deriveintopyobjectderiveintopyobjectref-field-attributes>
46    ///
47    /// You should always specify the type `T` like [NotRequired::<T>::into_py_with] when using these methods,
48    /// otherwise you may encounter a recursive `IntoPyObject` error.
49    pub fn into_py_with(
50        f: impl FnOnce(Python<'py>) -> PyResult<Bound<'py, PyAny>>,
51    ) -> impl FnOnce(Cow<'_, Self>, Python<'py>) -> PyResult<Bound<'py, PyAny>> {
52        move |value, py| match value {
53            Cow::Borrowed(v) => match &v.0 {
54                Some(inner) => inner.into_bound_py_any(py),
55                None => f(py),
56            },
57            Cow::Owned(v) => match v.0 {
58                Some(inner) => inner.into_bound_py_any(py),
59                None => f(py),
60            },
61        }
62    }
63
64    #[inline]
65    /// See also: [NotRequired::into_py_with]
66    pub fn into_py_with_none(slf: Cow<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
67        fn none(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
68            Ok(py.None().into_bound(py))
69        }
70        Self::into_py_with(none)(slf, py)
71    }
72
73    #[inline]
74    /// See also: [NotRequired::into_py_with]
75    pub fn into_py_with_default(slf: Cow<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>
76    where
77        T: Default,
78    {
79        fn default<'py, T>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>>
80        where
81            T: Default + IntoPyObject<'py>,
82        {
83            T::default().into_bound_py_any(py)
84        }
85        Self::into_py_with(default::<'py, T>)(slf, py)
86    }
87
88    #[inline]
89    /// See also: [NotRequired::into_py_with]
90    pub fn into_py_with_err(slf: Cow<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
91        fn not_required_into_pyobject_err(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
92            const NOT_REQUIRED_INTO_PYOBJECT_ERR: &str =
93                "`NotRequired` value does not exist, cannot convert to PyObject";
94
95            Err(PyTypeError::new_err(
96                pyo3::intern!(py, NOT_REQUIRED_INTO_PYOBJECT_ERR)
97                    .clone()
98                    .unbind(),
99            ))
100        }
101
102        Self::into_py_with(not_required_into_pyobject_err)(slf, py)
103    }
104}
105
106// TODO: once <https://github.com/PyO3/pyo3/issues/5163> is resolved, we can deprecate this trait.
107/// See also: [derive_from_py_dict]
108pub trait FromPyDict: Sized {
109    fn from_py_dict(dict: &Bound<'_, PyDict>) -> PyResult<Self>;
110}
111
112#[doc(hidden)]
113pub fn __get_item_with_default<T>(
114    dict: &Bound<'_, PyDict>,
115    key: &Bound<'_, PyString>,
116) -> PyResult<T>
117where
118    for<'a, 'py> T: FromPyObjectBound<'a, 'py> + Default,
119{
120    let value = match dict.get_item(key)? {
121        Some(value) => value.extract::<T>()?,
122        None => Default::default(),
123    };
124    Ok(value)
125}
126
127#[doc(hidden)]
128pub fn __get_item<T>(dict: &Bound<'_, PyDict>, key: &Bound<'_, PyString>) -> PyResult<T>
129where
130    for<'a, 'py> T: FromPyObjectBound<'a, 'py>,
131{
132    let value = dict.as_any().get_item(key)?.extract::<T>()?;
133    Ok(value)
134}
135
136// ref: <https://github.com/PyO3/pyo3/blob/3914daff760fc23aae4602378b4c010332baa920/src/impl_/frompyobject.rs#L82-L93>
137#[doc(hidden)]
138pub fn __failed_to_extract_struct_field<T>(
139    py: Python<'_>,
140    result: PyResult<T>,
141    struct_name: &'static str,
142    field_name: &'static str,
143) -> PyResult<T> {
144    result.map_err(|err| {
145        let new_err = PyTypeError::new_err(format!(
146            "failed to extract field {struct_name}.{field_name}"
147        ));
148        new_err.set_cause(py, Some(err));
149        new_err
150    })
151}
152
153/// Derives the [FromPyDict] trait for a struct.
154///
155/// > Why we need this trait?
156/// >
157/// > ref: <https://github.com/PyO3/pyo3/issues/5163>
158///
159/// # Example:
160/**
161```rust
162use pyo3_utils::from_py_dict::{derive_from_py_dict, FromPyDict as _, NotRequired};
163use pyo3::{
164    prelude::*,
165    types::{IntoPyDict as _, PyDict},
166};
167
168fn main() -> PyResult<()> {
169    pub struct Foo {
170        a: i32,
171        b: NotRequired<i32>,
172        #[cfg(all())]
173        c: NotRequired<Option<i32>>,
174    }
175
176    derive_from_py_dict!(Foo {
177        a,
178        #[pyo3(default)]
179        b,
180        // optional cfg attribute, but must be before `#[pyo3(default)]`
181        #[cfg(all())]
182        #[pyo3(default)]
183        c,
184    });
185
186    pyo3::prepare_freethreaded_python();
187    Python::with_gil(|py| {
188        // optional default `b`
189        let dict_0 = [("a", 1)].into_py_dict(py)?;
190        let foo_0 = Foo::from_py_dict(&dict_0)?;
191        assert_eq!(foo_0.a, 1);
192        assert_eq!(foo_0.b.0, None);
193
194        // missing required field `a`
195        let dict_1 = [("b", 2)].into_py_dict(py)?;
196        assert!(Foo::from_py_dict(&dict_1).is_err());
197
198        // provide a value for the optional field `b`
199        let dict_2 = [("a", 1), ("b", 2)].into_py_dict(py)?;
200        let foo_2 = Foo::from_py_dict(&dict_2)?;
201        assert_eq!(foo_2.a, 1);
202        assert_eq!(foo_2.b.0, Some(2));
203
204        // provide a value for the optional field `c: NotRequired[Optional[int]]`
205        let dict_3 = [("a", 1), ("c", 2)].into_py_dict(py)?;
206        let foo_3 = Foo::from_py_dict(&dict_3)?;
207        assert_eq!(foo_3.c.0, Some(Some(2)));
208
209        // provide `None` for the optional field `c: NotRequired[Optional[int]]`
210        let dict_4 = PyDict::new(py);
211        dict_4.set_item("a", 1)?;
212        dict_4.set_item("c", py.None())?;
213        let foo_4 = Foo::from_py_dict(&dict_4)?;
214        assert_eq!(foo_4.c.0, Some(None));
215
216        Ok(())
217    })
218}
219```
220*/
221#[macro_export]
222macro_rules! __derive_from_py_dict {
223    ($dict:expr, $key:expr, #) => {
224        $crate::from_py_dict::__get_item($dict, $key)
225    };
226    ($dict:expr, $key:expr, #default) => {
227        $crate::from_py_dict::__get_item_with_default($dict, $key)
228    };
229    ($dict:expr, $key:expr, #$attribute:ident) => {
230        compile_error!(concat!(
231            "Invalid attribute: #[pyo3(",
232            stringify!($attribute),
233            ")]. Only the optional `#[pyo3(default)]` attribute is accepted."
234        ))
235    };
236
237    (
238        $name:ty {
239            $(
240                $( #[cfg($cfg_meta:meta)] )?
241                $( #[pyo3($pyo3_meta:ident)] )?
242                $field:ident,
243            )*
244        }
245    ) => {
246        impl $crate::from_py_dict::FromPyDict for $name {
247            fn from_py_dict(dict: &::pyo3::Bound<'_, ::pyo3::types::PyDict>) -> ::pyo3::PyResult<Self> {
248                use $name as __name;
249                Ok(__name {
250                    $(
251                        $( #[cfg($cfg_meta)] )*
252                        $field: $crate::from_py_dict::__failed_to_extract_struct_field(
253                            dict.py(),
254                            {
255                                let key = ::pyo3::intern!(dict.py(), stringify!($field));
256                                $crate::from_py_dict::derive_from_py_dict!(dict, key, #$($pyo3_meta)?)
257                            },
258                            stringify!($name),
259                            stringify!($field),
260                        )?,
261                    )*
262                })
263            }
264        }
265    };
266}
267
268pub use __derive_from_py_dict as derive_from_py_dict;