mlua_serde/
de.rs

1use serde;
2use serde::de::IntoDeserializer;
3
4use mlua::{TablePairs, TableSequence, Value};
5
6use error::{Error, Result};
7
8pub struct Deserializer<'lua> {
9    pub value: Value<'lua>,
10}
11
12impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
13    type Error = Error;
14
15    #[inline]
16    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
17    where
18        V: serde::de::Visitor<'de>,
19    {
20        match self.value {
21            Value::Nil => visitor.visit_unit(),
22            Value::Boolean(v) => visitor.visit_bool(v),
23            Value::Integer(v) => visitor.visit_i64(v),
24            Value::Number(v) => visitor.visit_f64(v),
25            Value::String(v) => visitor.visit_str(v.to_str()?),
26            Value::Table(v) => {
27                let len = v.len()? as usize;
28                let mut deserializer = MapDeserializer(v.pairs(), None);
29                let map = visitor.visit_map(&mut deserializer)?;
30                let remaining = deserializer.0.count();
31                if remaining == 0 {
32                    Ok(map)
33                } else {
34                    Err(serde::de::Error::invalid_length(
35                        len,
36                        &"fewer elements in array",
37                    ))
38                }
39            }
40            _ => Err(serde::de::Error::custom("invalid value type")),
41        }
42    }
43
44    #[inline]
45    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
46    where
47        V: serde::de::Visitor<'de>,
48    {
49        match self.value {
50            Value::Nil => visitor.visit_none(),
51            _ => visitor.visit_some(self),
52        }
53    }
54
55    #[inline]
56    fn deserialize_enum<V>(
57        self,
58        _name: &str,
59        _variants: &'static [&'static str],
60        visitor: V,
61    ) -> Result<V::Value>
62    where
63        V: serde::de::Visitor<'de>,
64    {
65        let (variant, value) = match self.value {
66            Value::Table(value) => {
67                let mut iter = value.pairs::<String, Value>();
68                let (variant, value) = match iter.next() {
69                    Some(v) => v?,
70                    None => {
71                        return Err(serde::de::Error::invalid_value(
72                            serde::de::Unexpected::Map,
73                            &"map with a single key",
74                        ))
75                    }
76                };
77
78                if iter.next().is_some() {
79                    return Err(serde::de::Error::invalid_value(
80                        serde::de::Unexpected::Map,
81                        &"map with a single key",
82                    ));
83                }
84                (variant, Some(value))
85            }
86            Value::String(variant) => (variant.to_str()?.to_owned(), None),
87            _ => return Err(serde::de::Error::custom("bad enum value")),
88        };
89
90        visitor.visit_enum(EnumDeserializer { variant, value })
91    }
92
93    #[inline]
94    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
95    where
96        V: serde::de::Visitor<'de>,
97    {
98        match self.value {
99            Value::Table(v) => {
100                let len = v.len()? as usize;
101                let mut deserializer = SeqDeserializer(v.sequence_values());
102                let seq = visitor.visit_seq(&mut deserializer)?;
103                let remaining = deserializer.0.count();
104                if remaining == 0 {
105                    Ok(seq)
106                } else {
107                    Err(serde::de::Error::invalid_length(
108                        len,
109                        &"fewer elements in array",
110                    ))
111                }
112            }
113            _ => Err(serde::de::Error::custom("invalid value type")),
114        }
115    }
116
117    #[inline]
118    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
119    where
120        V: serde::de::Visitor<'de>,
121    {
122        self.deserialize_seq(visitor)
123    }
124
125    #[inline]
126    fn deserialize_tuple_struct<V>(
127        self,
128        _name: &'static str,
129        _len: usize,
130        visitor: V,
131    ) -> Result<V::Value>
132    where
133        V: serde::de::Visitor<'de>,
134    {
135        self.deserialize_seq(visitor)
136    }
137
138    forward_to_deserialize_any! {
139        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
140        byte_buf unit unit_struct newtype_struct
141        map struct identifier ignored_any
142    }
143}
144
145struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>);
146
147impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> {
148    type Error = Error;
149
150    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
151    where
152        T: serde::de::DeserializeSeed<'de>,
153    {
154        match self.0.next() {
155            Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some),
156            None => Ok(None),
157        }
158    }
159
160    fn size_hint(&self) -> Option<usize> {
161        match self.0.size_hint() {
162            (lower, Some(upper)) if lower == upper => Some(upper),
163            _ => None,
164        }
165    }
166}
167
168struct MapDeserializer<'lua>(
169    TablePairs<'lua, Value<'lua>, Value<'lua>>,
170    Option<Value<'lua>>,
171);
172
173impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> {
174    type Error = Error;
175
176    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
177    where
178        T: serde::de::DeserializeSeed<'de>,
179    {
180        match self.0.next() {
181            Some(item) => {
182                let (key, value) = item?;
183                self.1 = Some(value);
184                let key_de = Deserializer { value: key };
185                seed.deserialize(key_de).map(Some)
186            }
187            None => Ok(None),
188        }
189    }
190
191    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
192    where
193        T: serde::de::DeserializeSeed<'de>,
194    {
195        match self.1.take() {
196            Some(value) => seed.deserialize(Deserializer { value }),
197            None => Err(serde::de::Error::custom("value is missing")),
198        }
199    }
200
201    fn size_hint(&self) -> Option<usize> {
202        match self.0.size_hint() {
203            (lower, Some(upper)) if lower == upper => Some(upper),
204            _ => None,
205        }
206    }
207}
208
209struct EnumDeserializer<'lua> {
210    variant: String,
211    value: Option<Value<'lua>>,
212}
213
214impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> {
215    type Error = Error;
216    type Variant = VariantDeserializer<'lua>;
217
218    fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
219    where
220        T: serde::de::DeserializeSeed<'de>,
221    {
222        let variant = self.variant.into_deserializer();
223        let variant_access = VariantDeserializer { value: self.value };
224        seed.deserialize(variant).map(|v| (v, variant_access))
225    }
226}
227
228struct VariantDeserializer<'lua> {
229    value: Option<Value<'lua>>,
230}
231
232impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> {
233    type Error = Error;
234
235    fn unit_variant(self) -> Result<()> {
236        match self.value {
237            Some(_) => Err(serde::de::Error::invalid_type(
238                serde::de::Unexpected::NewtypeVariant,
239                &"unit variant",
240            )),
241            None => Ok(()),
242        }
243    }
244
245    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
246    where
247        T: serde::de::DeserializeSeed<'de>,
248    {
249        match self.value {
250            Some(value) => seed.deserialize(Deserializer { value }),
251            None => Err(serde::de::Error::invalid_type(
252                serde::de::Unexpected::UnitVariant,
253                &"newtype variant",
254            )),
255        }
256    }
257
258    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
259    where
260        V: serde::de::Visitor<'de>,
261    {
262        match self.value {
263            Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor),
264            None => Err(serde::de::Error::invalid_type(
265                serde::de::Unexpected::UnitVariant,
266                &"tuple variant",
267            )),
268        }
269    }
270
271    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
272    where
273        V: serde::de::Visitor<'de>,
274    {
275        match self.value {
276            Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor),
277            None => Err(serde::de::Error::invalid_type(
278                serde::de::Unexpected::UnitVariant,
279                &"struct variant",
280            )),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use mlua::Lua;
288
289    use from_value;
290
291    #[test]
292    fn test_struct() {
293        #[derive(Deserialize, PartialEq, Debug)]
294        struct Test {
295            int: u32,
296            seq: Vec<String>,
297            map: std::collections::HashMap<i32, i32>,
298            empty: Vec<()>,
299        }
300
301        let expected = Test {
302            int: 1,
303            seq: vec!["a".to_owned(), "b".to_owned()],
304            map: vec![(1, 2), (4, 1)].into_iter().collect(),
305            empty: vec![],
306        };
307
308        println!("{:?}", expected);
309        let lua = Lua::new();
310        let value = lua
311            .load(
312                r#"
313                a = {}
314                a.int = 1
315                a.seq = {"a", "b"}
316                a.map = {2, [4]=1}
317                a.empty = {}
318                return a
319            "#,
320            )
321            .eval()
322            .unwrap();
323        let got = from_value(value).unwrap();
324        assert_eq!(expected, got);
325    }
326
327    #[test]
328    fn test_tuple() {
329        #[derive(Deserialize, PartialEq, Debug)]
330        struct Rgb(u8, u8, u8);
331
332        let lua = Lua::new();
333        let expected = Rgb(1, 2, 3);
334        let value = lua
335            .load(
336                r#"
337                a = {1, 2, 3}
338                return a
339            "#,
340            )
341            .eval()
342            .unwrap();
343        let got = from_value(value).unwrap();
344        assert_eq!(expected, got);
345
346        let expected = (1, 2, 3);
347        let value = lua
348            .load(
349                r#"
350                a = {1, 2, 3}
351                return a
352            "#,
353            )
354            .eval()
355            .unwrap();
356        let got = from_value(value).unwrap();
357        assert_eq!(expected, got);
358    }
359
360    #[test]
361    fn test_enum() {
362        #[derive(Deserialize, PartialEq, Debug)]
363        enum E {
364            Unit,
365            Newtype(u32),
366            Tuple(u32, u32),
367            Struct { a: u32 },
368        }
369
370        let lua = Lua::new();
371            let expected = E::Unit;
372            let value = lua
373                .load(
374                    r#"
375                return "Unit"
376            "#,
377                )
378                .eval()
379                .unwrap();
380            let got = from_value(value).unwrap();
381            assert_eq!(expected, got);
382
383            let expected = E::Newtype(1);
384            let value = lua
385                .load(
386                    r#"
387                a = {}
388                a["Newtype"] = 1
389                return a
390            "#,
391                )
392                .eval()
393                .unwrap();
394            let got = from_value(value).unwrap();
395            assert_eq!(expected, got);
396
397            let expected = E::Tuple(1, 2);
398            let value = lua
399                .load(
400                    r#"
401                a = {}
402                a["Tuple"] = {1, 2}
403                return a
404            "#,
405                )
406                .eval()
407                .unwrap();
408            let got = from_value(value).unwrap();
409            assert_eq!(expected, got);
410
411            let expected = E::Struct { a: 1 };
412            let value = lua
413                .load(
414                    r#"
415                a = {}
416                a["Struct"] = {}
417                a["Struct"]["a"] = 1
418                return a
419            "#,
420                )
421                .eval()
422                .unwrap();
423            let got = from_value(value).unwrap();
424            assert_eq!(expected, got);
425    }
426}