Skip to main content

anodized_core/instrument/traits/
mod.rs

1use quote::quote;
2use syn::{FnArg, ImplItem, ItemFn, Pat, TraitItem, parse_quote};
3
4use crate::{
5    Spec,
6    instrument::{Backend, find_spec_attr, make_item_error},
7};
8
9impl Backend {
10    /// Expand trait items by mangling each method and adding a wrapper default impl.
11    ///
12    /// Mangling a function involves the following:
13    /// 1. Rename the function following the pattern: `fn add` -> `fn __anodized_add`.
14    /// 2. Make a new function with the original name that has a default impl; the
15    ///    default impl performs runtime validation and calls the mangled function.
16    pub fn instrument_trait(
17        &self,
18        spec: Spec,
19        mut the_trait: syn::ItemTrait,
20    ) -> syn::Result<syn::ItemTrait> {
21        // Currently we don't support any spec arguments for traits themselves.
22        if !spec.is_empty() {
23            return Err(spec.spec_err(
24                "Unsupported spec element on trait. Try placing it on an item inside the trait",
25            ));
26        }
27
28        let mut new_trait_items = Vec::with_capacity(the_trait.items.len() * 2);
29
30        for item in the_trait.items.into_iter() {
31            match item {
32                TraitItem::Fn(mut func) => {
33                    let (spec_attr, other_attrs) = find_spec_attr(func.attrs)?;
34
35                    // NOTE: We have no way of knowing which attributes are
36                    //   "external" - meant for the interface and belong on the wrapper,
37                    //   "internal" - meant for the mangled implementation.
38                    //   Right now we put all attribs on both functions, but that's certainly
39                    //   not going to work in every situation.
40                    func.attrs = other_attrs.clone();
41
42                    let original_ident = func.sig.ident.clone();
43                    let mangled_ident = mangle_ident(&original_ident);
44
45                    let mut mangled_fn = func.clone();
46                    mangled_fn.sig.ident = mangled_ident.clone();
47                    mangled_fn.attrs.retain(|attr| !attr.path().is_ident("doc"));
48                    mangled_fn.attrs.push(parse_quote!(#[doc(hidden)]));
49
50                    let call_args = build_call_args(&func.sig.inputs)?;
51                    let mut wrapper_block: syn::Block = parse_quote!({
52                        Self::#mangled_ident(#(#call_args),*)
53                    });
54
55                    if let Some(spec_attr) = spec_attr {
56                        let spec = spec_attr.parse_args()?;
57                        let wrapper_item = ItemFn {
58                            attrs: Vec::new(),
59                            vis: syn::Visibility::Inherited,
60                            sig: func.sig.clone(),
61                            block: Box::new(wrapper_block),
62                        };
63                        let instrumented = self.instrument_fn(spec, wrapper_item)?;
64                        wrapper_block = *instrumented.block;
65                    }
66
67                    let mut wrapper_fn = func;
68                    wrapper_fn.attrs = other_attrs;
69                    wrapper_fn.default = Some(wrapper_block);
70                    wrapper_fn.semi_token = None;
71
72                    new_trait_items.push(TraitItem::Fn(mangled_fn));
73                    new_trait_items.push(TraitItem::Fn(wrapper_fn));
74                }
75                TraitItem::Const(mut const_item) => {
76                    let (spec, attrs) = find_spec_attr(const_item.attrs)?;
77                    if let Some(ref spec_attr) = spec {
78                        return Err(make_item_error(&spec_attr, "trait const"));
79                    }
80                    const_item.attrs = attrs;
81                    new_trait_items.push(TraitItem::Const(const_item));
82                }
83                TraitItem::Type(mut type_item) => {
84                    let (spec, attrs) = find_spec_attr(type_item.attrs)?;
85                    if let Some(ref spec_attr) = spec {
86                        return Err(make_item_error(&spec_attr, "trait type"));
87                    }
88                    type_item.attrs = attrs;
89                    new_trait_items.push(TraitItem::Type(type_item));
90                }
91                TraitItem::Macro(mut macro_item) => {
92                    let (spec, attrs) = find_spec_attr(macro_item.attrs)?;
93                    if let Some(ref spec_attr) = spec {
94                        return Err(make_item_error(&spec_attr, "trait macro"));
95                    }
96                    macro_item.attrs = attrs;
97                    new_trait_items.push(TraitItem::Macro(macro_item));
98                }
99                TraitItem::Verbatim(token_stream) => {
100                    new_trait_items.push(TraitItem::Verbatim(token_stream));
101                }
102                _ => unimplemented!(),
103            }
104        }
105        the_trait.items = new_trait_items;
106        Ok(the_trait)
107    }
108
109    /// Expand impl items by mangling methods for trait impls
110    ///
111    /// `#[spec]` attributes on the impl items themselves are not allowed.
112    pub fn instrument_trait_impl(
113        &self,
114        spec: Spec,
115        mut the_impl: syn::ItemImpl,
116    ) -> syn::Result<syn::ItemImpl> {
117        let Some((trait_bang, ref _trait_path, _trait_for)) = the_impl.trait_ else {
118            return Err(make_item_error(&the_impl, "inherent impl"));
119        };
120
121        if trait_bang.is_some() {
122            return Err(make_item_error(&the_impl, "negative trait impl"));
123        }
124
125        if !spec.is_empty() {
126            return Err(spec.spec_err("Unsupported spec element on trait impl."));
127        }
128
129        let mut new_items = Vec::with_capacity(the_impl.items.len());
130
131        for item in the_impl.items.into_iter() {
132            let new_item = match item {
133                ImplItem::Fn(mut func) => {
134                    let (spec, mut func_attrs) = find_spec_attr(func.attrs)?;
135                    if let Some(ref spec_attr) = spec {
136                        return Err(make_item_error(&spec_attr, "trait impl fn"));
137                    }
138
139                    let original_ident = func.sig.ident;
140                    if original_ident.to_string().starts_with("__anodized_") {
141                        return Err(syn::Error::new_spanned(
142                            &original_ident,
143                            r#"An item with the `__anodized_` prefix is internal. Do not implement it directly.
144Instead, ensure that both the trait and the impl fn have a `#[spec]` annotation."#,
145                        ));
146                    }
147                    func.sig.ident = mangle_ident(&original_ident);
148
149                    // Add a default `#[inline]` attribute unless one is already there.
150                    // The caller can supress this with `#[inline(never)]`
151                    if !has_inline_attr(&func_attrs) {
152                        func_attrs.push(parse_quote!(#[inline]));
153                    }
154
155                    func.attrs = func_attrs;
156                    ImplItem::Fn(func)
157                }
158                ImplItem::Const(mut const_item) => {
159                    let (spec, attrs) = find_spec_attr(const_item.attrs)?;
160                    if let Some(ref spec_attr) = spec {
161                        return Err(make_item_error(&spec_attr, "trait impl const"));
162                    }
163                    const_item.attrs = attrs;
164                    ImplItem::Const(const_item)
165                }
166                ImplItem::Type(mut type_item) => {
167                    let (spec, attrs) = find_spec_attr(type_item.attrs)?;
168                    if let Some(ref spec_attr) = spec {
169                        return Err(make_item_error(&spec_attr, "trait impl type"));
170                    }
171                    type_item.attrs = attrs;
172                    ImplItem::Type(type_item)
173                }
174                ImplItem::Macro(mut macro_item) => {
175                    let (spec, attrs) = find_spec_attr(macro_item.attrs)?;
176                    if let Some(ref spec_attr) = spec {
177                        return Err(make_item_error(&spec_attr, "trait impl macro"));
178                    }
179                    macro_item.attrs = attrs;
180                    ImplItem::Macro(macro_item)
181                }
182                ImplItem::Verbatim(token_stream) => ImplItem::Verbatim(token_stream),
183                _ => unimplemented!(),
184            };
185
186            new_items.push(new_item);
187        }
188
189        the_impl.items = new_items;
190        Ok(the_impl)
191    }
192}
193
194/// Build argument tokens for calling the mangled trait method from the wrapper.
195///
196/// Purpose: the wrapper method needs to forward its arguments to the mangled
197/// implementation, so this extracts a usable token for each input.
198///
199/// Examples (inputs -> output tokens):
200/// - `fn f(&self, x: i32)` -> `self, x`
201/// - `fn f(self, a: u8, b: u8)` -> `self, a, b`
202///
203/// The caller is responsible for ensuring these tokens are used in a call
204/// expression like `Self::__anodized_f(#(#args),*)`.
205///
206/// Callers: only `instrument_trait` in this module should use this; it is not
207/// part of the public API.
208fn build_call_args(
209    inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
210) -> syn::Result<Vec<proc_macro2::TokenStream>> {
211    let mut args = Vec::new();
212    for input in inputs.iter() {
213        match input {
214            FnArg::Receiver(_) => {
215                args.push(quote! { self });
216            }
217            FnArg::Typed(pat) => match pat.pat.as_ref() {
218                Pat::Ident(pat_ident) => {
219                    let ident = &pat_ident.ident;
220                    args.push(quote! { #ident });
221                }
222                _ => {
223                    return Err(syn::Error::new_spanned(
224                        &pat.pat,
225                        "unsupported pattern in trait method arguments",
226                    ));
227                }
228            },
229        }
230    }
231    Ok(args)
232}
233
234/// Prefix an identifier with `__anodized_`, preserving the original span.
235/// Used when generating mangled method names in trait and impl expansion.
236fn mangle_ident(original_ident: &syn::Ident) -> syn::Ident {
237    syn::Ident::new(
238        &format!("__anodized_{original_ident}"),
239        original_ident.span(),
240    )
241}
242
243/// Checks to see if any `#[inline]` (with or without arg) exists in the function's attribs.
244fn has_inline_attr(attrs: &[syn::Attribute]) -> bool {
245    attrs.iter().any(|attr| attr.path().is_ident("inline"))
246}