1use crate::ast::{CompareOp, FieldRef, Filter as AstFilter, JoinQuery, Projection, QueryExpr};
14use crate::sql_lowering::{
15 effective_graph_filter, effective_join_filter, effective_table_filter, effective_vector_filter,
16};
17use reddb_types::Value;
18
19#[derive(Debug, Clone, Default)]
21pub struct RewriteContext {
22 pub property_cache: Vec<CachedProperty>,
24 pub errors: Vec<String>,
26 pub warnings: Vec<String>,
28 pub stats: RewriteStats,
30}
31
32#[derive(Debug, Clone)]
34pub struct CachedProperty {
35 pub source: String,
37 pub property: String,
39 pub cached_value: Option<String>,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct RewriteStats {
46 pub filters_simplified: u32,
48 pub predicates_pushed: u32,
50 pub properties_cached: u32,
52 pub expressions_normalized: u32,
54}
55
56pub trait RewriteRule: Send + Sync {
58 fn name(&self) -> &str;
60
61 fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr;
63
64 fn is_applicable(&self, query: &QueryExpr) -> bool;
66}
67
68pub struct QueryRewriter {
70 rules: Vec<Box<dyn RewriteRule>>,
72 max_iterations: usize,
74}
75
76impl QueryRewriter {
77 pub fn new() -> Self {
79 let rules: Vec<Box<dyn RewriteRule>> = vec![
80 Box::new(NormalizeRule),
81 Box::new(SimplifyFiltersRule),
82 Box::new(PushdownPredicatesRule),
83 Box::new(EliminateDeadCodeRule),
84 Box::new(FoldConstantsRule),
85 ];
86
87 Self {
88 rules,
89 max_iterations: 10,
90 }
91 }
92
93 pub fn add_rule(&mut self, rule: Box<dyn RewriteRule>) {
95 self.rules.push(rule);
96 }
97
98 pub fn rewrite(&self, query: QueryExpr) -> QueryExpr {
100 let mut ctx = RewriteContext::default();
101 self.rewrite_with_context(query, &mut ctx)
102 }
103
104 pub fn rewrite_with_context(
106 &self,
107 mut query: QueryExpr,
108 ctx: &mut RewriteContext,
109 ) -> QueryExpr {
110 for _iteration in 0..self.max_iterations {
112 let original = format!("{:?}", query);
113
114 for rule in &self.rules {
115 if rule.is_applicable(&query) {
116 query = rule.apply(query, ctx);
117 }
118 }
119
120 if format!("{:?}", query) == original {
122 break;
123 }
124 }
125
126 query
127 }
128}
129
130impl Default for QueryRewriter {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136struct NormalizeRule;
142
143impl RewriteRule for NormalizeRule {
144 fn name(&self) -> &str {
145 "Normalize"
146 }
147
148 fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
149 match query {
150 QueryExpr::Table(mut tq) => {
151 tq.columns.sort_by(|a, b| {
153 let a_name = projection_name(a);
154 let b_name = projection_name(b);
155 a_name.cmp(&b_name)
156 });
157 ctx.stats.expressions_normalized += 1;
158 QueryExpr::Table(tq)
159 }
160 QueryExpr::Graph(gq) => {
161 QueryExpr::Graph(gq)
163 }
164 QueryExpr::Join(jq) => {
165 let left = self.apply(*jq.left, ctx);
167 let right = self.apply(*jq.right, ctx);
168 QueryExpr::Join(JoinQuery {
169 left: Box::new(left),
170 right: Box::new(right),
171 ..jq
172 })
173 }
174 QueryExpr::Path(pq) => QueryExpr::Path(pq),
175 QueryExpr::Vector(vq) => {
176 QueryExpr::Vector(vq)
178 }
179 QueryExpr::Hybrid(mut hq) => {
180 hq.structured = Box::new(self.apply(*hq.structured, ctx));
182 QueryExpr::Hybrid(hq)
183 }
184 other @ (QueryExpr::Insert(_)
186 | QueryExpr::Update(_)
187 | QueryExpr::Delete(_)
188 | QueryExpr::CreateTable(_)
189 | QueryExpr::CreateCollection(_)
190 | QueryExpr::CreateVector(_)
191 | QueryExpr::DropTable(_)
192 | QueryExpr::DropGraph(_)
193 | QueryExpr::DropVector(_)
194 | QueryExpr::DropDocument(_)
195 | QueryExpr::DropKv(_)
196 | QueryExpr::DropCollection(_)
197 | QueryExpr::Truncate(_)
198 | QueryExpr::AlterTable(_)
199 | QueryExpr::GraphCommand(_)
200 | QueryExpr::SearchCommand(_)
201 | QueryExpr::CreateIndex(_)
202 | QueryExpr::DropIndex(_)
203 | QueryExpr::ProbabilisticCommand(_)
204 | QueryExpr::Ask(_)
205 | QueryExpr::SetConfig { .. }
206 | QueryExpr::ShowConfig { .. }
207 | QueryExpr::SetSecret { .. }
208 | QueryExpr::DeleteSecret { .. }
209 | QueryExpr::ShowSecrets { .. }
210 | QueryExpr::SetTenant(_)
211 | QueryExpr::ShowTenant
212 | QueryExpr::CreateTimeSeries(_)
213 | QueryExpr::CreateMetric(_)
214 | QueryExpr::AlterMetric(_)
215 | QueryExpr::CreateSlo(_)
216 | QueryExpr::DropTimeSeries(_)
217 | QueryExpr::CreateQueue(_)
218 | QueryExpr::AlterQueue(_)
219 | QueryExpr::DropQueue(_)
220 | QueryExpr::QueueSelect(_)
221 | QueryExpr::QueueCommand(_)
222 | QueryExpr::KvCommand(_)
223 | QueryExpr::ConfigCommand(_)
224 | QueryExpr::CreateTree(_)
225 | QueryExpr::DropTree(_)
226 | QueryExpr::TreeCommand(_)
227 | QueryExpr::ExplainAlter(_)
228 | QueryExpr::TransactionControl(_)
229 | QueryExpr::MaintenanceCommand(_)
230 | QueryExpr::CreateSchema(_)
231 | QueryExpr::DropSchema(_)
232 | QueryExpr::CreateSequence(_)
233 | QueryExpr::DropSequence(_)
234 | QueryExpr::CopyFrom(_)
235 | QueryExpr::CreateView(_)
236 | QueryExpr::DropView(_)
237 | QueryExpr::RefreshMaterializedView(_)
238 | QueryExpr::CreatePolicy(_)
239 | QueryExpr::DropPolicy(_)
240 | QueryExpr::CreateServer(_)
241 | QueryExpr::DropServer(_)
242 | QueryExpr::CreateForeignTable(_)
243 | QueryExpr::DropForeignTable(_)
244 | QueryExpr::Grant(_)
245 | QueryExpr::Revoke(_)
246 | QueryExpr::AlterUser(_)
247 | QueryExpr::CreateUser(_)
248 | QueryExpr::CreateIamPolicy { .. }
249 | QueryExpr::DropIamPolicy { .. }
250 | QueryExpr::AttachPolicy { .. }
251 | QueryExpr::DetachPolicy { .. }
252 | QueryExpr::ShowPolicies { .. }
253 | QueryExpr::ShowEffectivePermissions { .. }
254 | QueryExpr::RankOf(_)
255 | QueryExpr::ApproxRankOf(_)
256 | QueryExpr::RankRange(_)
257 | QueryExpr::SimulatePolicy { .. }
258 | QueryExpr::LintPolicy { .. }
259 | QueryExpr::MigratePolicyMode { .. }
260 | QueryExpr::CreateMigration(_)
261 | QueryExpr::ApplyMigration(_)
262 | QueryExpr::RollbackMigration(_)
263 | QueryExpr::ExplainMigration(_)
264 | QueryExpr::EventsBackfill(_)
265 | QueryExpr::EventsBackfillStatus { .. }) => other,
266 }
267 }
268
269 fn is_applicable(&self, _query: &QueryExpr) -> bool {
270 true
271 }
272}
273
274struct SimplifyFiltersRule;
276
277impl RewriteRule for SimplifyFiltersRule {
278 fn name(&self) -> &str {
279 "SimplifyFilters"
280 }
281
282 fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
283 match query {
284 QueryExpr::Table(mut tq) => {
285 if let Some(filter) = effective_table_filter(&tq) {
286 tq.filter = Some(simplify_filter(filter, ctx));
287 }
288 QueryExpr::Table(tq)
289 }
290 QueryExpr::Graph(mut gq) => {
291 if let Some(filter) = effective_graph_filter(&gq) {
292 gq.filter = Some(simplify_filter(filter, ctx));
293 }
294 QueryExpr::Graph(gq)
295 }
296 QueryExpr::Join(mut jq) => {
297 let join_filter = effective_join_filter(&jq);
298 let left = self.apply(*jq.left, ctx);
299 let right = self.apply(*jq.right, ctx);
300 if let Some(filter) = join_filter {
301 jq.filter = Some(simplify_filter(filter, ctx));
302 }
303 jq.left = Box::new(left);
304 jq.right = Box::new(right);
305 QueryExpr::Join(jq)
306 }
307 QueryExpr::Path(pq) => QueryExpr::Path(pq),
308 QueryExpr::Vector(vq) => {
309 QueryExpr::Vector(vq)
312 }
313 QueryExpr::Hybrid(mut hq) => {
314 hq.structured = Box::new(self.apply(*hq.structured, ctx));
316 QueryExpr::Hybrid(hq)
317 }
318 other @ (QueryExpr::Insert(_)
320 | QueryExpr::Update(_)
321 | QueryExpr::Delete(_)
322 | QueryExpr::CreateTable(_)
323 | QueryExpr::CreateCollection(_)
324 | QueryExpr::CreateVector(_)
325 | QueryExpr::DropTable(_)
326 | QueryExpr::DropGraph(_)
327 | QueryExpr::DropVector(_)
328 | QueryExpr::DropDocument(_)
329 | QueryExpr::DropKv(_)
330 | QueryExpr::DropCollection(_)
331 | QueryExpr::Truncate(_)
332 | QueryExpr::AlterTable(_)
333 | QueryExpr::GraphCommand(_)
334 | QueryExpr::SearchCommand(_)
335 | QueryExpr::CreateIndex(_)
336 | QueryExpr::DropIndex(_)
337 | QueryExpr::ProbabilisticCommand(_)
338 | QueryExpr::Ask(_)
339 | QueryExpr::SetConfig { .. }
340 | QueryExpr::ShowConfig { .. }
341 | QueryExpr::SetSecret { .. }
342 | QueryExpr::DeleteSecret { .. }
343 | QueryExpr::ShowSecrets { .. }
344 | QueryExpr::SetTenant(_)
345 | QueryExpr::ShowTenant
346 | QueryExpr::CreateTimeSeries(_)
347 | QueryExpr::CreateMetric(_)
348 | QueryExpr::AlterMetric(_)
349 | QueryExpr::CreateSlo(_)
350 | QueryExpr::DropTimeSeries(_)
351 | QueryExpr::CreateQueue(_)
352 | QueryExpr::AlterQueue(_)
353 | QueryExpr::DropQueue(_)
354 | QueryExpr::QueueSelect(_)
355 | QueryExpr::QueueCommand(_)
356 | QueryExpr::KvCommand(_)
357 | QueryExpr::ConfigCommand(_)
358 | QueryExpr::CreateTree(_)
359 | QueryExpr::DropTree(_)
360 | QueryExpr::TreeCommand(_)
361 | QueryExpr::ExplainAlter(_)
362 | QueryExpr::TransactionControl(_)
363 | QueryExpr::MaintenanceCommand(_)
364 | QueryExpr::CreateSchema(_)
365 | QueryExpr::DropSchema(_)
366 | QueryExpr::CreateSequence(_)
367 | QueryExpr::DropSequence(_)
368 | QueryExpr::CopyFrom(_)
369 | QueryExpr::CreateView(_)
370 | QueryExpr::DropView(_)
371 | QueryExpr::RefreshMaterializedView(_)
372 | QueryExpr::CreatePolicy(_)
373 | QueryExpr::DropPolicy(_)
374 | QueryExpr::CreateServer(_)
375 | QueryExpr::DropServer(_)
376 | QueryExpr::CreateForeignTable(_)
377 | QueryExpr::DropForeignTable(_)
378 | QueryExpr::Grant(_)
379 | QueryExpr::Revoke(_)
380 | QueryExpr::AlterUser(_)
381 | QueryExpr::CreateUser(_)
382 | QueryExpr::CreateIamPolicy { .. }
383 | QueryExpr::DropIamPolicy { .. }
384 | QueryExpr::AttachPolicy { .. }
385 | QueryExpr::DetachPolicy { .. }
386 | QueryExpr::ShowPolicies { .. }
387 | QueryExpr::ShowEffectivePermissions { .. }
388 | QueryExpr::RankOf(_)
389 | QueryExpr::ApproxRankOf(_)
390 | QueryExpr::RankRange(_)
391 | QueryExpr::SimulatePolicy { .. }
392 | QueryExpr::LintPolicy { .. }
393 | QueryExpr::MigratePolicyMode { .. }
394 | QueryExpr::CreateMigration(_)
395 | QueryExpr::ApplyMigration(_)
396 | QueryExpr::RollbackMigration(_)
397 | QueryExpr::ExplainMigration(_)
398 | QueryExpr::EventsBackfill(_)
399 | QueryExpr::EventsBackfillStatus { .. }) => other,
400 }
401 }
402
403 fn is_applicable(&self, query: &QueryExpr) -> bool {
404 match query {
405 QueryExpr::Table(tq) => effective_table_filter(tq).is_some(),
406 QueryExpr::Graph(gq) => effective_graph_filter(gq).is_some(),
407 QueryExpr::Join(_) => true,
408 QueryExpr::Path(_) => false,
409 QueryExpr::Vector(vq) => effective_vector_filter(vq).is_some(),
410 QueryExpr::Hybrid(_) => true, QueryExpr::Insert(_)
413 | QueryExpr::Update(_)
414 | QueryExpr::Delete(_)
415 | QueryExpr::CreateTable(_)
416 | QueryExpr::CreateCollection(_)
417 | QueryExpr::CreateVector(_)
418 | QueryExpr::DropTable(_)
419 | QueryExpr::DropGraph(_)
420 | QueryExpr::DropVector(_)
421 | QueryExpr::DropDocument(_)
422 | QueryExpr::DropKv(_)
423 | QueryExpr::DropCollection(_)
424 | QueryExpr::Truncate(_)
425 | QueryExpr::AlterTable(_)
426 | QueryExpr::GraphCommand(_)
427 | QueryExpr::SearchCommand(_)
428 | QueryExpr::CreateIndex(_)
429 | QueryExpr::DropIndex(_)
430 | QueryExpr::ProbabilisticCommand(_)
431 | QueryExpr::Ask(_)
432 | QueryExpr::SetConfig { .. }
433 | QueryExpr::ShowConfig { .. }
434 | QueryExpr::SetSecret { .. }
435 | QueryExpr::DeleteSecret { .. }
436 | QueryExpr::ShowSecrets { .. }
437 | QueryExpr::SetTenant(_)
438 | QueryExpr::ShowTenant
439 | QueryExpr::CreateTimeSeries(_)
440 | QueryExpr::CreateMetric(_)
441 | QueryExpr::AlterMetric(_)
442 | QueryExpr::CreateSlo(_)
443 | QueryExpr::DropTimeSeries(_)
444 | QueryExpr::CreateQueue(_)
445 | QueryExpr::AlterQueue(_)
446 | QueryExpr::DropQueue(_)
447 | QueryExpr::QueueSelect(_)
448 | QueryExpr::QueueCommand(_)
449 | QueryExpr::KvCommand(_)
450 | QueryExpr::ConfigCommand(_)
451 | QueryExpr::CreateTree(_)
452 | QueryExpr::DropTree(_)
453 | QueryExpr::TreeCommand(_)
454 | QueryExpr::ExplainAlter(_)
455 | QueryExpr::TransactionControl(_)
456 | QueryExpr::MaintenanceCommand(_)
457 | QueryExpr::CreateSchema(_)
458 | QueryExpr::DropSchema(_)
459 | QueryExpr::CreateSequence(_)
460 | QueryExpr::DropSequence(_)
461 | QueryExpr::CopyFrom(_)
462 | QueryExpr::CreateView(_)
463 | QueryExpr::DropView(_)
464 | QueryExpr::RefreshMaterializedView(_)
465 | QueryExpr::CreatePolicy(_)
466 | QueryExpr::DropPolicy(_)
467 | QueryExpr::CreateServer(_)
468 | QueryExpr::DropServer(_)
469 | QueryExpr::CreateForeignTable(_)
470 | QueryExpr::DropForeignTable(_)
471 | QueryExpr::Grant(_)
472 | QueryExpr::Revoke(_)
473 | QueryExpr::AlterUser(_)
474 | QueryExpr::CreateUser(_)
475 | QueryExpr::CreateIamPolicy { .. }
476 | QueryExpr::DropIamPolicy { .. }
477 | QueryExpr::AttachPolicy { .. }
478 | QueryExpr::DetachPolicy { .. }
479 | QueryExpr::ShowPolicies { .. }
480 | QueryExpr::ShowEffectivePermissions { .. }
481 | QueryExpr::RankOf(_)
482 | QueryExpr::ApproxRankOf(_)
483 | QueryExpr::RankRange(_)
484 | QueryExpr::SimulatePolicy { .. }
485 | QueryExpr::LintPolicy { .. }
486 | QueryExpr::MigratePolicyMode { .. }
487 | QueryExpr::CreateMigration(_)
488 | QueryExpr::ApplyMigration(_)
489 | QueryExpr::RollbackMigration(_)
490 | QueryExpr::ExplainMigration(_)
491 | QueryExpr::EventsBackfill(_)
492 | QueryExpr::EventsBackfillStatus { .. } => false,
493 }
494 }
495}
496
497struct PushdownPredicatesRule;
499
500impl RewriteRule for PushdownPredicatesRule {
501 fn name(&self) -> &str {
502 "PushdownPredicates"
503 }
504
505 fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
506 match query {
507 QueryExpr::Join(mut jq) => {
508 jq.left = Box::new(self.apply(*jq.left, ctx));
514 jq.right = Box::new(self.apply(*jq.right, ctx));
515 ctx.stats.predicates_pushed += 1;
516 QueryExpr::Join(jq)
517 }
518 other => other,
519 }
520 }
521
522 fn is_applicable(&self, query: &QueryExpr) -> bool {
523 matches!(query, QueryExpr::Join(_))
524 }
525}
526
527struct EliminateDeadCodeRule;
529
530impl RewriteRule for EliminateDeadCodeRule {
531 fn name(&self) -> &str {
532 "EliminateDeadCode"
533 }
534
535 fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
536 match query {
537 QueryExpr::Table(mut tq) => {
538 if let Some(filter) = effective_table_filter(&tq).as_ref() {
540 if is_always_true(filter) {
541 tq.filter = None;
542 }
543 }
544 QueryExpr::Table(tq)
545 }
546 other => other,
547 }
548 }
549
550 fn is_applicable(&self, query: &QueryExpr) -> bool {
551 matches!(query, QueryExpr::Table(_))
552 }
553}
554
555struct FoldConstantsRule;
557
558impl RewriteRule for FoldConstantsRule {
559 fn name(&self) -> &str {
560 "FoldConstants"
561 }
562
563 fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
564 query
567 }
568
569 fn is_applicable(&self, _query: &QueryExpr) -> bool {
570 true
571 }
572}
573
574fn projection_name(proj: &Projection) -> String {
579 match proj {
580 Projection::All => "*".to_string(),
581 Projection::Column(name) => name.clone(),
582 Projection::Alias(_, alias) => alias.clone(),
583 Projection::Function(name, _) => name
584 .split_once(':')
585 .map(|(_, alias)| alias.to_string())
586 .unwrap_or_else(|| name.clone()),
587 Projection::Expression(expr, alias) => {
588 alias.clone().unwrap_or_else(|| format!("{:?}", expr))
589 }
590 Projection::Field(field, alias) => alias.clone().unwrap_or_else(|| format!("{:?}", field)),
591 Projection::Window { name, alias, .. } => alias.clone().unwrap_or_else(|| name.clone()),
592 }
593}
594
595fn simplify_filter(filter: AstFilter, ctx: &mut RewriteContext) -> AstFilter {
596 match filter {
597 AstFilter::And(left, right) => {
598 let left = simplify_filter(*left, ctx);
599 let right = simplify_filter(*right, ctx);
600
601 if is_always_true(&left) {
603 ctx.stats.filters_simplified += 1;
604 return right;
605 }
606 if is_always_true(&right) {
607 ctx.stats.filters_simplified += 1;
608 return left;
609 }
610
611 if is_always_false(&left) || is_always_false(&right) {
613 ctx.stats.filters_simplified += 1;
614 return AstFilter::Compare {
615 field: FieldRef::TableColumn {
616 table: String::new(),
617 column: "1".to_string(),
618 },
619 op: CompareOp::Eq,
620 value: Value::Integer(0),
621 };
622 }
623
624 AstFilter::And(Box::new(left), Box::new(right))
625 }
626 AstFilter::Or(left, right) => {
627 let left = simplify_filter(*left, ctx);
628 let right = simplify_filter(*right, ctx);
629
630 if is_always_false(&left) {
632 ctx.stats.filters_simplified += 1;
633 return right;
634 }
635 if is_always_false(&right) {
636 ctx.stats.filters_simplified += 1;
637 return left;
638 }
639
640 if is_always_true(&left) || is_always_true(&right) {
642 ctx.stats.filters_simplified += 1;
643 return AstFilter::Compare {
644 field: FieldRef::TableColumn {
645 table: String::new(),
646 column: "1".to_string(),
647 },
648 op: CompareOp::Eq,
649 value: Value::Integer(1),
650 };
651 }
652
653 AstFilter::Or(Box::new(left), Box::new(right))
654 }
655 AstFilter::Not(inner) => {
656 let inner = simplify_filter(*inner, ctx);
657
658 if let AstFilter::Not(double_inner) = inner {
660 ctx.stats.filters_simplified += 1;
661 return *double_inner;
662 }
663
664 AstFilter::Not(Box::new(inner))
665 }
666 other => other,
667 }
668}
669
670fn is_always_true(filter: &AstFilter) -> bool {
671 match filter {
672 AstFilter::Compare { field, op, value } => {
673 matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
675 && matches!(op, CompareOp::Eq)
676 && matches!(value, Value::Integer(1))
677 }
678 _ => false,
679 }
680}
681
682fn is_always_false(filter: &AstFilter) -> bool {
683 match filter {
684 AstFilter::Compare { field, op, value } => {
685 matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
687 && matches!(op, CompareOp::Eq)
688 && matches!(value, Value::Integer(0))
689 }
690 _ => false,
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use crate::ast::{JoinCondition, TableQuery, WindowSpec};
698
699 fn make_field(name: &str) -> FieldRef {
700 FieldRef::TableColumn {
701 table: String::new(),
702 column: name.to_string(),
703 }
704 }
705
706 #[test]
707 fn test_simplify_and_with_true() {
708 let mut ctx = RewriteContext::default();
709
710 let filter = AstFilter::And(
711 Box::new(AstFilter::Compare {
712 field: make_field("1"),
713 op: CompareOp::Eq,
714 value: Value::Integer(1),
715 }),
716 Box::new(AstFilter::Compare {
717 field: make_field("x"),
718 op: CompareOp::Eq,
719 value: Value::Integer(5),
720 }),
721 );
722
723 let simplified = simplify_filter(filter, &mut ctx);
724
725 match simplified {
726 AstFilter::Compare { field, .. } => {
727 assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
728 }
729 _ => panic!("Expected Compare filter"),
730 }
731 }
732
733 #[test]
734 fn test_simplify_double_not() {
735 let mut ctx = RewriteContext::default();
736
737 let filter = AstFilter::Not(Box::new(AstFilter::Not(Box::new(AstFilter::Compare {
738 field: make_field("x"),
739 op: CompareOp::Eq,
740 value: Value::Integer(5),
741 }))));
742
743 let simplified = simplify_filter(filter, &mut ctx);
744
745 match simplified {
746 AstFilter::Compare { field, .. } => {
747 assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
748 }
749 _ => panic!("Expected Compare filter"),
750 }
751 }
752
753 #[test]
754 fn projection_name_uses_visible_output_name_for_all_projection_shapes() {
755 assert_eq!(projection_name(&Projection::All), "*");
756 assert_eq!(
757 projection_name(&Projection::Column("raw".to_string())),
758 "raw"
759 );
760 assert_eq!(
761 projection_name(&Projection::Alias("raw".to_string(), "alias".to_string())),
762 "alias"
763 );
764 assert_eq!(
765 projection_name(&Projection::Function(
766 "LOWER:display".to_string(),
767 Vec::new()
768 )),
769 "display"
770 );
771 assert_eq!(
772 projection_name(&Projection::Expression(
773 Box::new(AstFilter::Compare {
774 field: make_field("x"),
775 op: CompareOp::Eq,
776 value: Value::Integer(1),
777 }),
778 Some("expr_alias".to_string()),
779 )),
780 "expr_alias"
781 );
782 assert_eq!(
783 projection_name(&Projection::Field(
784 FieldRef::node_prop("n", "name"),
785 Some("node_name".to_string()),
786 )),
787 "node_name"
788 );
789 assert_eq!(
790 projection_name(&Projection::Window {
791 name: "ROW_NUMBER".to_string(),
792 args: Vec::new(),
793 window: Box::new(WindowSpec::default()),
794 alias: Some("rn".to_string()),
795 }),
796 "rn"
797 );
798 }
799
800 #[test]
801 fn normalize_rule_sorts_table_columns_by_output_name() {
802 let mut table = TableQuery::new("users");
803 table.columns = vec![
804 Projection::Column("z".to_string()),
805 Projection::Function("LOWER:a_alias".to_string(), Vec::new()),
806 Projection::Alias("name".to_string(), "m".to_string()),
807 ];
808
809 let mut ctx = RewriteContext::default();
810 let normalized = NormalizeRule.apply(QueryExpr::Table(table), &mut ctx);
811 let QueryExpr::Table(table) = normalized else {
812 panic!("expected table query");
813 };
814
815 assert_eq!(ctx.stats.expressions_normalized, 1);
816 assert_eq!(
817 table
818 .columns
819 .iter()
820 .map(projection_name)
821 .collect::<Vec<_>>(),
822 vec!["a_alias", "m", "z"]
823 );
824 }
825
826 #[test]
827 fn simplify_filter_covers_or_true_false_and_not_paths() {
828 let mut ctx = RewriteContext::default();
829 let truth = AstFilter::Compare {
830 field: make_field("1"),
831 op: CompareOp::Eq,
832 value: Value::Integer(1),
833 };
834 let falsehood = AstFilter::Compare {
835 field: make_field("1"),
836 op: CompareOp::Eq,
837 value: Value::Integer(0),
838 };
839 let predicate = AstFilter::Compare {
840 field: make_field("x"),
841 op: CompareOp::Eq,
842 value: Value::Integer(5),
843 };
844
845 assert_eq!(
846 simplify_filter(
847 AstFilter::Or(Box::new(falsehood.clone()), Box::new(predicate.clone())),
848 &mut ctx,
849 ),
850 predicate
851 );
852 assert!(is_always_true(&simplify_filter(
853 AstFilter::Or(Box::new(truth), Box::new(falsehood.clone())),
854 &mut ctx,
855 )));
856 assert!(is_always_false(&simplify_filter(
857 AstFilter::And(
858 Box::new(falsehood),
859 Box::new(AstFilter::IsNotNull(make_field("x")))
860 ),
861 &mut ctx,
862 )));
863 assert!(ctx.stats.filters_simplified >= 3);
864 }
865
866 #[test]
867 fn query_rewriter_runs_rules_until_fixed_point_and_exposes_context() {
868 let mut table = TableQuery::new("users");
869 table.filter = Some(AstFilter::And(
870 Box::new(AstFilter::Compare {
871 field: make_field("1"),
872 op: CompareOp::Eq,
873 value: Value::Integer(1),
874 }),
875 Box::new(AstFilter::Compare {
876 field: make_field("age"),
877 op: CompareOp::Ge,
878 value: Value::Integer(18),
879 }),
880 ));
881
882 let mut ctx = RewriteContext::default();
883 let rewritten =
884 QueryRewriter::default().rewrite_with_context(QueryExpr::Table(table), &mut ctx);
885 let QueryExpr::Table(table) = rewritten else {
886 panic!("expected table query");
887 };
888
889 assert!(matches!(
890 table.filter,
891 Some(AstFilter::Compare {
892 field: FieldRef::TableColumn { column, .. },
893 op: CompareOp::Ge,
894 value: Value::Integer(18),
895 }) if column == "age"
896 ));
897 assert!(ctx.stats.filters_simplified >= 1);
898 }
899
900 #[test]
901 fn query_rewriter_recurses_into_join_children_and_tracks_pushdown() {
902 let mut left = TableQuery::new("users");
903 left.columns = vec![
904 Projection::Column("z".to_string()),
905 Projection::Column("a".to_string()),
906 ];
907 left.filter = Some(AstFilter::And(
908 Box::new(AstFilter::Compare {
909 field: make_field("1"),
910 op: CompareOp::Eq,
911 value: Value::Integer(1),
912 }),
913 Box::new(AstFilter::Compare {
914 field: make_field("age"),
915 op: CompareOp::Ge,
916 value: Value::Integer(18),
917 }),
918 ));
919
920 let join = JoinQuery::new(
921 QueryExpr::Table(left),
922 QueryExpr::Table(TableQuery::new("orders")),
923 JoinCondition::new(make_field("id"), make_field("user_id")),
924 );
925
926 let mut ctx = RewriteContext::default();
927 let rewritten =
928 QueryRewriter::default().rewrite_with_context(QueryExpr::Join(join), &mut ctx);
929 let QueryExpr::Join(join) = rewritten else {
930 panic!("expected join query");
931 };
932 let QueryExpr::Table(left) = join.left.as_ref() else {
933 panic!("expected table on left side");
934 };
935
936 assert_eq!(
937 left.columns.iter().map(projection_name).collect::<Vec<_>>(),
938 vec!["a", "z"]
939 );
940 assert!(matches!(
941 &left.filter,
942 Some(AstFilter::Compare {
943 field: FieldRef::TableColumn { ref column, .. },
944 op: CompareOp::Ge,
945 value: Value::Integer(18),
946 }) if column == "age"
947 ));
948 assert!(ctx.stats.predicates_pushed >= 1);
949 }
950
951 #[test]
952 fn query_rewriter_eliminates_always_true_table_filters() {
953 let mut table = TableQuery::new("users");
954 table.filter = Some(AstFilter::Compare {
955 field: make_field("1"),
956 op: CompareOp::Eq,
957 value: Value::Integer(1),
958 });
959
960 let rewritten = QueryRewriter::default().rewrite(QueryExpr::Table(table));
961 let QueryExpr::Table(table) = rewritten else {
962 panic!("expected table query");
963 };
964
965 assert!(table.filter.is_none());
966 }
967
968 struct CountingRule;
969
970 impl RewriteRule for CountingRule {
971 fn name(&self) -> &str {
972 "CountingRule"
973 }
974
975 fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
976 ctx.warnings.push(self.name().to_string());
977 query
978 }
979
980 fn is_applicable(&self, query: &QueryExpr) -> bool {
981 matches!(query, QueryExpr::Table(_))
982 }
983 }
984
985 #[test]
986 fn custom_rules_can_be_added_after_defaults() {
987 let mut rewriter = QueryRewriter::new();
988 rewriter.add_rule(Box::new(CountingRule));
989
990 let mut ctx = RewriteContext::default();
991 let rewritten =
992 rewriter.rewrite_with_context(QueryExpr::Table(TableQuery::new("users")), &mut ctx);
993
994 assert!(matches!(rewritten, QueryExpr::Table(_)));
995 assert!(ctx.warnings.iter().any(|warning| warning == "CountingRule"));
996 }
997}