use super::*;
use crate::catalog::ColumnRefId;
use crate::v1::binder::*;
#[derive(Default)]
pub struct InputRefResolver {
bindings: Vec<Option<BoundExpr>>,
}
impl InputRefResolver {
fn rewrite_template(&self, expr: &mut BoundExpr) {
use BoundExpr::*;
if let Some(idx) = self
.bindings
.iter()
.position(|col| *col == Some(expr.clone()))
{
*expr = InputRef(BoundInputRef {
index: idx,
return_type: expr.return_type(),
});
return;
}
match expr {
BoundExpr::BinaryOp(op) => {
self.rewrite_expr(op.left_expr.as_mut());
self.rewrite_expr(op.right_expr.as_mut());
}
BoundExpr::UnaryOp(expr) => self.rewrite_expr(expr.expr.as_mut()),
BoundExpr::TypeCast(expr) => self.rewrite_expr(expr.expr.as_mut()),
BoundExpr::IsNull(expr) => self.rewrite_expr(expr.expr.as_mut()),
_ => unreachable!(),
}
}
}
impl ExprRewriter for InputRefResolver {
fn rewrite_column_ref(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
fn rewrite_agg_call(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
fn rewrite_binary_op(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
fn rewrite_unary_op(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
fn rewrite_type_cast(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
fn rewrite_is_null(&self, expr: &mut BoundExpr) {
self.rewrite_template(expr);
}
}
impl PlanRewriter for InputRefResolver {
fn rewrite_logical_join(&mut self, join: &LogicalJoin) -> PlanRef {
let left = self.rewrite(join.left());
let mut resolver = Self::default();
let right = resolver.rewrite(join.right());
self.bindings.append(&mut resolver.bindings);
Arc::new(join.clone_with_rewrite_expr(left, right, self))
}
fn rewrite_logical_table_scan(&mut self, plan: &LogicalTableScan) -> PlanRef {
self.bindings = plan
.column_ids()
.iter()
.zip(plan.column_descs())
.map(|(col_id, col_desc)| {
Some(BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::from_table(plan.table_ref_id(), *col_id),
is_primary_key: col_desc.is_primary(),
desc: col_desc.clone(),
}))
})
.collect();
Arc::new(plan.clone())
}
fn rewrite_internal(&mut self, plan: &Internal) -> PlanRef {
self.bindings = plan
.column_ids()
.iter()
.zip(plan.column_descs())
.map(|(col_id, col_desc)| {
Some(BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::from_table(plan.table_ref_id(), *col_id),
is_primary_key: col_desc.is_primary(),
desc: col_desc.clone(),
}))
})
.collect();
Arc::new(plan.clone())
}
fn rewrite_logical_projection(&mut self, proj: &LogicalProjection) -> PlanRef {
let new_child = self.rewrite(proj.child());
let bindings = proj
.project_expressions()
.iter()
.map(|e| match e {
BoundExpr::ExprWithAlias(alias) => Some((*alias.expr).clone()),
_ => Some(e.clone()),
})
.collect();
let ret = Arc::new(proj.clone_with_rewrite_expr(new_child, self));
self.bindings = bindings;
ret
}
fn rewrite_logical_aggregate(&mut self, agg: &LogicalAggregate) -> PlanRef {
let new_child = self.rewrite(agg.child());
let mut bindings: Vec<Option<BoundExpr>> =
agg.group_keys().iter().map(|e| Some(e.clone())).collect();
let agg_calls = agg
.agg_calls()
.iter()
.map(|e| Some(BoundExpr::AggCall(e.clone())));
bindings.extend(agg_calls);
let ret = Arc::new(agg.clone_with_rewrite_expr(new_child, self));
self.bindings = bindings;
ret
}
fn rewrite_logical_filter(&mut self, plan: &LogicalFilter) -> PlanRef {
let new_child = self.rewrite(plan.child());
Arc::new(plan.clone_with_rewrite_expr(new_child, self))
}
fn rewrite_logical_order(&mut self, plan: &LogicalOrder) -> PlanRef {
let child = self.rewrite(plan.child());
Arc::new(plan.clone_with_rewrite_expr(child, self))
}
fn rewrite_logical_values(&mut self, plan: &LogicalValues) -> PlanRef {
Arc::new(plan.clone_with_rewrite_expr(self))
}
}
#[derive(Default)]
#[allow(dead_code)]
struct AggInputRefResolver {
agg_start_index: usize,
}
impl AggInputRefResolver {
#[allow(dead_code)]
fn new(group_key_count: usize) -> Self {
AggInputRefResolver {
agg_start_index: group_key_count,
}
}
#[allow(dead_code)]
fn resolve_select_expr(&mut self, expr: &mut BoundExpr, group_keys: &Vec<BoundExpr>) {
use BoundExpr::*;
if let Some(i) = group_keys.iter().position(|e| e == expr) {
*expr = InputRef(BoundInputRef {
index: i,
return_type: expr.return_type(),
});
return;
}
match expr {
AggCall(agg) => {
*expr = InputRef(BoundInputRef {
index: self.agg_start_index,
return_type: agg.return_type.clone(),
});
self.agg_start_index += 1;
}
BinaryOp(bin_op) => {
self.resolve_select_expr(&mut bin_op.left_expr, group_keys);
self.resolve_select_expr(&mut bin_op.right_expr, group_keys);
}
UnaryOp(unary_op) => self.resolve_select_expr(&mut unary_op.expr, group_keys),
TypeCast(type_cast) => self.resolve_select_expr(&mut type_cast.expr, group_keys),
ExprWithAlias(expr_with_alias) => {
self.resolve_select_expr(&mut expr_with_alias.expr, group_keys)
}
IsNull(isnull) => self.resolve_select_expr(&mut isnull.expr, group_keys),
Constant(_) | ColumnRef(_) | InputRef(_) | Alias(_) => {}
}
}
}
#[cfg(test)]
mod tests {
use sqlparser::ast::BinaryOperator;
use super::*;
use crate::types::{DataTypeKind, DataValue};
use crate::v1::binder::AggKind;
#[test]
fn test_resolve_select_expr() {
let sum_v1_call = BoundExpr::AggCall(BoundAggCall {
kind: AggKind::Sum,
args: vec![BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::new(0, 0, 0, 0),
is_primary_key: false,
desc: DataTypeKind::Int32.not_null().to_column("v1".into()),
})],
return_type: DataTypeKind::Int32.not_null(),
});
let v2_plus_1_expr = BoundExpr::BinaryOp(BoundBinaryOp {
op: BinaryOperator::Plus,
left_expr: BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::new(0, 0, 0, 1),
is_primary_key: false,
desc: DataTypeKind::Int32.not_null().to_column("v2".into()),
})
.into(),
right_expr: BoundExpr::Constant(DataValue::Int32(1)).into(),
return_type: DataTypeKind::Int32.not_null(),
});
let group_keys = vec![v2_plus_1_expr.clone()];
let mut select_list = vec![v2_plus_1_expr, sum_v1_call];
let mut resolver = AggInputRefResolver::new(group_keys.len());
for expr in &mut select_list {
resolver.resolve_select_expr(expr, &group_keys);
}
assert_eq!(
select_list[0],
BoundExpr::InputRef(BoundInputRef {
index: 0,
return_type: DataTypeKind::Int32.not_null(),
})
);
assert_eq!(
select_list[1],
BoundExpr::InputRef(BoundInputRef {
index: 1,
return_type: DataTypeKind::Int32.not_null(),
})
);
}
#[test]
fn test_resolve_select_expr_plus_agg_call() {
let sum_v1_call = BoundExpr::AggCall(BoundAggCall {
kind: AggKind::Sum,
args: vec![BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::new(0, 0, 0, 0),
is_primary_key: false,
desc: DataTypeKind::Int32.not_null().to_column("v1".into()),
})],
return_type: DataTypeKind::Int32.not_null(),
});
let v2_expr = BoundExpr::ColumnRef(BoundColumnRef {
column_ref_id: ColumnRefId::new(0, 0, 0, 1),
is_primary_key: false,
desc: DataTypeKind::Int32.not_null().to_column("v2".into()),
});
let v2_plus_1_expr = BoundExpr::BinaryOp(BoundBinaryOp {
op: BinaryOperator::Plus,
left_expr: v2_expr.into(),
right_expr: BoundExpr::Constant(DataValue::Int32(1)).into(),
return_type: DataTypeKind::Int32.not_null(),
});
let v2_plus_1_plus_sum_expr = BoundExpr::BinaryOp(BoundBinaryOp {
op: BinaryOperator::Plus,
left_expr: v2_plus_1_expr.clone().into(),
right_expr: sum_v1_call.into(),
return_type: DataTypeKind::Int32.not_null(),
});
let group_keys = vec![v2_plus_1_expr];
let mut select_list = vec![v2_plus_1_plus_sum_expr];
let mut resolver = AggInputRefResolver::new(group_keys.len());
for expr in &mut select_list {
resolver.resolve_select_expr(expr, &group_keys);
}
assert_eq!(
select_list[0],
BoundExpr::BinaryOp(BoundBinaryOp {
op: BinaryOperator::Plus,
left_expr: BoundExpr::InputRef(BoundInputRef {
index: 0,
return_type: DataTypeKind::Int32.not_null(),
})
.into(),
right_expr: BoundExpr::InputRef(BoundInputRef {
index: 1,
return_type: DataTypeKind::Int32.not_null(),
})
.into(),
return_type: DataTypeKind::Int32.not_null(),
})
);
}
}