cookiebox_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse_macro_input, DeriveInput, Expr, Fields, ItemStruct, Lit, Meta, PathArguments, Type,
7};
8
9/// Implements a CookieName trait using passed in name from the macro attribute
10#[proc_macro_attribute]
11pub fn cookie(attr: TokenStream, item: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(item as ItemStruct);
13
14    let parsed_attr = parse_macro_input!(attr as Meta);
15
16    let mut cookie_name = String::new();
17
18    if !parsed_attr.path().is_ident("name") {
19        return syn::Error::new_spanned(
20            parsed_attr.path().get_ident(),
21            "Expected `name` parameter: #[cookie(name = \"...\")]",
22        )
23        .into_compile_error()
24        .into();
25    }
26    if let Meta::NameValue(nv) = parsed_attr {
27        if let Expr::Lit(expr) = &nv.value {
28            if let Lit::Str(lit_str) = &expr.lit {
29                cookie_name.push_str(&lit_str.value());
30            }
31        }
32    }
33
34    let cookie_struct = &input.ident;
35
36    let expanded = quote! {
37        #input
38
39        impl CookieName for #cookie_struct {
40            const COOKIE_NAME: &'static str = #cookie_name;
41        }
42    };
43
44    expanded.into()
45}
46
47/// Implements a FromRequest for a struct that holds cookie types
48///
49/// **Note**: only allows structs with either a single unnamed field or multiple unnamed fields
50#[proc_macro_derive(FromRequest)]
51pub fn cookie_collection(item: TokenStream) -> TokenStream {
52    let input = parse_macro_input!(item as DeriveInput);
53    let collection_struct = &input.ident;
54
55    // Extract the field types based on whether it's a tuple or named struct.
56    let (field_names, field_types) = match extract_fields_types(&input) {
57        Ok(fields) => fields,
58        Err(e) => return e.into_compile_error().into(),
59    };
60
61    // Extract the generic type argument from a Cookie<'c, SomeType> type.
62    let inner_types = field_types
63        .iter()
64        .try_fold(
65            Vec::new(),
66            |mut types, field_type| match extract_cookie_inner_type(field_type) {
67                Some(inner_type) => {
68                    types.push(inner_type);
69                    Ok(types)
70                }
71                None => Err(syn::Error::new_spanned(
72                    field_type,
73                    "Expected field type to be `Cookie<'c, SomeType>`",
74                )),
75            },
76        );
77
78    let inner_types = match inner_types {
79        Ok(types) => types,
80        Err(error) => return error.into_compile_error().into(),
81    };
82
83    let generated_types = if let Some(field_names) = field_names {
84        quote! { #collection_struct { #( #field_names: Cookie::<#inner_types>::new(&storage),)* }}
85    } else {
86        quote! { #collection_struct ( #( Cookie::<#inner_types>::new(&storage),)* )}
87    };
88
89    // Generate the implementation for FromRequest
90    let expanded = quote! {
91        impl actix_web::FromRequest for #collection_struct<'static> {
92            type Error = Box<dyn std::error::Error>;
93            type Future = std::future::Ready<Result<Self, Self::Error>>;
94
95            fn from_request(req: &actix_web::HttpRequest, _payload: &mut actix_web::dev::Payload) -> Self::Future {
96                match req.extensions().get::<cookiebox::Storage>() {
97                    Some(storage) => {
98                        std::future::ready(Ok( #generated_types ))
99                    }
100                    None => std::future::ready(Err("Storage not found in request extension".into())),
101                }
102            }
103        }
104    };
105
106    expanded.into()
107}
108
109fn extract_fields_types(
110    input: &DeriveInput,
111) -> Result<(Option<Vec<syn::Ident>>, Vec<&Type>), syn::Error> {
112    match &input.data {
113        syn::Data::Struct(data_struct) => match &data_struct.fields {
114            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
115                Ok((None, vec![&fields.unnamed[0].ty]))
116            }
117            Fields::Named(fields) => {
118                // Unwrap here is okay since Fields::Named require a field name which make a None ident value impossible to represent
119                let field_names = fields
120                    .named
121                    .iter()
122                    .map(|f| f.ident.clone().unwrap())
123                    .collect();
124                let field_types = fields.named.iter().map(|f| &f.ty).collect();
125                Ok((Some(field_names), field_types))
126            }
127            // Units and unnamed with more than 1 fields
128            token => Err(syn::Error::new_spanned(
129                token,
130                "Expected a single unnamed field or multiple named fields",
131            )),
132        },
133        // Enum and union
134        _ => Err(syn::Error::new_spanned(input, "Expected a struct")),
135    }
136}
137
138/// Extracts the inner type (SomeType) from a `Cookie<'c, SomeType>` type.
139fn extract_cookie_inner_type(field_type: &Type) -> Option<&Type> {
140    if let Type::Path(type_path) = field_type {
141        let segment = type_path.path.segments.first()?;
142        if segment.ident == "Cookie" {
143            if let PathArguments::AngleBracketed(generics) = &segment.arguments {
144                if generics.args.len() == 2 {
145                    if let syn::GenericArgument::Type(inner_type) = &generics.args[1] {
146                        return Some(inner_type);
147                    }
148                }
149            }
150        }
151    }
152    None
153}