quaint_forked/
serde.rs

1//! Convert results from the database into any type implementing `serde::Deserialize`.
2
3use std::borrow::Cow;
4
5use crate::{
6    ast::Value,
7    connector::{ResultRow, ResultSet},
8    error::{Error, ErrorKind},
9};
10use serde::{de::Error as SerdeError, de::*};
11
12impl ResultSet {
13    /// Takes the first row and deserializes it.
14    #[allow(clippy::wrong_self_convention)]
15    pub fn from_first<T: DeserializeOwned>(self) -> crate::Result<T> {
16        from_row(self.into_single()?)
17    }
18}
19
20/// Deserialize each row of a [`ResultSet`](../connector/struct.ResultSet.html).
21///
22/// For an example, see the docs for [`from_row`](fn.from_row.html).
23pub fn from_rows<T: DeserializeOwned>(result_set: ResultSet) -> crate::Result<Vec<T>> {
24    let mut deserialized_rows = Vec::with_capacity(result_set.len());
25
26    for row in result_set {
27        deserialized_rows.push(from_row(row)?)
28    }
29
30    Ok(deserialized_rows)
31}
32
33/// Deserialize a row into any type implementing `Deserialize`.
34///
35/// ```
36/// # use serde::Deserialize;
37/// # use quaint::ast::Value;
38/// #
39/// # #[derive(Deserialize, Debug, PartialEq)]
40/// # struct User {
41/// #     id: u64,
42/// #     name: String,
43/// # }
44/// #
45/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
46/// #
47/// #   let row = quaint::serde::make_row(vec![
48/// #       ("id", Value::from(12)),
49/// #       ("name", "Georgina".into()),
50/// #   ]);
51/// #
52/// #
53/// let user: User = quaint::serde::from_row(row)?;
54///
55/// assert_eq!(user, User { name: "Georgina".to_string(), id: 12 });
56/// # Ok(())
57/// # }
58/// ```
59pub fn from_row<T: DeserializeOwned>(row: ResultRow) -> crate::Result<T> {
60    let deserializer = RowDeserializer(row);
61
62    T::deserialize(deserializer).map_err(|e| Error::builder(ErrorKind::FromRowError(e)).build())
63}
64
65type DeserializeError = serde::de::value::Error;
66
67#[derive(Debug)]
68struct RowDeserializer(ResultRow);
69
70impl<'de> Deserializer<'de> for RowDeserializer {
71    type Error = DeserializeError;
72
73    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
74        let ResultRow { columns, mut values } = self.0;
75
76        let kvs = columns.iter().enumerate().map(move |(v, k)| {
77            // The unwrap is safe if `columns` is correct.
78            let value = values.get_mut(v).unwrap();
79            let taken_value = std::mem::replace(value, Value::Int64(None));
80            (k.as_str(), taken_value)
81        });
82
83        let deserializer = serde::de::value::MapDeserializer::new(kvs);
84
85        visitor.visit_map(deserializer)
86    }
87
88    serde::forward_to_deserialize_any! {
89        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf
90        option unit unit_struct newtype_struct seq tuple tuple_struct map
91        struct enum identifier ignored_any
92    }
93}
94
95impl<'de> IntoDeserializer<'de, DeserializeError> for Value<'de> {
96    type Deserializer = ValueDeserializer<'de>;
97
98    fn into_deserializer(self) -> Self::Deserializer {
99        ValueDeserializer(self)
100    }
101}
102
103#[derive(Debug)]
104pub struct ValueDeserializer<'a>(Value<'a>);
105
106impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
107    type Error = DeserializeError;
108
109    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
110        match self.0 {
111            Value::Text(Some(s)) => visitor.visit_string(s.into_owned()),
112            Value::Text(None) => visitor.visit_none(),
113            Value::Bytes(Some(bytes)) => visitor.visit_bytes(bytes.as_ref()),
114            Value::Bytes(None) => visitor.visit_none(),
115            Value::Enum(Some(s)) => visitor.visit_string(s.into_owned()),
116            Value::Enum(None) => visitor.visit_none(),
117            Value::Int32(Some(i)) => visitor.visit_i32(i),
118            Value::Int32(None) => visitor.visit_none(),
119            Value::Int64(Some(i)) => visitor.visit_i64(i),
120            Value::Int64(None) => visitor.visit_none(),
121            Value::Boolean(Some(b)) => visitor.visit_bool(b),
122            Value::Boolean(None) => visitor.visit_none(),
123            Value::Char(Some(c)) => visitor.visit_char(c),
124            Value::Char(None) => visitor.visit_none(),
125            Value::Float(Some(num)) => visitor.visit_f64(num as f64),
126            Value::Float(None) => visitor.visit_none(),
127            Value::Double(Some(num)) => visitor.visit_f64(num),
128            Value::Double(None) => visitor.visit_none(),
129
130            #[cfg(feature = "bigdecimal")]
131            Value::Numeric(Some(num)) => {
132                use crate::bigdecimal::ToPrimitive;
133                visitor.visit_f64(num.to_f64().unwrap())
134            }
135            #[cfg(feature = "bigdecimal")]
136            Value::Numeric(None) => visitor.visit_none(),
137
138            #[cfg(feature = "uuid")]
139            Value::Uuid(Some(uuid)) => visitor.visit_string(uuid.to_string()),
140            #[cfg(feature = "uuid")]
141            Value::Uuid(None) => visitor.visit_none(),
142
143            #[cfg(feature = "json")]
144            Value::Json(Some(value)) => {
145                let de = value.into_deserializer();
146
147                de.deserialize_any(visitor)
148                    .map_err(|err| serde::de::value::Error::custom(format!("Error deserializing JSON value: {err}")))
149            }
150            #[cfg(feature = "json")]
151            Value::Json(None) => visitor.visit_none(),
152
153            Value::Xml(Some(s)) => visitor.visit_string(s.into_owned()),
154            Value::Xml(None) => visitor.visit_none(),
155
156            #[cfg(feature = "chrono")]
157            Value::DateTime(Some(dt)) => visitor.visit_string(dt.to_rfc3339()),
158            #[cfg(feature = "chrono")]
159            Value::DateTime(None) => visitor.visit_none(),
160
161            #[cfg(feature = "chrono")]
162            Value::Date(Some(d)) => visitor.visit_string(format!("{d}")),
163            #[cfg(feature = "chrono")]
164            Value::Date(None) => visitor.visit_none(),
165
166            #[cfg(feature = "chrono")]
167            Value::Time(Some(t)) => visitor.visit_string(format!("{t}")),
168            #[cfg(feature = "chrono")]
169            Value::Time(None) => visitor.visit_none(),
170
171            Value::Array(Some(values)) => {
172                let deserializer = serde::de::value::SeqDeserializer::new(values.into_iter());
173                visitor.visit_seq(deserializer)
174            }
175            Value::Array(None) => visitor.visit_none(),
176        }
177    }
178
179    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
180        if self.0.is_null() {
181            visitor.visit_none()
182        } else {
183            visitor.visit_some(self)
184        }
185    }
186
187    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188    where
189        V: Visitor<'de>,
190    {
191        if let Value::Bytes(Some(bytes)) = self.0 {
192            match bytes {
193                Cow::Borrowed(bytes) => visitor.visit_borrowed_bytes(bytes),
194                Cow::Owned(bytes) => visitor.visit_byte_buf(bytes),
195            }
196        } else {
197            Err(DeserializeError::invalid_type(
198                Unexpected::Other(&format!("{:?}", self.0)),
199                &visitor,
200            ))
201        }
202    }
203
204    serde::forward_to_deserialize_any! {
205        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str byte_buf
206        string unit unit_struct newtype_struct seq tuple tuple_struct map
207        struct enum identifier ignored_any
208    }
209}
210
211#[doc(hidden)]
212pub fn make_row(cols: Vec<(&'static str, Value<'static>)>) -> ResultRow {
213    let mut columns = Vec::with_capacity(cols.len());
214    let mut values = Vec::with_capacity(cols.len());
215
216    for (name, value) in cols.into_iter() {
217        columns.push(name.to_owned());
218        values.push(value);
219    }
220
221    ResultRow {
222        values,
223        columns: std::sync::Arc::new(columns),
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use chrono::{DateTime, Utc};
231    use serde::Deserialize;
232
233    #[derive(Deserialize, Debug, PartialEq)]
234    struct User {
235        id: u64,
236        name: String,
237        bio: Option<String>,
238    }
239
240    #[derive(Deserialize, PartialEq, Debug)]
241    struct Cat {
242        age: f32,
243        birthday: DateTime<Utc>,
244        human: User,
245    }
246
247    #[test]
248    fn deserialize_user() {
249        let row = make_row(vec![("id", Value::integer(12)), ("name", "Georgina".into())]);
250        let user: User = from_row(row).unwrap();
251
252        assert_eq!(
253            user,
254            User {
255                id: 12,
256                name: "Georgina".to_owned(),
257                bio: None,
258            }
259        )
260    }
261
262    #[test]
263    fn from_rows_works() {
264        let first_row = make_row(vec![
265            ("id", Value::integer(12)),
266            ("name", "Georgina".into()),
267            ("bio", Value::Text(None)),
268        ]);
269        let second_row = make_row(vec![
270            ("id", 33.into()),
271            ("name", "Philbert".into()),
272            (
273                "bio",
274                "Invented sliced bread on a meditation retreat in the Himalayas.".into(),
275            ),
276        ]);
277
278        let result_set = ResultSet {
279            columns: std::sync::Arc::clone(&first_row.columns),
280            rows: vec![first_row.values, second_row.values],
281            last_insert_id: None,
282        };
283
284        let users: Vec<User> = from_rows(result_set).unwrap();
285
286        assert_eq!(
287            users,
288            &[
289                User {
290                    id: 12,
291                    name: "Georgina".to_owned(),
292                    bio: None,
293                },
294                User {
295                    id: 33,
296                    name: "Philbert".to_owned(),
297                    bio: Some("Invented sliced bread on a meditation retreat in the Himalayas.".into()),
298                }
299            ]
300        );
301    }
302
303    #[test]
304    fn deserialize_cat() {
305        let row = make_row(vec![
306            ("age", Value::numeric("18.800001".parse().unwrap())),
307            ("birthday", Value::datetime("2019-08-01T20:00:00Z".parse().unwrap())),
308            (
309                "human",
310                Value::json(serde_json::json!({
311                    "id": 19,
312                    "name": "Georgina"
313                })),
314            ),
315        ]);
316        let cat: Cat = from_row(row).unwrap();
317
318        let expected_cat = Cat {
319            age: 18.800001,
320            birthday: "2019-08-01T20:00:00Z".parse().unwrap(),
321            human: User {
322                name: "Georgina".into(),
323                id: 19,
324                bio: None,
325            },
326        };
327
328        assert_eq!(cat, expected_cat);
329    }
330}