1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
//! This create contains the procedural macros
//! 
//! Mostly the procedural macro [check_invariant](macro@check_invariant) which is used to check if a given invariant holds true before and after a method call.
//!
#![deny(warnings)]
#![deny(missing_docs)]
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::{ quote, format_ident };
use syn::{parse_macro_input, ItemFn, ReturnType, Result, FnArg, Pat, Ident};
use syn::parse::{Parse, ParseStream};
use proc_macro2::TokenTree;
use syn::token::Comma;

enum CheckTime {
    #[allow(dead_code)]
    Before,
    #[allow(dead_code)]
    After,
    BeforeAndAfter,
}

struct AttrList {
    #[allow(dead_code)]
    invariant_function_identifier: Ident,
    #[allow(dead_code)]
    rest: Vec<TokenTree>,
}

impl Parse for AttrList {
    fn parse(input: ParseStream) -> Result<Self> {
        let first_ident: Ident = input.parse()?;

        if input.is_empty() {
            return Ok(AttrList { invariant_function_identifier: first_ident, rest: vec![] });
        }

        let mut rest = Vec::new();

        while !input.is_empty() {
            let _: Comma = input.parse()?;
            let item: TokenTree = input.parse()?;
            rest.push(item);
        }

        Ok(AttrList { invariant_function_identifier: first_ident, rest })
    }
}

/// `check_invariant` is a procedural macro that checks if a given invariant holds true before and after a method call.
/// If the invariant does not hold, the macro will cause the program to panic with a specified message.
/// 
/// # Arguments
/// 
/// * `invariant`: A method that returns a boolean. This is the invariant that needs to be checked.
/// * `check_time`: An optional string literal that specifies when the invariant should be checked.
///   * `"before"` - The invariant is checked before the operation.
///   * `"after"` - The invariant is checked after the operation.
///   * `"before_and_after"` - The invariant is checked both before and after the operation.
/// 
/// # Example
///
/// ```
/// use eiffel_macros_gen::check_invariant;
/// 
/// struct MyClass {
///     // Fields
///     a: i32,
/// };
///
/// impl MyClass {
///     fn my_invariant(&self) -> bool {
///         // Your invariant checks here
///         true
///     }
///
///     #[check_invariant(my_invariant)]
///     fn my_method(&self) {
///         // Method body
///         println!("Method body {:?}", self.a);
///     }
///
///     // Only check the invariant before the method call
///     #[check_invariant(my_invariant, "before")]
///     fn my_other_method(&self) {
///         // Method body
///         println!("Method body {:?}", self.a);
///     }
///
///     // Only check the invariant after the method call
///     #[check_invariant(my_invariant, "after")]
///     fn my_other_method_after(&self) {
///         // Method body
///         println!("Method body {:?}", self.a);
///     }
///
///     // Only check the invariant before and after (default)
///     #[check_invariant(my_invariant, "before_and_after")]
///     fn my_other_method_before_and_after(&self) {
///         // Method body
///         println!("Method body {:?}", self.a);
///     }
///
/// }       
/// ```
///
/// # Test
///
/// ```
/// #[cfg(test)]
/// mod tests {
///     use super::*;
///
///     #[test]
///     fn test_my_method() {
///         let my_class = MyClass;
///         my_class.my_method(); // This should not panic as the invariant is true
///     }
/// }
/// ```
#[proc_macro_attribute]
pub fn check_invariant(attr: TokenStream, item: TokenStream) -> TokenStream {
    // let invariant_name = parse_macro_input!(attr as Ident);
    // let check_time = CheckTime::BeforeAndAfter;
    let mut check_time = None;
    
    let attr = parse_macro_input!(attr as AttrList);
    let invariant_name = attr.invariant_function_identifier;

    for item in attr.rest.into_iter() {
        match item {
            TokenTree::Literal(literal) => {
                let msg = literal.to_string();
                match msg.as_str() {
                    "\"before\"" => check_time = Some(CheckTime::Before),
                    "\"after\"" => check_time = Some(CheckTime::After),
                    "\"before_and_after\"" => check_time = Some(CheckTime::BeforeAndAfter),
                    _ => panic!("Invalid check time: {}, expected one of: \"before\", \"after\", \"before_and_after\"", msg)
                }
            }
            _ => {}
        }
    }

    let check_time = check_time.unwrap_or(CheckTime::BeforeAndAfter);

    // Extract the name, arguments, and return type of the input function
    let input_fn = parse_macro_input!(item as ItemFn);
    let input_fn_name = &input_fn.sig.ident;
    let input_fn_body = &input_fn.block;

    let args = &input_fn.sig.inputs;
    let arg_names: Vec<Ident> = args
        .iter()
        .filter_map(|arg| {
            if let FnArg::Typed(pat) = arg {
                if let Pat::Ident(pat_ident) = &*pat.pat {
                    return Some(pat_ident.ident.clone());
                }
            }
            None
        })
        .collect();
    
    let _self_arg = match args.first() {
        Some(FnArg::Receiver(receiver)) => receiver,
        _ => panic!("The input function must have a self argument"),
    };

    let return_type = match &input_fn.sig.output {
        ReturnType::Default => None,
        ReturnType::Type(_, ty) => Some(quote! { #ty }),
    };

    // Rename the original function
    let fn_without_invariant = format_ident!("{}_no_invariant", input_fn_name);
    
    let wrapped_function = match &return_type {
        None => quote! {
            fn #fn_without_invariant(#args) { 
                #input_fn_body
            }
        },
        Some(return_type) => quote! {
            fn #fn_without_invariant(#args) -> #return_type { 
                #input_fn_body
            }
        }
    };

    let call_invariant_before = match check_time {
        CheckTime::Before | CheckTime::BeforeAndAfter => quote! {
            if !self.#invariant_name() {
                panic!("Invariant {} failed on entry", stringify!(#invariant_name));
            }
        },
        _ => quote! {},
    };

    let call_invariant_after = match check_time {
        CheckTime::After | CheckTime::BeforeAndAfter => quote! {
            if !self.#invariant_name() {
                panic!("Invariant {} failed on exit", stringify!(#invariant_name));
            }
        },
        _ => quote! {},
    };

    let call_wrapped = quote! {
        self.#fn_without_invariant( #(#arg_names),*)
    };

    let invariant_checked_function = match return_type {
        None => quote! {
            fn #input_fn_name(#args) { 
                #call_invariant_before
                #call_wrapped;
                #call_invariant_after
            }
        },
        Some(return_type) => quote! {
            fn #input_fn_name(#args) -> #return_type {
                #call_invariant_before
                let result = #call_wrapped;
                #call_invariant_after
                result
            }
        }
    };

    // Generate the wrapper code
    let output = quote! {
        #wrapped_function
    
        #invariant_checked_function
    };

    output.into()
}