Skip to main content

anchor_modular_program/
lib.rs

1#![warn(missing_docs)]
2
3//! A replacement #\[program\] macro for anchor-lang that allows splitting
4//! instructions into modules.
5//!
6//!
7//! ```
8//! mod extra;                 
9//! use extra::types::*;       
10//!                            
11//! #[modular_program(         
12//!     modules=[              
13//!         extra::instructions
14//!     ]                      
15//! )]                         
16//! mod my_program {           
17//!     use super::*;          
18//! }
19//! ```
20
21
22
23use std::{collections::HashMap, io::Read};
24use anchor_syn::*;
25use syn::*;
26use quote::*;
27
28
29/// Modules can either be a rust path to an instructions module,
30/// or it can be an object:
31///
32/// ```
33/// #[modular_program(modules=[
34///     mymod::instructions,
35///     {
36///         // Required, module path to instructions
37///         module: path::to::instructions,
38///
39///         // Optional path, override where to look for the instructions
40///         file_path: "./src/mod/etc.rs",
41///
42///         // Optional prefix, empty for no prefix
43///         prefix: "prefix",
44///
45///         // Optional, A macro that wraps the call to the instruction, eg:
46///         // ```
47///         // macro_rules ix_wrapper {
48///         //     ($ix:path, $ctx:ident: $ctx_type:ty $(, $arg:ident: $arg_type:ty )*) => {
49///         //         $ix($ctx, $(, $arg)*)
50///         //     };
51///         // }
52///         // ```
53///         macro: path::to::macro
54///     }
55/// ])]
56/// mod my_program {           
57///     use super::*;          
58/// }
59/// ```
60///
61#[proc_macro_attribute]
62pub fn modular_program(
63    args: proc_macro::TokenStream,
64    input: proc_macro::TokenStream,
65) -> proc_macro::TokenStream {
66    let ProgramMacro { modules } = syn::parse_macro_input!(args as ProgramMacro);
67
68    let fns = modules.into_iter()
69        .map(|m| (m.clone(), get_program(m.clone())))
70        .flat_map(|(spec, p)| {
71            p.ixs.into_iter().map(|ix| build_relay(&spec, ix)).collect::<Vec<_>>()
72        })
73        .collect();
74
75    let input = insert_fns_into_first_module(input, fns);
76    let program = syn::parse_macro_input!(input as anchor_syn::Program);
77    program.to_token_stream().into()
78}
79
80
81/*
82 * This builds relay instructions, i.e, for `foo::instructions::do_thing`:
83 *
84 * pub fn foo_do_thing(ctx: Context<YourInstructionContext>, ...) -> Result<()> {
85 *     foo::instructions::do_thing(ctx, ...)
86 * }
87 */
88
89fn build_relay(spec: &ModuleSpec, ix: Ix) -> ItemFn {
90
91    let item_fn = &ix.raw_method;
92    let ItemFn { attrs, .. } = &item_fn;
93    let Signature { ident: fn_name, inputs, generics, output, .. } = &item_fn.sig;
94
95    let path = spec.module.clone();
96    let first = path.segments[0].clone().ident;
97
98    let new_name = match &spec.prefix {
99        Some(s) if s.is_empty() => fn_name.clone(),
100        o => {
101            let prefix = o.clone().unwrap_or(first.to_string());
102            Ident::new(format!("{}_{}", prefix, fn_name).as_str(), first.span())
103        }
104    };
105
106    let mut inputs = inputs.clone();
107    let arg_names = inputs.iter_mut().filter_map(|arg| {
108        if let FnArg::Typed(pt) = arg {
109            if let syn::Pat::Ident(id) = &mut *pt.pat {
110                id.mutability = None;
111                return Some(id.ident.clone());
112            }
113        }
114        None
115    }).collect::<Vec<_>>();
116
117    if let Some(w) = &spec.wrapper {
118        parse_quote! {
119            #(#attrs)*
120            pub fn #new_name #generics(#inputs) #output {
121                { #w!(#path::#fn_name, #inputs) }
122            }
123        }
124    } else {
125        parse_quote! {
126            #(#attrs)*
127            pub fn #new_name #generics(#inputs) #output {
128                #path::#fn_name(#(#arg_names),*)
129            }
130        }
131    }
132}
133
134
135/*
136 * Get an anchor Program from the given path, by parsing the file, i.e.
137 * foo::instructions is converted to "$PROGRAM_DIR/src/foo/instructions.rs"
138 */
139
140fn get_program(spec: ModuleSpec) -> Program {
141
142    let mod_path = format!(
143        "{}/{}",
144        std::env::var("CARGO_MANIFEST_DIR").unwrap(),
145        spec.get_file_path()
146    );
147
148    let mut code_str = String::new();
149    std::fs::File::open(mod_path).unwrap().read_to_string(&mut code_str).unwrap();
150    let parsed = syn::parse_file(&code_str).unwrap();
151
152    let program_mod = ItemMod {
153        vis: Visibility::Public(VisPublic { pub_token: Default::default() }),
154        attrs: vec![],
155        mod_token: syn::token::Mod::default(),
156        ident: Ident::new("abc", proc_macro2::Span::call_site()),
157        content: Some((
158            syn::token::Brace::default(),
159            parsed.items,
160        )),
161        semi: None,
162    };
163
164    let program = anchor_syn::parser::program::parse(program_mod).unwrap();
165    assert!(program.fallback_fn.is_none(), "additional program module cant have fallback");
166    program
167}
168
169
170
171/*
172 * Parse the macro arguments
173 */
174
175#[derive(Debug)]
176struct ProgramMacro { modules: Vec<ModuleSpec>, }
177
178impl parse::Parse for ProgramMacro {
179    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
180
181        // Parse `modules`
182        let modules_ident: Ident = input.parse()?;
183        if modules_ident != "modules" {
184            return Err(syn::Error::new(modules_ident.span(), "expected `modules`"));
185        }
186
187        input.parse::<Token![=]>()?;
188
189        // Parse the bracketed list `[cell, placement]`
190        let content;
191        syn::bracketed!(content in input);
192        let specs = content.parse_terminated::<ModuleSpec, Token![,]>(|p| p.parse())?;
193
194        // Convert Punctuated<Ident, _> to Vec<Ident>
195        let modules = specs.into_iter().collect();
196
197        Ok(ProgramMacro { modules })
198    }
199}
200
201#[derive(Clone, Debug)]
202struct ModuleSpec {
203    module: Path,
204    prefix: Option<String>,
205    file_path: Option<String>,
206    wrapper: Option<Path>
207}
208
209impl ModuleSpec {
210    fn get_file_path(&self) -> String {
211        self.file_path.clone().unwrap_or_else(|| {
212            let p = self.module.segments.iter().fold(String::new(), |s, p| format!("{}/{}", s, p.ident));
213            format!("./src{}.rs", p)
214        })
215    }
216}
217
218impl parse::Parse for ModuleSpec {
219    fn parse(input: parse::ParseStream) -> Result<Self> {
220
221        type T = (String, (Option<String>, Option<Path>));
222        fn parse_field(p: parse::ParseStream) -> syn::Result<T> {
223            let name = p.parse::<Ident>()?.to_string();
224            p.parse::<Token![:]>()?;
225            Ok(
226                if name == "file_path" || name == "prefix" {
227                    (name, (Some(p.parse::<LitStr>()?.value()), None))
228                } else if name == "module" || name == "wrapper" {
229                    (name, (None, Some(p.parse::<Path>()?)))
230                } else {
231                    panic!("Invalid module spec param: {}", name);
232                }
233            )
234        }
235
236
237        if input.peek(Ident) {
238            let module = input.parse::<Path>()?;
239            Ok(ModuleSpec { module, prefix: None, file_path: None, wrapper: None })
240        } else {
241            let content;
242            syn::braced!(content in input);
243            let fields = content.parse_terminated::<T, Token![,]>(parse_field)?;
244            let mut hm = fields.clone().into_iter().collect::<HashMap<String, _>>();
245            assert!(hm.len() == fields.len(), "duplicate field");
246            Ok(ModuleSpec {
247                module: hm.remove("module").expect("module is required").1.unwrap(),
248                prefix: hm.remove("prefix").map(|t| t.0).flatten(),
249                file_path: hm.remove("file_path").map(|t| t.0).flatten(),
250                wrapper: hm.remove("wrapper").map(|t| t.1).flatten(),
251            })
252        }
253    }
254}
255
256
257/*
258 * Append instruction functions to main program module
259 */
260
261fn insert_fns_into_first_module(input: proc_macro::TokenStream, fns: Vec<ItemFn>) -> proc_macro::TokenStream {
262
263    let mut item_mod: ItemMod = parse2(input.into()).expect("Failed to parse main program module");
264
265    item_mod.content
266        .as_mut()
267        .expect("Program module has no body?")
268        .1.extend(fns.into_iter().map(Into::into));
269
270    quote! { #item_mod }.into()
271}
272