pyo3_utils/
serde.rs

1use pyo3::{
2    exceptions::PyValueError,
3    prelude::*,
4    types::{PyBytes, PyString},
5};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8pub use pythonize;
9pub use serde;
10pub use serde_json;
11
12struct SerdeJsonError(serde_json::Error);
13
14impl From<serde_json::Error> for SerdeJsonError {
15    fn from(e: serde_json::Error) -> Self {
16        Self(e)
17    }
18}
19
20impl From<SerdeJsonError> for PyErr {
21    fn from(e: SerdeJsonError) -> Self {
22        PyValueError::new_err(e.0.to_string())
23    }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
27#[serde(transparent)]
28pub struct PySerde<T>(T);
29
30impl<T> PySerde<T> {
31    pub fn new(de: T) -> Self {
32        Self(de)
33    }
34
35    pub fn into_inner(self) -> T {
36        self.0
37    }
38
39    pub fn as_ref(&self) -> PySerde<&T> {
40        PySerde(&self.0)
41    }
42
43    pub fn as_mut(&mut self) -> PySerde<&mut T> {
44        PySerde(&mut self.0)
45    }
46}
47
48impl<'de, T> PySerde<T>
49where
50    T: Deserialize<'de>,
51{
52    pub fn from_json_str<'py>(ob: &'de Bound<'py, PyString>) -> PyResult<Self> {
53        let de = serde_json::from_str(ob.to_str()?).map_err(SerdeJsonError::from)?;
54        Ok(Self(de))
55    }
56
57    pub fn from_json_bytes<'py>(ob: &'de Bound<'py, PyBytes>) -> PyResult<Self> {
58        let de = serde_json::from_slice(ob.as_bytes()).map_err(SerdeJsonError::from)?;
59        Ok(Self(de))
60    }
61
62    pub fn from_object<'py>(ob: &'de Bound<'py, PyAny>) -> PyResult<Self> {
63        let de = pythonize::depythonize(ob)?;
64        Ok(Self(de))
65    }
66
67    pub fn extract<'py>(ob: &'de Bound<'py, PyAny>) -> PyResult<Self> {
68        if let Ok(v) = ob.downcast::<PyBytes>() {
69            Self::from_json_bytes(v)
70        } else if let Ok(v) = ob.downcast::<PyString>() {
71            Self::from_json_str(v)
72        } else {
73            Self::from_object(ob)
74        }
75    }
76}
77
78impl<'py, T> FromPyObject<'py> for PySerde<T>
79where
80    T: DeserializeOwned,
81{
82    /// TODO: We have to use [DeserializeOwned] because in `pyo3 v0.25` it cannot borrow data from the object.
83    /// We need to wait for [pyo3::conversion::FromPyObjectBound].
84    /// See: <https://github.com/PyO3/pyo3/pull/4390>.
85    ///
86    /// Use [PySerde::extract] as a workaround for now.
87    #[inline]
88    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
89        Self::extract(ob)
90    }
91}
92
93/// Benchmark(less is better):
94/// - [Self::to_object] : 0.9
95/// - [Self::to_json_str] : 0.6
96/// - json.loads([Self::to_json_str]): 1.2
97impl<T> PySerde<T>
98where
99    T: Serialize,
100{
101    pub fn to_json_str<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyString>> {
102        let val = serde_json::to_string(&self.0).map_err(SerdeJsonError::from)?;
103        Ok(PyString::new(py, &val))
104    }
105
106    pub fn to_json_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
107        let val = serde_json::to_vec(&self.0).map_err(SerdeJsonError::from)?;
108        Ok(PyBytes::new(py, &val))
109    }
110
111    pub fn to_object<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
112        let val = pythonize::pythonize(py, &self.0)?;
113        Ok(val)
114    }
115}
116
117impl<'py, T> IntoPyObject<'py> for &PySerde<T>
118where
119    T: Serialize,
120{
121    type Target = PyAny;
122    type Output = Bound<'py, Self::Target>;
123    type Error = PyErr;
124
125    #[inline]
126    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
127        self.to_object(py)
128    }
129}
130
131impl<'py, T> IntoPyObject<'py> for PySerde<T>
132where
133    T: Serialize,
134{
135    type Target = PyAny;
136    type Output = Bound<'py, Self::Target>;
137    type Error = PyErr;
138
139    #[inline]
140    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
141        (&self).into_pyobject(py)
142    }
143}