promql_parser/util/
visitor.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::parser::{
16    AggregateExpr, BinaryExpr, Expr, Extension, ParenExpr, SubqueryExpr, UnaryExpr,
17};
18
19/// Trait that implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern)
20/// for a depth first walk on [Expr] AST. [`pre_visit`](ExprVisitor::pre_visit) is called
21/// before any children are visited, and then [`post_visit`](ExprVisitor::post_visit) is called
22/// after all children have been visited. Only [`pre_visit`](ExprVisitor::pre_visit) is required.
23pub trait ExprVisitor {
24    type Error;
25
26    /// Called before any children are visited. Return `Ok(false)` to cut short the recursion
27    /// (skip traversing and return).
28    fn pre_visit(&mut self, plan: &Expr) -> Result<bool, Self::Error>;
29
30    /// Called after all children are visited. Return `Ok(false)` to cut short the recursion
31    /// (skip traversing and return).
32    fn post_visit(&mut self, _plan: &Expr) -> Result<bool, Self::Error> {
33        Ok(true)
34    }
35}
36
37/// A util function that traverses an AST [Expr] in depth-first order. Returns
38/// `Ok(true)` if all nodes were visited, and `Ok(false)` if any call to
39/// [`pre_visit`](ExprVisitor::pre_visit) or [`post_visit`](ExprVisitor::post_visit)
40/// returned `Ok(false)` and may have cut short the recursion.
41pub fn walk_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> Result<bool, V::Error> {
42    if !visitor.pre_visit(expr)? {
43        return Ok(false);
44    }
45
46    let recurse = match expr {
47        Expr::Aggregate(AggregateExpr { expr, .. }) => walk_expr(visitor, expr)?,
48        Expr::Unary(UnaryExpr { expr }) => walk_expr(visitor, expr)?,
49        Expr::Binary(BinaryExpr { lhs, rhs, .. }) => {
50            walk_expr(visitor, lhs)? && walk_expr(visitor, rhs)?
51        }
52        Expr::Paren(ParenExpr { expr }) => walk_expr(visitor, expr)?,
53        Expr::Subquery(SubqueryExpr { expr, .. }) => walk_expr(visitor, expr)?,
54        Expr::Extension(Extension { expr }) => {
55            for child in expr.children() {
56                if !walk_expr(visitor, child)? {
57                    return Ok(false);
58                }
59            }
60            true
61        }
62        Expr::Call(call) => {
63            for func_argument_expr in &call.args.args {
64                if !walk_expr(visitor, func_argument_expr)? {
65                    return Ok(false);
66                }
67            }
68            true
69        }
70        Expr::NumberLiteral(_)
71        | Expr::StringLiteral(_)
72        | Expr::VectorSelector(_)
73        | Expr::MatrixSelector(_) => true,
74    };
75
76    if !recurse {
77        return Ok(false);
78    }
79
80    if !visitor.post_visit(expr)? {
81        return Ok(false);
82    }
83
84    Ok(true)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::label::MatchOp;
91    use crate::parser;
92    use crate::parser::VectorSelector;
93
94    struct NamespaceVisitor {
95        namespace: String,
96    }
97
98    fn vector_selector_includes_namespace(
99        namespace: &str,
100        vector_selector: &VectorSelector,
101    ) -> bool {
102        let mut includes_namespace = false;
103        for filters in &vector_selector.matchers.matchers {
104            if filters.name.eq("namespace")
105                && filters.value.eq(namespace)
106                && filters.op == MatchOp::Equal
107            {
108                includes_namespace = true;
109                break;
110            }
111        }
112        includes_namespace
113    }
114
115    impl ExprVisitor for NamespaceVisitor {
116        type Error = &'static str;
117
118        fn pre_visit(&mut self, expr: &Expr) -> Result<bool, Self::Error> {
119            match expr {
120                Expr::VectorSelector(vector_selector) => {
121                    let included = vector_selector_includes_namespace(
122                        self.namespace.as_str(),
123                        vector_selector,
124                    );
125                    return Ok(included);
126                }
127                Expr::MatrixSelector(matrix_selector) => {
128                    let included = vector_selector_includes_namespace(
129                        self.namespace.as_str(),
130                        &matrix_selector.vs,
131                    );
132                    return Ok(included);
133                }
134                Expr::NumberLiteral(_) | Expr::StringLiteral(_) => return Ok(false),
135                _ => (),
136            }
137            Ok(true)
138        }
139    }
140
141    #[test]
142    fn test_check_for_namespace_basic_query() {
143        let expr = "pg_stat_activity_count{namespace=\"sample\"}";
144        let ast = parser::parse(expr).unwrap();
145        let mut visitor = NamespaceVisitor {
146            namespace: "sample".to_string(),
147        };
148        assert!(walk_expr(&mut visitor, &ast).unwrap());
149    }
150
151    #[test]
152    fn test_check_for_namespace_label_present() {
153        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
154        let ast = parser::parse(expr).unwrap();
155        let mut visitor = NamespaceVisitor {
156            namespace: "sample".to_string(),
157        };
158        assert!(walk_expr(&mut visitor, &ast).unwrap());
159    }
160
161    #[test]
162    fn test_check_for_namespace_label_wrong_namespace() {
163        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
164        let ast = parser::parse(expr).unwrap();
165        let mut visitor = NamespaceVisitor {
166            namespace: "foobar".to_string(),
167        };
168        assert!(!walk_expr(&mut visitor, &ast).unwrap());
169    }
170
171    #[test]
172    fn test_check_for_namespace_label_missing_namespace() {
173        let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{}[1h])))";
174        let ast = parser::parse(expr).unwrap();
175        let mut visitor = NamespaceVisitor {
176            namespace: "sample".to_string(),
177        };
178        assert!(!walk_expr(&mut visitor, &ast).unwrap());
179    }
180
181    #[test]
182    fn test_literal_expr() {
183        let mut visitor = NamespaceVisitor {
184            namespace: "sample".to_string(),
185        };
186
187        let ast = parser::parse("1").unwrap();
188        assert!(!walk_expr(&mut visitor, &ast).unwrap());
189
190        let ast = parser::parse("1 + 1").unwrap();
191        assert!(!walk_expr(&mut visitor, &ast).unwrap());
192
193        let ast = parser::parse(r#""1""#).unwrap();
194        assert!(!walk_expr(&mut visitor, &ast).unwrap());
195    }
196
197    #[test]
198    fn test_binary_expr() {
199        let mut visitor = NamespaceVisitor {
200            namespace: "sample".to_string(),
201        };
202
203        let ast = parser::parse(
204            "pg_stat_activity_count{namespace=\"sample\"} + pg_stat_activity_count{}",
205        )
206        .unwrap();
207        assert!(!walk_expr(&mut visitor, &ast).unwrap());
208
209        let ast = parser::parse(
210            "pg_stat_activity_count{} - pg_stat_activity_count{namespace=\"sample\"}",
211        )
212        .unwrap();
213        assert!(!walk_expr(&mut visitor, &ast).unwrap());
214
215        let ast = parser::parse("pg_stat_activity_count{} * pg_stat_activity_count{}").unwrap();
216        assert!(!walk_expr(&mut visitor, &ast).unwrap());
217
218        let ast = parser::parse("pg_stat_activity_count{namespace=\"sample\"} / 1").unwrap();
219        assert!(!walk_expr(&mut visitor, &ast).unwrap());
220
221        let ast = parser::parse("1 % pg_stat_activity_count{namespace=\"sample\"}").unwrap();
222        assert!(!walk_expr(&mut visitor, &ast).unwrap());
223
224        let ast = parser::parse("pg_stat_activity_count{namespace=\"sample\"} ^ pg_stat_activity_count{namespace=\"sample\"}").unwrap();
225        assert!(walk_expr(&mut visitor, &ast).unwrap());
226    }
227}