Skip to main content

runar_serializer_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::collections::HashSet;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Ident, Type};
7
8fn parse_runar_labels(attr: &Attribute) -> Vec<String> {
9    if !attr.path().is_ident("runar") {
10        return vec![];
11    }
12    let parsed: Punctuated<Ident, Comma> =
13        attr.parse_args_with(Punctuated::parse_terminated).unwrap();
14    parsed.iter().map(|ident| ident.to_string()).collect()
15}
16
17fn label_to_camel_case(s: &str) -> String {
18    s.split(['_', '-'])
19        .map(|part| {
20            let mut chars = part.chars();
21            match chars.next() {
22                None => String::new(),
23                Some(f) => f.to_uppercase().collect::<String>() + chars.as_str(),
24            }
25        })
26        .collect()
27}
28
29#[proc_macro_derive(Plain)]
30pub fn derive_plain(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let struct_name = input.ident.clone();
33
34    let expanded = quote! {
35        impl runar_serializer::traits::RunarEncryptable for #struct_name {}
36
37        impl runar_serializer::traits::RunarEncrypt for #struct_name {
38            type Encrypted = #struct_name;
39
40            fn encrypt_with_keystore(
41                &self,
42                _keystore: &std::sync::Arc<runar_serializer::KeyStore>,
43                _resolver: &dyn runar_serializer::LabelResolver,
44            ) -> anyhow::Result<Self::Encrypted> {
45                Ok(self.clone())
46            }
47        }
48
49        impl runar_serializer::traits::RunarDecrypt for #struct_name {
50            type Decrypted = #struct_name;
51
52            fn decrypt_with_keystore(
53                &self,
54                _keystore: &std::sync::Arc<runar_serializer::KeyStore>,
55            ) -> anyhow::Result<Self::Decrypted> {
56                Ok(self.clone())
57            }
58        }
59
60        // Automatically register JSON converter for this struct at program start.
61        const _: () = {
62            #[ctor::ctor]
63            fn register_json_converter() {
64                runar_serializer::registry::register_to_json::<#struct_name>();
65            }
66        };
67    };
68
69    TokenStream::from(expanded)
70}
71
72#[proc_macro_derive(Encrypt, attributes(runar))]
73pub fn derive_encrypt(input: TokenStream) -> TokenStream {
74    let input = parse_macro_input!(input as DeriveInput);
75    let struct_name = input.ident.clone();
76    let encrypted_name = format_ident!("Encrypted{}", struct_name);
77
78    let mut plaintext_fields: Vec<(Ident, Type)> = Vec::new();
79    let mut label_groups: std::collections::BTreeMap<String, Vec<(Ident, Type)>> =
80        std::collections::BTreeMap::new();
81
82    if let Data::Struct(ds) = input.data {
83        if let Fields::Named(named) = ds.fields {
84            for field in named.named.iter() {
85                let field_ident = field.ident.clone().expect("Expected named field");
86                let field_ty = field.ty.clone();
87                let labels = field
88                    .attrs
89                    .iter()
90                    .flat_map(parse_runar_labels)
91                    .collect::<Vec<_>>();
92                if labels.is_empty() {
93                    plaintext_fields.push((field_ident, field_ty));
94                } else {
95                    for label in labels {
96                        label_groups
97                            .entry(label)
98                            .or_default()
99                            .push((field_ident.clone(), field_ty.clone()));
100                    }
101                }
102            }
103        } else {
104            return syn::Error::new_spanned(
105                struct_name,
106                "Encrypt derive only supports structs with named fields",
107            )
108            .to_compile_error()
109            .into();
110        }
111    } else {
112        return syn::Error::new_spanned(struct_name, "Encrypt derive only supports structs")
113            .to_compile_error()
114            .into();
115    }
116
117    let mut label_order: Vec<_> = label_groups.keys().cloned().collect();
118    label_order.sort_by(|a, b| {
119        let rank = |l: &String| match l.as_str() {
120            "system" => 0,
121            "user" => 1,
122            _ => 2,
123        };
124        rank(a).cmp(&rank(b)).then_with(|| a.cmp(b))
125    });
126
127    let mut substruct_defs = Vec::new();
128    let mut encrypt_label_match_arms = Vec::new();
129    let mut decrypt_label_blocks = Vec::new();
130    let mut enc_label_tokens = Vec::new();
131    let mut proto_plaintext_fields = Vec::new();
132
133    for label in label_order.iter() {
134        let fields = &label_groups[label];
135        let cap_label = label_to_camel_case(label);
136        let substruct_ident = format_ident!("{}{}Fields", struct_name, cap_label);
137        let group_field_ident = format_ident!("{}_encrypted", label);
138
139        let sub_fields_tokens: Vec<_> = fields
140            .iter()
141            .map(|(id, ty)| quote! { pub #id: #ty, })
142            .collect();
143        substruct_defs.push(quote! {
144            #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Default)]
145            struct #substruct_ident {
146                #(#sub_fields_tokens)*
147            }
148        });
149
150        let substruct_build_fields: Vec<_> = fields
151            .iter()
152            .map(|(id, _)| quote! { #id: self.#id.clone(), })
153            .collect();
154        let label_lit = syn::LitStr::new(label, proc_macro2::Span::call_site());
155        encrypt_label_match_arms.push(quote! {
156            #group_field_ident: if resolver.can_resolve(#label_lit) {
157                let group_struct = #substruct_ident { #(#substruct_build_fields)* };
158                Some(runar_serializer::encryption::encrypt_label_group(#label_lit, &group_struct, keystore.as_ref(), resolver)?)
159            } else {
160                None
161            },
162        });
163
164        let assign_fields: Vec<_> = fields
165            .iter()
166            .map(|(id, _)| quote! { decrypted.#id = tmp.#id; })
167            .collect();
168        decrypt_label_blocks.push(quote! {
169            if let Some(ref group) = self.#group_field_ident {
170                if let Ok(tmp) = runar_serializer::encryption::decrypt_label_group::<#substruct_ident>(group, keystore.as_ref()) {
171                    #(#assign_fields)*
172                }
173            }
174        });
175
176        enc_label_tokens.push(quote! { pub #group_field_ident: ::core::option::Option<runar_serializer::encryption::EncryptedLabelGroup>, });
177    }
178
179    for (fid, fty) in plaintext_fields.iter() {
180        proto_plaintext_fields.push(quote! { pub #fid: #fty, });
181    }
182
183    let encrypted_struct_def = quote! {
184        #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
185        pub struct #encrypted_name {
186            #(#proto_plaintext_fields)*
187            #(#enc_label_tokens)*
188        }
189    };
190
191    let encrypt_plaintext_inits: Vec<_> = plaintext_fields
192        .iter()
193        .map(|(id, _)| quote! { #id: self.#id.clone(), })
194        .collect();
195    let decrypted_plaintext_init: Vec<_> = plaintext_fields
196        .iter()
197        .map(|(id, _)| quote! { #id: self.#id.clone(), })
198        .collect();
199    let mut seen = HashSet::new();
200    let labeled_field_defaults: Vec<_> = label_groups
201        .values()
202        .flat_map(|f| f.iter().map(|(id, _)| quote! { #id: Default::default(), }))
203        .filter(|tok| {
204            let s = tok.to_string();
205            if seen.contains(&s) {
206                false
207            } else {
208                seen.insert(s);
209                true
210            }
211        })
212        .collect();
213
214    let encrypt_impl = quote! { let encrypted = #encrypted_name { #(#encrypt_plaintext_inits)* #(#encrypt_label_match_arms)* }; Ok(encrypted) };
215
216    let decrypt_impl = quote! { let mut decrypted = #struct_name { #(#decrypted_plaintext_init)* #(#labeled_field_defaults)* }; #(#decrypt_label_blocks)* Ok(decrypted) };
217
218    let expanded = quote! {
219        #(#substruct_defs)*
220        #encrypted_struct_def
221
222        impl runar_serializer::traits::RunarEncryptable for #struct_name {}
223
224        impl runar_serializer::traits::RunarEncrypt for #struct_name {
225            type Encrypted = #encrypted_name;
226
227            fn encrypt_with_keystore(
228                &self,
229                keystore: &std::sync::Arc<runar_serializer::KeyStore>,
230                resolver: &dyn runar_serializer::LabelResolver,
231            ) -> anyhow::Result<Self::Encrypted> {
232                let encrypted = #encrypted_name { #(#encrypt_plaintext_inits)* #(#encrypt_label_match_arms)* };
233                Ok(encrypted)
234            }
235        }
236
237        impl runar_serializer::traits::RunarDecrypt for #encrypted_name {
238            type Decrypted = #struct_name;
239
240            fn decrypt_with_keystore(
241                &self,
242                keystore: &std::sync::Arc<runar_serializer::KeyStore>,
243            ) -> anyhow::Result<Self::Decrypted> {
244                let mut decrypted = #struct_name { #(#decrypted_plaintext_init)* #(#labeled_field_defaults)* };
245                #(#decrypt_label_blocks)*
246                Ok(decrypted)
247            }
248        }
249
250        impl #struct_name {
251            fn encrypt_with_keystore(
252                &self,
253                keystore: &std::sync::Arc<runar_serializer::KeyStore>,
254                resolver: &dyn runar_serializer::LabelResolver,
255            ) -> anyhow::Result<#encrypted_name> {
256                #encrypt_impl
257            }
258        }
259
260        impl #encrypted_name {
261            fn decrypt_with_keystore(
262                &self,
263                keystore: &std::sync::Arc<runar_serializer::KeyStore>,
264            ) -> anyhow::Result<#struct_name> {
265                #decrypt_impl
266            }
267        }
268
269        // Automatically register decryptor for this struct at program start.
270        const _: () = {
271            #[ctor::ctor]
272            fn register_decryptor() {
273                runar_serializer::registry::register_decrypt::<#struct_name, #encrypted_name>();
274            }
275        };
276
277        // Automatically register JSON converter for this struct at program start.
278        const _: () = {
279            #[ctor::ctor]
280            fn register_json_converter() {
281                runar_serializer::registry::register_to_json::<#struct_name>();
282            }
283        };
284
285        // Mark encrypted struct as RunarEncryptable so it can appear inside ArcValue without further bounds.
286        impl runar_serializer::traits::RunarEncryptable for #encrypted_name {}
287    };
288
289    TokenStream::from(expanded)
290}
291
292/// Decryption derive is just an alias
293#[proc_macro_derive(Decrypt, attributes(runar))]
294pub fn derive_decrypt(input: TokenStream) -> TokenStream {
295    derive_encrypt(input)
296}
297
298/// No-op attribute macro to allow `#[runar(...)]` field annotations.
299#[proc_macro_attribute]
300pub fn runar(_attr: TokenStream, item: TokenStream) -> TokenStream {
301    item
302}