Skip to main content

promql_parser/util/
visitor.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::parser::{
16    AggregateExpr, BinaryExpr, Expr, Extension, ParenExpr, SubqueryExpr, UnaryExpr,
17};
18
19/// Trait that implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern)
20/// for a depth first walk on [Expr] AST. [`pre_visit`](ExprVisitor::pre_visit) is called
21/// before any children are visited, and then [`post_visit`](ExprVisitor::post_visit) is called
22/// after all children have been visited. Only [`pre_visit`](ExprVisitor::pre_visit) is required.
23pub trait ExprVisitor {
24    type Error;
25
26    /// Called before any children are visited. Return `Ok(false)` to cut short the recursion
27    /// (skip traversing and return).
28    fn pre_visit(&mut self, plan: &Expr) -> Result<bool, Self::Error>;
29
30    /// Called after all children are visited. Return `Ok(false)` to cut short the recursion
31    /// (skip traversing and return).
32    fn post_visit(&mut self, _plan: &Expr) -> Result<bool, Self::Error> {
33        Ok(true)
34    }
35}
36
37/// Trait that implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern)
38/// for a depth first walk on [Expr] AST. [`pre_visit`](ExprVisitorMut::pre_visit) is called
39/// before any children are visited, and then [`post_visit`](ExprVisitorMut::post_visit) is called
40/// after all children have been visited. Only [`pre_visit`](ExprVisitorMut::pre_visit) is required.
41pub trait ExprVisitorMut {
42    type Error;
43
44    /// Called before any children are visited. Return `Ok(false)` to cut short the recursion
45    /// (skip traversing and return).
46    fn pre_visit(&mut self, plan: &mut Expr) -> Result<bool, Self::Error>;
47
48    /// Called after all children are visited. Return `Ok(false)` to cut short the recursion
49    /// (skip traversing and return).
50    fn post_visit(&mut self, _plan: &mut Expr) -> Result<bool, Self::Error> {
51        Ok(true)
52    }
53}
54
55/// A util function that traverses an AST [Expr] in depth-first order. Returns
56/// `Ok(true)` if all nodes were visited, and `Ok(false)` if any call to
57/// [`pre_visit`](ExprVisitor::pre_visit) or [`post_visit`](ExprVisitor::post_visit)
58/// returned `Ok(false)` and may have cut short the recursion.
59pub fn walk_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> Result<bool, V::Error> {
60    if !visitor.pre_visit(expr)? {
61        return Ok(false);
62    }
63
64    let recurse = match expr {
65        Expr::Aggregate(AggregateExpr { expr, .. }) => walk_expr(visitor, expr)?,
66        Expr::Unary(UnaryExpr { expr }) => walk_expr(visitor, expr)?,
67        Expr::Binary(BinaryExpr { lhs, rhs, .. }) => {
68            walk_expr(visitor, lhs)? && walk_expr(visitor, rhs)?
69        }
70        Expr::Paren(ParenExpr { expr }) => walk_expr(visitor, expr)?,
71        Expr::Subquery(SubqueryExpr { expr, .. }) => walk_expr(visitor, expr)?,
72        Expr::Extension(Extension { expr }) => {
73            for child in expr.children() {
74                if !walk_expr(visitor, child)? {
75                    return Ok(false);
76                }
77            }
78            true
79        }
80        Expr::Call(call) => {
81            for func_argument_expr in &call.args.args {
82                if !walk_expr(visitor, func_argument_expr)? {
83                    return Ok(false);
84                }
85            }
86            true
87        }
88        Expr::NumberLiteral(_)
89        | Expr::StringLiteral(_)
90        | Expr::VectorSelector(_)
91        | Expr::MatrixSelector(_) => true,
92    };
93
94    if !recurse {
95        return Ok(false);
96    }
97
98    if !visitor.post_visit(expr)? {
99        return Ok(false);
100    }
101
102    Ok(true)
103}
104
105/// A util function that traverses an AST [Expr] mutably in depth-first order.
106/// Returns `Ok(true)` if all nodes were visited, and `Ok(false)` if any call to
107/// [`pre_visit`](ExprVisitorMut::pre_visit) or [`post_visit`](ExprVisitorMut::post_visit)
108/// returned `Ok(false)` and may have cut short the recursion.
109pub fn walk_expr_mut<V: ExprVisitorMut>(
110    visitor: &mut V,
111    expr: &mut Expr,
112) -> Result<bool, V::Error> {
113    if !visitor.pre_visit(expr)? {
114        return Ok(false);
115    }
116
117    let recurse = match expr {
118        Expr::Aggregate(AggregateExpr { expr, .. }) => walk_expr_mut(visitor, expr)?,
119        Expr::Unary(UnaryExpr { expr }) => walk_expr_mut(visitor, expr)?,
120        Expr::Binary(BinaryExpr { lhs, rhs, .. }) => {
121            walk_expr_mut(visitor, lhs)? && walk_expr_mut(visitor, rhs)?
122        }
123        Expr::Paren(ParenExpr { expr }) => walk_expr_mut(visitor, expr)?,
124        Expr::Subquery(SubqueryExpr { expr, .. }) => walk_expr_mut(visitor, expr)?,
125        Expr::Extension(Extension { expr }) => {
126            let mut children = expr.children().to_vec();
127            let mut recurse = true;
128            for child in &mut children {
129                if !walk_expr_mut(visitor, child)? {
130                    recurse = false;
131                    break;
132                }
133            }
134            *expr = expr.with_new_children(children);
135            recurse
136        }
137        Expr::Call(call) => {
138            for func_argument_expr in &mut call.args.args {
139                if !walk_expr_mut(visitor, func_argument_expr)? {
140                    return Ok(false);
141                }
142            }
143            true
144        }
145        Expr::NumberLiteral(_)
146        | Expr::StringLiteral(_)
147        | Expr::VectorSelector(_)
148        | Expr::MatrixSelector(_) => true,
149    };
150
151    if !recurse {
152        return Ok(false);
153    }
154
155    if !visitor.post_visit(expr)? {
156        return Ok(false);
157    }
158
159    Ok(true)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::label::MatchOp;
166    use crate::parser;
167    use crate::parser::ast::ExtensionExpr;
168    use crate::parser::value::ValueType;
169    use crate::parser::VectorSelector;
170    use std::sync::Arc;
171
172    struct NamespaceVisitor {
173        namespace: String,
174    }
175
176    fn vector_selector_includes_namespace(
177        namespace: &str,
178        vector_selector: &VectorSelector,
179    ) -> bool {
180        let mut includes_namespace = false;
181        for filters in &vector_selector.matchers.matchers {
182            if filters.name.eq("namespace")
183                && filters.value.eq(namespace)
184                && filters.op == MatchOp::Equal
185            {
186                includes_namespace = true;
187                break;
188            }
189        }
190        includes_namespace
191    }
192
193    impl ExprVisitor for NamespaceVisitor {
194        type Error = &'static str;
195
196        fn pre_visit(&mut self, expr: &Expr) -> Result<bool, Self::Error> {
197            match expr {
198                Expr::VectorSelector(vector_selector) => {
199                    let included = vector_selector_includes_namespace(
200                        self.namespace.as_str(),
201                        vector_selector,
202                    );
203                    return Ok(included);
204                }
205                Expr::MatrixSelector(matrix_selector) => {
206                    let included = vector_selector_includes_namespace(
207                        self.namespace.as_str(),
208                        &matrix_selector.vs,
209                    );
210                    return Ok(included);
211                }
212                Expr::NumberLiteral(_) | Expr::StringLiteral(_) => return Ok(false),
213                _ => (),
214            }
215            Ok(true)
216        }
217    }
218
219    #[test]
220    fn test_check_for_namespace_basic_query() {
221        let expr = "pg_stat_activity_count{namespace=\"sample\"}";
222        let ast = parser::parse(expr).unwrap();
223        let mut visitor = NamespaceVisitor {
224            namespace: "sample".to_string(),
225        };
226        assert!(walk_expr(&mut visitor, &ast).unwrap());
227    }
228
229    #[test]
230    fn test_check_for_namespace_label_present() {
231        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
232        let ast = parser::parse(expr).unwrap();
233        let mut visitor = NamespaceVisitor {
234            namespace: "sample".to_string(),
235        };
236        assert!(walk_expr(&mut visitor, &ast).unwrap());
237    }
238
239    #[test]
240    fn test_check_for_namespace_label_wrong_namespace() {
241        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
242        let ast = parser::parse(expr).unwrap();
243        let mut visitor = NamespaceVisitor {
244            namespace: "foobar".to_string(),
245        };
246        assert!(!walk_expr(&mut visitor, &ast).unwrap());
247    }
248
249    #[test]
250    fn test_check_for_namespace_label_missing_namespace() {
251        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{}[1h])))";
252        let ast = parser::parse(expr).unwrap();
253        let mut visitor = NamespaceVisitor {
254            namespace: "sample".to_string(),
255        };
256        assert!(!walk_expr(&mut visitor, &ast).unwrap());
257    }
258
259    #[test]
260    fn test_literal_expr() {
261        let mut visitor = NamespaceVisitor {
262            namespace: "sample".to_string(),
263        };
264
265        let ast = parser::parse("1").unwrap();
266        assert!(!walk_expr(&mut visitor, &ast).unwrap());
267
268        let ast = parser::parse("1 + 1").unwrap();
269        assert!(!walk_expr(&mut visitor, &ast).unwrap());
270
271        let ast = parser::parse(r#""1""#).unwrap();
272        assert!(!walk_expr(&mut visitor, &ast).unwrap());
273    }
274
275    #[test]
276    fn test_binary_expr() {
277        let mut visitor = NamespaceVisitor {
278            namespace: "sample".to_string(),
279        };
280
281        let ast = parser::parse(
282            "pg_stat_activity_count{namespace=\"sample\"} + pg_stat_activity_count{}",
283        )
284        .unwrap();
285        assert!(!walk_expr(&mut visitor, &ast).unwrap());
286
287        let ast = parser::parse(
288            "pg_stat_activity_count{} - pg_stat_activity_count{namespace=\"sample\"}",
289        )
290        .unwrap();
291        assert!(!walk_expr(&mut visitor, &ast).unwrap());
292
293        let ast = parser::parse("pg_stat_activity_count{} * pg_stat_activity_count{}").unwrap();
294        assert!(!walk_expr(&mut visitor, &ast).unwrap());
295
296        let ast = parser::parse("pg_stat_activity_count{namespace=\"sample\"} / 1").unwrap();
297        assert!(!walk_expr(&mut visitor, &ast).unwrap());
298
299        let ast = parser::parse("1 % pg_stat_activity_count{namespace=\"sample\"}").unwrap();
300        assert!(!walk_expr(&mut visitor, &ast).unwrap());
301
302        let ast = parser::parse(
303            "pg_stat_activity_count{namespace=\"sample\"} ^ \
304             pg_stat_activity_count{namespace=\"sample\"}",
305        )
306        .unwrap();
307        assert!(walk_expr(&mut visitor, &ast).unwrap());
308    }
309
310    struct LabelInjectorVisitor {
311        label_name: String,
312        label_value: String,
313        inject_once: bool,
314    }
315
316    impl ExprVisitorMut for LabelInjectorVisitor {
317        type Error = &'static str;
318
319        fn pre_visit(&mut self, expr: &mut Expr) -> Result<bool, Self::Error> {
320            if let Expr::VectorSelector(vector_selector) = expr {
321                vector_selector
322                    .matchers
323                    .matchers
324                    .push(crate::label::Matcher {
325                        op: MatchOp::Equal,
326                        name: self.label_name.clone(),
327                        value: self.label_value.clone(),
328                    });
329
330                if self.inject_once {
331                    return Ok(false);
332                }
333            }
334            Ok(true)
335        }
336    }
337
338    #[test]
339    fn test_inject_label_into_vector_selector() {
340        let expr = "pg_stat_activity_count{}";
341        let mut ast = parser::parse(expr).unwrap();
342
343        let mut visitor = LabelInjectorVisitor {
344            label_name: "namespace".to_string(),
345            label_value: "injected".to_string(),
346            inject_once: false,
347        };
348
349        assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
350
351        if let Expr::VectorSelector(vs) = &ast {
352            assert_eq!(vs.matchers.matchers.len(), 1);
353            assert_eq!(vs.matchers.matchers[0].name, "namespace");
354            assert_eq!(vs.matchers.matchers[0].value, "injected");
355            assert_eq!(vs.matchers.matchers[0].op, MatchOp::Equal);
356        } else {
357            panic!("expected VectorSelector");
358        }
359    }
360
361    #[test]
362    fn test_inject_label_into_nested_expr() {
363        let expr = "sum(pg_stat_activity_count{})";
364        let mut ast = parser::parse(expr).unwrap();
365
366        let mut visitor = LabelInjectorVisitor {
367            label_name: "env".to_string(),
368            label_value: "prod".to_string(),
369            inject_once: false,
370        };
371
372        assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
373
374        if let Expr::Aggregate(agg) = &ast {
375            if let Expr::VectorSelector(vs) = &*agg.expr {
376                assert_eq!(vs.matchers.matchers.len(), 1);
377                assert_eq!(vs.matchers.matchers[0].name, "env");
378                assert_eq!(vs.matchers.matchers[0].value, "prod");
379            } else {
380                panic!("expected VectorSelector inside Aggregate");
381            }
382        } else {
383            panic!("expected Aggregate");
384        }
385    }
386
387    #[test]
388    fn test_inject_label_into_multiple_selectors() {
389        let expr = "pg_stat_activity_count{} + pg_stat_activity_count{}";
390        let mut ast = parser::parse(expr).unwrap();
391
392        let mut visitor = LabelInjectorVisitor {
393            label_name: "env".to_string(),
394            label_value: "prod".to_string(),
395            inject_once: false,
396        };
397
398        assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
399
400        if let Expr::Binary(binary) = &ast {
401            if let Expr::VectorSelector(lhs_vs) = &*binary.lhs {
402                assert_eq!(lhs_vs.matchers.matchers.len(), 1);
403                assert_eq!(lhs_vs.matchers.matchers[0].name, "env");
404                assert_eq!(lhs_vs.matchers.matchers[0].value, "prod");
405            } else {
406                panic!("expected LHS to be a VectorSelector");
407            }
408
409            if let Expr::VectorSelector(rhs_vs) = &*binary.rhs {
410                assert_eq!(rhs_vs.matchers.matchers.len(), 1);
411                assert_eq!(rhs_vs.matchers.matchers[0].name, "env");
412                assert_eq!(rhs_vs.matchers.matchers[0].value, "prod");
413            } else {
414                panic!("expected RHS to be a VectorSelector");
415            }
416        } else {
417            panic!("expected a Binary expression");
418        }
419    }
420
421    #[derive(Debug)]
422    struct DummyExtension {
423        children: Vec<Expr>,
424    }
425
426    impl ExtensionExpr for DummyExtension {
427        fn as_any(&self) -> &dyn std::any::Any {
428            self
429        }
430        fn name(&self) -> &str {
431            "dummy"
432        }
433        fn value_type(&self) -> ValueType {
434            ValueType::Vector
435        }
436        fn children(&self) -> &[Expr] {
437            &self.children
438        }
439        fn with_new_children(&self, children: Vec<Expr>) -> Arc<dyn ExtensionExpr> {
440            Arc::new(DummyExtension { children })
441        }
442    }
443
444    #[test]
445    fn test_inject_label_into_extension() {
446        let inner_expr = parser::parse("pg_stat_activity_count{}").unwrap();
447        let dummy_ext = DummyExtension {
448            children: vec![inner_expr],
449        };
450
451        let shared_arc = std::sync::Arc::new(dummy_ext);
452        // Clone the Arc to simulate multiple references to the same extension expression.
453        let _second_reference = std::sync::Arc::clone(&shared_arc);
454
455        let mut ast = Expr::Extension(parser::Extension { expr: shared_arc });
456
457        let mut visitor = LabelInjectorVisitor {
458            label_name: "env".to_string(),
459            label_value: "prod".to_string(),
460            inject_once: false,
461        };
462        assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
463
464        // The extension's children should be traversed and mutated like any other expression.
465        if let Expr::Extension(ext) = &ast {
466            let children = ext.expr.children();
467            assert_eq!(children.len(), 1);
468            if let Expr::VectorSelector(vs) = &children[0] {
469                assert_eq!(vs.matchers.matchers.len(), 1);
470                assert_eq!(vs.matchers.matchers[0].name, "env");
471                assert_eq!(vs.matchers.matchers[0].value, "prod");
472            } else {
473                panic!("expected inner expression to be a VectorSelector");
474            }
475        } else {
476            panic!("expected Extension expression");
477        }
478    }
479
480    #[test]
481    fn test_extension_partial_mutation_on_short_circuit() {
482        let child1 = parser::parse("metric_a{}").unwrap();
483        let child2 = parser::parse("metric_b{}").unwrap();
484
485        let dummy_ext = DummyExtension {
486            children: vec![child1, child2],
487        };
488
489        let mut ast = Expr::Extension(parser::Extension {
490            expr: std::sync::Arc::new(dummy_ext),
491        });
492
493        let mut visitor = LabelInjectorVisitor {
494            label_name: "env".to_string(),
495            label_value: "prod".to_string(),
496            inject_once: true,
497        };
498
499        // The walker returns Ok(false) because it short-circuits after mutating the first child.
500        assert_eq!(walk_expr_mut(&mut visitor, &mut ast), Ok(false));
501
502        if let Expr::Extension(ext) = &ast {
503            let children = ext.expr.children();
504            assert_eq!(children.len(), 2);
505
506            // The first child should have been mutated.
507            if let Expr::VectorSelector(vs) = &children[0] {
508                assert_eq!(vs.matchers.matchers.len(), 1);
509                assert_eq!(vs.matchers.matchers[0].name, "env");
510                assert_eq!(vs.matchers.matchers[0].value, "prod");
511            } else {
512                panic!("expected first child to be a VectorSelector");
513            }
514
515            // The second child remains untouched.
516            if let Expr::VectorSelector(vs) = &children[1] {
517                assert!(vs.matchers.matchers.is_empty());
518            } else {
519                panic!("expected second child to be a VectorSelector");
520            }
521        } else {
522            panic!("expected Extension expression");
523        }
524    }
525}