1use std::cell::RefCell;
2use std::os::raw::c_void;
3use std::rc::Rc;
4use std::result::Result as StdResult;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::table::{Table, TablePairs, TableSequence};
12use crate::userdata::AnyUserData;
13use crate::value::Value;
14
15#[derive(Debug)]
17pub struct Deserializer {
18 value: Value,
19 options: Options,
20 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27 pub deny_unsupported_types: bool,
38
39 pub deny_recursive_tables: bool,
45
46 pub sort_keys: bool,
50}
51
52impl Default for Options {
53 fn default() -> Self {
54 const { Self::new() }
55 }
56}
57
58impl Options {
59 pub const fn new() -> Self {
61 Options {
62 deny_unsupported_types: true,
63 deny_recursive_tables: true,
64 sort_keys: false,
65 }
66 }
67
68 #[must_use]
72 pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
73 self.deny_unsupported_types = enabled;
74 self
75 }
76
77 #[must_use]
81 pub const fn deny_recursive_tables(mut self, enabled: bool) -> Self {
82 self.deny_recursive_tables = enabled;
83 self
84 }
85
86 #[must_use]
90 pub const fn sort_keys(mut self, enabled: bool) -> Self {
91 self.sort_keys = enabled;
92 self
93 }
94}
95
96impl Deserializer {
97 pub fn new(value: Value) -> Self {
99 Self::new_with_options(value, Options::default())
100 }
101
102 pub fn new_with_options(value: Value, options: Options) -> Self {
104 Deserializer {
105 value,
106 options,
107 visited: Rc::new(RefCell::new(FxHashSet::default())),
108 }
109 }
110
111 fn from_parts(value: Value, options: Options, visited: Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
112 Deserializer {
113 value,
114 options,
115 visited,
116 }
117 }
118}
119
120impl<'de> serde::Deserializer<'de> for Deserializer {
121 type Error = Error;
122
123 #[inline]
124 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
125 where
126 V: de::Visitor<'de>,
127 {
128 match self.value {
129 Value::Nil => visitor.visit_unit(),
130 Value::Boolean(b) => visitor.visit_bool(b),
131 #[allow(clippy::useless_conversion)]
132 Value::Integer(i) => visitor.visit_i64(i.into()),
133 #[allow(clippy::useless_conversion)]
134 Value::Number(n) => visitor.visit_f64(n.into()),
135 #[cfg(feature = "luau")]
136 Value::Vector(_) => self.deserialize_seq(visitor),
137 Value::String(s) => match s.to_str() {
138 Ok(s) => visitor.visit_str(&s),
139 Err(_) => visitor.visit_bytes(&s.as_bytes()),
140 },
141 Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
142 Value::Table(_) => self.deserialize_map(visitor),
143 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
144 Value::UserData(ud) if ud.is_serializable() => {
145 serde_userdata(ud, |value| value.deserialize_any(visitor))
146 }
147 #[cfg(feature = "luau")]
148 Value::UserData(ud) if ud.1 == crate::types::SubtypeId::Buffer => unsafe {
149 let lua = ud.0.lua.lock();
150 let mut size = 0usize;
151 let buf = ffi::lua_tobuffer(lua.ref_thread(), ud.0.index, &mut size);
152 mlua_assert!(!buf.is_null(), "invalid Luau buffer");
153 let buf = std::slice::from_raw_parts(buf as *const u8, size);
154 visitor.visit_bytes(buf)
155 },
156 Value::Function(_)
157 | Value::Thread(_)
158 | Value::UserData(_)
159 | Value::LightUserData(_)
160 | Value::Error(_) => {
161 if self.options.deny_unsupported_types {
162 let msg = format!("unsupported value type `{}`", self.value.type_name());
163 Err(de::Error::custom(msg))
164 } else {
165 visitor.visit_unit()
166 }
167 }
168 }
169 }
170
171 #[inline]
172 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
173 where
174 V: de::Visitor<'de>,
175 {
176 match self.value {
177 Value::Nil => visitor.visit_none(),
178 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
179 _ => visitor.visit_some(self),
180 }
181 }
182
183 #[inline]
184 fn deserialize_enum<V>(
185 self,
186 name: &'static str,
187 variants: &'static [&'static str],
188 visitor: V,
189 ) -> Result<V::Value>
190 where
191 V: de::Visitor<'de>,
192 {
193 let (variant, value, _guard) = match self.value {
194 Value::Table(table) => {
195 let _guard = RecursionGuard::new(&table, &self.visited);
196
197 let mut iter = table.pairs::<StdString, Value>();
198 let (variant, value) = match iter.next() {
199 Some(v) => v?,
200 None => {
201 return Err(de::Error::invalid_value(
202 de::Unexpected::Map,
203 &"map with a single key",
204 ))
205 }
206 };
207
208 if iter.next().is_some() {
209 return Err(de::Error::invalid_value(
210 de::Unexpected::Map,
211 &"map with a single key",
212 ));
213 }
214 let skip = check_value_for_skip(&value, self.options, &self.visited)
215 .map_err(|err| Error::DeserializeError(err.to_string()))?;
216 if skip {
217 return Err(de::Error::custom("bad enum value"));
218 }
219
220 (variant, Some(value), Some(_guard))
221 }
222 Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
223 Value::UserData(ud) if ud.is_serializable() => {
224 return serde_userdata(ud, |value| value.deserialize_enum(name, variants, visitor));
225 }
226 _ => return Err(de::Error::custom("bad enum value")),
227 };
228
229 visitor.visit_enum(EnumDeserializer {
230 variant,
231 value,
232 options: self.options,
233 visited: self.visited,
234 })
235 }
236
237 #[inline]
238 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
239 where
240 V: de::Visitor<'de>,
241 {
242 match self.value {
243 #[cfg(feature = "luau")]
244 Value::Vector(vec) => {
245 let mut deserializer = VecDeserializer {
246 vec,
247 next: 0,
248 options: self.options,
249 visited: self.visited,
250 };
251 visitor.visit_seq(&mut deserializer)
252 }
253 Value::Table(t) => {
254 let _guard = RecursionGuard::new(&t, &self.visited);
255
256 let len = t.raw_len();
257 let mut deserializer = SeqDeserializer {
258 seq: t.sequence_values(),
259 options: self.options,
260 visited: self.visited,
261 };
262 let seq = visitor.visit_seq(&mut deserializer)?;
263 if deserializer.seq.count() == 0 {
264 Ok(seq)
265 } else {
266 Err(de::Error::invalid_length(len, &"fewer elements in the table"))
267 }
268 }
269 Value::UserData(ud) if ud.is_serializable() => {
270 serde_userdata(ud, |value| value.deserialize_seq(visitor))
271 }
272 value => Err(de::Error::invalid_type(
273 de::Unexpected::Other(value.type_name()),
274 &"table",
275 )),
276 }
277 }
278
279 #[inline]
280 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
281 where
282 V: de::Visitor<'de>,
283 {
284 self.deserialize_seq(visitor)
285 }
286
287 #[inline]
288 fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value>
289 where
290 V: de::Visitor<'de>,
291 {
292 self.deserialize_seq(visitor)
293 }
294
295 #[inline]
296 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
297 where
298 V: de::Visitor<'de>,
299 {
300 match self.value {
301 Value::Table(t) => {
302 let _guard = RecursionGuard::new(&t, &self.visited);
303
304 let mut deserializer = MapDeserializer {
305 pairs: MapPairs::new(&t, self.options.sort_keys)?,
306 value: None,
307 options: self.options,
308 visited: self.visited,
309 processed: 0,
310 };
311 let map = visitor.visit_map(&mut deserializer)?;
312 let count = deserializer.pairs.count();
313 if count == 0 {
314 Ok(map)
315 } else {
316 Err(de::Error::invalid_length(
317 deserializer.processed + count,
318 &"fewer elements in the table",
319 ))
320 }
321 }
322 Value::UserData(ud) if ud.is_serializable() => {
323 serde_userdata(ud, |value| value.deserialize_map(visitor))
324 }
325 value => Err(de::Error::invalid_type(
326 de::Unexpected::Other(value.type_name()),
327 &"table",
328 )),
329 }
330 }
331
332 #[inline]
333 fn deserialize_struct<V>(
334 self,
335 _name: &'static str,
336 _fields: &'static [&'static str],
337 visitor: V,
338 ) -> Result<V::Value>
339 where
340 V: de::Visitor<'de>,
341 {
342 self.deserialize_map(visitor)
343 }
344
345 #[inline]
346 fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
347 where
348 V: de::Visitor<'de>,
349 {
350 match self.value {
351 Value::UserData(ud) if ud.is_serializable() => {
352 serde_userdata(ud, |value| value.deserialize_newtype_struct(name, visitor))
353 }
354 _ => visitor.visit_newtype_struct(self),
355 }
356 }
357
358 #[inline]
359 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
360 where
361 V: de::Visitor<'de>,
362 {
363 match self.value {
364 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
365 _ => self.deserialize_any(visitor),
366 }
367 }
368
369 #[inline]
370 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
371 where
372 V: de::Visitor<'de>,
373 {
374 match self.value {
375 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_unit(),
376 _ => self.deserialize_any(visitor),
377 }
378 }
379
380 serde::forward_to_deserialize_any! {
381 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
382 byte_buf identifier ignored_any
383 }
384}
385
386struct SeqDeserializer<'a> {
387 seq: TableSequence<'a, Value>,
388 options: Options,
389 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
390}
391
392impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_> {
393 type Error = Error;
394
395 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
396 where
397 T: de::DeserializeSeed<'de>,
398 {
399 loop {
400 match self.seq.next() {
401 Some(value) => {
402 let value = value?;
403 let skip = check_value_for_skip(&value, self.options, &self.visited)
404 .map_err(|err| Error::DeserializeError(err.to_string()))?;
405 if skip {
406 continue;
407 }
408 let visited = Rc::clone(&self.visited);
409 let deserializer = Deserializer::from_parts(value, self.options, visited);
410 return seed.deserialize(deserializer).map(Some);
411 }
412 None => return Ok(None),
413 }
414 }
415 }
416
417 fn size_hint(&self) -> Option<usize> {
418 match self.seq.size_hint() {
419 (lower, Some(upper)) if lower == upper => Some(upper),
420 _ => None,
421 }
422 }
423}
424
425#[cfg(feature = "luau")]
426struct VecDeserializer {
427 vec: crate::types::Vector,
428 next: usize,
429 options: Options,
430 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
431}
432
433#[cfg(feature = "luau")]
434impl<'de> de::SeqAccess<'de> for VecDeserializer {
435 type Error = Error;
436
437 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
438 where
439 T: de::DeserializeSeed<'de>,
440 {
441 match self.vec.0.get(self.next) {
442 Some(&n) => {
443 self.next += 1;
444 let visited = Rc::clone(&self.visited);
445 let deserializer = Deserializer::from_parts(Value::Number(n as _), self.options, visited);
446 seed.deserialize(deserializer).map(Some)
447 }
448 None => Ok(None),
449 }
450 }
451
452 fn size_hint(&self) -> Option<usize> {
453 Some(crate::types::Vector::SIZE)
454 }
455}
456
457pub(crate) enum MapPairs<'a> {
458 Iter(TablePairs<'a, Value, Value>),
459 Vec(Vec<(Value, Value)>),
460}
461
462impl<'a> MapPairs<'a> {
463 pub(crate) fn new(t: &'a Table, sort_keys: bool) -> Result<Self> {
464 if sort_keys {
465 let mut pairs = t.pairs::<Value, Value>().collect::<Result<Vec<_>>>()?;
466 pairs.sort_by(|(a, _), (b, _)| b.cmp(a)); Ok(MapPairs::Vec(pairs))
468 } else {
469 Ok(MapPairs::Iter(t.pairs::<Value, Value>()))
470 }
471 }
472
473 pub(crate) fn count(self) -> usize {
474 match self {
475 MapPairs::Iter(iter) => iter.count(),
476 MapPairs::Vec(vec) => vec.len(),
477 }
478 }
479
480 pub(crate) fn size_hint(&self) -> (usize, Option<usize>) {
481 match self {
482 MapPairs::Iter(iter) => iter.size_hint(),
483 MapPairs::Vec(vec) => (vec.len(), Some(vec.len())),
484 }
485 }
486}
487
488impl Iterator for MapPairs<'_> {
489 type Item = Result<(Value, Value)>;
490
491 fn next(&mut self) -> Option<Self::Item> {
492 match self {
493 MapPairs::Iter(iter) => iter.next(),
494 MapPairs::Vec(vec) => vec.pop().map(Ok),
495 }
496 }
497}
498
499struct MapDeserializer<'a> {
500 pairs: MapPairs<'a>,
501 value: Option<Value>,
502 options: Options,
503 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
504 processed: usize,
505}
506
507impl<'a> MapDeserializer<'a> {
508 fn next_key_deserializer(&mut self) -> Result<Option<Deserializer>> {
509 loop {
510 match self.pairs.next() {
511 Some(item) => {
512 let (key, value) = item?;
513 let skip_key = check_value_for_skip(&key, self.options, &self.visited)
514 .map_err(|err| Error::DeserializeError(err.to_string()))?;
515 let skip_value = check_value_for_skip(&value, self.options, &self.visited)
516 .map_err(|err| Error::DeserializeError(err.to_string()))?;
517 if skip_key || skip_value {
518 continue;
519 }
520 self.processed += 1;
521 self.value = Some(value);
522 let visited = Rc::clone(&self.visited);
523 let key_de = Deserializer::from_parts(key, self.options, visited);
524 return Ok(Some(key_de));
525 }
526 None => return Ok(None),
527 }
528 }
529 }
530
531 fn next_value_deserializer(&mut self) -> Result<Deserializer> {
532 match self.value.take() {
533 Some(value) => {
534 let visited = Rc::clone(&self.visited);
535 Ok(Deserializer::from_parts(value, self.options, visited))
536 }
537 None => Err(de::Error::custom("value is missing")),
538 }
539 }
540}
541
542impl<'de> de::MapAccess<'de> for MapDeserializer<'_> {
543 type Error = Error;
544
545 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
546 where
547 T: de::DeserializeSeed<'de>,
548 {
549 match self.next_key_deserializer() {
550 Ok(Some(key_de)) => seed.deserialize(key_de).map(Some),
551 Ok(None) => Ok(None),
552 Err(error) => Err(error),
553 }
554 }
555
556 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
557 where
558 T: de::DeserializeSeed<'de>,
559 {
560 match self.next_value_deserializer() {
561 Ok(value_de) => seed.deserialize(value_de),
562 Err(error) => Err(error),
563 }
564 }
565
566 fn size_hint(&self) -> Option<usize> {
567 match self.pairs.size_hint() {
568 (lower, Some(upper)) if lower == upper => Some(upper),
569 _ => None,
570 }
571 }
572}
573
574struct EnumDeserializer {
575 variant: StdString,
576 value: Option<Value>,
577 options: Options,
578 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
579}
580
581impl<'de> de::EnumAccess<'de> for EnumDeserializer {
582 type Error = Error;
583 type Variant = VariantDeserializer;
584
585 fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
586 where
587 T: de::DeserializeSeed<'de>,
588 {
589 let variant = self.variant.into_deserializer();
590 let variant_access = VariantDeserializer {
591 value: self.value,
592 options: self.options,
593 visited: self.visited,
594 };
595 seed.deserialize(variant).map(|v| (v, variant_access))
596 }
597}
598
599struct VariantDeserializer {
600 value: Option<Value>,
601 options: Options,
602 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
603}
604
605impl<'de> de::VariantAccess<'de> for VariantDeserializer {
606 type Error = Error;
607
608 fn unit_variant(self) -> Result<()> {
609 match self.value {
610 Some(_) => Err(de::Error::invalid_type(
611 de::Unexpected::NewtypeVariant,
612 &"unit variant",
613 )),
614 None => Ok(()),
615 }
616 }
617
618 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
619 where
620 T: de::DeserializeSeed<'de>,
621 {
622 match self.value {
623 Some(value) => seed.deserialize(Deserializer::from_parts(value, self.options, self.visited)),
624 None => Err(de::Error::invalid_type(
625 de::Unexpected::UnitVariant,
626 &"newtype variant",
627 )),
628 }
629 }
630
631 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
632 where
633 V: de::Visitor<'de>,
634 {
635 match self.value {
636 Some(value) => serde::Deserializer::deserialize_seq(
637 Deserializer::from_parts(value, self.options, self.visited),
638 visitor,
639 ),
640 None => Err(de::Error::invalid_type(
641 de::Unexpected::UnitVariant,
642 &"tuple variant",
643 )),
644 }
645 }
646
647 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
648 where
649 V: de::Visitor<'de>,
650 {
651 match self.value {
652 Some(value) => serde::Deserializer::deserialize_map(
653 Deserializer::from_parts(value, self.options, self.visited),
654 visitor,
655 ),
656 None => Err(de::Error::invalid_type(
657 de::Unexpected::UnitVariant,
658 &"struct variant",
659 )),
660 }
661 }
662}
663
664pub(crate) struct RecursionGuard {
667 ptr: *const c_void,
668 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
669}
670
671impl RecursionGuard {
672 #[inline]
673 pub(crate) fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
674 let visited = Rc::clone(visited);
675 let ptr = table.to_pointer();
676 visited.borrow_mut().insert(ptr);
677 RecursionGuard { ptr, visited }
678 }
679}
680
681impl Drop for RecursionGuard {
682 fn drop(&mut self) {
683 self.visited.borrow_mut().remove(&self.ptr);
684 }
685}
686
687pub(crate) fn check_value_for_skip(
689 value: &Value,
690 options: Options,
691 visited: &RefCell<FxHashSet<*const c_void>>,
692) -> StdResult<bool, &'static str> {
693 match value {
694 Value::Table(table) => {
695 let ptr = table.to_pointer();
696 if visited.borrow().contains(&ptr) {
697 if options.deny_recursive_tables {
698 return Err("recursive table detected");
699 }
700 return Ok(true); }
702 }
703 Value::UserData(ud) if ud.is_serializable() => {}
704 Value::Function(_)
705 | Value::Thread(_)
706 | Value::UserData(_)
707 | Value::LightUserData(_)
708 | Value::Error(_)
709 if !options.deny_unsupported_types =>
710 {
711 return Ok(true); }
713 _ => {}
714 }
715 Ok(false) }
717
718fn serde_userdata<V>(
719 ud: AnyUserData,
720 f: impl FnOnce(serde_value::Value) -> std::result::Result<V, serde_value::DeserializerError>,
721) -> Result<V> {
722 match serde_value::to_value(ud) {
723 Ok(value) => match f(value) {
724 Ok(r) => Ok(r),
725 Err(error) => Err(Error::DeserializeError(error.to_string())),
726 },
727 Err(error) => Err(Error::SerializeError(error.to_string())),
728 }
729}