1use nu_protocol::{
2 BlockId, Span, VarId,
3 ast::{Block, Expr, Expression, Pipeline},
4 engine::StateWorkingSet,
5};
6
7pub trait AstVisitor {
9 fn visit_block(&mut self, block: &Block, context: &VisitContext) {
10 for pipeline in &block.pipelines {
11 self.visit_pipeline(pipeline, context);
12 }
13 }
14
15 fn visit_pipeline(&mut self, pipeline: &Pipeline, context: &VisitContext) {
16 for element in &pipeline.elements {
17 self.visit_expression(&element.expr, context);
18 }
19 }
20
21 fn visit_expression(&mut self, expr: &Expression, context: &VisitContext) {
22 walk_expression(self, expr, context);
23 }
24
25 fn visit_call(&mut self, call: &nu_protocol::ast::Call, context: &VisitContext) {
26 walk_call(self, call, context);
27 }
28
29 fn visit_var_decl(&mut self, _var_id: VarId, _span: Span, _context: &VisitContext) {}
32
33 fn visit_var_ref(&mut self, _var_id: VarId, _span: Span, _context: &VisitContext) {}
34
35 fn visit_binary_op(
36 &mut self,
37 lhs: &Expression,
38 op: &Expression,
39 rhs: &Expression,
40 context: &VisitContext,
41 ) {
42 self.visit_expression(lhs, context);
43 self.visit_expression(op, context);
44 self.visit_expression(rhs, context);
45 }
46
47 fn visit_list(&mut self, items: &[nu_protocol::ast::ListItem], context: &VisitContext) {
48 for item in items {
49 let expr = match item {
50 nu_protocol::ast::ListItem::Item(expr)
51 | nu_protocol::ast::ListItem::Spread(_, expr) => expr,
52 };
53 self.visit_expression(expr, context);
54 }
55 }
56
57 fn visit_string(&mut self, _content: &str, _span: Span, _context: &VisitContext) {}
58
59 fn visit_int(&mut self, _value: i64, _span: Span, _context: &VisitContext) {}
60}
61
62pub struct VisitContext<'a> {
64 pub working_set: &'a StateWorkingSet<'a>,
65 pub source: &'a str,
66}
67
68impl VisitContext<'_> {
69 #[must_use]
71 pub fn get_span_contents(&self, span: Span) -> &str {
72 let start = span.start.min(self.source.len());
73 let end = span.end.min(self.source.len());
74 &self.source[start..end]
75 }
76
77 #[must_use]
79 pub fn get_block(&self, block_id: BlockId) -> &Block {
80 self.working_set.get_block(block_id)
81 }
82
83 #[must_use]
85 pub fn get_variable(&self, var_id: VarId) -> &nu_protocol::engine::Variable {
86 self.working_set.get_variable(var_id)
87 }
88
89 #[must_use]
91 pub fn get_decl(&self, decl_id: nu_protocol::DeclId) -> &dyn nu_protocol::engine::Command {
92 self.working_set.get_decl(decl_id)
93 }
94
95 #[must_use]
97 pub fn extract_external_args(
98 &self,
99 args: &[nu_protocol::ast::ExternalArgument],
100 ) -> Vec<String> {
101 args.iter()
102 .map(|arg| match arg {
103 nu_protocol::ast::ExternalArgument::Regular(expr) => {
104 self.get_span_contents(expr.span).to_string()
105 }
106 nu_protocol::ast::ExternalArgument::Spread(expr) => {
107 format!("...{}", self.get_span_contents(expr.span))
108 }
109 })
110 .collect()
111 }
112}
113
114pub fn walk_expression<V: AstVisitor + ?Sized>(
116 visitor: &mut V,
117 expr: &Expression,
118 context: &VisitContext,
119) {
120 match &expr.expr {
121 Expr::VarDecl(var_id) => visitor.visit_var_decl(*var_id, expr.span, context),
122 Expr::Var(var_id) => visitor.visit_var_ref(*var_id, expr.span, context),
123 Expr::Call(call) => visitor.visit_call(call, context),
124 Expr::BinaryOp(lhs, op, rhs) => visitor.visit_binary_op(lhs, op, rhs, context),
125 Expr::UnaryNot(inner) => visitor.visit_expression(inner, context),
126 Expr::List(items) => visitor.visit_list(items, context),
127 Expr::Record(items) => {
128 for item in items {
129 match item {
130 nu_protocol::ast::RecordItem::Pair(key, value) => {
131 visitor.visit_expression(key, context);
132 visitor.visit_expression(value, context);
133 }
134 nu_protocol::ast::RecordItem::Spread(_, expr) => {
135 visitor.visit_expression(expr, context);
136 }
137 }
138 }
139 }
140 Expr::Block(block_id)
141 | Expr::Closure(block_id)
142 | Expr::Subexpression(block_id)
143 | Expr::RowCondition(block_id) => {
144 let block = context.get_block(*block_id);
145 visitor.visit_block(block, context);
146 }
147 Expr::Keyword(keyword) => visitor.visit_expression(&keyword.expr, context),
148 Expr::FullCellPath(cell_path) => visitor.visit_expression(&cell_path.head, context),
149 Expr::String(string_content) | Expr::RawString(string_content) => {
150 visitor.visit_string(string_content, expr.span, context);
151 }
152 Expr::Int(value) => visitor.visit_int(*value, expr.span, context),
153 Expr::StringInterpolation(exprs) => {
154 for expr in exprs {
155 visitor.visit_expression(expr, context);
156 }
157 }
158 Expr::MatchBlock(arms) => {
159 for (_, expr) in arms {
160 visitor.visit_expression(expr, context);
161 }
162 }
163 Expr::ExternalCall(head, args) => {
164 visitor.visit_expression(head, context);
165 for arg in args {
166 match arg {
167 nu_protocol::ast::ExternalArgument::Regular(expr)
168 | nu_protocol::ast::ExternalArgument::Spread(expr) => {
169 visitor.visit_expression(expr, context);
170 }
171 }
172 }
173 }
174 _ => {}
175 }
176}
177
178pub fn walk_call<V: AstVisitor + ?Sized>(
180 visitor: &mut V,
181 call: &nu_protocol::ast::Call,
182 context: &VisitContext,
183) {
184 for arg in &call.arguments {
185 match arg {
186 nu_protocol::ast::Argument::Named(named) => {
187 if let Some(expr) = &named.2 {
188 visitor.visit_expression(expr, context);
189 }
190 }
191 nu_protocol::ast::Argument::Positional(expr)
192 | nu_protocol::ast::Argument::Unknown(expr)
193 | nu_protocol::ast::Argument::Spread(expr) => {
194 visitor.visit_expression(expr, context);
195 }
196 }
197 }
198}