use crate::api_request::preferences::PreferCount;
use crate::backend::SqlDialect;
use crate::plan::call_plan::CallPlan;
use crate::plan::mutate_plan::MutatePlan;
use crate::plan::read_plan::ReadPlanTree;
use super::builder;
use super::fragment;
use super::sql_builder::SqlBuilder;
pub fn main_read(
read_plan: &ReadPlanTree,
prefer_count: Option<PreferCount>,
max_rows: Option<i64>,
headers_only: bool,
handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
dialect: &dyn SqlDialect,
) -> SqlBuilder {
let inner = builder::read_plan_to_query(read_plan, dialect);
let mut b = SqlBuilder::new();
b.push("WITH dbrst_source AS (");
b.push_builder(&inner);
b.push(")");
let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
if has_exact_count {
let count_q = builder::read_plan_to_count_query(read_plan, dialect);
b.push(", dbrst_count AS (");
b.push_builder(&count_q);
b.push(")");
}
b.push(" SELECT ");
if has_exact_count {
b.push("(SELECT ");
b.push_ident("dbrst_filtered_count");
b.push(" FROM dbrst_count)");
} else {
b.push("NULL");
}
b.push(" AS total_result_set");
b.push(", ");
dialect.count_expr(&mut b, "_dbrst_t");
b.push(" AS page_total");
let col_names = select_column_names(read_plan);
let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
if headers_only {
b.push(", NULL AS body");
} else {
b.push(", ");
if let Some(h) = handler {
fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
} else {
fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
}
b.push(" AS body");
}
b.push(", ");
dialect.get_session_var(&mut b, "response.headers", "response_headers");
b.push(", ");
dialect.get_session_var(&mut b, "response.status", "response_status");
b.push(" FROM (SELECT * FROM dbrst_source");
if let Some(max) = max_rows {
b.push(" LIMIT ");
b.push(&max.to_string());
}
b.push(") AS ");
b.push_ident("_dbrst_t");
b
}
fn select_column_names(tree: &ReadPlanTree) -> Vec<String> {
if tree.node.select.iter().any(|sf| sf.field.full_row) {
return Vec::new();
}
tree.node
.select
.iter()
.map(|sf| {
sf.alias
.as_ref()
.map(|a| a.to_string())
.unwrap_or_else(|| sf.field.name.to_string())
})
.collect()
}
pub fn main_write(
mutate_plan: &MutatePlan,
_read_plan: &ReadPlanTree,
return_representation: bool,
handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
dialect: &dyn SqlDialect,
) -> SqlBuilder {
let inner = builder::mutate_plan_to_query(mutate_plan, dialect);
let has_returning = !mutate_plan.returning().is_empty();
let mut b = SqlBuilder::new();
b.push("WITH dbrst_source AS (");
b.push_builder(&inner);
if !has_returning {
b.push(" RETURNING 1");
}
b.push(")");
b.push(" SELECT ");
b.push("'' AS total_result_set");
b.push(", ");
dialect.count_expr(&mut b, "_dbrst_t");
b.push(" AS page_total");
let col_names: Vec<String> = mutate_plan
.returning()
.iter()
.map(|sf| {
sf.alias
.as_ref()
.map(|a| a.to_string())
.unwrap_or_else(|| sf.field.name.to_string())
})
.collect();
let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
if return_representation && has_returning {
b.push(", ");
if let Some(h) = handler {
fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
} else {
fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
}
b.push(" AS body");
} else {
b.push(", NULL AS body");
}
b.push(", ");
dialect.get_session_var(&mut b, "response.headers", "response_headers");
b.push(", ");
dialect.get_session_var(&mut b, "response.status", "response_status");
b.push(" FROM (SELECT * FROM dbrst_source) AS ");
b.push_ident("_dbrst_t");
b
}
pub fn main_write_split(
mutate_plan: &MutatePlan,
_read_plan: &ReadPlanTree,
return_representation: bool,
handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
dialect: &dyn SqlDialect,
) -> (SqlBuilder, SqlBuilder) {
let mut mutation = builder::mutate_plan_to_query(mutate_plan, dialect);
let has_returning = !mutate_plan.returning().is_empty();
if !has_returning {
mutation.push(" RETURNING 1");
}
let mut b = SqlBuilder::new();
b.push("SELECT ");
b.push("'' AS total_result_set");
b.push(", ");
dialect.count_expr(&mut b, "_dbrst_t");
b.push(" AS page_total");
let col_names: Vec<String> = mutate_plan
.returning()
.iter()
.map(|sf| {
sf.alias
.as_ref()
.map(|a| a.to_string())
.unwrap_or_else(|| sf.field.name.to_string())
})
.collect();
let col_refs: Vec<&str> = col_names.iter().map(|s| s.as_str()).collect();
if return_representation && has_returning {
b.push(", ");
if let Some(h) = handler {
fragment::handler_agg_with_media_cols(&mut b, h, false, dialect, &col_refs);
} else {
fragment::handler_agg_cols(&mut b, false, dialect, &col_refs);
}
b.push(" AS body");
} else {
b.push(", NULL AS body");
}
b.push(", ");
dialect.get_session_var(&mut b, "response.headers", "response_headers");
b.push(", ");
dialect.get_session_var(&mut b, "response.status", "response_status");
b.push(" FROM ");
b.push_ident("_dbrst_mut");
b.push(" AS ");
b.push_ident("_dbrst_t");
(mutation, b)
}
pub fn main_call(
call_plan: &CallPlan,
prefer_count: Option<PreferCount>,
max_rows: Option<i64>,
handler: Option<&crate::schema_cache::media_handler::MediaHandler>,
dialect: &dyn SqlDialect,
) -> SqlBuilder {
let inner = builder::call_plan_to_query(call_plan, dialect);
let mut b = SqlBuilder::new();
b.push("WITH dbrst_source AS (");
b.push_builder(&inner);
b.push(")");
let has_exact_count = matches!(prefer_count, Some(PreferCount::Exact));
b.push(" SELECT ");
if has_exact_count {
dialect.count_star_from(&mut b, "dbrst_source");
} else {
b.push("NULL");
}
b.push(" AS total_result_set");
if call_plan.scalar {
b.push(", 1 AS page_total");
} else {
b.push(", ");
dialect.count_expr(&mut b, "_dbrst_t");
b.push(" AS page_total");
}
b.push(", ");
if call_plan.scalar {
dialect.row_to_json_star(&mut b, "dbrst_source");
} else if let Some(h) = handler {
fragment::handler_agg_with_media(&mut b, h, false, dialect);
} else {
fragment::handler_agg(&mut b, false, dialect);
}
b.push(" AS body");
b.push(", ");
dialect.get_session_var(&mut b, "response.headers", "response_headers");
b.push(", ");
dialect.get_session_var(&mut b, "response.status", "response_status");
if call_plan.scalar {
b.push(" FROM dbrst_source");
} else {
b.push(" FROM (SELECT * FROM dbrst_source");
if let Some(max) = max_rows {
b.push(" LIMIT ");
b.push(&max.to_string());
}
b.push(") AS ");
b.push_ident("_dbrst_t");
}
b
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_request::types::Payload;
use crate::plan::call_plan::{CallArgs, CallParams, CallPlan};
use crate::plan::mutate_plan::{InsertPlan, MutatePlan};
use crate::plan::read_plan::{ReadPlan, ReadPlanTree};
use crate::plan::types::*;
use crate::test_helpers::TestPgDialect;
use crate::types::identifiers::QualifiedIdentifier;
use bytes::Bytes;
use smallvec::SmallVec;
fn dialect() -> &'static dyn SqlDialect {
&TestPgDialect
}
fn test_qi() -> QualifiedIdentifier {
QualifiedIdentifier::new("public", "users")
}
fn select_field(name: &str) -> CoercibleSelectField {
CoercibleSelectField {
field: CoercibleField::unknown(name.into(), SmallVec::new()),
agg_function: None,
agg_cast: None,
cast: None,
alias: None,
}
}
fn typed_field(name: &str, base_type: &str) -> CoercibleField {
CoercibleField::from_column(name.into(), SmallVec::new(), base_type.into())
}
#[test]
fn test_main_read_basic() {
let mut plan = ReadPlan::root(test_qi());
plan.select = vec![select_field("id"), select_field("name")];
let tree = ReadPlanTree::leaf(plan);
let b = main_read(&tree, None, None, false, None, dialect());
let sql = b.sql();
assert!(sql.starts_with("WITH dbrst_source AS ("));
assert!(sql.contains("AS total_result_set"));
assert!(sql.contains("AS page_total"));
assert!(sql.contains("AS body"));
assert!(sql.contains("AS response_headers"));
assert!(sql.contains("AS response_status"));
}
#[test]
fn test_main_read_with_exact_count() {
let plan = ReadPlan::root(test_qi());
let tree = ReadPlanTree::leaf(plan);
let b = main_read(
&tree,
Some(PreferCount::Exact),
None,
false,
None,
dialect(),
);
let sql = b.sql();
assert!(sql.contains("dbrst_count"));
assert!(sql.contains("dbrst_filtered_count"));
}
#[test]
fn test_main_read_headers_only() {
let plan = ReadPlan::root(test_qi());
let tree = ReadPlanTree::leaf(plan);
let b = main_read(&tree, None, None, true, None, dialect());
let sql = b.sql();
assert!(sql.contains("NULL AS body"));
}
#[test]
fn test_main_read_with_max_rows() {
let plan = ReadPlan::root(test_qi());
let tree = ReadPlanTree::leaf(plan);
let b = main_read(&tree, None, Some(100), false, None, dialect());
let sql = b.sql();
assert!(sql.contains("LIMIT 100"));
}
#[test]
fn test_main_write_basic() {
let mutate = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![typed_field("name", "text")],
body: Payload::RawJSON(Bytes::from(r#"[{"name":"test"}]"#)),
on_conflict: None,
where_: vec![],
returning: vec![select_field("id")],
pk_cols: vec!["id".into()],
apply_defaults: false,
});
let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
let b = main_write(&mutate, &read, true, None, dialect());
let sql = b.sql();
assert!(sql.starts_with("WITH dbrst_source AS ("));
assert!(sql.contains("INSERT INTO"));
assert!(sql.contains("AS body"));
}
#[test]
fn test_main_write_no_representation() {
let mutate = MutatePlan::Insert(InsertPlan {
into: test_qi(),
columns: vec![],
body: Payload::RawJSON(Bytes::from("{}")),
on_conflict: None,
where_: vec![],
returning: vec![],
pk_cols: vec![],
apply_defaults: false,
});
let read = ReadPlanTree::leaf(ReadPlan::root(test_qi()));
let b = main_write(&mutate, &read, false, None, dialect());
let sql = b.sql();
assert!(sql.contains("NULL AS body"));
}
#[test]
fn test_main_call_basic() {
let call = CallPlan {
qi: QualifiedIdentifier::new("public", "get_time"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(None),
scalar: false,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let b = main_call(&call, None, None, None, dialect());
let sql = b.sql();
assert!(sql.starts_with("WITH dbrst_source AS ("));
assert!(sql.contains("get_time"));
assert!(sql.contains("AS body"));
}
#[test]
fn test_main_call_scalar() {
let call = CallPlan {
qi: QualifiedIdentifier::new("public", "add_numbers"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(None),
scalar: true,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let b = main_call(&call, None, None, None, dialect());
let sql = b.sql();
assert!(sql.contains("row_to_json"));
}
#[test]
fn test_main_call_with_count() {
let call = CallPlan {
qi: QualifiedIdentifier::new("public", "get_data"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(None),
scalar: false,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let b = main_call(&call, Some(PreferCount::Exact), None, None, dialect());
let sql = b.sql();
assert!(sql.contains("pg_catalog.count(*)"));
}
#[test]
fn test_main_call_with_max_rows() {
let call = CallPlan {
qi: QualifiedIdentifier::new("public", "get_data"),
params: CallParams::KeyParams(vec![]),
args: CallArgs::JsonArgs(None),
scalar: false,
set_of_scalar: false,
filter_fields: vec![],
returning: vec![],
};
let b = main_call(&call, None, Some(50), None, dialect());
let sql = b.sql();
assert!(sql.contains("LIMIT 50"));
}
}