1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_derive(IonType)]
6pub fn derive_ion_type(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = &input.ident;
9 let name_str = name.to_string();
10
11 match &input.data {
12 Data::Struct(data) => derive_struct(name, &name_str, data),
13 Data::Enum(data) => derive_enum(name, &name_str, data),
14 Data::Union(_) => syn::Error::new_spanned(name, "IonType cannot be derived for unions")
15 .to_compile_error()
16 .into(),
17 }
18}
19
20fn derive_struct(name: &syn::Ident, name_str: &str, data: &syn::DataStruct) -> TokenStream {
21 let fields = match &data.fields {
22 Fields::Named(f) => &f.named,
23 _ => {
24 return syn::Error::new_spanned(name, "IonType only supports named struct fields")
25 .to_compile_error()
26 .into();
27 }
28 };
29
30 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
31 let field_name_strs: Vec<String> = field_names.iter().map(|f| f.to_string()).collect();
32
33 let to_ion_fields = field_names.iter().zip(field_name_strs.iter()).map(|(ident, name_s)| {
35 quote! {
36 fields.insert(#name_s.to_string(), ion_core::host_types::IonType::to_ion(&self.#ident));
37 }
38 });
39
40 let from_ion_fields = field_names
42 .iter()
43 .zip(field_name_strs.iter())
44 .map(|(ident, name_s)| {
45 quote! {
46 #ident: {
47 let v = fields.get(#name_s)
48 .ok_or_else(|| format!("missing field '{}' in {}", #name_s, #name_str))?;
49 ion_core::host_types::IonType::from_ion(v)?
50 },
51 }
52 });
53
54 let def_fields = field_name_strs.iter().map(|s| {
56 quote! { #s.to_string() }
57 });
58
59 let expanded = quote! {
60 impl ion_core::host_types::IonType for #name {
61 fn to_ion(&self) -> ion_core::value::Value {
62 let mut fields = indexmap::IndexMap::new();
63 #(#to_ion_fields)*
64 ion_core::value::Value::HostStruct {
65 type_name: #name_str.to_string(),
66 fields,
67 }
68 }
69
70 fn from_ion(val: &ion_core::value::Value) -> Result<Self, String> {
71 if let ion_core::value::Value::HostStruct { type_name, fields } = val {
72 if type_name != #name_str {
73 return Err(format!("expected {}, got {}", #name_str, type_name));
74 }
75 Ok(Self {
76 #(#from_ion_fields)*
77 })
78 } else {
79 Err(format!("expected {}, got {}", #name_str, val.type_name()))
80 }
81 }
82
83 fn ion_type_def() -> ion_core::host_types::IonTypeDef {
84 ion_core::host_types::IonTypeDef::Struct(
85 ion_core::host_types::HostStructDef {
86 name: #name_str.to_string(),
87 fields: vec![#(#def_fields),*],
88 }
89 )
90 }
91 }
92 };
93
94 expanded.into()
95}
96
97fn derive_enum(name: &syn::Ident, name_str: &str, data: &syn::DataEnum) -> TokenStream {
98 let variants = &data.variants;
99 for variant in variants {
100 if matches!(variant.fields, Fields::Named(_)) {
101 return syn::Error::new_spanned(
102 &variant.ident,
103 "IonType does not support enum variants with named fields",
104 )
105 .to_compile_error()
106 .into();
107 }
108 }
109
110 let variant_defs = variants.iter().map(|v| {
112 let vname = v.ident.to_string();
113 let arity = match &v.fields {
114 Fields::Unit => 0usize,
115 Fields::Unnamed(f) => f.unnamed.len(),
116 Fields::Named(_) => unreachable!("named enum fields rejected above"),
117 };
118 quote! {
119 ion_core::host_types::HostVariantDef {
120 name: #vname.to_string(),
121 arity: #arity,
122 }
123 }
124 });
125
126 let to_ion_arms = variants.iter().map(|v| {
128 let vident = &v.ident;
129 let vname = v.ident.to_string();
130 match &v.fields {
131 Fields::Unit => {
132 quote! {
133 #name::#vident => ion_core::value::Value::HostEnum {
134 enum_name: #name_str.to_string(),
135 variant: #vname.to_string(),
136 data: vec![],
137 },
138 }
139 }
140 Fields::Unnamed(fields) => {
141 let bindings: Vec<_> = (0..fields.unnamed.len())
142 .map(|i| syn::Ident::new(&format!("f{}", i), proc_macro2::Span::call_site()))
143 .collect();
144 let to_ions = bindings.iter().map(|b| {
145 quote! { ion_core::host_types::IonType::to_ion(#b) }
146 });
147 quote! {
148 #name::#vident(#(#bindings),*) => ion_core::value::Value::HostEnum {
149 enum_name: #name_str.to_string(),
150 variant: #vname.to_string(),
151 data: vec![#(#to_ions),*],
152 },
153 }
154 }
155 Fields::Named(_) => unreachable!("named enum fields rejected above"),
156 }
157 });
158
159 let from_ion_arms = variants.iter().map(|v| {
161 let vident = &v.ident;
162 let vname = v.ident.to_string();
163 match &v.fields {
164 Fields::Unit => {
165 quote! {
166 #vname => {
167 if !data.is_empty() {
168 return Err(format!("{}::{} takes no arguments", #name_str, #vname));
169 }
170 Ok(#name::#vident)
171 }
172 }
173 }
174 Fields::Unnamed(fields) => {
175 let count = fields.unnamed.len();
176 let extracts: Vec<_> = (0..count)
177 .map(|i| {
178 quote! {
179 ion_core::host_types::IonType::from_ion(&data[#i])?
180 }
181 })
182 .collect();
183 quote! {
184 #vname => {
185 if data.len() != #count {
186 return Err(format!("{}::{} expects {} arguments, got {}", #name_str, #vname, #count, data.len()));
187 }
188 Ok(#name::#vident(#(#extracts),*))
189 }
190 }
191 }
192 Fields::Named(_) => unreachable!("named enum fields rejected above"),
193 }
194 });
195
196 let expanded = quote! {
197 impl ion_core::host_types::IonType for #name {
198 fn to_ion(&self) -> ion_core::value::Value {
199 match self {
200 #(#to_ion_arms)*
201 }
202 }
203
204 fn from_ion(val: &ion_core::value::Value) -> Result<Self, String> {
205 if let ion_core::value::Value::HostEnum { enum_name, variant, data } = val {
206 if enum_name != #name_str {
207 return Err(format!("expected {}, got {}", #name_str, enum_name));
208 }
209 match variant.as_str() {
210 #(#from_ion_arms)*
211 _ => Err(format!("unknown variant '{}' in {}", variant, #name_str)),
212 }
213 } else {
214 Err(format!("expected {}, got {}", #name_str, val.type_name()))
215 }
216 }
217
218 fn ion_type_def() -> ion_core::host_types::IonTypeDef {
219 ion_core::host_types::IonTypeDef::Enum(
220 ion_core::host_types::HostEnumDef {
221 name: #name_str.to_string(),
222 variants: vec![#(#variant_defs),*],
223 }
224 )
225 }
226 }
227 };
228
229 expanded.into()
230}