anchor_modular_program/
lib.rs

1use std::{collections::HashMap, io::Read};
2use anchor_syn::*;
3//use proc_macro::Ident;
4use syn::*;
5use quote::*;
6
7
8#[proc_macro_attribute]
9pub fn modular_program(
10    args: proc_macro::TokenStream,
11    input: proc_macro::TokenStream,
12) -> proc_macro::TokenStream {
13    let ProgramMacro { modules } = syn::parse_macro_input!(args as ProgramMacro);
14
15    let fns = modules.into_iter()
16        .map(|m| (m.clone(), get_program(m.clone())))
17        .flat_map(|(spec, p)| {
18            p.ixs.into_iter().map(|ix| build_relay(&spec, ix)).collect::<Vec<_>>()
19        })
20        .collect();
21
22    let input = insert_fns_into_first_module(input, fns);
23    let program = syn::parse_macro_input!(input as anchor_syn::Program);
24    program.to_token_stream().into()
25}
26
27
28/*
29 * This builds relay instructions, i.e, for `foo::instructions::do_thing`:
30 *
31 * pub fn foo_do_thing(ctx: Context<YourInstructionContext>, ...) -> Result<()> {
32 *     foo::instructions::do_thing(ctx, ...)
33 * }
34 */
35
36fn build_relay(spec: &ModuleSpec, ix: Ix) -> ItemFn {
37
38    let item_fn = &ix.raw_method;
39    let ItemFn { attrs, .. } = &item_fn;
40    let Signature { ident: fn_name, inputs, generics, output, .. } = &item_fn.sig;
41
42    let path = spec.module.clone();
43    let first = path.segments[0].clone().ident;
44
45    let new_name = match &spec.prefix {
46        Some(s) if s.is_empty() => fn_name.clone(),
47        o => {
48            let prefix = o.clone().unwrap_or(first.to_string());
49            Ident::new(format!("{}_{}", prefix, fn_name).as_str(), first.span())
50        }
51    };
52
53    // Extract argument names for the function call
54    let arg_names: Vec<Ident> = inputs
55        .iter()
56        .filter_map(|arg| match arg { FnArg::Typed(pt) => Some(&*pt.pat), _ => None })
57        .filter_map(|pt| match pt { syn::Pat::Ident(id) => Some(id.ident.clone()), _ => None })
58        .collect();
59
60    if let Some(w) = &spec.wrapper {
61        parse_quote! {
62            #(#attrs)*
63            pub fn #new_name #generics(#inputs) #output {
64                { #w!(#path::#fn_name, #inputs) }
65            }
66        }
67    } else {
68        parse_quote! {
69            #(#attrs)*
70            pub fn #new_name #generics(#inputs) #output {
71                #path::#fn_name(#(#arg_names),*)
72            }
73        }
74    }
75}
76
77
78/*
79 * Get an anchor Program from the given path, by parsing the file, i.e.
80 * foo::instructions is converted to "$PROGRAM_DIR/src/foo/instructions.rs"
81 */
82
83fn get_program(spec: ModuleSpec) -> Program {
84
85    let mod_path = format!(
86        "{}/{}",
87        std::env::var("CARGO_MANIFEST_DIR").unwrap(),
88        spec.get_file_path()
89    );
90
91    let mut code_str = String::new();
92    std::fs::File::open(mod_path).unwrap().read_to_string(&mut code_str).unwrap();
93    let parsed = syn::parse_file(&code_str).unwrap();
94
95    let program_mod = ItemMod {
96        vis: Visibility::Public(VisPublic { pub_token: Default::default() }),
97        attrs: vec![],
98        mod_token: syn::token::Mod::default(),
99        ident: Ident::new("abc", proc_macro2::Span::call_site()),
100        content: Some((
101            syn::token::Brace::default(),
102            parsed.items,
103        )),
104        semi: None,
105    };
106
107    let program = anchor_syn::parser::program::parse(program_mod).unwrap();
108    assert!(program.fallback_fn.is_none(), "additional program module cant have fallback");
109    program
110}
111
112
113
114/*
115 * Parse the macro arguments
116 */
117
118#[derive(Debug)]
119struct ProgramMacro { modules: Vec<ModuleSpec>, }
120
121impl parse::Parse for ProgramMacro {
122    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
123
124        // Parse `modules`
125        let modules_ident: Ident = input.parse()?;
126        if modules_ident != "modules" {
127            return Err(syn::Error::new(modules_ident.span(), "expected `modules`"));
128        }
129
130        input.parse::<Token![=]>()?;
131
132        // Parse the bracketed list `[cell, placement]`
133        let content;
134        syn::bracketed!(content in input);
135        let specs = content.parse_terminated::<ModuleSpec, Token![,]>(|p| p.parse())?;
136
137        // Convert Punctuated<Ident, _> to Vec<Ident>
138        let modules = specs.into_iter().collect();
139
140        Ok(ProgramMacro { modules })
141    }
142}
143
144#[derive(Clone, Debug)]
145struct ModuleSpec {
146    module: Path,
147    prefix: Option<String>,
148    file_path: Option<String>,
149    wrapper: Option<Path>
150}
151
152impl ModuleSpec {
153    fn get_file_path(&self) -> String {
154        self.file_path.clone().unwrap_or_else(|| {
155            let p = self.module.segments.iter().fold(String::new(), |s, p| format!("{}/{}", s, p.ident));
156            format!("./src{}.rs", p)
157        })
158    }
159}
160
161impl parse::Parse for ModuleSpec {
162    fn parse(input: parse::ParseStream) -> Result<Self> {
163
164        type T = (String, (Option<String>, Option<Path>));
165        fn parse_field(p: parse::ParseStream) -> syn::Result<T> {
166            let name = p.parse::<Ident>()?.to_string();
167            p.parse::<Token![:]>()?;
168            Ok(
169                if name == "file_path" || name == "prefix" {
170                    (name, (Some(p.parse::<LitStr>()?.value()), None))
171                } else if name == "module" || name == "wrapper" {
172                    (name, (None, Some(p.parse::<Path>()?)))
173                } else {
174                    panic!("Invalid module spec param: {}", name);
175                }
176            )
177        }
178
179
180        if input.peek(Ident) {
181            let module = input.parse::<Path>()?;
182            Ok(ModuleSpec { module, prefix: None, file_path: None, wrapper: None })
183        } else {
184            let content;
185            syn::braced!(content in input);
186            let fields = content.parse_terminated::<T, Token![,]>(parse_field)?;
187            let mut hm = fields.clone().into_iter().collect::<HashMap<String, _>>();
188            assert!(hm.len() == fields.len(), "duplicate field");
189            Ok(ModuleSpec {
190                module: hm.remove("module").expect("module is required").1.unwrap(),
191                prefix: hm.remove("prefix").map(|t| t.0).flatten(),
192                file_path: hm.remove("file_path").map(|t| t.0).flatten(),
193                wrapper: hm.remove("wrapper").map(|t| t.1).flatten(),
194            })
195        }
196    }
197}
198
199
200/*
201 * Append instruction functions to main program module
202 */
203
204fn insert_fns_into_first_module(input: proc_macro::TokenStream, fns: Vec<ItemFn>) -> proc_macro::TokenStream {
205
206    let mut item_mod: ItemMod = parse2(input.into()).expect("Failed to parse main program module");
207
208    item_mod.content
209        .as_mut()
210        .expect("Program module has no body?")
211        .1.extend(fns.into_iter().map(Into::into));
212
213    quote! { #item_mod }.into()
214}
215