1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
use syn::{
Block, Local, ExprField, ExprClosure, ExprBlock, TypeReference,
FnArg, ExprPath, ExprIf, ExprWhile, ExprForLoop,
punctuated::Punctuated, token::Comma, Expr, Member,
ItemFn, Pat, Type, Stmt, TypePath, ExprCall, PatIdent};
use proc_macro2::TokenStream as ProcTokenStream;
use crate::field_whitelist::WhitelistArgs;
use std::collections::HashSet;
use quote::quote;
pub fn assert_mutate_impl(macro_data: &WhitelistArgs, function: &ItemFn, restricted_mode: bool) -> ProcTokenStream {
// NOTE: For better code layout, we will require separate proc-macro for each field
// to be whitelisted type of #[struct_name: (field1, field2, field3, ...)].
let struct_name = ¯o_data.struct_name;
let whitelist = ¯o_data.values;
let mut errors: Vec<Error> = Vec::new();
let inputs: &Punctuated<FnArg, Comma> = &function.sig.inputs;
let block: &Box<Block> = &function.block;
// Track found instances for further mutation checks.
let mut found_instances = HashSet::new();
// Entry point: figure out the instance name by exploring the function input arguments.
// If not found, then need to parse the function body for the inner declaration check.
extract_instance_names(inputs, struct_name, &mut found_instances);
// Parse function recursively and as a state machine
// extract new definitions on the way if needed.
check_block_for_mutation(
block,
whitelist,
&mut found_instances,
struct_name,
&mut errors,
restricted_mode);
if !errors.is_empty() {
// Construct the error message based on each isntance of the struct_name.
let header = "Function contains mutations to non-whitelisted struct fields:\n";
let error_messages: Vec<String> = errors.iter()
.map(|e| format!(" - {}", e.message))
.collect();
let error_message = [header, &error_messages.join("\n")].concat();
let tokens = quote! { compile_error!(#error_message); };
return tokens.into();
}
// Return the original function if no errors.
let output = quote! { #function };
output.into()
}
/// Extracts all instance names from given function
/// arguments if matches the specified struct_name.
fn extract_instance_names(
inputs: &Punctuated<FnArg, Comma>,
struct_name: &str,
found_instances: &mut HashSet<String>
) {
for arg in inputs {
match arg {
FnArg::Typed(pat_type) => {
let pat = &*pat_type.pat;
let ty = &*pat_type.ty;
let ty_str = quote! { #ty }.to_string();
match pat {
Pat::Ident(pat_ident) => {
if let Type::Reference(TypeReference { elem, .. }) = ty {
// Handle type reference by dereferencing.
if let Type::Path(TypePath { path, .. }) = &**elem {
// let path_str = quote! { #path }.to_string();
if path.is_ident(struct_name) {
// println!("Added type path (dereferenced): {}", path_str);
found_instances.insert(pat_ident.ident.to_string());
}
}
} else if let Type::Path(TypePath { path, .. }) = ty {
// let path_str = quote! { #path }.to_string();
if path.is_ident(struct_name) {
// println!("Added type path: {}", path_str);
found_instances.insert(pat_ident.ident.to_string());
}
} else {
panic!("Found unexpected argument type: {}. We apologize for the inconvenience.
Please report this issue on GitHub, and we will address it promptly.", ty_str);
}
}
_ => {
let pat_str = quote! { #pat }.to_string();
panic!("Found unexpected pattern: {}. We apologize for the inconvenience.
Please report this issue on GitHub, and we will address it promptly.", pat_str);
}
}
}
FnArg::Receiver(receiver) => {
// Handle the case where the argument is `self` for methods
if receiver.reference.is_some() || receiver.mutability.is_some() {
found_instances.insert("self".to_string());
}
}
}
}
}
/// Extracts all inner instance names from the function
/// body if matches the specified struct_name on Expr::CAll.
fn extract_inner_instance(
left: &Pat,
right: &Expr,
found_instances: &mut HashSet<String>,
struct_name: &str,
) {
// TODO: Check later on if there are more
// complicated cases for the struct initialization.
if let Expr::Call(ExprCall { func, .. }) = right {
if let Expr::Path(ExprPath { path, .. }) = &**func {
let segments = &path.segments;
// Ensure the first segment matches the struct_name.
if segments.len() > 0 && segments[0].ident == struct_name {
// Check if the next segment is an initialization method.
if segments.len() > 1 {
let init_method = &segments[1].ident.to_string();
if init_method == "default" || init_method == "new" {
// Extract instance name from the left side of the initialization.
if let Pat::Ident(PatIdent { ident, .. }) = left {
found_instances.insert(ident.to_string());
// println!("Instance found: {}", ident.to_string());
}
}
}
}
}
}
}
#[derive(Debug)]
struct Error {
message: String,
}
impl Error {
fn new(message: String) -> Self {
Error { message }
}
}
fn check_whitelist(
field_ident_str: &String,
whitelist: &[String],
errors: &mut Vec<Error>,
message: &str,
mode: bool,
) {
let is_whitelisted = whitelist.contains(field_ident_str);
if mode {
if is_whitelisted {
// Custom assetion based on whitelist data and found AST calls.
errors.push(Error::new(format!("{} is resticted by the whitelist`", message)));
}
} else {
if !is_whitelisted {
// Custom assetion based on whitelist data and found AST calls.
errors.push(Error::new(format!("{} is not whitelisted`", message)));
}
}
}
#[allow(dead_code)]
fn print_ast<T>(item: &T, label: &str)
where
T: quote::ToTokens,
{
// A helper function to print AST tokens;
let tokens: ProcTokenStream = quote! { #item };
let item_string = tokens.to_string();
println!("{}: {}", label, item_string);
}
// Recursive check all statements in the block.
fn check_block_for_mutation(
block: &Block,
whitelist: &[String],
found_instances: &mut HashSet<String>,
struct_name: &str,
errors: &mut Vec<Error>,
mode: bool,
) {
for stmt in &block.stmts {
match stmt {
Stmt::Expr(expr, _) => {
// print_ast(expr, "Found Expression");
// Explore Netsted Expression for struct field mutation.
check_expr_for_mutation(expr, whitelist, errors, found_instances, struct_name, mode);
}
Stmt::Local(Local { pat, init, .. }) => {
if let Some(init) = init {
// print_ast(&init.expr, "Found Initialization Expression");
// Extract instance name if initialization expression is a struct creation.
extract_inner_instance(pat, &init.expr, found_instances, struct_name);
// Check the initialization expression for instance names and mutation.
check_expr_for_mutation(&init.expr, whitelist, errors, found_instances, struct_name, mode);
}
}
_ => {}
}
}
}
fn check_expr_for_mutation(
expr: &Expr,
whitelist: &[String],
errors: &mut Vec<Error>,
found_instances: &mut HashSet<String>,
struct_name: &str,
mode: bool,
) {
match expr {
Expr::Binary(binary_expr) => {
// Handle various binary operations, including compound assignments.
if let Expr::Field(ExprField { base, member, .. }) = &*binary_expr.left {
if let Member::Named(field_ident) = member {
// Check if the base is one of the found instances.
if let Expr::Path(ExprPath { path, .. }) = &**base {
if let Some(instance) = path.get_ident() {
let instance_name = instance.to_string();
if found_instances.contains(&instance_name) {
let field_ident_str = field_ident.to_string();
check_whitelist(
&field_ident_str,
whitelist,
errors,
&format!("Mutation to field `{}::{}`", struct_name, field_ident_str),
mode
);
}
}
}
}
}
}
Expr::Assign(assign_expr) => {
// Handle simple assignments (fails for everything => this is a mutation).
if let Expr::Field(ExprField { base, member, .. }) = &*assign_expr.left {
if let Member::Named(field_ident) = member {
if let Expr::Path(ExprPath { path, .. }) = &**base {
if let Some(instance) = path.get_ident() {
let instance_name = instance.to_string();
if found_instances.contains(&instance_name) {
let field_ident_str = field_ident.to_string();
check_whitelist(
&field_ident_str,
whitelist,
errors,
&format!("Mutation to field `{}::{}`", struct_name, field_ident_str),
mode
);
}
}
}
}
}
}
Expr::Block(ExprBlock { block, .. }) => {
// Handle a block of code: `{ ... }`.
check_block_for_mutation(&block, whitelist, found_instances, struct_name, errors, mode);
}
Expr::If(ExprIf { then_branch, else_branch, .. }) => {
// Process the `then` block.
check_block_for_mutation(&then_branch, whitelist, found_instances, struct_name, errors, mode);
// Process the `else` branch if present.
if let Some((_, else_expr)) = else_branch {
match &**else_expr {
Expr::Block(ExprBlock { block, .. }) => {
// Process the block inside `else_expr`
check_block_for_mutation(&block, whitelist, found_instances, struct_name, errors, mode);
},
// Handle other types of `else_expr` if necessary
_ => check_expr_for_mutation(expr, whitelist, errors, found_instances, struct_name, mode),
}
}
}
Expr::While(ExprWhile { body, .. }) => {
// Handle the expression inside the while loop (always block).
check_block_for_mutation(
&body,
whitelist,
found_instances,
struct_name,
errors,
mode);
}
Expr::ForLoop(ExprForLoop { body, .. }) => {
// Handle the expression inside the for loop (always block).
check_block_for_mutation(&body, whitelist, found_instances, struct_name, errors, mode);
}
Expr::Closure(ExprClosure { body, .. }) => {
// Handle closures (either block or expression).
if let Expr::Block(ExprBlock { block, .. }) = &**body {
check_block_for_mutation(block, whitelist, found_instances, struct_name, errors, mode);
} else {
check_expr_for_mutation(body, whitelist, errors, found_instances, struct_name, mode);
}
}
_ => {}
}
}