1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, quote};
5use syn::{
6 AngleBracketedGenericArguments, Data, DeriveInput, Fields, GenericParam, Ident, Lit, LitStr,
7 PathArguments, Token, Type, TypePath,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11};
12
13enum StorableAttr {
15 Item(Vec<LitStr>),
16 Flattened(),
17 Ignore(),
18}
19
20impl Parse for StorableAttr {
21 fn parse(input: ParseStream) -> syn::Result<Self> {
22 let metas = Punctuated::<syn::Meta, Token![,]>::parse_terminated(input)?;
23
24 for meta in metas {
25 match meta {
26 syn::Meta::List(list) => {
27 if list.path.is_ident("dims") {
28 return Ok(StorableAttr::Item(
29 list.nested
30 .into_iter()
31 .map(|e| match e {
32 syn::NestedMeta::Lit(Lit::Str(s)) => Ok(s),
33 _ => Err(syn::Error::new_spanned(e, "Expected string literal")),
34 })
35 .collect::<Result<Vec<_>, _>>()?,
36 ));
37 }
38 }
39 syn::Meta::Path(path) => {
40 if path.is_ident("flatten") {
41 return Ok(StorableAttr::Flattened());
42 }
43 if path.is_ident("ignore") {
44 return Ok(StorableAttr::Ignore());
45 }
46 }
47 _ => {
48 return Err(syn::Error::new_spanned(
49 meta,
50 "Unsupported storable attribute. Expected `dims(...)` or `flatten`",
51 ));
52 }
53 }
54 }
55
56 Ok(StorableAttr::Item(vec![]))
57 }
58}
59
60struct StorableBasicField {
61 name: Ident,
62 item_type: proc_macro2::TokenStream,
63 is_vec: bool,
64 is_option: bool,
65 dims: Vec<LitStr>,
66}
67
68struct StorableInnerField {
69 name: Ident,
70 item_type: proc_macro2::TokenStream,
71 is_option: bool,
72}
73
74enum StorableField {
75 Basic(StorableBasicField),
76 Inner(StorableInnerField),
77 Generic(StorableInnerField),
78}
79
80fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool {
82 if let Type::Path(type_path) = ty
83 && type_path.path.segments.len() == 1
84 {
85 let type_name = &type_path.path.segments.first().unwrap().ident;
86 return generics.params.iter().any(|param| {
87 if let GenericParam::Type(type_param) = param {
88 &type_param.ident == type_name
89 } else {
90 false
91 }
92 });
93 }
94 false
95}
96
97fn has_storable_bound(ty: &Ident, generics: &syn::Generics) -> bool {
99 for param in &generics.params {
100 if let GenericParam::Type(type_param) = param
101 && &type_param.ident == ty
102 {
103 for bound in &type_param.bounds {
104 if let syn::TypeParamBound::Trait(trait_bound) = bound {
105 let path = &trait_bound.path;
106 if path.segments.len() == 1
107 && path.segments.first().unwrap().ident == "Storable"
108 {
109 return true;
110 }
111 }
112 }
113 }
114 }
115 false
116}
117
118#[proc_macro_derive(Storable, attributes(storable))]
119pub fn storable_derive(input: TokenStream) -> TokenStream {
120 let ast = parse_macro_input!(input as DeriveInput);
121 let name = &ast.ident;
122 let generics = &ast.generics;
123
124 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
125 let impl_generics = if generics.params.is_empty() {
126 quote! { <P: nuts_storable::HasDims> }
127 } else {
128 quote! { #impl_generics }
129 };
130
131 let fields = if let Data::Struct(s) = ast.data {
132 if let Fields::Named(fields) = s.fields {
133 fields.named
134 } else {
135 panic!("Storable can only be derived for structs with named fields");
136 }
137 } else {
138 panic!("Storable can only be derived on structs");
139 };
140
141 let mut storable_fields = Vec::new();
142 for field in fields {
143 let field_name = field.ident.clone().unwrap();
144 let ty = &field.ty;
145 let ty_str = quote!(#ty).to_string();
146
147 let attr = field
148 .attrs
149 .iter()
150 .find(|a| a.path.is_ident("storable"))
151 .map(|a| a.parse_args::<StorableAttr>().unwrap());
152
153 if let Some(StorableAttr::Ignore()) = attr {
154 continue; }
156
157 let attr = attr.unwrap_or(StorableAttr::Item(vec![]));
158
159 if let StorableAttr::Flattened() = attr {
160 let path = if let Type::Path(TypePath { path: p, qself: _ }) = ty {
161 p
162 } else {
163 panic!(
164 "Unsupported field type with flattened attribute: {}",
165 ty_str
166 );
167 };
168 let item = if path.segments.first().unwrap().ident == "Option" {
169 if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
170 args, ..
171 }) = &path.segments.first().unwrap().arguments
172 {
173 if let Some(arg) = args.first() {
174 let inner_type = quote!(#arg);
175 StorableField::Inner(StorableInnerField {
176 name: field_name.clone(),
177 item_type: inner_type,
178 is_option: true,
179 })
180 } else {
181 panic!("Invalid Option type for flattened field");
182 }
183 } else {
184 panic!("Invalid Option type for flattened field");
185 }
186 } else {
187 StorableField::Inner(StorableInnerField {
188 name: field_name.clone(),
189 item_type: path.into_token_stream(),
190 is_option: false,
191 })
192 };
193 storable_fields.push(item);
194 continue;
195 }
196
197 let dims = if let StorableAttr::Item(dims) = attr {
198 dims
199 } else {
200 vec![]
201 };
202
203 if let Type::Path(type_path) = ty {
205 if type_path.path.segments.len() == 1 {
206 let type_name = &type_path.path.segments.first().unwrap().ident;
207
208 if is_generic_param(ty, generics) && has_storable_bound(type_name, generics) {
210 storable_fields.push(StorableField::Generic(StorableInnerField {
211 name: field_name,
212 item_type: quote!(#type_name),
213 is_option: false,
214 }));
215 continue;
216 }
217
218 if type_name == "Option" {
220 if let PathArguments::AngleBracketed(args) =
221 &type_path.path.segments.first().unwrap().arguments
222 {
223 if let Some(arg) = args.args.first() {
224 if let syn::GenericArgument::Type(inner_ty) = arg {
225 if let Type::Path(inner_path) = inner_ty {
226 if inner_path.path.segments.len() == 1 {
227 let inner_name =
228 &inner_path.path.segments.first().unwrap().ident;
229 if is_generic_param(inner_ty, generics)
230 && has_storable_bound(inner_name, generics)
231 {
232 storable_fields.push(StorableField::Generic(
233 StorableInnerField {
234 name: field_name,
235 item_type: quote!(#inner_name),
236 is_option: true,
237 },
238 ));
239 continue;
240 }
241 }
242 }
243 }
244 }
245 }
246 }
247 }
248 }
249
250 let item = match ty_str.as_str() {
251 "u64" => StorableField::Basic(StorableBasicField {
252 name: field_name.clone(),
253 item_type: quote! { nuts_storable::ItemType::U64 },
254 is_vec: false,
255 is_option: false,
256 dims,
257 }),
258 "i64" => StorableField::Basic(StorableBasicField {
259 name: field_name.clone(),
260 item_type: quote! { nuts_storable::ItemType::I64 },
261 is_vec: false,
262 is_option: false,
263 dims,
264 }),
265 "f64" => StorableField::Basic(StorableBasicField {
266 name: field_name.clone(),
267 item_type: quote! { nuts_storable::ItemType::F64 },
268 is_vec: false,
269 is_option: false,
270 dims,
271 }),
272 "f32" => StorableField::Basic(StorableBasicField {
273 name: field_name.clone(),
274 item_type: quote! { nuts_storable::ItemType::F32 },
275 is_vec: false,
276 is_option: false,
277 dims,
278 }),
279 "bool" => StorableField::Basic(StorableBasicField {
280 name: field_name.clone(),
281 item_type: quote! { nuts_storable::ItemType::Bool },
282 is_vec: false,
283 is_option: false,
284 dims,
285 }),
286 "Option < u64 >" => StorableField::Basic(StorableBasicField {
287 name: field_name.clone(),
288 item_type: quote! { nuts_storable::ItemType::U64 },
289 is_vec: false,
290 is_option: true,
291 dims,
292 }),
293 "Option < i64 >" => StorableField::Basic(StorableBasicField {
294 name: field_name.clone(),
295 item_type: quote! { nuts_storable::ItemType::I64 },
296 is_vec: false,
297 is_option: true,
298 dims,
299 }),
300 "Option < f64 >" => StorableField::Basic(StorableBasicField {
301 name: field_name.clone(),
302 item_type: quote! { nuts_storable::ItemType::F64 },
303 is_vec: false,
304 is_option: true,
305 dims,
306 }),
307 "Option < f32 >" => StorableField::Basic(StorableBasicField {
308 name: field_name.clone(),
309 item_type: quote! { nuts_storable::ItemType::F32 },
310 is_vec: false,
311 is_option: true,
312 dims,
313 }),
314 "Option < bool >" => StorableField::Basic(StorableBasicField {
315 name: field_name.clone(),
316 item_type: quote! { nuts_storable::ItemType::Bool },
317 is_vec: false,
318 is_option: true,
319 dims,
320 }),
321 "Vec < u64 >" => StorableField::Basic(StorableBasicField {
322 name: field_name.clone(),
323 item_type: quote! { nuts_storable::ItemType::U64 },
324 is_vec: true,
325 is_option: false,
326 dims,
327 }),
328 "Vec < i64 >" => StorableField::Basic(StorableBasicField {
329 name: field_name.clone(),
330 item_type: quote! { nuts_storable::ItemType::I64 },
331 is_vec: true,
332 is_option: false,
333 dims,
334 }),
335 "Vec < f64 >" => StorableField::Basic(StorableBasicField {
336 name: field_name.clone(),
337 item_type: quote! { nuts_storable::ItemType::F64 },
338 is_vec: true,
339 is_option: false,
340 dims,
341 }),
342 "Vec < f32 >" => StorableField::Basic(StorableBasicField {
343 name: field_name.clone(),
344 item_type: quote! { nuts_storable::ItemType::F32 },
345 is_vec: true,
346 is_option: false,
347 dims,
348 }),
349 "Vec < bool >" => StorableField::Basic(StorableBasicField {
350 name: field_name.clone(),
351 item_type: quote! { nuts_storable::ItemType::Bool },
352 is_vec: true,
353 is_option: false,
354 dims,
355 }),
356 "Option < Vec < u64 > >" => StorableField::Basic(StorableBasicField {
357 name: field_name.clone(),
358 item_type: quote! { nuts_storable::ItemType::U64 },
359 is_vec: true,
360 is_option: true,
361 dims,
362 }),
363 "Option < Vec < i64 > >" => StorableField::Basic(StorableBasicField {
364 name: field_name.clone(),
365 item_type: quote! { nuts_storable::ItemType::I64 },
366 is_vec: true,
367 is_option: true,
368 dims,
369 }),
370 "Option < Vec < f64 > >" => StorableField::Basic(StorableBasicField {
371 name: field_name.clone(),
372 item_type: quote! { nuts_storable::ItemType::F64 },
373 is_vec: true,
374 is_option: true,
375 dims,
376 }),
377 "Option < Vec < f32 > >" => StorableField::Basic(StorableBasicField {
378 name: field_name.clone(),
379 item_type: quote! { nuts_storable::ItemType::F32 },
380 is_vec: true,
381 is_option: true,
382 dims,
383 }),
384 "Option< Vec < bool > >" => StorableField::Basic(StorableBasicField {
385 name: field_name.clone(),
386 item_type: quote! { nuts_storable::ItemType::Bool },
387 is_vec: true,
388 is_option: true,
389 dims,
390 }),
391 _ => {
392 if let Type::Path(type_path) = ty {
394 let type_token = quote!(#type_path);
396 storable_fields.push(StorableField::Inner(StorableInnerField {
397 name: field_name.clone(),
398 item_type: type_token,
399 is_option: false,
400 }));
401 continue;
402 } else {
403 panic!("Unsupported field type: {}", ty_str);
404 }
405 }
406 };
407 storable_fields.push(item);
408 }
409
410 let names_exprs = storable_fields.iter().map(|f| match f {
411 StorableField::Basic(field) => {
412 let name = field.name.to_string();
413 quote! { vec![#name] }
414 }
415 StorableField::Inner(field) => {
416 let item_type = &field.item_type;
417 quote! { #item_type::names(parent) }
418 }
419 StorableField::Generic(field) => {
420 let name = field.name.to_string();
421 if field.is_option {
422 quote! { vec![#name] }
423 } else {
424 let item_type = &field.item_type;
425 quote! { #item_type::names(parent) }
426 }
427 }
428 });
429
430 let names_fn = quote! {
431 fn names(parent: &P) -> Vec<&str> {
432 let mut names = Vec::new();
433 #(names.extend(#names_exprs);)*
434 names
435 }
436 };
437
438 let item_type_arms = storable_fields.iter().map(|f| match f {
439 StorableField::Basic(field) => {
440 let name_str = field.name.to_string();
441 let item_type = &field.item_type;
442 quote! { #name_str => #item_type, }
443 }
444 StorableField::Inner(field) => {
445 let item_type = &field.item_type;
446 quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), }
447 }
448 StorableField::Generic(field) => {
449 let name_str = field.name.to_string();
450 let item_type = &field.item_type;
451 if field.is_option {
452 quote! { #name_str => nuts_storable::ItemType::Generic, }
453 } else {
454 quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), }
455 }
456 }
457 });
458
459 let item_type_fn = quote! {
460 fn item_type(parent: &P, item: &str) -> nuts_storable::ItemType {
461 match item {
462 #(#item_type_arms)*
463 _ => { panic!("Unknown item: {}", item); }
464 }
465 }
466 };
467
468 let dims_arms = storable_fields.iter().map(|f| match f {
469 StorableField::Basic(field) => {
470 let name_str = field.name.to_string();
471 let dims = &field.dims;
472 quote! { #name_str => vec![#(#dims),*], }
473 }
474 StorableField::Inner(field) => {
475 let item_type = &field.item_type;
476 quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), }
477 }
478 StorableField::Generic(field) => {
479 let name_str = field.name.to_string();
480 let item_type = &field.item_type;
481 if field.is_option {
482 quote! { #name_str => vec![], }
483 } else {
484 quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), }
485 }
486 }
487 });
488
489 let dims_fn = quote! {
490 fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str> {
491 match item {
492 #(#dims_arms)*
493 _ => { panic!("Unknown item: {}", item); }
494 }
495 }
496 };
497
498 let get_all_exprs = storable_fields.iter().map(|f| match f {
499 StorableField::Basic(field) => {
500 let name = &field.name;
501 let name_str = name.to_string();
502 let value_expr = if field.is_option {
503 if field.is_vec {
504 quote! { self.#name.as_ref().map(|v| nuts_storable::Value::from(v.clone())) }
505 } else {
506 quote! { self.#name.map(nuts_storable::Value::from) }
507 }
508 } else {
509 quote! { Some(nuts_storable::Value::from(self.#name.clone())) }
510 };
511 quote! { result.push((#name_str, #value_expr)); }
512 }
513 StorableField::Inner(field) => {
514 let name = &field.name;
515 if field.is_option {
516 quote! {
517 if let Some(inner) = &mut self.#name {
518 result.extend(inner.get_all(parent));
519 }
520 }
521 } else {
522 quote! { result.extend(self.#name.get_all(parent)); }
523 }
524 }
525 StorableField::Generic(field) => {
526 let name = &field.name;
527 if field.is_option {
528 quote! {
529 if let Some(inner) = &mut self.#name {
530 result.push((#name.to_string().as_str(), Some(nuts_storable::Value::Generic(Box::new(inner.clone())))));
531 } else {
532 result.push((#name.to_string().as_str(), None));
533 }
534 }
535 } else {
536 quote! { result.extend(self.#name.get_all(parent)); }
537 }
538 }
539 });
540
541 let get_all_fn = quote! {
542 fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
543 let mut result = Vec::with_capacity(Self::names(parent).len());
544 #(#get_all_exprs)*
545 result
546 }
547 };
548
549 let r#gen = quote! {
550 impl #impl_generics nuts_storable::Storable<P> for #name #ty_generics #where_clause {
551 #names_fn
552 #item_type_fn
553 #dims_fn
554 #get_all_fn
555 }
556 };
557
558 r#gen.into()
559}