1use std::cell::RefCell;
2use std::convert::TryInto;
3use std::os::raw::c_void;
4use std::rc::Rc;
5use std::string::String as StdString;
6
7use rustc_hash::FxHashSet;
8use serde::de::{self, IntoDeserializer};
9
10use crate::error::{Error, Result};
11use crate::ffi;
12use crate::table::{Table, TablePairs, TableSequence};
13use crate::value::Value;
14
15#[derive(Debug)]
17pub struct Deserializer<'lua> {
18 value: Value<'lua>,
19 options: Options,
20 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
21}
22
23#[derive(Debug, Clone, Copy)]
25#[non_exhaustive]
26pub struct Options {
27 pub deny_unsupported_types: bool,
38
39 pub deny_recursive_tables: bool,
45}
46
47impl Default for Options {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl Options {
54 pub const fn new() -> Self {
56 Options {
57 deny_unsupported_types: true,
58 deny_recursive_tables: true,
59 }
60 }
61
62 #[must_use]
66 pub const fn deny_unsupported_types(mut self, enabled: bool) -> Self {
67 self.deny_unsupported_types = enabled;
68 self
69 }
70
71 #[must_use]
75 pub fn deny_recursive_tables(mut self, enabled: bool) -> Self {
76 self.deny_recursive_tables = enabled;
77 self
78 }
79}
80
81impl<'lua> Deserializer<'lua> {
82 pub fn new(value: Value<'lua>) -> Self {
84 Self::new_with_options(value, Options::default())
85 }
86
87 pub fn new_with_options(value: Value<'lua>, options: Options) -> Self {
89 Deserializer {
90 value,
91 options,
92 visited: Rc::new(RefCell::new(FxHashSet::default())),
93 }
94 }
95
96 fn from_parts(
97 value: Value<'lua>,
98 options: Options,
99 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
100 ) -> Self {
101 Deserializer {
102 value,
103 options,
104 visited,
105 }
106 }
107}
108
109impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
110 type Error = Error;
111
112 #[inline]
113 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
114 where
115 V: de::Visitor<'de>,
116 {
117 match self.value {
118 Value::Nil => visitor.visit_unit(),
119 Value::Boolean(b) => visitor.visit_bool(b),
120 #[allow(clippy::useless_conversion)]
121 Value::Integer(i) => {
122 visitor.visit_i64(i.try_into().expect("cannot convert lua_Integer to i64"))
123 }
124 #[allow(clippy::useless_conversion)]
125 Value::Number(n) => visitor.visit_f64(n.into()),
126 #[cfg(feature = "luau")]
127 Value::Vector(_, _, _) => self.deserialize_seq(visitor),
128 Value::String(s) => match s.to_str() {
129 Ok(s) => visitor.visit_str(s),
130 Err(_) => visitor.visit_bytes(s.as_bytes()),
131 },
132 Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor),
133 Value::Table(_) => self.deserialize_map(visitor),
134 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
135 Value::Function(_)
136 | Value::Thread(_)
137 | Value::UserData(_)
138 | Value::LightUserData(_)
139 | Value::Error(_) => {
140 if self.options.deny_unsupported_types {
141 Err(de::Error::custom(format!(
142 "unsupported value type `{}`",
143 self.value.type_name()
144 )))
145 } else {
146 visitor.visit_unit()
147 }
148 }
149 }
150 }
151
152 #[inline]
153 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
154 where
155 V: de::Visitor<'de>,
156 {
157 match self.value {
158 Value::Nil => visitor.visit_none(),
159 Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(),
160 _ => visitor.visit_some(self),
161 }
162 }
163
164 #[inline]
165 fn deserialize_enum<V>(
166 self,
167 _name: &str,
168 _variants: &'static [&'static str],
169 visitor: V,
170 ) -> Result<V::Value>
171 where
172 V: de::Visitor<'de>,
173 {
174 let (variant, value, _guard) = match self.value {
175 Value::Table(table) => {
176 let _guard = RecursionGuard::new(&table, &self.visited);
177
178 let mut iter = table.pairs::<StdString, Value>();
179 let (variant, value) = match iter.next() {
180 Some(v) => v?,
181 None => {
182 return Err(de::Error::invalid_value(
183 de::Unexpected::Map,
184 &"map with a single key",
185 ))
186 }
187 };
188
189 if iter.next().is_some() {
190 return Err(de::Error::invalid_value(
191 de::Unexpected::Map,
192 &"map with a single key",
193 ));
194 }
195 if check_value_if_skip(&value, self.options, &self.visited)? {
196 return Err(de::Error::custom("bad enum value"));
197 }
198
199 (variant, Some(value), Some(_guard))
200 }
201 Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
202 _ => return Err(de::Error::custom("bad enum value")),
203 };
204
205 visitor.visit_enum(EnumDeserializer {
206 variant,
207 value,
208 options: self.options,
209 visited: self.visited,
210 })
211 }
212
213 #[inline]
214 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
215 where
216 V: de::Visitor<'de>,
217 {
218 match self.value {
219 #[cfg(feature = "luau")]
220 Value::Vector(x, y, z) => {
221 let mut deserializer = VecDeserializer {
222 vec: [x, y, z],
223 next: 0,
224 options: self.options,
225 visited: self.visited,
226 };
227 visitor.visit_seq(&mut deserializer)
228 }
229 Value::Table(t) => {
230 let _guard = RecursionGuard::new(&t, &self.visited);
231
232 let len = t.raw_len() as usize;
233 let mut deserializer = SeqDeserializer {
234 seq: t.raw_sequence_values(),
235 options: self.options,
236 visited: self.visited,
237 };
238 let seq = visitor.visit_seq(&mut deserializer)?;
239 if deserializer.seq.count() == 0 {
240 Ok(seq)
241 } else {
242 Err(de::Error::invalid_length(
243 len,
244 &"fewer elements in the table",
245 ))
246 }
247 }
248 value => Err(de::Error::invalid_type(
249 de::Unexpected::Other(value.type_name()),
250 &"table",
251 )),
252 }
253 }
254
255 #[inline]
256 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
257 where
258 V: de::Visitor<'de>,
259 {
260 self.deserialize_seq(visitor)
261 }
262
263 #[inline]
264 fn deserialize_tuple_struct<V>(
265 self,
266 _name: &'static str,
267 _len: usize,
268 visitor: V,
269 ) -> Result<V::Value>
270 where
271 V: de::Visitor<'de>,
272 {
273 self.deserialize_seq(visitor)
274 }
275
276 #[inline]
277 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
278 where
279 V: de::Visitor<'de>,
280 {
281 match self.value {
282 Value::Table(t) => {
283 let _guard = RecursionGuard::new(&t, &self.visited);
284
285 let mut deserializer = MapDeserializer {
286 pairs: t.pairs(),
287 value: None,
288 options: self.options,
289 visited: self.visited,
290 processed: 0,
291 };
292 let map = visitor.visit_map(&mut deserializer)?;
293 let count = deserializer.pairs.count();
294 if count == 0 {
295 Ok(map)
296 } else {
297 Err(de::Error::invalid_length(
298 deserializer.processed + count,
299 &"fewer elements in the table",
300 ))
301 }
302 }
303 value => Err(de::Error::invalid_type(
304 de::Unexpected::Other(value.type_name()),
305 &"table",
306 )),
307 }
308 }
309
310 #[inline]
311 fn deserialize_struct<V>(
312 self,
313 _name: &'static str,
314 _fields: &'static [&'static str],
315 visitor: V,
316 ) -> Result<V::Value>
317 where
318 V: de::Visitor<'de>,
319 {
320 self.deserialize_map(visitor)
321 }
322
323 #[inline]
324 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
325 where
326 V: de::Visitor<'de>,
327 {
328 visitor.visit_newtype_struct(self)
329 }
330
331 serde::forward_to_deserialize_any! {
332 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes
333 byte_buf unit unit_struct identifier ignored_any
334 }
335}
336
337struct SeqDeserializer<'lua> {
338 seq: TableSequence<'lua, Value<'lua>>,
339 options: Options,
340 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
341}
342
343impl<'lua, 'de> de::SeqAccess<'de> for SeqDeserializer<'lua> {
344 type Error = Error;
345
346 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
347 where
348 T: de::DeserializeSeed<'de>,
349 {
350 loop {
351 match self.seq.next() {
352 Some(value) => {
353 let value = value?;
354 if check_value_if_skip(&value, self.options, &self.visited)? {
355 continue;
356 }
357 let visited = Rc::clone(&self.visited);
358 let deserializer = Deserializer::from_parts(value, self.options, visited);
359 return seed.deserialize(deserializer).map(Some);
360 }
361 None => return Ok(None),
362 }
363 }
364 }
365
366 fn size_hint(&self) -> Option<usize> {
367 match self.seq.size_hint() {
368 (lower, Some(upper)) if lower == upper => Some(upper),
369 _ => None,
370 }
371 }
372}
373
374#[cfg(feature = "luau")]
375struct VecDeserializer {
376 vec: [f32; 3],
377 next: usize,
378 options: Options,
379 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
380}
381
382#[cfg(feature = "luau")]
383impl<'de> de::SeqAccess<'de> for VecDeserializer {
384 type Error = Error;
385
386 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
387 where
388 T: de::DeserializeSeed<'de>,
389 {
390 match self.vec.get(self.next) {
391 Some(&n) => {
392 self.next += 1;
393 let visited = Rc::clone(&self.visited);
394 let deserializer =
395 Deserializer::from_parts(Value::Number(n as _), self.options, visited);
396 seed.deserialize(deserializer).map(Some)
397 }
398 None => Ok(None),
399 }
400 }
401
402 fn size_hint(&self) -> Option<usize> {
403 Some(3)
404 }
405}
406
407struct MapDeserializer<'lua> {
408 pairs: TablePairs<'lua, Value<'lua>, Value<'lua>>,
409 value: Option<Value<'lua>>,
410 options: Options,
411 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
412 processed: usize,
413}
414
415impl<'lua, 'de> de::MapAccess<'de> for MapDeserializer<'lua> {
416 type Error = Error;
417
418 fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
419 where
420 T: de::DeserializeSeed<'de>,
421 {
422 loop {
423 match self.pairs.next() {
424 Some(item) => {
425 let (key, value) = item?;
426 if check_value_if_skip(&key, self.options, &self.visited)?
427 || check_value_if_skip(&value, self.options, &self.visited)?
428 {
429 continue;
430 }
431 self.processed += 1;
432 self.value = Some(value);
433 let visited = Rc::clone(&self.visited);
434 let key_de = Deserializer::from_parts(key, self.options, visited);
435 return seed.deserialize(key_de).map(Some);
436 }
437 None => return Ok(None),
438 }
439 }
440 }
441
442 fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
443 where
444 T: de::DeserializeSeed<'de>,
445 {
446 match self.value.take() {
447 Some(value) => {
448 let visited = Rc::clone(&self.visited);
449 seed.deserialize(Deserializer::from_parts(value, self.options, visited))
450 }
451 None => Err(de::Error::custom("value is missing")),
452 }
453 }
454
455 fn size_hint(&self) -> Option<usize> {
456 match self.pairs.size_hint() {
457 (lower, Some(upper)) if lower == upper => Some(upper),
458 _ => None,
459 }
460 }
461}
462
463struct EnumDeserializer<'lua> {
464 variant: StdString,
465 value: Option<Value<'lua>>,
466 options: Options,
467 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
468}
469
470impl<'lua, 'de> de::EnumAccess<'de> for EnumDeserializer<'lua> {
471 type Error = Error;
472 type Variant = VariantDeserializer<'lua>;
473
474 fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
475 where
476 T: de::DeserializeSeed<'de>,
477 {
478 let variant = self.variant.into_deserializer();
479 let variant_access = VariantDeserializer {
480 value: self.value,
481 options: self.options,
482 visited: self.visited,
483 };
484 seed.deserialize(variant).map(|v| (v, variant_access))
485 }
486}
487
488struct VariantDeserializer<'lua> {
489 value: Option<Value<'lua>>,
490 options: Options,
491 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
492}
493
494impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
495 type Error = Error;
496
497 fn unit_variant(self) -> Result<()> {
498 match self.value {
499 Some(_) => Err(de::Error::invalid_type(
500 de::Unexpected::NewtypeVariant,
501 &"unit variant",
502 )),
503 None => Ok(()),
504 }
505 }
506
507 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
508 where
509 T: de::DeserializeSeed<'de>,
510 {
511 match self.value {
512 Some(value) => {
513 seed.deserialize(Deserializer::from_parts(value, self.options, self.visited))
514 }
515 None => Err(de::Error::invalid_type(
516 de::Unexpected::UnitVariant,
517 &"newtype variant",
518 )),
519 }
520 }
521
522 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
523 where
524 V: de::Visitor<'de>,
525 {
526 match self.value {
527 Some(value) => serde::Deserializer::deserialize_seq(
528 Deserializer::from_parts(value, self.options, self.visited),
529 visitor,
530 ),
531 None => Err(de::Error::invalid_type(
532 de::Unexpected::UnitVariant,
533 &"tuple variant",
534 )),
535 }
536 }
537
538 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
539 where
540 V: de::Visitor<'de>,
541 {
542 match self.value {
543 Some(value) => serde::Deserializer::deserialize_map(
544 Deserializer::from_parts(value, self.options, self.visited),
545 visitor,
546 ),
547 None => Err(de::Error::invalid_type(
548 de::Unexpected::UnitVariant,
549 &"struct variant",
550 )),
551 }
552 }
553}
554
555struct RecursionGuard {
558 ptr: *const c_void,
559 visited: Rc<RefCell<FxHashSet<*const c_void>>>,
560}
561
562impl RecursionGuard {
563 #[inline]
564 fn new(table: &Table, visited: &Rc<RefCell<FxHashSet<*const c_void>>>) -> Self {
565 let visited = Rc::clone(visited);
566 let lua = table.0.lua;
567 let ptr =
568 unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
569 visited.borrow_mut().insert(ptr);
570 RecursionGuard { ptr, visited }
571 }
572}
573
574impl Drop for RecursionGuard {
575 fn drop(&mut self) {
576 self.visited.borrow_mut().remove(&self.ptr);
577 }
578}
579
580fn check_value_if_skip(
582 value: &Value,
583 options: Options,
584 visited: &RefCell<FxHashSet<*const c_void>>,
585) -> Result<bool> {
586 match value {
587 Value::Table(table) => {
588 let lua = table.0.lua;
589 let ptr =
590 unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
591 if visited.borrow().contains(&ptr) {
592 if options.deny_recursive_tables {
593 return Err(de::Error::custom("recursive table detected"));
594 }
595 return Ok(true); }
597 }
598 Value::Function(_)
599 | Value::Thread(_)
600 | Value::UserData(_)
601 | Value::LightUserData(_)
602 | Value::Error(_)
603 if !options.deny_unsupported_types =>
604 {
605 return Ok(true); }
607 _ => {}
608 }
609 Ok(false) }