nu_protocol/ast/
traverse.rs

1use crate::engine::StateWorkingSet;
2
3use super::{
4    Block, Expr, Expression, ListItem, MatchPattern, Pattern, PipelineRedirection, RecordItem,
5};
6
7/// Result of find_map closure
8#[derive(Default)]
9pub enum FindMapResult<T> {
10    Found(T),
11    #[default]
12    Continue,
13    Stop,
14}
15
16/// Trait for traversing the AST
17pub trait Traverse {
18    /// Generic function that do flat_map on an AST node
19    /// concatenates all recursive results on sub-expressions
20    ///
21    /// # Arguments
22    /// * `f` - function that overrides the default behavior
23    fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T>
24    where
25        F: Fn(&'a Expression) -> Option<Vec<T>>;
26
27    /// Generic function that do find_map on an AST node
28    /// return the first Some
29    ///
30    /// # Arguments
31    /// * `f` - function that overrides the default behavior
32    fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
33    where
34        F: Fn(&'a Expression) -> FindMapResult<T>;
35}
36
37impl Traverse for Block {
38    fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T>
39    where
40        F: Fn(&'a Expression) -> Option<Vec<T>>,
41    {
42        self.pipelines
43            .iter()
44            .flat_map(|pipeline| {
45                pipeline.elements.iter().flat_map(|element| {
46                    element.expr.flat_map(working_set, f).into_iter().chain(
47                        element
48                            .redirection
49                            .as_ref()
50                            .map(|redir| redir.flat_map(working_set, f))
51                            .unwrap_or_default(),
52                    )
53                })
54            })
55            .collect()
56    }
57
58    fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
59    where
60        F: Fn(&'a Expression) -> FindMapResult<T>,
61    {
62        self.pipelines.iter().find_map(|pipeline| {
63            pipeline.elements.iter().find_map(|element| {
64                element.expr.find_map(working_set, f).or(element
65                    .redirection
66                    .as_ref()
67                    .and_then(|redir| redir.find_map(working_set, f)))
68            })
69        })
70    }
71}
72
73impl Traverse for PipelineRedirection {
74    fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T>
75    where
76        F: Fn(&'a Expression) -> Option<Vec<T>>,
77    {
78        let recur = |expr: &'a Expression| expr.flat_map(working_set, f);
79        match self {
80            PipelineRedirection::Single { target, .. } => {
81                target.expr().map(recur).unwrap_or_default()
82            }
83            PipelineRedirection::Separate { out, err } => [out, err]
84                .iter()
85                .filter_map(|t| t.expr())
86                .flat_map(recur)
87                .collect(),
88        }
89    }
90
91    fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
92    where
93        F: Fn(&'a Expression) -> FindMapResult<T>,
94    {
95        let recur = |expr: &'a Expression| expr.find_map(working_set, f);
96        match self {
97            PipelineRedirection::Single { target, .. } => {
98                target.expr().map(recur).unwrap_or_default()
99            }
100            PipelineRedirection::Separate { out, err } => {
101                [out, err].iter().filter_map(|t| t.expr()).find_map(recur)
102            }
103        }
104    }
105}
106
107impl Traverse for Expression {
108    fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T>
109    where
110        F: Fn(&'a Expression) -> Option<Vec<T>>,
111    {
112        // behavior overridden by f
113        if let Some(vec) = f(self) {
114            return vec;
115        }
116        let recur = |expr: &'a Expression| expr.flat_map(working_set, f);
117        match &self.expr {
118            Expr::RowCondition(block_id)
119            | Expr::Subexpression(block_id)
120            | Expr::Block(block_id)
121            | Expr::Closure(block_id) => {
122                let block = working_set.get_block(block_id.to_owned());
123                block.flat_map(working_set, f)
124            }
125            Expr::Range(range) => [&range.from, &range.next, &range.to]
126                .iter()
127                .filter_map(|e| e.as_ref())
128                .flat_map(recur)
129                .collect(),
130            Expr::Call(call) => call
131                .arguments
132                .iter()
133                .filter_map(|arg| arg.expr())
134                .flat_map(recur)
135                .collect(),
136            Expr::ExternalCall(head, args) => recur(head.as_ref())
137                .into_iter()
138                .chain(args.iter().flat_map(|arg| recur(arg.expr())))
139                .collect(),
140            Expr::UnaryNot(expr) | Expr::Collect(_, expr) => recur(expr.as_ref()),
141            Expr::BinaryOp(lhs, op, rhs) => recur(lhs)
142                .into_iter()
143                .chain(recur(op))
144                .chain(recur(rhs))
145                .collect(),
146            Expr::MatchBlock(matches) => matches
147                .iter()
148                .flat_map(|(pattern, expr)| {
149                    pattern
150                        .flat_map(working_set, f)
151                        .into_iter()
152                        .chain(recur(expr))
153                })
154                .collect(),
155            Expr::List(items) => items
156                .iter()
157                .flat_map(|item| match item {
158                    ListItem::Item(expr) | ListItem::Spread(_, expr) => recur(expr),
159                })
160                .collect(),
161            Expr::Record(items) => items
162                .iter()
163                .flat_map(|item| match item {
164                    RecordItem::Spread(_, expr) => recur(expr),
165                    RecordItem::Pair(key, val) => [key, val].into_iter().flat_map(recur).collect(),
166                })
167                .collect(),
168            Expr::Table(table) => table
169                .columns
170                .iter()
171                .flat_map(recur)
172                .chain(table.rows.iter().flat_map(|row| row.iter().flat_map(recur)))
173                .collect(),
174            Expr::ValueWithUnit(vu) => recur(&vu.expr),
175            Expr::FullCellPath(fcp) => recur(&fcp.head),
176            Expr::Keyword(kw) => recur(&kw.expr),
177            Expr::StringInterpolation(vec) | Expr::GlobInterpolation(vec, _) => {
178                vec.iter().flat_map(recur).collect()
179            }
180            Expr::AttributeBlock(ab) => ab
181                .attributes
182                .iter()
183                .flat_map(|attr| recur(&attr.expr))
184                .chain(recur(&ab.item))
185                .collect(),
186
187            _ => Vec::new(),
188        }
189    }
190
191    fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
192    where
193        F: Fn(&'a Expression) -> FindMapResult<T>,
194    {
195        // behavior overridden by f
196        match f(self) {
197            FindMapResult::Found(t) => Some(t),
198            FindMapResult::Stop => None,
199            FindMapResult::Continue => {
200                let recur = |expr: &'a Expression| expr.find_map(working_set, f);
201                match &self.expr {
202                    Expr::RowCondition(block_id)
203                    | Expr::Subexpression(block_id)
204                    | Expr::Block(block_id)
205                    | Expr::Closure(block_id) => {
206                        let block = working_set.get_block(block_id.to_owned());
207                        block.find_map(working_set, f)
208                    }
209                    Expr::Range(range) => [&range.from, &range.next, &range.to]
210                        .iter()
211                        .find_map(|e| e.as_ref().and_then(recur)),
212                    Expr::Call(call) => call
213                        .arguments
214                        .iter()
215                        .find_map(|arg| arg.expr().and_then(recur)),
216                    Expr::ExternalCall(head, args) => {
217                        recur(head.as_ref()).or(args.iter().find_map(|arg| recur(arg.expr())))
218                    }
219                    Expr::UnaryNot(expr) | Expr::Collect(_, expr) => recur(expr.as_ref()),
220                    Expr::BinaryOp(lhs, op, rhs) => recur(lhs).or(recur(op)).or(recur(rhs)),
221                    Expr::MatchBlock(matches) => matches.iter().find_map(|(pattern, expr)| {
222                        pattern.find_map(working_set, f).or(recur(expr))
223                    }),
224                    Expr::List(items) => items.iter().find_map(|item| match item {
225                        ListItem::Item(expr) | ListItem::Spread(_, expr) => recur(expr),
226                    }),
227                    Expr::Record(items) => items.iter().find_map(|item| match item {
228                        RecordItem::Spread(_, expr) => recur(expr),
229                        RecordItem::Pair(key, val) => [key, val].into_iter().find_map(recur),
230                    }),
231                    Expr::Table(table) => table
232                        .columns
233                        .iter()
234                        .find_map(recur)
235                        .or(table.rows.iter().find_map(|row| row.iter().find_map(recur))),
236                    Expr::ValueWithUnit(vu) => recur(&vu.expr),
237                    Expr::FullCellPath(fcp) => recur(&fcp.head),
238                    Expr::Keyword(kw) => recur(&kw.expr),
239                    Expr::StringInterpolation(vec) | Expr::GlobInterpolation(vec, _) => {
240                        vec.iter().find_map(recur)
241                    }
242                    Expr::AttributeBlock(ab) => ab
243                        .attributes
244                        .iter()
245                        .find_map(|attr| recur(&attr.expr))
246                        .or_else(|| recur(&ab.item)),
247
248                    _ => None,
249                }
250            }
251        }
252    }
253}
254
255impl Traverse for MatchPattern {
256    fn flat_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Vec<T>
257    where
258        F: Fn(&'a Expression) -> Option<Vec<T>>,
259    {
260        let recur = |expr: &'a Expression| expr.flat_map(working_set, f);
261        let recur_pattern = |pattern: &'a MatchPattern| pattern.flat_map(working_set, f);
262        match &self.pattern {
263            Pattern::Expression(expr) => recur(expr),
264            Pattern::List(patterns) | Pattern::Or(patterns) => {
265                patterns.iter().flat_map(recur_pattern).collect()
266            }
267            Pattern::Record(entries) => {
268                entries.iter().flat_map(|(_, p)| recur_pattern(p)).collect()
269            }
270            _ => Vec::new(),
271        }
272        .into_iter()
273        .chain(self.guard.as_ref().map(|g| recur(g)).unwrap_or_default())
274        .collect()
275    }
276
277    fn find_map<'a, T, F>(&'a self, working_set: &'a StateWorkingSet, f: &F) -> Option<T>
278    where
279        F: Fn(&'a Expression) -> FindMapResult<T>,
280    {
281        let recur = |expr: &'a Expression| expr.find_map(working_set, f);
282        let recur_pattern = |pattern: &'a MatchPattern| pattern.find_map(working_set, f);
283        match &self.pattern {
284            Pattern::Expression(expr) => recur(expr),
285            Pattern::List(patterns) | Pattern::Or(patterns) => {
286                patterns.iter().find_map(recur_pattern)
287            }
288            Pattern::Record(entries) => entries.iter().find_map(|(_, p)| recur_pattern(p)),
289            _ => None,
290        }
291        .or(self.guard.as_ref().and_then(|g| recur(g)))
292    }
293}