nuts_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, quote};
5use syn::{
6    AngleBracketedGenericArguments, Data, DeriveInput, Fields, GenericParam, Ident, Lit, LitStr,
7    PathArguments, Token, Type, TypePath,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11};
12
13// Helper struct to parse `#[storable(dims(...))]` or #[storable(flattened)]
14enum StorableAttr {
15    Item(Vec<LitStr>),
16    Flattened(),
17    Ignore(),
18}
19
20impl Parse for StorableAttr {
21    fn parse(input: ParseStream) -> syn::Result<Self> {
22        let metas = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
23
24        for meta in metas {
25            match meta {
26                syn::Meta::List(list) => {
27                    if list.path.is_ident("dims") {
28                        return Ok(StorableAttr::Item(
29                            list.nested
30                                .into_iter()
31                                .map(|e| match e {
32                                    syn::NestedMeta::Lit(Lit::Str(s)) => Ok(s),
33                                    _ => Err(syn::Error::new_spanned(e, "Expected string literal")),
34                                })
35                                .collect::<Result<Vec<_>, _>>()?,
36                        ));
37                    }
38                }
39                syn::Meta::Path(path) => {
40                    if path.is_ident("flatten") {
41                        return Ok(StorableAttr::Flattened());
42                    }
43                    if path.is_ident("ignore") {
44                        return Ok(StorableAttr::Ignore());
45                    }
46                }
47                _ => {
48                    return Err(syn::Error::new_spanned(
49                        meta,
50                        "Unsupported storable attribute. Expected `dims(...)` or `flatten`",
51                    ));
52                }
53            }
54        }
55
56        Ok(StorableAttr::Item(vec![]))
57    }
58}
59
60struct StorableBasicField {
61    name: Ident,
62    item_type: proc_macro2::TokenStream,
63    is_vec: bool,
64    is_option: bool,
65    dims: Vec<LitStr>,
66}
67
68struct StorableInnerField {
69    name: Ident,
70    item_type: proc_macro2::TokenStream,
71    is_option: bool,
72}
73
74enum StorableField {
75    Basic(StorableBasicField),
76    Inner(StorableInnerField),
77    Generic(StorableInnerField),
78}
79
80// Check if a type is a generic type parameter
81fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool {
82    if let Type::Path(type_path) = ty
83        && type_path.path.segments.len() == 1
84    {
85        let type_name = &type_path.path.segments.first().unwrap().ident;
86        return generics.params.iter().any(|param| {
87            if let GenericParam::Type(type_param) = param {
88                &type_param.ident == type_name
89            } else {
90                false
91            }
92        });
93    }
94    false
95}
96
97// Check if a type implements Storable trait based on bounds
98fn has_storable_bound(ty: &Ident, generics: &syn::Generics) -> bool {
99    for param in &generics.params {
100        if let GenericParam::Type(type_param) = param
101            && &type_param.ident == ty
102        {
103            for bound in &type_param.bounds {
104                if let syn::TypeParamBound::Trait(trait_bound) = bound {
105                    let path = &trait_bound.path;
106                    if path.segments.len() == 1
107                        && path.segments.first().unwrap().ident == "Storable"
108                    {
109                        return true;
110                    }
111                }
112            }
113        }
114    }
115    false
116}
117
118#[proc_macro_derive(Storable, attributes(storable))]
119pub fn storable_derive(input: TokenStream) -> TokenStream {
120    let ast = parse_macro_input!(input as DeriveInput);
121    let name = &ast.ident;
122    let generics = &ast.generics;
123
124    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
125    let impl_generics = if generics.params.is_empty() {
126        quote! { <P: nuts_storable::HasDims> }
127    } else {
128        quote! { #impl_generics }
129    };
130
131    let fields = if let Data::Struct(s) = ast.data {
132        if let Fields::Named(fields) = s.fields {
133            fields.named
134        } else {
135            panic!("Storable can only be derived for structs with named fields");
136        }
137    } else {
138        panic!("Storable can only be derived on structs");
139    };
140
141    let mut storable_fields = Vec::new();
142    for field in fields {
143        let field_name = field.ident.clone().unwrap();
144        let ty = &field.ty;
145        let ty_str = quote!(#ty).to_string();
146
147        let attr = field
148            .attrs
149            .iter()
150            .find(|a| a.path.is_ident("storable"))
151            .map(|a| a.parse_args::<StorableAttr>().unwrap());
152
153        if let Some(StorableAttr::Ignore()) = attr {
154            continue; // Skip this field
155        }
156
157        let attr = attr.unwrap_or(StorableAttr::Item(vec![]));
158
159        if let StorableAttr::Flattened() = attr {
160            let path = if let Type::Path(TypePath { path: p, qself: _ }) = ty {
161                p
162            } else {
163                panic!(
164                    "Unsupported field type with flattened attribute: {}",
165                    ty_str
166                );
167            };
168            let item = if path.segments.first().unwrap().ident == "Option" {
169                if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
170                    args, ..
171                }) = &path.segments.first().unwrap().arguments
172                {
173                    if let Some(arg) = args.first() {
174                        let inner_type = quote!(#arg);
175                        StorableField::Inner(StorableInnerField {
176                            name: field_name.clone(),
177                            item_type: inner_type,
178                            is_option: true,
179                        })
180                    } else {
181                        panic!("Invalid Option type for flattened field");
182                    }
183                } else {
184                    panic!("Invalid Option type for flattened field");
185                }
186            } else {
187                StorableField::Inner(StorableInnerField {
188                    name: field_name.clone(),
189                    item_type: path.into_token_stream(),
190                    is_option: false,
191                })
192            };
193            storable_fields.push(item);
194            continue;
195        }
196
197        let dims = if let StorableAttr::Item(dims) = attr {
198            dims
199        } else {
200            vec![]
201        };
202
203        // Check if the field is a generic type parameter
204        if let Type::Path(type_path) = ty {
205            if type_path.path.segments.len() == 1 {
206                let type_name = &type_path.path.segments.first().unwrap().ident;
207
208                // Check if this is a generic type parameter with Storable bound
209                if is_generic_param(ty, generics) && has_storable_bound(type_name, generics) {
210                    storable_fields.push(StorableField::Generic(StorableInnerField {
211                        name: field_name,
212                        item_type: quote!(#type_name),
213                        is_option: false,
214                    }));
215                    continue;
216                }
217
218                // Check if this is Option<T> where T is a generic type parameter
219                if type_name == "Option" {
220                    if let PathArguments::AngleBracketed(args) =
221                        &type_path.path.segments.first().unwrap().arguments
222                    {
223                        if let Some(arg) = args.args.first() {
224                            if let syn::GenericArgument::Type(inner_ty) = arg {
225                                if let Type::Path(inner_path) = inner_ty {
226                                    if inner_path.path.segments.len() == 1 {
227                                        let inner_name =
228                                            &inner_path.path.segments.first().unwrap().ident;
229                                        if is_generic_param(inner_ty, generics)
230                                            && has_storable_bound(inner_name, generics)
231                                        {
232                                            storable_fields.push(StorableField::Generic(
233                                                StorableInnerField {
234                                                    name: field_name,
235                                                    item_type: quote!(#inner_name),
236                                                    is_option: true,
237                                                },
238                                            ));
239                                            continue;
240                                        }
241                                    }
242                                }
243                            }
244                        }
245                    }
246                }
247            }
248        }
249
250        let item = match ty_str.as_str() {
251            "u64" => StorableField::Basic(StorableBasicField {
252                name: field_name.clone(),
253                item_type: quote! { nuts_storable::ItemType::U64 },
254                is_vec: false,
255                is_option: false,
256                dims,
257            }),
258            "i64" => StorableField::Basic(StorableBasicField {
259                name: field_name.clone(),
260                item_type: quote! { nuts_storable::ItemType::I64 },
261                is_vec: false,
262                is_option: false,
263                dims,
264            }),
265            "f64" => StorableField::Basic(StorableBasicField {
266                name: field_name.clone(),
267                item_type: quote! { nuts_storable::ItemType::F64 },
268                is_vec: false,
269                is_option: false,
270                dims,
271            }),
272            "f32" => StorableField::Basic(StorableBasicField {
273                name: field_name.clone(),
274                item_type: quote! { nuts_storable::ItemType::F32 },
275                is_vec: false,
276                is_option: false,
277                dims,
278            }),
279            "bool" => StorableField::Basic(StorableBasicField {
280                name: field_name.clone(),
281                item_type: quote! { nuts_storable::ItemType::Bool },
282                is_vec: false,
283                is_option: false,
284                dims,
285            }),
286            "Option < u64 >" => StorableField::Basic(StorableBasicField {
287                name: field_name.clone(),
288                item_type: quote! { nuts_storable::ItemType::U64 },
289                is_vec: false,
290                is_option: true,
291                dims,
292            }),
293            "Option < i64 >" => StorableField::Basic(StorableBasicField {
294                name: field_name.clone(),
295                item_type: quote! { nuts_storable::ItemType::I64 },
296                is_vec: false,
297                is_option: true,
298                dims,
299            }),
300            "Option < f64 >" => StorableField::Basic(StorableBasicField {
301                name: field_name.clone(),
302                item_type: quote! { nuts_storable::ItemType::F64 },
303                is_vec: false,
304                is_option: true,
305                dims,
306            }),
307            "Option < f32 >" => StorableField::Basic(StorableBasicField {
308                name: field_name.clone(),
309                item_type: quote! { nuts_storable::ItemType::F32 },
310                is_vec: false,
311                is_option: true,
312                dims,
313            }),
314            "Option < bool >" => StorableField::Basic(StorableBasicField {
315                name: field_name.clone(),
316                item_type: quote! { nuts_storable::ItemType::Bool },
317                is_vec: false,
318                is_option: true,
319                dims,
320            }),
321            "Vec < u64 >" => StorableField::Basic(StorableBasicField {
322                name: field_name.clone(),
323                item_type: quote! { nuts_storable::ItemType::U64 },
324                is_vec: true,
325                is_option: false,
326                dims,
327            }),
328            "Vec < i64 >" => StorableField::Basic(StorableBasicField {
329                name: field_name.clone(),
330                item_type: quote! { nuts_storable::ItemType::I64 },
331                is_vec: true,
332                is_option: false,
333                dims,
334            }),
335            "Vec < f64 >" => StorableField::Basic(StorableBasicField {
336                name: field_name.clone(),
337                item_type: quote! { nuts_storable::ItemType::F64 },
338                is_vec: true,
339                is_option: false,
340                dims,
341            }),
342            "Vec < f32 >" => StorableField::Basic(StorableBasicField {
343                name: field_name.clone(),
344                item_type: quote! { nuts_storable::ItemType::F32 },
345                is_vec: true,
346                is_option: false,
347                dims,
348            }),
349            "Vec < bool >" => StorableField::Basic(StorableBasicField {
350                name: field_name.clone(),
351                item_type: quote! { nuts_storable::ItemType::Bool },
352                is_vec: true,
353                is_option: false,
354                dims,
355            }),
356            "Option < Vec < u64 > >" => StorableField::Basic(StorableBasicField {
357                name: field_name.clone(),
358                item_type: quote! { nuts_storable::ItemType::U64 },
359                is_vec: true,
360                is_option: true,
361                dims,
362            }),
363            "Option < Vec < i64 > >" => StorableField::Basic(StorableBasicField {
364                name: field_name.clone(),
365                item_type: quote! { nuts_storable::ItemType::I64 },
366                is_vec: true,
367                is_option: true,
368                dims,
369            }),
370            "Option < Vec < f64 > >" => StorableField::Basic(StorableBasicField {
371                name: field_name.clone(),
372                item_type: quote! { nuts_storable::ItemType::F64 },
373                is_vec: true,
374                is_option: true,
375                dims,
376            }),
377            "Option < Vec < f32 > >" => StorableField::Basic(StorableBasicField {
378                name: field_name.clone(),
379                item_type: quote! { nuts_storable::ItemType::F32 },
380                is_vec: true,
381                is_option: true,
382                dims,
383            }),
384            "Option< Vec < bool > >" => StorableField::Basic(StorableBasicField {
385                name: field_name.clone(),
386                item_type: quote! { nuts_storable::ItemType::Bool },
387                is_vec: true,
388                is_option: true,
389                dims,
390            }),
391            _ => {
392                // Attempt to handle complex generic types that are still Storable
393                if let Type::Path(type_path) = ty {
394                    // Check if it's a type that has the Storable trait
395                    let type_token = quote!(#type_path);
396                    storable_fields.push(StorableField::Inner(StorableInnerField {
397                        name: field_name.clone(),
398                        item_type: type_token,
399                        is_option: false,
400                    }));
401                    continue;
402                } else {
403                    panic!("Unsupported field type: {}", ty_str);
404                }
405            }
406        };
407        storable_fields.push(item);
408    }
409
410    let names_exprs = storable_fields.iter().map(|f| match f {
411        StorableField::Basic(field) => {
412            let name = field.name.to_string();
413            quote! { vec![#name] }
414        }
415        StorableField::Inner(field) => {
416            let item_type = &field.item_type;
417            quote! { #item_type::names(parent) }
418        }
419        StorableField::Generic(field) => {
420            let name = field.name.to_string();
421            if field.is_option {
422                quote! { vec![#name] }
423            } else {
424                let item_type = &field.item_type;
425                quote! { #item_type::names(parent) }
426            }
427        }
428    });
429
430    let names_fn = quote! {
431        fn names(parent: &P) -> Vec<&str> {
432            let mut names = Vec::new();
433            #(names.extend(#names_exprs);)*
434            names
435        }
436    };
437
438    let item_type_arms = storable_fields.iter().map(|f| match f {
439        StorableField::Basic(field) => {
440            let name_str = field.name.to_string();
441            let item_type = &field.item_type;
442            quote! { #name_str => #item_type, }
443        }
444        StorableField::Inner(field) => {
445            let item_type = &field.item_type;
446            quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), }
447        }
448        StorableField::Generic(field) => {
449            let name_str = field.name.to_string();
450            let item_type = &field.item_type;
451            if field.is_option {
452                quote! { #name_str => nuts_storable::ItemType::Generic, }
453            } else {
454                quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), }
455            }
456        }
457    });
458
459    let item_type_fn = quote! {
460        fn item_type(parent: &P, item: &str) -> nuts_storable::ItemType {
461            match item {
462                #(#item_type_arms)*
463                _ => { panic!("Unknown item: {}", item); }
464            }
465        }
466    };
467
468    let dims_arms = storable_fields.iter().map(|f| match f {
469        StorableField::Basic(field) => {
470            let name_str = field.name.to_string();
471            let dims = &field.dims;
472            quote! { #name_str => vec![#(#dims),*], }
473        }
474        StorableField::Inner(field) => {
475            let item_type = &field.item_type;
476            quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), }
477        }
478        StorableField::Generic(field) => {
479            let name_str = field.name.to_string();
480            let item_type = &field.item_type;
481            if field.is_option {
482                quote! { #name_str => vec![], }
483            } else {
484                quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), }
485            }
486        }
487    });
488
489    let dims_fn = quote! {
490        fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str> {
491            match item {
492                #(#dims_arms)*
493                _ => { panic!("Unknown item: {}", item); }
494            }
495        }
496    };
497
498    let get_all_exprs = storable_fields.iter().map(|f| match f {
499        StorableField::Basic(field) => {
500            let name = &field.name;
501            let name_str = name.to_string();
502            let value_expr = if field.is_option {
503                if field.is_vec {
504                    quote! { self.#name.as_ref().map(|v| nuts_storable::Value::from(v.clone())) }
505                } else {
506                    quote! { self.#name.map(nuts_storable::Value::from) }
507                }
508            } else {
509                quote! { Some(nuts_storable::Value::from(self.#name.clone())) }
510            };
511            quote! { result.push((#name_str, #value_expr)); }
512        }
513        StorableField::Inner(field) => {
514            let name = &field.name;
515            if field.is_option {
516                quote! {
517                    if let Some(inner) = &mut self.#name {
518                        result.extend(inner.get_all(parent));
519                    }
520                }
521            } else {
522                quote! { result.extend(self.#name.get_all(parent)); }
523            }
524        }
525        StorableField::Generic(field) => {
526            let name = &field.name;
527            if field.is_option {
528                quote! {
529                    if let Some(inner) = &mut self.#name {
530                        result.push((#name.to_string().as_str(), Some(nuts_storable::Value::Generic(Box::new(inner.clone())))));
531                    } else {
532                        result.push((#name.to_string().as_str(), None));
533                    }
534                }
535            } else {
536                quote! { result.extend(self.#name.get_all(parent)); }
537            }
538        }
539    });
540
541    let get_all_fn = quote! {
542        fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
543            let mut result = Vec::with_capacity(Self::names(parent).len());
544            #(#get_all_exprs)*
545            result
546        }
547    };
548
549    let r#gen = quote! {
550        impl #impl_generics nuts_storable::Storable<P> for #name #ty_generics #where_clause {
551            #names_fn
552            #item_type_fn
553            #dims_fn
554            #get_all_fn
555        }
556    };
557
558    r#gen.into()
559}