canpack_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{punctuated::Punctuated, spanned::Spanned, Attribute};
5
6#[derive(Default)]
7struct CanpackAttribute {
8    rename: Option<String>,
9    mode: Option<MethodMode>,
10}
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13enum MethodMode {
14    // Init,
15    Query,
16    CompositeQuery,
17    Update,
18}
19
20impl MethodMode {
21    pub fn candid_mode(&self) -> TokenStream2 {
22        match self {
23            // Self::Init => "init",
24            Self::Query => quote! {query},
25            Self::CompositeQuery => quote! {composite_query},
26            Self::Update => quote! {update},
27        }
28    }
29
30    pub fn ic_cdk_attr(&self) -> TokenStream2 {
31        match self {
32            // Self::Init => "init",
33            Self::Query => quote! {query},
34            Self::CompositeQuery => quote! {query(composite = true)},
35            Self::Update => quote! {update},
36        }
37    }
38}
39
40impl<'a> TryFrom<&'a str> for MethodMode {
41    type Error = &'a str;
42
43    fn try_from(mode: &'a str) -> Result<Self, Self::Error> {
44        use MethodMode::*;
45        Ok(match mode {
46            // "init" => Init,
47            "query" => Query,
48            "composite_query" => CompositeQuery,
49            "update" => Update,
50            _ => Err(mode)?,
51        })
52    }
53}
54
55fn get_lit_str(expr: &syn::Expr) -> std::result::Result<syn::LitStr, ()> {
56    if let syn::Expr::Lit(expr) = expr {
57        if let syn::Lit::Str(lit) = &expr.lit {
58            return Ok(lit.clone());
59        }
60    }
61    Err(())
62}
63
64fn parse_canpack_args(attr: &Attribute) -> syn::Result<Punctuated<syn::Meta, syn::Token![,]>> {
65    Ok(attr
66        .meta
67        .require_list()?
68        .parse_args_with(Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)?)
69}
70
71fn read_canpack_attribute(args: Vec<syn::Meta>) -> syn::Result<CanpackAttribute> {
72    let mut attr = CanpackAttribute {
73        rename: None,
74        mode: None,
75    };
76    for meta in args {
77        match &meta {
78            syn::Meta::NameValue(m) if m.path.is_ident("rename") && attr.rename.is_none() => {
79                if let Ok(lit) = get_lit_str(&m.value) {
80                    attr.rename = Some(lit.value());
81                }
82            }
83            syn::Meta::Path(p) if attr.mode.is_none() => {
84                let mode = p.get_ident().unwrap().to_string();
85                attr.mode = Some(mode.as_str().try_into().map_err(|mode| {
86                    syn::Error::new_spanned(meta, format!("unknown mode: {}", mode))
87                })?)
88            }
89            meta => {
90                return Err(syn::Error::new_spanned(
91                    meta,
92                    format!("unknown or conflicting attribute: {}", quote! {#meta}),
93                ))
94            }
95        }
96    }
97    Ok(attr)
98}
99
100#[proc_macro]
101pub fn export(input: TokenStream) -> TokenStream {
102    match export_macro(input) {
103        Ok(output) => output,
104        Err(err) => err.to_compile_error().into(),
105    }
106}
107
108fn export_macro(input: TokenStream) -> syn::Result<TokenStream> {
109    let input2: TokenStream2 = input.into();
110    let module: syn::ItemMod = syn::parse(
111        quote! {
112            mod __canpack_export {
113                #input2
114            }
115        }
116        .into(),
117    )?;
118
119    let mut module_output = quote! {};
120    let mut canpack_output = quote! {};
121
122    let mut functions = vec![];
123
124    for item in module.content.unwrap().1 {
125        if let syn::Item::Fn(function) = item {
126            functions.push(function);
127        } else {
128            return Err(syn::Error::new_spanned(
129                item,
130                "expected a function in `canpack::export!` macro",
131            ));
132        }
133    }
134    for mut function in functions {
135        let (canpack_attrs, fn_attrs) = function
136            .attrs
137            .into_iter()
138            .partition(|attr| attr.path().is_ident("canpack"));
139        function.attrs = fn_attrs;
140        if canpack_attrs.len() > 1 {
141            return Err(syn::Error::new_spanned(
142                canpack_attrs.last().unwrap(),
143                "more than one #[canpack] attribute on the same function",
144            ));
145        }
146
147        let fn_sig = &function.sig;
148        let fn_name = &function.sig.ident;
149        let fn_args = function
150            .sig
151            .inputs
152            .iter()
153            .map(|arg| {
154                if let syn::FnArg::Typed(pat_type) = arg {
155                    if let syn::Pat::Ident(id) = &*pat_type.pat {
156                        return Ok(&id.ident);
157                    }
158                }
159                Err(syn::Error::new_spanned(
160                    arg,
161                    "non-identifier pattern in function args",
162                ))
163            })
164            .collect::<syn::Result<Punctuated<_, syn::Token![,]>>>()?;
165
166        let mut mode = MethodMode::Query;
167        let mut fn_sig_rename = fn_sig.clone();
168
169        for attr in canpack_attrs {
170            let args = parse_canpack_args(&attr)?;
171            let canpack_attr = read_canpack_attribute(args.clone().into_iter().collect())?;
172            mode = canpack_attr.mode.unwrap_or(mode);
173            fn_sig_rename.ident = canpack_attr
174                .rename
175                .map(|name| syn::Ident::new(&name, attr.span()))
176                .unwrap_or(fn_sig_rename.ident);
177        }
178
179        let ic_cdk_attr = mode.ic_cdk_attr();
180        let candid_mode = mode.candid_mode();
181        let opt_await = fn_sig.asyncness.map(|_| quote! { .await });
182
183        module_output = quote! {
184            #module_output
185            #function
186        };
187        canpack_output = quote! {
188            #canpack_output
189            #[ic_cdk::#ic_cdk_attr]
190            #[candid::candid_method(#candid_mode)]
191            #fn_sig_rename {
192                $crate::#fn_name(#fn_args)#opt_await
193            }
194        };
195    }
196    let output = quote! {
197        #module_output
198        #[macro_export]
199        macro_rules! canpack {
200            () => {
201                #canpack_output
202            };
203        }
204    };
205    // Err(syn::Error::new_spanned(
206    //     output.clone(),
207    //     format!(">>>\n{}", quote! {#output}),
208    // ))?;
209    Ok(output.into())
210}