Skip to main content

reddb_rql/planner/
rewriter.rs

1//! Query Rewriter
2//!
3//! Multi-pass AST transformation system inspired by Neo4j's query rewriting.
4//!
5//! # Rewrite Passes
6//!
7//! 1. **Normalize**: Standardize AST structure
8//! 2. **InjectCachedProperties**: Cache property lookups at compile time
9//! 3. **SimplifyFilters**: Combine and simplify filter expressions
10//! 4. **PushdownPredicates**: Move filters closer to data source
11//! 5. **ValidateFunctions**: Check function calls against schema
12
13use crate::ast::{CompareOp, FieldRef, Filter as AstFilter, JoinQuery, Projection, QueryExpr};
14use crate::sql_lowering::{
15    effective_graph_filter, effective_join_filter, effective_table_filter, effective_vector_filter,
16};
17use reddb_types::Value;
18
19/// Context for rewrite operations
20#[derive(Debug, Clone, Default)]
21pub struct RewriteContext {
22    /// Property cache for compile-time lookups
23    pub property_cache: Vec<CachedProperty>,
24    /// Validation errors encountered
25    pub errors: Vec<String>,
26    /// Warnings generated
27    pub warnings: Vec<String>,
28    /// Statistics about rewrites
29    pub stats: RewriteStats,
30}
31
32/// A cached property lookup
33#[derive(Debug, Clone)]
34pub struct CachedProperty {
35    /// Source alias (table or node)
36    pub source: String,
37    /// Property name
38    pub property: String,
39    /// Cached value if known at compile time
40    pub cached_value: Option<String>,
41}
42
43/// Statistics about rewrite passes
44#[derive(Debug, Clone, Default)]
45pub struct RewriteStats {
46    /// Number of filters simplified
47    pub filters_simplified: u32,
48    /// Number of predicates pushed down
49    pub predicates_pushed: u32,
50    /// Number of properties cached
51    pub properties_cached: u32,
52    /// Number of expressions normalized
53    pub expressions_normalized: u32,
54}
55
56/// A rewrite rule that transforms query expressions
57pub trait RewriteRule: Send + Sync {
58    /// Rule name for debugging
59    fn name(&self) -> &str;
60
61    /// Apply the rule to a query expression
62    fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr;
63
64    /// Check if this rule is applicable to the query
65    fn is_applicable(&self, query: &QueryExpr) -> bool;
66}
67
68/// Query rewriter with pluggable rules
69pub struct QueryRewriter {
70    /// Ordered list of rewrite rules
71    rules: Vec<Box<dyn RewriteRule>>,
72    /// Maximum number of rewrite iterations
73    max_iterations: usize,
74}
75
76impl QueryRewriter {
77    /// Create a new rewriter with default rules
78    pub fn new() -> Self {
79        let rules: Vec<Box<dyn RewriteRule>> = vec![
80            Box::new(NormalizeRule),
81            Box::new(SimplifyFiltersRule),
82            Box::new(PushdownPredicatesRule),
83            Box::new(EliminateDeadCodeRule),
84            Box::new(FoldConstantsRule),
85        ];
86
87        Self {
88            rules,
89            max_iterations: 10,
90        }
91    }
92
93    /// Add a custom rewrite rule
94    pub fn add_rule(&mut self, rule: Box<dyn RewriteRule>) {
95        self.rules.push(rule);
96    }
97
98    /// Rewrite a query expression
99    pub fn rewrite(&self, query: QueryExpr) -> QueryExpr {
100        let mut ctx = RewriteContext::default();
101        self.rewrite_with_context(query, &mut ctx)
102    }
103
104    /// Rewrite with access to context
105    pub fn rewrite_with_context(
106        &self,
107        mut query: QueryExpr,
108        ctx: &mut RewriteContext,
109    ) -> QueryExpr {
110        // Apply rules iteratively until fixed point
111        for _iteration in 0..self.max_iterations {
112            let original = format!("{:?}", query);
113
114            for rule in &self.rules {
115                if rule.is_applicable(&query) {
116                    query = rule.apply(query, ctx);
117                }
118            }
119
120            // Check for fixed point
121            if format!("{:?}", query) == original {
122                break;
123            }
124        }
125
126        query
127    }
128}
129
130impl Default for QueryRewriter {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136// =============================================================================
137// Built-in Rewrite Rules
138// =============================================================================
139
140/// Normalize AST structure
141struct NormalizeRule;
142
143impl RewriteRule for NormalizeRule {
144    fn name(&self) -> &str {
145        "Normalize"
146    }
147
148    fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
149        match query {
150            QueryExpr::Table(mut tq) => {
151                // Normalize column order
152                tq.columns.sort_by(|a, b| {
153                    let a_name = projection_name(a);
154                    let b_name = projection_name(b);
155                    a_name.cmp(&b_name)
156                });
157                ctx.stats.expressions_normalized += 1;
158                QueryExpr::Table(tq)
159            }
160            QueryExpr::Graph(gq) => {
161                // Graph queries don't need normalization currently
162                QueryExpr::Graph(gq)
163            }
164            QueryExpr::Join(jq) => {
165                // Recursively normalize children
166                let left = self.apply(*jq.left, ctx);
167                let right = self.apply(*jq.right, ctx);
168                QueryExpr::Join(JoinQuery {
169                    left: Box::new(left),
170                    right: Box::new(right),
171                    ..jq
172                })
173            }
174            QueryExpr::Path(pq) => QueryExpr::Path(pq),
175            QueryExpr::Vector(vq) => {
176                // Vector queries don't need normalization currently
177                QueryExpr::Vector(vq)
178            }
179            QueryExpr::Hybrid(mut hq) => {
180                // Normalize the structured part
181                hq.structured = Box::new(self.apply(*hq.structured, ctx));
182                QueryExpr::Hybrid(hq)
183            }
184            // DML/DDL/Command statements pass through without normalization
185            other @ (QueryExpr::Insert(_)
186            | QueryExpr::Update(_)
187            | QueryExpr::Delete(_)
188            | QueryExpr::CreateTable(_)
189            | QueryExpr::CreateCollection(_)
190            | QueryExpr::CreateVector(_)
191            | QueryExpr::DropTable(_)
192            | QueryExpr::DropGraph(_)
193            | QueryExpr::DropVector(_)
194            | QueryExpr::DropDocument(_)
195            | QueryExpr::DropKv(_)
196            | QueryExpr::DropCollection(_)
197            | QueryExpr::Truncate(_)
198            | QueryExpr::AlterTable(_)
199            | QueryExpr::GraphCommand(_)
200            | QueryExpr::SearchCommand(_)
201            | QueryExpr::CreateIndex(_)
202            | QueryExpr::DropIndex(_)
203            | QueryExpr::ProbabilisticCommand(_)
204            | QueryExpr::Ask(_)
205            | QueryExpr::SetConfig { .. }
206            | QueryExpr::ShowConfig { .. }
207            | QueryExpr::SetSecret { .. }
208            | QueryExpr::DeleteSecret { .. }
209            | QueryExpr::ShowSecrets { .. }
210            | QueryExpr::SetTenant(_)
211            | QueryExpr::ShowTenant
212            | QueryExpr::CreateTimeSeries(_)
213            | QueryExpr::CreateMetric(_)
214            | QueryExpr::AlterMetric(_)
215            | QueryExpr::CreateSlo(_)
216            | QueryExpr::DropTimeSeries(_)
217            | QueryExpr::CreateQueue(_)
218            | QueryExpr::AlterQueue(_)
219            | QueryExpr::DropQueue(_)
220            | QueryExpr::QueueSelect(_)
221            | QueryExpr::QueueCommand(_)
222            | QueryExpr::KvCommand(_)
223            | QueryExpr::ConfigCommand(_)
224            | QueryExpr::CreateTree(_)
225            | QueryExpr::DropTree(_)
226            | QueryExpr::TreeCommand(_)
227            | QueryExpr::ExplainAlter(_)
228            | QueryExpr::TransactionControl(_)
229            | QueryExpr::MaintenanceCommand(_)
230            | QueryExpr::CreateSchema(_)
231            | QueryExpr::DropSchema(_)
232            | QueryExpr::CreateSequence(_)
233            | QueryExpr::DropSequence(_)
234            | QueryExpr::CopyFrom(_)
235            | QueryExpr::CreateView(_)
236            | QueryExpr::DropView(_)
237            | QueryExpr::RefreshMaterializedView(_)
238            | QueryExpr::CreatePolicy(_)
239            | QueryExpr::DropPolicy(_)
240            | QueryExpr::CreateServer(_)
241            | QueryExpr::DropServer(_)
242            | QueryExpr::CreateForeignTable(_)
243            | QueryExpr::DropForeignTable(_)
244            | QueryExpr::Grant(_)
245            | QueryExpr::Revoke(_)
246            | QueryExpr::AlterUser(_)
247            | QueryExpr::CreateUser(_)
248            | QueryExpr::CreateIamPolicy { .. }
249            | QueryExpr::DropIamPolicy { .. }
250            | QueryExpr::AttachPolicy { .. }
251            | QueryExpr::DetachPolicy { .. }
252            | QueryExpr::ShowPolicies { .. }
253            | QueryExpr::ShowEffectivePermissions { .. }
254            | QueryExpr::RankOf(_)
255            | QueryExpr::ApproxRankOf(_)
256            | QueryExpr::RankRange(_)
257            | QueryExpr::SimulatePolicy { .. }
258            | QueryExpr::LintPolicy { .. }
259            | QueryExpr::MigratePolicyMode { .. }
260            | QueryExpr::CreateMigration(_)
261            | QueryExpr::ApplyMigration(_)
262            | QueryExpr::RollbackMigration(_)
263            | QueryExpr::ExplainMigration(_)
264            | QueryExpr::EventsBackfill(_)
265            | QueryExpr::EventsBackfillStatus { .. }) => other,
266        }
267    }
268
269    fn is_applicable(&self, _query: &QueryExpr) -> bool {
270        true
271    }
272}
273
274/// Simplify filter expressions
275struct SimplifyFiltersRule;
276
277impl RewriteRule for SimplifyFiltersRule {
278    fn name(&self) -> &str {
279        "SimplifyFilters"
280    }
281
282    fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
283        match query {
284            QueryExpr::Table(mut tq) => {
285                if let Some(filter) = effective_table_filter(&tq) {
286                    tq.filter = Some(simplify_filter(filter, ctx));
287                }
288                QueryExpr::Table(tq)
289            }
290            QueryExpr::Graph(mut gq) => {
291                if let Some(filter) = effective_graph_filter(&gq) {
292                    gq.filter = Some(simplify_filter(filter, ctx));
293                }
294                QueryExpr::Graph(gq)
295            }
296            QueryExpr::Join(mut jq) => {
297                let join_filter = effective_join_filter(&jq);
298                let left = self.apply(*jq.left, ctx);
299                let right = self.apply(*jq.right, ctx);
300                if let Some(filter) = join_filter {
301                    jq.filter = Some(simplify_filter(filter, ctx));
302                }
303                jq.left = Box::new(left);
304                jq.right = Box::new(right);
305                QueryExpr::Join(jq)
306            }
307            QueryExpr::Path(pq) => QueryExpr::Path(pq),
308            QueryExpr::Vector(vq) => {
309                // Vector queries have MetadataFilter, not AstFilter
310                // Pass through for now
311                QueryExpr::Vector(vq)
312            }
313            QueryExpr::Hybrid(mut hq) => {
314                // Simplify filters in the structured part
315                hq.structured = Box::new(self.apply(*hq.structured, ctx));
316                QueryExpr::Hybrid(hq)
317            }
318            // DML/DDL/Command statements pass through without filter simplification
319            other @ (QueryExpr::Insert(_)
320            | QueryExpr::Update(_)
321            | QueryExpr::Delete(_)
322            | QueryExpr::CreateTable(_)
323            | QueryExpr::CreateCollection(_)
324            | QueryExpr::CreateVector(_)
325            | QueryExpr::DropTable(_)
326            | QueryExpr::DropGraph(_)
327            | QueryExpr::DropVector(_)
328            | QueryExpr::DropDocument(_)
329            | QueryExpr::DropKv(_)
330            | QueryExpr::DropCollection(_)
331            | QueryExpr::Truncate(_)
332            | QueryExpr::AlterTable(_)
333            | QueryExpr::GraphCommand(_)
334            | QueryExpr::SearchCommand(_)
335            | QueryExpr::CreateIndex(_)
336            | QueryExpr::DropIndex(_)
337            | QueryExpr::ProbabilisticCommand(_)
338            | QueryExpr::Ask(_)
339            | QueryExpr::SetConfig { .. }
340            | QueryExpr::ShowConfig { .. }
341            | QueryExpr::SetSecret { .. }
342            | QueryExpr::DeleteSecret { .. }
343            | QueryExpr::ShowSecrets { .. }
344            | QueryExpr::SetTenant(_)
345            | QueryExpr::ShowTenant
346            | QueryExpr::CreateTimeSeries(_)
347            | QueryExpr::CreateMetric(_)
348            | QueryExpr::AlterMetric(_)
349            | QueryExpr::CreateSlo(_)
350            | QueryExpr::DropTimeSeries(_)
351            | QueryExpr::CreateQueue(_)
352            | QueryExpr::AlterQueue(_)
353            | QueryExpr::DropQueue(_)
354            | QueryExpr::QueueSelect(_)
355            | QueryExpr::QueueCommand(_)
356            | QueryExpr::KvCommand(_)
357            | QueryExpr::ConfigCommand(_)
358            | QueryExpr::CreateTree(_)
359            | QueryExpr::DropTree(_)
360            | QueryExpr::TreeCommand(_)
361            | QueryExpr::ExplainAlter(_)
362            | QueryExpr::TransactionControl(_)
363            | QueryExpr::MaintenanceCommand(_)
364            | QueryExpr::CreateSchema(_)
365            | QueryExpr::DropSchema(_)
366            | QueryExpr::CreateSequence(_)
367            | QueryExpr::DropSequence(_)
368            | QueryExpr::CopyFrom(_)
369            | QueryExpr::CreateView(_)
370            | QueryExpr::DropView(_)
371            | QueryExpr::RefreshMaterializedView(_)
372            | QueryExpr::CreatePolicy(_)
373            | QueryExpr::DropPolicy(_)
374            | QueryExpr::CreateServer(_)
375            | QueryExpr::DropServer(_)
376            | QueryExpr::CreateForeignTable(_)
377            | QueryExpr::DropForeignTable(_)
378            | QueryExpr::Grant(_)
379            | QueryExpr::Revoke(_)
380            | QueryExpr::AlterUser(_)
381            | QueryExpr::CreateUser(_)
382            | QueryExpr::CreateIamPolicy { .. }
383            | QueryExpr::DropIamPolicy { .. }
384            | QueryExpr::AttachPolicy { .. }
385            | QueryExpr::DetachPolicy { .. }
386            | QueryExpr::ShowPolicies { .. }
387            | QueryExpr::ShowEffectivePermissions { .. }
388            | QueryExpr::RankOf(_)
389            | QueryExpr::ApproxRankOf(_)
390            | QueryExpr::RankRange(_)
391            | QueryExpr::SimulatePolicy { .. }
392            | QueryExpr::LintPolicy { .. }
393            | QueryExpr::MigratePolicyMode { .. }
394            | QueryExpr::CreateMigration(_)
395            | QueryExpr::ApplyMigration(_)
396            | QueryExpr::RollbackMigration(_)
397            | QueryExpr::ExplainMigration(_)
398            | QueryExpr::EventsBackfill(_)
399            | QueryExpr::EventsBackfillStatus { .. }) => other,
400        }
401    }
402
403    fn is_applicable(&self, query: &QueryExpr) -> bool {
404        match query {
405            QueryExpr::Table(tq) => effective_table_filter(tq).is_some(),
406            QueryExpr::Graph(gq) => effective_graph_filter(gq).is_some(),
407            QueryExpr::Join(_) => true,
408            QueryExpr::Path(_) => false,
409            QueryExpr::Vector(vq) => effective_vector_filter(vq).is_some(),
410            QueryExpr::Hybrid(_) => true, // May have filters in structured part
411            // DML/DDL/Command statements are not applicable for filter simplification
412            QueryExpr::Insert(_)
413            | QueryExpr::Update(_)
414            | QueryExpr::Delete(_)
415            | QueryExpr::CreateTable(_)
416            | QueryExpr::CreateCollection(_)
417            | QueryExpr::CreateVector(_)
418            | QueryExpr::DropTable(_)
419            | QueryExpr::DropGraph(_)
420            | QueryExpr::DropVector(_)
421            | QueryExpr::DropDocument(_)
422            | QueryExpr::DropKv(_)
423            | QueryExpr::DropCollection(_)
424            | QueryExpr::Truncate(_)
425            | QueryExpr::AlterTable(_)
426            | QueryExpr::GraphCommand(_)
427            | QueryExpr::SearchCommand(_)
428            | QueryExpr::CreateIndex(_)
429            | QueryExpr::DropIndex(_)
430            | QueryExpr::ProbabilisticCommand(_)
431            | QueryExpr::Ask(_)
432            | QueryExpr::SetConfig { .. }
433            | QueryExpr::ShowConfig { .. }
434            | QueryExpr::SetSecret { .. }
435            | QueryExpr::DeleteSecret { .. }
436            | QueryExpr::ShowSecrets { .. }
437            | QueryExpr::SetTenant(_)
438            | QueryExpr::ShowTenant
439            | QueryExpr::CreateTimeSeries(_)
440            | QueryExpr::CreateMetric(_)
441            | QueryExpr::AlterMetric(_)
442            | QueryExpr::CreateSlo(_)
443            | QueryExpr::DropTimeSeries(_)
444            | QueryExpr::CreateQueue(_)
445            | QueryExpr::AlterQueue(_)
446            | QueryExpr::DropQueue(_)
447            | QueryExpr::QueueSelect(_)
448            | QueryExpr::QueueCommand(_)
449            | QueryExpr::KvCommand(_)
450            | QueryExpr::ConfigCommand(_)
451            | QueryExpr::CreateTree(_)
452            | QueryExpr::DropTree(_)
453            | QueryExpr::TreeCommand(_)
454            | QueryExpr::ExplainAlter(_)
455            | QueryExpr::TransactionControl(_)
456            | QueryExpr::MaintenanceCommand(_)
457            | QueryExpr::CreateSchema(_)
458            | QueryExpr::DropSchema(_)
459            | QueryExpr::CreateSequence(_)
460            | QueryExpr::DropSequence(_)
461            | QueryExpr::CopyFrom(_)
462            | QueryExpr::CreateView(_)
463            | QueryExpr::DropView(_)
464            | QueryExpr::RefreshMaterializedView(_)
465            | QueryExpr::CreatePolicy(_)
466            | QueryExpr::DropPolicy(_)
467            | QueryExpr::CreateServer(_)
468            | QueryExpr::DropServer(_)
469            | QueryExpr::CreateForeignTable(_)
470            | QueryExpr::DropForeignTable(_)
471            | QueryExpr::Grant(_)
472            | QueryExpr::Revoke(_)
473            | QueryExpr::AlterUser(_)
474            | QueryExpr::CreateUser(_)
475            | QueryExpr::CreateIamPolicy { .. }
476            | QueryExpr::DropIamPolicy { .. }
477            | QueryExpr::AttachPolicy { .. }
478            | QueryExpr::DetachPolicy { .. }
479            | QueryExpr::ShowPolicies { .. }
480            | QueryExpr::ShowEffectivePermissions { .. }
481            | QueryExpr::RankOf(_)
482            | QueryExpr::ApproxRankOf(_)
483            | QueryExpr::RankRange(_)
484            | QueryExpr::SimulatePolicy { .. }
485            | QueryExpr::LintPolicy { .. }
486            | QueryExpr::MigratePolicyMode { .. }
487            | QueryExpr::CreateMigration(_)
488            | QueryExpr::ApplyMigration(_)
489            | QueryExpr::RollbackMigration(_)
490            | QueryExpr::ExplainMigration(_)
491            | QueryExpr::EventsBackfill(_)
492            | QueryExpr::EventsBackfillStatus { .. } => false,
493        }
494    }
495}
496
497/// Push predicates down to data sources
498struct PushdownPredicatesRule;
499
500impl RewriteRule for PushdownPredicatesRule {
501    fn name(&self) -> &str {
502        "PushdownPredicates"
503    }
504
505    fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
506        match query {
507            QueryExpr::Join(mut jq) => {
508                // Try to push join predicates down to children
509                // This is a simplified version - real implementation would analyze
510                // which predicates can be pushed to which child
511
512                // For now, just recursively apply to children
513                jq.left = Box::new(self.apply(*jq.left, ctx));
514                jq.right = Box::new(self.apply(*jq.right, ctx));
515                ctx.stats.predicates_pushed += 1;
516                QueryExpr::Join(jq)
517            }
518            other => other,
519        }
520    }
521
522    fn is_applicable(&self, query: &QueryExpr) -> bool {
523        matches!(query, QueryExpr::Join(_))
524    }
525}
526
527/// Eliminate dead code branches
528struct EliminateDeadCodeRule;
529
530impl RewriteRule for EliminateDeadCodeRule {
531    fn name(&self) -> &str {
532        "EliminateDeadCode"
533    }
534
535    fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
536        match query {
537            QueryExpr::Table(mut tq) => {
538                // Remove always-true filters
539                if let Some(filter) = effective_table_filter(&tq).as_ref() {
540                    if is_always_true(filter) {
541                        tq.filter = None;
542                    }
543                }
544                QueryExpr::Table(tq)
545            }
546            other => other,
547        }
548    }
549
550    fn is_applicable(&self, query: &QueryExpr) -> bool {
551        matches!(query, QueryExpr::Table(_))
552    }
553}
554
555/// Fold constant expressions
556struct FoldConstantsRule;
557
558impl RewriteRule for FoldConstantsRule {
559    fn name(&self) -> &str {
560        "FoldConstants"
561    }
562
563    fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
564        // Constant folding is complex - for now just pass through
565        // A real implementation would evaluate constant expressions at compile time
566        query
567    }
568
569    fn is_applicable(&self, _query: &QueryExpr) -> bool {
570        true
571    }
572}
573
574// =============================================================================
575// Helper Functions
576// =============================================================================
577
578fn projection_name(proj: &Projection) -> String {
579    match proj {
580        Projection::All => "*".to_string(),
581        Projection::Column(name) => name.clone(),
582        Projection::Alias(_, alias) => alias.clone(),
583        Projection::Function(name, _) => name
584            .split_once(':')
585            .map(|(_, alias)| alias.to_string())
586            .unwrap_or_else(|| name.clone()),
587        Projection::Expression(expr, alias) => {
588            alias.clone().unwrap_or_else(|| format!("{:?}", expr))
589        }
590        Projection::Field(field, alias) => alias.clone().unwrap_or_else(|| format!("{:?}", field)),
591        Projection::Window { name, alias, .. } => alias.clone().unwrap_or_else(|| name.clone()),
592    }
593}
594
595fn simplify_filter(filter: AstFilter, ctx: &mut RewriteContext) -> AstFilter {
596    match filter {
597        AstFilter::And(left, right) => {
598            let left = simplify_filter(*left, ctx);
599            let right = simplify_filter(*right, ctx);
600
601            // AND with TRUE -> other side
602            if is_always_true(&left) {
603                ctx.stats.filters_simplified += 1;
604                return right;
605            }
606            if is_always_true(&right) {
607                ctx.stats.filters_simplified += 1;
608                return left;
609            }
610
611            // AND with FALSE -> FALSE
612            if is_always_false(&left) || is_always_false(&right) {
613                ctx.stats.filters_simplified += 1;
614                return AstFilter::Compare {
615                    field: FieldRef::TableColumn {
616                        table: String::new(),
617                        column: "1".to_string(),
618                    },
619                    op: CompareOp::Eq,
620                    value: Value::Integer(0),
621                };
622            }
623
624            AstFilter::And(Box::new(left), Box::new(right))
625        }
626        AstFilter::Or(left, right) => {
627            let left = simplify_filter(*left, ctx);
628            let right = simplify_filter(*right, ctx);
629
630            // OR with FALSE -> other side
631            if is_always_false(&left) {
632                ctx.stats.filters_simplified += 1;
633                return right;
634            }
635            if is_always_false(&right) {
636                ctx.stats.filters_simplified += 1;
637                return left;
638            }
639
640            // OR with TRUE -> TRUE
641            if is_always_true(&left) || is_always_true(&right) {
642                ctx.stats.filters_simplified += 1;
643                return AstFilter::Compare {
644                    field: FieldRef::TableColumn {
645                        table: String::new(),
646                        column: "1".to_string(),
647                    },
648                    op: CompareOp::Eq,
649                    value: Value::Integer(1),
650                };
651            }
652
653            AstFilter::Or(Box::new(left), Box::new(right))
654        }
655        AstFilter::Not(inner) => {
656            let inner = simplify_filter(*inner, ctx);
657
658            // NOT NOT x -> x
659            if let AstFilter::Not(double_inner) = inner {
660                ctx.stats.filters_simplified += 1;
661                return *double_inner;
662            }
663
664            AstFilter::Not(Box::new(inner))
665        }
666        other => other,
667    }
668}
669
670fn is_always_true(filter: &AstFilter) -> bool {
671    match filter {
672        AstFilter::Compare { field, op, value } => {
673            // 1 = 1 is always true
674            matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
675                && matches!(op, CompareOp::Eq)
676                && matches!(value, Value::Integer(1))
677        }
678        _ => false,
679    }
680}
681
682fn is_always_false(filter: &AstFilter) -> bool {
683    match filter {
684        AstFilter::Compare { field, op, value } => {
685            // 1 = 0 is always false
686            matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
687                && matches!(op, CompareOp::Eq)
688                && matches!(value, Value::Integer(0))
689        }
690        _ => false,
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use crate::ast::{JoinCondition, TableQuery, WindowSpec};
698
699    fn make_field(name: &str) -> FieldRef {
700        FieldRef::TableColumn {
701            table: String::new(),
702            column: name.to_string(),
703        }
704    }
705
706    #[test]
707    fn test_simplify_and_with_true() {
708        let mut ctx = RewriteContext::default();
709
710        let filter = AstFilter::And(
711            Box::new(AstFilter::Compare {
712                field: make_field("1"),
713                op: CompareOp::Eq,
714                value: Value::Integer(1),
715            }),
716            Box::new(AstFilter::Compare {
717                field: make_field("x"),
718                op: CompareOp::Eq,
719                value: Value::Integer(5),
720            }),
721        );
722
723        let simplified = simplify_filter(filter, &mut ctx);
724
725        match simplified {
726            AstFilter::Compare { field, .. } => {
727                assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
728            }
729            _ => panic!("Expected Compare filter"),
730        }
731    }
732
733    #[test]
734    fn test_simplify_double_not() {
735        let mut ctx = RewriteContext::default();
736
737        let filter = AstFilter::Not(Box::new(AstFilter::Not(Box::new(AstFilter::Compare {
738            field: make_field("x"),
739            op: CompareOp::Eq,
740            value: Value::Integer(5),
741        }))));
742
743        let simplified = simplify_filter(filter, &mut ctx);
744
745        match simplified {
746            AstFilter::Compare { field, .. } => {
747                assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
748            }
749            _ => panic!("Expected Compare filter"),
750        }
751    }
752
753    #[test]
754    fn projection_name_uses_visible_output_name_for_all_projection_shapes() {
755        assert_eq!(projection_name(&Projection::All), "*");
756        assert_eq!(
757            projection_name(&Projection::Column("raw".to_string())),
758            "raw"
759        );
760        assert_eq!(
761            projection_name(&Projection::Alias("raw".to_string(), "alias".to_string())),
762            "alias"
763        );
764        assert_eq!(
765            projection_name(&Projection::Function(
766                "LOWER:display".to_string(),
767                Vec::new()
768            )),
769            "display"
770        );
771        assert_eq!(
772            projection_name(&Projection::Expression(
773                Box::new(AstFilter::Compare {
774                    field: make_field("x"),
775                    op: CompareOp::Eq,
776                    value: Value::Integer(1),
777                }),
778                Some("expr_alias".to_string()),
779            )),
780            "expr_alias"
781        );
782        assert_eq!(
783            projection_name(&Projection::Field(
784                FieldRef::node_prop("n", "name"),
785                Some("node_name".to_string()),
786            )),
787            "node_name"
788        );
789        assert_eq!(
790            projection_name(&Projection::Window {
791                name: "ROW_NUMBER".to_string(),
792                args: Vec::new(),
793                window: Box::new(WindowSpec::default()),
794                alias: Some("rn".to_string()),
795            }),
796            "rn"
797        );
798    }
799
800    #[test]
801    fn normalize_rule_sorts_table_columns_by_output_name() {
802        let mut table = TableQuery::new("users");
803        table.columns = vec![
804            Projection::Column("z".to_string()),
805            Projection::Function("LOWER:a_alias".to_string(), Vec::new()),
806            Projection::Alias("name".to_string(), "m".to_string()),
807        ];
808
809        let mut ctx = RewriteContext::default();
810        let normalized = NormalizeRule.apply(QueryExpr::Table(table), &mut ctx);
811        let QueryExpr::Table(table) = normalized else {
812            panic!("expected table query");
813        };
814
815        assert_eq!(ctx.stats.expressions_normalized, 1);
816        assert_eq!(
817            table
818                .columns
819                .iter()
820                .map(projection_name)
821                .collect::<Vec<_>>(),
822            vec!["a_alias", "m", "z"]
823        );
824    }
825
826    #[test]
827    fn simplify_filter_covers_or_true_false_and_not_paths() {
828        let mut ctx = RewriteContext::default();
829        let truth = AstFilter::Compare {
830            field: make_field("1"),
831            op: CompareOp::Eq,
832            value: Value::Integer(1),
833        };
834        let falsehood = AstFilter::Compare {
835            field: make_field("1"),
836            op: CompareOp::Eq,
837            value: Value::Integer(0),
838        };
839        let predicate = AstFilter::Compare {
840            field: make_field("x"),
841            op: CompareOp::Eq,
842            value: Value::Integer(5),
843        };
844
845        assert_eq!(
846            simplify_filter(
847                AstFilter::Or(Box::new(falsehood.clone()), Box::new(predicate.clone())),
848                &mut ctx,
849            ),
850            predicate
851        );
852        assert!(is_always_true(&simplify_filter(
853            AstFilter::Or(Box::new(truth), Box::new(falsehood.clone())),
854            &mut ctx,
855        )));
856        assert!(is_always_false(&simplify_filter(
857            AstFilter::And(
858                Box::new(falsehood),
859                Box::new(AstFilter::IsNotNull(make_field("x")))
860            ),
861            &mut ctx,
862        )));
863        assert!(ctx.stats.filters_simplified >= 3);
864    }
865
866    #[test]
867    fn query_rewriter_runs_rules_until_fixed_point_and_exposes_context() {
868        let mut table = TableQuery::new("users");
869        table.filter = Some(AstFilter::And(
870            Box::new(AstFilter::Compare {
871                field: make_field("1"),
872                op: CompareOp::Eq,
873                value: Value::Integer(1),
874            }),
875            Box::new(AstFilter::Compare {
876                field: make_field("age"),
877                op: CompareOp::Ge,
878                value: Value::Integer(18),
879            }),
880        ));
881
882        let mut ctx = RewriteContext::default();
883        let rewritten =
884            QueryRewriter::default().rewrite_with_context(QueryExpr::Table(table), &mut ctx);
885        let QueryExpr::Table(table) = rewritten else {
886            panic!("expected table query");
887        };
888
889        assert!(matches!(
890            table.filter,
891            Some(AstFilter::Compare {
892                field: FieldRef::TableColumn { column, .. },
893                op: CompareOp::Ge,
894                value: Value::Integer(18),
895            }) if column == "age"
896        ));
897        assert!(ctx.stats.filters_simplified >= 1);
898    }
899
900    #[test]
901    fn query_rewriter_recurses_into_join_children_and_tracks_pushdown() {
902        let mut left = TableQuery::new("users");
903        left.columns = vec![
904            Projection::Column("z".to_string()),
905            Projection::Column("a".to_string()),
906        ];
907        left.filter = Some(AstFilter::And(
908            Box::new(AstFilter::Compare {
909                field: make_field("1"),
910                op: CompareOp::Eq,
911                value: Value::Integer(1),
912            }),
913            Box::new(AstFilter::Compare {
914                field: make_field("age"),
915                op: CompareOp::Ge,
916                value: Value::Integer(18),
917            }),
918        ));
919
920        let join = JoinQuery::new(
921            QueryExpr::Table(left),
922            QueryExpr::Table(TableQuery::new("orders")),
923            JoinCondition::new(make_field("id"), make_field("user_id")),
924        );
925
926        let mut ctx = RewriteContext::default();
927        let rewritten =
928            QueryRewriter::default().rewrite_with_context(QueryExpr::Join(join), &mut ctx);
929        let QueryExpr::Join(join) = rewritten else {
930            panic!("expected join query");
931        };
932        let QueryExpr::Table(left) = join.left.as_ref() else {
933            panic!("expected table on left side");
934        };
935
936        assert_eq!(
937            left.columns.iter().map(projection_name).collect::<Vec<_>>(),
938            vec!["a", "z"]
939        );
940        assert!(matches!(
941            &left.filter,
942            Some(AstFilter::Compare {
943                field: FieldRef::TableColumn { ref column, .. },
944                op: CompareOp::Ge,
945                value: Value::Integer(18),
946            }) if column == "age"
947        ));
948        assert!(ctx.stats.predicates_pushed >= 1);
949    }
950
951    #[test]
952    fn query_rewriter_eliminates_always_true_table_filters() {
953        let mut table = TableQuery::new("users");
954        table.filter = Some(AstFilter::Compare {
955            field: make_field("1"),
956            op: CompareOp::Eq,
957            value: Value::Integer(1),
958        });
959
960        let rewritten = QueryRewriter::default().rewrite(QueryExpr::Table(table));
961        let QueryExpr::Table(table) = rewritten else {
962            panic!("expected table query");
963        };
964
965        assert!(table.filter.is_none());
966    }
967
968    struct CountingRule;
969
970    impl RewriteRule for CountingRule {
971        fn name(&self) -> &str {
972            "CountingRule"
973        }
974
975        fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
976            ctx.warnings.push(self.name().to_string());
977            query
978        }
979
980        fn is_applicable(&self, query: &QueryExpr) -> bool {
981            matches!(query, QueryExpr::Table(_))
982        }
983    }
984
985    #[test]
986    fn custom_rules_can_be_added_after_defaults() {
987        let mut rewriter = QueryRewriter::new();
988        rewriter.add_rule(Box::new(CountingRule));
989
990        let mut ctx = RewriteContext::default();
991        let rewritten =
992            rewriter.rewrite_with_context(QueryExpr::Table(TableQuery::new("users")), &mut ctx);
993
994        assert!(matches!(rewritten, QueryExpr::Table(_)));
995        assert!(ctx.warnings.iter().any(|warning| warning == "CountingRule"));
996    }
997}