use toasty_core::{
driver::{Capability, operation::TypedValue},
schema::{Schema, db},
stmt::{self, VisitMut},
};
type Cx<'a> = stmt::ExprContext<'a, db::Schema>;
pub(crate) fn extract_params(
stmt: &mut stmt::Statement,
schema: &Schema,
capability: &Capability,
) -> Vec<TypedValue> {
let mut params: Vec<Param> = Vec::new();
extract_values(stmt, &mut params, capability);
refine_param_types(stmt, &schema.db, &mut params);
params
.into_iter()
.map(|p| {
let Param { value, ty } = p;
TypedValue {
ty: finalize_ty(&value, ty),
value,
}
})
.collect()
}
struct Param {
value: stmt::Value,
ty: Ty,
}
fn finalize_ty(value: &stmt::Value, ty: Ty) -> db::Type {
match ty {
Ty::Column(t) | Ty::Inferred(t) => t,
Ty::List(elem) => db::Type::List(Box::new(finalize_ty(value, *elem))),
Ty::Unknown => panic!("extract_params left {value:?} with unresolved type"),
Ty::Record(_) => panic!(
"extract_params left {value:?} typed as a record; only scalars and lists are extracted as params"
),
}
}
fn infer_ty(value: &stmt::Value) -> Ty {
use stmt::Value;
match value {
Value::Bool(_) => Ty::Inferred(db::Type::Boolean),
Value::I8(_) => Ty::Inferred(db::Type::Integer(1)),
Value::I16(_) => Ty::Inferred(db::Type::Integer(2)),
Value::I32(_) => Ty::Inferred(db::Type::Integer(4)),
Value::I64(_) => Ty::Inferred(db::Type::Integer(8)),
Value::U8(_) => Ty::Inferred(db::Type::UnsignedInteger(1)),
Value::U16(_) => Ty::Inferred(db::Type::UnsignedInteger(2)),
Value::U32(_) => Ty::Inferred(db::Type::UnsignedInteger(4)),
Value::U64(_) => Ty::Inferred(db::Type::UnsignedInteger(8)),
Value::String(_) => Ty::Inferred(db::Type::Text),
Value::Uuid(_) => Ty::Inferred(db::Type::Uuid),
Value::Bytes(_) => Ty::Inferred(db::Type::Blob),
#[cfg(feature = "rust_decimal")]
Value::Decimal(_) => Ty::Inferred(db::Type::Numeric(None)),
#[cfg(feature = "jiff")]
Value::Timestamp(_) => Ty::Inferred(db::Type::Timestamp(6)),
#[cfg(feature = "jiff")]
Value::Date(_) => Ty::Inferred(db::Type::Date),
#[cfg(feature = "jiff")]
Value::Time(_) => Ty::Inferred(db::Type::Time(6)),
#[cfg(feature = "jiff")]
Value::DateTime(_) => Ty::Inferred(db::Type::DateTime(6)),
Value::List(items) => {
let elem = items
.iter()
.find(|v| !v.is_null())
.map(infer_ty)
.unwrap_or(Ty::Unknown);
Ty::List(Box::new(elem))
}
_ => Ty::Unknown,
}
}
#[derive(Debug, Clone)]
enum Ty {
Column(db::Type),
Inferred(db::Type),
Record(Vec<Ty>),
List(Box<Ty>),
Unknown,
}
impl Ty {
fn db_type(&self) -> Option<&db::Type> {
match self {
Ty::Column(ty) | Ty::Inferred(ty) => Some(ty),
_ => None,
}
}
#[cfg(test)]
fn is_column(&self) -> bool {
matches!(self, Ty::Column(_))
}
}
fn extract_values(stmt: &mut stmt::Statement, params: &mut Vec<Param>, capability: &Capability) {
struct Extract<'a> {
params: &'a mut Vec<Param>,
bind_list_param: bool,
}
impl stmt::VisitMut for Extract<'_> {
fn visit_expr_mut(&mut self, expr: &mut stmt::Expr) {
match expr {
stmt::Expr::AnyOp(e) => {
self.visit_expr_mut(&mut e.lhs);
if let Some(arg) = extract_array_operand(&mut e.rhs, self.params) {
*e.rhs = arg;
} else {
self.visit_expr_mut(&mut e.rhs);
}
return;
}
stmt::Expr::AllOp(e) => {
self.visit_expr_mut(&mut e.lhs);
if let Some(arg) = extract_array_operand(&mut e.rhs, self.params) {
*e.rhs = arg;
} else {
self.visit_expr_mut(&mut e.rhs);
}
return;
}
stmt::Expr::InList(e) => {
self.visit_expr_mut(&mut e.expr);
if let stmt::Expr::Value(stmt::Value::List(_)) = e.list.as_ref() {
let stmt::Expr::Value(stmt::Value::List(items)) =
std::mem::replace(e.list.as_mut(), stmt::Expr::null())
else {
unreachable!()
};
let items = items
.into_iter()
.map(|v| value_to_extracted_expr(v, self.params, false))
.collect();
*e.list = stmt::Expr::List(stmt::ExprList { items });
} else {
self.visit_expr_mut(&mut e.list);
}
return;
}
_ => {}
}
if self.bind_list_param
&& is_scalar_list(expr)
&& let Some(arg) = extract_array_operand(expr, self.params)
{
*expr = arg;
return;
}
stmt::visit_mut::visit_expr_mut(self, expr);
match expr {
stmt::Expr::Value(value) if is_extractable_scalar(value) => {
let ty = infer_ty(value);
let position = self.params.len();
let value = std::mem::replace(value, stmt::Value::Null);
self.params.push(Param { value, ty });
*expr = stmt::Expr::arg(position);
}
stmt::Expr::Value(value @ (stmt::Value::Record(_) | stmt::Value::List(_))) => {
let owned = std::mem::replace(value, stmt::Value::Null);
*expr = value_to_extracted_expr(owned, self.params, self.bind_list_param);
}
_ => {}
}
}
}
Extract {
params,
bind_list_param: capability.bind_list_param,
}
.visit_mut(stmt);
}
fn is_extractable_scalar_expr(expr: &stmt::Expr) -> bool {
matches!(expr, stmt::Expr::Value(v) if is_extractable_scalar(v))
}
fn is_scalar_list(expr: &stmt::Expr) -> bool {
match expr {
stmt::Expr::List(list) => list.items.iter().all(is_extractable_scalar_expr),
stmt::Expr::Value(stmt::Value::List(items)) => items.iter().all(is_extractable_scalar),
_ => false,
}
}
fn extract_array_operand(expr: &mut stmt::Expr, params: &mut Vec<Param>) -> Option<stmt::Expr> {
let items: Vec<stmt::Value> = match expr {
stmt::Expr::Value(stmt::Value::List(_)) => {
let stmt::Expr::Value(stmt::Value::List(items)) =
std::mem::replace(expr, stmt::Expr::null())
else {
unreachable!()
};
items
}
stmt::Expr::List(list) if list.items.iter().all(|i| matches!(i, stmt::Expr::Value(_))) => {
let stmt::Expr::List(list) = std::mem::replace(expr, stmt::Expr::null()) else {
unreachable!()
};
list.items
.into_iter()
.map(|e| match e {
stmt::Expr::Value(v) => v,
_ => unreachable!(),
})
.collect()
}
_ => return None,
};
let value = stmt::Value::List(items);
let ty = infer_ty(&value);
let position = params.len();
params.push(Param { value, ty });
Some(stmt::Expr::arg(position))
}
fn value_to_extracted_expr(
value: stmt::Value,
params: &mut Vec<Param>,
bind_list_param: bool,
) -> stmt::Expr {
match value {
stmt::Value::Null => stmt::Expr::Value(stmt::Value::Null),
stmt::Value::Record(record) => {
let fields = record
.fields
.into_iter()
.map(|f| value_to_extracted_expr(f, params, bind_list_param))
.collect();
stmt::Expr::Record(stmt::ExprRecord::from_vec(fields))
}
stmt::Value::List(values)
if bind_list_param && values.iter().all(is_extractable_scalar) =>
{
let value = stmt::Value::List(values);
let ty = infer_ty(&value);
let position = params.len();
params.push(Param { value, ty });
stmt::Expr::arg(position)
}
stmt::Value::List(values) => {
let items = values
.into_iter()
.map(|v| value_to_extracted_expr(v, params, bind_list_param))
.collect();
stmt::Expr::List(stmt::ExprList { items })
}
scalar => {
let ty = infer_ty(&scalar);
let position = params.len();
params.push(Param { value: scalar, ty });
stmt::Expr::arg(position)
}
}
}
fn is_extractable_scalar(value: &stmt::Value) -> bool {
!matches!(
value,
stmt::Value::Null | stmt::Value::Record(_) | stmt::Value::List(_)
)
}
fn refine_param_types(stmt: &stmt::Statement, db_schema: &db::Schema, params: &mut [Param]) {
let cx = stmt::ExprContext::new(db_schema);
refine_stmt(stmt, &cx, db_schema, params);
}
fn refine_stmt(stmt: &stmt::Statement, cx: &Cx<'_>, db_schema: &db::Schema, params: &mut [Param]) {
match stmt {
stmt::Statement::Insert(insert) => {
let cx = cx.scope(insert);
refine_insert(insert, &cx, db_schema, params);
}
stmt::Statement::Update(update) => {
let cx = cx.scope(update);
refine_update(update, &cx, db_schema, params);
}
stmt::Statement::Delete(delete) => {
let cx = cx.scope(delete);
refine_filter(&delete.filter, &cx, params);
}
stmt::Statement::Query(query) => {
refine_query(query, cx, params);
}
}
}
fn ty_from_column(storage_ty: db::Type) -> Ty {
match storage_ty {
db::Type::List(elem) => Ty::List(Box::new(ty_from_column(*elem))),
scalar => Ty::Column(scalar),
}
}
fn refine_insert(
insert: &stmt::Insert,
_cx: &Cx<'_>,
db_schema: &db::Schema,
params: &mut [Param],
) {
let expected = match &insert.target {
stmt::InsertTarget::Table(table) => {
let db_table = &db_schema.tables[table.table.0];
let field_types: Vec<Ty> = table
.columns
.iter()
.map(|col_id| ty_from_column(db_table.columns[col_id.index].storage_ty.clone()))
.collect();
Ty::Record(field_types)
}
_ => Ty::Unknown,
};
if let stmt::ExprSet::Values(values) = &insert.source.body {
for row in &values.rows {
check(row, &expected, params);
}
}
}
fn refine_update(update: &stmt::Update, cx: &Cx<'_>, db_schema: &db::Schema, params: &mut [Param]) {
if let stmt::UpdateTarget::Table(table_id) = &update.target {
let db_table = &db_schema.tables[table_id.0];
for (projection, assignment) in update.assignments.iter() {
let steps = projection.as_slice();
assert_eq!(
steps.len(),
1,
"UPDATE assignment projection should be a single column index, got {steps:?}"
);
let col_idx = steps[0];
let Some(col) = db_table.columns.get(col_idx) else {
continue;
};
match assignment {
stmt::Assignment::Set(expr) | stmt::Assignment::Append(expr) => {
let expected = ty_from_column(col.storage_ty.clone());
check(expr, &expected, params);
}
stmt::Assignment::Remove(expr) => {
if let db::Type::List(elem) = &col.storage_ty {
let expected = ty_from_column((**elem).clone());
check(expr, &expected, params);
}
}
stmt::Assignment::RemoveAt(_) | stmt::Assignment::Pop => {}
stmt::Assignment::Insert(_) | stmt::Assignment::Batch(_) => continue,
}
}
}
refine_filter(&update.filter, cx, params);
}
fn refine_query(query: &stmt::Query, cx: &Cx<'_>, params: &mut [Param]) {
let cx = cx.scope(query);
match &query.body {
stmt::ExprSet::Select(select) => {
let cx = cx.scope(&**select);
refine_filter(&select.filter, &cx, params);
}
stmt::ExprSet::Values(values) => {
for row in &values.rows {
synthesize(row, &cx, params);
}
}
_ => {}
}
if let Some(with) = &query.with {
for cte in &with.ctes {
refine_query(&cte.query, &cx, params);
}
}
}
fn refine_filter(filter: &stmt::Filter, cx: &Cx<'_>, params: &mut [Param]) {
if let Some(expr) = &filter.expr {
synthesize(expr, cx, params);
}
}
fn synthesize(expr: &stmt::Expr, cx: &Cx<'_>, params: &mut [Param]) -> Ty {
match expr {
stmt::Expr::Arg(arg) => params[arg.position].ty.clone(),
stmt::Expr::Reference(expr_ref @ stmt::ExprReference::Column(_)) => {
match cx.resolve_expr_reference(expr_ref) {
stmt::ResolvedRef::Column(col) => Ty::Column(col.storage_ty.clone()),
_ => Ty::Unknown,
}
}
stmt::Expr::Project(project) => {
let mut ty = synthesize(&project.base, cx, params);
for &step in project.projection.as_slice() {
ty = match ty {
Ty::Record(fields) => {
assert!(
step < fields.len(),
"projection step {step} out of range for record with {} fields",
fields.len()
);
fields.into_iter().nth(step).unwrap()
}
other => panic!("cannot project from non-record type: {other:?}"),
};
}
ty
}
stmt::Expr::Record(record) => {
let fields: Vec<Ty> = record
.fields
.iter()
.map(|f| synthesize(f, cx, params))
.collect();
Ty::Record(fields)
}
stmt::Expr::List(list) => {
let mut merged = Ty::Unknown;
for item in &list.items {
let item_ty = synthesize(item, cx, params);
merged = merge(&merged, &item_ty);
}
Ty::List(Box::new(merged))
}
stmt::Expr::BinaryOp(binary) => {
let lhs_ty = synthesize(&binary.lhs, cx, params);
let rhs_ty = synthesize(&binary.rhs, cx, params);
let merged = merge(&lhs_ty, &rhs_ty);
check(&binary.lhs, &merged, params);
check(&binary.rhs, &merged, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::InList(in_list) => {
let expr_ty = synthesize(&in_list.expr, cx, params);
synthesize(&in_list.list, cx, params);
check_list(&in_list.list, &expr_ty, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::AnyOp(e) => {
let lhs_ty = synthesize(&e.lhs, cx, params);
check(&e.rhs, &Ty::List(Box::new(lhs_ty)), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::AllOp(e) => {
let lhs_ty = synthesize(&e.lhs, cx, params);
check(&e.rhs, &Ty::List(Box::new(lhs_ty)), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::InSubquery(in_sub) => {
synthesize(&in_sub.expr, cx, params);
refine_query(&in_sub.query, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Exists(exists) => {
refine_query(&exists.subquery, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Stmt(expr_stmt) => {
refine_stmt(&expr_stmt.stmt, cx, cx.schema(), params);
Ty::Unknown
}
stmt::Expr::And(and) => {
for op in &and.operands {
synthesize(op, cx, params);
}
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Or(or) => {
for op in &or.operands {
synthesize(op, cx, params);
}
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Not(not) => {
synthesize(¬.expr, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::IsNull(is_null) => {
synthesize(&is_null.expr, cx, params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::StartsWith(e) => {
check(&e.expr, &Ty::Inferred(db::Type::Text), params);
check(&e.prefix, &Ty::Inferred(db::Type::Text), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Like(e) => {
check(&e.expr, &Ty::Inferred(db::Type::Text), params);
check(&e.pattern, &Ty::Inferred(db::Type::Text), params);
Ty::Inferred(db::Type::Boolean)
}
stmt::Expr::Value(stmt::Value::Null) => Ty::Unknown,
stmt::Expr::Default => Ty::Unknown,
_ => Ty::Unknown,
}
}
fn check(expr: &stmt::Expr, expected: &Ty, params: &mut [Param]) {
match (expr, expected) {
(stmt::Expr::Arg(arg), ty) => {
let current = params[arg.position].ty.clone();
params[arg.position].ty = merge(¤t, ty);
}
(stmt::Expr::Record(record), Ty::Record(field_types)) => {
for (field, field_ty) in record.fields.iter().zip(field_types) {
check(field, field_ty, params);
}
}
(stmt::Expr::List(list), Ty::List(elem_ty)) => {
for item in &list.items {
check(item, elem_ty, params);
}
}
(stmt::Expr::List(list), ty) if ty.db_type().is_some() => {
for item in &list.items {
check(item, ty, params);
}
}
_ => {}
}
}
fn check_list(list_expr: &stmt::Expr, elem_ty: &Ty, params: &mut [Param]) {
match list_expr {
stmt::Expr::List(list) => {
for item in &list.items {
check(item, elem_ty, params);
}
}
_ => {
check(list_expr, elem_ty, params);
}
}
}
fn merge(a: &Ty, b: &Ty) -> Ty {
match (a, b) {
(Ty::Unknown, other) | (other, Ty::Unknown) => other.clone(),
(Ty::Column(a_ty), Ty::Column(b_ty)) => {
assert_eq!(
a_ty, b_ty,
"two column types in the same expression disagree: {a_ty:?} vs {b_ty:?}"
);
a.clone()
}
(Ty::Column(_), Ty::Inferred(_)) => a.clone(),
(Ty::Inferred(_), Ty::Column(_)) => b.clone(),
(Ty::Inferred(a_ty), Ty::Inferred(b_ty)) => {
assert_eq!(
a_ty, b_ty,
"two inferred types in the same expression disagree: {a_ty:?} vs {b_ty:?}"
);
a.clone()
}
(Ty::Record(a_fields), Ty::Record(b_fields)) if a_fields.len() == b_fields.len() => {
Ty::Record(
a_fields
.iter()
.zip(b_fields)
.map(|(a, b)| merge(a, b))
.collect(),
)
}
(Ty::List(a_elem), Ty::List(b_elem)) => Ty::List(Box::new(merge(a_elem, b_elem))),
_ => panic!("cannot merge incompatible types: {a:?} and {b:?}"),
}
}
#[cfg(test)]
mod tests;