Skip to main content

nested_deserialize/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5#[proc_macro_derive(NestedDeserialize)]
6pub fn derive_nested_deserialize(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9
10    let enum_data = match input.data {
11        Data::Enum(data) => data,
12        _ => panic!("NestedDeserialize only works on enums"),
13    };
14
15    let mut variant_idents = Vec::new();
16    let mut inner_types = Vec::new();
17
18    for variant in enum_data.variants {
19        variant_idents.push(variant.ident.clone());
20        match variant.fields {
21            Fields::Unnamed(f) if f.unnamed.len() == 1 => {
22                inner_types.push(f.unnamed[0].ty.clone());
23            }
24            _ => panic!("Variants must be single-element tuples"),
25        }
26    }
27
28    // Generates a formatted list of all valid variants for error messages
29    let expected_variants_code = quote! {
30        {
31            let mut expected = Vec::new();
32            #( expected.extend_from_slice(<#inner_types as strum::VariantNames>::VARIANTS); )*
33            expected.join(", ")
34        }
35    };
36
37    let expanded = quote! {
38        impl<'de> serde::Deserialize<'de> for #name {
39            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40            where
41                D: serde::Deserializer<'de>,
42            {
43                use serde::de::{Visitor, Error, MapAccess, DeserializeSeed, IntoDeserializer};
44                use std::fmt;
45                use std::borrow::Cow;
46
47                struct NestedVisitor;
48
49                impl<'de> Visitor<'de> for NestedVisitor {
50                    type Value = #name;
51
52                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53                        formatter.write_str("an externally tagged nested enum")
54                    }
55
56                    // Handles unit variants (e.g. JSON: `"Velocity"`)
57                    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58                    where
59                        E: Error,
60                    {
61                        #(
62                            if <#inner_types as strum::VariantNames>::VARIANTS.contains(&value) {
63                                let inner = <#inner_types as serde::Deserialize>::deserialize(
64                                    value.into_deserializer()
65                                )?;
66                                return Ok(#name::#variant_idents(inner));
67                            }
68                        )*
69
70                        let expected_str = #expected_variants_code;
71                        Err(Error::custom(format!("unknown variant `{}`, expected one of: {}", value, expected_str)))
72                    }
73
74                    fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
75                    where
76                        E: Error,
77                    {
78                        self.visit_str(&value)
79                    }
80
81                    // Handles tuple/struct variants (e.g. JSON: `{"Velocity": {"m_s": 12.0}}`)
82                    fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
83                    where
84                        M: MapAccess<'de>,
85                    {
86                        // 1. Peek at the first key to get the tag
87                        let key: Cow<'de, str> = match map.next_key()? {
88                            Some(k) => k,
89                            None => return Err(Error::custom("expected an externally tagged enum map, found empty map")),
90                        };
91
92                        // 2. Wrap the stream to "replay" the consumed key
93                        struct ReplayMapAccess<'a, M> {
94                            key: Option<&'a str>,
95                            map: M,
96                        }
97
98                        impl<'a, 'de, M> MapAccess<'de> for ReplayMapAccess<'a, M>
99                        where
100                            M: MapAccess<'de>,
101                        {
102                            type Error = M::Error;
103
104                            fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
105                            where
106                                K: DeserializeSeed<'de>,
107                            {
108                                if let Some(k) = self.key.take() {
109                                    // Yield the borrowed key, then fall back to the map stream
110                                    seed.deserialize(k.into_deserializer()).map(Some)
111                                } else {
112                                    self.map.next_key_seed(seed)
113                                }
114                            }
115
116                            fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
117                            where
118                                V: DeserializeSeed<'de>,
119                            {
120                                self.map.next_value_seed(seed)
121                            }
122                        }
123
124                        let replay_map = ReplayMapAccess {
125                            key: Some(&key),
126                            map,
127                        };
128
129                        // Turn our Map wrapper back into a Deserializer that the inner enum can consume
130                        let map_deserializer = serde::de::value::MapAccessDeserializer::new(replay_map);
131                        let tag = key.as_ref();
132
133                        // 3. Dispatch to the correct inner enum based on the tag
134                        #(
135                            if <#inner_types as strum::VariantNames>::VARIANTS.contains(&tag) {
136                                let inner = <#inner_types as serde::Deserialize>::deserialize(map_deserializer)?;
137                                return Ok(#name::#variant_idents(inner));
138                            }
139                        )*
140
141                        // 4. Fallback: Detailed error message
142                        let expected_str = #expected_variants_code;
143                        Err(Error::custom(format!("unknown variant `{}`, expected one of: {}", tag, expected_str)))
144                    }
145                }
146
147                // Since we rely on extracting the tag dynamically, we use deserialize_any
148                deserializer.deserialize_any(NestedVisitor)
149            }
150        }
151    };
152
153    TokenStream::from(expanded)
154}