1#![deny(unsafe_code)]
2#![deny(unsafe_op_in_unsafe_fn)]
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{format_ident, quote};
7use syn::{
8 parse_macro_input, parse_quote, Attribute, Data, DataEnum, DataStruct, DeriveInput, Error,
9 Fields, Generics, LitStr, Path, Result, WherePredicate,
10};
11
12#[proc_macro_derive(SecureSanitize, attributes(sanitization))]
33pub fn derive_secure_sanitize(input: TokenStream) -> TokenStream {
34 let input = parse_macro_input!(input as DeriveInput);
35 expand_secure_sanitize(&input)
36 .unwrap_or_else(Error::into_compile_error)
37 .into()
38}
39
40#[proc_macro_derive(SecureSanitizeOnDrop, attributes(sanitization))]
59pub fn derive_secure_sanitize_on_drop(input: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(input as DeriveInput);
61 expand_secure_sanitize_on_drop(&input)
62 .unwrap_or_else(Error::into_compile_error)
63 .into()
64}
65
66#[proc_macro_derive(ConstantTimeEq, attributes(sanitization))]
77pub fn derive_constant_time_eq(input: TokenStream) -> TokenStream {
78 let input = parse_macro_input!(input as DeriveInput);
79 expand_constant_time_eq(&input)
80 .unwrap_or_else(Error::into_compile_error)
81 .into()
82}
83
84#[proc_macro_derive(ConditionallySelectable, attributes(sanitization))]
94pub fn derive_conditionally_selectable(input: TokenStream) -> TokenStream {
95 let input = parse_macro_input!(input as DeriveInput);
96 expand_conditionally_selectable(&input)
97 .unwrap_or_else(Error::into_compile_error)
98 .into()
99}
100
101#[derive(Default)]
102struct ContainerOptions {
103 crate_path: Option<Path>,
104 bound_override: Option<Vec<WherePredicate>>,
105 enum_inactive_variant_bytes_acknowledged: bool,
106}
107
108#[derive(Default)]
109struct FieldOptions {
110 skip: bool,
111 bound_override: Option<Vec<WherePredicate>>,
112}
113
114fn expand_secure_sanitize(input: &DeriveInput) -> Result<TokenStream2> {
115 let options = parse_container_options(&input.attrs)?;
116 let crate_path = crate_path(&options);
117 let body = match &input.data {
118 Data::Struct(data) => expand_struct_body(data, &crate_path)?,
119 Data::Enum(data) => {
120 validate_enum_options(input, &options)?;
121 expand_enum_body(data, &crate_path)?
122 }
123 Data::Union(_) => {
124 return Err(Error::new_spanned(
125 input,
126 "SecureSanitize cannot be derived for unions; implement it manually using documented unsafe code for the active field",
127 ))
128 }
129 };
130 let generics = add_sanitize_bounds(input.generics.clone(), &input.data, &crate_path, &options)?;
131 let name = &input.ident;
132 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
133
134 Ok(quote! {
135 impl #impl_generics #crate_path::SecureSanitize for #name #type_generics #where_clause {
136 #[inline]
137 fn secure_sanitize(&mut self) {
138 #body
139 }
140 }
141 })
142}
143
144fn validate_enum_options(input: &DeriveInput, options: &ContainerOptions) -> Result<()> {
145 if cfg!(feature = "strict-enum-derive") && !options.enum_inactive_variant_bytes_acknowledged {
146 return Err(Error::new_spanned(
147 input,
148 "SecureSanitize enum derives are rejected by the strict-enum-derive feature unless #[sanitization(enum_inactive_variant_bytes = \"acknowledged\")] is present; derived enum sanitization only clears the active variant",
149 ));
150 }
151
152 Ok(())
153}
154
155fn expand_secure_sanitize_on_drop(input: &DeriveInput) -> Result<TokenStream2> {
156 let options = parse_container_options(&input.attrs)?;
157 let crate_path = crate_path(&options);
158
159 if matches!(input.data, Data::Union(_)) {
160 return Err(Error::new_spanned(
161 input,
162 "SecureSanitizeOnDrop cannot be derived for unions",
163 ));
164 }
165
166 let name = &input.ident;
167 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
168
169 Ok(quote! {
170 impl #impl_generics Drop for #name #type_generics #where_clause {
171 #[inline]
172 fn drop(&mut self) {
173 #crate_path::SecureSanitize::secure_sanitize(self);
174 }
175 }
176 })
177}
178
179fn expand_constant_time_eq(input: &DeriveInput) -> Result<TokenStream2> {
180 let options = parse_container_options(&input.attrs)?;
181 let crate_path = crate_path(&options);
182 let body = match &input.data {
183 Data::Struct(data) => expand_ct_eq_struct_body(data, &crate_path)?,
184 Data::Enum(_) => {
185 return Err(Error::new_spanned(
186 input,
187 "ConstantTimeEq cannot be derived for enums; compare explicit struct wrappers or implement the active-variant semantics manually",
188 ))
189 }
190 Data::Union(_) => {
191 return Err(Error::new_spanned(
192 input,
193 "ConstantTimeEq cannot be derived for unions; implement it manually using documented unsafe code for the active field",
194 ))
195 }
196 };
197 let trait_path: TokenStream2 = quote!(#crate_path::ct::ConstantTimeEq);
198 let generics = add_trait_bounds(
199 input.generics.clone(),
200 &input.data,
201 &trait_path,
202 &options,
203 SkipPolicy::Allow,
204 )?;
205 let name = &input.ident;
206 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
207
208 Ok(quote! {
209 impl #impl_generics #crate_path::ct::ConstantTimeEq for #name #type_generics #where_clause {
210 #[inline]
211 fn ct_eq(&self, other: &Self) -> #crate_path::ct::Choice {
212 #body
213 }
214 }
215 })
216}
217
218fn expand_conditionally_selectable(input: &DeriveInput) -> Result<TokenStream2> {
219 let options = parse_container_options(&input.attrs)?;
220 let crate_path = crate_path(&options);
221 let body = match &input.data {
222 Data::Struct(data) => expand_ct_select_struct_body(data, &crate_path)?,
223 Data::Enum(_) => {
224 return Err(Error::new_spanned(
225 input,
226 "ConditionallySelectable cannot be derived for enums; select explicit struct wrappers or implement the active-variant semantics manually",
227 ))
228 }
229 Data::Union(_) => {
230 return Err(Error::new_spanned(
231 input,
232 "ConditionallySelectable cannot be derived for unions; implement it manually using documented unsafe code for the active field",
233 ))
234 }
235 };
236 let trait_path: TokenStream2 = quote!(#crate_path::ct::ConditionallySelectable);
237 let generics = add_trait_bounds(
238 input.generics.clone(),
239 &input.data,
240 &trait_path,
241 &options,
242 SkipPolicy::Reject,
243 )?;
244 let name = &input.ident;
245 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
246
247 Ok(quote! {
248 impl #impl_generics #crate_path::ct::ConditionallySelectable for #name #type_generics #where_clause {
249 #[inline]
250 fn conditional_select(
251 left: &Self,
252 right: &Self,
253 choice: #crate_path::ct::Choice,
254 ) -> Self {
255 #body
256 }
257 }
258 })
259}
260
261fn crate_path(options: &ContainerOptions) -> Path {
262 options
263 .crate_path
264 .clone()
265 .unwrap_or_else(|| parse_quote!(::sanitization))
266}
267
268fn add_sanitize_bounds(
269 mut generics: Generics,
270 data: &Data,
271 crate_path: &Path,
272 options: &ContainerOptions,
273) -> Result<Generics> {
274 let where_clause = generics.make_where_clause();
275
276 if let Some(bounds) = &options.bound_override {
277 where_clause.predicates.extend(bounds.iter().cloned());
278 return Ok(generics);
279 }
280
281 for field in sanitized_fields(data)? {
282 let field_options = parse_field_options(&field.attrs)?;
283 if field_options.skip {
284 continue;
285 }
286
287 if let Some(bounds) = field_options.bound_override {
288 where_clause.predicates.extend(bounds);
289 } else {
290 let ty = &field.ty;
291 where_clause
292 .predicates
293 .push(parse_quote!(#ty: #crate_path::SecureSanitize));
294 }
295 }
296
297 Ok(generics)
298}
299
300#[derive(Clone, Copy)]
301enum SkipPolicy {
302 Allow,
303 Reject,
304}
305
306fn add_trait_bounds(
307 mut generics: Generics,
308 data: &Data,
309 trait_path: &TokenStream2,
310 options: &ContainerOptions,
311 skip_policy: SkipPolicy,
312) -> Result<Generics> {
313 let where_clause = generics.make_where_clause();
314
315 if let Some(bounds) = &options.bound_override {
316 where_clause.predicates.extend(bounds.iter().cloned());
317 return Ok(generics);
318 }
319
320 for field in sanitized_fields(data)? {
321 let field_options = parse_field_options(&field.attrs)?;
322 if field_options.skip {
323 if matches!(skip_policy, SkipPolicy::Reject) {
324 return Err(Error::new_spanned(
325 field,
326 "#[sanitization(skip)] is not supported for this derive because every output field must be constructed",
327 ));
328 }
329 continue;
330 }
331
332 if let Some(bounds) = field_options.bound_override {
333 where_clause.predicates.extend(bounds);
334 } else {
335 let ty = &field.ty;
336 where_clause.predicates.push(parse_quote!(#ty: #trait_path));
337 }
338 }
339
340 Ok(generics)
341}
342
343fn sanitized_fields(data: &Data) -> Result<Vec<&syn::Field>> {
344 let mut fields = Vec::new();
345 match data {
346 Data::Struct(data) => fields.extend(data.fields.iter()),
347 Data::Enum(data) => {
348 for variant in &data.variants {
349 fields.extend(variant.fields.iter());
350 }
351 }
352 Data::Union(_) => {}
353 }
354 Ok(fields)
355}
356
357fn expand_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
358 let calls = field_calls_for_struct(&data.fields, crate_path)?;
359 Ok(quote!(#(#calls)*))
360}
361
362fn expand_ct_eq_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
363 let mut calls = Vec::new();
364
365 for (index, field) in data.fields.iter().enumerate() {
366 if parse_field_options(&field.attrs)?.skip {
367 continue;
368 }
369
370 let (left, right) = match &field.ident {
371 Some(ident) => (quote!(&self.#ident), quote!(&other.#ident)),
372 None => {
373 let index = syn::Index::from(index);
374 (quote!(&self.#index), quote!(&other.#index))
375 }
376 };
377 calls.push(quote! {
378 result = result & #crate_path::ct::ConstantTimeEq::ct_eq(#left, #right);
379 });
380 }
381
382 Ok(quote! {
383 let mut result = #crate_path::ct::Choice::TRUE;
384 #(#calls)*
385 result
386 })
387}
388
389fn expand_ct_select_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
390 match &data.fields {
391 Fields::Named(fields) => {
392 let mut selected = Vec::new();
393 for field in &fields.named {
394 if parse_field_options(&field.attrs)?.skip {
395 return Err(Error::new_spanned(
396 field,
397 "#[sanitization(skip)] is not supported for ConditionallySelectable derives",
398 ));
399 }
400 let ident = field.ident.as_ref().expect("named field");
401 selected.push(quote! {
402 #ident: #crate_path::ct::ConditionallySelectable::conditional_select(
403 &left.#ident,
404 &right.#ident,
405 choice,
406 )
407 });
408 }
409 Ok(quote!(Self { #(#selected),* }))
410 }
411 Fields::Unnamed(fields) => {
412 let mut selected = Vec::new();
413 for (index, field) in fields.unnamed.iter().enumerate() {
414 if parse_field_options(&field.attrs)?.skip {
415 return Err(Error::new_spanned(
416 field,
417 "#[sanitization(skip)] is not supported for ConditionallySelectable derives",
418 ));
419 }
420 let index = syn::Index::from(index);
421 selected.push(quote! {
422 #crate_path::ct::ConditionallySelectable::conditional_select(
423 &left.#index,
424 &right.#index,
425 choice,
426 )
427 });
428 }
429 Ok(quote!(Self(#(#selected),*)))
430 }
431 Fields::Unit => Ok(quote!(Self)),
432 }
433}
434
435fn field_calls_for_struct(fields: &Fields, crate_path: &Path) -> Result<Vec<TokenStream2>> {
436 let mut calls = Vec::new();
437
438 for (index, field) in fields.iter().enumerate() {
439 if parse_field_options(&field.attrs)?.skip {
440 continue;
441 }
442
443 let access = match &field.ident {
444 Some(ident) => quote!(&mut self.#ident),
445 None => {
446 let index = syn::Index::from(index);
447 quote!(&mut self.#index)
448 }
449 };
450 calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#access);));
451 }
452
453 Ok(calls)
454}
455
456fn expand_enum_body(data: &DataEnum, crate_path: &Path) -> Result<TokenStream2> {
457 let mut arms = Vec::new();
458
459 for variant in &data.variants {
460 let variant_ident = &variant.ident;
461 let (pattern, calls) = match &variant.fields {
462 Fields::Named(fields) => {
463 let mut bindings = Vec::new();
464 let mut calls = Vec::new();
465 for field in &fields.named {
466 let ident = field.ident.as_ref().expect("named field");
467 if parse_field_options(&field.attrs)?.skip {
468 continue;
469 }
470 bindings.push(quote!(#ident));
471 calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#ident);));
472 }
473
474 let pattern = if bindings.is_empty() {
475 quote!(Self::#variant_ident { .. })
476 } else {
477 quote!(Self::#variant_ident { #(#bindings),*, .. })
478 };
479 (pattern, calls)
480 }
481 Fields::Unnamed(fields) => {
482 let mut pattern_fields = Vec::new();
483 let mut calls = Vec::new();
484 for (index, field) in fields.unnamed.iter().enumerate() {
485 if parse_field_options(&field.attrs)?.skip {
486 pattern_fields.push(quote!(_));
487 } else {
488 let binding = format_ident!("field_{index}");
489 pattern_fields.push(quote!(#binding));
490 calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#binding);));
491 }
492 }
493 (quote!(Self::#variant_ident(#(#pattern_fields),*)), calls)
494 }
495 Fields::Unit => (quote!(Self::#variant_ident), Vec::new()),
496 };
497
498 arms.push(quote!(#pattern => { #(#calls)* }));
499 }
500
501 Ok(quote! {
502 match self {
503 #(#arms),*
504 }
505 })
506}
507
508fn parse_container_options(attrs: &[Attribute]) -> Result<ContainerOptions> {
509 let mut options = ContainerOptions::default();
510
511 for attr in attrs
512 .iter()
513 .filter(|attr| attr.path().is_ident("sanitization"))
514 {
515 attr.parse_nested_meta(|meta| {
516 if meta.path.is_ident("crate") {
517 let value = meta.value()?;
518 let literal: LitStr = value.parse()?;
519 options.crate_path = Some(literal.parse()?);
520 Ok(())
521 } else if meta.path.is_ident("bound") {
522 let value = meta.value()?;
523 let literal: LitStr = value.parse()?;
524 options.bound_override = Some(parse_bounds(&literal)?);
525 Ok(())
526 } else if meta.path.is_ident("enum_inactive_variant_bytes") {
527 let value = meta.value()?;
528 let literal: LitStr = value.parse()?;
529 if literal.value() == "acknowledged" {
530 options.enum_inactive_variant_bytes_acknowledged = true;
531 Ok(())
532 } else {
533 Err(meta.error("enum_inactive_variant_bytes must be exactly \"acknowledged\""))
534 }
535 } else {
536 Err(meta.error("unsupported sanitization container attribute"))
537 }
538 })?;
539 }
540
541 Ok(options)
542}
543
544fn parse_field_options(attrs: &[Attribute]) -> Result<FieldOptions> {
545 let mut options = FieldOptions::default();
546
547 for attr in attrs
548 .iter()
549 .filter(|attr| attr.path().is_ident("sanitization"))
550 {
551 attr.parse_nested_meta(|meta| {
552 if meta.path.is_ident("skip") {
553 options.skip = true;
554 Ok(())
555 } else if meta.path.is_ident("bound") {
556 let value = meta.value()?;
557 let literal: LitStr = value.parse()?;
558 options.bound_override = Some(parse_bounds(&literal)?);
559 Ok(())
560 } else {
561 Err(meta.error("unsupported sanitization field attribute"))
562 }
563 })?;
564 }
565
566 Ok(options)
567}
568
569fn parse_bounds(literal: &LitStr) -> Result<Vec<WherePredicate>> {
570 let text = literal.value();
571 if text.trim().is_empty() {
572 return Ok(Vec::new());
573 }
574
575 let where_clause: syn::WhereClause = syn::parse_str(&format!("where {text}"))
576 .map_err(|error| Error::new(literal.span(), error))?;
577 Ok(where_clause.predicates.into_iter().collect())
578}