1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
//! This create contains the procedural macros
//!
//! Mostly the procedural macro [check_invariant](macro@check_invariant) which is used to check if a given invariant holds true before and after a method call.
//!
#![deny(warnings)]
#![deny(missing_docs)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::{ quote, format_ident };
use syn::{parse_macro_input, ItemFn, ReturnType, Result, FnArg, Pat, Ident};
use syn::parse::{Parse, ParseStream};
use proc_macro2::TokenTree;
use syn::token::Comma;
enum CheckTime {
#[allow(dead_code)]
Before,
#[allow(dead_code)]
After,
BeforeAndAfter,
}
struct AttrList {
#[allow(dead_code)]
invariant_function_identifier: Ident,
#[allow(dead_code)]
rest: Vec<TokenTree>,
}
impl Parse for AttrList {
fn parse(input: ParseStream) -> Result<Self> {
let first_ident: Ident = input.parse()?;
if input.is_empty() {
return Ok(AttrList { invariant_function_identifier: first_ident, rest: vec![] });
}
let mut rest = Vec::new();
while !input.is_empty() {
let _: Comma = input.parse()?;
let item: TokenTree = input.parse()?;
rest.push(item);
}
Ok(AttrList { invariant_function_identifier: first_ident, rest })
}
}
/// `check_invariant` is a procedural macro that checks if a given invariant holds true before and after a method call.
/// If the invariant does not hold, the macro will cause the program to panic with a specified message.
///
/// # Arguments
///
/// * `invariant`: A method that returns a boolean. This is the invariant that needs to be checked.
/// * `check_time`: An optional string literal that specifies when the invariant should be checked.
/// * `"before"` - The invariant is checked before the operation.
/// * `"after"` - The invariant is checked after the operation.
/// * `"before_and_after"` - The invariant is checked both before and after the operation.
///
/// # Example
///
/// ```
/// use eiffel_macros_gen::check_invariant;
///
/// struct MyClass {
/// // Fields
/// a: i32,
/// };
///
/// impl MyClass {
/// fn my_invariant(&self) -> bool {
/// // Your invariant checks here
/// true
/// }
///
/// #[check_invariant(my_invariant)]
/// fn my_method(&self) {
/// // Method body
/// println!("Method body {:?}", self.a);
/// }
///
/// // Only check the invariant before the method call
/// #[check_invariant(my_invariant, "before")]
/// fn my_other_method(&self) {
/// // Method body
/// println!("Method body {:?}", self.a);
/// }
///
/// // Only check the invariant after the method call
/// #[check_invariant(my_invariant, "after")]
/// fn my_other_method_after(&self) {
/// // Method body
/// println!("Method body {:?}", self.a);
/// }
///
/// // Only check the invariant before and after (default)
/// #[check_invariant(my_invariant, "before_and_after")]
/// fn my_other_method_before_and_after(&self) {
/// // Method body
/// println!("Method body {:?}", self.a);
/// }
///
/// }
/// ```
///
/// # Test
///
/// ```
/// #[cfg(test)]
/// mod tests {
/// use super::*;
///
/// #[test]
/// fn test_my_method() {
/// let my_class = MyClass;
/// my_class.my_method(); // This should not panic as the invariant is true
/// }
/// }
/// ```
#[proc_macro_attribute]
pub fn check_invariant(attr: TokenStream, item: TokenStream) -> TokenStream {
// let invariant_name = parse_macro_input!(attr as Ident);
// let check_time = CheckTime::BeforeAndAfter;
let mut check_time = None;
let attr = parse_macro_input!(attr as AttrList);
let invariant_name = attr.invariant_function_identifier;
for item in attr.rest.into_iter() {
match item {
TokenTree::Literal(literal) => {
let msg = literal.to_string();
match msg.as_str() {
"\"before\"" => check_time = Some(CheckTime::Before),
"\"after\"" => check_time = Some(CheckTime::After),
"\"before_and_after\"" => check_time = Some(CheckTime::BeforeAndAfter),
_ => panic!("Invalid check time: {}, expected one of: \"before\", \"after\", \"before_and_after\"", msg)
}
}
_ => {}
}
}
let check_time = check_time.unwrap_or(CheckTime::BeforeAndAfter);
// Extract the name, arguments, and return type of the input function
let input_fn = parse_macro_input!(item as ItemFn);
let input_fn_name = &input_fn.sig.ident;
let input_fn_body = &input_fn.block;
let args = &input_fn.sig.inputs;
let arg_names: Vec<Ident> = args
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat) = arg {
if let Pat::Ident(pat_ident) = &*pat.pat {
return Some(pat_ident.ident.clone());
}
}
None
})
.collect();
let _self_arg = match args.first() {
Some(FnArg::Receiver(receiver)) => receiver,
_ => panic!("The input function must have a self argument"),
};
let return_type = match &input_fn.sig.output {
ReturnType::Default => None,
ReturnType::Type(_, ty) => Some(quote! { #ty }),
};
// Rename the original function
let fn_without_invariant = format_ident!("{}_no_invariant", input_fn_name);
let wrapped_function = match &return_type {
None => quote! {
fn #fn_without_invariant(#args) {
#input_fn_body
}
},
Some(return_type) => quote! {
fn #fn_without_invariant(#args) -> #return_type {
#input_fn_body
}
}
};
let call_invariant_before = match check_time {
CheckTime::Before | CheckTime::BeforeAndAfter => quote! {
if !self.#invariant_name() {
panic!("Invariant {} failed on entry", stringify!(#invariant_name));
}
},
_ => quote! {},
};
let call_invariant_after = match check_time {
CheckTime::After | CheckTime::BeforeAndAfter => quote! {
if !self.#invariant_name() {
panic!("Invariant {} failed on exit", stringify!(#invariant_name));
}
},
_ => quote! {},
};
let call_wrapped = quote! {
self.#fn_without_invariant( #(#arg_names),*)
};
let invariant_checked_function = match return_type {
None => quote! {
fn #input_fn_name(#args) {
#call_invariant_before
#call_wrapped;
#call_invariant_after
}
},
Some(return_type) => quote! {
fn #input_fn_name(#args) -> #return_type {
#call_invariant_before
let result = #call_wrapped;
#call_invariant_after
result
}
}
};
// Generate the wrapper code
let output = quote! {
#wrapped_function
#invariant_checked_function
};
output.into()
}