1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{
5 parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Data, DeriveInput,
6 Field, Ident, Meta, Token,
7};
8
9#[proc_macro_derive(DataLen)]
17pub fn derive_data_len(input: TokenStream) -> TokenStream {
18 let input = parse_macro_input!(input as DeriveInput);
19 let name = input.ident;
20
21 let expanded = quote! {
22 impl pinocchio_util::DataLen for #name {
23 const LEN: usize = core::mem::size_of::<#name>();
24 }
25 };
26
27 TokenStream::from(expanded)
28}
29
30#[proc_macro_derive(Updates)]
49pub fn derive_updates(input: TokenStream) -> TokenStream {
50 let input = parse_macro_input!(input as DeriveInput);
51 let name = input.ident;
52 let update_enum_name = Ident::new(&format!("{}Update", name), name.span());
53
54 let fields = match input.data {
55 Data::Struct(data) => data.fields,
56 _ => panic!("Updates derive macro only supports structs"),
57 };
58
59 let field_variants: Vec<_> = fields
60 .iter()
61 .enumerate()
62 .map(|(_i, field)| {
63 let field_name = field.ident.as_ref().unwrap();
64 let _field_type = &field.ty;
65 let variant_name = Ident::new(
66 &format!(
67 "Set{}",
68 field_name
69 .to_string()
70 .chars()
71 .next()
72 .unwrap()
73 .to_uppercase()
74 .chain(field_name.to_string().chars().skip(1))
75 .collect::<String>()
76 ),
77 field_name.span(),
78 );
79
80 quote! {
81 #variant_name(#_field_type)
82 }
83 })
84 .collect();
85
86 let match_arms: Vec<_> = fields
87 .iter()
88 .enumerate()
89 .map(|(_i, field)| {
90 let field_name = field.ident.as_ref().unwrap();
91 let _field_type = &field.ty;
92 let variant_name = Ident::new(
93 &format!(
94 "Set{}",
95 field_name
96 .to_string()
97 .chars()
98 .next()
99 .unwrap()
100 .to_uppercase()
101 .chain(field_name.to_string().chars().skip(1))
102 .collect::<String>()
103 ),
104 field_name.span(),
105 );
106
107 quote! {
108 #update_enum_name::#variant_name(value) => self.#field_name = value,
109 }
110 })
111 .collect();
112
113 let expanded = quote! {
114 pub enum #update_enum_name {
115 #(#field_variants),*
116 }
117
118 impl pinocchio_util::AccountUpdates for #name {
119 type Update = #update_enum_name;
120
121 fn updates(&mut self, updates: Self::Update) -> Result<(), pinocchio::program_error::ProgramError> {
122 match updates {
123 #(#match_arms)*
124 }
125 Ok(())
126 }
127 }
128 };
129
130 TokenStream::from(expanded)
131}
132
133struct ValidationAttr {
134 non_empty: bool,
135 is_signer: bool,
136 is_executable: bool,
137 len: Option<usize>,
138 id: Option<syn::Expr>,
139}
140
141impl Parse for ValidationAttr {
142 fn parse(input: ParseStream) -> syn::Result<Self> {
143 let mut non_empty = false;
144 let mut len = None;
145 let mut id = None;
146 let mut is_signer = false;
147 let mut is_executable = false;
148
149 let args = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
150
151 for arg in args {
152 match arg {
153 Meta::Path(path) => {
154 if path.is_ident("non_empty") {
155 non_empty = true;
156 }
157 }
158 Meta::NameValue(name_value) => {
159 if name_value.path.is_ident("len") {
160 if let syn::Expr::Lit(syn::ExprLit {
161 lit: syn::Lit::Int(lit_int),
162 ..
163 }) = &name_value.value
164 {
165 len = Some(lit_int.base10_parse()?);
166 }
167 } else if name_value.path.is_ident("id") {
168 id = Some(name_value.value);
169 }
170 }
171 _ => {}
172 }
173 }
174
175 Ok(ValidationAttr {
176 non_empty,
177 len,
178 id,
179 is_signer,
180 is_executable,
181 })
182 }
183}
184
185#[proc_macro_derive(Validate, attributes(validate))]
215pub fn derive_validate(input: TokenStream) -> TokenStream {
216 let input = parse_macro_input!(input as DeriveInput);
217 let name = input.ident;
218
219 let fields = match input.data {
220 Data::Struct(data) => data.fields,
221 _ => panic!("This macro only supports structs"),
222 };
223
224 let validation_checks: Vec<_> = fields
225 .iter()
226 .enumerate()
227 .map(|(_i, field)| {
228 let field_name = field.ident.as_ref().unwrap();
229
230 let mut validation_attr = None;
231 for attr in &field.attrs {
232 if attr.path().is_ident("validate") {
233 validation_attr = Some(attr.parse_args::<ValidationAttr>().unwrap());
234 break;
235 }
236 }
237
238 if let Some(attr) = validation_attr {
239 let mut checks = Vec::new();
240
241 if attr.non_empty {
242 checks.push(quote! {
243 if self.#field_name.data_len() == 0 {
244 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
245 }
246 });
247 }
248
249 if attr.is_signer {
250 checks.push(quote! {
251 if !self.#field_name.is_signer() {
252 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
253 }
254 });
255 }
256
257 if attr.is_executable {
258 checks.push(quote! {
259 if !self.#field_name.is_executable() {
260 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
261 }
262 });
263 }
264
265 if let Some(len) = attr.len {
266 checks.push(quote! {
267 if self.#field_name.data_len() != #len {
268 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
269 }
270 });
271 }
272
273 if let Some(id) = attr.id {
274 checks.push(quote! {
275 if self.#field_name.key() != &#id {
276 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
277 }
278 });
279 }
280
281 quote! {
282 #(#checks)*
283 }
284 } else {
285 quote! {}
286 }
287 })
288 .collect();
289
290 let expanded = quote! {
291 impl<'info> pinocchio_util::Validate<'info> for #name<'info> {
292 fn validate(&self) -> Result<(), pinocchio::program_error::ProgramError> {
293 #(#validation_checks)*
294 Ok(())
295 }
296 }
297 };
298
299 TokenStream::from(expanded)
300}
301
302#[proc_macro_derive(Context)]
327pub fn derive_context(input: TokenStream) -> TokenStream {
328 let input = parse_macro_input!(input as DeriveInput);
329 let name = &input.ident;
330
331 let lifetime_params: Vec<_> = input.generics.lifetimes().collect();
332
333 if lifetime_params.len() != 1 {
334 panic!("Context derive requires exactly one lifetime parameter");
335 }
336
337 let lifetime_param = &lifetime_params[0];
338 let lifetime = &lifetime_param.lifetime;
339
340 if lifetime.ident != "info" {
341 panic!("Context derive requires the lifetime parameter to be named 'info");
342 }
343
344 let fields = match input.data {
345 Data::Struct(ref data) => &data.fields,
346 _ => panic!("Context derive only works on structs"),
347 };
348
349 let accounts_len = fields.len();
350 let field_assignments: Vec<_> = fields
351 .iter()
352 .enumerate()
353 .map(|(i, field)| {
354 let field_name = field.ident.as_ref().unwrap();
355 quote! { #field_name: &accounts.get_unchecked(#i), }
356 })
357 .collect();
358
359 let expanded = quote! {
360 impl<'info> pinocchio_util::Context<'info> for #name<'info> {
361 const ACCOUNTS_LEN: usize = #accounts_len;
362
363 fn build(accounts: &'info [pinocchio::account_info::AccountInfo])
364 -> Result<Self, pinocchio::program_error::ProgramError>
365 {
366 if accounts.len() != Self::ACCOUNTS_LEN {
367 return Err(pinocchio::program_error::ProgramError::InvalidAccountData);
368 }
369
370 Ok(unsafe {
371 Self {
372 #(#field_assignments)*
373 }
374 })
375 }
376 }
377 };
378
379 TokenStream::from(expanded)
380}