nested_deserialize/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, parse_macro_input};
4
5#[proc_macro_derive(NestedDeserialize)]
6pub fn derive_nested_deserialize(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = &input.ident;
9
10 let enum_data = match input.data {
11 Data::Enum(data) => data,
12 _ => panic!("NestedDeserialize only works on enums"),
13 };
14
15 let mut variant_idents = Vec::new();
16 let mut inner_types = Vec::new();
17
18 for variant in enum_data.variants {
19 variant_idents.push(variant.ident.clone());
20 match variant.fields {
21 Fields::Unnamed(f) if f.unnamed.len() == 1 => {
22 inner_types.push(f.unnamed[0].ty.clone());
23 }
24 _ => panic!("Variants must be single-element tuples"),
25 }
26 }
27
28 let expected_variants_code = quote! {
30 {
31 let mut expected = Vec::new();
32 #( expected.extend_from_slice(<#inner_types as strum::VariantNames>::VARIANTS); )*
33 expected.join(", ")
34 }
35 };
36
37 let expanded = quote! {
38 impl<'de> serde::Deserialize<'de> for #name {
39 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40 where
41 D: serde::Deserializer<'de>,
42 {
43 use serde::de::{Visitor, Error, MapAccess, DeserializeSeed, IntoDeserializer};
44 use std::fmt;
45 use std::borrow::Cow;
46
47 struct NestedVisitor;
48
49 impl<'de> Visitor<'de> for NestedVisitor {
50 type Value = #name;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("an externally tagged nested enum")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58 where
59 E: Error,
60 {
61 #(
62 if <#inner_types as strum::VariantNames>::VARIANTS.contains(&value) {
63 let inner = <#inner_types as serde::Deserialize>::deserialize(
64 value.into_deserializer()
65 )?;
66 return Ok(#name::#variant_idents(inner));
67 }
68 )*
69
70 let expected_str = #expected_variants_code;
71 Err(Error::custom(format!("unknown variant `{}`, expected one of: {}", value, expected_str)))
72 }
73
74 fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
75 where
76 E: Error,
77 {
78 self.visit_str(&value)
79 }
80
81 fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
83 where
84 M: MapAccess<'de>,
85 {
86 let key: Cow<'de, str> = match map.next_key()? {
88 Some(k) => k,
89 None => return Err(Error::custom("expected an externally tagged enum map, found empty map")),
90 };
91
92 struct ReplayMapAccess<'a, M> {
94 key: Option<&'a str>,
95 map: M,
96 }
97
98 impl<'a, 'de, M> MapAccess<'de> for ReplayMapAccess<'a, M>
99 where
100 M: MapAccess<'de>,
101 {
102 type Error = M::Error;
103
104 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
105 where
106 K: DeserializeSeed<'de>,
107 {
108 if let Some(k) = self.key.take() {
109 seed.deserialize(k.into_deserializer()).map(Some)
111 } else {
112 self.map.next_key_seed(seed)
113 }
114 }
115
116 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
117 where
118 V: DeserializeSeed<'de>,
119 {
120 self.map.next_value_seed(seed)
121 }
122 }
123
124 let replay_map = ReplayMapAccess {
125 key: Some(&key),
126 map,
127 };
128
129 let map_deserializer = serde::de::value::MapAccessDeserializer::new(replay_map);
131 let tag = key.as_ref();
132
133 #(
135 if <#inner_types as strum::VariantNames>::VARIANTS.contains(&tag) {
136 let inner = <#inner_types as serde::Deserialize>::deserialize(map_deserializer)?;
137 return Ok(#name::#variant_idents(inner));
138 }
139 )*
140
141 let expected_str = #expected_variants_code;
143 Err(Error::custom(format!("unknown variant `{}`, expected one of: {}", tag, expected_str)))
144 }
145 }
146
147 deserializer.deserialize_any(NestedVisitor)
149 }
150 }
151 };
152
153 TokenStream::from(expanded)
154}