Skip to main content

finit_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, parse_quote};
5
6mod comparisons;
7mod operations;
8
9#[proc_macro_derive(Set)]
10pub fn set_derive(input: TokenStream) -> TokenStream {
11    let input = proc_macro2::TokenStream::from(input);
12
13    let input: DeriveInput = syn::parse2(input).unwrap();
14
15    let crate_name: syn::Path = parse_quote!(::finit);
16    let struct_name = &input.ident;
17
18    let Data::Struct(struct_data) = &input.data else {
19        unimplemented!("Currently, there is only support for structs.");
20    };
21
22    let is_empty_body: Vec<proc_macro2::TokenStream> = match &struct_data.fields {
23        syn::Fields::Named(fields) => fields
24            .named
25            .iter()
26            .map(|field| {
27                let field_name = field.ident.as_ref().expect("Struct is named.");
28                quote! {
29                    #crate_name::Set::is_empty(&self.#field_name)
30                }
31            })
32            .collect(),
33        syn::Fields::Unnamed(fields) => fields
34            .unnamed
35            .iter()
36            .enumerate()
37            .map(|(i, _field)| {
38                quote! {
39                    #crate_name::Set::is_empty(&self.#i)
40                }
41            })
42            .collect(),
43        syn::Fields::Unit => {
44            return quote! {
45              compile_error!("Unit structs can't be a set.")
46            }
47            .into();
48        }
49    };
50
51    let is_empty_body = is_empty_body
52        .into_iter()
53        .reduce(|acc, value| quote! { #acc & #value})
54        .expect("No unit structs means there must be atleast 1 field.");
55
56    let empty_body: Vec<proc_macro2::TokenStream> = match &struct_data.fields {
57        syn::Fields::Named(fields) => fields
58            .named
59            .iter()
60            .map(|field| {
61                let field_name = field.ident.as_ref().expect("Struct is named.");
62                let field_type = &field.ty;
63                quote! {
64                    #field_name: <#field_type as #crate_name::Set>::empty(),
65                }
66            })
67            .collect(),
68        syn::Fields::Unnamed(fields) => fields
69            .unnamed
70            .iter()
71            .enumerate()
72            .map(|(i, _field)| {
73                quote! {
74                    #crate_name::Set::is_empty(&self.#i)
75                }
76            })
77            .collect(),
78        syn::Fields::Unit => unreachable!("Already returned error earlier."),
79    };
80
81    let empty_body = empty_body
82        .into_iter()
83        .reduce(|acc, value| quote! { #acc #value})
84        .expect("No unit structs means there must be atleast 1 field.");
85
86    quote! {
87        impl #crate_name::Set for #struct_name {
88            type Empty = Self;
89
90            fn is_empty(&self) -> bool {
91                #is_empty_body
92            }
93
94            fn empty() -> Self {
95                Self {
96                    #empty_body
97                }
98            }
99        }
100    }
101    .into()
102}
103
104#[proc_macro_derive(UnionAssign)]
105pub fn union_assign_derive(input: TokenStream) -> TokenStream {
106    let trait_path: syn::Path = parse_quote!(::finit::operations::UnionAssign);
107    let fn_name = format_ident!("union_assign");
108    operations::operation_assign_derive(input, &trait_path, &fn_name)
109}
110
111#[proc_macro_derive(DifferenceAssign)]
112pub fn difference_assign_derive(input: TokenStream) -> TokenStream {
113    let trait_path: syn::Path = parse_quote!(::finit::operations::DifferenceAssign);
114    let fn_name = format_ident!("difference_assign");
115    operations::operation_assign_derive(input, &trait_path, &fn_name)
116}
117
118#[proc_macro_derive(IntersectionAssign)]
119pub fn intersection_assign_derive(input: TokenStream) -> TokenStream {
120    let trait_path: syn::Path = parse_quote!(::finit::operations::IntersectionAssign);
121    let fn_name = format_ident!("intersection_assign");
122    operations::operation_assign_derive(input, &trait_path, &fn_name)
123}
124
125#[proc_macro_derive(DisjunctiveUnionAssign)]
126pub fn disjunctive_union_assign_derive(input: TokenStream) -> TokenStream {
127    let trait_path: syn::Path = parse_quote!(::finit::operations::DisjunctiveUnionAssign);
128    let fn_name = format_ident!("disjunctive_union_assign");
129    operations::operation_assign_derive(input, &trait_path, &fn_name)
130}
131
132#[proc_macro_derive(Union)]
133pub fn union_derive(input: TokenStream) -> TokenStream {
134    let trait_path: syn::Path = parse_quote!(::finit::operations::Union);
135    let fn_name = format_ident!("union");
136    operations::operation_derive(input, &trait_path, &fn_name)
137}
138
139#[proc_macro_derive(Difference)]
140pub fn difference_derive(input: TokenStream) -> TokenStream {
141    let trait_path: syn::Path = parse_quote!(::finit::operations::Difference);
142    let fn_name = format_ident!("difference");
143    operations::operation_derive(input, &trait_path, &fn_name)
144}
145
146#[proc_macro_derive(Intersection)]
147pub fn intersection_derive(input: TokenStream) -> TokenStream {
148    let trait_path: syn::Path = parse_quote!(::finit::operations::Intersection);
149    let fn_name = format_ident!("intersection");
150    operations::operation_derive(input, &trait_path, &fn_name)
151}
152
153#[proc_macro_derive(DisjunctiveUnion)]
154pub fn disjunctive_union_derive(input: TokenStream) -> TokenStream {
155    let trait_path: syn::Path = parse_quote!(::finit::operations::DisjunctiveUnion);
156    let fn_name = format_ident!("disjunctive_union");
157    operations::operation_derive(input, &trait_path, &fn_name)
158}
159
160#[proc_macro_derive(SetEq)]
161pub fn set_eq_derive(input: TokenStream) -> TokenStream {
162    comparisons::set_eq_derive(input)
163}
164
165#[proc_macro_derive(SubsetOf)]
166pub fn subset_of_derive(input: TokenStream) -> TokenStream {
167    comparisons::subset_of_derive(input)
168}