mpl_candy_guard_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput};
4
5#[proc_macro_derive(GuardSet)]
6pub fn derive(input: TokenStream) -> TokenStream {
7 let ast = parse_macro_input!(input as DeriveInput);
8 let name = &ast.ident;
9
10 let fields = if let syn::Data::Struct(syn::DataStruct {
11 fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }),
12 ..
13 }) = ast.data
14 {
15 named
16 } else {
17 panic!("No fields found");
18 };
19
20 let is_option_t = |ty: &syn::Type| -> bool {
21 if let syn::Type::Path(ref p) = ty {
22 if p.path.segments.len() != 1 || p.path.segments[0].ident != "Option" {
23 return false;
24 }
25 if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments {
26 if inner_ty.args.len() != 1 {
27 return false;
28 } else if let syn::GenericArgument::Type(ref _ty) = inner_ty.args.first().unwrap() {
29 return true;
30 }
31 }
32 }
33 false
34 };
35
36 let unwrap_option_t = |ty: &syn::Type| -> syn::Type {
37 if let syn::Type::Path(ref p) = ty {
38 if p.path.segments.len() != 1 || p.path.segments[0].ident != "Option" {
39 panic!("Type was not Option<T>");
40 }
41 if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments {
42 if inner_ty.args.len() != 1 {
43 panic!("Option type was not Option<T>");
44 } else if let syn::GenericArgument::Type(ref ty) = inner_ty.args.first().unwrap() {
45 return ty.clone();
46 }
47 }
48 }
49 panic!("Type was not Option<T>");
50 };
51
52 let from_data = fields.iter().map(|f| {
53 let name = &f.ident;
54
55 if is_option_t(&f.ty) {
56 let ty = unwrap_option_t(&f.ty);
57 quote! {
58 let #name = if #ty::is_enabled(features) {
59 cursor += #ty::size();
60 #ty::load(data, cursor)?
61 } else {
62 None
63 };
64 }
65 } else {
66 quote! {}
67 }
68 });
69
70 let to_data = fields.iter().map(|f| {
71 let name = &f.ident;
72
73 if is_option_t(&f.ty) {
74 let ty = unwrap_option_t(&f.ty);
75 quote! {
76 if let Some(#name) = &self.#name {
77 cursor += #ty::size();
78 if cursor <= data.len() {
79 #name.save(data, cursor - #ty::size())?;
80 features = #ty::enable(features);
81 } else {
82 return err!(crate::errors::CandyGuardError::InvalidAccountSize);
83 }
84 }
85 }
86 } else {
87 quote! {}
88 }
89 });
90
91 let merge_data = fields.iter().map(|f| {
92 let name = &f.ident;
93
94 if is_option_t(&f.ty) {
95 quote! {
96 if let Some(#name) = other.#name {
97 self.#name = Some(#name);
98 }
99 }
100 } else {
101 quote! {}
102 }
103 });
104
105 let struct_fields = fields.iter().map(|f| {
106 let name = &f.ident;
107 quote! { #name }
108 });
109
110 let enabled = fields.iter().map(|f| {
111 let name = &f.ident;
112
113 if is_option_t(&f.ty) {
114 quote! {
115 if let Some(#name) = &self.#name {
116 conditions.push(#name);
117 }
118 }
119 } else {
120 quote! {}
121 }
122 });
123
124 let struct_size = fields.iter().map(|f| {
125 let name = &f.ident;
126
127 if is_option_t(&f.ty) {
128 let ty = unwrap_option_t(&f.ty);
129 quote! {
130 if self.#name.is_some() {
131 size += #ty::size();
132 }
133 }
134 } else {
135 quote! {}
136 }
137 });
138
139 let bytes_count = fields.iter().map(|f| {
140 if is_option_t(&f.ty) {
141 let ty = unwrap_option_t(&f.ty);
142 quote! {
143 if #ty::is_enabled(features) {
144 count += #ty::size();
145 }
146 }
147 } else {
148 quote! {}
149 }
150 });
151 let route_arm = fields.iter().map(|f| {
162 if is_option_t(&f.ty) {
163 let ty = unwrap_option_t(&f.ty);
164 quote! {
165 GuardType::#ty => #ty::instruction(&ctx, route_context, args.data)
166 }
167 } else {
168 quote! {}
169 }
170 });
171
172 let verify = fields.iter().map(|f| {
173 if is_option_t(&f.ty) {
174 let ty = unwrap_option_t(&f.ty);
175 quote! {
176 #ty::verify(data)?;
177 }
178 } else {
179 quote! {}
180 }
181 });
182
183 let expanded = quote! {
184 impl #name {
185 pub fn from_data(data: &[u8]) -> anchor_lang::Result<(Self, u64)> {
186 let mut cursor = 0;
187
188 let features = u64::from_le_bytes(*arrayref::array_ref![data, cursor, 8]);
189 cursor += 8;
190
191 #(#from_data)*
192
193 Ok((Self {
194 #(#struct_fields,)*
195 }, features))
196 }
197
198 pub fn bytes_count(features: u64) -> usize {
199 let mut count = 8; #(#bytes_count)*
201 count
202 }
203
204 pub fn to_data(&self, data: &mut [u8]) -> anchor_lang::Result<u64> {
205 let mut features = 0;
206 let mut cursor = 8;
208
209 #(#to_data)*
210
211 data[0..8].copy_from_slice(&u64::to_le_bytes(features));
213
214 Ok(features)
215 }
216
217 pub fn merge(&mut self, other: GuardSet) {
218 #(#merge_data)*
219 }
220
221 pub fn enabled_conditions(&self) -> Vec<&dyn Condition> {
222 let mut conditions: Vec<&dyn Condition> = vec![];
224 #(#enabled)*
225
226 conditions
227 }
228
229 pub fn size(&self) -> usize {
230 let mut size = 8; #(#struct_size)*
232 size
233 }
234
235 pub fn route<'info>(
236 ctx: Context<'_, '_, '_, 'info, crate::instructions::Route<'info>>,
237 route_context: crate::instructions::RouteContext<'info>,
238 args: crate::instructions::RouteArgs
239 ) -> anchor_lang::Result<()> {
240 match args.guard {
241 #(#route_arm,)*
242 _ => err!(CandyGuardError::InstructionNotFound)
243 }
244 }
245
246 pub fn verify(data: &CandyGuardData) -> Result<()> {
247 #(#verify)*
248
249 Ok(())
250 }
251 }
252 };
259
260 TokenStream::from(expanded)
261}