factorio_mlua/serde/
de.rs

1use std::cell::RefCell;
2use std::convert::TryInto;
3use std::os::raw::c_void;
4use std::rc::Rc;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::ffi;
12use crate::table::{Table, TablePairs, TableSequence};
13use crate::value::Value;
14
15/// A struct for deserializing Lua values into Rust values.
16#[derive(Debug)]
17pub struct Deserializer<'lua> {
18    value: Value<'lua>,
19    options: Options,
20    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23/// A struct with options to change default deserializer behavior.
24#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27    /// If true, an attempt to serialize types such as [`Thread`], [`UserData`], [`LightUserData`]
28    /// and [`Error`] will cause an error.
29    /// Otherwise these types skipped when iterating or serialized as unit type.
30    ///
31    /// Default: **true**
32    ///
33    /// [`Thread`]: crate::Thread
34    /// [`UserData`]: crate::UserData
35    /// [`LightUserData`]: crate::LightUserData
36    /// [`Error`]: crate::Error
37    pub deny_unsupported_types: bool,
38
39    /// If true, an attempt to serialize a recursive table (table that refers to itself)
40    /// will cause an error.
41    /// Otherwise subsequent attempts to serialize the same table will be ignored.
42    ///
43    /// Default: **true**
44    pub deny_recursive_tables: bool,
45}
46
47impl Default for Options {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl Options {
54    /// Returns a new instance of `Options` with default parameters.
55    pub const fn new() -> Self {
56        Options {
57            deny_unsupported_types: true,
58            deny_recursive_tables: true,
59        }
60    }
61
62    /// Sets [`deny_unsupported_types`] option.
63    ///
64    /// [`deny_unsupported_types`]: #structfield.deny_unsupported_types
65    #[must_use]
66    pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
67        self.deny_unsupported_types = enabled;
68        self
69    }
70
71    /// Sets [`deny_recursive_tables`] option.
72    ///
73    /// [`deny_recursive_tables`]: #structfield.deny_recursive_tables
74    #[must_use]
75    pub fn deny_recursive_tables(mut self, enabled: bool) -> Self {
76        self.deny_recursive_tables = enabled;
77        self
78    }
79}
80
81impl<'lua> Deserializer<'lua> {
82    /// Creates a new Lua Deserializer for the `Value`.
83    pub fn new(value: Value<'lua>) -> Self {
84        Self::new_with_options(value, Options::default())
85    }
86
87    /// Creates a new Lua Deserializer for the `Value` with custom options.
88    pub fn new_with_options(value: Value<'lua>, options: Options) -> Self {
89        Deserializer {
90            value,
91            options,
92            visited: Rc::new(RefCell::new(FxHashSet::default())),
93        }
94    }
95
96    fn from_parts(
97        value: Value<'lua>,
98        options: Options,
99        visited: Rc<RefCell<FxHashSet<*const c_void>>>,
100    ) -> Self {
101        Deserializer {
102            value,
103            options,
104            visited,
105        }
106    }
107}
108
109impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
110    type Error = Error;
111
112    #[inline]
113    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
114    where
115        V: de::Visitor<'de>,
116    {
117        match self.value {
118            Value::Nil => visitor.visit_unit(),
119            Value::Boolean(b) => visitor.visit_bool(b),
120            #[allow(clippy::useless_conversion)]
121            Value::Integer(i) => {
122                visitor.visit_i64(i.try_into().expect("cannot convert lua_Integer to i64"))
123            }
124            #[allow(clippy::useless_conversion)]
125            Value::Number(n) => visitor.visit_f64(n.into()),
126            #[cfg(feature = "luau")]
127            Value::Vector(_, _, _) => self.deserialize_seq(visitor),
128            Value::String(s) => match s.to_str() {
129                Ok(s) => visitor.visit_str(s),
130                Err(_) => visitor.visit_bytes(s.as_bytes()),
131            },
132            Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
133            Value::Table(_) => self.deserialize_map(visitor),
134            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
135            Value::Function(_)
136            | Value::Thread(_)
137            | Value::UserData(_)
138            | Value::LightUserData(_)
139            | Value::Error(_) => {
140                if self.options.deny_unsupported_types {
141                    Err(de::Error::custom(format!(
142                        "unsupported value type `{}`",
143                        self.value.type_name()
144                    )))
145                } else {
146                    visitor.visit_unit()
147                }
148            }
149        }
150    }
151
152    #[inline]
153    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
154    where
155        V: de::Visitor<'de>,
156    {
157        match self.value {
158            Value::Nil => visitor.visit_none(),
159            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
160            _ => visitor.visit_some(self),
161        }
162    }
163
164    #[inline]
165    fn deserialize_enum<V>(
166        self,
167        _name: &str,
168        _variants: &'static [&'static str],
169        visitor: V,
170    ) -> Result<V::Value>
171    where
172        V: de::Visitor<'de>,
173    {
174        let (variant, value, _guard) = match self.value {
175            Value::Table(table) => {
176                let _guard = RecursionGuard::new(&table, &self.visited);
177
178                let mut iter = table.pairs::<StdString, Value>();
179                let (variant, value) = match iter.next() {
180                    Some(v) => v?,
181                    None => {
182                        return Err(de::Error::invalid_value(
183                            de::Unexpected::Map,
184                            &"map with a single key",
185                        ))
186                    }
187                };
188
189                if iter.next().is_some() {
190                    return Err(de::Error::invalid_value(
191                        de::Unexpected::Map,
192                        &"map with a single key",
193                    ));
194                }
195                if check_value_if_skip(&value, self.options, &self.visited)? {
196                    return Err(de::Error::custom("bad enum value"));
197                }
198
199                (variant, Some(value), Some(_guard))
200            }
201            Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
202            _ => return Err(de::Error::custom("bad enum value")),
203        };
204
205        visitor.visit_enum(EnumDeserializer {
206            variant,
207            value,
208            options: self.options,
209            visited: self.visited,
210        })
211    }
212
213    #[inline]
214    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
215    where
216        V: de::Visitor<'de>,
217    {
218        match self.value {
219            #[cfg(feature = "luau")]
220            Value::Vector(x, y, z) => {
221                let mut deserializer = VecDeserializer {
222                    vec: [x, y, z],
223                    next: 0,
224                    options: self.options,
225                    visited: self.visited,
226                };
227                visitor.visit_seq(&mut deserializer)
228            }
229            Value::Table(t) => {
230                let _guard = RecursionGuard::new(&t, &self.visited);
231
232                let len = t.raw_len() as usize;
233                let mut deserializer = SeqDeserializer {
234                    seq: t.raw_sequence_values(),
235                    options: self.options,
236                    visited: self.visited,
237                };
238                let seq = visitor.visit_seq(&mut deserializer)?;
239                if deserializer.seq.count() == 0 {
240                    Ok(seq)
241                } else {
242                    Err(de::Error::invalid_length(
243                        len,
244                        &"fewer elements in the table",
245                    ))
246                }
247            }
248            value => Err(de::Error::invalid_type(
249                de::Unexpected::Other(value.type_name()),
250                &"table",
251            )),
252        }
253    }
254
255    #[inline]
256    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
257    where
258        V: de::Visitor<'de>,
259    {
260        self.deserialize_seq(visitor)
261    }
262
263    #[inline]
264    fn deserialize_tuple_struct<V>(
265        self,
266        _name: &'static str,
267        _len: usize,
268        visitor: V,
269    ) -> Result<V::Value>
270    where
271        V: de::Visitor<'de>,
272    {
273        self.deserialize_seq(visitor)
274    }
275
276    #[inline]
277    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
278    where
279        V: de::Visitor<'de>,
280    {
281        match self.value {
282            Value::Table(t) => {
283                let _guard = RecursionGuard::new(&t, &self.visited);
284
285                let mut deserializer = MapDeserializer {
286                    pairs: t.pairs(),
287                    value: None,
288                    options: self.options,
289                    visited: self.visited,
290                    processed: 0,
291                };
292                let map = visitor.visit_map(&mut deserializer)?;
293                let count = deserializer.pairs.count();
294                if count == 0 {
295                    Ok(map)
296                } else {
297                    Err(de::Error::invalid_length(
298                        deserializer.processed + count,
299                        &"fewer elements in the table",
300                    ))
301                }
302            }
303            value => Err(de::Error::invalid_type(
304                de::Unexpected::Other(value.type_name()),
305                &"table",
306            )),
307        }
308    }
309
310    #[inline]
311    fn deserialize_struct<V>(
312        self,
313        _name: &'static str,
314        _fields: &'static [&'static str],
315        visitor: V,
316    ) -> Result<V::Value>
317    where
318        V: de::Visitor<'de>,
319    {
320        self.deserialize_map(visitor)
321    }
322
323    #[inline]
324    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
325    where
326        V: de::Visitor<'de>,
327    {
328        visitor.visit_newtype_struct(self)
329    }
330
331    serde::forward_to_deserialize_any! {
332        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
333        byte_buf unit unit_struct identifier ignored_any
334    }
335}
336
337struct SeqDeserializer<'lua> {
338    seq: TableSequence<'lua, Value<'lua>>,
339    options: Options,
340    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
341}
342
343impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> {
344    type Error = Error;
345
346    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
347    where
348        T: de::DeserializeSeed<'de>,
349    {
350        loop {
351            match self.seq.next() {
352                Some(value) => {
353                    let value = value?;
354                    if check_value_if_skip(&value, self.options, &self.visited)? {
355                        continue;
356                    }
357                    let visited = Rc::clone(&self.visited);
358                    let deserializer = Deserializer::from_parts(value, self.options, visited);
359                    return seed.deserialize(deserializer).map(Some);
360                }
361                None => return Ok(None),
362            }
363        }
364    }
365
366    fn size_hint(&self) -> Option<usize> {
367        match self.seq.size_hint() {
368            (lower, Some(upper)) if lower == upper => Some(upper),
369            _ => None,
370        }
371    }
372}
373
374#[cfg(feature = "luau")]
375struct VecDeserializer {
376    vec: [f32; 3],
377    next: usize,
378    options: Options,
379    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
380}
381
382#[cfg(feature = "luau")]
383impl<'de> de::SeqAccess<'de> for VecDeserializer {
384    type Error = Error;
385
386    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
387    where
388        T: de::DeserializeSeed<'de>,
389    {
390        match self.vec.get(self.next) {
391            Some(&n) => {
392                self.next += 1;
393                let visited = Rc::clone(&self.visited);
394                let deserializer =
395                    Deserializer::from_parts(Value::Number(n as _), self.options, visited);
396                seed.deserialize(deserializer).map(Some)
397            }
398            None => Ok(None),
399        }
400    }
401
402    fn size_hint(&self) -> Option<usize> {
403        Some(3)
404    }
405}
406
407struct MapDeserializer<'lua> {
408    pairs: TablePairs<'lua, Value<'lua>, Value<'lua>>,
409    value: Option<Value<'lua>>,
410    options: Options,
411    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
412    processed: usize,
413}
414
415impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
416    type Error = Error;
417
418    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
419    where
420        T: de::DeserializeSeed<'de>,
421    {
422        loop {
423            match self.pairs.next() {
424                Some(item) => {
425                    let (key, value) = item?;
426                    if check_value_if_skip(&key, self.options, &self.visited)?
427                        || check_value_if_skip(&value, self.options, &self.visited)?
428                    {
429                        continue;
430                    }
431                    self.processed += 1;
432                    self.value = Some(value);
433                    let visited = Rc::clone(&self.visited);
434                    let key_de = Deserializer::from_parts(key, self.options, visited);
435                    return seed.deserialize(key_de).map(Some);
436                }
437                None => return Ok(None),
438            }
439        }
440    }
441
442    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
443    where
444        T: de::DeserializeSeed<'de>,
445    {
446        match self.value.take() {
447            Some(value) => {
448                let visited = Rc::clone(&self.visited);
449                seed.deserialize(Deserializer::from_parts(value, self.options, visited))
450            }
451            None => Err(de::Error::custom("value is missing")),
452        }
453    }
454
455    fn size_hint(&self) -> Option<usize> {
456        match self.pairs.size_hint() {
457            (lower, Some(upper)) if lower == upper => Some(upper),
458            _ => None,
459        }
460    }
461}
462
463struct EnumDeserializer<'lua> {
464    variant: StdString,
465    value: Option<Value<'lua>>,
466    options: Options,
467    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
468}
469
470impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> {
471    type Error = Error;
472    type Variant = VariantDeserializer<'lua>;
473
474    fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
475    where
476        T: de::DeserializeSeed<'de>,
477    {
478        let variant = self.variant.into_deserializer();
479        let variant_access = VariantDeserializer {
480            value: self.value,
481            options: self.options,
482            visited: self.visited,
483        };
484        seed.deserialize(variant).map(|v| (v, variant_access))
485    }
486}
487
488struct VariantDeserializer<'lua> {
489    value: Option<Value<'lua>>,
490    options: Options,
491    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
492}
493
494impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
495    type Error = Error;
496
497    fn unit_variant(self) -> Result<()> {
498        match self.value {
499            Some(_) => Err(de::Error::invalid_type(
500                de::Unexpected::NewtypeVariant,
501                &"unit variant",
502            )),
503            None => Ok(()),
504        }
505    }
506
507    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
508    where
509        T: de::DeserializeSeed<'de>,
510    {
511        match self.value {
512            Some(value) => {
513                seed.deserialize(Deserializer::from_parts(value, self.options, self.visited))
514            }
515            None => Err(de::Error::invalid_type(
516                de::Unexpected::UnitVariant,
517                &"newtype variant",
518            )),
519        }
520    }
521
522    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
523    where
524        V: de::Visitor<'de>,
525    {
526        match self.value {
527            Some(value) => serde::Deserializer::deserialize_seq(
528                Deserializer::from_parts(value, self.options, self.visited),
529                visitor,
530            ),
531            None => Err(de::Error::invalid_type(
532                de::Unexpected::UnitVariant,
533                &"tuple variant",
534            )),
535        }
536    }
537
538    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
539    where
540        V: de::Visitor<'de>,
541    {
542        match self.value {
543            Some(value) => serde::Deserializer::deserialize_map(
544                Deserializer::from_parts(value, self.options, self.visited),
545                visitor,
546            ),
547            None => Err(de::Error::invalid_type(
548                de::Unexpected::UnitVariant,
549                &"struct variant",
550            )),
551        }
552    }
553}
554
555// Adds `ptr` to the `visited` map and removes on drop
556// Used to track recursive tables but allow to traverse same tables multiple times
557struct RecursionGuard {
558    ptr: *const c_void,
559    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
560}
561
562impl RecursionGuard {
563    #[inline]
564    fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
565        let visited = Rc::clone(visited);
566        let lua = table.0.lua;
567        let ptr =
568            unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
569        visited.borrow_mut().insert(ptr);
570        RecursionGuard { ptr, visited }
571    }
572}
573
574impl Drop for RecursionGuard {
575    fn drop(&mut self) {
576        self.visited.borrow_mut().remove(&self.ptr);
577    }
578}
579
580// Checks `options` and decides should we emit an error or skip next element
581fn check_value_if_skip(
582    value: &Value,
583    options: Options,
584    visited: &RefCell<FxHashSet<*const c_void>>,
585) -> Result<bool> {
586    match value {
587        Value::Table(table) => {
588            let lua = table.0.lua;
589            let ptr =
590                unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
591            if visited.borrow().contains(&ptr) {
592                if options.deny_recursive_tables {
593                    return Err(de::Error::custom("recursive table detected"));
594                }
595                return Ok(true); // skip
596            }
597        }
598        Value::Function(_)
599        | Value::Thread(_)
600        | Value::UserData(_)
601        | Value::LightUserData(_)
602        | Value::Error(_)
603            if !options.deny_unsupported_types =>
604        {
605            return Ok(true); // skip
606        }
607        _ => {}
608    }
609    Ok(false) // do not skip
610}