eiffel_macros_gen/
lib.rs

1//! This create contains the procedural macros
2//! 
3//! Mostly the procedural macro [contract](macro@contract) which is used to check if a given invariant holds true before and after a method call.
4//!
5#![deny(warnings)]
6#![deny(missing_docs)]
7extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use quote::{ quote, format_ident };
11use syn::{parse_macro_input, ItemFn, ReturnType, Result, FnArg, Pat, Ident};
12use syn::parse::{Parse, ParseStream};
13use proc_macro2::TokenTree;
14use syn::token::Comma;
15
16enum CheckTime {
17    #[allow(dead_code)]
18    Require,
19    #[allow(dead_code)]
20    Ensure,
21    RequireAndEnsure,
22}
23
24struct AttrList {
25    #[allow(dead_code)]
26    invariant_function_identifier: Ident,
27    #[allow(dead_code)]
28    rest: Vec<TokenTree>,
29}
30
31impl Parse for AttrList {
32    fn parse(input: ParseStream) -> Result<Self> {
33        let first_ident: Ident = input.parse()?;
34
35        if input.is_empty() {
36            return Ok(AttrList { invariant_function_identifier: first_ident, rest: vec![] });
37        }
38
39        let mut rest = Vec::new();
40
41        while !input.is_empty() {
42            let _: Comma = input.parse()?;
43            let item: TokenTree = input.parse()?;
44            rest.push(item);
45        }
46
47        Ok(AttrList { invariant_function_identifier: first_ident, rest })
48    }
49}
50
51/// `contract` is a procedural macro that checks if a given invariant holds true before and after a method call.
52/// If the invariant does not hold, the macro will cause the program to panic with a specified message.
53/// 
54/// # Arguments
55/// 
56/// * `invariant`: A struct method identifier that returns a boolean. This is the invariant that needs to be checked.
57/// * `check_time`: An optional string literal that specifies when the invariant should be checked.
58///   * `"require"` - The invariant is checked before the operation.
59///   * `"ensure"` - The invariant is checked after the operation.
60///   * `"require_and_ensure"` - The invariant is checked both before and after the operation.
61/// 
62/// # Example
63///
64/// ```
65/// use eiffel_macros_gen::contract;
66/// 
67/// struct MyClass {
68///     // Fields
69///     a: i32,
70/// };
71///
72/// impl MyClass {
73///     fn my_invariant(&self) -> bool {
74///         // Your invariant checks here
75///         true
76///     }
77///
78///     #[contract(my_invariant)]
79///     fn my_method(&self) {
80///         // Method body
81///         println!("Method body {:?}", self.a);
82///     }
83///
84///     // Only check the invariant before the method call
85///     #[contract(my_invariant, "require")]
86///     fn my_other_method(&self) {
87///         // Method body
88///         println!("Method body {:?}", self.a);
89///     }
90///
91///     // Only check the invariant after the method call
92///     #[contract(my_invariant, "ensure")]
93///     fn my_other_method_after(&self) {
94///         // Method body
95///         println!("Method body {:?}", self.a);
96///     }
97///
98///     // Only check the invariant before and after (default)
99///     #[contract(my_invariant, "require_and_ensure")]
100///     fn my_other_method_before_and_after(&self) {
101///         // Method body
102///         println!("Method body {:?}", self.a);
103///     }
104///
105/// }       
106/// ```
107///
108/// # Test
109///
110/// ```
111/// #[cfg(test)]
112/// mod tests {
113///     use super::*;
114///
115///     #[test]
116///     fn test_my_method() {
117///         let my_class = MyClass;
118///         my_class.my_method(); // This should not panic as the invariant is true
119///     }
120/// }
121/// ```
122#[proc_macro_attribute]
123pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream {
124    // let invariant_name = parse_macro_input!(attr as Ident);
125    // let check_time = CheckTime::RequireAndEnsure;
126    let mut check_time = None;
127    
128    let attr = parse_macro_input!(attr as AttrList);
129    let invariant_name = attr.invariant_function_identifier;
130
131    for item in attr.rest.into_iter() {
132        match item {
133            TokenTree::Literal(literal) => {
134                let msg = literal.to_string();
135                match msg.as_str() {
136                    "\"require\"" => check_time = Some(CheckTime::Require),
137                    "\"ensure\"" => check_time = Some(CheckTime::Ensure),
138                    "\"require_and_ensure\"" => check_time = Some(CheckTime::RequireAndEnsure),
139                    _ => panic!("Invalid check time: {}, expected one of: \"require\", \"ensure\", \"require_and_ensure\"", msg)
140                }
141            }
142            _ => {}
143        }
144    }
145
146    let check_time = check_time.unwrap_or(CheckTime::RequireAndEnsure);
147
148    // Extract the name, arguments, and return type of the input function
149    let input_fn = parse_macro_input!(item as ItemFn);
150    let input_fn_name = &input_fn.sig.ident;
151    let input_fn_body = &input_fn.block;
152
153    let args = &input_fn.sig.inputs;
154    let arg_names: Vec<Ident> = args
155        .iter()
156        .filter_map(|arg| {
157            if let FnArg::Typed(pat) = arg {
158                if let Pat::Ident(pat_ident) = &*pat.pat {
159                    return Some(pat_ident.ident.clone());
160                }
161            }
162            None
163        })
164        .collect();
165    
166    let _self_arg = match args.first() {
167        Some(FnArg::Receiver(receiver)) => receiver,
168        _ => panic!("The input function must have a self argument"),
169    };
170
171    let return_type = match &input_fn.sig.output {
172        ReturnType::Default => None,
173        ReturnType::Type(_, ty) => Some(quote! { #ty }),
174    };
175
176    // Rename the original function
177    let fn_without_invariant = format_ident!("{}_no_invariant", input_fn_name);
178    
179    let wrapped_function = match &return_type {
180        None => quote! {
181            fn #fn_without_invariant(#args) { 
182                #input_fn_body
183            }
184        },
185        Some(return_type) => quote! {
186            fn #fn_without_invariant(#args) -> #return_type { 
187                #input_fn_body
188            }
189        }
190    };
191    
192    let call_invariant_before = match check_time {
193        CheckTime::Require | CheckTime::RequireAndEnsure => quote! {
194            if !self.#invariant_name() {
195                panic!("Invariant {} failed on entry", stringify!(#invariant_name));
196            }
197        },
198        _ => quote! {},
199    };
200
201    let call_invariant_after = match check_time {
202        CheckTime::Ensure | CheckTime::RequireAndEnsure => quote! {
203            if !self.#invariant_name() {
204                panic!("Invariant {} failed on exit", stringify!(#invariant_name));
205            }
206        },
207        _ => quote! {},
208    };
209
210    let call_wrapped = quote! {
211        self.#fn_without_invariant( #(#arg_names),*)
212    };
213
214    let invariant_checked_function = match return_type {
215        None => quote! {
216            fn #input_fn_name(#args) { 
217                #call_invariant_before
218                #call_wrapped;
219                #call_invariant_after
220            }
221        },
222        Some(return_type) => quote! {
223            fn #input_fn_name(#args) -> #return_type {
224                #call_invariant_before
225                let result = #call_wrapped;
226                #call_invariant_after
227                result
228            }
229        }
230    };
231
232    // Generate the wrapper code
233    let output = quote! {
234        #wrapped_function
235    
236        #invariant_checked_function
237    };
238
239    output.into()
240}