1use nu_protocol::{
2 BlockId, Span, VarId,
3 ast::{Block, Expr, Expression, Pipeline},
4 engine::StateWorkingSet,
5};
6
7pub trait AstVisitor {
9 fn visit_block(&mut self, block: &Block, context: &VisitContext) {
10 for pipeline in &block.pipelines {
11 self.visit_pipeline(pipeline, context);
12 }
13 }
14
15 fn visit_pipeline(&mut self, pipeline: &Pipeline, context: &VisitContext) {
16 for element in &pipeline.elements {
17 self.visit_expression(&element.expr, context);
18 }
19 }
20
21 fn visit_expression(&mut self, expr: &Expression, context: &VisitContext) {
22 walk_expression(self, expr, context);
23 }
24
25 fn visit_call(&mut self, call: &nu_protocol::ast::Call, context: &VisitContext) {
26 walk_call(self, call, context);
27 }
28
29 fn visit_var_decl(&mut self, _var_id: VarId, _span: Span, _context: &VisitContext) {}
30
31 fn visit_var_ref(&mut self, _var_id: VarId, _span: Span, _context: &VisitContext) {}
32
33 fn visit_binary_op(
34 &mut self,
35 lhs: &Expression,
36 op: &Expression,
37 rhs: &Expression,
38 context: &VisitContext,
39 ) {
40 self.visit_expression(lhs, context);
41 self.visit_expression(op, context);
42 self.visit_expression(rhs, context);
43 }
44
45 fn visit_list(&mut self, items: &[nu_protocol::ast::ListItem], context: &VisitContext) {
46 for item in items {
47 match item {
48 nu_protocol::ast::ListItem::Item(expr)
49 | nu_protocol::ast::ListItem::Spread(_, expr) => {
50 self.visit_expression(expr, context);
51 }
52 }
53 }
54 }
55
56 fn visit_string(&mut self, _content: &str, _span: Span, _context: &VisitContext) {}
57
58 fn visit_int(&mut self, _value: i64, _span: Span, _context: &VisitContext) {}
59}
60
61pub struct VisitContext<'a> {
63 pub working_set: &'a StateWorkingSet<'a>,
64 pub source: &'a str,
65}
66
67impl<'a> VisitContext<'a> {
68 #[must_use]
69 pub fn new(working_set: &'a StateWorkingSet<'a>, source: &'a str) -> Self {
70 Self {
71 working_set,
72 source,
73 }
74 }
75
76 #[must_use]
78 pub fn get_span_contents(&self, span: Span) -> &str {
79 let start = span.start.min(self.source.len());
80 let end = span.end.min(self.source.len());
81 &self.source[start..end]
82 }
83
84 #[must_use]
86 pub fn get_block(&self, block_id: BlockId) -> &Block {
87 self.working_set.get_block(block_id)
88 }
89
90 #[must_use]
92 pub fn get_variable(&self, var_id: VarId) -> &nu_protocol::engine::Variable {
93 self.working_set.get_variable(var_id)
94 }
95
96 #[must_use]
98 pub fn get_decl(&self, decl_id: nu_protocol::DeclId) -> &dyn nu_protocol::engine::Command {
99 self.working_set.get_decl(decl_id)
100 }
101
102 #[must_use]
105 pub fn extract_external_args(
106 &self,
107 args: &[nu_protocol::ast::ExternalArgument],
108 ) -> Vec<String> {
109 args.iter()
110 .map(|arg| match &arg {
111 nu_protocol::ast::ExternalArgument::Regular(expr) => {
112 self.get_span_contents(expr.span).to_string()
113 }
114 nu_protocol::ast::ExternalArgument::Spread(expr) => {
115 format!("...{}", self.get_span_contents(expr.span))
116 }
117 })
118 .collect()
119 }
120}
121
122pub fn walk_expression<V: AstVisitor + ?Sized>(
124 visitor: &mut V,
125 expr: &Expression,
126 context: &VisitContext,
127) {
128 match &expr.expr {
129 Expr::VarDecl(var_id) => visitor.visit_var_decl(*var_id, expr.span, context),
130 Expr::Var(var_id) => visitor.visit_var_ref(*var_id, expr.span, context),
131 Expr::Call(call) => visitor.visit_call(call, context),
132 Expr::BinaryOp(lhs, op, rhs) => visitor.visit_binary_op(lhs, op, rhs, context),
133 Expr::UnaryNot(inner) => visitor.visit_expression(inner, context),
134 Expr::List(items) => visitor.visit_list(items, context),
135 Expr::Record(items) => {
136 for item in items {
137 match item {
138 nu_protocol::ast::RecordItem::Pair(key, value) => {
139 visitor.visit_expression(key, context);
140 visitor.visit_expression(value, context);
141 }
142 nu_protocol::ast::RecordItem::Spread(_, expr) => {
143 visitor.visit_expression(expr, context);
144 }
145 }
146 }
147 }
148 Expr::Block(block_id) | Expr::Closure(block_id) | Expr::Subexpression(block_id) => {
149 let block = context.get_block(*block_id);
150 visitor.visit_block(block, context);
151 }
152 Expr::FullCellPath(cell_path) => visitor.visit_expression(&cell_path.head, context),
153 Expr::String(string_content) | Expr::RawString(string_content) => {
154 visitor.visit_string(string_content, expr.span, context);
155 }
156 Expr::Int(value) => visitor.visit_int(*value, expr.span, context),
157 Expr::StringInterpolation(exprs) => {
158 for expr in exprs {
159 visitor.visit_expression(expr, context);
160 }
161 }
162 Expr::MatchBlock(arms) => {
163 for (_, expr) in arms {
164 visitor.visit_expression(expr, context);
165 }
166 }
167 _ => {}
168 }
169}
170
171pub fn walk_call<V: AstVisitor + ?Sized>(
173 visitor: &mut V,
174 call: &nu_protocol::ast::Call,
175 context: &VisitContext,
176) {
177 for arg in &call.arguments {
178 match arg {
179 nu_protocol::ast::Argument::Named(named) => {
180 if let Some(expr) = &named.2 {
181 visitor.visit_expression(expr, context);
182 }
183 }
184 nu_protocol::ast::Argument::Positional(expr)
185 | nu_protocol::ast::Argument::Unknown(expr)
186 | nu_protocol::ast::Argument::Spread(expr) => {
187 visitor.visit_expression(expr, context);
188 }
189 }
190 }
191}