1#![deny(warnings)]
6#![deny(missing_docs)]
7extern crate proc_macro;
8
9use proc_macro::TokenStream;
10use quote::{ quote, format_ident };
11use syn::{parse_macro_input, ItemFn, ReturnType, Result, FnArg, Pat, Ident};
12use syn::parse::{Parse, ParseStream};
13use proc_macro2::TokenTree;
14use syn::token::Comma;
15
16enum CheckTime {
17 #[allow(dead_code)]
18 Require,
19 #[allow(dead_code)]
20 Ensure,
21 RequireAndEnsure,
22}
23
24struct AttrList {
25 #[allow(dead_code)]
26 invariant_function_identifier: Ident,
27 #[allow(dead_code)]
28 rest: Vec<TokenTree>,
29}
30
31impl Parse for AttrList {
32 fn parse(input: ParseStream) -> Result<Self> {
33 let first_ident: Ident = input.parse()?;
34
35 if input.is_empty() {
36 return Ok(AttrList { invariant_function_identifier: first_ident, rest: vec![] });
37 }
38
39 let mut rest = Vec::new();
40
41 while !input.is_empty() {
42 let _: Comma = input.parse()?;
43 let item: TokenTree = input.parse()?;
44 rest.push(item);
45 }
46
47 Ok(AttrList { invariant_function_identifier: first_ident, rest })
48 }
49}
50
51#[proc_macro_attribute]
123pub fn contract(attr: TokenStream, item: TokenStream) -> TokenStream {
124 let mut check_time = None;
127
128 let attr = parse_macro_input!(attr as AttrList);
129 let invariant_name = attr.invariant_function_identifier;
130
131 for item in attr.rest.into_iter() {
132 match item {
133 TokenTree::Literal(literal) => {
134 let msg = literal.to_string();
135 match msg.as_str() {
136 "\"require\"" => check_time = Some(CheckTime::Require),
137 "\"ensure\"" => check_time = Some(CheckTime::Ensure),
138 "\"require_and_ensure\"" => check_time = Some(CheckTime::RequireAndEnsure),
139 _ => panic!("Invalid check time: {}, expected one of: \"require\", \"ensure\", \"require_and_ensure\"", msg)
140 }
141 }
142 _ => {}
143 }
144 }
145
146 let check_time = check_time.unwrap_or(CheckTime::RequireAndEnsure);
147
148 let input_fn = parse_macro_input!(item as ItemFn);
150 let input_fn_name = &input_fn.sig.ident;
151 let input_fn_body = &input_fn.block;
152
153 let args = &input_fn.sig.inputs;
154 let arg_names: Vec<Ident> = args
155 .iter()
156 .filter_map(|arg| {
157 if let FnArg::Typed(pat) = arg {
158 if let Pat::Ident(pat_ident) = &*pat.pat {
159 return Some(pat_ident.ident.clone());
160 }
161 }
162 None
163 })
164 .collect();
165
166 let _self_arg = match args.first() {
167 Some(FnArg::Receiver(receiver)) => receiver,
168 _ => panic!("The input function must have a self argument"),
169 };
170
171 let return_type = match &input_fn.sig.output {
172 ReturnType::Default => None,
173 ReturnType::Type(_, ty) => Some(quote! { #ty }),
174 };
175
176 let fn_without_invariant = format_ident!("{}_no_invariant", input_fn_name);
178
179 let wrapped_function = match &return_type {
180 None => quote! {
181 fn #fn_without_invariant(#args) {
182 #input_fn_body
183 }
184 },
185 Some(return_type) => quote! {
186 fn #fn_without_invariant(#args) -> #return_type {
187 #input_fn_body
188 }
189 }
190 };
191
192 let call_invariant_before = match check_time {
193 CheckTime::Require | CheckTime::RequireAndEnsure => quote! {
194 if !self.#invariant_name() {
195 panic!("Invariant {} failed on entry", stringify!(#invariant_name));
196 }
197 },
198 _ => quote! {},
199 };
200
201 let call_invariant_after = match check_time {
202 CheckTime::Ensure | CheckTime::RequireAndEnsure => quote! {
203 if !self.#invariant_name() {
204 panic!("Invariant {} failed on exit", stringify!(#invariant_name));
205 }
206 },
207 _ => quote! {},
208 };
209
210 let call_wrapped = quote! {
211 self.#fn_without_invariant( #(#arg_names),*)
212 };
213
214 let invariant_checked_function = match return_type {
215 None => quote! {
216 fn #input_fn_name(#args) {
217 #call_invariant_before
218 #call_wrapped;
219 #call_invariant_after
220 }
221 },
222 Some(return_type) => quote! {
223 fn #input_fn_name(#args) -> #return_type {
224 #call_invariant_before
225 let result = #call_wrapped;
226 #call_invariant_after
227 result
228 }
229 }
230 };
231
232 let output = quote! {
234 #wrapped_function
235
236 #invariant_checked_function
237 };
238
239 output.into()
240}