use crate::error::{DbxError, DbxResult};
use crate::sql::planner::types::ScalarValue;
impl Database {
pub fn prepare(&self, sql: &str) -> DbxResult<crate::engine::prepared::PreparedStatement> {
let placeholder_count = sql.matches('?').count();
let mut temp_sql = sql.to_string();
for _ in 0..placeholder_count {
temp_sql = temp_sql.replacen('?', "NULL", 1);
}
let ast = self.sql_parser.parse(&temp_sql)?;
let logical_plan = self.sql_planner.plan(&ast)?;
let physical_plan = self.sql_optimizer.optimize(logical_plan)?;
Ok(crate::engine::prepared::PreparedStatement::new(
sql.to_string(),
physical_plan,
placeholder_count,
))
}
pub fn execute_prepared(
&self,
stmt: &crate::engine::prepared::PreparedStatement,
params: &[ScalarValue],
) -> DbxResult<Vec<arrow::record_batch::RecordBatch>> {
stmt.validate_params(params)?;
let bound_plan = bind_parameters_to_plan(&stmt.plan, params)?;
self.sql_interface.execute_plan(&bound_plan)
}
}
fn bind_parameters_to_plan(
plan: &crate::sql::planner::types::PhysicalPlan,
params: &[ScalarValue],
) -> DbxResult<crate::sql::planner::types::PhysicalPlan> {
use crate::sql::planner::types::PhysicalPlan;
let mut param_index = 0;
match plan {
PhysicalPlan::TableScan { table, projection, filter } => {
let bound_filter = if let Some(f) = filter {
Some(bind_expr(f, params, &mut param_index)?)
} else {
None
};
Ok(PhysicalPlan::TableScan {
table: table.clone(),
projection: projection.clone(),
filter: bound_filter,
})
}
other => Ok(other.clone()),
}
}
fn bind_expr(
expr: &crate::sql::planner::types::Expr,
params: &[ScalarValue],
param_index: &mut usize,
) -> DbxResult<crate::sql::planner::types::Expr> {
use crate::sql::planner::types::Expr;
match expr {
Expr::Literal(ScalarValue::Null) => {
if *param_index >= params.len() {
return Err(DbxError::Schema(
format!("Not enough parameters: need {}, got {}", param_index + 1, params.len())
));
}
let result = Expr::Literal(params[*param_index].clone());
*param_index += 1;
Ok(result)
}
Expr::BinaryOp { left, op, right } => {
let bound_left = bind_expr(left, params, param_index)?;
let bound_right = bind_expr(right, params, param_index)?;
Ok(Expr::BinaryOp {
left: Box::new(bound_left),
op: *op,
right: Box::new(bound_right),
})
}
Expr::Function { name, args } => {
let bound_args: Result<Vec<_>, _> = args
.iter()
.map(|arg| bind_expr(arg, params, param_index))
.collect();
Ok(Expr::Function {
name: name.clone(),
args: bound_args?,
})
}
Expr::InList { expr: inner, list, negated } => {
let bound_inner = bind_expr(inner, params, param_index)?;
let bound_list: Result<Vec<_>, _> = list
.iter()
.map(|item| bind_expr(item, params, param_index))
.collect();
Ok(Expr::InList {
expr: Box::new(bound_inner),
list: bound_list?,
negated: *negated,
})
}
Expr::IsNull(inner) => {
let bound_inner = bind_expr(inner, params, param_index)?;
Ok(Expr::IsNull(Box::new(bound_inner)))
}
Expr::IsNotNull(inner) => {
let bound_inner = bind_expr(inner, params, param_index)?;
Ok(Expr::IsNotNull(Box::new(bound_inner)))
}
other => Ok(other.clone()),
}
}
fn scalar_to_sql_string(value: &ScalarValue) -> String {
match value {
ScalarValue::Boolean(b) => b.to_string(),
ScalarValue::Int32(i) => i.to_string(),
ScalarValue::Int64(i) => i.to_string(),
ScalarValue::Float64(f) => f.to_string(),
ScalarValue::Utf8(s) => format!("'{}'", s.replace('\'', "''")), ScalarValue::Null => "NULL".to_string(),
}
}