mpl_candy_guard_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput};
4
5#[proc_macro_derive(GuardSet)]
6pub fn derive(input: TokenStream) -> TokenStream {
7    let ast = parse_macro_input!(input as DeriveInput);
8    let name = &ast.ident;
9
10    let fields = if let syn::Data::Struct(syn::DataStruct {
11        fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }),
12        ..
13    }) = ast.data
14    {
15        named
16    } else {
17        panic!("No fields found");
18    };
19
20    let is_option_t = |ty: &syn::Type| -> bool {
21        if let syn::Type::Path(ref p) = ty {
22            if p.path.segments.len() != 1 || p.path.segments[0].ident != "Option" {
23                return false;
24            }
25            if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments {
26                if inner_ty.args.len() != 1 {
27                    return false;
28                } else if let syn::GenericArgument::Type(ref _ty) = inner_ty.args.first().unwrap() {
29                    return true;
30                }
31            }
32        }
33        false
34    };
35
36    let unwrap_option_t = |ty: &syn::Type| -> syn::Type {
37        if let syn::Type::Path(ref p) = ty {
38            if p.path.segments.len() != 1 || p.path.segments[0].ident != "Option" {
39                panic!("Type was not Option<T>");
40            }
41            if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments {
42                if inner_ty.args.len() != 1 {
43                    panic!("Option type was not Option<T>");
44                } else if let syn::GenericArgument::Type(ref ty) = inner_ty.args.first().unwrap() {
45                    return ty.clone();
46                }
47            }
48        }
49        panic!("Type was not Option<T>");
50    };
51
52    let from_data = fields.iter().map(|f| {
53        let name = &f.ident;
54
55        if is_option_t(&f.ty) {
56            let ty = unwrap_option_t(&f.ty);
57            quote! {
58                let #name = if #ty::is_enabled(features) {
59                    cursor += #ty::size();
60                    #ty::load(data, cursor)?
61                } else {
62                    None
63                };
64            }
65        } else {
66            quote! {}
67        }
68    });
69
70    let to_data = fields.iter().map(|f| {
71        let name = &f.ident;
72
73        if is_option_t(&f.ty) {
74            let ty = unwrap_option_t(&f.ty);
75            quote! {
76                if let Some(#name) = &self.#name {
77                    cursor += #ty::size();
78                    if cursor <= data.len() {
79                        #name.save(data, cursor - #ty::size())?;
80                        features = #ty::enable(features);
81                    } else {
82                        return err!(crate::errors::CandyGuardError::InvalidAccountSize);
83                    }
84                }
85            }
86        } else {
87            quote! {}
88        }
89    });
90
91    let merge_data = fields.iter().map(|f| {
92        let name = &f.ident;
93
94        if is_option_t(&f.ty) {
95            quote! {
96                if let Some(#name) = other.#name {
97                    self.#name = Some(#name);
98                }
99            }
100        } else {
101            quote! {}
102        }
103    });
104
105    let struct_fields = fields.iter().map(|f| {
106        let name = &f.ident;
107        quote! { #name }
108    });
109
110    let enabled = fields.iter().map(|f| {
111        let name = &f.ident;
112
113        if is_option_t(&f.ty) {
114            quote! {
115                if let Some(#name) = &self.#name {
116                    conditions.push(#name);
117                }
118            }
119        } else {
120            quote! {}
121        }
122    });
123
124    let struct_size = fields.iter().map(|f| {
125        let name = &f.ident;
126
127        if is_option_t(&f.ty) {
128            let ty = unwrap_option_t(&f.ty);
129            quote! {
130                if self.#name.is_some() {
131                    size += #ty::size();
132                }
133            }
134        } else {
135            quote! {}
136        }
137    });
138
139    let bytes_count = fields.iter().map(|f| {
140        if is_option_t(&f.ty) {
141            let ty = unwrap_option_t(&f.ty);
142            quote! {
143                if #ty::is_enabled(features) {
144                    count += #ty::size();
145                }
146            }
147        } else {
148            quote! {}
149        }
150    });
151    /* This is used to generate the GuardType enum
152    let types_list = fields.iter().map(|f| {
153        if is_option_t(&f.ty) {
154            let ty = unwrap_option_t(&f.ty);
155            quote! { #ty }
156        } else {
157            quote! {}
158        }
159    });
160    */
161    let route_arm = fields.iter().map(|f| {
162        if is_option_t(&f.ty) {
163            let ty = unwrap_option_t(&f.ty);
164            quote! {
165                GuardType::#ty => #ty::instruction(&ctx, route_context, args.data)
166            }
167        } else {
168            quote! {}
169        }
170    });
171
172    let verify = fields.iter().map(|f| {
173        if is_option_t(&f.ty) {
174            let ty = unwrap_option_t(&f.ty);
175            quote! {
176                #ty::verify(data)?;
177            }
178        } else {
179            quote! {}
180        }
181    });
182
183    let expanded = quote! {
184        impl #name {
185            pub fn from_data(data: &[u8]) -> anchor_lang::Result<(Self, u64)> {
186                let mut cursor = 0;
187
188                let features = u64::from_le_bytes(*arrayref::array_ref![data, cursor, 8]);
189                cursor += 8;
190
191                #(#from_data)*
192
193                Ok((Self {
194                    #(#struct_fields,)*
195                }, features))
196            }
197
198            pub fn bytes_count(features: u64) -> usize {
199                let mut count = 8; // features (u64)
200                #(#bytes_count)*
201                count
202            }
203
204            pub fn to_data(&self, data: &mut [u8]) -> anchor_lang::Result<u64> {
205                let mut features = 0;
206                // leave space to write the features flag at the end
207                let mut cursor = 8;
208
209                #(#to_data)*
210
211                // features
212                data[0..8].copy_from_slice(&u64::to_le_bytes(features));
213
214                Ok(features)
215            }
216
217            pub fn merge(&mut self, other: GuardSet) {
218                #(#merge_data)*
219            }
220
221            pub fn enabled_conditions(&self) -> Vec<&dyn Condition> {
222                // list of condition trait objects
223                let mut conditions: Vec<&dyn Condition> = vec![];
224                #(#enabled)*
225
226                conditions
227            }
228
229            pub fn size(&self) -> usize {
230                let mut size = 8; // features (u64)
231                #(#struct_size)*
232                size
233            }
234
235            pub fn route<'info>(
236                ctx: Context<'_, '_, '_, 'info, crate::instructions::Route<'info>>,
237                route_context: crate::instructions::RouteContext<'info>,
238                args: crate::instructions::RouteArgs
239            ) -> anchor_lang::Result<()> {
240                match args.guard {
241                    #(#route_arm,)*
242                    _ => err!(CandyGuardError::InstructionNotFound)
243                }
244            }
245
246            pub fn verify(data: &CandyGuardData) -> Result<()> {
247                #(#verify)*
248
249                Ok(())
250            }
251        }
252        /*
253        #[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)]
254        pub enum GuardType {
255            #(#types_list,)*
256        }
257         */
258    };
259
260    TokenStream::from(expanded)
261}