anchor_modular_program/
lib.rs1#![warn(missing_docs)]
2
3use std::{collections::HashMap, io::Read};
24use anchor_syn::*;
25use syn::*;
26use quote::*;
27
28
29#[proc_macro_attribute]
62pub fn modular_program(
63 args: proc_macro::TokenStream,
64 input: proc_macro::TokenStream,
65) -> proc_macro::TokenStream {
66 let ProgramMacro { modules } = syn::parse_macro_input!(args as ProgramMacro);
67
68 let fns = modules.into_iter()
69 .map(|m| (m.clone(), get_program(m.clone())))
70 .flat_map(|(spec, p)| {
71 p.ixs.into_iter().map(|ix| build_relay(&spec, ix)).collect::<Vec<_>>()
72 })
73 .collect();
74
75 let input = insert_fns_into_first_module(input, fns);
76 let program = syn::parse_macro_input!(input as anchor_syn::Program);
77 program.to_token_stream().into()
78}
79
80
81fn build_relay(spec: &ModuleSpec, ix: Ix) -> ItemFn {
90
91 let item_fn = &ix.raw_method;
92 let ItemFn { attrs, .. } = &item_fn;
93 let Signature { ident: fn_name, inputs, generics, output, .. } = &item_fn.sig;
94
95 let path = spec.module.clone();
96 let first = path.segments[0].clone().ident;
97
98 let new_name = match &spec.prefix {
99 Some(s) if s.is_empty() => fn_name.clone(),
100 o => {
101 let prefix = o.clone().unwrap_or(first.to_string());
102 Ident::new(format!("{}_{}", prefix, fn_name).as_str(), first.span())
103 }
104 };
105
106 let mut inputs = inputs.clone();
107 let arg_names = inputs.iter_mut().filter_map(|arg| {
108 if let FnArg::Typed(pt) = arg {
109 if let syn::Pat::Ident(id) = &mut *pt.pat {
110 id.mutability = None;
111 return Some(id.ident.clone());
112 }
113 }
114 None
115 }).collect::<Vec<_>>();
116
117 if let Some(w) = &spec.wrapper {
118 parse_quote! {
119 #(#attrs)*
120 pub fn #new_name #generics(#inputs) #output {
121 { #w!(#path::#fn_name, #inputs) }
122 }
123 }
124 } else {
125 parse_quote! {
126 #(#attrs)*
127 pub fn #new_name #generics(#inputs) #output {
128 #path::#fn_name(#(#arg_names),*)
129 }
130 }
131 }
132}
133
134
135fn get_program(spec: ModuleSpec) -> Program {
141
142 let mod_path = format!(
143 "{}/{}",
144 std::env::var("CARGO_MANIFEST_DIR").unwrap(),
145 spec.get_file_path()
146 );
147
148 let mut code_str = String::new();
149 std::fs::File::open(mod_path).unwrap().read_to_string(&mut code_str).unwrap();
150 let parsed = syn::parse_file(&code_str).unwrap();
151
152 let program_mod = ItemMod {
153 vis: Visibility::Public(VisPublic { pub_token: Default::default() }),
154 attrs: vec![],
155 mod_token: syn::token::Mod::default(),
156 ident: Ident::new("abc", proc_macro2::Span::call_site()),
157 content: Some((
158 syn::token::Brace::default(),
159 parsed.items,
160 )),
161 semi: None,
162 };
163
164 let program = anchor_syn::parser::program::parse(program_mod).unwrap();
165 assert!(program.fallback_fn.is_none(), "additional program module cant have fallback");
166 program
167}
168
169
170
171#[derive(Debug)]
176struct ProgramMacro { modules: Vec<ModuleSpec>, }
177
178impl parse::Parse for ProgramMacro {
179 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
180
181 let modules_ident: Ident = input.parse()?;
183 if modules_ident != "modules" {
184 return Err(syn::Error::new(modules_ident.span(), "expected `modules`"));
185 }
186
187 input.parse::<Token![=]>()?;
188
189 let content;
191 syn::bracketed!(content in input);
192 let specs = content.parse_terminated::<ModuleSpec, Token![,]>(|p| p.parse())?;
193
194 let modules = specs.into_iter().collect();
196
197 Ok(ProgramMacro { modules })
198 }
199}
200
201#[derive(Clone, Debug)]
202struct ModuleSpec {
203 module: Path,
204 prefix: Option<String>,
205 file_path: Option<String>,
206 wrapper: Option<Path>
207}
208
209impl ModuleSpec {
210 fn get_file_path(&self) -> String {
211 self.file_path.clone().unwrap_or_else(|| {
212 let p = self.module.segments.iter().fold(String::new(), |s, p| format!("{}/{}", s, p.ident));
213 format!("./src{}.rs", p)
214 })
215 }
216}
217
218impl parse::Parse for ModuleSpec {
219 fn parse(input: parse::ParseStream) -> Result<Self> {
220
221 type T = (String, (Option<String>, Option<Path>));
222 fn parse_field(p: parse::ParseStream) -> syn::Result<T> {
223 let name = p.parse::<Ident>()?.to_string();
224 p.parse::<Token![:]>()?;
225 Ok(
226 if name == "file_path" || name == "prefix" {
227 (name, (Some(p.parse::<LitStr>()?.value()), None))
228 } else if name == "module" || name == "wrapper" {
229 (name, (None, Some(p.parse::<Path>()?)))
230 } else {
231 panic!("Invalid module spec param: {}", name);
232 }
233 )
234 }
235
236
237 if input.peek(Ident) {
238 let module = input.parse::<Path>()?;
239 Ok(ModuleSpec { module, prefix: None, file_path: None, wrapper: None })
240 } else {
241 let content;
242 syn::braced!(content in input);
243 let fields = content.parse_terminated::<T, Token![,]>(parse_field)?;
244 let mut hm = fields.clone().into_iter().collect::<HashMap<String, _>>();
245 assert!(hm.len() == fields.len(), "duplicate field");
246 Ok(ModuleSpec {
247 module: hm.remove("module").expect("module is required").1.unwrap(),
248 prefix: hm.remove("prefix").map(|t| t.0).flatten(),
249 file_path: hm.remove("file_path").map(|t| t.0).flatten(),
250 wrapper: hm.remove("wrapper").map(|t| t.1).flatten(),
251 })
252 }
253 }
254}
255
256
257fn insert_fns_into_first_module(input: proc_macro::TokenStream, fns: Vec<ItemFn>) -> proc_macro::TokenStream {
262
263 let mut item_mod: ItemMod = parse2(input.into()).expect("Failed to parse main program module");
264
265 item_mod.content
266 .as_mut()
267 .expect("Program module has no body?")
268 .1.extend(fns.into_iter().map(Into::into));
269
270 quote! { #item_mod }.into()
271}
272