podstru_derive/
lib.rs

1#[macro_use]
2extern crate quote;
3#[macro_use]
4extern crate syn;
5
6extern crate proc_macro;
7
8use std::collections::HashMap;
9
10use proc_macro::TokenStream;
11use quote::ToTokens;
12use syn::{
13  punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed, Ident, Meta,
14  PathArguments, Type,
15};
16
17/// Allows derivation of a builder pattern on any struct
18///
19/// # Examples
20///
21/// ```rust
22/// use podstru_derive::Builder;
23/// use podstru_internal::Builder;
24///
25/// #[derive(Builder, Debug, PartialEq)]
26/// struct Data {
27///   field: usize
28/// }
29///
30/// fn main() {
31///   let data = Data::builder().with_field(42).build();
32///   assert_eq!(data, Data { field: 42 });
33/// }
34/// ```
35#[proc_macro_derive(Builder, attributes(builder))]
36pub fn builder(input: TokenStream) -> TokenStream {
37  // Parse the input tokens into a syntax tree
38  let input = parse_macro_input!(input as DeriveInput);
39
40  let in_name = input.ident;
41  let fname = format!("{}Builder", in_name);
42  let builder_ty = syn::Ident::new(&fname, in_name.span());
43
44  let orig_fields = match validate_struct(&input.data) {
45    Ok(fields) => fields,
46    Err(e) => return e.into(),
47  };
48
49  let field_accessors = orig_fields
50    .named
51    .iter()
52    .map(|f| {
53      let field_name = f.ident.clone().unwrap();
54      let field_ty = f.ty.clone();
55      let with_func_name = Ident::new(&format!("with_{}", field_name.clone()), f.span());
56      // let get_func_name = Ident::new(&format!("get_{}", field_name.clone()), f.span());
57      let set_func_name = Ident::new(&format!("set_{}", field_name.clone()), f.span());
58      let ref_func_name = Ident::new(&format!("{}", field_name.clone()), f.span());
59      let ref_mut_func_name = Ident::new(&format!("{}_mut", field_name.clone()), f.span());
60      let unwrapped = if let Type::Path(path) = &field_ty {
61        path
62          .path
63          .segments
64          .iter()
65          .find_map(|seg| match &seg.arguments {
66            PathArguments::AngleBracketed(args) => args.args.first(),
67            _ => None,
68          })
69      } else {
70        None
71      };
72      let field_ty = match unwrapped {
73        Some(ty) => quote! {#ty},
74        None => quote! {#field_ty},
75      };
76      quote! {
77        pub fn #with_func_name(mut self, v: #field_ty) -> Self {
78          self.#field_name = Some(v);
79          self
80        }
81
82        pub fn #ref_func_name(&self) -> Option<&#field_ty> {
83          self.#field_name.as_ref()
84        }
85
86        pub fn #ref_mut_func_name(&mut self) -> &mut Option<#field_ty> {
87          &mut self.#field_name
88        }
89
90        pub fn #set_func_name(&mut self, v: #field_ty) -> &mut Self {
91          self.#field_name = Some(v);
92          self
93        }
94      }
95    })
96    .collect::<proc_macro2::TokenStream>();
97  let mut field_values: HashMap<Ident, proc_macro2::TokenStream> = HashMap::from_iter(
98    orig_fields
99      .named
100      .iter()
101      .flat_map(|f| {
102        f.attrs.iter().find_map(move |attr| {
103          if attr.path().is_ident("builder") {
104            let nested = attr
105              .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
106              .unwrap();
107            for meta in nested {
108              match meta {
109                Meta::NameValue(meta_name_value) => {
110                  if meta_name_value.path.is_ident("default") {
111                    println!(
112                      "[{}] {} = {}",
113                      f.ident.clone().to_token_stream().to_string(),
114                      meta_name_value.path.to_token_stream().to_string(),
115                      meta_name_value.value.to_token_stream().to_string()
116                    );
117                    let field_name = f.ident.clone().unwrap();
118                    let field_value = meta_name_value.value.to_token_stream();
119                    if is_option(&f.ty) {
120                      return Some((
121                        field_name.clone(),
122                        quote! {
123                          Some(self.#field_name.unwrap_or_else(|| #field_value))
124                        },
125                      ));
126                    } else {
127                      return Some((
128                        field_name.clone(),
129                        quote! {
130                          self.#field_name.unwrap_or_else(|| #field_value)
131                        },
132                      ));
133                    }
134                  }
135                }
136                _ => {}
137              }
138            }
139          }
140          None
141        })
142      })
143      .collect::<Vec<_>>(),
144  );
145
146  for orig_field in &orig_fields.named {
147    let field_name = orig_field.ident.clone().unwrap();
148    if !field_values.contains_key(&field_name) {
149      field_values.insert(
150        field_name.clone(),
151        quote! {
152          self.#field_name.unwrap_or_default()
153        },
154      );
155    }
156  }
157
158  let new_fields = orig_fields
159    .named
160    .iter()
161    .map(|field| {
162      let field_name = field.ident.clone().unwrap();
163      let field_vis = field.vis.clone();
164      let field_ty = field.ty.clone();
165      if is_option(&field_ty) {
166        quote! {
167          #field_vis #field_name: #field_ty,
168        }
169      } else {
170        quote! {
171          #field_vis #field_name: Option<#field_ty>,
172        }
173      }
174    })
175    .collect::<proc_macro2::TokenStream>();
176
177  let builder_ctor: proc_macro2::TokenStream = orig_fields
178    .named
179    .iter()
180    .map(|field| {
181      let field_name = field.ident.clone().unwrap();
182      quote! {
183        #field_name: Default::default(),
184      }
185    })
186    .collect();
187  let builder_ctor: proc_macro2::TokenStream = quote! {
188    #builder_ty {
189      #builder_ctor
190    }
191  };
192
193  let orig_ctor: proc_macro2::TokenStream = orig_fields
194    .named
195    .iter()
196    .map(|field| {
197      let field_name = field.ident.clone().unwrap();
198      let field_value = &field_values[&field_name];
199      quote! {
200          #field_name: #field_value,
201      }
202    })
203    .collect();
204  let orig_ctor: proc_macro2::TokenStream = quote! {
205    #in_name {
206      #orig_ctor
207    }
208  };
209
210  // Build the output, possibly using quasi-quotation
211  let expanded = quote! {
212        struct #builder_ty {
213          #new_fields
214        }
215
216        impl Default for #builder_ty {
217          fn default() -> Self {
218            #builder_ctor
219          }
220        }
221
222        impl podstru_internal::Builder for #in_name {
223          type Target = #builder_ty;
224
225          fn builder() -> Self::Target {
226            Self::Target::default()
227          }
228        }
229
230        impl #builder_ty {
231          #field_accessors
232
233          pub fn build(mut self) -> #in_name {
234            #orig_ctor
235          }
236        }
237  };
238
239  // Hand the output tokens back to the compiler
240  TokenStream::from(expanded)
241}
242
243#[proc_macro_derive(Getters, attributes(getters))]
244pub fn getters(input: TokenStream) -> TokenStream {
245  // Parse the input tokens into a syntax tree
246  let input = parse_macro_input!(input as DeriveInput);
247
248  let in_name = input.ident;
249
250  let orig_fields = match validate_struct(&input.data) {
251    Ok(fields) => fields,
252    Err(e) => return e.into(),
253  };
254
255  let mut field_skips: Vec<Ident> = vec![];
256  for field in &orig_fields.named {
257    for attr in &field.attrs {
258      if attr.path().is_ident("getters") {
259        let nested = attr
260          .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
261          .unwrap();
262        for meta in nested {
263          match meta {
264            Meta::NameValue(meta_name_value) => {
265              if meta_name_value.path.is_ident("skip") {
266                return quote_spanned! {
267                  attr.path().span() => compile_error!("`skip` attribute on `Getters` derive macro cannot have a value")
268                }.into();
269              }
270            }
271            Meta::Path(path) => {
272              if path.is_ident("skip") {
273                let field_name = field.ident.clone().unwrap();
274                field_skips.push(field_name.clone());
275              }
276            }
277            Meta::List(list) => {
278              if list.path.is_ident("skip") {
279                let field_name = field.ident.clone().unwrap();
280                field_skips.push(field_name.clone());
281              }
282            }
283          }
284        }
285      }
286    }
287  }
288  let field_accessors = orig_fields
289    .named
290    .iter()
291    .map(|f| {
292      let field_name = f.ident.clone().unwrap();
293      let field_ty = f.ty.clone();
294      let ref_func_name = Ident::new(&format!("{}", field_name.clone()), f.span());
295      let unwrapped = if let Type::Path(path) = &field_ty {
296        path
297          .path
298          .segments
299          .iter()
300          .find_map(|seg| match &seg.arguments {
301            PathArguments::AngleBracketed(args) => args.args.first(),
302            _ => None,
303          })
304      } else {
305        None
306      };
307      let unwrapped_field_ty = match unwrapped {
308        Some(ty) => quote! {#ty},
309        None => quote! {#field_ty},
310      };
311
312      if !field_skips.contains(&field_name) {
313        if is_option(&field_ty) {
314          quote! {
315            pub fn #ref_func_name(&self) -> Option<&#unwrapped_field_ty> {
316              self.#field_name.as_ref()
317            }
318          }
319        } else {
320          quote! {
321            pub fn #ref_func_name(&self) -> &#field_ty {
322              &self.#field_name
323            }
324          }
325        }
326      } else {
327        quote! {}
328      }
329    })
330    .collect::<proc_macro2::TokenStream>();
331
332  // Build the output, possibly using quasi-quotation
333  let expanded = quote! {
334      impl #in_name {
335        #field_accessors
336      }
337  };
338
339  // Hand the output tokens back to the compiler
340  TokenStream::from(expanded)
341}
342
343#[proc_macro_derive(Setters, attributes(setters))]
344pub fn setters(input: TokenStream) -> TokenStream {
345  // Parse the input tokens into a syntax tree
346  let input = parse_macro_input!(input as DeriveInput);
347
348  let in_name = input.ident;
349
350  let orig_fields = match validate_struct(&input.data) {
351    Ok(fields) => fields,
352    Err(e) => return e.into(),
353  };
354
355  let mut field_skips: Vec<Ident> = vec![];
356  for field in &orig_fields.named {
357    for attr in &field.attrs {
358      if attr.path().is_ident("setters") {
359        let nested = attr
360          .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
361          .unwrap();
362        for meta in nested {
363          match meta {
364            Meta::NameValue(meta_name_value) => {
365              if meta_name_value.path.is_ident("skip") {
366                return quote_spanned! {
367                  attr.path().span() => compile_error!("`skip` attribute on `Setters` derive macro cannot have a value")
368                }.into();
369              }
370            }
371            Meta::Path(path) => {
372              if path.is_ident("skip") {
373                let field_name = field.ident.clone().unwrap();
374                field_skips.push(field_name.clone());
375              }
376            }
377            Meta::List(list) => {
378              if list.path.is_ident("skip") {
379                let field_name = field.ident.clone().unwrap();
380                field_skips.push(field_name.clone());
381              }
382            }
383          }
384        }
385      }
386    }
387  }
388
389  let field_accessors = orig_fields
390    .named
391    .iter()
392    .map(|f| {
393      let field_name = f.ident.clone().unwrap();
394      let field_ty = f.ty.clone();
395      let ref_mut_func_name = Ident::new(&format!("{}_mut", field_name.clone()), f.span());
396      let set_func_name = Ident::new(&format!("set_{}", field_name.clone()), f.span());
397      let with_func_name = Ident::new(&format!("with_{}", field_name.clone()), f.span());
398      if !field_skips.contains(&field_name) {
399        if is_option(&field_ty) {
400          quote! {
401            pub fn #ref_mut_func_name(&mut self) -> &mut #field_ty {
402              &mut self.#field_name
403            }
404
405            pub fn #set_func_name(&mut self, v: #field_ty) -> &mut Self {
406              self.#field_name = v;
407              self
408            }
409
410            pub fn #with_func_name(mut self, v: #field_ty) -> Self {
411              self.#field_name = v;
412              self
413            }
414          }
415        } else {
416          quote! {
417            pub fn #ref_mut_func_name(&self) -> &#field_ty {
418              &self.#field_name
419            }
420
421            pub fn #set_func_name(&mut self, v: #field_ty) -> &mut Self {
422              self.#field_name = v;
423              self
424            }
425
426            pub fn #with_func_name(mut self, v: #field_ty) -> Self {
427              self.#field_name = v;
428              self
429            }
430          }
431        }
432      } else {
433        quote! {}
434      }
435    })
436    .collect::<proc_macro2::TokenStream>();
437
438  // Build the output, possibly using quasi-quotation
439  let expanded = quote! {
440      impl #in_name {
441        #field_accessors
442      }
443  };
444
445  // Hand the output tokens back to the compiler
446  TokenStream::from(expanded)
447}
448
449#[proc_macro_derive(Fields, attributes(fields))]
450pub fn fields(input: TokenStream) -> TokenStream {
451  // Parse the input tokens into a syntax tree
452  let input = parse_macro_input!(input as DeriveInput);
453
454  let in_name = input.ident;
455
456  let orig_fields = match validate_struct(&input.data) {
457    Ok(fields) => fields,
458    Err(e) => return e.into(),
459  };
460
461  let mut field_skips: Vec<Ident> = vec![];
462  for field in &orig_fields.named {
463    for attr in &field.attrs {
464      if attr.path().is_ident("fields") {
465        let nested = attr
466          .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
467          .unwrap();
468        for meta in nested {
469          match meta {
470            Meta::NameValue(meta_name_value) => {
471              if meta_name_value.path.is_ident("skip") {
472                return quote_spanned! {
473                  attr.path().span() => compile_error!("`skip` attribute on `Fields` derive macro cannot have a value")
474                }.into();
475              }
476            }
477            Meta::Path(path) => {
478              if path.is_ident("skip") {
479                let field_name = field.ident.clone().unwrap();
480                field_skips.push(field_name.clone());
481              }
482            }
483            Meta::List(list) => {
484              if list.path.is_ident("skip") {
485                let field_name = field.ident.clone().unwrap();
486                field_skips.push(field_name.clone());
487              }
488            }
489          }
490        }
491      }
492    }
493  }
494  let field_accessors = orig_fields
495    .named
496    .iter()
497    .map(|f| {
498      let field_name = f.ident.clone().unwrap();
499      let field_ty = f.ty.clone();
500      let ref_func_name = Ident::new(&format!("{}", field_name.clone()), f.span());
501      let ref_mut_func_name = Ident::new(&format!("{}_mut", field_name.clone()), f.span());
502      let set_func_name = Ident::new(&format!("set_{}", field_name.clone()), f.span());
503      let with_func_name = Ident::new(&format!("with_{}", field_name.clone()), f.span());
504      let unwrapped = if let Type::Path(path) = &field_ty {
505        path
506          .path
507          .segments
508          .iter()
509          .find_map(|seg| match &seg.arguments {
510            PathArguments::AngleBracketed(args) => args.args.first(),
511            _ => None,
512          })
513      } else {
514        None
515      };
516      let unwrapped_field_ty = match unwrapped {
517        Some(ty) => quote! {#ty},
518        None => quote! {#field_ty},
519      };
520
521      if !field_skips.contains(&field_name) {
522        if is_option(&field_ty) {
523          quote! {
524            #[doc = concat!("Return the `", stringify!(#field_name), "` field as a mutable reference.")]
525            pub fn #ref_mut_func_name(&mut self) -> &mut #field_ty {
526              &mut self.#field_name
527            }
528
529            #[doc = concat!("Define the `", stringify!(#field_name), "` field.")]
530            pub fn #set_func_name(&mut self, v: #field_ty) -> &mut Self {
531              self.#field_name = v;
532              self
533            }
534
535            #[doc = concat!("Define the `", stringify!(#field_name), "` field.")]
536            pub fn #with_func_name(mut self, v: #field_ty) -> Self {
537              self.#field_name = v;
538              self
539            }
540
541            #[doc = concat!("Return the `", stringify!(#field_name), "` field.")]
542            pub fn #ref_func_name(&self) -> Option<&#unwrapped_field_ty> {
543              self.#field_name.as_ref()
544            }
545          }
546        } else {
547          quote! {
548            #[doc = concat!("Retrieve the `", stringify!(#field_name), "` field as a mutable reference.")]
549            pub fn #ref_mut_func_name(&mut self) -> &mut #field_ty {
550              &mut self.#field_name
551            }
552
553            #[doc = concat!("Define the `", stringify!(#field_name), "` field.")]
554            pub fn #set_func_name(&mut self, v: #field_ty) -> &mut Self {
555              self.#field_name = v;
556              self
557            }
558
559            #[doc = concat!("Define the `", stringify!(#field_name), "` field.")]
560            pub fn #with_func_name(mut self, v: #field_ty) -> Self {
561              self.#field_name = v;
562              self
563            }
564
565            #[doc = concat!("Retrieve the `", stringify!(#field_name), "` field as a reference.")]
566            pub fn #ref_func_name(&self) -> &#field_ty {
567              &self.#field_name
568            }
569          }
570        }
571      } else {
572        quote!{}
573      }
574    })
575    .collect::<proc_macro2::TokenStream>();
576
577  // Build the output, possibly using quasi-quotation
578  let expanded = quote! {
579      impl #in_name {
580        #field_accessors
581      }
582  };
583
584  // Hand the output tokens back to the compiler
585  TokenStream::from(expanded)
586}
587
588fn is_option(ty: &Type) -> bool {
589  match ty {
590    Type::Path(path) if path.qself.is_none() => path
591      .path
592      .segments
593      .iter()
594      .find(|seg| {
595        seg
596          .ident
597          .span()
598          .source_text()
599          .unwrap_or_default()
600          .eq("Option")
601      })
602      .is_some(),
603    _ => false,
604  }
605}
606fn validate_struct(data: &Data) -> Result<&FieldsNamed, proc_macro2::TokenStream> {
607  Ok(match data {
608    Data::Struct(s) => match &s.fields {
609      Fields::Named(fields) => fields,
610      Fields::Unit => {
611        return Err(quote_spanned! {
612          s.struct_token.span() =>
613            compile_error!("Builder pattern only available for named fields");
614        })
615      }
616
617      Fields::Unnamed(u) => {
618        return Err(quote_spanned! {
619          u.paren_token.span =>
620          compile_error!("Builder pattern only available for named fields");
621        })
622      }
623    },
624    Data::Enum(e) => {
625      return Err(quote_spanned! {
626          e.enum_token.span =>
627          compile_error!("Builder pattern only available for Struct");
628      })
629    }
630    Data::Union(u) => {
631      return Err(quote_spanned! {
632          u.union_token.span =>
633          compile_error!("Builder pattern only available for Struct");
634      })
635    }
636  })
637}
638
639#[proc_macro_derive(Ctor, attributes(ctor))]
640pub fn ctor(input: TokenStream) -> TokenStream {
641  // Parse the input tokens into a syntax tree
642  let input = parse_macro_input!(input as DeriveInput);
643
644  let in_ty = input.ident;
645
646  let orig_fields = match validate_struct(&input.data) {
647    Ok(fields) => fields,
648    Err(e) => return e.into(),
649  };
650
651  let field_skips: HashMap<Ident, proc_macro2::TokenStream> = HashMap::from_iter(
652    orig_fields
653      .named
654      .iter()
655      .flat_map(|f| {
656        f.attrs.iter().find_map(move |attr| {
657          if attr.path().is_ident("ctor") {
658            let nested = attr
659              .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
660              .unwrap();
661            let field_name = f.ident.clone().unwrap();
662            for meta in nested {
663              match meta {
664                Meta::NameValue(meta_name_value) => {
665                  if meta_name_value.path.is_ident("skip") {
666                    println!(
667                      "[{}] {} = {}",
668                      f.ident.clone().to_token_stream().to_string(),
669                      meta_name_value.path.to_token_stream().to_string(),
670                      meta_name_value.value.to_token_stream().to_string()
671                    );
672                    let field_value = meta_name_value.value.to_token_stream();
673                    return Some((
674                      field_name.clone(),
675                      quote! {
676                        #field_value
677                      },
678                    ));
679                  }
680                }
681                Meta::Path(path) => {
682                  if path.is_ident("skip") {
683                    return Some((field_name.clone(), quote_spanned! {
684                      path.span() => compile_error!("`skip` attribute on `Ctor` derive macro must have a value: the default value of the skipped field")
685                    }))
686                  }
687                }
688                Meta::List(list) => {
689                  if list.path.is_ident("skip") {
690                    return Some((field_name.clone(), quote_spanned! {
691                      list.span() => compile_error!("`skip` attribute on `Ctor` derive macro must have a value: the default value of the skipped field")
692                    }))
693                  }
694                }
695              }
696            }
697          }
698          None
699        })
700      })
701      .collect::<Vec<_>>(),
702  );
703
704  let orig_ctor_params: proc_macro2::TokenStream = orig_fields
705    .named
706    .iter()
707    .map(|field| {
708      let field_name = field.ident.as_ref().unwrap();
709      let field_ty = &field.ty;
710      if !field_skips.contains_key(field_name) {
711        quote! {
712            #field_name: #field_ty,
713        }
714      } else {
715        quote! {}
716      }
717    })
718    .collect();
719
720  let orig_ctor: proc_macro2::TokenStream = orig_fields
721    .named
722    .iter()
723    .map(|field| {
724      let field_name = field.ident.clone().unwrap();
725      if let Some(skipped_field) = field_skips.get(&field_name) {
726        quote! {#field_name: #skipped_field,}
727      } else {
728        quote! {#field_name,}
729      }
730    })
731    .collect();
732
733  // Build the output, possibly using quasi-quotation
734  let expanded = quote! {
735    impl #in_ty {
736      pub fn new(#orig_ctor_params) -> Self {
737        Self {
738          #orig_ctor
739        }
740      }
741    }
742  };
743
744  // Hand the output tokens back to the compiler
745  TokenStream::from(expanded)
746}