clang_ast/
deserializer.rs

1use crate::kind::{AnyKind, Kind, SometimesBorrowedStrDeserializer};
2use crate::Node;
3use serde::de::value::BorrowedStrDeserializer;
4use serde::de::{
5    Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, Expected, IgnoredAny, MapAccess,
6    Unexpected, VariantAccess, Visitor,
7};
8use serde::forward_to_deserialize_any;
9use std::error::Error as StdError;
10use std::fmt::{self, Display};
11use std::marker::PhantomData;
12
13pub(crate) struct NodeDeserializer<'de, 'a, T, M> {
14    kind: &'a AnyKind<'de>,
15    inner: &'a mut Vec<Node<T>>,
16    map: M,
17    has_kind: bool,
18}
19
20impl<'de, 'a, T, M> NodeDeserializer<'de, 'a, T, M> {
21    pub(crate) fn new(kind: &'a AnyKind<'de>, inner: &'a mut Vec<Node<T>>, map: M) -> Self {
22        let has_kind = match kind {
23            AnyKind::Kind(Kind::null) => false,
24            _ => true,
25        };
26        NodeDeserializer {
27            kind,
28            inner,
29            map,
30            has_kind,
31        }
32    }
33}
34
35impl<'de, 'a, T, M> Deserializer<'de> for NodeDeserializer<'de, 'a, T, M>
36where
37    T: Deserialize<'de>,
38    M: MapAccess<'de>,
39{
40    type Error = M::Error;
41
42    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
43    where
44        V: Visitor<'de>,
45    {
46        visitor.visit_map(self)
47    }
48
49    fn deserialize_enum<V>(
50        self,
51        name: &'static str,
52        variants: &'static [&'static str],
53        visitor: V,
54    ) -> Result<V::Value, Self::Error>
55    where
56        V: Visitor<'de>,
57    {
58        let _ = name;
59        let expected = self.kind.as_str();
60        let mut expects_the_unexpected = None;
61        for &variant in variants {
62            if variant == expected {
63                return visitor.visit_enum(self);
64            } else if variant == "Unknown" || variant == "Other" {
65                expects_the_unexpected = Some(variant);
66            }
67        }
68        if let Some(unexpected) = expects_the_unexpected {
69            visitor.visit_enum(UnknownNode {
70                name: unexpected,
71                node: self,
72            })
73        } else {
74            visitor.visit_enum(self)
75        }
76    }
77
78    fn deserialize_unit_struct<V>(
79        self,
80        name: &'static str,
81        visitor: V,
82    ) -> Result<V::Value, Self::Error>
83    where
84        V: Visitor<'de>,
85    {
86        let _ = name;
87        self.deserialize_unit(visitor)
88    }
89
90    fn deserialize_unit<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
91    where
92        V: Visitor<'de>,
93    {
94        self.ignore()?;
95        visitor.visit_unit()
96    }
97
98    forward_to_deserialize_any! {
99        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
100        bytes byte_buf option newtype_struct seq tuple tuple_struct map struct
101        identifier ignored_any
102    }
103}
104
105impl<'de, 'a, T, M> EnumAccess<'de> for NodeDeserializer<'de, 'a, T, M>
106where
107    T: Deserialize<'de>,
108    M: MapAccess<'de>,
109{
110    type Error = M::Error;
111    type Variant = Self;
112
113    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
114    where
115        V: DeserializeSeed<'de>,
116    {
117        let deserializer = match &self.kind {
118            AnyKind::Kind(kind) => SometimesBorrowedStrDeserializer::borrowed(kind.as_str()),
119            AnyKind::Borrowed(kind) => SometimesBorrowedStrDeserializer::borrowed(kind),
120            AnyKind::Owned(kind) => SometimesBorrowedStrDeserializer::transient(kind),
121        };
122        let value = seed.deserialize(deserializer)?;
123        Ok((value, self))
124    }
125}
126
127impl<'de, 'a, T, M> VariantAccess<'de> for NodeDeserializer<'de, 'a, T, M>
128where
129    T: Deserialize<'de>,
130    M: MapAccess<'de>,
131{
132    type Error = M::Error;
133
134    fn unit_variant(mut self) -> Result<(), Self::Error> {
135        self.ignore()?;
136        Ok(())
137    }
138
139    fn newtype_variant_seed<V>(self, seed: V) -> Result<V::Value, Self::Error>
140    where
141        V: DeserializeSeed<'de>,
142    {
143        seed.deserialize(NodeFieldsDeserializer { node: self })
144    }
145
146    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
147    where
148        V: Visitor<'de>,
149    {
150        let _ = len;
151        let _ = visitor;
152        let kind = self.kind.as_str();
153        let expected = ExpectedTupleVariant { kind };
154        Err(Error::invalid_type(Unexpected::StructVariant, &expected))
155    }
156
157    fn struct_variant<V>(
158        self,
159        fields: &'static [&'static str],
160        visitor: V,
161    ) -> Result<V::Value, Self::Error>
162    where
163        V: Visitor<'de>,
164    {
165        let _ = fields;
166        let kind = self.kind;
167        match visitor.visit_map(NodeFieldsDeserializer { node: self }) {
168            Ok(value) => Ok(value),
169            Err(error) => Err(error.with_kind(kind)),
170        }
171    }
172}
173
174impl<'de, 'a, T, M> MapAccess<'de> for NodeDeserializer<'de, 'a, T, M>
175where
176    T: Deserialize<'de>,
177    M: MapAccess<'de>,
178{
179    type Error = M::Error;
180
181    fn next_key_seed<K>(&mut self, mut seed: K) -> Result<Option<K::Value>, Self::Error>
182    where
183        K: DeserializeSeed<'de>,
184    {
185        if self.has_kind {
186            let deserializer = BorrowedStrDeserializer::new("kind");
187            seed.deserialize(deserializer).map(Some)
188        } else {
189            loop {
190                seed = match self.map.next_key_seed(NodeFieldSeed {
191                    kind: self.kind,
192                    seed,
193                })? {
194                    None => return Ok(None),
195                    Some(NodeField::Inner(seed)) => {
196                        *self.inner = self.map.next_value()?;
197                        seed
198                    }
199                    Some(NodeField::Delegate(value)) => return Ok(Some(value)),
200                };
201            }
202        }
203    }
204
205    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
206    where
207        V: DeserializeSeed<'de>,
208    {
209        if self.has_kind {
210            let deserializer = match &self.kind {
211                AnyKind::Kind(kind) => SometimesBorrowedStrDeserializer::borrowed(kind.as_str()),
212                AnyKind::Borrowed(kind) => SometimesBorrowedStrDeserializer::borrowed(kind),
213                AnyKind::Owned(kind) => SometimesBorrowedStrDeserializer::transient(kind),
214            };
215            let value = seed.deserialize(deserializer);
216            self.has_kind = false;
217            value
218        } else {
219            self.map.next_value_seed(seed)
220        }
221    }
222}
223
224impl<'de, 'a, T, M> NodeDeserializer<'de, 'a, T, M>
225where
226    T: Deserialize<'de>,
227    M: MapAccess<'de>,
228{
229    fn ignore(&mut self) -> Result<(), M::Error> {
230        while let Some(node_field) = self.map.next_key_seed(NodeFieldSeed {
231            kind: self.kind,
232            seed: PhantomData::<IgnoredAny>,
233        })? {
234            match node_field {
235                NodeField::Inner(PhantomData) => {
236                    *self.inner = self.map.next_value()?;
237                }
238                NodeField::Delegate(IgnoredAny) => {
239                    let _: IgnoredAny = self.map.next_value()?;
240                }
241            }
242        }
243        Ok(())
244    }
245}
246
247struct UnknownNode<'de, 'a, T, M> {
248    name: &'static str,
249    node: NodeDeserializer<'de, 'a, T, M>,
250}
251
252impl<'de, 'a, T, M> EnumAccess<'de> for UnknownNode<'de, 'a, T, M>
253where
254    T: Deserialize<'de>,
255    M: MapAccess<'de>,
256{
257    type Error = M::Error;
258    type Variant = Self;
259
260    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
261    where
262        V: DeserializeSeed<'de>,
263    {
264        let deserializer = BorrowedStrDeserializer::new(self.name);
265        let value = seed.deserialize(deserializer)?;
266        Ok((value, self))
267    }
268}
269
270impl<'de, 'a, T, M> VariantAccess<'de> for UnknownNode<'de, 'a, T, M>
271where
272    T: Deserialize<'de>,
273    M: MapAccess<'de>,
274{
275    type Error = M::Error;
276
277    fn unit_variant(self) -> Result<(), Self::Error> {
278        self.node.unit_variant()
279    }
280
281    fn newtype_variant_seed<V>(self, seed: V) -> Result<V::Value, Self::Error>
282    where
283        V: DeserializeSeed<'de>,
284    {
285        seed.deserialize(self.node)
286    }
287
288    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
289    where
290        V: Visitor<'de>,
291    {
292        self.node.tuple_variant(len, visitor)
293    }
294
295    fn struct_variant<V>(
296        self,
297        fields: &'static [&'static str],
298        visitor: V,
299    ) -> Result<V::Value, Self::Error>
300    where
301        V: Visitor<'de>,
302    {
303        let _ = fields;
304        visitor.visit_map(self.node)
305    }
306}
307
308struct NodeFieldsDeserializer<'de, 'a, T, M> {
309    node: NodeDeserializer<'de, 'a, T, M>,
310}
311
312impl<'de, 'a, T, M> Deserializer<'de> for NodeFieldsDeserializer<'de, 'a, T, M>
313where
314    T: Deserialize<'de>,
315    M: MapAccess<'de>,
316{
317    type Error = M::Error;
318
319    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
320    where
321        V: Visitor<'de>,
322    {
323        let kind = self.node.kind;
324        match visitor.visit_map(self) {
325            Ok(value) => Ok(value),
326            Err(error) => Err(error.with_kind(kind)),
327        }
328    }
329
330    fn deserialize_enum<V>(
331        self,
332        name: &'static str,
333        variants: &'static [&'static str],
334        visitor: V,
335    ) -> Result<V::Value, Self::Error>
336    where
337        V: Visitor<'de>,
338    {
339        let _ = variants;
340        visitor.visit_enum(NodeEnumDeserializer {
341            name,
342            node: self.node,
343        })
344    }
345
346    fn deserialize_unit_struct<V>(
347        self,
348        name: &'static str,
349        visitor: V,
350    ) -> Result<V::Value, Self::Error>
351    where
352        V: Visitor<'de>,
353    {
354        let _ = name;
355        self.deserialize_unit(visitor)
356    }
357
358    fn deserialize_unit<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
359    where
360        V: Visitor<'de>,
361    {
362        self.node.ignore()?;
363        visitor.visit_unit()
364    }
365
366    forward_to_deserialize_any! {
367        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
368        bytes byte_buf option newtype_struct seq tuple tuple_struct map struct
369        identifier ignored_any
370    }
371}
372
373impl<'de, 'a, T, M> MapAccess<'de> for NodeFieldsDeserializer<'de, 'a, T, M>
374where
375    T: Deserialize<'de>,
376    M: MapAccess<'de>,
377{
378    type Error = FieldOfKindError<M::Error>;
379
380    fn next_key_seed<K>(&mut self, mut seed: K) -> Result<Option<K::Value>, Self::Error>
381    where
382        K: DeserializeSeed<'de>,
383    {
384        loop {
385            seed = match self
386                .node
387                .map
388                .next_key_seed(NodeFieldSeed {
389                    kind: self.node.kind,
390                    seed,
391                })
392                .map_err(FieldOfKindError::Other)?
393            {
394                None => return Ok(None),
395                Some(NodeField::Inner(seed)) => {
396                    *self.node.inner = self
397                        .node
398                        .map
399                        .next_value()
400                        .map_err(FieldOfKindError::Other)?;
401                    seed
402                }
403                Some(NodeField::Delegate(value)) => return Ok(Some(value)),
404            };
405        }
406    }
407
408    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
409    where
410        V: DeserializeSeed<'de>,
411    {
412        self.node
413            .map
414            .next_value_seed(seed)
415            .map_err(FieldOfKindError::Other)
416    }
417}
418
419struct NodeEnumDeserializer<'de, 'a, T, M> {
420    name: &'static str,
421    node: NodeDeserializer<'de, 'a, T, M>,
422}
423
424impl<'de, 'a, T, M> EnumAccess<'de> for NodeEnumDeserializer<'de, 'a, T, M>
425where
426    T: Deserialize<'de>,
427    M: MapAccess<'de>,
428{
429    type Error = M::Error;
430    type Variant = Self;
431
432    fn variant_seed<V>(mut self, mut seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
433    where
434        V: DeserializeSeed<'de>,
435    {
436        loop {
437            seed = match self.node.map.next_key_seed(NodeFieldSeed {
438                kind: self.node.kind,
439                seed,
440            })? {
441                None => {
442                    let expected = ExpectedEnum { name: self.name };
443                    return Err(Error::invalid_type(Unexpected::Map, &expected));
444                }
445                Some(NodeField::Inner(seed)) => {
446                    *self.node.inner = self.node.map.next_value()?;
447                    seed
448                }
449                Some(NodeField::Delegate(value)) => return Ok((value, self)),
450            }
451        }
452    }
453}
454
455impl<'de, 'a, T, M> VariantAccess<'de> for NodeEnumDeserializer<'de, 'a, T, M>
456where
457    T: Deserialize<'de>,
458    M: MapAccess<'de>,
459{
460    type Error = M::Error;
461
462    fn unit_variant(self) -> Result<(), Self::Error> {
463        let expected = "unit variant";
464        Err(Error::invalid_type(Unexpected::NewtypeVariant, &expected))
465    }
466
467    fn newtype_variant_seed<V>(mut self, seed: V) -> Result<V::Value, Self::Error>
468    where
469        V: DeserializeSeed<'de>,
470    {
471        let value = self.node.map.next_value_seed(seed)?;
472        loop {
473            match self.node.map.next_key_seed(NodeFieldSeed {
474                kind: self.node.kind,
475                seed: PhantomData::<UnexpectedField>,
476            })? {
477                None => return Ok(value),
478                Some(NodeField::Inner(PhantomData)) => {
479                    *self.node.inner = self.node.map.next_value()?;
480                }
481                #[allow(unreachable_patterns)]
482                Some(NodeField::Delegate(unexpected)) => match unexpected {},
483            }
484        }
485    }
486
487    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
488    where
489        V: Visitor<'de>,
490    {
491        let _ = len;
492        let _ = visitor;
493        let expected = "tuple variant";
494        Err(Error::invalid_type(Unexpected::NewtypeVariant, &expected))
495    }
496
497    fn struct_variant<V>(
498        self,
499        fields: &'static [&'static str],
500        visitor: V,
501    ) -> Result<V::Value, Self::Error>
502    where
503        V: Visitor<'de>,
504    {
505        let _ = fields;
506        let _ = visitor;
507        let expected = "struct variant";
508        Err(Error::invalid_type(Unexpected::NewtypeVariant, &expected))
509    }
510}
511
512struct NodeFieldSeed<'a, K> {
513    kind: &'a AnyKind<'a>,
514    seed: K,
515}
516
517enum NodeField<K, X> {
518    Inner(K),
519    Delegate(X),
520}
521
522impl<'de, 'a, K> DeserializeSeed<'de> for NodeFieldSeed<'a, K>
523where
524    K: DeserializeSeed<'de>,
525{
526    type Value = NodeField<K, K::Value>;
527
528    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
529    where
530        D: Deserializer<'de>,
531    {
532        deserializer.deserialize_identifier(self)
533    }
534}
535
536impl<'de, 'a, K> Visitor<'de> for NodeFieldSeed<'a, K>
537where
538    K: DeserializeSeed<'de>,
539{
540    type Value = NodeField<K, K::Value>;
541
542    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
543        formatter.write_str("field of syntax tree node")
544    }
545
546    fn visit_str<E>(self, identifier: &str) -> Result<Self::Value, E>
547    where
548        E: Error,
549    {
550        match identifier {
551            "inner" => Ok(NodeField::Inner(self.seed)),
552            other => match self.seed.deserialize(FieldOfKindDeserializer {
553                field: other,
554                error: PhantomData,
555            }) {
556                Ok(field) => Ok(NodeField::Delegate(field)),
557                Err(error) => Err(error.with_kind(self.kind)),
558            },
559        }
560    }
561}
562
563struct FieldOfKindDeserializer<'a, E> {
564    field: &'a str,
565    error: PhantomData<E>,
566}
567
568impl<'de, 'a, E> Deserializer<'de> for FieldOfKindDeserializer<'a, E>
569where
570    E: Error,
571{
572    type Error = FieldOfKindError<E>;
573
574    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
575    where
576        V: Visitor<'de>,
577    {
578        visitor.visit_str(self.field)
579    }
580
581    forward_to_deserialize_any! {
582        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
583        bytes byte_buf option unit unit_struct newtype_struct seq tuple
584        tuple_struct map struct enum identifier ignored_any
585    }
586}
587
588#[derive(Debug)]
589enum FieldOfKindError<E> {
590    UnknownField {
591        field: Box<str>,
592        expected: &'static [&'static str],
593    },
594    MissingField {
595        field: &'static str,
596    },
597    Other(E),
598}
599
600impl<E> FieldOfKindError<E>
601where
602    E: Error,
603{
604    fn with_kind(self, kind: &AnyKind) -> E {
605        match self {
606            FieldOfKindError::UnknownField { field, expected } => {
607                if let AnyKind::Kind(Kind::null) = kind {
608                    E::unknown_field(&field, expected)
609                } else if expected.is_empty() {
610                    E::custom(format_args!(
611                        "unknown field `{}` in {}, there are no fields",
612                        field, kind,
613                    ))
614                } else {
615                    E::custom(format_args!(
616                        "unknown field `{}` in {}, expected {}",
617                        field,
618                        kind,
619                        OneOf { names: expected },
620                    ))
621                }
622            }
623            FieldOfKindError::MissingField { field } => {
624                if let AnyKind::Kind(Kind::null) = kind {
625                    E::missing_field(field)
626                } else {
627                    E::custom(format_args!("missing field `{}` in {}", field, kind))
628                }
629            }
630            FieldOfKindError::Other(error) => error,
631        }
632    }
633}
634
635impl<E> Error for FieldOfKindError<E>
636where
637    E: Error,
638{
639    fn unknown_field(field: &str, expected: &'static [&'static str]) -> Self {
640        FieldOfKindError::UnknownField {
641            field: Box::from(field),
642            expected,
643        }
644    }
645
646    fn missing_field(field: &'static str) -> Self {
647        FieldOfKindError::MissingField { field }
648    }
649
650    fn custom<T>(msg: T) -> Self
651    where
652        T: Display,
653    {
654        FieldOfKindError::Other(E::custom(msg))
655    }
656
657    fn invalid_type(unexp: Unexpected, exp: &dyn Expected) -> Self {
658        FieldOfKindError::Other(E::invalid_type(unexp, exp))
659    }
660
661    fn invalid_value(unexp: Unexpected, exp: &dyn Expected) -> Self {
662        FieldOfKindError::Other(E::invalid_value(unexp, exp))
663    }
664
665    fn invalid_length(len: usize, exp: &dyn Expected) -> Self {
666        FieldOfKindError::Other(E::invalid_length(len, exp))
667    }
668
669    fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
670        FieldOfKindError::Other(E::unknown_variant(variant, expected))
671    }
672
673    fn duplicate_field(field: &'static str) -> Self {
674        FieldOfKindError::Other(E::duplicate_field(field))
675    }
676}
677
678impl<E> StdError for FieldOfKindError<E>
679where
680    E: StdError,
681{
682    fn source(&self) -> Option<&(dyn StdError + 'static)> {
683        match self {
684            FieldOfKindError::UnknownField { .. } | FieldOfKindError::MissingField { .. } => None,
685            FieldOfKindError::Other(error) => error.source(),
686        }
687    }
688}
689
690impl<E> Display for FieldOfKindError<E>
691where
692    E: Display,
693{
694    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
695        match self {
696            FieldOfKindError::UnknownField { field, expected } => {
697                if expected.is_empty() {
698                    write!(formatter, "unknown field `{}`, there are no fields", field)
699                } else {
700                    write!(
701                        formatter,
702                        "unknown field `{}`, expected {}",
703                        field,
704                        OneOf { names: expected },
705                    )
706                }
707            }
708            FieldOfKindError::MissingField { field } => {
709                write!(formatter, "missing field `{}`", field)
710            }
711            FieldOfKindError::Other(error) => Display::fmt(error, formatter),
712        }
713    }
714}
715
716struct OneOf {
717    names: &'static [&'static str],
718}
719
720impl Display for OneOf {
721    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
722        match self.names.len() {
723            0 => unreachable!(), // special case elsewhere
724            1 => write!(formatter, "`{}`", self.names[0]),
725            2 => write!(formatter, "`{}` or `{}`", self.names[0], self.names[1]),
726            _ => {
727                formatter.write_str("one of ")?;
728                for (i, alt) in self.names.iter().enumerate() {
729                    if i > 0 {
730                        formatter.write_str(", ")?;
731                    }
732                    write!(formatter, "`{}`", alt)?;
733                }
734                Ok(())
735            }
736        }
737    }
738}
739
740enum UnexpectedField {}
741
742impl<'de> Deserialize<'de> for UnexpectedField {
743    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
744    where
745        D: Deserializer<'de>,
746    {
747        deserializer.deserialize_identifier(UnexpectedFieldVisitor)
748    }
749}
750
751struct UnexpectedFieldVisitor;
752
753impl<'de> Visitor<'de> for UnexpectedFieldVisitor {
754    type Value = UnexpectedField;
755
756    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
757        formatter.write_str("no more fields")
758    }
759
760    fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
761    where
762        E: Error,
763    {
764        Err(Error::unknown_field(string, &[]))
765    }
766}
767
768struct ExpectedTupleVariant<'a> {
769    kind: &'a str,
770}
771
772impl<'a> Expected for ExpectedTupleVariant<'a> {
773    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
774        write!(formatter, "tuple variant of type `{}`", self.kind)
775    }
776}
777
778struct ExpectedEnum {
779    name: &'static str,
780}
781
782impl Expected for ExpectedEnum {
783    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
784        write!(formatter, "enum `{}`", self.name)
785    }
786}