Skip to main content

rhai_analyzer/
lib.rs

1//! Static analysis utilities for Rhai scripts.
2//!
3//! Traverses a compiled [`rhai::AST`] to extract variable access paths,
4//! local variable definitions, and string literal comparisons. The output is
5//! entirely domain-agnostic — callers decide what to do with the results.
6
7use std::collections::{HashMap, HashSet};
8
9use rhai::{AST, Expr, Stmt};
10
11/// The result of a static analysis pass over a Rhai [`AST`].
12#[derive(Debug, Default)]
13pub struct ScriptAnalysisResult {
14    /// All unique, fully-qualified variable paths accessed in the script
15    /// (e.g. `"tx.value"`, `"log.name"`).
16    pub accessed_variables: HashSet<String>,
17
18    /// Local variables defined within the script via `let` or loop iteration
19    /// variables.
20    pub local_variables: HashSet<String>,
21
22    /// Maps a fully-qualified variable path to the set of string literals it
23    /// is compared against (via `==` or `!=`) anywhere in the script.
24    ///
25    /// Example: `log.name == "Transfer"` produces
26    /// `"log.name" => {"Transfer"}`.
27    pub string_comparisons: HashMap<String, HashSet<String>>,
28}
29
30/// Traverses a compiled [`AST`] and returns a [`ScriptAnalysisResult`].
31///
32/// This is the primary entry point for the analyzer.
33pub fn analyze_ast(ast: &AST) -> ScriptAnalysisResult {
34    let mut result = ScriptAnalysisResult::default();
35    for stmt in ast.statements() {
36        walk_stmt(stmt, &mut result);
37    }
38    result
39}
40
41// ---------------------------------------------------------------------------
42// Statement walker
43// ---------------------------------------------------------------------------
44
45fn walk_stmt(stmt: &Stmt, result: &mut ScriptAnalysisResult) {
46    // Check for string comparisons at the statement level before the main
47    // structural dispatch below.
48    match stmt {
49        Stmt::Expr(expr) => check_for_string_comparisons(expr, result),
50        Stmt::FnCall(fn_call_expr, _) => {
51            let expr = Expr::FnCall(fn_call_expr.clone(), rhai::Position::NONE);
52            check_for_string_comparisons(&expr, result);
53        }
54        _ => {}
55    }
56
57    match stmt {
58        Stmt::Expr(expr) => walk_expr(expr, result),
59        Stmt::Block(stmt_block) => {
60            for s in stmt_block.statements() {
61                walk_stmt(s, result);
62            }
63        }
64        Stmt::If(flow_control, _) => {
65            walk_expr(&flow_control.expr, result);
66            for s in flow_control.body.statements() {
67                walk_stmt(s, result);
68            }
69            for s in flow_control.branch.statements() {
70                walk_stmt(s, result);
71            }
72        }
73        Stmt::While(flow_control, _) => {
74            walk_expr(&flow_control.expr, result);
75            for s in flow_control.body.statements() {
76                walk_stmt(s, result);
77            }
78        }
79        Stmt::Do(flow_control, _, _) => {
80            for s in flow_control.body.statements() {
81                walk_stmt(s, result);
82            }
83            walk_expr(&flow_control.expr, result);
84        }
85        Stmt::For(for_loop, _) => {
86            result.local_variables.insert(for_loop.0.name.to_string());
87            if let Some(second_var) = &for_loop.1 {
88                result.local_variables.insert(second_var.name.to_string());
89            }
90            walk_expr(&for_loop.2.expr, result);
91            for s in for_loop.2.body.statements() {
92                walk_stmt(s, result);
93            }
94        }
95        Stmt::Var(var_definition, _, _) => {
96            result
97                .local_variables
98                .insert(var_definition.0.name.to_string());
99            walk_expr(&var_definition.1, result);
100        }
101        Stmt::Assignment(assignment) => {
102            walk_expr(&assignment.1.lhs, result);
103            walk_expr(&assignment.1.rhs, result);
104        }
105        Stmt::FnCall(fn_call_expr, _) => {
106            for arg in &fn_call_expr.args {
107                walk_expr(arg, result);
108            }
109        }
110        Stmt::Switch(switch_data, _) => {
111            let (expr, cases_collection) = &**switch_data;
112            walk_expr(expr, result);
113            for case_expr in &cases_collection.expressions {
114                walk_expr(&case_expr.lhs, result);
115                walk_expr(&case_expr.rhs, result);
116            }
117        }
118        Stmt::TryCatch(flow_control, _) => {
119            for s in flow_control.body.statements() {
120                walk_stmt(s, result);
121            }
122            for s in flow_control.branch.statements() {
123                walk_stmt(s, result);
124            }
125        }
126        Stmt::Return(Some(expr), _, _) | Stmt::BreakLoop(Some(expr), _, _) => {
127            walk_expr(expr, result);
128        }
129        Stmt::Import(import_data, _) => {
130            walk_expr(&import_data.0, result);
131        }
132        Stmt::Noop(_)
133        | Stmt::Return(None, _, _)
134        | Stmt::BreakLoop(None, _, _)
135        | Stmt::Export(_, _)
136        | Stmt::Share(_) => {}
137        _ => {}
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Expression walker
143// ---------------------------------------------------------------------------
144
145fn walk_expr(expr: &Expr, result: &mut ScriptAnalysisResult) {
146    check_for_string_comparisons(expr, result);
147
148    if let Some(path) = get_full_variable_path(expr) {
149        result.accessed_variables.insert(path);
150        if let Expr::Index(binary_expr, _, _) = expr
151            && let Some(index_path) = get_full_variable_path(&binary_expr.rhs)
152        {
153            result.accessed_variables.insert(index_path);
154        }
155        return;
156    }
157
158    match expr {
159        Expr::Dot(binary_expr, _, _) => {
160            walk_expr(&binary_expr.lhs, result);
161            walk_expr(&binary_expr.rhs, result);
162        }
163        Expr::Index(binary_expr, _, _) => {
164            walk_expr(&binary_expr.lhs, result);
165            if let Some(index_path) = get_full_variable_path(&binary_expr.rhs) {
166                result.accessed_variables.insert(index_path);
167            } else {
168                walk_expr(&binary_expr.rhs, result);
169            }
170        }
171        Expr::MethodCall(method_call_expr, _) => {
172            for arg in &method_call_expr.args {
173                walk_expr(arg, result);
174            }
175        }
176        Expr::FnCall(fn_call_expr, _) => {
177            for arg in &fn_call_expr.args {
178                walk_expr(arg, result);
179            }
180        }
181        Expr::And(expr_vec, _) | Expr::Or(expr_vec, _) | Expr::Coalesce(expr_vec, _) => {
182            for e in &**expr_vec {
183                walk_expr(e, result);
184            }
185        }
186        Expr::Array(expr_vec, _) | Expr::InterpolatedString(expr_vec, _) => {
187            for e in expr_vec {
188                walk_expr(e, result);
189            }
190        }
191        Expr::Map(map_data, _) => {
192            for (_, value_expr) in &map_data.0 {
193                walk_expr(value_expr, result);
194            }
195        }
196        Expr::Stmt(stmt_block) => {
197            for s in stmt_block.statements() {
198                walk_stmt(s, result);
199            }
200        }
201        Expr::Custom(custom_expr, _) => {
202            for e in &custom_expr.inputs {
203                walk_expr(e, result);
204            }
205        }
206        _ => {}
207    }
208}
209
210// ---------------------------------------------------------------------------
211// Path reconstruction
212// ---------------------------------------------------------------------------
213
214/// Attempts to reconstruct a full dotted variable path (e.g. `"tx.value"`)
215/// from an expression.
216fn get_full_variable_path(expr: &Expr) -> Option<String> {
217    fn collect_path(expr: &Expr, parts: &mut Vec<String>) -> bool {
218        match expr {
219            Expr::Dot(binary_expr, _, _) => {
220                collect_path(&binary_expr.lhs, parts) && collect_path(&binary_expr.rhs, parts)
221            }
222            Expr::Property(prop_info, _) => {
223                parts.push(prop_info.2.to_string());
224                true
225            }
226            Expr::Variable(var_info, _, _) => {
227                parts.push(var_info.1.to_string());
228                true
229            }
230            Expr::Index(binary_expr, _, _) => collect_path(&binary_expr.lhs, parts),
231            _ => false,
232        }
233    }
234
235    let mut path_parts = Vec::new();
236    if collect_path(expr, &mut path_parts) && !path_parts.is_empty() {
237        Some(path_parts.join("."))
238    } else {
239        None
240    }
241}
242
243// ---------------------------------------------------------------------------
244// String comparison tracking
245// ---------------------------------------------------------------------------
246
247/// Recursively checks an expression for `variable_path == "literal"` (or
248/// `!=`) patterns and records them in
249/// [`ScriptAnalysisResult::string_comparisons`].
250fn check_for_string_comparisons(expr: &Expr, result: &mut ScriptAnalysisResult) {
251    match expr {
252        Expr::FnCall(fn_call_expr, _) => {
253            if fn_call_expr.namespace.is_empty() && fn_call_expr.args.len() == 2 {
254                match fn_call_expr.name.as_str() {
255                    "==" | "!=" => {
256                        record_string_comparison(
257                            &fn_call_expr.args[0],
258                            &fn_call_expr.args[1],
259                            result,
260                        );
261                        record_string_comparison(
262                            &fn_call_expr.args[1],
263                            &fn_call_expr.args[0],
264                            result,
265                        );
266                    }
267                    _ => {
268                        for arg in &fn_call_expr.args {
269                            check_for_string_comparisons(arg, result);
270                        }
271                    }
272                }
273            } else {
274                for arg in &fn_call_expr.args {
275                    check_for_string_comparisons(arg, result);
276                }
277            }
278        }
279        Expr::And(expr_vec, _) | Expr::Or(expr_vec, _) => {
280            for e in &**expr_vec {
281                check_for_string_comparisons(e, result);
282            }
283        }
284        Expr::Dot(binary_expr, _, _) | Expr::Index(binary_expr, _, _) => {
285            check_for_string_comparisons(&binary_expr.lhs, result);
286            check_for_string_comparisons(&binary_expr.rhs, result);
287        }
288        Expr::MethodCall(method_call_expr, _) => {
289            for arg in &method_call_expr.args {
290                check_for_string_comparisons(arg, result);
291            }
292        }
293        Expr::Array(expr_vec, _) => {
294            for e in expr_vec {
295                check_for_string_comparisons(e, result);
296            }
297        }
298        Expr::Stmt(stmt_block) => {
299            for s in stmt_block.statements() {
300                if let Stmt::Expr(inner_expr) = s {
301                    check_for_string_comparisons(inner_expr, result);
302                }
303            }
304        }
305        _ => {}
306    }
307}
308
309/// If `lhs` is a variable path and `rhs` is a string literal, records the
310/// comparison in `result.string_comparisons`.
311fn record_string_comparison(lhs: &Expr, rhs: &Expr, result: &mut ScriptAnalysisResult) {
312    if let Some(var_path) = get_full_variable_path(lhs)
313        && let Expr::StringConstant(string_val, _) = rhs
314    {
315        result
316            .string_comparisons
317            .entry(var_path)
318            .or_default()
319            .insert(string_val.to_string());
320    }
321}
322
323// ---------------------------------------------------------------------------
324// Tests
325// ---------------------------------------------------------------------------
326
327#[cfg(test)]
328mod tests {
329    use rhai::{Engine, ParseError};
330
331    use super::*;
332
333    fn analyze_script(script: &str) -> Result<ScriptAnalysisResult, ParseError> {
334        let engine = Engine::new();
335        let ast = engine.compile(script)?;
336        Ok(analyze_ast(&ast))
337    }
338
339    #[test]
340    fn test_simple_binary_op() {
341        let result = analyze_script("tx.value > 100").unwrap();
342        assert_eq!(
343            result.accessed_variables,
344            HashSet::from(["tx.value".to_string()])
345        );
346    }
347
348    #[test]
349    fn test_logical_operators() {
350        let script = r#"tx.from == owner && log.name != "Transfer" || block.number > 1000"#;
351        let result = analyze_script(script).unwrap();
352        assert_eq!(
353            result.accessed_variables,
354            HashSet::from([
355                "tx.from".to_string(),
356                "owner".to_string(),
357                "log.name".to_string(),
358                "block.number".to_string(),
359            ])
360        );
361    }
362
363    #[test]
364    fn test_multiple_variables_and_coalesce() {
365        let result = analyze_script("tx.from ?? fallback_addr.address").unwrap();
366        assert_eq!(
367            result.accessed_variables,
368            HashSet::from(["tx.from".to_string(), "fallback_addr.address".to_string()])
369        );
370    }
371
372    #[test]
373    fn test_deeply_nested_variable() {
374        let script = r#"log.params.level_one.level_two.user == "admin""#;
375        let result = analyze_script(script).unwrap();
376        assert_eq!(
377            result.accessed_variables,
378            HashSet::from(["log.params.level_one.level_two.user".to_string()])
379        );
380    }
381
382    #[test]
383    fn test_variables_in_function_calls() {
384        let result = analyze_script("my_func(tx.value, log.params.user, 42)").unwrap();
385        assert_eq!(
386            result.accessed_variables,
387            HashSet::from(["tx.value".to_string(), "log.params.user".to_string()])
388        );
389    }
390
391    #[test]
392    fn test_variables_in_let_and_if() {
393        let script = r#"
394            let threshold = config.min_value;
395            if tx.value > threshold && tx.to != blacklist.address {
396                true
397            } else {
398                false
399            }
400        "#;
401        let result = analyze_script(script).unwrap();
402        assert_eq!(
403            result.accessed_variables,
404            HashSet::from([
405                "config.min_value".to_string(),
406                "tx.value".to_string(),
407                "threshold".to_string(),
408                "tx.to".to_string(),
409                "blacklist.address".to_string()
410            ])
411        );
412    }
413
414    #[test]
415    fn test_variables_in_loops() {
416        let script = r#"
417            for item in tx.items {
418                if item.cost > max_cost {
419                    return false;
420                }
421            }
422            while x < limit {
423                x = x + 1;
424            }
425        "#;
426        let result = analyze_script(script).unwrap();
427        assert_eq!(
428            result.accessed_variables,
429            HashSet::from([
430                "tx.items".to_string(),
431                "item.cost".to_string(),
432                "max_cost".to_string(),
433                "x".to_string(),
434                "limit".to_string(),
435            ])
436        );
437    }
438
439    #[test]
440    fn test_variables_in_strings_or_comments_are_ignored() {
441        let script = r#"
442            // This is a comment about tx.value
443            let x = "this string mentions log.name";
444            tx.from == "0x123"
445        "#;
446        let result = analyze_script(script).unwrap();
447        assert_eq!(
448            result.accessed_variables,
449            HashSet::from(["tx.from".to_string()])
450        );
451    }
452
453    #[test]
454    fn test_indexing_expression() {
455        let script = r#"tx.logs[0].name == "Transfer" && some_array[tx.index] > 100"#;
456        let result = analyze_script(script).unwrap();
457        assert_eq!(
458            result.accessed_variables,
459            HashSet::from([
460                "tx.logs".to_string(),
461                "some_array".to_string(),
462                "tx.index".to_string(),
463            ])
464        );
465    }
466
467    #[test]
468    fn test_method_calls() {
469        let script = r#"my_array.contains(tx.value) && other_var.to_string() == "hello""#;
470        let result = analyze_script(script).unwrap();
471        assert_eq!(
472            result.accessed_variables,
473            HashSet::from([
474                "my_array".to_string(),
475                "tx.value".to_string(),
476                "other_var".to_string(),
477            ])
478        );
479    }
480
481    #[test]
482    fn test_switch_statement() {
483        let script = r#"
484            switch tx.action {
485                "transfer" => do_transfer(log.params.amount),
486                "approve" if log.approved => do_approve(),
487                _ => do_nothing(contract.address)
488            }
489        "#;
490        let result = analyze_script(script).unwrap();
491        assert_eq!(
492            result.accessed_variables,
493            HashSet::from([
494                "tx.action".to_string(),
495                "log.params.amount".to_string(),
496                "log.approved".to_string(),
497                "contract.address".to_string(),
498            ])
499        );
500    }
501
502    #[test]
503    fn test_no_variables() {
504        let result = analyze_script("1 + 1 == 2").unwrap();
505        assert!(result.accessed_variables.is_empty());
506    }
507
508    #[test]
509    fn test_array_and_map_literals() {
510        let script = r#"
511            let my_array = [tx.value, log.topic];
512            let my_map = #{ a: some.value, b: 42 };
513            my_array[0] > my_map.a
514        "#;
515        let result = analyze_script(script).unwrap();
516        assert_eq!(
517            result.accessed_variables,
518            HashSet::from([
519                "tx.value".to_string(),
520                "log.topic".to_string(),
521                "some.value".to_string(),
522                "my_array".to_string(),
523                "my_map.a".to_string(),
524            ])
525        );
526    }
527
528    #[test]
529    fn test_string_comparison_simple() {
530        let result = analyze_script(r#"log.name == "Transfer""#).unwrap();
531        assert_eq!(
532            result.accessed_variables,
533            HashSet::from(["log.name".to_string()])
534        );
535        let names = result.string_comparisons.get("log.name").unwrap();
536        assert_eq!(names, &HashSet::from(["Transfer".to_string()]));
537    }
538
539    #[test]
540    fn test_string_comparison_reversed() {
541        let result = analyze_script(r#""Approval" == log.name"#).unwrap();
542        let names = result.string_comparisons.get("log.name").unwrap();
543        assert_eq!(names, &HashSet::from(["Approval".to_string()]));
544    }
545
546    #[test]
547    fn test_string_comparison_in_logical_or() {
548        let result = analyze_script(r#"tx.value > 100 || log.name == "Deposit""#).unwrap();
549        let names = result.string_comparisons.get("log.name").unwrap();
550        assert_eq!(names, &HashSet::from(["Deposit".to_string()]));
551    }
552
553    #[test]
554    fn test_string_comparison_multiple_values() {
555        let result = analyze_script(r#"log.name == "Transfer" || log.name == "Approval""#).unwrap();
556        let names = result.string_comparisons.get("log.name").unwrap();
557        assert_eq!(
558            names,
559            &HashSet::from(["Transfer".to_string(), "Approval".to_string()])
560        );
561    }
562
563    #[test]
564    fn test_string_comparison_inequality() {
565        let result = analyze_script(r#"log.name != "Transfer""#).unwrap();
566        let names = result.string_comparisons.get("log.name").unwrap();
567        assert_eq!(names, &HashSet::from(["Transfer".to_string()]));
568    }
569
570    #[test]
571    fn test_string_comparison_different_path() {
572        let result = analyze_script(r#"tx.status == "success" && tx.type != "mint""#).unwrap();
573        let statuses = result.string_comparisons.get("tx.status").unwrap();
574        assert_eq!(statuses, &HashSet::from(["success".to_string()]));
575        let types = result.string_comparisons.get("tx.type").unwrap();
576        assert_eq!(types, &HashSet::from(["mint".to_string()]));
577    }
578}