builders 0.5.0

Rust macros for building structs
Documentation
use std::borrow::Cow;

use proc_macro2::{TokenStream, TokenTree};
use quote::quote;
use syn::{parse_macro_input, parse_quote, Data, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, Ident, Lit, Type};

use crate::util::{get_inner_ty, get_stripped_generics};

fn try_find_attr(f: &[syn::Attribute], name: &str) -> Option<()> {
    for attr in f {
        let Ok(list) = attr.meta.require_list() else { continue };
        let segments = &list.path.segments;
        if segments.len() != 1 || segments[0].ident != "builder" { continue }
        let mut tokens = list.tokens.clone().into_iter();
        match tokens.next().unwrap() {
            TokenTree::Ident(ref i) => if i == name {
                return Some(())
            },
            tt => panic!("Expected \"{name}\", found: {tt}"),
        }
    }
    None
}

fn find_attr_nameval(f: &[syn::Attribute], name: &str) -> Result<TokenStream,proc_macro2::TokenStream> {
    for attr in f {
        let Ok(list) = attr.meta.require_list() else { continue };
        let segments = &list.path.segments;
        if segments.len() != 1 || segments[0].ident != "builder" { return Err(TokenStream::new()); }
        let mut tokens = list.tokens.clone().into_iter();
        match tokens.next().unwrap() {
            TokenTree::Ident(ref i) => if i != name {
                continue
            },
            tt => panic!("Expected \"{name}\", found: {tt}"),
        }
        match tokens.next().unwrap() {
            TokenTree::Punct(ref p) => assert_eq!(p.as_char(), '='),
            tt => panic!("Expected '=', found: {tt}"),
        }
        return Ok(tokens.next().unwrap().into());
    }
    Err(TokenStream::new())
}

