use crate::ast::{CompareOp, FieldRef, Filter as AstFilter, JoinQuery, Projection, QueryExpr};
use crate::sql_lowering::{
effective_graph_filter, effective_join_filter, effective_table_filter, effective_vector_filter,
};
use reddb_types::Value;
#[derive(Debug, Clone, Default)]
pub struct RewriteContext {
pub property_cache: Vec<CachedProperty>,
pub errors: Vec<String>,
pub warnings: Vec<String>,
pub stats: RewriteStats,
}
#[derive(Debug, Clone)]
pub struct CachedProperty {
pub source: String,
pub property: String,
pub cached_value: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct RewriteStats {
pub filters_simplified: u32,
pub predicates_pushed: u32,
pub properties_cached: u32,
pub expressions_normalized: u32,
}
pub trait RewriteRule: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr;
fn is_applicable(&self, query: &QueryExpr) -> bool;
}
pub struct QueryRewriter {
rules: Vec<Box<dyn RewriteRule>>,
max_iterations: usize,
}
impl QueryRewriter {
pub fn new() -> Self {
let rules: Vec<Box<dyn RewriteRule>> = vec![
Box::new(NormalizeRule),
Box::new(SimplifyFiltersRule),
Box::new(PushdownPredicatesRule),
Box::new(EliminateDeadCodeRule),
Box::new(FoldConstantsRule),
];
Self {
rules,
max_iterations: 10,
}
}
pub fn add_rule(&mut self, rule: Box<dyn RewriteRule>) {
self.rules.push(rule);
}
pub fn rewrite(&self, query: QueryExpr) -> QueryExpr {
let mut ctx = RewriteContext::default();
self.rewrite_with_context(query, &mut ctx)
}
pub fn rewrite_with_context(
&self,
mut query: QueryExpr,
ctx: &mut RewriteContext,
) -> QueryExpr {
for _iteration in 0..self.max_iterations {
let original = format!("{:?}", query);
for rule in &self.rules {
if rule.is_applicable(&query) {
query = rule.apply(query, ctx);
}
}
if format!("{:?}", query) == original {
break;
}
}
query
}
}
impl Default for QueryRewriter {
fn default() -> Self {
Self::new()
}
}
struct NormalizeRule;
impl RewriteRule for NormalizeRule {
fn name(&self) -> &str {
"Normalize"
}
fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
match query {
QueryExpr::Table(mut tq) => {
tq.columns.sort_by(|a, b| {
let a_name = projection_name(a);
let b_name = projection_name(b);
a_name.cmp(&b_name)
});
ctx.stats.expressions_normalized += 1;
QueryExpr::Table(tq)
}
QueryExpr::Graph(gq) => {
QueryExpr::Graph(gq)
}
QueryExpr::Join(jq) => {
let left = self.apply(*jq.left, ctx);
let right = self.apply(*jq.right, ctx);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..jq
})
}
QueryExpr::Path(pq) => QueryExpr::Path(pq),
QueryExpr::Vector(vq) => {
QueryExpr::Vector(vq)
}
QueryExpr::Hybrid(mut hq) => {
hq.structured = Box::new(self.apply(*hq.structured, ctx));
QueryExpr::Hybrid(hq)
}
other @ (QueryExpr::Insert(_)
| QueryExpr::Update(_)
| QueryExpr::Delete(_)
| QueryExpr::CreateTable(_)
| QueryExpr::CreateCollection(_)
| QueryExpr::CreateVector(_)
| QueryExpr::DropTable(_)
| QueryExpr::DropGraph(_)
| QueryExpr::DropVector(_)
| QueryExpr::DropDocument(_)
| QueryExpr::DropKv(_)
| QueryExpr::DropCollection(_)
| QueryExpr::Truncate(_)
| QueryExpr::AlterTable(_)
| QueryExpr::GraphCommand(_)
| QueryExpr::SearchCommand(_)
| QueryExpr::CreateIndex(_)
| QueryExpr::DropIndex(_)
| QueryExpr::ProbabilisticCommand(_)
| QueryExpr::Ask(_)
| QueryExpr::SetConfig { .. }
| QueryExpr::ShowConfig { .. }
| QueryExpr::SetSecret { .. }
| QueryExpr::DeleteSecret { .. }
| QueryExpr::ShowSecrets { .. }
| QueryExpr::SetTenant(_)
| QueryExpr::ShowTenant
| QueryExpr::CreateTimeSeries(_)
| QueryExpr::CreateMetric(_)
| QueryExpr::AlterMetric(_)
| QueryExpr::CreateSlo(_)
| QueryExpr::DropTimeSeries(_)
| QueryExpr::CreateQueue(_)
| QueryExpr::AlterQueue(_)
| QueryExpr::DropQueue(_)
| QueryExpr::QueueSelect(_)
| QueryExpr::QueueCommand(_)
| QueryExpr::KvCommand(_)
| QueryExpr::ConfigCommand(_)
| QueryExpr::CreateTree(_)
| QueryExpr::DropTree(_)
| QueryExpr::TreeCommand(_)
| QueryExpr::ExplainAlter(_)
| QueryExpr::TransactionControl(_)
| QueryExpr::MaintenanceCommand(_)
| QueryExpr::CreateSchema(_)
| QueryExpr::DropSchema(_)
| QueryExpr::CreateSequence(_)
| QueryExpr::DropSequence(_)
| QueryExpr::CopyFrom(_)
| QueryExpr::CreateView(_)
| QueryExpr::DropView(_)
| QueryExpr::RefreshMaterializedView(_)
| QueryExpr::CreatePolicy(_)
| QueryExpr::DropPolicy(_)
| QueryExpr::CreateServer(_)
| QueryExpr::DropServer(_)
| QueryExpr::CreateForeignTable(_)
| QueryExpr::DropForeignTable(_)
| QueryExpr::Grant(_)
| QueryExpr::Revoke(_)
| QueryExpr::AlterUser(_)
| QueryExpr::CreateUser(_)
| QueryExpr::CreateIamPolicy { .. }
| QueryExpr::DropIamPolicy { .. }
| QueryExpr::AttachPolicy { .. }
| QueryExpr::DetachPolicy { .. }
| QueryExpr::ShowPolicies { .. }
| QueryExpr::ShowEffectivePermissions { .. }
| QueryExpr::RankOf(_)
| QueryExpr::ApproxRankOf(_)
| QueryExpr::RankRange(_)
| QueryExpr::SimulatePolicy { .. }
| QueryExpr::LintPolicy { .. }
| QueryExpr::MigratePolicyMode { .. }
| QueryExpr::CreateMigration(_)
| QueryExpr::ApplyMigration(_)
| QueryExpr::RollbackMigration(_)
| QueryExpr::ExplainMigration(_)
| QueryExpr::EventsBackfill(_)
| QueryExpr::EventsBackfillStatus { .. }) => other,
}
}
fn is_applicable(&self, _query: &QueryExpr) -> bool {
true
}
}
struct SimplifyFiltersRule;
impl RewriteRule for SimplifyFiltersRule {
fn name(&self) -> &str {
"SimplifyFilters"
}
fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
match query {
QueryExpr::Table(mut tq) => {
if let Some(filter) = effective_table_filter(&tq) {
tq.filter = Some(simplify_filter(filter, ctx));
}
QueryExpr::Table(tq)
}
QueryExpr::Graph(mut gq) => {
if let Some(filter) = effective_graph_filter(&gq) {
gq.filter = Some(simplify_filter(filter, ctx));
}
QueryExpr::Graph(gq)
}
QueryExpr::Join(mut jq) => {
let join_filter = effective_join_filter(&jq);
let left = self.apply(*jq.left, ctx);
let right = self.apply(*jq.right, ctx);
if let Some(filter) = join_filter {
jq.filter = Some(simplify_filter(filter, ctx));
}
jq.left = Box::new(left);
jq.right = Box::new(right);
QueryExpr::Join(jq)
}
QueryExpr::Path(pq) => QueryExpr::Path(pq),
QueryExpr::Vector(vq) => {
QueryExpr::Vector(vq)
}
QueryExpr::Hybrid(mut hq) => {
hq.structured = Box::new(self.apply(*hq.structured, ctx));
QueryExpr::Hybrid(hq)
}
other @ (QueryExpr::Insert(_)
| QueryExpr::Update(_)
| QueryExpr::Delete(_)
| QueryExpr::CreateTable(_)
| QueryExpr::CreateCollection(_)
| QueryExpr::CreateVector(_)
| QueryExpr::DropTable(_)
| QueryExpr::DropGraph(_)
| QueryExpr::DropVector(_)
| QueryExpr::DropDocument(_)
| QueryExpr::DropKv(_)
| QueryExpr::DropCollection(_)
| QueryExpr::Truncate(_)
| QueryExpr::AlterTable(_)
| QueryExpr::GraphCommand(_)
| QueryExpr::SearchCommand(_)
| QueryExpr::CreateIndex(_)
| QueryExpr::DropIndex(_)
| QueryExpr::ProbabilisticCommand(_)
| QueryExpr::Ask(_)
| QueryExpr::SetConfig { .. }
| QueryExpr::ShowConfig { .. }
| QueryExpr::SetSecret { .. }
| QueryExpr::DeleteSecret { .. }
| QueryExpr::ShowSecrets { .. }
| QueryExpr::SetTenant(_)
| QueryExpr::ShowTenant
| QueryExpr::CreateTimeSeries(_)
| QueryExpr::CreateMetric(_)
| QueryExpr::AlterMetric(_)
| QueryExpr::CreateSlo(_)
| QueryExpr::DropTimeSeries(_)
| QueryExpr::CreateQueue(_)
| QueryExpr::AlterQueue(_)
| QueryExpr::DropQueue(_)
| QueryExpr::QueueSelect(_)
| QueryExpr::QueueCommand(_)
| QueryExpr::KvCommand(_)
| QueryExpr::ConfigCommand(_)
| QueryExpr::CreateTree(_)
| QueryExpr::DropTree(_)
| QueryExpr::TreeCommand(_)
| QueryExpr::ExplainAlter(_)
| QueryExpr::TransactionControl(_)
| QueryExpr::MaintenanceCommand(_)
| QueryExpr::CreateSchema(_)
| QueryExpr::DropSchema(_)
| QueryExpr::CreateSequence(_)
| QueryExpr::DropSequence(_)
| QueryExpr::CopyFrom(_)
| QueryExpr::CreateView(_)
| QueryExpr::DropView(_)
| QueryExpr::RefreshMaterializedView(_)
| QueryExpr::CreatePolicy(_)
| QueryExpr::DropPolicy(_)
| QueryExpr::CreateServer(_)
| QueryExpr::DropServer(_)
| QueryExpr::CreateForeignTable(_)
| QueryExpr::DropForeignTable(_)
| QueryExpr::Grant(_)
| QueryExpr::Revoke(_)
| QueryExpr::AlterUser(_)
| QueryExpr::CreateUser(_)
| QueryExpr::CreateIamPolicy { .. }
| QueryExpr::DropIamPolicy { .. }
| QueryExpr::AttachPolicy { .. }
| QueryExpr::DetachPolicy { .. }
| QueryExpr::ShowPolicies { .. }
| QueryExpr::ShowEffectivePermissions { .. }
| QueryExpr::RankOf(_)
| QueryExpr::ApproxRankOf(_)
| QueryExpr::RankRange(_)
| QueryExpr::SimulatePolicy { .. }
| QueryExpr::LintPolicy { .. }
| QueryExpr::MigratePolicyMode { .. }
| QueryExpr::CreateMigration(_)
| QueryExpr::ApplyMigration(_)
| QueryExpr::RollbackMigration(_)
| QueryExpr::ExplainMigration(_)
| QueryExpr::EventsBackfill(_)
| QueryExpr::EventsBackfillStatus { .. }) => other,
}
}
fn is_applicable(&self, query: &QueryExpr) -> bool {
match query {
QueryExpr::Table(tq) => effective_table_filter(tq).is_some(),
QueryExpr::Graph(gq) => effective_graph_filter(gq).is_some(),
QueryExpr::Join(_) => true,
QueryExpr::Path(_) => false,
QueryExpr::Vector(vq) => effective_vector_filter(vq).is_some(),
QueryExpr::Hybrid(_) => true, QueryExpr::Insert(_)
| QueryExpr::Update(_)
| QueryExpr::Delete(_)
| QueryExpr::CreateTable(_)
| QueryExpr::CreateCollection(_)
| QueryExpr::CreateVector(_)
| QueryExpr::DropTable(_)
| QueryExpr::DropGraph(_)
| QueryExpr::DropVector(_)
| QueryExpr::DropDocument(_)
| QueryExpr::DropKv(_)
| QueryExpr::DropCollection(_)
| QueryExpr::Truncate(_)
| QueryExpr::AlterTable(_)
| QueryExpr::GraphCommand(_)
| QueryExpr::SearchCommand(_)
| QueryExpr::CreateIndex(_)
| QueryExpr::DropIndex(_)
| QueryExpr::ProbabilisticCommand(_)
| QueryExpr::Ask(_)
| QueryExpr::SetConfig { .. }
| QueryExpr::ShowConfig { .. }
| QueryExpr::SetSecret { .. }
| QueryExpr::DeleteSecret { .. }
| QueryExpr::ShowSecrets { .. }
| QueryExpr::SetTenant(_)
| QueryExpr::ShowTenant
| QueryExpr::CreateTimeSeries(_)
| QueryExpr::CreateMetric(_)
| QueryExpr::AlterMetric(_)
| QueryExpr::CreateSlo(_)
| QueryExpr::DropTimeSeries(_)
| QueryExpr::CreateQueue(_)
| QueryExpr::AlterQueue(_)
| QueryExpr::DropQueue(_)
| QueryExpr::QueueSelect(_)
| QueryExpr::QueueCommand(_)
| QueryExpr::KvCommand(_)
| QueryExpr::ConfigCommand(_)
| QueryExpr::CreateTree(_)
| QueryExpr::DropTree(_)
| QueryExpr::TreeCommand(_)
| QueryExpr::ExplainAlter(_)
| QueryExpr::TransactionControl(_)
| QueryExpr::MaintenanceCommand(_)
| QueryExpr::CreateSchema(_)
| QueryExpr::DropSchema(_)
| QueryExpr::CreateSequence(_)
| QueryExpr::DropSequence(_)
| QueryExpr::CopyFrom(_)
| QueryExpr::CreateView(_)
| QueryExpr::DropView(_)
| QueryExpr::RefreshMaterializedView(_)
| QueryExpr::CreatePolicy(_)
| QueryExpr::DropPolicy(_)
| QueryExpr::CreateServer(_)
| QueryExpr::DropServer(_)
| QueryExpr::CreateForeignTable(_)
| QueryExpr::DropForeignTable(_)
| QueryExpr::Grant(_)
| QueryExpr::Revoke(_)
| QueryExpr::AlterUser(_)
| QueryExpr::CreateUser(_)
| QueryExpr::CreateIamPolicy { .. }
| QueryExpr::DropIamPolicy { .. }
| QueryExpr::AttachPolicy { .. }
| QueryExpr::DetachPolicy { .. }
| QueryExpr::ShowPolicies { .. }
| QueryExpr::ShowEffectivePermissions { .. }
| QueryExpr::RankOf(_)
| QueryExpr::ApproxRankOf(_)
| QueryExpr::RankRange(_)
| QueryExpr::SimulatePolicy { .. }
| QueryExpr::LintPolicy { .. }
| QueryExpr::MigratePolicyMode { .. }
| QueryExpr::CreateMigration(_)
| QueryExpr::ApplyMigration(_)
| QueryExpr::RollbackMigration(_)
| QueryExpr::ExplainMigration(_)
| QueryExpr::EventsBackfill(_)
| QueryExpr::EventsBackfillStatus { .. } => false,
}
}
}
struct PushdownPredicatesRule;
impl RewriteRule for PushdownPredicatesRule {
fn name(&self) -> &str {
"PushdownPredicates"
}
fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
match query {
QueryExpr::Join(mut jq) => {
jq.left = Box::new(self.apply(*jq.left, ctx));
jq.right = Box::new(self.apply(*jq.right, ctx));
ctx.stats.predicates_pushed += 1;
QueryExpr::Join(jq)
}
other => other,
}
}
fn is_applicable(&self, query: &QueryExpr) -> bool {
matches!(query, QueryExpr::Join(_))
}
}
struct EliminateDeadCodeRule;
impl RewriteRule for EliminateDeadCodeRule {
fn name(&self) -> &str {
"EliminateDeadCode"
}
fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
match query {
QueryExpr::Table(mut tq) => {
if let Some(filter) = effective_table_filter(&tq).as_ref() {
if is_always_true(filter) {
tq.filter = None;
}
}
QueryExpr::Table(tq)
}
other => other,
}
}
fn is_applicable(&self, query: &QueryExpr) -> bool {
matches!(query, QueryExpr::Table(_))
}
}
struct FoldConstantsRule;
impl RewriteRule for FoldConstantsRule {
fn name(&self) -> &str {
"FoldConstants"
}
fn apply(&self, query: QueryExpr, _ctx: &mut RewriteContext) -> QueryExpr {
query
}
fn is_applicable(&self, _query: &QueryExpr) -> bool {
true
}
}
fn projection_name(proj: &Projection) -> String {
match proj {
Projection::All => "*".to_string(),
Projection::Column(name) => name.clone(),
Projection::Alias(_, alias) => alias.clone(),
Projection::Function(name, _) => name
.split_once(':')
.map(|(_, alias)| alias.to_string())
.unwrap_or_else(|| name.clone()),
Projection::Expression(expr, alias) => {
alias.clone().unwrap_or_else(|| format!("{:?}", expr))
}
Projection::Field(field, alias) => alias.clone().unwrap_or_else(|| format!("{:?}", field)),
Projection::Window { name, alias, .. } => alias.clone().unwrap_or_else(|| name.clone()),
}
}
fn simplify_filter(filter: AstFilter, ctx: &mut RewriteContext) -> AstFilter {
match filter {
AstFilter::And(left, right) => {
let left = simplify_filter(*left, ctx);
let right = simplify_filter(*right, ctx);
if is_always_true(&left) {
ctx.stats.filters_simplified += 1;
return right;
}
if is_always_true(&right) {
ctx.stats.filters_simplified += 1;
return left;
}
if is_always_false(&left) || is_always_false(&right) {
ctx.stats.filters_simplified += 1;
return AstFilter::Compare {
field: FieldRef::TableColumn {
table: String::new(),
column: "1".to_string(),
},
op: CompareOp::Eq,
value: Value::Integer(0),
};
}
AstFilter::And(Box::new(left), Box::new(right))
}
AstFilter::Or(left, right) => {
let left = simplify_filter(*left, ctx);
let right = simplify_filter(*right, ctx);
if is_always_false(&left) {
ctx.stats.filters_simplified += 1;
return right;
}
if is_always_false(&right) {
ctx.stats.filters_simplified += 1;
return left;
}
if is_always_true(&left) || is_always_true(&right) {
ctx.stats.filters_simplified += 1;
return AstFilter::Compare {
field: FieldRef::TableColumn {
table: String::new(),
column: "1".to_string(),
},
op: CompareOp::Eq,
value: Value::Integer(1),
};
}
AstFilter::Or(Box::new(left), Box::new(right))
}
AstFilter::Not(inner) => {
let inner = simplify_filter(*inner, ctx);
if let AstFilter::Not(double_inner) = inner {
ctx.stats.filters_simplified += 1;
return *double_inner;
}
AstFilter::Not(Box::new(inner))
}
other => other,
}
}
fn is_always_true(filter: &AstFilter) -> bool {
match filter {
AstFilter::Compare { field, op, value } => {
matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
&& matches!(op, CompareOp::Eq)
&& matches!(value, Value::Integer(1))
}
_ => false,
}
}
fn is_always_false(filter: &AstFilter) -> bool {
match filter {
AstFilter::Compare { field, op, value } => {
matches!(field, FieldRef::TableColumn { column, .. } if column == "1")
&& matches!(op, CompareOp::Eq)
&& matches!(value, Value::Integer(0))
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{JoinCondition, TableQuery, WindowSpec};
fn make_field(name: &str) -> FieldRef {
FieldRef::TableColumn {
table: String::new(),
column: name.to_string(),
}
}
#[test]
fn test_simplify_and_with_true() {
let mut ctx = RewriteContext::default();
let filter = AstFilter::And(
Box::new(AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(1),
}),
Box::new(AstFilter::Compare {
field: make_field("x"),
op: CompareOp::Eq,
value: Value::Integer(5),
}),
);
let simplified = simplify_filter(filter, &mut ctx);
match simplified {
AstFilter::Compare { field, .. } => {
assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
}
_ => panic!("Expected Compare filter"),
}
}
#[test]
fn test_simplify_double_not() {
let mut ctx = RewriteContext::default();
let filter = AstFilter::Not(Box::new(AstFilter::Not(Box::new(AstFilter::Compare {
field: make_field("x"),
op: CompareOp::Eq,
value: Value::Integer(5),
}))));
let simplified = simplify_filter(filter, &mut ctx);
match simplified {
AstFilter::Compare { field, .. } => {
assert!(matches!(field, FieldRef::TableColumn { column, .. } if column == "x"));
}
_ => panic!("Expected Compare filter"),
}
}
#[test]
fn projection_name_uses_visible_output_name_for_all_projection_shapes() {
assert_eq!(projection_name(&Projection::All), "*");
assert_eq!(
projection_name(&Projection::Column("raw".to_string())),
"raw"
);
assert_eq!(
projection_name(&Projection::Alias("raw".to_string(), "alias".to_string())),
"alias"
);
assert_eq!(
projection_name(&Projection::Function(
"LOWER:display".to_string(),
Vec::new()
)),
"display"
);
assert_eq!(
projection_name(&Projection::Expression(
Box::new(AstFilter::Compare {
field: make_field("x"),
op: CompareOp::Eq,
value: Value::Integer(1),
}),
Some("expr_alias".to_string()),
)),
"expr_alias"
);
assert_eq!(
projection_name(&Projection::Field(
FieldRef::node_prop("n", "name"),
Some("node_name".to_string()),
)),
"node_name"
);
assert_eq!(
projection_name(&Projection::Window {
name: "ROW_NUMBER".to_string(),
args: Vec::new(),
window: Box::new(WindowSpec::default()),
alias: Some("rn".to_string()),
}),
"rn"
);
}
#[test]
fn normalize_rule_sorts_table_columns_by_output_name() {
let mut table = TableQuery::new("users");
table.columns = vec![
Projection::Column("z".to_string()),
Projection::Function("LOWER:a_alias".to_string(), Vec::new()),
Projection::Alias("name".to_string(), "m".to_string()),
];
let mut ctx = RewriteContext::default();
let normalized = NormalizeRule.apply(QueryExpr::Table(table), &mut ctx);
let QueryExpr::Table(table) = normalized else {
panic!("expected table query");
};
assert_eq!(ctx.stats.expressions_normalized, 1);
assert_eq!(
table
.columns
.iter()
.map(projection_name)
.collect::<Vec<_>>(),
vec!["a_alias", "m", "z"]
);
}
#[test]
fn simplify_filter_covers_or_true_false_and_not_paths() {
let mut ctx = RewriteContext::default();
let truth = AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(1),
};
let falsehood = AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(0),
};
let predicate = AstFilter::Compare {
field: make_field("x"),
op: CompareOp::Eq,
value: Value::Integer(5),
};
assert_eq!(
simplify_filter(
AstFilter::Or(Box::new(falsehood.clone()), Box::new(predicate.clone())),
&mut ctx,
),
predicate
);
assert!(is_always_true(&simplify_filter(
AstFilter::Or(Box::new(truth), Box::new(falsehood.clone())),
&mut ctx,
)));
assert!(is_always_false(&simplify_filter(
AstFilter::And(
Box::new(falsehood),
Box::new(AstFilter::IsNotNull(make_field("x")))
),
&mut ctx,
)));
assert!(ctx.stats.filters_simplified >= 3);
}
#[test]
fn query_rewriter_runs_rules_until_fixed_point_and_exposes_context() {
let mut table = TableQuery::new("users");
table.filter = Some(AstFilter::And(
Box::new(AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(1),
}),
Box::new(AstFilter::Compare {
field: make_field("age"),
op: CompareOp::Ge,
value: Value::Integer(18),
}),
));
let mut ctx = RewriteContext::default();
let rewritten =
QueryRewriter::default().rewrite_with_context(QueryExpr::Table(table), &mut ctx);
let QueryExpr::Table(table) = rewritten else {
panic!("expected table query");
};
assert!(matches!(
table.filter,
Some(AstFilter::Compare {
field: FieldRef::TableColumn { column, .. },
op: CompareOp::Ge,
value: Value::Integer(18),
}) if column == "age"
));
assert!(ctx.stats.filters_simplified >= 1);
}
#[test]
fn query_rewriter_recurses_into_join_children_and_tracks_pushdown() {
let mut left = TableQuery::new("users");
left.columns = vec![
Projection::Column("z".to_string()),
Projection::Column("a".to_string()),
];
left.filter = Some(AstFilter::And(
Box::new(AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(1),
}),
Box::new(AstFilter::Compare {
field: make_field("age"),
op: CompareOp::Ge,
value: Value::Integer(18),
}),
));
let join = JoinQuery::new(
QueryExpr::Table(left),
QueryExpr::Table(TableQuery::new("orders")),
JoinCondition::new(make_field("id"), make_field("user_id")),
);
let mut ctx = RewriteContext::default();
let rewritten =
QueryRewriter::default().rewrite_with_context(QueryExpr::Join(join), &mut ctx);
let QueryExpr::Join(join) = rewritten else {
panic!("expected join query");
};
let QueryExpr::Table(left) = join.left.as_ref() else {
panic!("expected table on left side");
};
assert_eq!(
left.columns.iter().map(projection_name).collect::<Vec<_>>(),
vec!["a", "z"]
);
assert!(matches!(
&left.filter,
Some(AstFilter::Compare {
field: FieldRef::TableColumn { ref column, .. },
op: CompareOp::Ge,
value: Value::Integer(18),
}) if column == "age"
));
assert!(ctx.stats.predicates_pushed >= 1);
}
#[test]
fn query_rewriter_eliminates_always_true_table_filters() {
let mut table = TableQuery::new("users");
table.filter = Some(AstFilter::Compare {
field: make_field("1"),
op: CompareOp::Eq,
value: Value::Integer(1),
});
let rewritten = QueryRewriter::default().rewrite(QueryExpr::Table(table));
let QueryExpr::Table(table) = rewritten else {
panic!("expected table query");
};
assert!(table.filter.is_none());
}
struct CountingRule;
impl RewriteRule for CountingRule {
fn name(&self) -> &str {
"CountingRule"
}
fn apply(&self, query: QueryExpr, ctx: &mut RewriteContext) -> QueryExpr {
ctx.warnings.push(self.name().to_string());
query
}
fn is_applicable(&self, query: &QueryExpr) -> bool {
matches!(query, QueryExpr::Table(_))
}
}
#[test]
fn custom_rules_can_be_added_after_defaults() {
let mut rewriter = QueryRewriter::new();
rewriter.add_rule(Box::new(CountingRule));
let mut ctx = RewriteContext::default();
let rewritten =
rewriter.rewrite_with_context(QueryExpr::Table(TableQuery::new("users")), &mut ctx);
assert!(matches!(rewritten, QueryExpr::Table(_)));
assert!(ctx.warnings.iter().any(|warning| warning == "CountingRule"));
}
}