Skip to main content

protovalidate_buffa_macros/
lib.rs

1//! `#[auto_validate]` — inserts `req.validate()?` at the top of every service
2//! handler method in an `impl` block whose request parameter is an
3//! `OwnedView<_>`. Single-site safety net: add it once to the service impl
4//! and every present-and-future handler is validated on entry.
5//!
6//! Non-handler `async fn`s inside the same `impl` block are left alone
7//! (they lack an `OwnedView<_>` parameter, so the macro skips them).
8
9use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::quote;
12use syn::{parse_macro_input, Error, FnArg, ImplItem, ItemImpl, PatType, Type, TypePath};
13
14#[proc_macro_attribute]
15pub fn auto_validate(attr: TokenStream, input: TokenStream) -> TokenStream {
16    if !attr.is_empty() {
17        return Error::new_spanned(
18            TokenStream2::from(attr),
19            "protovalidate_buffa::auto_validate takes no arguments",
20        )
21        .to_compile_error()
22        .into();
23    }
24
25    let mut item = parse_macro_input!(input as ItemImpl);
26
27    for impl_item in &mut item.items {
28        if let ImplItem::Fn(f) = impl_item {
29            if let Some(arg_ident) = find_owned_view_arg(&f.sig) {
30                let pv_ident = proc_macro2::Ident::new(
31                    "__protovalidate_buffa_req_owned",
32                    arg_ident.span(),
33                );
34
35                let decode: syn::Stmt = syn::parse_quote! {
36                    let #pv_ident = #arg_ident.to_owned_message();
37                };
38                let validate: syn::Stmt = syn::parse_quote! {
39                    <_ as ::protovalidate_buffa::Validate>::validate(&#pv_ident)
40                        .map_err(::protovalidate_buffa::ValidationError::into_connect_error)?;
41                };
42
43                f.block.stmts.insert(0, decode);
44                f.block.stmts.insert(1, validate);
45            }
46        }
47    }
48
49    TokenStream::from(quote! { #item })
50}
51
52/// Returns the ident of the first parameter whose type is a path ending in
53/// `OwnedView` (e.g. `OwnedView<pb::CreateFooRequestView<'static>>`).
54/// Non-handler methods that lack such a parameter return `None`.
55fn find_owned_view_arg(sig: &syn::Signature) -> Option<syn::Ident> {
56    for arg in &sig.inputs {
57        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
58            if is_owned_view(ty) {
59                if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
60                    return Some(pat_ident.ident.clone());
61                }
62            }
63        }
64    }
65    None
66}
67
68fn is_owned_view(ty: &Type) -> bool {
69    if let Type::Path(TypePath { path, .. }) = ty {
70        if let Some(last) = path.segments.last() {
71            return last.ident == "OwnedView";
72        }
73    }
74    false
75}