rocket_csrf_guard_derive/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::parse::{Parse, ParseStream, Parser, Result};
7use syn::{
8    parse_macro_input, punctuated::Punctuated, Error, Field, Fields, GenericParam, Ident,
9    ItemStruct, LitStr, Token, Type,
10};
11
12#[derive(Debug)]
13struct MaybeName {
14    name: Option<LitStr>,
15}
16
17impl Parse for MaybeName {
18    fn parse(input: ParseStream) -> Result<Self> {
19        let vars = Punctuated::<syn::LitStr, Token![,]>::parse_terminated(input)?;
20        if vars.is_empty() {
21            Ok(Self { name: None })
22        } else if vars.len() == 1 {
23            Ok(Self {
24                name: vars.first().cloned(),
25            })
26        } else {
27            Err(Error::new(
28                input.span(),
29                "expected at most one field name for csrf token, got multiple!",
30            ))
31        }
32    }
33}
34
35fn get_singular_lifetime(item: &ItemStruct) -> Option<Ident> {
36    let generics = &item.generics;
37    if generics.params.len() != 1 {
38        return None;
39    }
40    if let Some(GenericParam::Lifetime(lifetime)) = generics.params.first() {
41        Some(lifetime.lifetime.ident.clone())
42    } else {
43        None
44    }
45}
46
47#[proc_macro_attribute]
48pub fn with_csrf_token(args: TokenStream, input: TokenStream) -> TokenStream {
49    let mut item_struct = parse_macro_input!(input as ItemStruct);
50    let struct_name = item_struct.ident.clone();
51    let maybe_names = parse_macro_input!(args as MaybeName);
52
53    let field_name = maybe_names
54        .name
55        .map_or_else(|| "csrf_token".to_owned(), |s| s.value());
56
57    let lifetime = get_singular_lifetime(&item_struct);
58    let ident = Ident::new(&field_name, Span::call_site());
59
60    if let Fields::Named(ref mut fields) = item_struct.fields {
61        let existing = fields
62            .named
63            .iter()
64            .any(|f| f.ident.as_ref().map_or(false, |i| *i == ident));
65        // TODO: Validate field type is string or &str
66        if !existing {
67            if let Some(lifetime) = lifetime {
68                if let Ok(mut field) = Field::parse_named.parse2(quote! { #ident: &'a str }) {
69                    if let Type::Reference(reference) = &mut field.ty {
70                        if let Some(field_lifetime) = reference.lifetime.as_mut() {
71                            field_lifetime.ident = lifetime;
72                        }
73                    }
74                    fields.named.push(field);
75                }
76            } else if let Ok(field) = syn::Field::parse_named.parse2(quote! { #ident: String }) {
77                fields.named.push(field);
78            }
79        }
80    }
81
82    let (impl_generics, ty_generics, _) = item_struct.generics.split_for_impl();
83
84    quote! {
85        #item_struct
86
87        impl #impl_generics rocket_csrf_guard::WithUserProvidedCsrfToken for #struct_name #ty_generics {
88            fn csrf_token(&self) -> &str {
89                &self.#ident
90            }
91        }
92    }
93    .into()
94}