1use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{format_ident, quote};
12use syn::{Expr, ItemFn, Meta, Token, parse::Parser, parse_macro_input, punctuated::Punctuated};
13
14#[proc_macro_attribute]
21pub fn canic_query(attr: TokenStream, item: TokenStream) -> TokenStream {
22 expand_entry(EndpointKind::Query, attr, item)
23}
24
25#[proc_macro_attribute]
26pub fn canic_update(attr: TokenStream, item: TokenStream) -> TokenStream {
27 expand_entry(EndpointKind::Update, attr, item)
28}
29
30#[derive(Clone, Copy)]
37enum EndpointKind {
38 Query,
39 Update,
40}
41
42mod parse {
49 use super::*;
50
51 #[derive(Clone, Debug)]
52 pub enum AuthSpec {
53 Any(Vec<Expr>),
54 All(Vec<Expr>),
55 }
56
57 #[derive(Debug)]
58 pub struct ParsedArgs {
59 pub forwarded: Vec<TokenStream2>,
60 pub app_guard: bool,
61 pub user_guard: bool,
62 pub auth: Option<AuthSpec>,
63 pub policies: Vec<Expr>,
64 }
65
66 pub fn parse_args(attr: TokenStream2) -> syn::Result<ParsedArgs> {
67 let Ok(metas) = Punctuated::<Meta, Token![,]>::parse_terminated.parse2(attr.clone()) else {
68 if attr.is_empty() {
71 return Ok(empty());
72 }
73
74 return Ok(ParsedArgs {
75 forwarded: vec![attr],
76 ..empty()
77 });
78 };
79
80 let mut forwarded = Vec::new();
81 let mut app_guard = false;
82 let mut user_guard = false;
83 let mut auth = None::<AuthSpec>;
84 let mut policies = Vec::<Expr>::new();
85
86 for meta in metas {
87 match meta {
88 Meta::List(list) if list.path.is_ident("guard") => {
92 let inner = Punctuated::<Meta, Token![,]>::parse_terminated
93 .parse2(list.tokens.clone())?
94 .into_iter()
95 .collect::<Vec<_>>();
96
97 if inner.is_empty() {
98 return Err(syn::Error::new_spanned(
99 list,
100 "`guard(...)` expects at least one argument (e.g., `guard(app)`)",
101 ));
102 }
103
104 for item in inner {
106 match item {
107 Meta::Path(p) if p.is_ident("app") => {
108 app_guard = true;
109 }
110 other => {
111 return Err(syn::Error::new_spanned(
112 other,
113 "only `guard(app)` is supported",
114 ));
115 }
116 }
117 }
118 }
119
120 Meta::List(list) if list.path.is_ident("auth_any") => {
122 if auth.is_some() {
123 return Err(conflicting_auth(&list));
124 }
125 let rules = parse_rules(&list)?;
126 auth = Some(AuthSpec::Any(rules));
127 }
128
129 Meta::List(list) if list.path.is_ident("auth_all") => {
131 if auth.is_some() {
132 return Err(conflicting_auth(&list));
133 }
134 let rules = parse_rules(&list)?;
135 auth = Some(AuthSpec::All(rules));
136 }
137
138 Meta::List(list) if list.path.is_ident("policy") => {
142 let parsed = Punctuated::<Expr, Token![,]>::parse_terminated
143 .parse2(list.tokens.clone())?
144 .into_iter()
145 .collect::<Vec<_>>();
146
147 if parsed.is_empty() {
148 return Err(syn::Error::new_spanned(
149 list,
150 "`policy(...)` expects at least one policy expression",
151 ));
152 }
153
154 policies.extend(parsed);
155 }
156
157 Meta::NameValue(nv) if nv.path.is_ident("guard") => {
161 user_guard = true;
162 forwarded.push(quote!(#nv));
163 }
164
165 _ => forwarded.push(quote!(#meta)),
167 }
168 }
169
170 Ok(ParsedArgs {
171 forwarded,
172 app_guard,
173 user_guard,
174 auth,
175 policies,
176 })
177 }
178 const fn empty() -> ParsedArgs {
179 ParsedArgs {
180 forwarded: Vec::new(),
181 app_guard: false,
182 user_guard: false,
183 auth: None,
184 policies: Vec::new(),
185 }
186 }
187
188 fn parse_rules(list: &syn::MetaList) -> syn::Result<Vec<Expr>> {
189 let rules = Punctuated::<Expr, Token![,]>::parse_terminated
190 .parse2(list.tokens.clone())?
191 .into_iter()
192 .collect::<Vec<_>>();
193
194 if rules.is_empty() {
195 return Err(syn::Error::new_spanned(
196 list,
197 "authorization requires at least one rule",
198 ));
199 }
200
201 Ok(rules)
202 }
203
204 fn conflicting_auth(list: &syn::MetaList) -> syn::Error {
205 syn::Error::new_spanned(list, "conflicting authorization composition")
206 }
207}
208
209mod validate {
216 use super::*;
217 use parse::{AuthSpec, ParsedArgs};
218
219 pub struct ValidatedArgs {
220 pub forwarded: Vec<TokenStream2>,
221 pub app_guard: bool,
222 pub auth: Option<AuthSpec>,
223 pub policies: Vec<Expr>,
224 }
225
226 pub fn validate(
227 parsed: ParsedArgs,
228 sig: &syn::Signature,
229 asyncness: bool,
230 ) -> syn::Result<ValidatedArgs> {
231 if parsed.app_guard && parsed.user_guard {
232 return Err(syn::Error::new_spanned(
233 &sig.ident,
234 "`app` cannot be combined with `guard = ...`",
235 ));
236 }
237
238 if parsed.auth.is_some() && parsed.user_guard {
239 return Err(syn::Error::new_spanned(
240 &sig.ident,
241 "authorization cannot be combined with `guard = ...`",
242 ));
243 }
244
245 if parsed.auth.is_some() {
246 if !asyncness {
247 return Err(syn::Error::new_spanned(
248 &sig.ident,
249 "authorization requires `async fn`",
250 ));
251 }
252 if !returns_result(sig) {
253 return Err(syn::Error::new_spanned(
254 &sig.output,
255 "authorized endpoints must return `Result<_, From<canic::Error>>`",
256 ));
257 }
258 }
259
260 if parsed.app_guard && !returns_result(sig) {
261 return Err(syn::Error::new_spanned(
262 &sig.output,
263 "`app` guard requires `Result<_, From<canic::Error>>`",
264 ));
265 }
266
267 if !parsed.policies.is_empty() && !returns_result(sig) {
268 return Err(syn::Error::new_spanned(
269 &sig.output,
270 "`policy(...)` requires `Result<_, From<canic::Error>>`",
271 ));
272 }
273
274 Ok(ValidatedArgs {
275 forwarded: parsed.forwarded,
276 app_guard: parsed.app_guard,
277 auth: parsed.auth,
278 policies: parsed.policies,
279 })
280 }
281
282 fn returns_result(sig: &syn::Signature) -> bool {
283 let syn::ReturnType::Type(_, ty) = &sig.output else {
284 return false;
285 };
286 let syn::Type::Path(ty) = &**ty else {
287 return false;
288 };
289 ty.path
290 .segments
291 .last()
292 .is_some_and(|seg| seg.ident == "Result")
293 }
294}
295
296mod expand {
303 use super::*;
304 use parse::AuthSpec;
305 use validate::ValidatedArgs;
306
307 pub fn expand(kind: EndpointKind, args: ValidatedArgs, mut func: ItemFn) -> TokenStream {
308 let attrs = func.attrs.clone();
309 let orig_sig = func.sig.clone();
310 let orig_name = orig_sig.ident.clone();
311 let vis = func.vis.clone();
312 let inputs = orig_sig.inputs.clone();
313 let output = orig_sig.output.clone();
314 let asyncness = orig_sig.asyncness.is_some();
315 let returns_result = returns_result(&orig_sig);
316
317 let impl_name = format_ident!("__canic_impl_{}", orig_name);
318 func.sig.ident = impl_name.clone();
319
320 let cdk_attr = cdk_attr(kind, &args.forwarded);
321
322 let dispatch = dispatch(kind, asyncness);
323
324 let wrapper_sig = syn::Signature {
325 ident: orig_name.clone(),
326 inputs,
327 output,
328 ..orig_sig.clone()
329 };
330
331 let label = orig_name.to_string();
332
333 let attempted = attempted(&label);
334 let guard = guard(kind, args.app_guard, &label);
335 let auth = auth(args.auth.as_ref(), &label);
336 let policy = policy(&args.policies, &label);
337
338 let call_args = match extract_args(&orig_sig) {
339 Ok(v) => v,
340 Err(e) => return e.to_compile_error().into(),
341 };
342
343 let call = call(asyncness, dispatch, &label, impl_name, &call_args);
344 let completion = completion(&label, returns_result, call);
345
346 quote! {
347 #(#attrs)*
348 #cdk_attr
349 #vis #wrapper_sig {
350 #attempted
351 #guard
352 #auth
353 #policy
354 #completion
355 }
356
357 #func
358 }
359 .into()
360 }
361
362 fn returns_result(sig: &syn::Signature) -> bool {
363 let syn::ReturnType::Type(_, ty) = &sig.output else {
364 return false;
365 };
366 let syn::Type::Path(ty) = &**ty else {
367 return false;
368 };
369 ty.path
370 .segments
371 .last()
372 .is_some_and(|seg| seg.ident == "Result")
373 }
374
375 fn dispatch(kind: EndpointKind, asyncness: bool) -> TokenStream2 {
376 match (kind, asyncness) {
377 (EndpointKind::Query, false) => quote!(::canic::core::dispatch::dispatch_query),
378 (EndpointKind::Query, true) => quote!(::canic::core::dispatch::dispatch_query_async),
379 (EndpointKind::Update, false) => quote!(::canic::core::dispatch::dispatch_update),
380 (EndpointKind::Update, true) => quote!(::canic::core::dispatch::dispatch_update_async),
381 }
382 }
383
384 fn record_access_denied(label: &String, kind: TokenStream2) -> TokenStream2 {
385 quote! {
386 ::canic::core::ops::runtime::metrics::AccessMetrics::increment(#label, #kind);
387 }
388 }
389
390 fn attempted(label: &String) -> TokenStream2 {
391 quote! {
392 ::canic::core::ops::runtime::metrics::EndpointAttemptMetrics::increment_attempted(#label);
393 }
394 }
395
396 fn guard(kind: EndpointKind, enabled: bool, label: &String) -> TokenStream2 {
397 if !enabled {
398 return quote!();
399 }
400
401 let metric = record_access_denied(
402 label,
403 quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Guard),
404 );
405
406 match kind {
407 EndpointKind::Query => quote! {
408 if let Err(err) = ::canic::core::guard::guard_app_query() {
409 #metric
410 return Err(err.into());
411 }
412 },
413 EndpointKind::Update => quote! {
414 if let Err(err) = ::canic::core::guard::guard_app_update() {
415 #metric
416 return Err(err.into());
417 }
418 },
419 }
420 }
421
422 fn auth(auth: Option<&AuthSpec>, label: &String) -> TokenStream2 {
423 let metric = record_access_denied(
424 label,
425 quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Auth),
426 );
427
428 match auth {
429 Some(AuthSpec::Any(rules)) => quote! {
430 if let Err(err) = ::canic::core::auth_require_any!(#(#rules),*) {
431 #metric
432 return Err(err.into());
433 }
434 },
435 Some(AuthSpec::All(rules)) => quote! {
436 if let Err(err) = ::canic::core::auth_require_all!(#(#rules),*) {
437 #metric
438 return Err(err.into());
439 }
440 },
441 None => quote!(),
442 }
443 }
444
445 fn policy(policies: &[Expr], label: &String) -> TokenStream2 {
446 if policies.is_empty() {
447 return quote!();
448 }
449
450 let metric = record_access_denied(
451 label,
452 quote!(::canic::core::ops::runtime::metrics::AccessMetricKind::Policy),
453 );
454
455 let checks = policies.iter().map(|expr| {
456 quote! {
457 if let Err(err) = #expr().await {
458 #metric
459 return Err(err.into());
460 }
461 }
462 });
463 quote!(#(#checks)*)
464 }
465
466 fn call(
467 asyncness: bool,
468 dispatch: TokenStream2,
469 label: &String,
470 impl_name: syn::Ident,
471 call_args: &[TokenStream2],
472 ) -> TokenStream2 {
473 if asyncness {
474 quote! {
475 #dispatch(#label, || async move {
476 #impl_name(#(#call_args),*).await
477 }).await
478 }
479 } else {
480 quote! {
481 #dispatch(#label, || {
482 #impl_name(#(#call_args),*)
483 })
484 }
485 }
486 }
487
488 fn completion(label: &String, returns_result: bool, call: TokenStream2) -> TokenStream2 {
489 let result_metrics = if returns_result {
490 quote! {
491 if out.is_ok() {
492 ::canic::core::ops::runtime::metrics::EndpointResultMetrics::increment_ok(#label);
493 } else {
494 ::canic::core::ops::runtime::metrics::EndpointResultMetrics::increment_err(#label);
495 }
496 }
497 } else {
498 quote!()
499 };
500
501 quote! {
502 {
503 let out = #call;
504 ::canic::core::ops::runtime::metrics::EndpointAttemptMetrics::increment_completed(#label);
505 #result_metrics
506 out
507 }
508 }
509 }
510
511 fn extract_args(sig: &syn::Signature) -> syn::Result<Vec<TokenStream2>> {
512 let mut out = Vec::new();
513 for input in &sig.inputs {
514 match input {
515 syn::FnArg::Typed(pat) => match &*pat.pat {
516 syn::Pat::Ident(id) => out.push(quote!(#id)),
517 _ => {
518 return Err(syn::Error::new_spanned(
519 &pat.pat,
520 "destructuring parameters not supported",
521 ));
522 }
523 },
524 syn::FnArg::Receiver(r) => {
525 return Err(syn::Error::new_spanned(
526 r,
527 "`self` not supported in canic endpoints",
528 ));
529 }
530 }
531 }
532 Ok(out)
533 }
534}
535
536fn cdk_attr(kind: EndpointKind, forwarded: &[TokenStream2]) -> TokenStream2 {
537 match kind {
538 EndpointKind::Query => {
539 if forwarded.is_empty() {
540 quote!(#[::canic::cdk::query])
541 } else {
542 quote!(#[::canic::cdk::query(#(#forwarded),*)])
543 }
544 }
545 EndpointKind::Update => {
546 if forwarded.is_empty() {
547 quote!(#[::canic::cdk::update])
548 } else {
549 quote!(#[::canic::cdk::update(#(#forwarded),*)])
550 }
551 }
552 }
553}
554
555fn expand_entry(kind: EndpointKind, attr: TokenStream, item: TokenStream) -> TokenStream {
562 let func = parse_macro_input!(item as ItemFn);
563 let sig = func.sig.clone();
564 let asyncness = sig.asyncness.is_some();
565
566 let parsed = match parse::parse_args(attr.into()) {
567 Ok(v) => v,
568 Err(e) => return e.to_compile_error().into(),
569 };
570
571 let validated = match validate::validate(parsed, &sig, asyncness) {
572 Ok(v) => v,
573 Err(e) => return e.to_compile_error().into(),
574 };
575
576 expand::expand(kind, validated, func)
577}