Skip to main content

anodized_core/instrument/fns/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use crate::{
5    Capture, PostCondition, PreCondition, Spec,
6    instrument::{Config, build_assert, build_eprint},
7    qualifiers::FnQualifiers,
8};
9
10use proc_macro2::Span;
11use quote::{ToTokens, quote};
12use syn::{
13    Attribute, Block, Expr, Ident, Pat, PatIdent, Path, ReturnType, Signature, Stmt, Type,
14    parse::{Parse, Result},
15    parse_quote,
16};
17
18impl Config {
19    pub fn instrument_fn(&self, spec: &Spec, sig: &Signature, body: &mut Block) -> syn::Result<()> {
20        self.instrument_loops_in_fn_body(body)?;
21
22        let is_async = sig.asyncness.is_some();
23
24        // Extract the return type from the function signature
25        let return_type = match &sig.output {
26            syn::ReturnType::Default => syn::parse_quote!(()),
27            syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
28        };
29
30        // Generate the new, instrumented function body.
31        let new_body = self.instrument_fn_body(spec, body, is_async, &return_type)?;
32
33        // Replace the old function body with the new one.
34        *body = new_body;
35
36        Ok(())
37    }
38
39    pub fn build_spec_fn_sig(prefix: &str, sig: &Signature) -> Signature {
40        Signature {
41            constness: sig.constness,
42            asyncness: sig.asyncness,
43            unsafety: sig.unsafety,
44            abi: sig.abi.clone(),
45            fn_token: sig.fn_token,
46            ident: syn::Ident::new(&format!("{prefix}_{}", sig.ident), sig.ident.span()),
47            generics: sig.generics.clone(),
48            paren_token: sig.paren_token,
49            inputs: sig.inputs.clone(),
50            variadic: sig.variadic.clone(),
51            output: syn::ReturnType::Default,
52        }
53    }
54
55    pub fn build_qualifier_const_item<SomeConstItem: Parse>(
56        attrs: &[Attribute],
57        prefix: &str,
58        qualifiers: FnQualifiers,
59        fn_ident: &Ident,
60    ) -> SomeConstItem {
61        let qualifier_bits = qualifiers.bits();
62        let name: Ident = syn::Ident::new(&format!("{}_{}", prefix, fn_ident), fn_ident.span());
63        parse_quote! {
64            #(#attrs)*
65            const #name: u32 = #qualifier_bits;
66        }
67    }
68
69    pub fn build_qualifier_check_stmt(
70        fn_ident: &Ident,
71        impl_type: &Type,
72        trait_path: &Path,
73    ) -> Stmt {
74        let impl_const_name = Ident::new(
75            &format!("__anodized_fn_qualifiers_{}", fn_ident),
76            fn_ident.span(),
77        );
78
79        let trait_const_name = Ident::new(
80            &format!("__anodized_fn_qualifiers_trait_{}", fn_ident),
81            fn_ident.span(),
82        );
83
84        let message = format!(
85            "the qualifiers on the impl `{}::{fn_ident}` cannot be weaker than the qualifiers on the trait `{}::{fn_ident}`",
86            impl_type.to_token_stream(),
87            trait_path.to_token_stream(),
88        );
89
90        parse_quote! {
91            const {
92                assert!(
93                    Self::#impl_const_name == Self::#trait_const_name | Self::#impl_const_name,
94                    #message,
95                );
96            };
97        }
98    }
99
100    pub fn build_precondition_fn_body(conditions: &[PreCondition]) -> Block {
101        let statements = conditions.iter().map(|condition| -> Stmt {
102            let closure = &condition.closure;
103            parse_quote! { let _ = #closure; }
104        });
105        parse_quote! {
106            {
107                #(#statements)*
108            }
109        }
110    }
111
112    pub fn build_poscondition_fn_body(
113        captures: &[Capture],
114        conditions: &[PostCondition],
115        return_type: &ReturnType,
116    ) -> Result<Block> {
117        let aliases = captures.iter().map(|capture| &capture.pat);
118        let capture_exprs = captures.iter().map(|capture| -> Expr {
119            let expr = &capture.expr;
120            // Wrap in closure to guard against `return`.
121            parse_quote! { (|| #expr)() }
122        });
123
124        let mut statements = vec![];
125
126        for condition in conditions {
127            let closure = &condition.closure;
128            // TODO: This sort of validation should happen during parsing.
129            let output_binder = match closure.inputs.first() {
130                Some(output_binder) if closure.inputs.len() == 1 => output_binder,
131                _ => {
132                    return Err(syn::Error::new_spanned(
133                        &closure.inputs,
134                        "Postcondition closure must have exactly one parameter.",
135                    ));
136                }
137            };
138            let statement: Stmt = if let Pat::Type(_) = output_binder {
139                // If the output binder has a type annotation, use as-is.
140                parse_quote! { let _ = #closure; }
141            } else {
142                // Otherwise add a type annotation.
143                let body = &closure.body;
144                match &return_type {
145                    ReturnType::Default => {
146                        parse_quote! { let _ = |#output_binder: &()| #body; }
147                    }
148                    ReturnType::Type(_, ty) => {
149                        parse_quote! { let _ = |#output_binder: &#ty| #body; }
150                    }
151                }
152            };
153            statements.push(statement);
154        }
155
156        Ok(parse_quote! {
157            {
158                let (#(#aliases),*) = (#(#capture_exprs),*);
159                #(#statements)*
160            }
161        })
162    }
163
164    fn instrument_fn_body(
165        &self,
166        spec: &Spec,
167        original_body: &Block,
168        is_async: bool,
169        return_type: &syn::Type,
170    ) -> Result<Block> {
171        let build_check = match (self.emit_print, self.emit_panic) {
172            (true, true) => build_assert,
173            (true, false) => build_eprint,
174            (false, true) => build_assert,
175            (false, false) => return Ok(original_body.clone()),
176        };
177
178        // The identifier for the return value binding.
179        let output_ident = Pat::Ident(PatIdent {
180            attrs: vec![],
181            by_ref: None,
182            mutability: None,
183            ident: Ident::new("__anodized_output", Span::mixed_site()),
184            subpat: None,
185        });
186
187        // --- Generate Precondition Checks ---
188        let precondition_checks = spec
189            .requires
190            .iter()
191            .map(|condition| {
192                let closure = condition.closure.to_token_stream();
193                let expr = quote! { (#closure)() };
194                let repr = condition.closure.body.to_token_stream();
195                build_check(
196                    condition.cfg.as_ref(),
197                    &expr,
198                    "Precondition failed: {}",
199                    &repr,
200                )
201            })
202            .chain(spec.maintains.iter().map(|condition| {
203                let closure = condition.closure.to_token_stream();
204                let expr = quote! { (#closure)() };
205                let repr = condition.closure.body.to_token_stream();
206                build_check(
207                    condition.cfg.as_ref(),
208                    &expr,
209                    "Pre-invariant failed: {}",
210                    &repr,
211                )
212            }));
213
214        // --- Generate Combined Body and Capture Statement ---
215        // Capture values and execute body in a single tuple assignment
216        // This ensures captured values aren't accessible to the body itself
217
218        // Chain capture aliases with output binding
219        let aliases = spec
220            .captures
221            .iter()
222            .map(|cb| &cb.pat)
223            .chain(std::iter::once(&output_ident));
224
225        // Chain capture expressions with body expression
226        let capture_exprs = spec.captures.iter().map(|cb| {
227            let expr = &cb.expr;
228            // Evaluate expression in a closure to prevent early return.
229            quote! { (|| #expr)() }
230        });
231
232        // Chain underscore types with return type for tuple type annotation
233        let types = spec
234            .captures
235            .iter()
236            .map(|_| quote! { _ })
237            .chain(std::iter::once(quote! { #return_type }));
238
239        let body_expr = if is_async {
240            quote! { (async || #original_body)().await }
241        } else {
242            quote! { (|| #original_body)() }
243        };
244
245        let exprs = capture_exprs.chain(std::iter::once(body_expr));
246
247        // Build tuple assignment with type annotation on the tuple
248        let body_and_captures = quote! {
249            let (#(#aliases),*): (#(#types),*) = (#(#exprs),*);
250        };
251
252        // --- Generate Postcondition Checks ---
253        let postcondition_checks = spec
254            .maintains
255            .iter()
256            .map(|condition| {
257                let closure = condition.closure.to_token_stream();
258                let expr = quote! { (#closure)() };
259                let repr = condition.closure.body.to_token_stream();
260                build_check(
261                    condition.cfg.as_ref(),
262                    &expr,
263                    "Post-invariant failed: {}",
264                    &repr,
265                )
266            })
267            .chain(spec.ensures.iter().map(|postcondition| {
268                let closure = annotate_postcondition_closure_argument(
269                    postcondition.closure.clone(),
270                    return_type.clone(),
271                );
272
273                let expr = quote! { (#closure)(&#output_ident) };
274                build_check(
275                    postcondition.cfg.as_ref(),
276                    &expr,
277                    "Postcondition failed: {}",
278                    &postcondition.closure.to_token_stream(),
279                )
280            }));
281
282        Ok(parse_quote! {
283            {
284                #(#precondition_checks)*
285                #body_and_captures
286                #(#postcondition_checks)*
287                #output_ident
288            }
289        })
290    }
291}
292
293fn annotate_postcondition_closure_argument(
294    mut closure: syn::ExprClosure,
295    return_type: syn::Type,
296) -> syn::ExprClosure {
297    // Add type annotation: convert |param| to |param: &ReturnType|.
298    if let Some(first_input) = closure.inputs.first_mut() {
299        // Wrap the pattern with a type annotation
300        let pattern = first_input.clone();
301        *first_input = syn::Pat::Type(syn::PatType {
302            attrs: vec![],
303            pat: Box::new(pattern),
304            colon_token: Default::default(),
305            ty: Box::new(syn::Type::Reference(syn::TypeReference {
306                and_token: Default::default(),
307                lifetime: None,
308                mutability: None,
309                elem: Box::new(return_type),
310            })),
311        });
312    }
313    closure
314}