use crate::api_request::types::Payload;
use crate::backend::SqlDialect;
use crate::plan::call_plan::{CallArgs, CallParams, CallPlan, RpcParamValue};
use crate::plan::mutate_plan::{DeletePlan, InsertPlan, MutatePlan, UpdatePlan};
use crate::plan::read_plan::ReadPlanTree;
use super::fragment;
use super::sql_builder::{SqlBuilder, SqlParam};
pub fn read_plan_to_query(tree: &ReadPlanTree, dialect: &dyn SqlDialect) -> SqlBuilder {
let plan = &tree.node;
let qi = &plan.from;
let mut b = SqlBuilder::new();
b.push("SELECT ");
if plan.select.is_empty() {
b.push_qi(qi);
b.push(".*");
} else {
b.push_separated(", ", &plan.select, |b, sel| {
fragment::fmt_select_item(b, qi, sel, dialect);
});
}
if dialect.supports_lateral_join() {
for child in &tree.forest {
b.push(", ");
let agg_alias = &child.node.rel_agg_alias;
b.push_ident(agg_alias);
b.push(".body");
let sel_name = child
.node
.rel_alias
.as_ref()
.unwrap_or(&child.node.rel_name);
b.push(" AS ");
b.push_ident(sel_name);
}
} else {
for child in &tree.forest {
b.push(", ");
let is_to_one = child
.node
.rel_to_parent
.as_ref()
.map(|r| r.is_to_one())
.unwrap_or(false);
b.push("(SELECT ");
if is_to_one {
dialect.row_to_json(&mut b, "_dbrst_t");
} else {
dialect.json_agg(&mut b, "_dbrst_t");
}
b.push(" FROM (");
let child_query = read_plan_to_query(child, dialect);
b.push_builder(&child_query);
b.push(") AS ");
b.push_ident("_dbrst_t");
b.push(")");
let sel_name = child
.node
.rel_alias
.as_ref()
.unwrap_or(&child.node.rel_name);
b.push(" AS ");
b.push_ident(sel_name);
}
}
b.push(" FROM ");
b.push_qi(qi);
if let Some(ref alias) = plan.from_alias {
b.push(" AS ");
b.push_ident(alias);
}
if dialect.supports_lateral_join() {
for child in &tree.forest {
let is_inner = child
.node
.rel_join_type
.map(|jt| matches!(jt, crate::api_request::types::JoinType::Inner))
.unwrap_or(false);
let join_type = if is_inner {
"INNER JOIN LATERAL"
} else {
"LEFT JOIN LATERAL"
};
b.push(" ");
b.push(join_type);
b.push(" (");
let is_to_one = child
.node
.rel_to_parent
.as_ref()
.map(|r| r.is_to_one())
.unwrap_or(false);
if is_to_one {
b.push("SELECT ");
dialect.row_to_json(&mut b, "_dbrst_t");
b.push(" AS body FROM (");
} else {
b.push("SELECT ");
dialect.json_agg(&mut b, "_dbrst_t");
b.push(" AS body FROM (");
}
let child_query = read_plan_to_query(child, dialect);
b.push_builder(&child_query);
b.push(") AS ");
b.push_ident("_dbrst_t");
b.push(") AS ");
b.push_ident(&child.node.rel_agg_alias);
b.push(" ON TRUE");
}
}
let mut has_where = false;
if !plan.where_.is_empty() {
fragment::where_clause(&mut b, qi, &plan.where_, dialect);
has_where = true;
}
if !plan.rel_join_conds.is_empty() {
if has_where {
b.push(" AND ");
} else {
b.push(" WHERE ");
}
b.push_separated(" AND ", &plan.rel_join_conds, |b, jc| {
fragment::fmt_join_condition(b, jc);
});
}
fragment::group_clause(&mut b, qi, &plan.select);
fragment::order_clause(&mut b, qi, &plan.order);
fragment::limit_offset(&mut b, plan.range.offset, plan.range.limit_to);
b
}
pub fn read_plan_to_count_query(tree: &ReadPlanTree, dialect: &dyn SqlDialect) -> SqlBuilder {
let mut b = SqlBuilder::new();
fragment::count_f(&mut b, dialect);
let plan = &tree.node;
let qi = &plan.from;
b.push(" FROM (SELECT 1 FROM ");
b.push_qi(qi);
if !plan.where_.is_empty() {
fragment::where_clause(&mut b, qi, &plan.where_, dialect);
}
b.push(") AS _dbrst_count_t");
b
}
pub fn mutate_plan_to_query(plan: &MutatePlan, dialect: &dyn SqlDialect) -> SqlBuilder {
match plan {
MutatePlan::Insert(insert) => insert_to_query(insert, dialect),
MutatePlan::Update(update) => update_to_query(update, dialect),
MutatePlan::Delete(delete) => delete_to_query(delete, dialect),
}
}
fn insert_to_query(plan: &InsertPlan, dialect: &dyn SqlDialect) -> SqlBuilder {
let qi = &plan.into;
let mut b = SqlBuilder::new();
b.push("INSERT INTO ");
b.push_qi(qi);
if plan.columns.is_empty() {
b.push(" DEFAULT VALUES");
} else {
b.push("(");
b.push_separated(", ", &plan.columns, |b, col| {
b.push_ident(&col.name);
});
b.push(")");
b.push(" SELECT ");
b.push_separated(", ", &plan.columns, |b, col| {
b.push_ident(&col.name);
});
b.push(" FROM ");
let json_bytes = payload_to_bytes(&plan.body);
fragment::from_json_body(&mut b, &plan.columns, &json_bytes, dialect);
}
if let Some(ref oc) = plan.on_conflict {
b.push(" ON CONFLICT(");
b.push_separated(", ", &oc.columns, |b, col| {
b.push_ident(col);
});
b.push(")");
if oc.merge_duplicates {
b.push(" DO UPDATE SET ");
b.push_separated(", ", &plan.columns, |b, col| {
b.push_ident(&col.name);
b.push(" = EXCLUDED.");
b.push_ident(&col.name);
});
} else {
b.push(" DO NOTHING");
}
}
fragment::where_clause(&mut b, qi, &plan.where_, dialect);
fragment::returning_clause(&mut b, qi, &plan.returning, dialect);
b
}
fn update_to_query(plan: &UpdatePlan, dialect: &dyn SqlDialect) -> SqlBuilder {
let qi = &plan.into;
let mut b = SqlBuilder::new();
b.push("UPDATE ");
b.push_qi(qi);
b.push(" SET ");
if plan.columns.len() == 1 {
b.push_ident(&plan.columns[0].name);
b.push(" = (SELECT ");
b.push_ident(&plan.columns[0].name);
} else {
b.push("(");
b.push_separated(", ", &plan.columns, |b, col| {
b.push_ident(&col.name);
});
b.push(") = (SELECT ");
b.push_separated(", ", &plan.columns, |b, col| {
b.push_ident(&col.name);
});
}
b.push(" FROM ");
let json_bytes = payload_to_bytes(&plan.body);
fragment::from_json_body(&mut b, &plan.columns, &json_bytes, dialect);
b.push(")");
fragment::where_clause(&mut b, qi, &plan.where_, dialect);
fragment::returning_clause(&mut b, qi, &plan.returning, dialect);
b
}
fn delete_to_query(plan: &DeletePlan, dialect: &dyn SqlDialect) -> SqlBuilder {
let qi = &plan.from;
let mut b = SqlBuilder::new();
b.push("DELETE FROM ");
b.push_qi(qi);
fragment::where_clause(&mut b, qi, &plan.where_, dialect);
fragment::returning_clause(&mut b, qi, &plan.returning, dialect);
b
}
pub fn call_plan_to_query(plan: &CallPlan, dialect: &dyn SqlDialect) -> SqlBuilder {
let _ = dialect; let mut b = SqlBuilder::new();
b.push("SELECT * FROM ");
b.push_qi(&plan.qi);
b.push("(");
match &plan.args {
CallArgs::DirectArgs(args) => {
match &plan.params {
CallParams::KeyParams(params) => {
let mut first = true;
for param in params {
if let Some(val) = args.get(¶m.name) {
if !first {
b.push(", ");
}
first = false;
b.push_ident(¶m.name);
b.push(" := ");
match val {
RpcParamValue::Fixed(v) => {
b.push_param(SqlParam::Text(v.to_string()));
}
RpcParamValue::Variadic(vals) => {
b.push("VARIADIC ARRAY[");
for (i, v) in vals.iter().enumerate() {
if i > 0 {
b.push(", ");
}
b.push_param(SqlParam::Text(v.to_string()));
}
b.push("]");
}
}
}
}
}
CallParams::OnePosParam(_) => {
if let Some((_, val)) = args.iter().next() {
match val {
RpcParamValue::Fixed(v) => {
b.push_param(SqlParam::Text(v.to_string()));
}
RpcParamValue::Variadic(vals) => {
b.push_param(SqlParam::Text(vals.join(",").to_string()));
}
}
}
}
}
}
CallArgs::JsonArgs(json) => {
if let Some(body) = json {
b.push_param(SqlParam::Json(body.clone()));
}
}
}
b.push(")");
b
}
fn payload_to_bytes(payload: &Payload) -> Vec<u8> {
match payload {
Payload::ProcessedJSON { raw, .. } => raw.to_vec(),
Payload::RawJSON(raw) => raw.to_vec(),
Payload::RawPayload(raw) => raw.to_vec(),
Payload::ProcessedUrlEncoded { params, .. } => {
let json = serde_json::json!(
params
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<std::collections::HashMap<_, _>>()
);
json.to_string().into_bytes()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_request::range::Range;
use crate::api_request::types::*;
use crate::plan::call_plan::*;
use crate::plan::mutate_plan::*;
use crate::plan::read_plan::*;
use crate::plan::types::*;
use crate::test_helpers::TestPgDialect;
use crate::types::identifiers::QualifiedIdentifier;
use bytes::Bytes;
use compact_str::CompactString;
use smallvec::SmallVec;
use std::collections::HashMap;
fn dialect() -> &'static dyn SqlDialect {
&TestPgDialect
}
fn test_qi() -> QualifiedIdentifier {
QualifiedIdentifier::new("public", "users")
}
fn field(name: &str) -> CoercibleField {
CoercibleField::unknown(name.into(), SmallVec::new())
}
fn typed_field(name: &str, base_type: &str) -> CoercibleField {
CoercibleField::from_column(name.into(), SmallVec::new(), base_type.into())
}
fn select_field(name: &str) -> CoercibleSelectField {
CoercibleSelectField {
field: field(name),
agg_function: None,
agg_cast: None,
cast: None,
alias: None,
}
}
#[test]
fn test_read_plan_simple() {
let mut plan = ReadPlan::root(test_qi());
plan.select = vec![select_field("id"), select_field("name")];
let tree = ReadPlanTree::leaf(plan);
let b = read_plan_to_query(&tree, dialect());
let sql = b.sql();
assert!(sql.starts_with("SELECT "));
assert!(sql.contains("\"id\""));
assert!(sql.contains("\"name\""));
assert!(sql.contains("FROM \"public\".\"users\""));
}
#[test]
fn test_read_plan_with_where() {
let mut plan = ReadPlan::root(test_qi());
plan.select = vec![select_field("id")];
plan.where_ = vec![CoercibleLogicTree::Stmnt(CoercibleFilter::Filter {
field: field("id"),
op_expr: OpExpr::Expr {
negated: false,
operation: Operation::Quant(QuantOperator::Equal, None, "1".into()),
},
})];
let tree = ReadPlanTree::leaf(plan);
let b = read_plan_to_query(&tree, dialect());
assert!(b.sql().contains("WHERE"));
assert!(b.sql().contains("$1"));
assert_eq!(b.param_count(), 1);
}
#[test]
fn test_read_plan_with_order() {
let mut plan = ReadPlan::root(test_qi());
plan.select = vec![select_field("name")];
plan.order = vec![CoercibleOrderTerm::Term {
field: field("name"),
direction: Some(OrderDirection::Asc),
nulls: None,
}];
let tree = ReadPlanTree::leaf(plan);
let sql = read_plan_to_query(&tree, dialect()).sql().to_string();
assert!(sql.contains("ORDER BY"));
assert!(sql.contains("ASC"));
}
#[test]
fn test_read_plan_with_limit_offset() {
let mut plan = ReadPlan::root(test_qi());
plan.select = vec![select_field("id")];
plan.range = Range {
offset: 5,
limit_to: Some(14),
};
let tree = ReadPlanTree::leaf(plan);
let sql = read_plan_to_query(&tree, dialect()).sql().to_string();
assert!(sql.contains("LIMIT 10"));
assert!(sql.contains("OFFSET 5"));
}
#[test]
fn test_read_plan_default_select() {
let plan = ReadPlan::root(test_qi());
let tree = ReadPlanTree::leaf(plan);
let sql = read_plan_to_query(&tree, dialect()).sql().to_string();
assert!(sql.contains("\"public\".\"users\".*"));
}
#[test]
fn test_read_plan_count_query() {
let mut plan = ReadPlan::root(test_qi());
plan.where_ = vec![CoercibleLogicTree::Stmnt(CoercibleFilter::Filter {
field: field("status"),
op_expr: OpExpr::Expr {
negated: false,
operation: Operation::Quant(QuantOperator::Equal, None, "active".into()),
},
})];
let tree = ReadPlanTree::leaf(plan);
let b = read_plan_to_count_query(&tree, dialect());
assert!(b.sql().contains("COUNT(*)"));
assert!(b.sql().contains("_dbrst_count_t"));
}
#[test]
fn test_read_plan_with_lateral_join() {
use crate::schema_cache::relationship::{AnyRelationship, Cardinality, Relationship};
let root = ReadPlan::root(test_qi());
let mut child = ReadPlan::child(
QualifiedIdentifier::new("public", "posts"),
"posts".into(),
1,
);
child.select = vec![select_field("id"), select_field("title")];
child.rel_to_parent = Some(AnyRelationship::ForeignKey(Relationship {
table: QualifiedIdentifier::new("public", "users"),
foreign_table: QualifiedIdentifier::new("public", "posts"),
is_self: false,
cardinality: Cardinality::O2M {
constraint: "fk_posts".into(),
columns: smallvec::smallvec![("id".into(), "user_id".into())],
},
table_is_view: false,
foreign_table_is_view: false,
}));
child.rel_join_conds = vec![JoinCondition {
parent: (test_qi(), "id".into()),
child: (
QualifiedIdentifier::new("public", "posts"),
"user_id".into(),
),
}];
let tree = ReadPlanTree::with_children(root, vec![ReadPlanTree::leaf(child)]);
let sql = read_plan_to_query(&tree, dialect()).sql().to_string();
assert!(sql.contains("LEFT JOIN LATERAL"));
assert!(sql.contains("json_agg"));
assert!(sql.contains("ON TRUE"));
}
#[test]
fn test_insert_query() {
let plan = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![typed_field("id", "integer"), typed_field("name", "text")],
body: Payload::RawJSON(Bytes::from(r#"[{"id":1,"name":"test"}]"#)),
on_conflict: None,
where_: vec![],
returning: vec![select_field("id")],
pk_cols: vec!["id".into()],
apply_defaults: false,
});
let b = mutate_plan_to_query(&plan, dialect());
let sql = b.sql();
assert!(sql.starts_with("INSERT INTO "));
assert!(sql.contains("json_to_recordset"));
assert!(sql.contains("RETURNING"));
}
#[test]
fn test_insert_with_on_conflict() {
let plan = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![typed_field("id", "integer"), typed_field("name", "text")],
body: Payload::RawJSON(Bytes::from(r#"[{"id":1,"name":"test"}]"#)),
on_conflict: Some(crate::plan::mutate_plan::OnConflict {
columns: vec!["id".into()],
merge_duplicates: true,
}),
where_: vec![],
returning: vec![],
pk_cols: vec!["id".into()],
apply_defaults: false,
});
let sql = mutate_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.contains("ON CONFLICT"));
assert!(sql.contains("DO UPDATE SET"));
assert!(sql.contains("EXCLUDED"));
}
#[test]
fn test_insert_do_nothing() {
let plan = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![typed_field("id", "integer")],
body: Payload::RawJSON(Bytes::from(r#"[{"id":1}]"#)),
on_conflict: Some(crate::plan::mutate_plan::OnConflict {
columns: vec!["id".into()],
merge_duplicates: false,
}),
where_: vec![],
returning: vec![],
pk_cols: vec!["id".into()],
apply_defaults: false,
});
let sql = mutate_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.contains("DO NOTHING"));
}
#[test]
fn test_update_query() {
let plan = MutatePlan::Update(UpdatePlan {
into: test_qi(),
columns: vec![typed_field("name", "text")],
body: Payload::RawJSON(Bytes::from(r#"{"name":"updated"}"#)),
where_: vec![CoercibleLogicTree::Stmnt(CoercibleFilter::Filter {
field: field("id"),
op_expr: OpExpr::Expr {
negated: false,
operation: Operation::Quant(QuantOperator::Equal, None, "1".into()),
},
})],
returning: vec![select_field("id"), select_field("name")],
apply_defaults: false,
});
let sql = mutate_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.starts_with("UPDATE "));
assert!(sql.contains("SET "));
assert!(sql.contains("WHERE"));
assert!(sql.contains("RETURNING"));
}
#[test]
fn test_delete_query() {
let plan = MutatePlan::Delete(DeletePlan {
from: test_qi(),
where_: vec![CoercibleLogicTree::Stmnt(CoercibleFilter::Filter {
field: field("id"),
op_expr: OpExpr::Expr {
negated: false,
operation: Operation::Quant(QuantOperator::Equal, None, "1".into()),
},
})],
returning: vec![],
});
let sql = mutate_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.starts_with("DELETE FROM "));
assert!(sql.contains("WHERE"));
}
#[test]
fn test_call_plan_named_args() {
let mut args = HashMap::new();
args.insert(CompactString::from("a"), RpcParamValue::Fixed("1".into()));
args.insert(CompactString::from("b"), RpcParamValue::Fixed("2".into()));
let plan = CallPlan {
qi: QualifiedIdentifier::new("public", "add_numbers"),
params: CallParams::KeyParams(vec![
crate::schema_cache::routine::RoutineParam {
name: "a".into(),
pg_type: "integer".into(),
type_max_length: "integer".into(),
required: true,
is_variadic: false,
},
crate::schema_cache::routine::RoutineParam {
name: "b".into(),
pg_type: "integer".into(),
type_max_length: "integer".into(),
required: true,
is_variadic: false,
},
]),
args: CallArgs::DirectArgs(args),
scalar: true,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let sql = call_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.starts_with("SELECT * FROM \"public\".\"add_numbers\"("));
assert!(sql.contains(":="));
}
#[test]
fn test_call_plan_json_body() {
let plan = CallPlan {
qi: QualifiedIdentifier::new("public", "process_data"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(Some(Bytes::from(r#"{"key":"value"}"#))),
scalar: false,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let b = call_plan_to_query(&plan, dialect());
assert!(b.sql().contains("$1"));
assert_eq!(b.param_count(), 1);
}
#[test]
fn test_call_plan_no_args() {
let plan = CallPlan {
qi: QualifiedIdentifier::new("public", "get_time"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(None),
scalar: true,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let sql = call_plan_to_query(&plan, dialect()).sql().to_string();
assert_eq!(sql, "SELECT * FROM \"public\".\"get_time\"()");
}
#[test]
fn test_call_plan_variadic() {
let mut args = HashMap::new();
args.insert(
CompactString::from("vals"),
RpcParamValue::Variadic(vec!["a".into(), "b".into(), "c".into()]),
);
let plan = CallPlan {
qi: QualifiedIdentifier::new("public", "concat_vals"),
params: CallParams::KeyParams(vec![crate::schema_cache::routine::RoutineParam {
name: "vals".into(),
pg_type: "text".into(),
type_max_length: "text".into(),
required: true,
is_variadic: true,
}]),
args: CallArgs::DirectArgs(args),
scalar: true,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let sql = call_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.contains("VARIADIC ARRAY["));
}
#[test]
fn test_payload_to_bytes_raw_json() {
let payload = Payload::RawJSON(Bytes::from(r#"[{"id":1}]"#));
let bytes = payload_to_bytes(&payload);
assert_eq!(bytes, b"[{\"id\":1}]");
}
#[test]
fn test_insert_default_values() {
let plan = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![],
body: Payload::RawJSON(Bytes::from("{}")),
on_conflict: None,
where_: vec![],
returning: vec![],
pk_cols: vec![],
apply_defaults: true,
});
let sql = mutate_plan_to_query(&plan, dialect()).sql().to_string();
assert!(sql.contains("DEFAULT VALUES"));
}
}