lets_expect_core/utils/
expr_dependencies.rs

1use std::collections::HashSet;
2
3use proc_macro2::{Ident, TokenTree};
4use syn::{Block, Expr, Pat, Stmt};
5
6/// Returns the identifiers used in an expression.
7pub(crate) fn expr_dependencies(expr: &Expr) -> HashSet<Ident> {
8    let mut dependencies = HashSet::new();
9
10    match expr {
11        Expr::Binary(binary) => {
12            dependencies.extend(expr_dependencies(&binary.left));
13            dependencies.extend(expr_dependencies(&binary.right));
14        }
15        Expr::Unary(unary) => {
16            dependencies.extend(expr_dependencies(&unary.expr));
17        }
18        Expr::Assign(assign) => {
19            dependencies.extend(expr_dependencies(&assign.left));
20            dependencies.extend(expr_dependencies(&assign.right));
21        }
22        Expr::AssignOp(assign_op) => {
23            dependencies.extend(expr_dependencies(&assign_op.left));
24            dependencies.extend(expr_dependencies(&assign_op.right));
25        }
26        Expr::Path(path) => {
27            if let Some(ident) = path.path.get_ident() {
28                dependencies.insert(ident.clone());
29            }
30        }
31        Expr::Call(call) => {
32            dependencies.extend(expr_dependencies(&call.func));
33            dependencies.extend(call.args.iter().flat_map(expr_dependencies));
34        }
35        Expr::MethodCall(method_call) => {
36            dependencies.extend(expr_dependencies(&method_call.receiver));
37            dependencies.extend(method_call.args.iter().flat_map(expr_dependencies));
38        }
39        Expr::Field(field) => {
40            dependencies.extend(expr_dependencies(&field.base));
41        }
42        Expr::Index(index) => {
43            dependencies.extend(expr_dependencies(&index.expr));
44            dependencies.extend(expr_dependencies(&index.index));
45        }
46        Expr::Range(range) => {
47            if let Some(from) = &range.from {
48                dependencies.extend(expr_dependencies(from));
49            }
50            if let Some(to) = &range.to {
51                dependencies.extend(expr_dependencies(to));
52            }
53        }
54        Expr::Reference(reference) => {
55            dependencies.extend(expr_dependencies(&reference.expr));
56        }
57        Expr::Paren(paren) => {
58            dependencies.extend(expr_dependencies(&paren.expr));
59        }
60        Expr::Group(group) => {
61            dependencies.extend(expr_dependencies(&group.expr));
62        }
63        Expr::Block(block) => {
64            dependencies.extend(block_dependencies(&block.block));
65        }
66        Expr::If(r#if) => {
67            dependencies.extend(expr_dependencies(&r#if.cond));
68            dependencies.extend(r#if.then_branch.stmts.iter().flat_map(stmt_dependencies));
69            if let Some(else_branch) = &r#if.else_branch {
70                dependencies.extend(expr_dependencies(&else_branch.1));
71            }
72        }
73        Expr::Match(match_) => {
74            dependencies.extend(expr_dependencies(&match_.expr));
75            dependencies.extend(
76                match_
77                    .arms
78                    .iter()
79                    .map(|arm| &*arm.body)
80                    .flat_map(expr_dependencies),
81            );
82        }
83        Expr::Closure(closure) => {
84            dependencies.extend(closure.inputs.iter().flat_map(pat_dependencies));
85            dependencies.extend(expr_dependencies(&closure.body));
86        }
87        Expr::Unsafe(unsafe_) => {
88            dependencies.extend(unsafe_.block.stmts.iter().flat_map(stmt_dependencies));
89        }
90        Expr::Loop(r#loop) => {
91            dependencies.extend(r#loop.body.stmts.iter().flat_map(stmt_dependencies));
92        }
93        Expr::While(while_) => {
94            dependencies.extend(expr_dependencies(&while_.cond));
95            dependencies.extend(while_.body.stmts.iter().flat_map(stmt_dependencies));
96        }
97        Expr::ForLoop(for_loop) => {
98            dependencies.extend(pat_dependencies(&for_loop.pat));
99            dependencies.extend(expr_dependencies(&for_loop.expr));
100            dependencies.extend(for_loop.body.stmts.iter().flat_map(stmt_dependencies));
101        }
102        Expr::Break(r#break) => {
103            if let Some(expr) = &r#break.expr {
104                dependencies.extend(expr_dependencies(expr));
105            }
106        }
107        Expr::Return(return_) => {
108            if let Some(expr) = &return_.expr {
109                dependencies.extend(expr_dependencies(expr));
110            }
111        }
112        Expr::Yield(yield_) => {
113            if let Some(expr) = &yield_.expr {
114                dependencies.extend(expr_dependencies(expr));
115            }
116        }
117        Expr::Try(try_) => {
118            dependencies.extend(expr_dependencies(&try_.expr));
119        }
120        Expr::Async(async_) => {
121            dependencies.extend(async_.block.stmts.iter().flat_map(stmt_dependencies));
122        }
123        Expr::Await(r#await) => {
124            dependencies.extend(expr_dependencies(&r#await.base));
125        }
126        Expr::Macro(macro_) => {
127            dependencies.extend(
128                macro_
129                    .mac
130                    .path
131                    .segments
132                    .iter()
133                    .map(|segment| segment.ident.clone()),
134            );
135            dependencies.extend(macro_.mac.tokens.clone().into_iter().flat_map(|token| {
136                if let TokenTree::Ident(ident) = token {
137                    Some(ident)
138                } else {
139                    None
140                }
141            }));
142        }
143        Expr::Tuple(tuple) => {
144            dependencies.extend(tuple.elems.iter().flat_map(expr_dependencies));
145        }
146        Expr::Array(array) => {
147            dependencies.extend(array.elems.iter().flat_map(expr_dependencies));
148        }
149        Expr::Repeat(repeat) => {
150            dependencies.extend(expr_dependencies(&repeat.expr));
151            dependencies.extend(expr_dependencies(&repeat.len));
152        }
153        Expr::Struct(r#struct) => {
154            dependencies.extend(
155                r#struct
156                    .fields
157                    .iter()
158                    .flat_map(|field| expr_dependencies(&field.expr)),
159            );
160            dependencies.extend(
161                r#struct
162                    .fields
163                    .iter()
164                    .flat_map(|field| expr_dependencies(&field.expr)),
165            );
166        }
167        _ => {}
168    }
169
170    dependencies
171}
172
173pub fn block_dependencies(block: &Block) -> HashSet<Ident> {
174    block.stmts.iter().flat_map(stmt_dependencies).collect()
175}
176
177fn pat_dependencies(pat: &Pat) -> HashSet<Ident> {
178    let mut dependencies = HashSet::new();
179
180    match pat {
181        Pat::Ident(ident) => {
182            dependencies.insert(ident.ident.clone());
183        }
184        Pat::Type(pat_type) => {
185            dependencies.extend(pat_dependencies(&pat_type.pat));
186        }
187        _ => {}
188    }
189
190    dependencies
191}
192
193pub fn stmt_dependencies(stmt: &Stmt) -> HashSet<Ident> {
194    match stmt {
195        Stmt::Local(local) => {
196            if let Some(init) = &local.init {
197                expr_dependencies(&init.1)
198            } else {
199                HashSet::new()
200            }
201        }
202        Stmt::Item(_) => HashSet::new(),
203        Stmt::Expr(expr) | Stmt::Semi(expr, _) => expr_dependencies(expr),
204    }
205}