canic_macros/
lib.rs

1//! Canic proc macros.
2//!
3//! Thin, opinionated wrappers around IC CDK endpoint attributes
4//! (`#[query]`, `#[update]`), routed through `canic::cdk::*`.
5//!
6//! Pipeline enforced by generated wrappers:
7//!   guard → auth → policy → dispatch
8
9use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{format_ident, quote};
12use syn::{Expr, ItemFn, Meta, Token, parse::Parser, parse_macro_input, punctuated::Punctuated};
13
14//
15// ============================================================================
16// Public entry points
17// ============================================================================
18//
19
20#[proc_macro_attribute]
21pub fn canic_query(attr: TokenStream, item: TokenStream) -> TokenStream {
22    expand_entry(EndpointKind::Query, attr, item)
23}
24
25#[proc_macro_attribute]
26pub fn canic_update(attr: TokenStream, item: TokenStream) -> TokenStream {
27    expand_entry(EndpointKind::Update, attr, item)
28}
29
30//
31// ============================================================================
32// Shared internal types
33// ============================================================================
34//
35
36#[derive(Clone, Copy)]
37enum EndpointKind {
38    Query,
39    Update,
40}
41
42//
43// ============================================================================
44// parse — attribute grammar only
45// ============================================================================
46//
47
48mod parse {
49    use super::*;
50
51    #[derive(Clone, Debug)]
52    pub enum AuthSpec {
53        Any(Vec<Expr>),
54        All(Vec<Expr>),
55    }
56
57    #[derive(Debug)]
58    pub struct ParsedArgs {
59        pub forwarded: Vec<TokenStream2>,
60        pub app_guard: bool,
61        pub user_guard: bool,
62        pub auth: Option<AuthSpec>,
63        pub policies: Vec<Expr>,
64    }
65
66    pub fn parse_args(attr: TokenStream2) -> syn::Result<ParsedArgs> {
67        let Ok(metas) = Punctuated::<Meta, Token![,]>::parse_terminated.parse2(attr.clone()) else {
68            // If the attr doesn't parse as Meta list, fall back to forwarding raw tokens to the CDK.
69            // This preserves compatibility with CDK syntax we don't model.
70            if attr.is_empty() {
71                return Ok(empty());
72            }
73
74            return Ok(ParsedArgs {
75                forwarded: vec![attr],
76                ..empty()
77            });
78        };
79
80        let mut forwarded = Vec::new();
81        let mut app_guard = false;
82        let mut user_guard = false;
83        let mut auth = None::<AuthSpec>;
84        let mut policies = Vec::<Expr>::new();
85
86        for meta in metas {
87            match meta {
88                // guard(...)
89                //
90                // Canic-specific guard stage. Top-level `app` is no longer accepted.
91                Meta::List(list) if list.path.is_ident("guard") => {
92                    let inner = Punctuated::<Meta, Token![,]>::parse_terminated
93                        .parse2(list.tokens.clone())?
94                        .into_iter()
95                        .collect::<Vec<_>>();
96
97                    if inner.is_empty() {
98                        return Err(syn::Error::new_spanned(
99                            list,
100                            "`guard(...)` expects at least one argument (e.g., `guard(app)`)",
101                        ));
102                    }
103
104                    // For now, support only guard(app). You can widen this later.
105                    for item in inner {
106                        match item {
107                            Meta::Path(p) if p.is_ident("app") => {
108                                app_guard = true;
109                            }
110                            other => {
111                                return Err(syn::Error::new_spanned(
112                                    other,
113                                    "only `guard(app)` is supported",
114                                ));
115                            }
116                        }
117                    }
118                }
119
120                // auth_any(...)
121                Meta::List(list) if list.path.is_ident("auth_any") => {
122                    if auth.is_some() {
123                        return Err(conflicting_auth(&list));
124                    }
125                    let rules = parse_rules(&list)?;
126                    auth = Some(AuthSpec::Any(rules));
127                }
128
129                // auth_all(...)
130                Meta::List(list) if list.path.is_ident("auth_all") => {
131                    if auth.is_some() {
132                        return Err(conflicting_auth(&list));
133                    }
134                    let rules = parse_rules(&list)?;
135                    auth = Some(AuthSpec::All(rules));
136                }
137
138                // policy(...)
139                //
140                // Parse as Expr so you can do policy(local_only()), policy(max_rounds(rounds, 10_000)), etc.
141                Meta::List(list) if list.path.is_ident("policy") => {
142                    let parsed = Punctuated::<Expr, Token![,]>::parse_terminated
143                        .parse2(list.tokens.clone())?
144                        .into_iter()
145                        .collect::<Vec<_>>();
146
147                    if parsed.is_empty() {
148                        return Err(syn::Error::new_spanned(
149                            list,
150                            "`policy(...)` expects at least one policy expression",
151                        ));
152                    }
153
154                    policies.extend(parsed);
155                }
156
157                // explicit CDK guard = ...
158                //
159                // We still forward it, but track that it exists so validation can ban combinations.
160                Meta::NameValue(nv) if nv.path.is_ident("guard") => {
161                    user_guard = true;
162                    forwarded.push(quote!(#nv));
163                }
164
165                // Everything else is forwarded to the CDK attribute unchanged.
166                _ => forwarded.push(quote!(#meta)),
167            }
168        }
169
170        Ok(ParsedArgs {
171            forwarded,
172            app_guard,
173            user_guard,
174            auth,
175            policies,
176        })
177    }
178    const fn empty() -> ParsedArgs {
179        ParsedArgs {
180            forwarded: Vec::new(),
181            app_guard: false,
182            user_guard: false,
183            auth: None,
184            policies: Vec::new(),
185        }
186    }
187
188    fn parse_rules(list: &syn::MetaList) -> syn::Result<Vec<Expr>> {
189        let rules = Punctuated::<Expr, Token![,]>::parse_terminated
190            .parse2(list.tokens.clone())?
191            .into_iter()
192            .collect::<Vec<_>>();
193
194        if rules.is_empty() {
195            return Err(syn::Error::new_spanned(
196                list,
197                "authorization requires at least one rule",
198            ));
199        }
200
201        Ok(rules)
202    }
203
204    fn conflicting_auth(list: &syn::MetaList) -> syn::Error {
205        syn::Error::new_spanned(list, "conflicting authorization composition")
206    }
207}
208
209//
210// ============================================================================
211// validate — semantic constraints
212// ============================================================================
213//
214
215mod validate {
216    use super::*;
217    use parse::{AuthSpec, ParsedArgs};
218
219    pub struct ValidatedArgs {
220        pub forwarded: Vec<TokenStream2>,
221        pub app_guard: bool,
222        pub auth: Option<AuthSpec>,
223        pub policies: Vec<Expr>,
224    }
225
226    pub fn validate(
227        parsed: ParsedArgs,
228        sig: &syn::Signature,
229        asyncness: bool,
230    ) -> syn::Result<ValidatedArgs> {
231        if parsed.app_guard && parsed.user_guard {
232            return Err(syn::Error::new_spanned(
233                &sig.ident,
234                "`app` cannot be combined with `guard = ...`",
235            ));
236        }
237
238        if parsed.auth.is_some() && parsed.user_guard {
239            return Err(syn::Error::new_spanned(
240                &sig.ident,
241                "authorization cannot be combined with `guard = ...`",
242            ));
243        }
244
245        if parsed.auth.is_some() {
246            if !asyncness {
247                return Err(syn::Error::new_spanned(
248                    &sig.ident,
249                    "authorization requires `async fn`",
250                ));
251            }
252            if !returns_result(sig) {
253                return Err(syn::Error::new_spanned(
254                    &sig.output,
255                    "authorized endpoints must return `Result<_, From<canic::Error>>`",
256                ));
257            }
258        }
259
260        if parsed.app_guard && !returns_result(sig) {
261            return Err(syn::Error::new_spanned(
262                &sig.output,
263                "`app` guard requires `Result<_, From<canic::Error>>`",
264            ));
265        }
266
267        if !parsed.policies.is_empty() && !returns_result(sig) {
268            return Err(syn::Error::new_spanned(
269                &sig.output,
270                "`policy(...)` requires `Result<_, From<canic::Error>>`",
271            ));
272        }
273
274        Ok(ValidatedArgs {
275            forwarded: parsed.forwarded,
276            app_guard: parsed.app_guard,
277            auth: parsed.auth,
278            policies: parsed.policies,
279        })
280    }
281
282    fn returns_result(sig: &syn::Signature) -> bool {
283        let syn::ReturnType::Type(_, ty) = &sig.output else {
284            return false;
285        };
286        let syn::Type::Path(ty) = &**ty else {
287            return false;
288        };
289        ty.path
290            .segments
291            .last()
292            .is_some_and(|seg| seg.ident == "Result")
293    }
294}
295
296//
297// ============================================================================
298// expand — code generation only
299// ============================================================================
300//
301
302mod expand {
303    use super::*;
304    use parse::AuthSpec;
305    use validate::ValidatedArgs;
306
307    pub fn expand(kind: EndpointKind, args: ValidatedArgs, mut func: ItemFn) -> TokenStream {
308        let attrs = func.attrs.clone();
309        let orig_sig = func.sig.clone();
310        let orig_name = orig_sig.ident.clone();
311        let vis = func.vis.clone();
312        let inputs = orig_sig.inputs.clone();
313        let output = orig_sig.output.clone();
314        let asyncness = orig_sig.asyncness.is_some();
315        let returns_result = returns_result(&orig_sig);
316
317        let impl_name = format_ident!("__canic_impl_{}", orig_name);
318        func.sig.ident = impl_name.clone();
319
320        let cdk_attr = cdk_attr(kind, &args.forwarded);
321
322        let dispatch = dispatch(kind, asyncness);
323
324        let wrapper_sig = syn::Signature {
325            ident: orig_name.clone(),
326            inputs,
327            output,
328            ..orig_sig.clone()
329        };
330
331        let label = orig_name.to_string();
332
333        let attempted = attempted(&label);
334        let guard = guard(kind, args.app_guard, &label);
335        let auth = auth(args.auth.as_ref(), &label);
336        let policy = policy(&args.policies, &label);
337
338        let call_args = match extract_args(&orig_sig) {
339            Ok(v) => v,
340            Err(e) => return e.to_compile_error().into(),
341        };
342
343        let call = call(asyncness, dispatch, &label, impl_name, &call_args);
344        let completion = completion(&label, returns_result, call);
345
346        quote! {
347           #(#attrs)*
348           #cdk_attr
349            #vis #wrapper_sig {
350                #attempted
351                #guard
352                #auth
353                #policy
354                #completion
355            }
356
357            #func
358        }
359        .into()
360    }
361
362    fn returns_result(sig: &syn::Signature) -> bool {
363        let syn::ReturnType::Type(_, ty) = &sig.output else {
364            return false;
365        };
366        let syn::Type::Path(ty) = &**ty else {
367            return false;
368        };
369        ty.path
370            .segments
371            .last()
372            .is_some_and(|seg| seg.ident == "Result")
373    }
374
375    fn dispatch(kind: EndpointKind, asyncness: bool) -> TokenStream2 {
376        match (kind, asyncness) {
377            (EndpointKind::Query, false) => quote!(::canic::core::dispatch::dispatch_query),
378            (EndpointKind::Query, true) => quote!(::canic::core::dispatch::dispatch_query_async),
379            (EndpointKind::Update, false) => quote!(::canic::core::dispatch::dispatch_update),
380            (EndpointKind::Update, true) => quote!(::canic::core::dispatch::dispatch_update_async),
381        }
382    }
383
384    fn record_access_denied(label: &String, kind: TokenStream2) -> TokenStream2 {
385        quote! {
386            ::canic::core::ops::runtime::metrics::AccessMetrics::increment(#label, #kind);
387        }
388    }
389
390    fn attempted(label: &String) -> TokenStream2 {
391        quote! {
392            ::canic::core::ops::runtime::metrics::EndpointAttemptMetrics::increment_attempted(#label);
393        }
394    }
395
396    fn guard(kind: EndpointKind, enabled: bool, label: &String) -> TokenStream2 {
397        if !enabled {
398            return quote!();
399        }
400
401        let metric = record_access_denied(
402            label,
403            quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Guard),
404        );
405
406        match kind {
407            EndpointKind::Query => quote! {
408                if let Err(err) = ::canic::core::guard::guard_app_query() {
409                    #metric
410                    return Err(err.into());
411                }
412            },
413            EndpointKind::Update => quote! {
414                if let Err(err) = ::canic::core::guard::guard_app_update() {
415                    #metric
416                    return Err(err.into());
417                }
418            },
419        }
420    }
421
422    fn auth(auth: Option<&AuthSpec>, label: &String) -> TokenStream2 {
423        let metric = record_access_denied(
424            label,
425            quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Auth),
426        );
427
428        match auth {
429            Some(AuthSpec::Any(rules)) => quote! {
430                if let Err(err) = ::canic::core::auth_require_any!(#(#rules),*) {
431                    #metric
432                    return Err(err.into());
433                }
434            },
435            Some(AuthSpec::All(rules)) => quote! {
436                if let Err(err) = ::canic::core::auth_require_all!(#(#rules),*) {
437                    #metric
438                    return Err(err.into());
439                }
440            },
441            None => quote!(),
442        }
443    }
444
445    fn policy(policies: &[Expr], label: &String) -> TokenStream2 {
446        if policies.is_empty() {
447            return quote!();
448        }
449
450        let metric = record_access_denied(
451            label,
452            quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Policy),
453        );
454
455        let checks = policies.iter().map(|expr| {
456            quote! {
457                if let Err(err) = #expr().await {
458                    #metric
459                    return Err(err.into());
460                }
461            }
462        });
463        quote!(#(#checks)*)
464    }
465
466    fn call(
467        asyncness: bool,
468        dispatch: TokenStream2,
469        label: &String,
470        impl_name: syn::Ident,
471        call_args: &[TokenStream2],
472    ) -> TokenStream2 {
473        if asyncness {
474            quote! {
475                #dispatch(#label, || async move {
476                    #impl_name(#(#call_args),*).await
477                }).await
478            }
479        } else {
480            quote! {
481                #dispatch(#label, || {
482                    #impl_name(#(#call_args),*)
483                })
484            }
485        }
486    }
487
488    fn completion(label: &String, returns_result: bool, call: TokenStream2) -> TokenStream2 {
489        let result_metrics = if returns_result {
490            quote! {
491                if out.is_ok() {
492                    ::canic::core::ops::runtime::metrics::EndpointResultMetrics::increment_ok(#label);
493                } else {
494                    ::canic::core::ops::runtime::metrics::EndpointResultMetrics::increment_err(#label);
495                }
496            }
497        } else {
498            quote!()
499        };
500
501        quote! {
502            {
503                let out = #call;
504                ::canic::core::ops::runtime::metrics::EndpointAttemptMetrics::increment_completed(#label);
505                #result_metrics
506                out
507            }
508        }
509    }
510
511    fn extract_args(sig: &syn::Signature) -> syn::Result<Vec<TokenStream2>> {
512        let mut out = Vec::new();
513        for input in &sig.inputs {
514            match input {
515                syn::FnArg::Typed(pat) => match &*pat.pat {
516                    syn::Pat::Ident(id) => out.push(quote!(#id)),
517                    _ => {
518                        return Err(syn::Error::new_spanned(
519                            &pat.pat,
520                            "destructuring parameters not supported",
521                        ));
522                    }
523                },
524                syn::FnArg::Receiver(r) => {
525                    return Err(syn::Error::new_spanned(
526                        r,
527                        "`self` not supported in canic endpoints",
528                    ));
529                }
530            }
531        }
532        Ok(out)
533    }
534}
535
536fn cdk_attr(kind: EndpointKind, forwarded: &[TokenStream2]) -> TokenStream2 {
537    match kind {
538        EndpointKind::Query => {
539            if forwarded.is_empty() {
540                quote!(#[::canic::cdk::query])
541            } else {
542                quote!(#[::canic::cdk::query(#(#forwarded),*)])
543            }
544        }
545        EndpointKind::Update => {
546            if forwarded.is_empty() {
547                quote!(#[::canic::cdk::update])
548            } else {
549                quote!(#[::canic::cdk::update(#(#forwarded),*)])
550            }
551        }
552    }
553}
554
555//
556// ============================================================================
557// Entry dispatcher
558// ============================================================================
559//
560
561fn expand_entry(kind: EndpointKind, attr: TokenStream, item: TokenStream) -> TokenStream {
562    let func = parse_macro_input!(item as ItemFn);
563    let sig = func.sig.clone();
564    let asyncness = sig.asyncness.is_some();
565
566    let parsed = match parse::parse_args(attr.into()) {
567        Ok(v) => v,
568        Err(e) => return e.to_compile_error().into(),
569    };
570
571    let validated = match validate::validate(parsed, &sig, asyncness) {
572        Ok(v) => v,
573        Err(e) => return e.to_compile_error().into(),
574    };
575
576    expand::expand(kind, validated, func)
577}