Skip to main content

sanitization_derive/
lib.rs

1#![deny(unsafe_code)]
2#![deny(unsafe_op_in_unsafe_fn)]
3
4use proc_macro::TokenStream;
5use proc_macro2::TokenStream as TokenStream2;
6use quote::{format_ident, quote};
7use syn::{
8    parse_macro_input, parse_quote, Attribute, Data, DataEnum, DataStruct, DeriveInput, Error,
9    Fields, Generics, LitStr, Path, Result, WherePredicate,
10};
11
12/// Derive `sanitization::SecureSanitize` for structs and enums.
13///
14/// Every non-skipped field must implement `SecureSanitize`. Use
15/// `#[sanitization(skip)]` only for fields that are intentionally non-secret or
16/// cleared elsewhere.
17#[proc_macro_derive(SecureSanitize, attributes(sanitization))]
18pub fn derive_secure_sanitize(input: TokenStream) -> TokenStream {
19    let input = parse_macro_input!(input as DeriveInput);
20    expand_secure_sanitize(&input)
21        .unwrap_or_else(Error::into_compile_error)
22        .into()
23}
24
25/// Derive `Drop` by calling `sanitization::SecureSanitize::secure_sanitize`.
26///
27/// # Generics
28///
29/// For structs with type parameters that hold sanitizable data, the parameter
30/// must carry the `SecureSanitize` bound at the type declaration:
31///
32/// ```ignore
33/// use sanitization::SecureSanitize;
34///
35/// #[derive(SecureSanitize, SecureSanitizeOnDrop)]
36/// struct Wrapper<T: SecureSanitize> {
37///     inner: T,
38/// }
39/// ```
40///
41/// This is a Rust `Drop` restriction: the generated `Drop` impl cannot add a
42/// stricter `T: SecureSanitize` bound than the struct declaration itself.
43#[proc_macro_derive(SecureSanitizeOnDrop, attributes(sanitization))]
44pub fn derive_secure_sanitize_on_drop(input: TokenStream) -> TokenStream {
45    let input = parse_macro_input!(input as DeriveInput);
46    expand_secure_sanitize_on_drop(&input)
47        .unwrap_or_else(Error::into_compile_error)
48        .into()
49}
50
51#[derive(Default)]
52struct ContainerOptions {
53    crate_path: Option<Path>,
54    bound_override: Option<Vec<WherePredicate>>,
55}
56
57#[derive(Default)]
58struct FieldOptions {
59    skip: bool,
60    bound_override: Option<Vec<WherePredicate>>,
61}
62
63fn expand_secure_sanitize(input: &DeriveInput) -> Result<TokenStream2> {
64    let options = parse_container_options(&input.attrs)?;
65    let crate_path = crate_path(&options);
66    let body = match &input.data {
67        Data::Struct(data) => expand_struct_body(data, &crate_path)?,
68        Data::Enum(data) => expand_enum_body(data, &crate_path)?,
69        Data::Union(_) => {
70            return Err(Error::new_spanned(
71                input,
72                "SecureSanitize cannot be derived for unions; implement it manually using documented unsafe code for the active field",
73            ))
74        }
75    };
76    let generics = add_sanitize_bounds(input.generics.clone(), &input.data, &crate_path, &options)?;
77    let name = &input.ident;
78    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
79
80    Ok(quote! {
81        impl #impl_generics #crate_path::SecureSanitize for #name #type_generics #where_clause {
82            #[inline]
83            fn secure_sanitize(&mut self) {
84                #body
85            }
86        }
87    })
88}
89
90fn expand_secure_sanitize_on_drop(input: &DeriveInput) -> Result<TokenStream2> {
91    let options = parse_container_options(&input.attrs)?;
92    let crate_path = crate_path(&options);
93
94    if matches!(input.data, Data::Union(_)) {
95        return Err(Error::new_spanned(
96            input,
97            "SecureSanitizeOnDrop cannot be derived for unions",
98        ));
99    }
100
101    let name = &input.ident;
102    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
103
104    Ok(quote! {
105        impl #impl_generics Drop for #name #type_generics #where_clause {
106            #[inline]
107            fn drop(&mut self) {
108                #crate_path::SecureSanitize::secure_sanitize(self);
109            }
110        }
111    })
112}
113
114fn crate_path(options: &ContainerOptions) -> Path {
115    options
116        .crate_path
117        .clone()
118        .unwrap_or_else(|| parse_quote!(::sanitization))
119}
120
121fn add_sanitize_bounds(
122    mut generics: Generics,
123    data: &Data,
124    crate_path: &Path,
125    options: &ContainerOptions,
126) -> Result<Generics> {
127    let where_clause = generics.make_where_clause();
128
129    if let Some(bounds) = &options.bound_override {
130        where_clause.predicates.extend(bounds.iter().cloned());
131        return Ok(generics);
132    }
133
134    for field in sanitized_fields(data)? {
135        let field_options = parse_field_options(&field.attrs)?;
136        if field_options.skip {
137            continue;
138        }
139
140        if let Some(bounds) = field_options.bound_override {
141            where_clause.predicates.extend(bounds);
142        } else {
143            let ty = &field.ty;
144            where_clause
145                .predicates
146                .push(parse_quote!(#ty: #crate_path::SecureSanitize));
147        }
148    }
149
150    Ok(generics)
151}
152
153fn sanitized_fields(data: &Data) -> Result<Vec<&syn::Field>> {
154    let mut fields = Vec::new();
155    match data {
156        Data::Struct(data) => fields.extend(data.fields.iter()),
157        Data::Enum(data) => {
158            for variant in &data.variants {
159                fields.extend(variant.fields.iter());
160            }
161        }
162        Data::Union(_) => {}
163    }
164    Ok(fields)
165}
166
167fn expand_struct_body(data: &DataStruct, crate_path: &Path) -> Result<TokenStream2> {
168    let calls = field_calls_for_struct(&data.fields, crate_path)?;
169    Ok(quote!(#(#calls)*))
170}
171
172fn field_calls_for_struct(fields: &Fields, crate_path: &Path) -> Result<Vec<TokenStream2>> {
173    let mut calls = Vec::new();
174
175    for (index, field) in fields.iter().enumerate() {
176        if parse_field_options(&field.attrs)?.skip {
177            continue;
178        }
179
180        let access = match &field.ident {
181            Some(ident) => quote!(&mut self.#ident),
182            None => {
183                let index = syn::Index::from(index);
184                quote!(&mut self.#index)
185            }
186        };
187        calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#access);));
188    }
189
190    Ok(calls)
191}
192
193fn expand_enum_body(data: &DataEnum, crate_path: &Path) -> Result<TokenStream2> {
194    let mut arms = Vec::new();
195
196    for variant in &data.variants {
197        let variant_ident = &variant.ident;
198        let (pattern, calls) = match &variant.fields {
199            Fields::Named(fields) => {
200                let mut bindings = Vec::new();
201                let mut calls = Vec::new();
202                for field in &fields.named {
203                    let ident = field.ident.as_ref().expect("named field");
204                    if parse_field_options(&field.attrs)?.skip {
205                        continue;
206                    }
207                    bindings.push(quote!(#ident));
208                    calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#ident);));
209                }
210
211                let pattern = if bindings.is_empty() {
212                    quote!(Self::#variant_ident { .. })
213                } else {
214                    quote!(Self::#variant_ident { #(#bindings),*, .. })
215                };
216                (pattern, calls)
217            }
218            Fields::Unnamed(fields) => {
219                let mut pattern_fields = Vec::new();
220                let mut calls = Vec::new();
221                for (index, field) in fields.unnamed.iter().enumerate() {
222                    if parse_field_options(&field.attrs)?.skip {
223                        pattern_fields.push(quote!(_));
224                    } else {
225                        let binding = format_ident!("field_{index}");
226                        pattern_fields.push(quote!(#binding));
227                        calls.push(quote!(#crate_path::SecureSanitize::secure_sanitize(#binding);));
228                    }
229                }
230                (quote!(Self::#variant_ident(#(#pattern_fields),*)), calls)
231            }
232            Fields::Unit => (quote!(Self::#variant_ident), Vec::new()),
233        };
234
235        arms.push(quote!(#pattern => { #(#calls)* }));
236    }
237
238    Ok(quote! {
239        match self {
240            #(#arms),*
241        }
242    })
243}
244
245fn parse_container_options(attrs: &[Attribute]) -> Result<ContainerOptions> {
246    let mut options = ContainerOptions::default();
247
248    for attr in attrs
249        .iter()
250        .filter(|attr| attr.path().is_ident("sanitization"))
251    {
252        attr.parse_nested_meta(|meta| {
253            if meta.path.is_ident("crate") {
254                let value = meta.value()?;
255                let literal: LitStr = value.parse()?;
256                options.crate_path = Some(literal.parse()?);
257                Ok(())
258            } else if meta.path.is_ident("bound") {
259                let value = meta.value()?;
260                let literal: LitStr = value.parse()?;
261                options.bound_override = Some(parse_bounds(&literal)?);
262                Ok(())
263            } else {
264                Err(meta.error("unsupported sanitization container attribute"))
265            }
266        })?;
267    }
268
269    Ok(options)
270}
271
272fn parse_field_options(attrs: &[Attribute]) -> Result<FieldOptions> {
273    let mut options = FieldOptions::default();
274
275    for attr in attrs
276        .iter()
277        .filter(|attr| attr.path().is_ident("sanitization"))
278    {
279        attr.parse_nested_meta(|meta| {
280            if meta.path.is_ident("skip") {
281                options.skip = true;
282                Ok(())
283            } else if meta.path.is_ident("bound") {
284                let value = meta.value()?;
285                let literal: LitStr = value.parse()?;
286                options.bound_override = Some(parse_bounds(&literal)?);
287                Ok(())
288            } else {
289                Err(meta.error("unsupported sanitization field attribute"))
290            }
291        })?;
292    }
293
294    Ok(options)
295}
296
297fn parse_bounds(literal: &LitStr) -> Result<Vec<WherePredicate>> {
298    let text = literal.value();
299    if text.trim().is_empty() {
300        return Ok(Vec::new());
301    }
302
303    let where_clause: syn::WhereClause = syn::parse_str(&format!("where {text}"))
304        .map_err(|error| Error::new(literal.span(), error))?;
305    Ok(where_clause.predicates.into_iter().collect())
306}