use super::{
CTE, ConditionValue, Operator, OrGroup, QueryBuilder, QueryFragment, UnionClause,
WhereCondition, WindowFunction, WindowFunctionType, db_sql,
};
use crate::error::{Error, Result};
use crate::model::Model;
use std::marker::PhantomData;
impl<M: Model> QueryBuilder<M> {
fn split_select_alias(value: &str) -> &str {
let trimmed = value.trim();
let lowered = trimmed.to_ascii_lowercase();
if let Some(index) = lowered.find(" as ") {
trimmed[..index].trim()
} else {
trimmed
}
}
fn simple_column_reference(value: &str) -> Option<(&str, &str)> {
let value = Self::split_select_alias(value);
if value.is_empty()
|| value.starts_with('"')
|| value.ends_with('"')
|| value.starts_with('`')
|| value.ends_with('`')
|| value.contains('(')
|| value.contains(')')
|| value.contains('*')
|| value.contains(' ')
{
return None;
}
match value.split_once('.') {
Some((table, column))
if !table.is_empty() && !column.is_empty() && !column.contains('.') =>
{
Some((table, column))
}
Some(_) => None,
None => Some(("", value)),
}
}
fn validate_model_column_reference(kind: &str, value: &str) -> std::result::Result<(), String> {
let Some((table, column)) = Self::simple_column_reference(value) else {
return Ok(());
};
let reference = if table.is_empty() {
column.to_string()
} else {
format!("{}.{}", table, column)
};
db_sql::validate_identifier_reference(kind, &reference)?;
if table.is_empty() || table == M::table_name() {
if M::column_names().contains(&column) {
Ok(())
} else {
Err(format!(
"unknown {} '{}' for model '{}'; known columns: {}",
kind,
reference,
M::table_name(),
M::column_names().join(", ")
))
}
} else {
Ok(())
}
}
fn validate_condition(condition: &WhereCondition) -> std::result::Result<(), String> {
match (&condition.operator, &condition.value) {
(Operator::Raw, ConditionValue::RawExpr(raw_sql)) => {
let kind = if condition.column.is_empty() {
"WHERE raw SQL"
} else {
"WHERE raw column expression"
};
db_sql::validate_raw_sql_fragment(kind, raw_sql)
}
(Operator::SubqueryIn, ConditionValue::Subquery(query_sql))
| (Operator::SubqueryNotIn, ConditionValue::Subquery(query_sql)) => {
db_sql::validate_subquery_sql(query_sql)
}
_ => Ok(()),
}?;
if !condition.column.is_empty() {
Self::validate_model_column_reference("WHERE column", &condition.column)?;
}
Ok(())
}
fn validate_or_group(group: &OrGroup) -> std::result::Result<(), String> {
for condition in &group.conditions {
Self::validate_condition(condition)?;
}
for nested_group in &group.nested_groups {
Self::validate_or_group(nested_group)?;
}
Ok(())
}
pub(super) fn validate_union_clause(union: &UnionClause) -> std::result::Result<(), String> {
db_sql::validate_subquery_sql(&union.query_sql)
}
pub(super) fn validate_window_function(
window_function: &WindowFunction,
) -> std::result::Result<(), String> {
db_sql::validate_identifier("window alias", &window_function.alias)?;
for column in &window_function.partition_by {
Self::validate_model_column_reference("window PARTITION BY column", column)?;
}
for (column, _) in &window_function.order_by {
Self::validate_model_column_reference("window ORDER BY column", column)?;
}
match &window_function.function {
WindowFunctionType::Lag(column, _, default)
| WindowFunctionType::Lead(column, _, default) => {
Self::validate_model_column_reference("window function column", column)?;
if let Some(default) = default {
db_sql::validate_raw_sql_fragment("LAG/LEAD default expression", default)?;
}
}
WindowFunctionType::FirstValue(column)
| WindowFunctionType::LastValue(column)
| WindowFunctionType::Sum(column)
| WindowFunctionType::Avg(column)
| WindowFunctionType::Min(column)
| WindowFunctionType::Max(column) => {
Self::validate_model_column_reference("window function column", column)?;
}
WindowFunctionType::NthValue(column, _) => {
Self::validate_model_column_reference("window function column", column)?;
}
WindowFunctionType::Count(Some(column)) => {
Self::validate_model_column_reference("window function column", column)?;
}
WindowFunctionType::Custom(expression) => {
db_sql::validate_raw_sql_fragment("window function expression", expression)?;
}
WindowFunctionType::RowNumber
| WindowFunctionType::Rank
| WindowFunctionType::DenseRank
| WindowFunctionType::Ntile(_)
| WindowFunctionType::Count(None) => {}
}
Ok(())
}
pub(super) fn validate_cte_clause(cte: &CTE) -> std::result::Result<(), String> {
db_sql::validate_identifier("CTE name", &cte.name)?;
if let Some(columns) = &cte.columns {
for column in columns {
db_sql::validate_identifier("CTE column", column)?;
}
}
db_sql::validate_subquery_sql(&cte.query_sql)
}
fn validate_query_fragments(&self) -> Result<()> {
for condition in &self.conditions {
Self::validate_condition(condition).map_err(Error::invalid_query)?;
}
for group in &self.or_groups {
Self::validate_or_group(group).map_err(Error::invalid_query)?;
}
for (column, _) in &self.order_by {
Self::validate_model_column_reference("ORDER BY column", column)
.map_err(Error::invalid_query)?;
}
for column in &self.group_by {
Self::validate_model_column_reference("GROUP BY column", column)
.map_err(Error::invalid_query)?;
}
if let Some(columns) = &self.select_columns {
for column in columns {
Self::validate_model_column_reference("SELECT column", column)
.map_err(Error::invalid_query)?;
}
}
for union in &self.unions {
Self::validate_union_clause(union).map_err(Error::invalid_query)?;
}
for window_function in &self.window_functions {
Self::validate_window_function(window_function).map_err(Error::invalid_query)?;
}
for cte in &self.ctes {
Self::validate_cte_clause(cte).map_err(Error::invalid_query)?;
}
Ok(())
}
pub fn new() -> Self {
Self {
_marker: PhantomData,
database: None,
conditions: Vec::new(),
or_groups: Vec::new(),
order_by: Vec::new(),
limit_value: None,
offset_value: None,
select_columns: None,
raw_select_expressions: Vec::new(),
include_trashed: false,
only_trashed: false,
joins: Vec::new(),
invalid_query_reason: None,
group_by: Vec::new(),
having_conditions: Vec::new(),
unions: Vec::new(),
window_functions: Vec::new(),
ctes: Vec::new(),
cache_options: None,
cache_key: None,
}
}
pub(crate) fn with_database(mut self, database: crate::database::Database) -> Self {
self.database = Some(database);
self
}
pub(super) fn current_db(&self) -> Result<crate::database::Database> {
if let Some(database) = &self.database {
Ok(database.clone())
} else {
crate::database::__current_db()
}
}
pub fn consolidate(&self) -> QueryFragment<M> {
QueryFragment {
_marker: PhantomData,
conditions: self.conditions.clone(),
or_groups: self.or_groups.clone(),
order_by: self.order_by.clone(),
limit_value: self.limit_value,
offset_value: self.offset_value,
select_columns: self.select_columns.clone(),
raw_select_expressions: self.raw_select_expressions.clone(),
group_by: self.group_by.clone(),
having_conditions: self.having_conditions.clone(),
joins: self.joins.clone(),
unions: self.unions.clone(),
window_functions: self.window_functions.clone(),
ctes: self.ctes.clone(),
cache_options: self.cache_options.clone(),
cache_key: self.cache_key.clone(),
invalid_query_reason: self.invalid_query_reason.clone(),
include_trashed: self.include_trashed,
only_trashed: self.only_trashed,
}
}
pub fn apply(mut self, fragment: &QueryFragment<M>) -> Self {
self.conditions.extend(fragment.conditions.clone());
self.or_groups.extend(fragment.or_groups.clone());
if self.order_by.is_empty() {
self.order_by.extend(fragment.order_by.clone());
}
if self.limit_value.is_none() {
self.limit_value = fragment.limit_value;
}
if self.offset_value.is_none() {
self.offset_value = fragment.offset_value;
}
if self.select_columns.is_none() {
self.select_columns = fragment.select_columns.clone();
}
self.raw_select_expressions
.extend(fragment.raw_select_expressions.clone());
self.group_by.extend(fragment.group_by.clone());
self.having_conditions
.extend(fragment.having_conditions.clone());
self.joins.extend(fragment.joins.clone());
self.unions.extend(fragment.unions.clone());
self.window_functions
.extend(fragment.window_functions.clone());
self.ctes.extend(fragment.ctes.clone());
if self.cache_options.is_none() {
self.cache_options = fragment.cache_options.clone();
}
if self.cache_key.is_none() {
self.cache_key = fragment.cache_key.clone();
}
if self.invalid_query_reason.is_none() {
self.invalid_query_reason = fragment.invalid_query_reason.clone();
}
if fragment.only_trashed {
self.only_trashed = true;
self.include_trashed = false;
} else if fragment.include_trashed {
self.include_trashed = true;
self.only_trashed = false;
}
self
}
pub(super) fn invalidate_query(&mut self, reason: String) {
if self.invalid_query_reason.is_none() {
self.invalid_query_reason = Some(reason);
}
}
pub(super) fn ensure_query_is_valid(&self) -> Result<()> {
if let Some(reason) = &self.invalid_query_reason {
return Err(Error::invalid_query(reason.clone()));
}
self.validate_query_fragments()?;
Ok(())
}
pub(super) fn validate_join_clause(
table: &str,
alias: Option<&str>,
left_column: &str,
right_column: &str,
) -> std::result::Result<(), String> {
db_sql::validate_identifier("JOIN table", table)?;
if let Some(alias) = alias {
db_sql::validate_identifier("JOIN alias", alias)?;
}
db_sql::validate_join_column(left_column)?;
db_sql::validate_join_column(right_column)?;
Ok(())
}
}