Skip to main content

ion_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields};
4
5#[proc_macro_derive(IonType)]
6pub fn derive_ion_type(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9    let name_str = name.to_string();
10
11    match &input.data {
12        Data::Struct(data) => derive_struct(name, &name_str, data),
13        Data::Enum(data) => derive_enum(name, &name_str, data),
14        Data::Union(_) => syn::Error::new_spanned(name, "IonType cannot be derived for unions")
15            .to_compile_error()
16            .into(),
17    }
18}
19
20fn derive_struct(name: &syn::Ident, name_str: &str, data: &syn::DataStruct) -> TokenStream {
21    let fields = match &data.fields {
22        Fields::Named(f) => &f.named,
23        _ => {
24            return syn::Error::new_spanned(name, "IonType only supports named struct fields")
25                .to_compile_error()
26                .into();
27        }
28    };
29
30    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
31    let field_name_strs: Vec<String> = field_names.iter().map(|f| f.to_string()).collect();
32
33    // to_ion: convert each field
34    let to_ion_fields = field_names.iter().zip(field_name_strs.iter()).map(|(ident, name_s)| {
35        quote! {
36            fields.insert(#name_s.to_string(), ion_core::host_types::IonType::to_ion(&self.#ident));
37        }
38    });
39
40    // from_ion: extract each field
41    let from_ion_fields = field_names
42        .iter()
43        .zip(field_name_strs.iter())
44        .map(|(ident, name_s)| {
45            quote! {
46                #ident: {
47                    let v = fields.get(#name_s)
48                        .ok_or_else(|| format!("missing field '{}' in {}", #name_s, #name_str))?;
49                    ion_core::host_types::IonType::from_ion(v)?
50                },
51            }
52        });
53
54    // ion_type_def: field name list
55    let def_fields = field_name_strs.iter().map(|s| {
56        quote! { #s.to_string() }
57    });
58
59    let expanded = quote! {
60        impl ion_core::host_types::IonType for #name {
61            fn to_ion(&self) -> ion_core::value::Value {
62                let mut fields = indexmap::IndexMap::new();
63                #(#to_ion_fields)*
64                ion_core::value::Value::HostStruct {
65                    type_name: #name_str.to_string(),
66                    fields,
67                }
68            }
69
70            fn from_ion(val: &ion_core::value::Value) -> Result<Self, String> {
71                if let ion_core::value::Value::HostStruct { type_name, fields } = val {
72                    if type_name != #name_str {
73                        return Err(format!("expected {}, got {}", #name_str, type_name));
74                    }
75                    Ok(Self {
76                        #(#from_ion_fields)*
77                    })
78                } else {
79                    Err(format!("expected {}, got {}", #name_str, val.type_name()))
80                }
81            }
82
83            fn ion_type_def() -> ion_core::host_types::IonTypeDef {
84                ion_core::host_types::IonTypeDef::Struct(
85                    ion_core::host_types::HostStructDef {
86                        name: #name_str.to_string(),
87                        fields: vec![#(#def_fields),*],
88                    }
89                )
90            }
91        }
92    };
93
94    expanded.into()
95}
96
97fn derive_enum(name: &syn::Ident, name_str: &str, data: &syn::DataEnum) -> TokenStream {
98    let variants = &data.variants;
99    for variant in variants {
100        if matches!(variant.fields, Fields::Named(_)) {
101            return syn::Error::new_spanned(
102                &variant.ident,
103                "IonType does not support enum variants with named fields",
104            )
105            .to_compile_error()
106            .into();
107        }
108    }
109
110    // ion_type_def: variant definitions
111    let variant_defs = variants.iter().map(|v| {
112        let vname = v.ident.to_string();
113        let arity = match &v.fields {
114            Fields::Unit => 0usize,
115            Fields::Unnamed(f) => f.unnamed.len(),
116            Fields::Named(_) => unreachable!("named enum fields rejected above"),
117        };
118        quote! {
119            ion_core::host_types::HostVariantDef {
120                name: #vname.to_string(),
121                arity: #arity,
122            }
123        }
124    });
125
126    // to_ion arms
127    let to_ion_arms = variants.iter().map(|v| {
128        let vident = &v.ident;
129        let vname = v.ident.to_string();
130        match &v.fields {
131            Fields::Unit => {
132                quote! {
133                    #name::#vident => ion_core::value::Value::HostEnum {
134                        enum_name: #name_str.to_string(),
135                        variant: #vname.to_string(),
136                        data: vec![],
137                    },
138                }
139            }
140            Fields::Unnamed(fields) => {
141                let bindings: Vec<_> = (0..fields.unnamed.len())
142                    .map(|i| syn::Ident::new(&format!("f{}", i), proc_macro2::Span::call_site()))
143                    .collect();
144                let to_ions = bindings.iter().map(|b| {
145                    quote! { ion_core::host_types::IonType::to_ion(#b) }
146                });
147                quote! {
148                    #name::#vident(#(#bindings),*) => ion_core::value::Value::HostEnum {
149                        enum_name: #name_str.to_string(),
150                        variant: #vname.to_string(),
151                        data: vec![#(#to_ions),*],
152                    },
153                }
154            }
155            Fields::Named(_) => unreachable!("named enum fields rejected above"),
156        }
157    });
158
159    // from_ion arms
160    let from_ion_arms = variants.iter().map(|v| {
161        let vident = &v.ident;
162        let vname = v.ident.to_string();
163        match &v.fields {
164            Fields::Unit => {
165                quote! {
166                    #vname => {
167                        if !data.is_empty() {
168                            return Err(format!("{}::{} takes no arguments", #name_str, #vname));
169                        }
170                        Ok(#name::#vident)
171                    }
172                }
173            }
174            Fields::Unnamed(fields) => {
175                let count = fields.unnamed.len();
176                let extracts: Vec<_> = (0..count)
177                    .map(|i| {
178                        quote! {
179                            ion_core::host_types::IonType::from_ion(&data[#i])?
180                        }
181                    })
182                    .collect();
183                quote! {
184                    #vname => {
185                        if data.len() != #count {
186                            return Err(format!("{}::{} expects {} arguments, got {}", #name_str, #vname, #count, data.len()));
187                        }
188                        Ok(#name::#vident(#(#extracts),*))
189                    }
190                }
191            }
192            Fields::Named(_) => unreachable!("named enum fields rejected above"),
193        }
194    });
195
196    let expanded = quote! {
197        impl ion_core::host_types::IonType for #name {
198            fn to_ion(&self) -> ion_core::value::Value {
199                match self {
200                    #(#to_ion_arms)*
201                }
202            }
203
204            fn from_ion(val: &ion_core::value::Value) -> Result<Self, String> {
205                if let ion_core::value::Value::HostEnum { enum_name, variant, data } = val {
206                    if enum_name != #name_str {
207                        return Err(format!("expected {}, got {}", #name_str, enum_name));
208                    }
209                    match variant.as_str() {
210                        #(#from_ion_arms)*
211                        _ => Err(format!("unknown variant '{}' in {}", variant, #name_str)),
212                    }
213                } else {
214                    Err(format!("expected {}, got {}", #name_str, val.type_name()))
215                }
216            }
217
218            fn ion_type_def() -> ion_core::host_types::IonTypeDef {
219                ion_core::host_types::IonTypeDef::Enum(
220                    ion_core::host_types::HostEnumDef {
221                        name: #name_str.to_string(),
222                        variants: vec![#(#variant_defs),*],
223                    }
224                )
225            }
226        }
227    };
228
229    expanded.into()
230}