piccolo_util/serde/
de.rs

1use std::{borrow::Cow, fmt};
2
3use piccolo::{table::NextValue, Table, Value};
4use serde::de;
5use thiserror::Error;
6
7use super::markers::{is_none, is_unit};
8
9#[derive(Debug, Error)]
10pub enum Error {
11    #[error("{0}")]
12    Message(String),
13    #[error("expected {expected}, found {found}")]
14    TypeError {
15        expected: &'static str,
16        found: &'static str,
17    },
18}
19
20impl de::Error for Error {
21    fn custom<T: fmt::Display>(msg: T) -> Self {
22        Error::Message(msg.to_string())
23    }
24}
25
26pub fn from_value<'gc, T: de::Deserialize<'gc>>(value: Value<'gc>) -> Result<T, Error> {
27    T::deserialize(Deserializer::from_value(value))
28}
29
30pub struct Deserializer<'gc> {
31    value: Value<'gc>,
32}
33
34impl<'gc> Deserializer<'gc> {
35    pub fn from_value(value: Value<'gc>) -> Self {
36        Self { value }
37    }
38}
39
40impl<'gc> de::Deserializer<'gc> for Deserializer<'gc> {
41    type Error = Error;
42
43    fn deserialize_any<V: de::Visitor<'gc>>(self, visitor: V) -> Result<V::Value, Error> {
44        match self.value {
45            Value::Nil => self.deserialize_unit(visitor),
46            Value::Boolean(_) => self.deserialize_bool(visitor),
47            Value::Integer(_) => self.deserialize_i64(visitor),
48            Value::Number(_) => self.deserialize_f64(visitor),
49            Value::String(_) => self.deserialize_bytes(visitor),
50            Value::Table(t) => {
51                if is_sequence(t) {
52                    self.deserialize_seq(visitor)
53                } else {
54                    self.deserialize_map(visitor)
55                }
56            }
57            Value::Function(_) => Err(de::Error::custom("cannot deserialize from function")),
58            Value::Thread(_) => Err(de::Error::custom("cannot deserialize from thread")),
59            Value::UserData(_) => Err(de::Error::custom("cannot deserialize from userdata")),
60        }
61    }
62
63    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Error>
64    where
65        V: de::Visitor<'gc>,
66    {
67        visitor.visit_bool(self.value.to_bool())
68    }
69
70    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Error>
71    where
72        V: de::Visitor<'gc>,
73    {
74        self.deserialize_i64(visitor)
75    }
76
77    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Error>
78    where
79        V: de::Visitor<'gc>,
80    {
81        self.deserialize_i64(visitor)
82    }
83
84    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Error>
85    where
86        V: de::Visitor<'gc>,
87    {
88        self.deserialize_i64(visitor)
89    }
90
91    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Error>
92    where
93        V: de::Visitor<'gc>,
94    {
95        if let Some(i) = self.value.to_integer() {
96            visitor.visit_i64(i)
97        } else {
98            Err(Error::TypeError {
99                expected: "integer",
100                found: self.value.type_name(),
101            })
102        }
103    }
104
105    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Error>
106    where
107        V: de::Visitor<'gc>,
108    {
109        self.deserialize_i64(visitor)
110    }
111
112    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Error>
113    where
114        V: de::Visitor<'gc>,
115    {
116        self.deserialize_i64(visitor)
117    }
118
119    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Error>
120    where
121        V: de::Visitor<'gc>,
122    {
123        self.deserialize_i64(visitor)
124    }
125
126    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Error>
127    where
128        V: de::Visitor<'gc>,
129    {
130        self.deserialize_i64(visitor)
131    }
132
133    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Error>
134    where
135        V: de::Visitor<'gc>,
136    {
137        self.deserialize_f64(visitor)
138    }
139
140    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Error>
141    where
142        V: de::Visitor<'gc>,
143    {
144        if let Some(f) = self.value.to_number() {
145            visitor.visit_f64(f)
146        } else {
147            Err(Error::TypeError {
148                expected: "number",
149                found: self.value.type_name(),
150            })
151        }
152    }
153
154    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Error>
155    where
156        V: de::Visitor<'gc>,
157    {
158        self.deserialize_str(visitor)
159    }
160
161    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Error>
162    where
163        V: de::Visitor<'gc>,
164    {
165        if let Value::String(s) = self.value {
166            match s.to_str_lossy() {
167                Cow::Borrowed(s) => visitor.visit_borrowed_str(s),
168                Cow::Owned(s) => visitor.visit_string(s),
169            }
170        } else {
171            visitor.visit_string(self.value.to_string())
172        }
173    }
174
175    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Error>
176    where
177        V: de::Visitor<'gc>,
178    {
179        self.deserialize_str(visitor)
180    }
181
182    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Error>
183    where
184        V: de::Visitor<'gc>,
185    {
186        if let Value::String(s) = self.value {
187            visitor.visit_borrowed_bytes(s.as_bytes())
188        } else {
189            Err(Error::TypeError {
190                expected: "string",
191                found: self.value.type_name(),
192            })
193        }
194    }
195
196    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Error>
197    where
198        V: de::Visitor<'gc>,
199    {
200        self.deserialize_bytes(visitor)
201    }
202
203    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
204    where
205        V: de::Visitor<'gc>,
206    {
207        match self.value {
208            Value::Nil => visitor.visit_none(),
209            Value::UserData(ud) if is_none(ud) => visitor.visit_none(),
210            _ => visitor.visit_some(self),
211        }
212    }
213
214    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Error>
215    where
216        V: de::Visitor<'gc>,
217    {
218        match self.value {
219            Value::Nil => visitor.visit_unit(),
220            Value::UserData(ud) if is_unit(ud) => visitor.visit_unit(),
221            v => Err(Error::TypeError {
222                expected: "nil or unit",
223                found: v.type_name(),
224            }),
225        }
226    }
227
228    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Error>
229    where
230        V: de::Visitor<'gc>,
231    {
232        self.deserialize_unit(visitor)
233    }
234
235    fn deserialize_newtype_struct<V>(
236        self,
237        _name: &'static str,
238        visitor: V,
239    ) -> Result<V::Value, Error>
240    where
241        V: de::Visitor<'gc>,
242    {
243        visitor.visit_newtype_struct(self)
244    }
245
246    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Error>
247    where
248        V: de::Visitor<'gc>,
249    {
250        if let Value::Table(table) = self.value {
251            visitor.visit_seq(Seq::new(table))
252        } else {
253            Err(Error::TypeError {
254                expected: "table",
255                found: self.value.type_name(),
256            })
257        }
258    }
259
260    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Error>
261    where
262        V: de::Visitor<'gc>,
263    {
264        if let Value::Table(table) = self.value {
265            visitor.visit_seq(Tuple::new(
266                table,
267                len.try_into()
268                    .map_err(|_| de::Error::custom("tuple length out of range"))?,
269            ))
270        } else {
271            Err(Error::TypeError {
272                expected: "table",
273                found: self.value.type_name(),
274            })
275        }
276    }
277
278    fn deserialize_tuple_struct<V>(
279        self,
280        _name: &'static str,
281        len: usize,
282        visitor: V,
283    ) -> Result<V::Value, Error>
284    where
285        V: de::Visitor<'gc>,
286    {
287        self.deserialize_tuple(len, visitor)
288    }
289
290    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Error>
291    where
292        V: de::Visitor<'gc>,
293    {
294        if let Value::Table(table) = self.value {
295            visitor.visit_map(Map::new(table))
296        } else {
297            Err(Error::TypeError {
298                expected: "table",
299                found: self.value.type_name(),
300            })
301        }
302    }
303
304    fn deserialize_struct<V>(
305        self,
306        _name: &'static str,
307        _fields: &'static [&'static str],
308        visitor: V,
309    ) -> Result<V::Value, Error>
310    where
311        V: de::Visitor<'gc>,
312    {
313        self.deserialize_map(visitor)
314    }
315
316    fn deserialize_enum<V>(
317        self,
318        _name: &'static str,
319        _variants: &'static [&'static str],
320        visitor: V,
321    ) -> Result<V::Value, Error>
322    where
323        V: de::Visitor<'gc>,
324    {
325        match self.value {
326            Value::Table(table) => match table.next(Value::Nil) {
327                NextValue::Found { key, value } => visitor.visit_enum(Enum::new(key, value)),
328                NextValue::Last => Err(de::Error::custom("enum table has no entries")),
329                NextValue::NotFound => unreachable!(),
330            },
331            v => visitor.visit_enum(UnitEnum::new(v)),
332        }
333    }
334
335    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Error>
336    where
337        V: de::Visitor<'gc>,
338    {
339        self.deserialize_str(visitor)
340    }
341
342    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
343    where
344        V: de::Visitor<'gc>,
345    {
346        self.deserialize_any(visitor)
347    }
348}
349
350pub struct Seq<'gc> {
351    table: Table<'gc>,
352    ind: i64,
353}
354
355impl<'gc> Seq<'gc> {
356    fn new(table: Table<'gc>) -> Self {
357        Self { table, ind: 1 }
358    }
359}
360
361impl<'gc> de::SeqAccess<'gc> for Seq<'gc> {
362    type Error = Error;
363
364    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Error>
365    where
366        T: de::DeserializeSeed<'gc>,
367    {
368        let v = self.table.get_value(Value::Integer(self.ind));
369        if v.is_nil() {
370            Ok(None)
371        } else {
372            let res = Some(seed.deserialize(Deserializer::from_value(v))?);
373            self.ind = self
374                .ind
375                .checked_add(1)
376                .ok_or(de::Error::custom("index overflow"))?;
377            Ok(res)
378        }
379    }
380}
381
382pub struct Tuple<'gc> {
383    table: Table<'gc>,
384    len: i64,
385    ind: i64,
386}
387
388impl<'gc> Tuple<'gc> {
389    fn new(table: Table<'gc>, len: i64) -> Self {
390        Self { table, len, ind: 1 }
391    }
392}
393
394impl<'gc> de::SeqAccess<'gc> for Tuple<'gc> {
395    type Error = Error;
396
397    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Error>
398    where
399        T: de::DeserializeSeed<'gc>,
400    {
401        if self.ind > self.len {
402            Ok(None)
403        } else {
404            let v = self.table.get_value(Value::Integer(self.ind));
405            let res = Some(seed.deserialize(Deserializer::from_value(v))?);
406            self.ind += 1;
407            Ok(res)
408        }
409    }
410}
411
412pub struct Map<'gc> {
413    table: Table<'gc>,
414    key: Value<'gc>,
415    value: Value<'gc>,
416}
417
418impl<'gc> Map<'gc> {
419    fn new(table: Table<'gc>) -> Self {
420        Self {
421            table,
422            key: Value::Nil,
423            value: Value::Nil,
424        }
425    }
426}
427
428impl<'gc> de::MapAccess<'gc> for Map<'gc> {
429    type Error = Error;
430
431    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
432    where
433        K: de::DeserializeSeed<'gc>,
434    {
435        match self.table.next(self.key) {
436            NextValue::Found { key, value } => {
437                self.key = key;
438                self.value = value;
439                seed.deserialize(Deserializer::from_value(self.key))
440                    .map(Some)
441            }
442            NextValue::Last => Ok(None),
443            NextValue::NotFound => unreachable!(),
444        }
445    }
446
447    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
448    where
449        V: de::DeserializeSeed<'gc>,
450    {
451        seed.deserialize(Deserializer::from_value(self.value))
452    }
453}
454
455pub struct Enum<'gc> {
456    key: Value<'gc>,
457    value: Value<'gc>,
458}
459
460impl<'gc> Enum<'gc> {
461    fn new(key: Value<'gc>, value: Value<'gc>) -> Self {
462        Self { key, value }
463    }
464}
465
466impl<'gc> de::EnumAccess<'gc> for Enum<'gc> {
467    type Error = Error;
468    type Variant = Variant<'gc>;
469
470    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Variant<'gc>), Error>
471    where
472        V: de::DeserializeSeed<'gc>,
473    {
474        Ok((
475            seed.deserialize(Deserializer::from_value(self.key))?,
476            Variant::new(self.value),
477        ))
478    }
479}
480
481pub struct Variant<'gc> {
482    value: Value<'gc>,
483}
484
485impl<'gc> Variant<'gc> {
486    fn new(value: Value<'gc>) -> Self {
487        Self { value }
488    }
489}
490
491impl<'gc> de::VariantAccess<'gc> for Variant<'gc> {
492    type Error = Error;
493
494    fn unit_variant(self) -> Result<(), Error> {
495        de::Deserialize::deserialize(Deserializer::from_value(self.value))
496    }
497
498    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
499    where
500        T: de::DeserializeSeed<'gc>,
501    {
502        seed.deserialize(Deserializer::from_value(self.value))
503    }
504
505    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Error>
506    where
507        V: de::Visitor<'gc>,
508    {
509        de::Deserializer::deserialize_tuple(Deserializer::from_value(self.value), len, visitor)
510    }
511
512    fn struct_variant<V>(
513        self,
514        _fields: &'static [&'static str],
515        visitor: V,
516    ) -> Result<V::Value, Error>
517    where
518        V: de::Visitor<'gc>,
519    {
520        de::Deserializer::deserialize_map(Deserializer::from_value(self.value), visitor)
521    }
522}
523
524pub struct UnitEnum<'gc> {
525    key: Value<'gc>,
526}
527
528impl<'gc> UnitEnum<'gc> {
529    fn new(key: Value<'gc>) -> Self {
530        Self { key }
531    }
532}
533
534impl<'gc> de::EnumAccess<'gc> for UnitEnum<'gc> {
535    type Error = Error;
536    type Variant = UnitVariant;
537
538    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, UnitVariant), Error>
539    where
540        V: de::DeserializeSeed<'gc>,
541    {
542        Ok((
543            seed.deserialize(Deserializer::from_value(self.key))?,
544            UnitVariant::new(),
545        ))
546    }
547}
548
549pub struct UnitVariant {}
550
551impl UnitVariant {
552    fn new() -> Self {
553        Self {}
554    }
555}
556
557impl<'de> de::VariantAccess<'de> for UnitVariant {
558    type Error = Error;
559
560    fn unit_variant(self) -> Result<(), Error> {
561        Ok(())
562    }
563
564    fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Error>
565    where
566        T: de::DeserializeSeed<'de>,
567    {
568        Err(Error::TypeError {
569            expected: "table",
570            found: "non-table",
571        })
572    }
573
574    fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Error>
575    where
576        V: de::Visitor<'de>,
577    {
578        Err(Error::TypeError {
579            expected: "table",
580            found: "non-table",
581        })
582    }
583
584    fn struct_variant<V>(
585        self,
586        _fields: &'static [&'static str],
587        _visitor: V,
588    ) -> Result<V::Value, Error>
589    where
590        V: de::Visitor<'de>,
591    {
592        Err(Error::TypeError {
593            expected: "table",
594            found: "non-table",
595        })
596    }
597}
598
599fn is_sequence<'gc>(table: Table<'gc>) -> bool {
600    let mut key = match table.next(Value::Nil) {
601        NextValue::Found { key, value: _ } => key,
602        NextValue::Last => return true,
603        NextValue::NotFound => unreachable!(),
604    };
605
606    let mut ind = 1;
607    loop {
608        if !matches!(key, Value::Integer(i) if i == ind) {
609            return false;
610        }
611
612        ind = if let Some(i) = ind.checked_add(1) {
613            i
614        } else {
615            return false;
616        };
617
618        key = match table.next(key) {
619            NextValue::Found { key, value: _ } => key,
620            NextValue::Last => return true,
621            NextValue::NotFound => unreachable!(),
622        };
623    }
624}