rocket_csrf_guard_derive/
lib.rs1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::parse::{Parse, ParseStream, Parser, Result};
7use syn::{
8 parse_macro_input, punctuated::Punctuated, Error, Field, Fields, GenericParam, Ident,
9 ItemStruct, LitStr, Token, Type,
10};
11
12#[derive(Debug)]
13struct MaybeName {
14 name: Option<LitStr>,
15}
16
17impl Parse for MaybeName {
18 fn parse(input: ParseStream) -> Result<Self> {
19 let vars = Punctuated::<syn::LitStr, Token![,]>::parse_terminated(input)?;
20 if vars.is_empty() {
21 Ok(Self { name: None })
22 } else if vars.len() == 1 {
23 Ok(Self {
24 name: vars.first().cloned(),
25 })
26 } else {
27 Err(Error::new(
28 input.span(),
29 "expected at most one field name for csrf token, got multiple!",
30 ))
31 }
32 }
33}
34
35fn get_singular_lifetime(item: &ItemStruct) -> Option<Ident> {
36 let generics = &item.generics;
37 if generics.params.len() != 1 {
38 return None;
39 }
40 if let Some(GenericParam::Lifetime(lifetime)) = generics.params.first() {
41 Some(lifetime.lifetime.ident.clone())
42 } else {
43 None
44 }
45}
46
47#[proc_macro_attribute]
48pub fn with_csrf_token(args: TokenStream, input: TokenStream) -> TokenStream {
49 let mut item_struct = parse_macro_input!(input as ItemStruct);
50 let struct_name = item_struct.ident.clone();
51 let maybe_names = parse_macro_input!(args as MaybeName);
52
53 let field_name = maybe_names
54 .name
55 .map_or_else(|| "csrf_token".to_owned(), |s| s.value());
56
57 let lifetime = get_singular_lifetime(&item_struct);
58 let ident = Ident::new(&field_name, Span::call_site());
59
60 if let Fields::Named(ref mut fields) = item_struct.fields {
61 let existing = fields
62 .named
63 .iter()
64 .any(|f| f.ident.as_ref().map_or(false, |i| *i == ident));
65 if !existing {
67 if let Some(lifetime) = lifetime {
68 if let Ok(mut field) = Field::parse_named.parse2(quote! { #ident: &'a str }) {
69 if let Type::Reference(reference) = &mut field.ty {
70 if let Some(field_lifetime) = reference.lifetime.as_mut() {
71 field_lifetime.ident = lifetime;
72 }
73 }
74 fields.named.push(field);
75 }
76 } else if let Ok(field) = syn::Field::parse_named.parse2(quote! { #ident: String }) {
77 fields.named.push(field);
78 }
79 }
80 }
81
82 let (impl_generics, ty_generics, _) = item_struct.generics.split_for_impl();
83
84 quote! {
85 #item_struct
86
87 impl #impl_generics rocket_csrf_guard::WithUserProvidedCsrfToken for #struct_name #ty_generics {
88 fn csrf_token(&self) -> &str {
89 &self.#ident
90 }
91 }
92 }
93 .into()
94}