runar_serializer_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::collections::HashSet;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Ident, Type};
7
8fn parse_runar_labels(attr: &Attribute) -> Vec<String> {
9 if !attr.path().is_ident("runar") {
10 return vec![];
11 }
12 let parsed: Punctuated<Ident, Comma> =
13 attr.parse_args_with(Punctuated::parse_terminated).unwrap();
14 parsed.iter().map(|ident| ident.to_string()).collect()
15}
16
17fn label_to_camel_case(s: &str) -> String {
18 s.split(['_', '-'])
19 .map(|part| {
20 let mut chars = part.chars();
21 match chars.next() {
22 None => String::new(),
23 Some(f) => f.to_uppercase().collect::<String>() + chars.as_str(),
24 }
25 })
26 .collect()
27}
28
29#[proc_macro_derive(Plain)]
30pub fn derive_plain(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let struct_name = input.ident.clone();
33
34 let expanded = quote! {
35 impl runar_serializer::traits::RunarEncryptable for #struct_name {}
36
37 impl runar_serializer::traits::RunarEncrypt for #struct_name {
38 type Encrypted = #struct_name;
39
40 fn encrypt_with_keystore(
41 &self,
42 _keystore: &std::sync::Arc<runar_serializer::KeyStore>,
43 _resolver: &dyn runar_serializer::LabelResolver,
44 ) -> anyhow::Result<Self::Encrypted> {
45 Ok(self.clone())
46 }
47 }
48
49 impl runar_serializer::traits::RunarDecrypt for #struct_name {
50 type Decrypted = #struct_name;
51
52 fn decrypt_with_keystore(
53 &self,
54 _keystore: &std::sync::Arc<runar_serializer::KeyStore>,
55 ) -> anyhow::Result<Self::Decrypted> {
56 Ok(self.clone())
57 }
58 }
59
60 const _: () = {
62 #[ctor::ctor]
63 fn register_json_converter() {
64 runar_serializer::registry::register_to_json::<#struct_name>();
65 }
66 };
67 };
68
69 TokenStream::from(expanded)
70}
71
72#[proc_macro_derive(Encrypt, attributes(runar))]
73pub fn derive_encrypt(input: TokenStream) -> TokenStream {
74 let input = parse_macro_input!(input as DeriveInput);
75 let struct_name = input.ident.clone();
76 let encrypted_name = format_ident!("Encrypted{}", struct_name);
77
78 let mut plaintext_fields: Vec<(Ident, Type)> = Vec::new();
79 let mut label_groups: std::collections::BTreeMap<String, Vec<(Ident, Type)>> =
80 std::collections::BTreeMap::new();
81
82 if let Data::Struct(ds) = input.data {
83 if let Fields::Named(named) = ds.fields {
84 for field in named.named.iter() {
85 let field_ident = field.ident.clone().expect("Expected named field");
86 let field_ty = field.ty.clone();
87 let labels = field
88 .attrs
89 .iter()
90 .flat_map(parse_runar_labels)
91 .collect::<Vec<_>>();
92 if labels.is_empty() {
93 plaintext_fields.push((field_ident, field_ty));
94 } else {
95 for label in labels {
96 label_groups
97 .entry(label)
98 .or_default()
99 .push((field_ident.clone(), field_ty.clone()));
100 }
101 }
102 }
103 } else {
104 return syn::Error::new_spanned(
105 struct_name,
106 "Encrypt derive only supports structs with named fields",
107 )
108 .to_compile_error()
109 .into();
110 }
111 } else {
112 return syn::Error::new_spanned(struct_name, "Encrypt derive only supports structs")
113 .to_compile_error()
114 .into();
115 }
116
117 let mut label_order: Vec<_> = label_groups.keys().cloned().collect();
118 label_order.sort_by(|a, b| {
119 let rank = |l: &String| match l.as_str() {
120 "system" => 0,
121 "user" => 1,
122 _ => 2,
123 };
124 rank(a).cmp(&rank(b)).then_with(|| a.cmp(b))
125 });
126
127 let mut substruct_defs = Vec::new();
128 let mut encrypt_label_match_arms = Vec::new();
129 let mut decrypt_label_blocks = Vec::new();
130 let mut enc_label_tokens = Vec::new();
131 let mut proto_plaintext_fields = Vec::new();
132
133 for label in label_order.iter() {
134 let fields = &label_groups[label];
135 let cap_label = label_to_camel_case(label);
136 let substruct_ident = format_ident!("{}{}Fields", struct_name, cap_label);
137 let group_field_ident = format_ident!("{}_encrypted", label);
138
139 let sub_fields_tokens: Vec<_> = fields
140 .iter()
141 .map(|(id, ty)| quote! { pub #id: #ty, })
142 .collect();
143 substruct_defs.push(quote! {
144 #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Default)]
145 struct #substruct_ident {
146 #(#sub_fields_tokens)*
147 }
148 });
149
150 let substruct_build_fields: Vec<_> = fields
151 .iter()
152 .map(|(id, _)| quote! { #id: self.#id.clone(), })
153 .collect();
154 let label_lit = syn::LitStr::new(label, proc_macro2::Span::call_site());
155 encrypt_label_match_arms.push(quote! {
156 #group_field_ident: if resolver.can_resolve(#label_lit) {
157 let group_struct = #substruct_ident { #(#substruct_build_fields)* };
158 Some(runar_serializer::encryption::encrypt_label_group(#label_lit, &group_struct, keystore.as_ref(), resolver)?)
159 } else {
160 None
161 },
162 });
163
164 let assign_fields: Vec<_> = fields
165 .iter()
166 .map(|(id, _)| quote! { decrypted.#id = tmp.#id; })
167 .collect();
168 decrypt_label_blocks.push(quote! {
169 if let Some(ref group) = self.#group_field_ident {
170 if let Ok(tmp) = runar_serializer::encryption::decrypt_label_group::<#substruct_ident>(group, keystore.as_ref()) {
171 #(#assign_fields)*
172 }
173 }
174 });
175
176 enc_label_tokens.push(quote! { pub #group_field_ident: ::core::option::Option<runar_serializer::encryption::EncryptedLabelGroup>, });
177 }
178
179 for (fid, fty) in plaintext_fields.iter() {
180 proto_plaintext_fields.push(quote! { pub #fid: #fty, });
181 }
182
183 let encrypted_struct_def = quote! {
184 #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
185 pub struct #encrypted_name {
186 #(#proto_plaintext_fields)*
187 #(#enc_label_tokens)*
188 }
189 };
190
191 let encrypt_plaintext_inits: Vec<_> = plaintext_fields
192 .iter()
193 .map(|(id, _)| quote! { #id: self.#id.clone(), })
194 .collect();
195 let decrypted_plaintext_init: Vec<_> = plaintext_fields
196 .iter()
197 .map(|(id, _)| quote! { #id: self.#id.clone(), })
198 .collect();
199 let mut seen = HashSet::new();
200 let labeled_field_defaults: Vec<_> = label_groups
201 .values()
202 .flat_map(|f| f.iter().map(|(id, _)| quote! { #id: Default::default(), }))
203 .filter(|tok| {
204 let s = tok.to_string();
205 if seen.contains(&s) {
206 false
207 } else {
208 seen.insert(s);
209 true
210 }
211 })
212 .collect();
213
214 let encrypt_impl = quote! { let encrypted = #encrypted_name { #(#encrypt_plaintext_inits)* #(#encrypt_label_match_arms)* }; Ok(encrypted) };
215
216 let decrypt_impl = quote! { let mut decrypted = #struct_name { #(#decrypted_plaintext_init)* #(#labeled_field_defaults)* }; #(#decrypt_label_blocks)* Ok(decrypted) };
217
218 let expanded = quote! {
219 #(#substruct_defs)*
220 #encrypted_struct_def
221
222 impl runar_serializer::traits::RunarEncryptable for #struct_name {}
223
224 impl runar_serializer::traits::RunarEncrypt for #struct_name {
225 type Encrypted = #encrypted_name;
226
227 fn encrypt_with_keystore(
228 &self,
229 keystore: &std::sync::Arc<runar_serializer::KeyStore>,
230 resolver: &dyn runar_serializer::LabelResolver,
231 ) -> anyhow::Result<Self::Encrypted> {
232 let encrypted = #encrypted_name { #(#encrypt_plaintext_inits)* #(#encrypt_label_match_arms)* };
233 Ok(encrypted)
234 }
235 }
236
237 impl runar_serializer::traits::RunarDecrypt for #encrypted_name {
238 type Decrypted = #struct_name;
239
240 fn decrypt_with_keystore(
241 &self,
242 keystore: &std::sync::Arc<runar_serializer::KeyStore>,
243 ) -> anyhow::Result<Self::Decrypted> {
244 let mut decrypted = #struct_name { #(#decrypted_plaintext_init)* #(#labeled_field_defaults)* };
245 #(#decrypt_label_blocks)*
246 Ok(decrypted)
247 }
248 }
249
250 impl #struct_name {
251 fn encrypt_with_keystore(
252 &self,
253 keystore: &std::sync::Arc<runar_serializer::KeyStore>,
254 resolver: &dyn runar_serializer::LabelResolver,
255 ) -> anyhow::Result<#encrypted_name> {
256 #encrypt_impl
257 }
258 }
259
260 impl #encrypted_name {
261 fn decrypt_with_keystore(
262 &self,
263 keystore: &std::sync::Arc<runar_serializer::KeyStore>,
264 ) -> anyhow::Result<#struct_name> {
265 #decrypt_impl
266 }
267 }
268
269 const _: () = {
271 #[ctor::ctor]
272 fn register_decryptor() {
273 runar_serializer::registry::register_decrypt::<#struct_name, #encrypted_name>();
274 }
275 };
276
277 const _: () = {
279 #[ctor::ctor]
280 fn register_json_converter() {
281 runar_serializer::registry::register_to_json::<#struct_name>();
282 }
283 };
284
285 impl runar_serializer::traits::RunarEncryptable for #encrypted_name {}
287 };
288
289 TokenStream::from(expanded)
290}
291
292#[proc_macro_derive(Decrypt, attributes(runar))]
294pub fn derive_decrypt(input: TokenStream) -> TokenStream {
295 derive_encrypt(input)
296}
297
298#[proc_macro_attribute]
300pub fn runar(_attr: TokenStream, item: TokenStream) -> TokenStream {
301 item
302}