mlua_codemp_patch/serde/
de.rs

1use std::cell::RefCell;
2use std::os::raw::c_void;
3use std::rc::Rc;
4use std::result::Result as StdResult;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::table::{Table, TablePairs, TableSequence};
12use crate::userdata::AnyUserData;
13use crate::value::Value;
14
15/// A struct for deserializing Lua values into Rust values.
16#[derive(Debug)]
17pub struct Deserializer {
18    value: Value,
19    options: Options,
20    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23/// A struct with options to change default deserializer behavior.
24#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27    /// If true, an attempt to serialize types such as [`Function`], [`Thread`], [`LightUserData`]
28    /// and [`Error`] will cause an error.
29    /// Otherwise these types skipped when iterating or serialized as unit type.
30    ///
31    /// Default: **true**
32    ///
33    /// [`Function`]: crate::Function
34    /// [`Thread`]: crate::Thread
35    /// [`LightUserData`]: crate::LightUserData
36    /// [`Error`]: crate::Error
37    pub deny_unsupported_types: bool,
38
39    /// If true, an attempt to serialize a recursive table (table that refers to itself)
40    /// will cause an error.
41    /// Otherwise subsequent attempts to serialize the same table will be ignored.
42    ///
43    /// Default: **true**
44    pub deny_recursive_tables: bool,
45
46    /// If true, keys in tables will be iterated in sorted order.
47    ///
48    /// Default: **false**
49    pub sort_keys: bool,
50}
51
52impl Default for Options {
53    fn default() -> Self {
54        const { Self::new() }
55    }
56}
57
58impl Options {
59    /// Returns a new instance of `Options` with default parameters.
60    pub const fn new() -> Self {
61        Options {
62            deny_unsupported_types: true,
63            deny_recursive_tables: true,
64            sort_keys: false,
65        }
66    }
67
68    /// Sets [`deny_unsupported_types`] option.
69    ///
70    /// [`deny_unsupported_types`]: #structfield.deny_unsupported_types
71    #[must_use]
72    pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
73        self.deny_unsupported_types = enabled;
74        self
75    }
76
77    /// Sets [`deny_recursive_tables`] option.
78    ///
79    /// [`deny_recursive_tables`]: #structfield.deny_recursive_tables
80    #[must_use]
81    pub const fn deny_recursive_tables(mut self, enabled: bool) -> Self {
82        self.deny_recursive_tables = enabled;
83        self
84    }
85
86    /// Sets [`sort_keys`] option.
87    ///
88    /// [`sort_keys`]: #structfield.sort_keys
89    #[must_use]
90    pub const fn sort_keys(mut self, enabled: bool) -> Self {
91        self.sort_keys = enabled;
92        self
93    }
94}
95
96impl Deserializer {
97    /// Creates a new Lua Deserializer for the `Value`.
98    pub fn new(value: Value) -> Self {
99        Self::new_with_options(value, Options::default())
100    }
101
102    /// Creates a new Lua Deserializer for the `Value` with custom options.
103    pub fn new_with_options(value: Value, options: Options) -> Self {
104        Deserializer {
105            value,
106            options,
107            visited: Rc::new(RefCell::new(FxHashSet::default())),
108        }
109    }
110
111    fn from_parts(value: Value, options: Options, visited: Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
112        Deserializer {
113            value,
114            options,
115            visited,
116        }
117    }
118}
119
120impl<'de> serde::Deserializer<'de> for Deserializer {
121    type Error = Error;
122
123    #[inline]
124    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
125    where
126        V: de::Visitor<'de>,
127    {
128        match self.value {
129            Value::Nil => visitor.visit_unit(),
130            Value::Boolean(b) => visitor.visit_bool(b),
131            #[allow(clippy::useless_conversion)]
132            Value::Integer(i) => visitor.visit_i64(i.into()),
133            #[allow(clippy::useless_conversion)]
134            Value::Number(n) => visitor.visit_f64(n.into()),
135            #[cfg(feature = "luau")]
136            Value::Vector(_) => self.deserialize_seq(visitor),
137            Value::String(s) => match s.to_str() {
138                Ok(s) => visitor.visit_str(&s),
139                Err(_) => visitor.visit_bytes(&s.as_bytes()),
140            },
141            Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
142            Value::Table(_) => self.deserialize_map(visitor),
143            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
144            Value::UserData(ud) if ud.is_serializable() => {
145                serde_userdata(ud, |value| value.deserialize_any(visitor))
146            }
147            #[cfg(feature = "luau")]
148            Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
149                let lua = ud.0.lua.lock();
150                let mut size = 0usize;
151                let buf = ffi::lua_tobuffer(lua.ref_thread(), ud.0.index, &mut size);
152                mlua_assert!(!buf.is_null(), "invalid Luau buffer");
153                let buf = std::slice::from_raw_parts(buf as *const u8, size);
154                visitor.visit_bytes(buf)
155            },
156            Value::Function(_)
157            | Value::Thread(_)
158            | Value::UserData(_)
159            | Value::LightUserData(_)
160            | Value::Error(_) => {
161                if self.options.deny_unsupported_types {
162                    let msg = format!("unsupported value type `{}`", self.value.type_name());
163                    Err(de::Error::custom(msg))
164                } else {
165                    visitor.visit_unit()
166                }
167            }
168        }
169    }
170
171    #[inline]
172    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
173    where
174        V: de::Visitor<'de>,
175    {
176        match self.value {
177            Value::Nil => visitor.visit_none(),
178            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
179            _ => visitor.visit_some(self),
180        }
181    }
182
183    #[inline]
184    fn deserialize_enum<V>(
185        self,
186        name: &'static str,
187        variants: &'static [&'static str],
188        visitor: V,
189    ) -> Result<V::Value>
190    where
191        V: de::Visitor<'de>,
192    {
193        let (variant, value, _guard) = match self.value {
194            Value::Table(table) => {
195                let _guard = RecursionGuard::new(&table, &self.visited);
196
197                let mut iter = table.pairs::<StdString, Value>();
198                let (variant, value) = match iter.next() {
199                    Some(v) => v?,
200                    None => {
201                        return Err(de::Error::invalid_value(
202                            de::Unexpected::Map,
203                            &"map with a single key",
204                        ))
205                    }
206                };
207
208                if iter.next().is_some() {
209                    return Err(de::Error::invalid_value(
210                        de::Unexpected::Map,
211                        &"map with a single key",
212                    ));
213                }
214                let skip = check_value_for_skip(&value, self.options, &self.visited)
215                    .map_err(|err| Error::DeserializeError(err.to_string()))?;
216                if skip {
217                    return Err(de::Error::custom("bad enum value"));
218                }
219
220                (variant, Some(value), Some(_guard))
221            }
222            Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
223            Value::UserData(ud) if ud.is_serializable() => {
224                return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor));
225            }
226            _ => return Err(de::Error::custom("bad enum value")),
227        };
228
229        visitor.visit_enum(EnumDeserializer {
230            variant,
231            value,
232            options: self.options,
233            visited: self.visited,
234        })
235    }
236
237    #[inline]
238    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
239    where
240        V: de::Visitor<'de>,
241    {
242        match self.value {
243            #[cfg(feature = "luau")]
244            Value::Vector(vec) => {
245                let mut deserializer = VecDeserializer {
246                    vec,
247                    next: 0,
248                    options: self.options,
249                    visited: self.visited,
250                };
251                visitor.visit_seq(&mut deserializer)
252            }
253            Value::Table(t) => {
254                let _guard = RecursionGuard::new(&t, &self.visited);
255
256                let len = t.raw_len();
257                let mut deserializer = SeqDeserializer {
258                    seq: t.sequence_values(),
259                    options: self.options,
260                    visited: self.visited,
261                };
262                let seq = visitor.visit_seq(&mut deserializer)?;
263                if deserializer.seq.count() == 0 {
264                    Ok(seq)
265                } else {
266                    Err(de::Error::invalid_length(len, &"fewer elements in the table"))
267                }
268            }
269            Value::UserData(ud) if ud.is_serializable() => {
270                serde_userdata(ud, |value| value.deserialize_seq(visitor))
271            }
272            value => Err(de::Error::invalid_type(
273                de::Unexpected::Other(value.type_name()),
274                &"table",
275            )),
276        }
277    }
278
279    #[inline]
280    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
281    where
282        V: de::Visitor<'de>,
283    {
284        self.deserialize_seq(visitor)
285    }
286
287    #[inline]
288    fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value>
289    where
290        V: de::Visitor<'de>,
291    {
292        self.deserialize_seq(visitor)
293    }
294
295    #[inline]
296    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
297    where
298        V: de::Visitor<'de>,
299    {
300        match self.value {
301            Value::Table(t) => {
302                let _guard = RecursionGuard::new(&t, &self.visited);
303
304                let mut deserializer = MapDeserializer {
305                    pairs: MapPairs::new(&t, self.options.sort_keys)?,
306                    value: None,
307                    options: self.options,
308                    visited: self.visited,
309                    processed: 0,
310                };
311                let map = visitor.visit_map(&mut deserializer)?;
312                let count = deserializer.pairs.count();
313                if count == 0 {
314                    Ok(map)
315                } else {
316                    Err(de::Error::invalid_length(
317                        deserializer.processed + count,
318                        &"fewer elements in the table",
319                    ))
320                }
321            }
322            Value::UserData(ud) if ud.is_serializable() => {
323                serde_userdata(ud, |value| value.deserialize_map(visitor))
324            }
325            value => Err(de::Error::invalid_type(
326                de::Unexpected::Other(value.type_name()),
327                &"table",
328            )),
329        }
330    }
331
332    #[inline]
333    fn deserialize_struct<V>(
334        self,
335        _name: &'static str,
336        _fields: &'static [&'static str],
337        visitor: V,
338    ) -> Result<V::Value>
339    where
340        V: de::Visitor<'de>,
341    {
342        self.deserialize_map(visitor)
343    }
344
345    #[inline]
346    fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
347    where
348        V: de::Visitor<'de>,
349    {
350        match self.value {
351            Value::UserData(ud) if ud.is_serializable() => {
352                serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor))
353            }
354            _ => visitor.visit_newtype_struct(self),
355        }
356    }
357
358    #[inline]
359    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
360    where
361        V: de::Visitor<'de>,
362    {
363        match self.value {
364            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
365            _ => self.deserialize_any(visitor),
366        }
367    }
368
369    #[inline]
370    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
371    where
372        V: de::Visitor<'de>,
373    {
374        match self.value {
375            Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
376            _ => self.deserialize_any(visitor),
377        }
378    }
379
380    serde::forward_to_deserialize_any! {
381        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
382        byte_buf identifier ignored_any
383    }
384}
385
386struct SeqDeserializer<'a> {
387    seq: TableSequence<'a, Value>,
388    options: Options,
389    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
390}
391
392impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_> {
393    type Error = Error;
394
395    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
396    where
397        T: de::DeserializeSeed<'de>,
398    {
399        loop {
400            match self.seq.next() {
401                Some(value) => {
402                    let value = value?;
403                    let skip = check_value_for_skip(&value, self.options, &self.visited)
404                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
405                    if skip {
406                        continue;
407                    }
408                    let visited = Rc::clone(&self.visited);
409                    let deserializer = Deserializer::from_parts(value, self.options, visited);
410                    return seed.deserialize(deserializer).map(Some);
411                }
412                None => return Ok(None),
413            }
414        }
415    }
416
417    fn size_hint(&self) -> Option<usize> {
418        match self.seq.size_hint() {
419            (lower, Some(upper)) if lower == upper => Some(upper),
420            _ => None,
421        }
422    }
423}
424
425#[cfg(feature = "luau")]
426struct VecDeserializer {
427    vec: crate::types::Vector,
428    next: usize,
429    options: Options,
430    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
431}
432
433#[cfg(feature = "luau")]
434impl<'de> de::SeqAccess<'de> for VecDeserializer {
435    type Error = Error;
436
437    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
438    where
439        T: de::DeserializeSeed<'de>,
440    {
441        match self.vec.0.get(self.next) {
442            Some(&n) => {
443                self.next += 1;
444                let visited = Rc::clone(&self.visited);
445                let deserializer = Deserializer::from_parts(Value::Number(n as _), self.options, visited);
446                seed.deserialize(deserializer).map(Some)
447            }
448            None => Ok(None),
449        }
450    }
451
452    fn size_hint(&self) -> Option<usize> {
453        Some(crate::types::Vector::SIZE)
454    }
455}
456
457pub(crate) enum MapPairs<'a> {
458    Iter(TablePairs<'a, Value, Value>),
459    Vec(Vec<(Value, Value)>),
460}
461
462impl<'a> MapPairs<'a> {
463    pub(crate) fn new(t: &'a Table, sort_keys: bool) -> Result<Self> {
464        if sort_keys {
465            let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
466            pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); // reverse order as we pop values from the end
467            Ok(MapPairs::Vec(pairs))
468        } else {
469            Ok(MapPairs::Iter(t.pairs::<Value, Value>()))
470        }
471    }
472
473    pub(crate) fn count(self) -> usize {
474        match self {
475            MapPairs::Iter(iter) => iter.count(),
476            MapPairs::Vec(vec) => vec.len(),
477        }
478    }
479
480    pub(crate) fn size_hint(&self) -> (usize, Option<usize>) {
481        match self {
482            MapPairs::Iter(iter) => iter.size_hint(),
483            MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
484        }
485    }
486}
487
488impl Iterator for MapPairs<'_> {
489    type Item = Result<(Value, Value)>;
490
491    fn next(&mut self) -> Option<Self::Item> {
492        match self {
493            MapPairs::Iter(iter) => iter.next(),
494            MapPairs::Vec(vec) => vec.pop().map(Ok),
495        }
496    }
497}
498
499struct MapDeserializer<'a> {
500    pairs: MapPairs<'a>,
501    value: Option<Value>,
502    options: Options,
503    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
504    processed: usize,
505}
506
507impl<'a> MapDeserializer<'a> {
508    fn next_key_deserializer(&mut self) -> Result<Option<Deserializer>> {
509        loop {
510            match self.pairs.next() {
511                Some(item) => {
512                    let (key, value) = item?;
513                    let skip_key = check_value_for_skip(&key, self.options, &self.visited)
514                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
515                    let skip_value = check_value_for_skip(&value, self.options, &self.visited)
516                        .map_err(|err| Error::DeserializeError(err.to_string()))?;
517                    if skip_key || skip_value {
518                        continue;
519                    }
520                    self.processed += 1;
521                    self.value = Some(value);
522                    let visited = Rc::clone(&self.visited);
523                    let key_de = Deserializer::from_parts(key, self.options, visited);
524                    return Ok(Some(key_de));
525                }
526                None => return Ok(None),
527            }
528        }
529    }
530
531    fn next_value_deserializer(&mut self) -> Result<Deserializer> {
532        match self.value.take() {
533            Some(value) => {
534                let visited = Rc::clone(&self.visited);
535                Ok(Deserializer::from_parts(value, self.options, visited))
536            }
537            None => Err(de::Error::custom("value is missing")),
538        }
539    }
540}
541
542impl<'de> de::MapAccess<'de> for MapDeserializer<'_> {
543    type Error = Error;
544
545    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
546    where
547        T: de::DeserializeSeed<'de>,
548    {
549        match self.next_key_deserializer() {
550            Ok(Some(key_de)) => seed.deserialize(key_de).map(Some),
551            Ok(None) => Ok(None),
552            Err(error) => Err(error),
553        }
554    }
555
556    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
557    where
558        T: de::DeserializeSeed<'de>,
559    {
560        match self.next_value_deserializer() {
561            Ok(value_de) => seed.deserialize(value_de),
562            Err(error) => Err(error),
563        }
564    }
565
566    fn size_hint(&self) -> Option<usize> {
567        match self.pairs.size_hint() {
568            (lower, Some(upper)) if lower == upper => Some(upper),
569            _ => None,
570        }
571    }
572}
573
574struct EnumDeserializer {
575    variant: StdString,
576    value: Option<Value>,
577    options: Options,
578    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
579}
580
581impl<'de> de::EnumAccess<'de> for EnumDeserializer {
582    type Error = Error;
583    type Variant = VariantDeserializer;
584
585    fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
586    where
587        T: de::DeserializeSeed<'de>,
588    {
589        let variant = self.variant.into_deserializer();
590        let variant_access = VariantDeserializer {
591            value: self.value,
592            options: self.options,
593            visited: self.visited,
594        };
595        seed.deserialize(variant).map(|v| (v, variant_access))
596    }
597}
598
599struct VariantDeserializer {
600    value: Option<Value>,
601    options: Options,
602    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
603}
604
605impl<'de> de::VariantAccess<'de> for VariantDeserializer {
606    type Error = Error;
607
608    fn unit_variant(self) -> Result<()> {
609        match self.value {
610            Some(_) => Err(de::Error::invalid_type(
611                de::Unexpected::NewtypeVariant,
612                &"unit variant",
613            )),
614            None => Ok(()),
615        }
616    }
617
618    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
619    where
620        T: de::DeserializeSeed<'de>,
621    {
622        match self.value {
623            Some(value) => seed.deserialize(Deserializer::from_parts(value, self.options, self.visited)),
624            None => Err(de::Error::invalid_type(
625                de::Unexpected::UnitVariant,
626                &"newtype variant",
627            )),
628        }
629    }
630
631    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
632    where
633        V: de::Visitor<'de>,
634    {
635        match self.value {
636            Some(value) => serde::Deserializer::deserialize_seq(
637                Deserializer::from_parts(value, self.options, self.visited),
638                visitor,
639            ),
640            None => Err(de::Error::invalid_type(
641                de::Unexpected::UnitVariant,
642                &"tuple variant",
643            )),
644        }
645    }
646
647    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
648    where
649        V: de::Visitor<'de>,
650    {
651        match self.value {
652            Some(value) => serde::Deserializer::deserialize_map(
653                Deserializer::from_parts(value, self.options, self.visited),
654                visitor,
655            ),
656            None => Err(de::Error::invalid_type(
657                de::Unexpected::UnitVariant,
658                &"struct variant",
659            )),
660        }
661    }
662}
663
664// Adds `ptr` to the `visited` map and removes on drop
665// Used to track recursive tables but allow to traverse same tables multiple times
666pub(crate) struct RecursionGuard {
667    ptr: *const c_void,
668    visited: Rc<RefCell<FxHashSet<*const c_void>>>,
669}
670
671impl RecursionGuard {
672    #[inline]
673    pub(crate) fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
674        let visited = Rc::clone(visited);
675        let ptr = table.to_pointer();
676        visited.borrow_mut().insert(ptr);
677        RecursionGuard { ptr, visited }
678    }
679}
680
681impl Drop for RecursionGuard {
682    fn drop(&mut self) {
683        self.visited.borrow_mut().remove(&self.ptr);
684    }
685}
686
687// Checks `options` and decides should we emit an error or skip next element
688pub(crate) fn check_value_for_skip(
689    value: &Value,
690    options: Options,
691    visited: &RefCell<FxHashSet<*const c_void>>,
692) -> StdResult<bool, &'static str> {
693    match value {
694        Value::Table(table) => {
695            let ptr = table.to_pointer();
696            if visited.borrow().contains(&ptr) {
697                if options.deny_recursive_tables {
698                    return Err("recursive table detected");
699                }
700                return Ok(true); // skip
701            }
702        }
703        Value::UserData(ud) if ud.is_serializable() => {}
704        Value::Function(_)
705        | Value::Thread(_)
706        | Value::UserData(_)
707        | Value::LightUserData(_)
708        | Value::Error(_)
709            if !options.deny_unsupported_types =>
710        {
711            return Ok(true); // skip
712        }
713        _ => {}
714    }
715    Ok(false) // do not skip
716}
717
718fn serde_userdata<V>(
719    ud: AnyUserData,
720    f: impl FnOnce(serde_value::Value) -> std::result::Result<V, serde_value::DeserializerError>,
721) -> Result<V> {
722    match serde_value::to_value(ud) {
723        Ok(value) => match f(value) {
724            Ok(r) => Ok(r),
725            Err(error) => Err(Error::DeserializeError(error.to_string())),
726        },
727        Err(error) => Err(Error::SerializeError(error.to_string())),
728    }
729}