Skip to main content

freenet_scaffold_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::{parse_macro_input, Data, DeriveInput, Fields, Meta, Token};
7
8/// Parse the `post_apply_delta = "method_name"` from attribute arguments.
9fn parse_post_apply_delta(attr: TokenStream) -> Option<syn::Ident> {
10    if attr.is_empty() {
11        return None;
12    }
13    let parsed = syn::parse::Parser::parse(Punctuated::<Meta, Token![,]>::parse_terminated, attr)
14        .expect("Failed to parse #[composable] attributes");
15
16    for meta in parsed {
17        if let Meta::NameValue(nv) = meta {
18            if nv.path.is_ident("post_apply_delta") {
19                if let syn::Expr::Lit(syn::ExprLit {
20                    lit: syn::Lit::Str(s),
21                    ..
22                }) = &nv.value
23                {
24                    return Some(format_ident!("{}", s.value()));
25                } else {
26                    panic!("post_apply_delta value must be a string literal");
27                }
28            }
29        }
30    }
31    None
32}
33
34#[proc_macro_attribute]
35pub fn composable(attr: TokenStream, item: TokenStream) -> TokenStream {
36    let post_apply_delta_method = parse_post_apply_delta(attr);
37
38    let input = parse_macro_input!(item as DeriveInput);
39    let name = &input.ident;
40
41    let fields = match &input.data {
42        Data::Struct(data_struct) => match &data_struct.fields {
43            Fields::Named(fields_named) => &fields_named.named,
44            _ => panic!("ComposableState can only be applied to structs with named fields"),
45        },
46        _ => panic!("ComposableState can only be applied to structs"),
47    };
48
49    let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
50    let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
51
52    // Take the type of the first field to derive ParentState and Parameters
53    let first_field_type = &field_types[0];
54
55    let summary_name = format_ident!("{}Summary", name);
56    let delta_name = format_ident!("{}Delta", name);
57
58    let summary_fields = field_names
59        .iter()
60        .zip(field_types.iter())
61        .map(|(name, ty)| {
62            quote! {
63                pub #name: <#ty as ComposableState>::Summary
64            }
65        });
66
67    let delta_fields = field_names
68        .iter()
69        .zip(field_types.iter())
70        .map(|(name, ty)| {
71            quote! {
72                pub #name: Option<<#ty as ComposableState>::Delta>
73            }
74        });
75
76    // Error messages for missing ComposableState implementation
77    let check_composable_impls = field_types.iter().map(|ty| {
78        quote! {
79            const _: fn() = || {
80                fn check_composable<T: ComposableState>() {}
81                check_composable::<#ty>();
82            };
83        }
84    });
85
86    // Ensure that all fields share the same ParentState and Parameters
87    let check_matching_parent_state = field_types.iter().map(|ty| {
88        quote! {
89            const _: fn() = || {
90                fn check_parent_state<T: ComposableState<ParentState = <#first_field_type as ComposableState>::ParentState>>() {}
91                check_parent_state::<#ty>();
92            };
93        }
94    });
95
96    let check_matching_parameters = field_types.iter().map(|ty| {
97        quote! {
98            const _: fn() = || {
99                fn check_parameters<T: ComposableState<Parameters = <#first_field_type as ComposableState>::Parameters>>() {}
100                check_parameters::<#ty>();
101            };
102        }
103    });
104
105    let verify_impl = field_names.iter().map(|name| {
106        quote! {
107            self.#name.verify(parent_state, parameters)?;
108        }
109    });
110
111    let summarize_impl = field_names.iter().map(|name| {
112        quote! {
113            #name: self.#name.summarize(parent_state, parameters)
114        }
115    });
116
117    let delta_impl = field_names.iter().map(|name| {
118        quote! {
119            #name: self.#name.delta(parent_state, parameters, &old_state_summary.#name)
120        }
121    });
122
123    let all_none_check = field_names
124        .iter()
125        .map(|name| {
126            quote! {
127                delta.#name.is_none()
128            }
129        })
130        .collect::<Vec<_>>();
131
132    // Note: we're passing self_clone as the parent_state so that dependencies between fields work
133    let apply_delta_impl = field_names.iter().map(|name| {
134        quote! {
135            let self_clone = self.clone();
136            self.#name.apply_delta(&self_clone, parameters, &delta.#name)?;
137        }
138    });
139
140    // Generate post_apply_delta call if the attribute specifies a method
141    let post_apply_delta_call = if let Some(method) = &post_apply_delta_method {
142        quote! { self.#method(parameters)?; }
143    } else {
144        quote! {}
145    };
146
147    let _generic_params: Vec<_> = input.generics.params.iter().collect();
148    let where_clause = input.generics.where_clause.clone();
149    let (impl_generics, ty_generics, _) = input.generics.split_for_impl();
150
151    let expanded = quote! {
152        use freenet_scaffold::ComposableState;
153
154        #input
155
156        // Automatically implement Serialize, Deserialize, Clone, PartialEq, and Debug for the generated Summary and Delta structs
157        #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Debug)]
158        pub struct #summary_name #ty_generics #where_clause {
159            #(#summary_fields,)*
160        }
161
162        #[derive(serde::Serialize, serde::Deserialize, Clone, PartialEq, Debug, Default)]
163        pub struct #delta_name #ty_generics #where_clause {
164            #(#delta_fields,)*
165        }
166
167        impl #impl_generics ComposableState for #name #ty_generics #where_clause
168        where
169            #(#field_types: ComposableState,)*
170        {
171            type ParentState = #name;
172            type Summary = #summary_name #ty_generics;
173            type Delta = #delta_name #ty_generics;
174            type Parameters = <#first_field_type as ComposableState>::Parameters;
175
176            fn verify(&self, parent_state: &Self::ParentState, parameters: &Self::Parameters) -> Result<(), String> {
177                #(#verify_impl)*
178                Ok(())
179            }
180
181            fn summarize(&self, parent_state: &Self::ParentState, parameters: &Self::Parameters) -> Self::Summary {
182                #summary_name {
183                    #(#summarize_impl,)*
184                }
185            }
186
187            fn delta(&self, parent_state: &Self::ParentState, parameters: &Self::Parameters, old_state_summary: &Self::Summary) -> Option<Self::Delta> {
188                let delta = #delta_name {
189                    #(#delta_impl,)*
190                };
191
192                if #(#all_none_check)&&* {
193                    None
194                } else {
195                    Some(delta)
196                }
197            }
198
199            // parent_state disregarded because we need to use self so that dependencies between fields work, ugly
200            fn apply_delta(&mut self, _parent_state: &Self::ParentState, parameters: &Self::Parameters, delta: &Option<Self::Delta>) -> Result<(), String> {
201                if let Some(delta) = delta {
202                    #(#apply_delta_impl)*
203                    #post_apply_delta_call
204                }
205                Ok(())
206            }
207        }
208
209        // Additional checks to provide better compile-time error messages
210        #(#check_composable_impls)*
211        #(#check_matching_parent_state)*
212        #(#check_matching_parameters)*
213    };
214
215    TokenStream::from(expanded)
216}