1use serde;
2use serde::de::IntoDeserializer;
3
4use mlua::{TablePairs, TableSequence, Value};
5
6use error::{Error, Result};
7
8pub struct Deserializer<'lua> {
9 pub value: Value<'lua>,
10}
11
12impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
13 type Error = Error;
14
15 #[inline]
16 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
17 where
18 V: serde::de::Visitor<'de>,
19 {
20 match self.value {
21 Value::Nil => visitor.visit_unit(),
22 Value::Boolean(v) => visitor.visit_bool(v),
23 Value::Integer(v) => visitor.visit_i64(v),
24 Value::Number(v) => visitor.visit_f64(v),
25 Value::String(v) => visitor.visit_str(v.to_str()?),
26 Value::Table(v) => {
27 let len = v.len()? as usize;
28 let mut deserializer = MapDeserializer(v.pairs(), None);
29 let map = visitor.visit_map(&mut deserializer)?;
30 let remaining = deserializer.0.count();
31 if remaining == 0 {
32 Ok(map)
33 } else {
34 Err(serde::de::Error::invalid_length(
35 len,
36 &"fewer elements in array",
37 ))
38 }
39 }
40 _ => Err(serde::de::Error::custom("invalid value type")),
41 }
42 }
43
44 #[inline]
45 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
46 where
47 V: serde::de::Visitor<'de>,
48 {
49 match self.value {
50 Value::Nil => visitor.visit_none(),
51 _ => visitor.visit_some(self),
52 }
53 }
54
55 #[inline]
56 fn deserialize_enum<V>(
57 self,
58 _name: &str,
59 _variants: &'static [&'static str],
60 visitor: V,
61 ) -> Result<V::Value>
62 where
63 V: serde::de::Visitor<'de>,
64 {
65 let (variant, value) = match self.value {
66 Value::Table(value) => {
67 let mut iter = value.pairs::<String, Value>();
68 let (variant, value) = match iter.next() {
69 Some(v) => v?,
70 None => {
71 return Err(serde::de::Error::invalid_value(
72 serde::de::Unexpected::Map,
73 &"map with a single key",
74 ))
75 }
76 };
77
78 if iter.next().is_some() {
79 return Err(serde::de::Error::invalid_value(
80 serde::de::Unexpected::Map,
81 &"map with a single key",
82 ));
83 }
84 (variant, Some(value))
85 }
86 Value::String(variant) => (variant.to_str()?.to_owned(), None),
87 _ => return Err(serde::de::Error::custom("bad enum value")),
88 };
89
90 visitor.visit_enum(EnumDeserializer { variant, value })
91 }
92
93 #[inline]
94 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
95 where
96 V: serde::de::Visitor<'de>,
97 {
98 match self.value {
99 Value::Table(v) => {
100 let len = v.len()? as usize;
101 let mut deserializer = SeqDeserializer(v.sequence_values());
102 let seq = visitor.visit_seq(&mut deserializer)?;
103 let remaining = deserializer.0.count();
104 if remaining == 0 {
105 Ok(seq)
106 } else {
107 Err(serde::de::Error::invalid_length(
108 len,
109 &"fewer elements in array",
110 ))
111 }
112 }
113 _ => Err(serde::de::Error::custom("invalid value type")),
114 }
115 }
116
117 #[inline]
118 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
119 where
120 V: serde::de::Visitor<'de>,
121 {
122 self.deserialize_seq(visitor)
123 }
124
125 #[inline]
126 fn deserialize_tuple_struct<V>(
127 self,
128 _name: &'static str,
129 _len: usize,
130 visitor: V,
131 ) -> Result<V::Value>
132 where
133 V: serde::de::Visitor<'de>,
134 {
135 self.deserialize_seq(visitor)
136 }
137
138 forward_to_deserialize_any! {
139 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
140 byte_buf unit unit_struct newtype_struct
141 map struct identifier ignored_any
142 }
143}
144
145struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>);
146
147impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> {
148 type Error = Error;
149
150 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
151 where
152 T: serde::de::DeserializeSeed<'de>,
153 {
154 match self.0.next() {
155 Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some),
156 None => Ok(None),
157 }
158 }
159
160 fn size_hint(&self) -> Option<usize> {
161 match self.0.size_hint() {
162 (lower, Some(upper)) if lower == upper => Some(upper),
163 _ => None,
164 }
165 }
166}
167
168struct MapDeserializer<'lua>(
169 TablePairs<'lua, Value<'lua>, Value<'lua>>,
170 Option<Value<'lua>>,
171);
172
173impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> {
174 type Error = Error;
175
176 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
177 where
178 T: serde::de::DeserializeSeed<'de>,
179 {
180 match self.0.next() {
181 Some(item) => {
182 let (key, value) = item?;
183 self.1 = Some(value);
184 let key_de = Deserializer { value: key };
185 seed.deserialize(key_de).map(Some)
186 }
187 None => Ok(None),
188 }
189 }
190
191 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
192 where
193 T: serde::de::DeserializeSeed<'de>,
194 {
195 match self.1.take() {
196 Some(value) => seed.deserialize(Deserializer { value }),
197 None => Err(serde::de::Error::custom("value is missing")),
198 }
199 }
200
201 fn size_hint(&self) -> Option<usize> {
202 match self.0.size_hint() {
203 (lower, Some(upper)) if lower == upper => Some(upper),
204 _ => None,
205 }
206 }
207}
208
209struct EnumDeserializer<'lua> {
210 variant: String,
211 value: Option<Value<'lua>>,
212}
213
214impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> {
215 type Error = Error;
216 type Variant = VariantDeserializer<'lua>;
217
218 fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
219 where
220 T: serde::de::DeserializeSeed<'de>,
221 {
222 let variant = self.variant.into_deserializer();
223 let variant_access = VariantDeserializer { value: self.value };
224 seed.deserialize(variant).map(|v| (v, variant_access))
225 }
226}
227
228struct VariantDeserializer<'lua> {
229 value: Option<Value<'lua>>,
230}
231
232impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> {
233 type Error = Error;
234
235 fn unit_variant(self) -> Result<()> {
236 match self.value {
237 Some(_) => Err(serde::de::Error::invalid_type(
238 serde::de::Unexpected::NewtypeVariant,
239 &"unit variant",
240 )),
241 None => Ok(()),
242 }
243 }
244
245 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
246 where
247 T: serde::de::DeserializeSeed<'de>,
248 {
249 match self.value {
250 Some(value) => seed.deserialize(Deserializer { value }),
251 None => Err(serde::de::Error::invalid_type(
252 serde::de::Unexpected::UnitVariant,
253 &"newtype variant",
254 )),
255 }
256 }
257
258 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
259 where
260 V: serde::de::Visitor<'de>,
261 {
262 match self.value {
263 Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor),
264 None => Err(serde::de::Error::invalid_type(
265 serde::de::Unexpected::UnitVariant,
266 &"tuple variant",
267 )),
268 }
269 }
270
271 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
272 where
273 V: serde::de::Visitor<'de>,
274 {
275 match self.value {
276 Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor),
277 None => Err(serde::de::Error::invalid_type(
278 serde::de::Unexpected::UnitVariant,
279 &"struct variant",
280 )),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use mlua::Lua;
288
289 use from_value;
290
291 #[test]
292 fn test_struct() {
293 #[derive(Deserialize, PartialEq, Debug)]
294 struct Test {
295 int: u32,
296 seq: Vec<String>,
297 map: std::collections::HashMap<i32, i32>,
298 empty: Vec<()>,
299 }
300
301 let expected = Test {
302 int: 1,
303 seq: vec!["a".to_owned(), "b".to_owned()],
304 map: vec![(1, 2), (4, 1)].into_iter().collect(),
305 empty: vec![],
306 };
307
308 println!("{:?}", expected);
309 let lua = Lua::new();
310 let value = lua
311 .load(
312 r#"
313 a = {}
314 a.int = 1
315 a.seq = {"a", "b"}
316 a.map = {2, [4]=1}
317 a.empty = {}
318 return a
319 "#,
320 )
321 .eval()
322 .unwrap();
323 let got = from_value(value).unwrap();
324 assert_eq!(expected, got);
325 }
326
327 #[test]
328 fn test_tuple() {
329 #[derive(Deserialize, PartialEq, Debug)]
330 struct Rgb(u8, u8, u8);
331
332 let lua = Lua::new();
333 let expected = Rgb(1, 2, 3);
334 let value = lua
335 .load(
336 r#"
337 a = {1, 2, 3}
338 return a
339 "#,
340 )
341 .eval()
342 .unwrap();
343 let got = from_value(value).unwrap();
344 assert_eq!(expected, got);
345
346 let expected = (1, 2, 3);
347 let value = lua
348 .load(
349 r#"
350 a = {1, 2, 3}
351 return a
352 "#,
353 )
354 .eval()
355 .unwrap();
356 let got = from_value(value).unwrap();
357 assert_eq!(expected, got);
358 }
359
360 #[test]
361 fn test_enum() {
362 #[derive(Deserialize, PartialEq, Debug)]
363 enum E {
364 Unit,
365 Newtype(u32),
366 Tuple(u32, u32),
367 Struct { a: u32 },
368 }
369
370 let lua = Lua::new();
371 let expected = E::Unit;
372 let value = lua
373 .load(
374 r#"
375 return "Unit"
376 "#,
377 )
378 .eval()
379 .unwrap();
380 let got = from_value(value).unwrap();
381 assert_eq!(expected, got);
382
383 let expected = E::Newtype(1);
384 let value = lua
385 .load(
386 r#"
387 a = {}
388 a["Newtype"] = 1
389 return a
390 "#,
391 )
392 .eval()
393 .unwrap();
394 let got = from_value(value).unwrap();
395 assert_eq!(expected, got);
396
397 let expected = E::Tuple(1, 2);
398 let value = lua
399 .load(
400 r#"
401 a = {}
402 a["Tuple"] = {1, 2}
403 return a
404 "#,
405 )
406 .eval()
407 .unwrap();
408 let got = from_value(value).unwrap();
409 assert_eq!(expected, got);
410
411 let expected = E::Struct { a: 1 };
412 let value = lua
413 .load(
414 r#"
415 a = {}
416 a["Struct"] = {}
417 a["Struct"]["a"] = 1
418 return a
419 "#,
420 )
421 .eval()
422 .unwrap();
423 let got = from_value(value).unwrap();
424 assert_eq!(expected, got);
425 }
426}