1use crate::parser::{
16 AggregateExpr, BinaryExpr, Expr, Extension, ParenExpr, SubqueryExpr, UnaryExpr,
17};
18
19pub trait ExprVisitor {
24 type Error;
25
26 fn pre_visit(&mut self, plan: &Expr) -> Result<bool, Self::Error>;
29
30 fn post_visit(&mut self, _plan: &Expr) -> Result<bool, Self::Error> {
33 Ok(true)
34 }
35}
36
37pub trait ExprVisitorMut {
42 type Error;
43
44 fn pre_visit(&mut self, plan: &mut Expr) -> Result<bool, Self::Error>;
47
48 fn post_visit(&mut self, _plan: &mut Expr) -> Result<bool, Self::Error> {
51 Ok(true)
52 }
53}
54
55pub fn walk_expr<V: ExprVisitor>(visitor: &mut V, expr: &Expr) -> Result<bool, V::Error> {
60 if !visitor.pre_visit(expr)? {
61 return Ok(false);
62 }
63
64 let recurse = match expr {
65 Expr::Aggregate(AggregateExpr { expr, .. }) => walk_expr(visitor, expr)?,
66 Expr::Unary(UnaryExpr { expr }) => walk_expr(visitor, expr)?,
67 Expr::Binary(BinaryExpr { lhs, rhs, .. }) => {
68 walk_expr(visitor, lhs)? && walk_expr(visitor, rhs)?
69 }
70 Expr::Paren(ParenExpr { expr }) => walk_expr(visitor, expr)?,
71 Expr::Subquery(SubqueryExpr { expr, .. }) => walk_expr(visitor, expr)?,
72 Expr::Extension(Extension { expr }) => {
73 for child in expr.children() {
74 if !walk_expr(visitor, child)? {
75 return Ok(false);
76 }
77 }
78 true
79 }
80 Expr::Call(call) => {
81 for func_argument_expr in &call.args.args {
82 if !walk_expr(visitor, func_argument_expr)? {
83 return Ok(false);
84 }
85 }
86 true
87 }
88 Expr::NumberLiteral(_)
89 | Expr::StringLiteral(_)
90 | Expr::VectorSelector(_)
91 | Expr::MatrixSelector(_) => true,
92 };
93
94 if !recurse {
95 return Ok(false);
96 }
97
98 if !visitor.post_visit(expr)? {
99 return Ok(false);
100 }
101
102 Ok(true)
103}
104
105pub fn walk_expr_mut<V: ExprVisitorMut>(
110 visitor: &mut V,
111 expr: &mut Expr,
112) -> Result<bool, V::Error> {
113 if !visitor.pre_visit(expr)? {
114 return Ok(false);
115 }
116
117 let recurse = match expr {
118 Expr::Aggregate(AggregateExpr { expr, .. }) => walk_expr_mut(visitor, expr)?,
119 Expr::Unary(UnaryExpr { expr }) => walk_expr_mut(visitor, expr)?,
120 Expr::Binary(BinaryExpr { lhs, rhs, .. }) => {
121 walk_expr_mut(visitor, lhs)? && walk_expr_mut(visitor, rhs)?
122 }
123 Expr::Paren(ParenExpr { expr }) => walk_expr_mut(visitor, expr)?,
124 Expr::Subquery(SubqueryExpr { expr, .. }) => walk_expr_mut(visitor, expr)?,
125 Expr::Extension(Extension { expr }) => {
126 let mut children = expr.children().to_vec();
127 let mut recurse = true;
128 for child in &mut children {
129 if !walk_expr_mut(visitor, child)? {
130 recurse = false;
131 break;
132 }
133 }
134 *expr = expr.with_new_children(children);
135 recurse
136 }
137 Expr::Call(call) => {
138 for func_argument_expr in &mut call.args.args {
139 if !walk_expr_mut(visitor, func_argument_expr)? {
140 return Ok(false);
141 }
142 }
143 true
144 }
145 Expr::NumberLiteral(_)
146 | Expr::StringLiteral(_)
147 | Expr::VectorSelector(_)
148 | Expr::MatrixSelector(_) => true,
149 };
150
151 if !recurse {
152 return Ok(false);
153 }
154
155 if !visitor.post_visit(expr)? {
156 return Ok(false);
157 }
158
159 Ok(true)
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::label::MatchOp;
166 use crate::parser;
167 use crate::parser::ast::ExtensionExpr;
168 use crate::parser::value::ValueType;
169 use crate::parser::VectorSelector;
170 use std::sync::Arc;
171
172 struct NamespaceVisitor {
173 namespace: String,
174 }
175
176 fn vector_selector_includes_namespace(
177 namespace: &str,
178 vector_selector: &VectorSelector,
179 ) -> bool {
180 let mut includes_namespace = false;
181 for filters in &vector_selector.matchers.matchers {
182 if filters.name.eq("namespace")
183 && filters.value.eq(namespace)
184 && filters.op == MatchOp::Equal
185 {
186 includes_namespace = true;
187 break;
188 }
189 }
190 includes_namespace
191 }
192
193 impl ExprVisitor for NamespaceVisitor {
194 type Error = &'static str;
195
196 fn pre_visit(&mut self, expr: &Expr) -> Result<bool, Self::Error> {
197 match expr {
198 Expr::VectorSelector(vector_selector) => {
199 let included = vector_selector_includes_namespace(
200 self.namespace.as_str(),
201 vector_selector,
202 );
203 return Ok(included);
204 }
205 Expr::MatrixSelector(matrix_selector) => {
206 let included = vector_selector_includes_namespace(
207 self.namespace.as_str(),
208 &matrix_selector.vs,
209 );
210 return Ok(included);
211 }
212 Expr::NumberLiteral(_) | Expr::StringLiteral(_) => return Ok(false),
213 _ => (),
214 }
215 Ok(true)
216 }
217 }
218
219 #[test]
220 fn test_check_for_namespace_basic_query() {
221 let expr = "pg_stat_activity_count{namespace=\"sample\"}";
222 let ast = parser::parse(expr).unwrap();
223 let mut visitor = NamespaceVisitor {
224 namespace: "sample".to_string(),
225 };
226 assert!(walk_expr(&mut visitor, &ast).unwrap());
227 }
228
229 #[test]
230 fn test_check_for_namespace_label_present() {
231 let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
232 let ast = parser::parse(expr).unwrap();
233 let mut visitor = NamespaceVisitor {
234 namespace: "sample".to_string(),
235 };
236 assert!(walk_expr(&mut visitor, &ast).unwrap());
237 }
238
239 #[test]
240 fn test_check_for_namespace_label_wrong_namespace() {
241 let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{namespace=\"sample\"}[1h])))";
242 let ast = parser::parse(expr).unwrap();
243 let mut visitor = NamespaceVisitor {
244 namespace: "foobar".to_string(),
245 };
246 assert!(!walk_expr(&mut visitor, &ast).unwrap());
247 }
248
249 #[test]
250 fn test_check_for_namespace_label_missing_namespace() {
251 let expr = "(sum by (namespace) (max_over_time(pg_stat_activity_count{}[1h])))";
252 let ast = parser::parse(expr).unwrap();
253 let mut visitor = NamespaceVisitor {
254 namespace: "sample".to_string(),
255 };
256 assert!(!walk_expr(&mut visitor, &ast).unwrap());
257 }
258
259 #[test]
260 fn test_literal_expr() {
261 let mut visitor = NamespaceVisitor {
262 namespace: "sample".to_string(),
263 };
264
265 let ast = parser::parse("1").unwrap();
266 assert!(!walk_expr(&mut visitor, &ast).unwrap());
267
268 let ast = parser::parse("1 + 1").unwrap();
269 assert!(!walk_expr(&mut visitor, &ast).unwrap());
270
271 let ast = parser::parse(r#""1""#).unwrap();
272 assert!(!walk_expr(&mut visitor, &ast).unwrap());
273 }
274
275 #[test]
276 fn test_binary_expr() {
277 let mut visitor = NamespaceVisitor {
278 namespace: "sample".to_string(),
279 };
280
281 let ast = parser::parse(
282 "pg_stat_activity_count{namespace=\"sample\"} + pg_stat_activity_count{}",
283 )
284 .unwrap();
285 assert!(!walk_expr(&mut visitor, &ast).unwrap());
286
287 let ast = parser::parse(
288 "pg_stat_activity_count{} - pg_stat_activity_count{namespace=\"sample\"}",
289 )
290 .unwrap();
291 assert!(!walk_expr(&mut visitor, &ast).unwrap());
292
293 let ast = parser::parse("pg_stat_activity_count{} * pg_stat_activity_count{}").unwrap();
294 assert!(!walk_expr(&mut visitor, &ast).unwrap());
295
296 let ast = parser::parse("pg_stat_activity_count{namespace=\"sample\"} / 1").unwrap();
297 assert!(!walk_expr(&mut visitor, &ast).unwrap());
298
299 let ast = parser::parse("1 % pg_stat_activity_count{namespace=\"sample\"}").unwrap();
300 assert!(!walk_expr(&mut visitor, &ast).unwrap());
301
302 let ast = parser::parse(
303 "pg_stat_activity_count{namespace=\"sample\"} ^ \
304 pg_stat_activity_count{namespace=\"sample\"}",
305 )
306 .unwrap();
307 assert!(walk_expr(&mut visitor, &ast).unwrap());
308 }
309
310 struct LabelInjectorVisitor {
311 label_name: String,
312 label_value: String,
313 inject_once: bool,
314 }
315
316 impl ExprVisitorMut for LabelInjectorVisitor {
317 type Error = &'static str;
318
319 fn pre_visit(&mut self, expr: &mut Expr) -> Result<bool, Self::Error> {
320 if let Expr::VectorSelector(vector_selector) = expr {
321 vector_selector
322 .matchers
323 .matchers
324 .push(crate::label::Matcher {
325 op: MatchOp::Equal,
326 name: self.label_name.clone(),
327 value: self.label_value.clone(),
328 });
329
330 if self.inject_once {
331 return Ok(false);
332 }
333 }
334 Ok(true)
335 }
336 }
337
338 #[test]
339 fn test_inject_label_into_vector_selector() {
340 let expr = "pg_stat_activity_count{}";
341 let mut ast = parser::parse(expr).unwrap();
342
343 let mut visitor = LabelInjectorVisitor {
344 label_name: "namespace".to_string(),
345 label_value: "injected".to_string(),
346 inject_once: false,
347 };
348
349 assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
350
351 if let Expr::VectorSelector(vs) = &ast {
352 assert_eq!(vs.matchers.matchers.len(), 1);
353 assert_eq!(vs.matchers.matchers[0].name, "namespace");
354 assert_eq!(vs.matchers.matchers[0].value, "injected");
355 assert_eq!(vs.matchers.matchers[0].op, MatchOp::Equal);
356 } else {
357 panic!("expected VectorSelector");
358 }
359 }
360
361 #[test]
362 fn test_inject_label_into_nested_expr() {
363 let expr = "sum(pg_stat_activity_count{})";
364 let mut ast = parser::parse(expr).unwrap();
365
366 let mut visitor = LabelInjectorVisitor {
367 label_name: "env".to_string(),
368 label_value: "prod".to_string(),
369 inject_once: false,
370 };
371
372 assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
373
374 if let Expr::Aggregate(agg) = &ast {
375 if let Expr::VectorSelector(vs) = &*agg.expr {
376 assert_eq!(vs.matchers.matchers.len(), 1);
377 assert_eq!(vs.matchers.matchers[0].name, "env");
378 assert_eq!(vs.matchers.matchers[0].value, "prod");
379 } else {
380 panic!("expected VectorSelector inside Aggregate");
381 }
382 } else {
383 panic!("expected Aggregate");
384 }
385 }
386
387 #[test]
388 fn test_inject_label_into_multiple_selectors() {
389 let expr = "pg_stat_activity_count{} + pg_stat_activity_count{}";
390 let mut ast = parser::parse(expr).unwrap();
391
392 let mut visitor = LabelInjectorVisitor {
393 label_name: "env".to_string(),
394 label_value: "prod".to_string(),
395 inject_once: false,
396 };
397
398 assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
399
400 if let Expr::Binary(binary) = &ast {
401 if let Expr::VectorSelector(lhs_vs) = &*binary.lhs {
402 assert_eq!(lhs_vs.matchers.matchers.len(), 1);
403 assert_eq!(lhs_vs.matchers.matchers[0].name, "env");
404 assert_eq!(lhs_vs.matchers.matchers[0].value, "prod");
405 } else {
406 panic!("expected LHS to be a VectorSelector");
407 }
408
409 if let Expr::VectorSelector(rhs_vs) = &*binary.rhs {
410 assert_eq!(rhs_vs.matchers.matchers.len(), 1);
411 assert_eq!(rhs_vs.matchers.matchers[0].name, "env");
412 assert_eq!(rhs_vs.matchers.matchers[0].value, "prod");
413 } else {
414 panic!("expected RHS to be a VectorSelector");
415 }
416 } else {
417 panic!("expected a Binary expression");
418 }
419 }
420
421 #[derive(Debug)]
422 struct DummyExtension {
423 children: Vec<Expr>,
424 }
425
426 impl ExtensionExpr for DummyExtension {
427 fn as_any(&self) -> &dyn std::any::Any {
428 self
429 }
430 fn name(&self) -> &str {
431 "dummy"
432 }
433 fn value_type(&self) -> ValueType {
434 ValueType::Vector
435 }
436 fn children(&self) -> &[Expr] {
437 &self.children
438 }
439 fn with_new_children(&self, children: Vec<Expr>) -> Arc<dyn ExtensionExpr> {
440 Arc::new(DummyExtension { children })
441 }
442 }
443
444 #[test]
445 fn test_inject_label_into_extension() {
446 let inner_expr = parser::parse("pg_stat_activity_count{}").unwrap();
447 let dummy_ext = DummyExtension {
448 children: vec![inner_expr],
449 };
450
451 let shared_arc = std::sync::Arc::new(dummy_ext);
452 let _second_reference = std::sync::Arc::clone(&shared_arc);
454
455 let mut ast = Expr::Extension(parser::Extension { expr: shared_arc });
456
457 let mut visitor = LabelInjectorVisitor {
458 label_name: "env".to_string(),
459 label_value: "prod".to_string(),
460 inject_once: false,
461 };
462 assert!(walk_expr_mut(&mut visitor, &mut ast).unwrap());
463
464 if let Expr::Extension(ext) = &ast {
466 let children = ext.expr.children();
467 assert_eq!(children.len(), 1);
468 if let Expr::VectorSelector(vs) = &children[0] {
469 assert_eq!(vs.matchers.matchers.len(), 1);
470 assert_eq!(vs.matchers.matchers[0].name, "env");
471 assert_eq!(vs.matchers.matchers[0].value, "prod");
472 } else {
473 panic!("expected inner expression to be a VectorSelector");
474 }
475 } else {
476 panic!("expected Extension expression");
477 }
478 }
479
480 #[test]
481 fn test_extension_partial_mutation_on_short_circuit() {
482 let child1 = parser::parse("metric_a{}").unwrap();
483 let child2 = parser::parse("metric_b{}").unwrap();
484
485 let dummy_ext = DummyExtension {
486 children: vec![child1, child2],
487 };
488
489 let mut ast = Expr::Extension(parser::Extension {
490 expr: std::sync::Arc::new(dummy_ext),
491 });
492
493 let mut visitor = LabelInjectorVisitor {
494 label_name: "env".to_string(),
495 label_value: "prod".to_string(),
496 inject_once: true,
497 };
498
499 assert_eq!(walk_expr_mut(&mut visitor, &mut ast), Ok(false));
501
502 if let Expr::Extension(ext) = &ast {
503 let children = ext.expr.children();
504 assert_eq!(children.len(), 2);
505
506 if let Expr::VectorSelector(vs) = &children[0] {
508 assert_eq!(vs.matchers.matchers.len(), 1);
509 assert_eq!(vs.matchers.matchers[0].name, "env");
510 assert_eq!(vs.matchers.matchers[0].value, "prod");
511 } else {
512 panic!("expected first child to be a VectorSelector");
513 }
514
515 if let Expr::VectorSelector(vs) = &children[1] {
517 assert!(vs.matchers.matchers.is_empty());
518 } else {
519 panic!("expected second child to be a VectorSelector");
520 }
521 } else {
522 panic!("expected Extension expression");
523 }
524 }
525}