fn find_attr_nameval_lit(f: &[syn::Attribute], name: &str) -> Result<Lit,proc_macro2::TokenStream> {
    let token = find_attr_nameval(f, name)?;
    let lit: Lit = parse_quote!(#token);
    Ok(lit)
}

fn find_attr_nameval_expr(f: &[syn::Attribute], name: &str) -> Result<Expr,proc_macro2::TokenStream> {
    let token = find_attr_nameval(f, name)?;
    let lit: Expr = parse_quote!(#token);
    Ok(lit)
}

fn find_attr_nameval_bool(f: &[syn::Attribute], name: &str) -> Result<bool,proc_macro2::TokenStream> {
    if try_find_attr(f, name).is_some() { return Ok(true) }
    let token = find_attr_nameval(f, name)?;
    let lit: Lit = parse_quote!(#token);
    let Lit::Bool(b) = lit else { return Err(quote! { compile_error!("Expected boolean attribute"); })};
    Ok(b.value)
}

fn is_optional(f: &syn::Field) -> bool {
    find_attr_nameval_bool(&f.attrs, "optional").unwrap_or_default()
}

fn is_disabled(f: &syn::Field) -> bool {
    find_attr_nameval_bool(&f.attrs, "disabled").unwrap_or_default()
}

fn is_each(f: &syn::Field) -> bool {
    matches!(find_attr_nameval_lit(&f.attrs, "vec"), Ok(Lit::Str(_)))
    || matches!(find_attr_nameval_lit(&f.attrs, "map"), Ok(Lit::Str(_)))
}

pub (crate) fn builder_derive_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = parse_macro_input!(input as DeriveInput);
    let ident = &ast.ident;
    let builder_name = format!("{ident}Builder");
    let builder_name = Ident::new(&builder_name, ident.span());

    let fields=
    if let Data::Struct(DataStruct {
        fields: Fields::Named(FieldsNamed { ref named, .. }),
        ..
    }) = ast.data {
        named
    } else {
        unimplemented!()
    };
    let builder_fields = fields.iter().filter(|f| !is_disabled(f)).map(|f| {
        let name = &f.ident;
        let ty = &f.ty;
        if is_optional(f) || is_each(f) {
            quote!( #name: #ty )
        } else {
            quote!( #name: ::core::option::Option<#ty> )
        }
    });
    let builder_methods = fields.iter().filter(|f| !is_disabled(f)).map(|f| {
        let field_name = &f.ident;
        match find_attr_nameval_lit(&f.attrs, "vec") {
            Ok(Lit::Str(lit)) => {
                let arg = Ident::new(&lit.value(), lit.span());
                let set_arg = Ident::new(&format!("set_{}", &lit.value()), lit.span());
                let tys = get_inner_ty(&f.ty).expect("Expected at least one generic argument");
                let inner = tys.first();
                quote! {
                    pub fn #arg(mut self, #arg: impl ::core::convert::Into<#inner>) -> Self {
                        self.#field_name.push(#arg.into());
                        self
                    }

                    pub fn #set_arg(&mut self, #arg: impl ::core::convert::Into<#inner>) -> &mut Self {
                        self.#field_name.push(#arg.into());
                        self
                    }
                }
            },
            Ok(_) => {
                quote! {
                    compile_error!("Expected \"each\" to be a string");
                }
            },
            Err(_) => {
                match find_attr_nameval_lit(&f.attrs, "map") {
                    Ok(Lit::Str(lit)) => {
                        let arg = Ident::new(&lit.value(), lit.span());
                        let set_arg = Ident::new(&format!("set_{}", &lit.value()), lit.span());
                        let [first, second] = get_inner_ty(&f.ty).expect("Expected at least one generic argument")[..2] else { panic!() };
                        quote! {
                            pub fn #arg(mut self, #arg: impl ::core::convert::Into<#first>, val: impl ::core::convert::Into<#second>) -> Self {
                                self.#field_name.insert(#arg.into(),val.into());
                                self
                            }

                            pub fn #set_arg(&mut self, #arg: impl ::core::convert::Into<#first>, val: impl ::core::convert::Into<#second>) -> &mut Self {
                                self.#field_name.insert(#arg.into(),val.into());
                                self
                            }
                        }
                    },
                    Ok(_) => {
                        quote! {
                            compile_error!("Expected \"each\" to be a string");
                        }
                    },
                    Err(err) => {
                        let ty: Cow<Type> = if is_optional(f) {
                            let t = get_inner_ty(&f.ty).map(|l| l[0]).unwrap_or(&f.ty);
                            let t: Type = parse_quote!(#t);
                            Cow::Owned(t)
                        } else {
                            Cow::Borrowed(&f.ty)
                        };
                        let fname = field_name.as_ref().unwrap();
                        let set_name = Ident::new(&format!("set_{}", &fname.to_string()), fname.span());
                        quote! {
                            #err
                            pub fn #field_name(mut self, #field_name: impl ::core::convert::Into<#ty>) -> Self {
                                self.#field_name = ::core::option::Option::Some(#field_name.into());
                                self
                            }

                            pub fn #set_name(&mut self, #field_name: impl ::core::convert::Into<#ty>) -> &mut Self {
                                self.#field_name = ::core::option::Option::Some(#field_name.into());
                                self
                            }
                        }
                    }
                }
            }
        }
    });
    let empty_fields = fields.iter().filter(|f| !is_disabled(f)).map(|f| {
        let field_name = &f.ident;
        if is_each(f) {
            return quote! { #field_name : ::core::default::Default::default() };
        }
        quote! { #field_name : ::core::option::Option::None }
    });
    let generics = &ast.generics;
    let wher = &generics.where_clause;
    let stripped_generics = get_stripped_generics(generics, false);
    let generics_no_defaults = get_stripped_generics(generics, true);

    let vis = &ast.vis;

    let build_fields = fields.iter().map(|f| {
        let field_name = &f.ident;
        quote! { #field_name }
    });


    let is_infallible = find_attr_nameval_bool(&ast.attrs, "infallible").is_ok_and(|v| v);
    if is_infallible {
        for field in fields.iter() {
            if !is_optional(field) && !is_each(field) && find_attr_nameval_expr(&field.attrs, "def").is_err() {
                return quote! {compile_error!("Infallible builder must have all it's fields \
                    marked as either optional, or provide with a default value for them"); }.into()
            }
        }
    }

    let build_fields_let = fields.iter().map(|f| {
        let field_name = &f.ident;
        let expr =
            if is_disabled(f) {
                let Ok(lit) = find_attr_nameval_expr(&f.attrs, "def") else {
                    panic!("Disabled field must have a builder(def = ...) attribute");
                };
                quote! { #lit }
            }
            else if is_optional(f) {
                quote! { self.#field_name }
            } else if let Ok(lit) = find_attr_nameval_expr(&f.attrs, "def"){
                quote! { self.#field_name.unwrap_or_else(|| #lit .into()) }
            }  else if is_each(f) {
                quote! { self. #field_name }
            } else {
                quote! { self.#field_name.ok_or(stringify!("{} is not set", #field_name))? }
            };
        quote! { let #field_name = #expr ; }
    });

    let generics_params = &ast.generics.params;

    let clone_where_token = if wher.is_none() { quote! { where }} else { quote! {} };

    let must_clone = find_attr_nameval_bool(&ast.attrs, "clone").is_ok_and(|v| v);

    let mut clone_impl = quote!{};

    if must_clone {
        let where_clone = ast.generics.type_params().map(|tp| {
            let id = &tp.ident;
            quote! { #id : Clone }
        });
        let fields_clone = fields.iter().filter(|f| !is_disabled(f)).map(|field| {
            let name = &field.ident;
            quote! { #name : self . #name .clone() }
        });

        clone_impl = quote! {
            impl<#generics_params> Clone for #builder_name #stripped_generics
                #wher
                #clone_where_token
                #( #where_clone, )*
            {
                fn clone(&self) -> Self {
                    Self {
                        #(#fields_clone ,)*
                    }
                }
            }
        };
    }

    let build_fn_return = if is_infallible {
        quote! { #ident #stripped_generics }
    } else {
        quote! { ::core::result::Result<#ident #stripped_generics, &'static str>  }
    };

    let mut build_self_struct = quote! {
        #ident {
            #( #build_fields ,)*
        }
    };

    if !is_infallible {
        build_self_struct = quote! { Ok( #build_self_struct ) };
    }

    quote! {
        #vis struct #builder_name #generics {
            #( #builder_fields ,)*
        }

        #clone_impl

        impl #generics_no_defaults #builder_name #stripped_generics #wher {
            #vis fn build(self) -> #build_fn_return {
                #( #build_fields_let )*
                #build_self_struct
            }

            #( #builder_methods )*
        }

        impl #generics_no_defaults #ident #stripped_generics #wher {
            #vis fn builder() -> #builder_name #stripped_generics {
                #builder_name {
                    #( #empty_fields ,)*
                }
            }
        }
    }.into()
}