use std::fmt::Write as _;
use crate::{
compiler::{
aggregate_types::{AggregateFunction, TemporalBucket},
aggregation::{
AggregateExpression, AggregationPlan, GroupByExpression, OrderByClause, OrderDirection,
ValidatedHavingCondition,
},
fact_table::FactTableMetadata,
},
db::{
identifier::{
quote_mysql_identifier, quote_postgres_identifier, quote_sqlserver_identifier,
},
path_escape::{
escape_mysql_json_path, escape_postgres_jsonb_segment, escape_sqlite_json_path,
escape_sqlserver_json_path,
},
types::DatabaseType,
where_clause::{WhereClause, WhereOperator},
},
error::{FraiseQLError, Result},
utils::casing::to_snake_case,
};
mod expressions;
mod where_clause;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone)]
pub struct ParameterizedAggregationSql {
pub sql: String,
pub params: Vec<serde_json::Value>,
}
pub struct AggregationSqlGenerator {
database_type: DatabaseType,
}
impl AggregationSqlGenerator {
#[must_use]
pub const fn new(database_type: DatabaseType) -> Self {
Self { database_type }
}
pub(super) fn jsonb_extract_sql(&self, jsonb_column: &str, path: &str) -> String {
match self.database_type {
DatabaseType::PostgreSQL => {
let escaped = escape_postgres_jsonb_segment(path);
format!("{}->>'{}' ", jsonb_column, escaped)
},
DatabaseType::MySQL => {
let escaped = escape_mysql_json_path(&[path.to_owned()]);
format!("JSON_UNQUOTE(JSON_EXTRACT({}, '{}'))", jsonb_column, escaped)
},
DatabaseType::SQLite => {
let escaped = escape_sqlite_json_path(&[path.to_owned()]);
format!("json_extract({}, '{}')", jsonb_column, escaped)
},
DatabaseType::SQLServer => {
let escaped = escape_sqlserver_json_path(&[path.to_owned()]);
format!("JSON_VALUE({}, '{}')", jsonb_column, escaped)
},
}
}
pub(super) const fn operator_to_sql(&self, operator: &WhereOperator) -> &'static str {
match operator {
WhereOperator::Neq => "!=",
WhereOperator::Gt => ">",
WhereOperator::Gte => ">=",
WhereOperator::Lt => "<",
WhereOperator::Lte => "<=",
WhereOperator::In => "IN",
WhereOperator::Nin => "NOT IN",
WhereOperator::Like
| WhereOperator::Contains
| WhereOperator::Startswith
| WhereOperator::Endswith => "LIKE",
WhereOperator::Ilike
| WhereOperator::Icontains
| WhereOperator::Istartswith
| WhereOperator::Iendswith => match self.database_type {
DatabaseType::PostgreSQL => "ILIKE",
_ => "LIKE", },
_ => "=",
}
}
pub(super) fn quote_identifier(&self, name: &str) -> String {
match self.database_type {
DatabaseType::MySQL => quote_mysql_identifier(name),
DatabaseType::SQLServer => quote_sqlserver_identifier(name),
DatabaseType::PostgreSQL | DatabaseType::SQLite => quote_postgres_identifier(name),
}
}
pub(super) fn escape_sql_string(&self, s: &str) -> String {
let without_nulls: std::borrow::Cow<str> = if s.contains('\0') {
s.replace('\0', "").into()
} else {
s.into()
};
if matches!(self.database_type, DatabaseType::MySQL) {
without_nulls.replace('\\', "\\\\").replace('\'', "''")
} else {
without_nulls.replace('\'', "''")
}
}
pub(super) fn placeholder(&self, index: usize) -> String {
match self.database_type {
DatabaseType::PostgreSQL => format!("${}", index + 1),
DatabaseType::SQLServer => format!("@P{}", index + 1),
_ => "?".to_string(),
}
}
pub(super) fn emit_value_param(
&self,
value: &serde_json::Value,
params: &mut Vec<serde_json::Value>,
) -> String {
if matches!(value, serde_json::Value::Null) {
return "NULL".to_string();
}
let idx = params.len();
params.push(value.clone());
self.placeholder(idx)
}
pub(super) fn emit_like_pattern_param(
&self,
operator: &WhereOperator,
value: &str,
params: &mut Vec<serde_json::Value>,
) -> (String, bool) {
let clean: String = if value.contains('\0') {
value.replace('\0', "")
} else {
value.to_string()
};
let (pattern, needs_escape) = match operator {
WhereOperator::Contains | WhereOperator::Icontains => {
let esc = clean.replace('!', "!!").replace('%', "!%").replace('_', "!_");
(format!("%{esc}%"), true)
},
WhereOperator::Startswith | WhereOperator::Istartswith => {
let esc = clean.replace('!', "!!").replace('%', "!%").replace('_', "!_");
(format!("{esc}%"), true)
},
WhereOperator::Endswith | WhereOperator::Iendswith => {
let esc = clean.replace('!', "!!").replace('%', "!%").replace('_', "!_");
(format!("%{esc}"), true)
},
_ => (clean, false),
};
let ph = self.emit_value_param(&serde_json::Value::String(pattern), params);
(ph, needs_escape)
}
pub fn generate_parameterized(
&self,
plan: &AggregationPlan,
) -> Result<ParameterizedAggregationSql> {
let mut params: Vec<serde_json::Value> = Vec::new();
let select_sql =
self.build_select_clause(&plan.group_by_expressions, &plan.aggregate_expressions)?;
let from_sql = format!("FROM {}", plan.request.table_name);
let where_sql = if let Some(ref wc) = plan.request.where_clause {
self.build_where_clause_parameterized(wc, &plan.metadata, &mut params)?
} else {
String::new()
};
let group_sql = if !plan.group_by_expressions.is_empty() {
self.build_group_by_clause(&plan.group_by_expressions)?
} else {
String::new()
};
let having_sql =
self.build_having_clause_parameterized(&plan.having_conditions, &mut params)?;
let native_aliases = plan.native_aliases();
let order_sql = if !plan.request.order_by.is_empty() {
self.build_order_by_clause(&plan.request.order_by, &native_aliases)?
} else {
String::new()
};
let mut parts: Vec<&str> = vec![
&select_sql,
&from_sql,
&where_sql,
&group_sql,
&having_sql,
&order_sql,
];
parts.retain(|s| !s.is_empty());
let mut sql = parts.join("\n");
if let Some(limit) = plan.request.limit {
let _ = write!(sql, "\nLIMIT {limit}");
}
if let Some(offset) = plan.request.offset {
let _ = write!(sql, "\nOFFSET {offset}");
}
Ok(ParameterizedAggregationSql { sql, params })
}
}