use toasty_core::{
driver::operation::TypedValue,
schema::{Schema, db},
stmt,
};
type Cx<'a> = stmt::ExprContext<'a, db::Schema>;
pub(crate) fn extract_params(stmt: &mut stmt::Statement, schema: &Schema) -> Vec<TypedValue> {
let mut params = Vec::new();
extract_values(stmt, &mut params);
refine_param_types(stmt, &schema.db, &mut params);
params
}
#[derive(Debug, Clone)]
enum Ty {
Column(db::Type),
Inferred(db::Type),
Record(Vec<Ty>),
List(Box<Ty>),
Unknown,
}
impl Ty {
fn db_type(&self) -> Option<&db::Type> {
match self {
Ty::Column(ty) | Ty::Inferred(ty) => Some(ty),
_ => None,
}
}
fn is_column(&self) -> bool {
matches!(self, Ty::Column(_))
}
}
fn extract_values(stmt: &mut stmt::Statement, params: &mut Vec<TypedValue>) {
stmt::visit_mut::for_each_expr_mut(stmt, |expr| {
match expr {
stmt::Expr::Value(value) if is_extractable_scalar(value) => {
let ty = db::Type::from_value(value);
let position = params.len();
let value = std::mem::replace(value, stmt::Value::Null);
params.push(TypedValue { value, ty });
*expr = stmt::Expr::arg(position);
}
stmt::Expr::Value(value @ (stmt::Value::Record(_) | stmt::Value::List(_))) => {
let owned = std::mem::replace(value, stmt::Value::Null);
*expr = value_to_extracted_expr(owned, params);
}
_ => {}
}
});
}
fn value_to_extracted_expr(value: stmt::Value, params: &mut Vec<TypedValue>) -> stmt::Expr {
match value {
stmt::Value::Null => stmt::Expr::Value(stmt::Value::Null),
stmt::Value::Record(record) => {
let fields = record
.fields
.into_iter()
.map(|f| value_to_extracted_expr(f, params))
.collect();
stmt::Expr::Record(stmt::ExprRecord::from_vec(fields))
}
stmt::Value::List(values) => {
let items = values
.into_iter()
.map(|v| value_to_extracted_expr(v, params))
.collect();
stmt::Expr::List(stmt::ExprList { items })
}
scalar => {
let ty = db::Type::from_value(&scalar);
let position = params.len();
params.push(TypedValue { value: scalar, ty });
stmt::Expr::arg(position)
}
}
}
fn is_extractable_scalar(value: &stmt::Value) -> bool {
!matches!(
value,
stmt::Value::Null | stmt::Value::Record(_) | stmt::Value::List(_)
)
}
fn refine_param_types(stmt: &stmt::Statement, db_schema: &db::Schema, params: &mut [TypedValue]) {
let cx = stmt::ExprContext::new(db_schema);
refine_stmt(stmt, &cx, db_schema, params);
}
fn refine_stmt(
stmt: &stmt::Statement,
cx: &Cx<'_>,
db_schema: &db::Schema,
params: &mut [TypedValue],
) {
match stmt {
stmt::Statement::Insert(insert) => {
let cx = cx.scope(insert);
refine_insert(insert, &cx, db_schema, params);
}
stmt::Statement::Update(update) => {
let cx = cx.scope(update);
refine_update(update, &cx, db_schema, params);
}
stmt::Statement::Delete(delete) => {
let cx = cx.scope(delete);
refine_filter(&delete.filter, &cx, params);
}
stmt::Statement::Query(query) => {
refine_query(query, cx, params);
}
}
}
fn refine_insert(
insert: &stmt::Insert,
_cx: &Cx<'_>,
db_schema: &db::Schema,
params: &mut [TypedValue],
) {
let expected = match &insert.target {
stmt::InsertTarget::Table(table) => {
let db_table = &db_schema.tables[table.table.0];
let field_types: Vec<Ty> = table
.columns
.iter()
.map(|col_id| Ty::Column(db_table.columns[col_id.index].storage_ty.clone()))
.collect();
Ty::Record(field_types)
}
_ => Ty::Unknown,
};
if let stmt::ExprSet::Values(values) = &insert.source.body {
for row in &values.rows {
check(row, &expected, params);
}
}
}
fn refine_update(
update: &stmt::Update,
cx: &Cx<'_>,
db_schema: &db::Schema,
params: &mut [TypedValue],
) {
if let stmt::UpdateTarget::Table(table_id) = &update.target {
let db_table = &db_schema.tables[table_id.0];
for (projection, assignment) in update.assignments.iter() {
if let stmt::Assignment::Set(expr) = assignment {
let steps = projection.as_slice();
assert_eq!(
steps.len(),
1,
"UPDATE assignment projection should be a single column index, got {steps:?}"
);
let col_idx = steps[0];
if let Some(col) = db_table.columns.get(col_idx) {
let expected = Ty::Column(col.storage_ty.clone());
check(expr, &expected, params);
}
}
}
}
refine_filter(&update.filter, cx, params);
}
fn refine_query(query: &stmt::Query, cx: &Cx<'_>, params: &mut [TypedValue]) {
let cx = cx.scope(query);
match &query.body {
stmt::ExprSet::Select(select) => {
let cx = cx.scope(&**select);
refine_filter(&select.filter, &cx, params);
}
stmt::ExprSet::Values(values) => {
for row in &values.rows {
synthesize(row, &cx, params);
}
}
_ => {}
}
if let Some(with) = &query.with {
for cte in &with.ctes {
refine_query(&cte.query, &cx, params);
}
}
}
fn refine_filter(filter: &stmt::Filter, cx: &Cx<'_>, params: &mut [TypedValue]) {
if let Some(expr) = &filter.expr {
synthesize(expr, cx, params);
}
}
fn synthesize(expr: &stmt::Expr, cx: &Cx<'_>, params: &mut [TypedValue]) -> Ty {
match expr {
stmt::Expr::Arg(arg) => {
let tv = ¶ms[arg.position];
Ty::Inferred(tv.ty.clone())
}
stmt::Expr::Reference(expr_ref @ stmt::ExprReference::Column(_)) => {
match cx.resolve_expr_reference(expr_ref) {
stmt::ResolvedRef::Column(col) => Ty::Column(col.storage_ty.clone()),
_ => Ty::Unknown,
}
}
stmt::Expr::Project(project) => {
let mut ty = synthesize(&project.base, cx, params);
for &step in project.projection.as_slice() {
ty = match ty {
Ty::Record(fields) => {
assert!(
step < fields.len(),
"projection step {step} out of range for record with {} fields",
fields.len()
);
fields.into_iter().nth(step).unwrap()
}
other => panic!("cannot project from non-record type: {other:?}"),
};
}
ty
}
stmt::Expr::Record(record) => {
let fields: Vec<Ty> = record
.fields
.iter()
.map(|f| synthesize(f, cx, params))
.collect();
Ty::Record(fields)
}
stmt::Expr::List(list) => {
let mut merged = Ty::Unknown;
for item in &list.items {
let item_ty = synthesize(item, cx, params);
merged = merge(&merged, &item_ty);
}
Ty::List(Box::new(merged))
}
stmt::Expr::BinaryOp(binary) => {
let lhs_ty = synthesize(&binary.lhs, cx, params);
let rhs_ty = synthesize(&binary.rhs, cx, params);
let merged = merge(&lhs_ty, &rhs_ty);
check(&binary.lhs, &merged, params);
check(&binary.rhs, &merged, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::InList(in_list) => {
let expr_ty = synthesize(&in_list.expr, cx, params);
synthesize(&in_list.list, cx, params);
check_list(&in_list.list, &expr_ty, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::InSubquery(in_sub) => {
synthesize(&in_sub.expr, cx, params);
refine_query(&in_sub.query, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Exists(exists) => {
refine_query(&exists.subquery, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Stmt(expr_stmt) => {
refine_stmt(&expr_stmt.stmt, cx, cx.schema(), params);
Ty::Unknown
}
stmt::Expr::And(and) => {
for op in &and.operands {
synthesize(op, cx, params);
}
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Or(or) => {
for op in &or.operands {
synthesize(op, cx, params);
}
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Not(not) => {
synthesize(¬.expr, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::IsNull(is_null) => {
synthesize(&is_null.expr, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::StartsWith(e) => {
check(&e.expr, &Ty::Inferred(db::Type::Text), params);
check(&e.prefix, &Ty::Inferred(db::Type::Text), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Like(e) => {
check(&e.expr, &Ty::Inferred(db::Type::Text), params);
check(&e.pattern, &Ty::Inferred(db::Type::Text), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Value(stmt::Value::Null) => Ty::Unknown,
stmt::Expr::Default => Ty::Unknown,
_ => Ty::Unknown,
}
}
fn check(expr: &stmt::Expr, expected: &Ty, params: &mut [TypedValue]) {
match (expr, expected) {
(stmt::Expr::Arg(arg), ty) if ty.is_column() => {
if let Some(db_ty) = ty.db_type() {
params[arg.position].ty = db_ty.clone();
}
}
(stmt::Expr::Record(record), Ty::Record(field_types)) => {
for (field, field_ty) in record.fields.iter().zip(field_types) {
check(field, field_ty, params);
}
}
(stmt::Expr::List(list), Ty::List(elem_ty)) => {
for item in &list.items {
check(item, elem_ty, params);
}
}
(stmt::Expr::List(list), ty) if ty.db_type().is_some() => {
for item in &list.items {
check(item, ty, params);
}
}
_ => {}
}
}
fn check_list(list_expr: &stmt::Expr, elem_ty: &Ty, params: &mut [TypedValue]) {
match list_expr {
stmt::Expr::List(list) => {
for item in &list.items {
check(item, elem_ty, params);
}
}
_ => {
check(list_expr, elem_ty, params);
}
}
}
fn merge(a: &Ty, b: &Ty) -> Ty {
match (a, b) {
(Ty::Unknown, other) | (other, Ty::Unknown) => other.clone(),
(Ty::Column(a_ty), Ty::Column(b_ty)) => {
assert_eq!(
a_ty, b_ty,
"two column types in the same expression disagree: {a_ty:?} vs {b_ty:?}"
);
a.clone()
}
(Ty::Column(_), Ty::Inferred(_)) => a.clone(),
(Ty::Inferred(_), Ty::Column(_)) => b.clone(),
(Ty::Inferred(a_ty), Ty::Inferred(b_ty)) => {
assert_eq!(
a_ty, b_ty,
"two inferred types in the same expression disagree: {a_ty:?} vs {b_ty:?}"
);
a.clone()
}
(Ty::Record(a_fields), Ty::Record(b_fields)) if a_fields.len() == b_fields.len() => {
Ty::Record(
a_fields
.iter()
.zip(b_fields)
.map(|(a, b)| merge(a, b))
.collect(),
)
}
(Ty::List(a_elem), Ty::List(b_elem)) => Ty::List(Box::new(merge(a_elem, b_elem))),
_ => panic!("cannot merge incompatible types: {a:?} and {b:?}"),
}
}
#[cfg(test)]
mod tests;