1use std::{
19 collections::HashMap,
20 hash::{BuildHasher, Hash},
21};
22
23use indexmap::IndexMap;
24use pyo3::{
25 conversion::{FromPyObjectOwned, IntoPyObject},
26 prelude::*,
27 types::{PyAny, PyDict},
28};
29use serde::{Serialize, de::DeserializeOwned};
30
31use crate::python::to_pyvalue_err;
32
33pub fn from_dict_pyo3<T>(py: Python<'_>, values: Py<PyDict>) -> Result<T, PyErr>
42where
43 T: DeserializeOwned,
44{
45 let values = values.into_any();
46 from_pyobject_pyo3(py, values.bind(py))
47}
48
49pub fn from_pyobject_pyo3<T>(py: Python<'_>, value: &Bound<'_, PyAny>) -> Result<T, PyErr>
58where
59 T: DeserializeOwned,
60{
61 let kwargs = PyDict::new(py);
65 kwargs.set_item("ensure_ascii", false)?;
66 let json_str: String = PyModule::import(py, "json")?
67 .call_method("dumps", (value,), Some(&kwargs))?
68 .extract()?;
69
70 let instance = serde_json::from_str(&json_str).map_err(to_pyvalue_err)?;
71 Ok(instance)
72}
73
74pub fn to_dict_pyo3<T>(py: Python<'_>, value: &T) -> PyResult<Py<PyDict>>
83where
84 T: Serialize,
85{
86 let py_object = to_pyobject_pyo3(py, value)?;
87 let py_dict = py_object
88 .bind(py)
89 .cast::<PyDict>()
90 .map_err(Into::<PyErr>::into)?
91 .clone()
92 .unbind();
93 Ok(py_dict)
94}
95
96pub fn to_pyobject_pyo3<T>(py: Python<'_>, value: &T) -> PyResult<Py<PyAny>>
105where
106 T: Serialize,
107{
108 let json_str = serde_json::to_string(value).map_err(to_pyvalue_err)?;
109 let py_object = PyModule::import(py, "json")?
110 .call_method("loads", (json_str,), None)?
111 .extract()?;
112 Ok(py_object)
113}
114
115pub fn indexmap_from_pyobject_pyo3<K, V>(
125 py: Python<'_>,
126 value: &Bound<'_, PyAny>,
127) -> PyResult<IndexMap<K, V>>
128where
129 K: for<'py> FromPyObjectOwned<'py> + DeserializeOwned + Eq + Hash,
130 V: for<'py> FromPyObjectOwned<'py> + DeserializeOwned,
131 IndexMap<K, V>: DeserializeOwned,
132{
133 let Ok(dict) = value.cast::<PyDict>() else {
134 return from_pyobject_pyo3(py, value);
135 };
136
137 let mut map = IndexMap::with_capacity(dict.len());
138 for (key, value) in dict.iter() {
139 map.insert(
140 extract_typed_or_json_pyo3(py, &key)?,
141 extract_typed_or_json_pyo3(py, &value)?,
142 );
143 }
144 Ok(map)
145}
146
147pub fn hashmap_from_pyobject_pyo3<K, V>(
156 py: Python<'_>,
157 value: &Bound<'_, PyAny>,
158) -> PyResult<HashMap<K, V>>
159where
160 K: for<'py> FromPyObjectOwned<'py> + DeserializeOwned + Eq + Hash,
161 V: for<'py> FromPyObjectOwned<'py> + DeserializeOwned,
162 HashMap<K, V>: DeserializeOwned,
163{
164 let Ok(dict) = value.cast::<PyDict>() else {
165 return from_pyobject_pyo3(py, value);
166 };
167
168 let mut map = HashMap::with_capacity(dict.len());
169 for (key, value) in dict.iter() {
170 map.insert(
171 extract_typed_or_json_pyo3(py, &key)?,
172 extract_typed_or_json_pyo3(py, &value)?,
173 );
174 }
175 Ok(map)
176}
177
178pub fn indexmap_to_pydict_pyo3<K, V>(py: Python<'_>, value: &IndexMap<K, V>) -> PyResult<Py<PyAny>>
184where
185 K: for<'py> IntoPyObject<'py> + Clone,
186 V: for<'py> IntoPyObject<'py> + Clone,
187{
188 let dict = PyDict::new(py);
189 for (key, value) in value {
190 dict.set_item(key.clone(), value.clone())?;
191 }
192 Ok(dict.into_any().unbind())
193}
194
195pub fn hashmap_to_pydict_pyo3<K, V, S>(
201 py: Python<'_>,
202 value: &HashMap<K, V, S>,
203) -> PyResult<Py<PyAny>>
204where
205 K: for<'py> IntoPyObject<'py> + Clone,
206 V: for<'py> IntoPyObject<'py> + Clone,
207 S: BuildHasher,
208{
209 let dict = PyDict::new(py);
210 for (key, value) in value {
211 dict.set_item(key.clone(), value.clone())?;
212 }
213 Ok(dict.into_any().unbind())
214}
215
216fn extract_typed_or_json_pyo3<T>(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult<T>
217where
218 T: for<'py> FromPyObjectOwned<'py> + DeserializeOwned,
219{
220 value.extract::<T>().or_else(|typed_err| {
221 let typed_err: PyErr = typed_err.into();
222 from_pyobject_pyo3(py, value).map_err(|json_err| {
223 to_pyvalue_err(format!(
224 "typed extraction failed: {typed_err}; JSON extraction failed: {json_err}"
225 ))
226 })
227 })
228}
229
230#[cfg(test)]
231mod tests {
232 use std::collections::HashMap;
233
234 use pyo3::types::PyDict;
235 use rstest::rstest;
236 use serde::{Deserialize, Serialize};
237
238 use super::*;
239
240 #[derive(Debug, PartialEq, Serialize, Deserialize)]
241 struct Payload {
242 values: HashMap<String, String>,
243 }
244
245 #[rstest]
246 fn test_from_pyobject_pyo3_preserves_non_ascii_strings() {
247 Python::initialize();
248 Python::attach(|py| {
249 let values = PyDict::new(py);
250 values.set_item("clé", "café").unwrap();
251
252 let dict = PyDict::new(py);
253 dict.set_item("values", values).unwrap();
254
255 let payload: Payload = from_pyobject_pyo3(py, dict.as_any()).unwrap();
256 assert_eq!(payload.values.get("clé").unwrap(), "café");
257 });
258 }
259}