1use std::collections::{HashMap, HashSet};
8
9use rhai::{AST, Expr, Stmt};
10
11#[derive(Debug, Default)]
13pub struct ScriptAnalysisResult {
14 pub accessed_variables: HashSet<String>,
17
18 pub local_variables: HashSet<String>,
21
22 pub string_comparisons: HashMap<String, HashSet<String>>,
28}
29
30pub fn analyze_ast(ast: &AST) -> ScriptAnalysisResult {
34 let mut result = ScriptAnalysisResult::default();
35 for stmt in ast.statements() {
36 walk_stmt(stmt, &mut result);
37 }
38 result
39}
40
41fn walk_stmt(stmt: &Stmt, result: &mut ScriptAnalysisResult) {
46 match stmt {
49 Stmt::Expr(expr) => check_for_string_comparisons(expr, result),
50 Stmt::FnCall(fn_call_expr, _) => {
51 let expr = Expr::FnCall(fn_call_expr.clone(), rhai::Position::NONE);
52 check_for_string_comparisons(&expr, result);
53 }
54 _ => {}
55 }
56
57 match stmt {
58 Stmt::Expr(expr) => walk_expr(expr, result),
59 Stmt::Block(stmt_block) => {
60 for s in stmt_block.statements() {
61 walk_stmt(s, result);
62 }
63 }
64 Stmt::If(flow_control, _) => {
65 walk_expr(&flow_control.expr, result);
66 for s in flow_control.body.statements() {
67 walk_stmt(s, result);
68 }
69 for s in flow_control.branch.statements() {
70 walk_stmt(s, result);
71 }
72 }
73 Stmt::While(flow_control, _) => {
74 walk_expr(&flow_control.expr, result);
75 for s in flow_control.body.statements() {
76 walk_stmt(s, result);
77 }
78 }
79 Stmt::Do(flow_control, _, _) => {
80 for s in flow_control.body.statements() {
81 walk_stmt(s, result);
82 }
83 walk_expr(&flow_control.expr, result);
84 }
85 Stmt::For(for_loop, _) => {
86 result.local_variables.insert(for_loop.0.name.to_string());
87 if let Some(second_var) = &for_loop.1 {
88 result.local_variables.insert(second_var.name.to_string());
89 }
90 walk_expr(&for_loop.2.expr, result);
91 for s in for_loop.2.body.statements() {
92 walk_stmt(s, result);
93 }
94 }
95 Stmt::Var(var_definition, _, _) => {
96 result
97 .local_variables
98 .insert(var_definition.0.name.to_string());
99 walk_expr(&var_definition.1, result);
100 }
101 Stmt::Assignment(assignment) => {
102 walk_expr(&assignment.1.lhs, result);
103 walk_expr(&assignment.1.rhs, result);
104 }
105 Stmt::FnCall(fn_call_expr, _) => {
106 for arg in &fn_call_expr.args {
107 walk_expr(arg, result);
108 }
109 }
110 Stmt::Switch(switch_data, _) => {
111 let (expr, cases_collection) = &**switch_data;
112 walk_expr(expr, result);
113 for case_expr in &cases_collection.expressions {
114 walk_expr(&case_expr.lhs, result);
115 walk_expr(&case_expr.rhs, result);
116 }
117 }
118 Stmt::TryCatch(flow_control, _) => {
119 for s in flow_control.body.statements() {
120 walk_stmt(s, result);
121 }
122 for s in flow_control.branch.statements() {
123 walk_stmt(s, result);
124 }
125 }
126 Stmt::Return(Some(expr), _, _) | Stmt::BreakLoop(Some(expr), _, _) => {
127 walk_expr(expr, result);
128 }
129 Stmt::Import(import_data, _) => {
130 walk_expr(&import_data.0, result);
131 }
132 Stmt::Noop(_)
133 | Stmt::Return(None, _, _)
134 | Stmt::BreakLoop(None, _, _)
135 | Stmt::Export(_, _)
136 | Stmt::Share(_) => {}
137 _ => {}
138 }
139}
140
141fn walk_expr(expr: &Expr, result: &mut ScriptAnalysisResult) {
146 check_for_string_comparisons(expr, result);
147
148 if let Some(path) = get_full_variable_path(expr) {
149 result.accessed_variables.insert(path);
150 if let Expr::Index(binary_expr, _, _) = expr
151 && let Some(index_path) = get_full_variable_path(&binary_expr.rhs)
152 {
153 result.accessed_variables.insert(index_path);
154 }
155 return;
156 }
157
158 match expr {
159 Expr::Dot(binary_expr, _, _) => {
160 walk_expr(&binary_expr.lhs, result);
161 walk_expr(&binary_expr.rhs, result);
162 }
163 Expr::Index(binary_expr, _, _) => {
164 walk_expr(&binary_expr.lhs, result);
165 if let Some(index_path) = get_full_variable_path(&binary_expr.rhs) {
166 result.accessed_variables.insert(index_path);
167 } else {
168 walk_expr(&binary_expr.rhs, result);
169 }
170 }
171 Expr::MethodCall(method_call_expr, _) => {
172 for arg in &method_call_expr.args {
173 walk_expr(arg, result);
174 }
175 }
176 Expr::FnCall(fn_call_expr, _) => {
177 for arg in &fn_call_expr.args {
178 walk_expr(arg, result);
179 }
180 }
181 Expr::And(expr_vec, _) | Expr::Or(expr_vec, _) | Expr::Coalesce(expr_vec, _) => {
182 for e in &**expr_vec {
183 walk_expr(e, result);
184 }
185 }
186 Expr::Array(expr_vec, _) | Expr::InterpolatedString(expr_vec, _) => {
187 for e in expr_vec {
188 walk_expr(e, result);
189 }
190 }
191 Expr::Map(map_data, _) => {
192 for (_, value_expr) in &map_data.0 {
193 walk_expr(value_expr, result);
194 }
195 }
196 Expr::Stmt(stmt_block) => {
197 for s in stmt_block.statements() {
198 walk_stmt(s, result);
199 }
200 }
201 Expr::Custom(custom_expr, _) => {
202 for e in &custom_expr.inputs {
203 walk_expr(e, result);
204 }
205 }
206 _ => {}
207 }
208}
209
210fn get_full_variable_path(expr: &Expr) -> Option<String> {
217 fn collect_path(expr: &Expr, parts: &mut Vec<String>) -> bool {
218 match expr {
219 Expr::Dot(binary_expr, _, _) => {
220 collect_path(&binary_expr.lhs, parts) && collect_path(&binary_expr.rhs, parts)
221 }
222 Expr::Property(prop_info, _) => {
223 parts.push(prop_info.2.to_string());
224 true
225 }
226 Expr::Variable(var_info, _, _) => {
227 parts.push(var_info.1.to_string());
228 true
229 }
230 Expr::Index(binary_expr, _, _) => collect_path(&binary_expr.lhs, parts),
231 _ => false,
232 }
233 }
234
235 let mut path_parts = Vec::new();
236 if collect_path(expr, &mut path_parts) && !path_parts.is_empty() {
237 Some(path_parts.join("."))
238 } else {
239 None
240 }
241}
242
243fn check_for_string_comparisons(expr: &Expr, result: &mut ScriptAnalysisResult) {
251 match expr {
252 Expr::FnCall(fn_call_expr, _) => {
253 if fn_call_expr.namespace.is_empty() && fn_call_expr.args.len() == 2 {
254 match fn_call_expr.name.as_str() {
255 "==" | "!=" => {
256 record_string_comparison(
257 &fn_call_expr.args[0],
258 &fn_call_expr.args[1],
259 result,
260 );
261 record_string_comparison(
262 &fn_call_expr.args[1],
263 &fn_call_expr.args[0],
264 result,
265 );
266 }
267 _ => {
268 for arg in &fn_call_expr.args {
269 check_for_string_comparisons(arg, result);
270 }
271 }
272 }
273 } else {
274 for arg in &fn_call_expr.args {
275 check_for_string_comparisons(arg, result);
276 }
277 }
278 }
279 Expr::And(expr_vec, _) | Expr::Or(expr_vec, _) => {
280 for e in &**expr_vec {
281 check_for_string_comparisons(e, result);
282 }
283 }
284 Expr::Dot(binary_expr, _, _) | Expr::Index(binary_expr, _, _) => {
285 check_for_string_comparisons(&binary_expr.lhs, result);
286 check_for_string_comparisons(&binary_expr.rhs, result);
287 }
288 Expr::MethodCall(method_call_expr, _) => {
289 for arg in &method_call_expr.args {
290 check_for_string_comparisons(arg, result);
291 }
292 }
293 Expr::Array(expr_vec, _) => {
294 for e in expr_vec {
295 check_for_string_comparisons(e, result);
296 }
297 }
298 Expr::Stmt(stmt_block) => {
299 for s in stmt_block.statements() {
300 if let Stmt::Expr(inner_expr) = s {
301 check_for_string_comparisons(inner_expr, result);
302 }
303 }
304 }
305 _ => {}
306 }
307}
308
309fn record_string_comparison(lhs: &Expr, rhs: &Expr, result: &mut ScriptAnalysisResult) {
312 if let Some(var_path) = get_full_variable_path(lhs)
313 && let Expr::StringConstant(string_val, _) = rhs
314 {
315 result
316 .string_comparisons
317 .entry(var_path)
318 .or_default()
319 .insert(string_val.to_string());
320 }
321}
322
323#[cfg(test)]
328mod tests {
329 use rhai::{Engine, ParseError};
330
331 use super::*;
332
333 fn analyze_script(script: &str) -> Result<ScriptAnalysisResult, ParseError> {
334 let engine = Engine::new();
335 let ast = engine.compile(script)?;
336 Ok(analyze_ast(&ast))
337 }
338
339 #[test]
340 fn test_simple_binary_op() {
341 let result = analyze_script("tx.value > 100").unwrap();
342 assert_eq!(
343 result.accessed_variables,
344 HashSet::from(["tx.value".to_string()])
345 );
346 }
347
348 #[test]
349 fn test_logical_operators() {
350 let script = r#"tx.from == owner && log.name != "Transfer" || block.number > 1000"#;
351 let result = analyze_script(script).unwrap();
352 assert_eq!(
353 result.accessed_variables,
354 HashSet::from([
355 "tx.from".to_string(),
356 "owner".to_string(),
357 "log.name".to_string(),
358 "block.number".to_string(),
359 ])
360 );
361 }
362
363 #[test]
364 fn test_multiple_variables_and_coalesce() {
365 let result = analyze_script("tx.from ?? fallback_addr.address").unwrap();
366 assert_eq!(
367 result.accessed_variables,
368 HashSet::from(["tx.from".to_string(), "fallback_addr.address".to_string()])
369 );
370 }
371
372 #[test]
373 fn test_deeply_nested_variable() {
374 let script = r#"log.params.level_one.level_two.user == "admin""#;
375 let result = analyze_script(script).unwrap();
376 assert_eq!(
377 result.accessed_variables,
378 HashSet::from(["log.params.level_one.level_two.user".to_string()])
379 );
380 }
381
382 #[test]
383 fn test_variables_in_function_calls() {
384 let result = analyze_script("my_func(tx.value, log.params.user, 42)").unwrap();
385 assert_eq!(
386 result.accessed_variables,
387 HashSet::from(["tx.value".to_string(), "log.params.user".to_string()])
388 );
389 }
390
391 #[test]
392 fn test_variables_in_let_and_if() {
393 let script = r#"
394 let threshold = config.min_value;
395 if tx.value > threshold && tx.to != blacklist.address {
396 true
397 } else {
398 false
399 }
400 "#;
401 let result = analyze_script(script).unwrap();
402 assert_eq!(
403 result.accessed_variables,
404 HashSet::from([
405 "config.min_value".to_string(),
406 "tx.value".to_string(),
407 "threshold".to_string(),
408 "tx.to".to_string(),
409 "blacklist.address".to_string()
410 ])
411 );
412 }
413
414 #[test]
415 fn test_variables_in_loops() {
416 let script = r#"
417 for item in tx.items {
418 if item.cost > max_cost {
419 return false;
420 }
421 }
422 while x < limit {
423 x = x + 1;
424 }
425 "#;
426 let result = analyze_script(script).unwrap();
427 assert_eq!(
428 result.accessed_variables,
429 HashSet::from([
430 "tx.items".to_string(),
431 "item.cost".to_string(),
432 "max_cost".to_string(),
433 "x".to_string(),
434 "limit".to_string(),
435 ])
436 );
437 }
438
439 #[test]
440 fn test_variables_in_strings_or_comments_are_ignored() {
441 let script = r#"
442 // This is a comment about tx.value
443 let x = "this string mentions log.name";
444 tx.from == "0x123"
445 "#;
446 let result = analyze_script(script).unwrap();
447 assert_eq!(
448 result.accessed_variables,
449 HashSet::from(["tx.from".to_string()])
450 );
451 }
452
453 #[test]
454 fn test_indexing_expression() {
455 let script = r#"tx.logs[0].name == "Transfer" && some_array[tx.index] > 100"#;
456 let result = analyze_script(script).unwrap();
457 assert_eq!(
458 result.accessed_variables,
459 HashSet::from([
460 "tx.logs".to_string(),
461 "some_array".to_string(),
462 "tx.index".to_string(),
463 ])
464 );
465 }
466
467 #[test]
468 fn test_method_calls() {
469 let script = r#"my_array.contains(tx.value) && other_var.to_string() == "hello""#;
470 let result = analyze_script(script).unwrap();
471 assert_eq!(
472 result.accessed_variables,
473 HashSet::from([
474 "my_array".to_string(),
475 "tx.value".to_string(),
476 "other_var".to_string(),
477 ])
478 );
479 }
480
481 #[test]
482 fn test_switch_statement() {
483 let script = r#"
484 switch tx.action {
485 "transfer" => do_transfer(log.params.amount),
486 "approve" if log.approved => do_approve(),
487 _ => do_nothing(contract.address)
488 }
489 "#;
490 let result = analyze_script(script).unwrap();
491 assert_eq!(
492 result.accessed_variables,
493 HashSet::from([
494 "tx.action".to_string(),
495 "log.params.amount".to_string(),
496 "log.approved".to_string(),
497 "contract.address".to_string(),
498 ])
499 );
500 }
501
502 #[test]
503 fn test_no_variables() {
504 let result = analyze_script("1 + 1 == 2").unwrap();
505 assert!(result.accessed_variables.is_empty());
506 }
507
508 #[test]
509 fn test_array_and_map_literals() {
510 let script = r#"
511 let my_array = [tx.value, log.topic];
512 let my_map = #{ a: some.value, b: 42 };
513 my_array[0] > my_map.a
514 "#;
515 let result = analyze_script(script).unwrap();
516 assert_eq!(
517 result.accessed_variables,
518 HashSet::from([
519 "tx.value".to_string(),
520 "log.topic".to_string(),
521 "some.value".to_string(),
522 "my_array".to_string(),
523 "my_map.a".to_string(),
524 ])
525 );
526 }
527
528 #[test]
529 fn test_string_comparison_simple() {
530 let result = analyze_script(r#"log.name == "Transfer""#).unwrap();
531 assert_eq!(
532 result.accessed_variables,
533 HashSet::from(["log.name".to_string()])
534 );
535 let names = result.string_comparisons.get("log.name").unwrap();
536 assert_eq!(names, &HashSet::from(["Transfer".to_string()]));
537 }
538
539 #[test]
540 fn test_string_comparison_reversed() {
541 let result = analyze_script(r#""Approval" == log.name"#).unwrap();
542 let names = result.string_comparisons.get("log.name").unwrap();
543 assert_eq!(names, &HashSet::from(["Approval".to_string()]));
544 }
545
546 #[test]
547 fn test_string_comparison_in_logical_or() {
548 let result = analyze_script(r#"tx.value > 100 || log.name == "Deposit""#).unwrap();
549 let names = result.string_comparisons.get("log.name").unwrap();
550 assert_eq!(names, &HashSet::from(["Deposit".to_string()]));
551 }
552
553 #[test]
554 fn test_string_comparison_multiple_values() {
555 let result = analyze_script(r#"log.name == "Transfer" || log.name == "Approval""#).unwrap();
556 let names = result.string_comparisons.get("log.name").unwrap();
557 assert_eq!(
558 names,
559 &HashSet::from(["Transfer".to_string(), "Approval".to_string()])
560 );
561 }
562
563 #[test]
564 fn test_string_comparison_inequality() {
565 let result = analyze_script(r#"log.name != "Transfer""#).unwrap();
566 let names = result.string_comparisons.get("log.name").unwrap();
567 assert_eq!(names, &HashSet::from(["Transfer".to_string()]));
568 }
569
570 #[test]
571 fn test_string_comparison_different_path() {
572 let result = analyze_script(r#"tx.status == "success" && tx.type != "mint""#).unwrap();
573 let statuses = result.string_comparisons.get("tx.status").unwrap();
574 assert_eq!(statuses, &HashSet::from(["success".to_string()]));
575 let types = result.string_comparisons.get("tx.type").unwrap();
576 assert_eq!(types, &HashSet::from(["mint".to_string()]));
577 }
578}