ovunto_security_macros/
lib.rs

1extern crate proc_macro;
2extern crate proc_macro2;
3#[macro_use]
4extern crate quote;
5extern crate syn;
6
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Ident, Type};
10
11#[proc_macro_derive(Encrypt, attributes(encrypt))]
12pub fn encrypt(input: TokenStream) -> TokenStream {
13    struct Field {
14        encrypt: bool,
15        name: Ident,
16        ty: Type,
17    }
18
19    // Parse the string representation
20    let ast: DeriveInput = syn::parse(input).unwrap();
21
22    let name = &ast.ident;
23    let encrypted_struct_name = Ident::new(&format!("Encrypted{name}"), Span::call_site());
24
25    let fields = match ast.data {
26        Data::Struct(DataStruct {
27            fields: Fields::Named(FieldsNamed { named: fields, .. }),
28            ..
29        }) => fields,
30        _ => panic!("Invalid data"),
31    };
32
33    let fields = fields.iter().map(|field| {
34        let ident = field
35            .ident
36            .as_ref()
37            .expect("Named fields should have an identifier");
38        let ty = &field.ty;
39
40        // Check if the field has the `encrypt` attribute
41        let mut encrypt = false;
42        for attr in &field.attrs {
43            if attr.path().is_ident("encrypt") {
44                encrypt = true;
45                break;
46            }
47        }
48
49        Field {
50            encrypt,
51            name: ident.clone(),
52            ty: ty.clone(),
53        }
54    });
55
56    let fields_names = fields.clone().map(|f| f.name);
57
58    let encrypted_struct_fields = fields.clone().map(|f| {
59        let name = f.name;
60        let ty = f.ty;
61        if f.encrypt {
62            quote! { #name: <#ty as Encrypt>::Encrypted }
63        } else {
64            quote! { #name: #ty }
65        }
66    });
67
68    let encrypting_fields = fields.clone().map(|f| {
69        let name = f.name;
70        if f.encrypt {
71            quote! { self.#name.encrypt(key)? }
72        } else {
73            quote! { self.#name }
74        }
75    });
76
77    let tokens = quote! {
78        #[derive(Debug, Clone, Serialize, Deserialize)]
79        pub struct #encrypted_struct_name {
80            #(#encrypted_struct_fields),*
81        }
82
83        impl Encrypt for #name {
84            type Encrypted = #encrypted_struct_name;
85
86            fn encrypt(self, key: crate::Key) -> crate::Result<Self::Encrypted> {
87                Ok(#encrypted_struct_name {
88                    #(
89                        #fields_names: #encrypting_fields,
90                    )*
91                })
92            }
93        }
94    };
95
96    tokens.into()
97}
98
99#[proc_macro_derive(Decrypt, attributes(encrypt))]
100pub fn decrypt(input: TokenStream) -> TokenStream {
101    struct Field {
102        encrypt: bool,
103        name: Ident,
104    }
105
106    // Parse the string representation
107    let ast: DeriveInput = syn::parse(input).unwrap();
108
109    let name = &ast.ident;
110    let encrypted_struct_name = Ident::new(&format!("Encrypted{name}"), Span::call_site());
111
112    let fields = match ast.data {
113        Data::Struct(DataStruct {
114            fields: Fields::Named(FieldsNamed { named: fields, .. }),
115            ..
116        }) => fields,
117        _ => panic!("Invalid data"),
118    };
119
120    let fields = fields.iter().map(|field| {
121        let ident = field
122            .ident
123            .as_ref()
124            .expect("Named fields should have an identifier");
125
126        // Check if the field has the `encrypt` attribute
127        let mut encrypt = false;
128        for attr in &field.attrs {
129            if attr.path().is_ident("encrypt") {
130                encrypt = true;
131                break;
132            }
133        }
134
135        Field {
136            encrypt,
137            name: ident.clone(),
138        }
139    });
140
141    let fields_names = fields.clone().map(|f| f.name);
142
143    let decrypting_fields = fields.clone().map(|f| {
144        let name = f.name;
145        if f.encrypt {
146            quote! { self.#name.decrypt(key)? }
147        } else {
148            quote! { self.#name }
149        }
150    });
151
152    let tokens = quote! {
153        impl Decrypt<#name> for #encrypted_struct_name {
154            fn decrypt(self, key: crate::Key) -> crate::Result<#name> {
155                Ok(#name {
156                    #(
157                        #fields_names: #decrypting_fields,
158                    )*
159                })
160            }
161        }
162    };
163
164    tokens.into()
165}