use crate::ast::{
Action, Condition, Expr, Merge, MergeAction, MergeMatchKind, MergeSource, Operator, Qail, Value,
};
use crate::transpiler::conditions::{ConditionToSql, read_only_subquery_sql};
use crate::transpiler::dialect::Dialect;
use crate::transpiler::traits::escape_sql_string_literal;
use crate::transpiler::{SqlGenerator, ToSql};
pub fn build_merge(cmd: &Qail, dialect: Dialect) -> String {
if dialect != Dialect::Postgres {
return "-- MERGE is only supported by the PostgreSQL dialect".to_string();
}
let Some(merge) = &cmd.merge else {
return "/* ERROR: MERGE requires source, ON conditions, and WHEN clauses */".to_string();
};
if merge.on.is_empty() {
return "/* ERROR: MERGE requires at least one ON condition */".to_string();
}
if merge.clauses.is_empty() {
return "/* ERROR: MERGE requires at least one WHEN clause */".to_string();
}
if let Some(error) = validate_merge_shape(merge) {
return format!("/* ERROR: {} */", error);
}
let generator = dialect.generator();
let mut sql = String::new();
push_cte_prefix(&mut sql, cmd, dialect);
sql.push_str("MERGE INTO ");
sql.push_str(&generator.quote_identifier(&cmd.table));
if let Some(alias) = &merge.target_alias {
sql.push_str(" AS ");
sql.push_str(&generator.quote_identifier(alias));
}
sql.push_str(" USING ");
sql.push_str(&merge_source_sql(
&merge.source,
dialect,
generator.as_ref(),
));
sql.push_str(" ON ");
sql.push_str(&conditions_sql(&merge.on, generator.as_ref()));
for clause in &merge.clauses {
sql.push_str(" WHEN ");
sql.push_str(match clause.match_kind {
MergeMatchKind::Matched => "MATCHED",
MergeMatchKind::NotMatchedByTarget => "NOT MATCHED BY TARGET",
MergeMatchKind::NotMatchedBySource => "NOT MATCHED BY SOURCE",
});
if !clause.condition.is_empty() {
sql.push_str(" AND ");
sql.push_str(&conditions_sql(&clause.condition, generator.as_ref()));
}
sql.push_str(" THEN ");
sql.push_str(&merge_action_sql(&clause.action, generator.as_ref()));
}
if let Some(returning) = &cmd.returning
&& !returning.is_empty()
{
sql.push_str(" RETURNING ");
let returning_sql: Vec<String> = returning
.iter()
.map(|expr| expr_sql(expr, generator.as_ref()))
.collect();
sql.push_str(&returning_sql.join(", "));
}
sql
}
fn push_cte_prefix(sql: &mut String, cmd: &Qail, dialect: Dialect) {
if cmd.ctes.is_empty() {
return;
}
if cmd.ctes.iter().any(|cte| cte.recursive) {
sql.push_str("WITH RECURSIVE ");
} else {
sql.push_str("WITH ");
}
let cte_parts = cmd
.ctes
.iter()
.map(|cte| super::cte::build_single_cte(cte, dialect))
.collect::<Vec<_>>();
sql.push_str(&cte_parts.join(", "));
sql.push(' ');
}
fn merge_source_sql(
source: &MergeSource,
dialect: Dialect,
generator: &dyn SqlGenerator,
) -> String {
match source {
MergeSource::Table { name, alias } => {
let mut sql = generator.quote_identifier(name);
if let Some(alias) = alias {
sql.push_str(" AS ");
sql.push_str(&generator.quote_identifier(alias));
}
sql
}
MergeSource::Query { query, alias } => {
let mut sql = format!("({})", query.to_sql_with_dialect(dialect));
if let Some(alias) = alias {
sql.push_str(" AS ");
sql.push_str(&generator.quote_identifier(alias));
}
sql
}
}
}
fn conditions_sql(conditions: &[Condition], generator: &dyn SqlGenerator) -> String {
conditions
.iter()
.map(|condition| condition_sql(condition, generator))
.collect::<Vec<_>>()
.join(" AND ")
}
fn condition_sql(condition: &Condition, generator: &dyn SqlGenerator) -> String {
if condition.is_array_unnest {
return condition.to_sql(generator, None);
}
let left = expr_sql(&condition.left, generator);
match condition.op {
Operator::Fuzzy => {
let value = fuzzy_pattern_sql(&condition.value, generator);
format!("{left} {} {value}", generator.fuzzy_operator())
}
Operator::IsNull => format!("{left} IS NULL"),
Operator::IsNotNull => format!("{left} IS NOT NULL"),
Operator::In | Operator::NotIn => in_condition_sql(condition, &left, generator),
Operator::Contains => {
generator.json_contains(&left, &value_sql(&condition.value, generator))
}
Operator::KeyExists => {
generator.json_key_exists(&left, &value_sql(&condition.value, generator))
}
Operator::JsonExists => {
let path = json_path_arg(condition, generator);
generator.json_exists(&left, &path)
}
Operator::JsonQuery => {
let path = json_path_arg(condition, generator);
format!("{} IS NOT NULL", generator.json_query(&left, &path))
}
Operator::JsonValue => {
let path = json_path_arg(condition, generator);
format!("{} IS NOT NULL", generator.json_value(&left, &path))
}
Operator::Between | Operator::NotBetween => {
if let Value::Array(values) = &condition.value
&& values.len() == 2
{
return format!(
"{left} {} {} AND {}",
condition.op.sql_symbol(),
value_sql(&values[0], generator),
value_sql(&values[1], generator)
);
}
invalid_between_condition_sql()
}
Operator::Exists | Operator::NotExists => match &condition.value {
Value::Subquery(query) => {
let keyword = condition.op.sql_symbol();
format!("{keyword} ({})", read_only_subquery_sql(query))
}
_ => invalid_exists_condition_sql(),
},
_ => format!(
"{left} {} {}",
condition.op.sql_symbol(),
value_sql(&condition.value, generator)
),
}
}
fn invalid_exists_condition_sql() -> String {
"FALSE /* ERROR: EXISTS condition requires subquery value */".to_string()
}
fn invalid_in_condition_sql() -> String {
"FALSE /* ERROR: IN condition requires a non-empty array, subquery, or array parameter */"
.to_string()
}
fn invalid_between_condition_sql() -> String {
"FALSE /* ERROR: BETWEEN condition requires exactly two array values */".to_string()
}
fn value_sql(value: &Value, generator: &dyn SqlGenerator) -> String {
match value {
Value::Column(column) => render_named_expr(column, generator),
Value::Expr(expr) => expr_sql(expr, generator),
Value::Subquery(query) => format!("({})", read_only_subquery_sql(query)),
Value::Function(function) => render_raw_function_value(function),
Value::NamedParam(name) => render_named_param(name),
Value::Array(values) => {
let values = values
.iter()
.map(|value| value_sql(value, generator))
.collect::<Vec<_>>()
.join(", ");
format!("({values})")
}
_ => value.to_string(),
}
}
fn json_path_arg(condition: &Condition, generator: &dyn SqlGenerator) -> String {
match &condition.value {
Value::String(path) => path.clone(),
Value::Param(index) => generator.placeholder(*index),
Value::NamedParam(name) => format!(":{}", name),
_ => value_sql(&condition.value, generator),
}
}
fn in_condition_sql(condition: &Condition, left: &str, generator: &dyn SqlGenerator) -> String {
match &condition.value {
Value::Array(values) if !values.is_empty() => {
let values = values
.iter()
.map(|value| value_sql(value, generator))
.collect::<Vec<_>>()
.join(", ");
format!("{left} {} ({values})", condition.op.sql_symbol())
}
Value::Subquery(_) => {
format!(
"{left} {} {}",
condition.op.sql_symbol(),
value_sql(&condition.value, generator)
)
}
Value::Param(_) | Value::NamedParam(_) if condition.op == Operator::In => {
generator.in_array(left, &value_sql(&condition.value, generator))
}
Value::Param(_) | Value::NamedParam(_) => {
generator.not_in_array(left, &value_sql(&condition.value, generator))
}
_ if condition.op == Operator::In => invalid_in_condition_sql(),
_ => invalid_in_condition_sql(),
}
}
fn fuzzy_pattern_sql(value: &Value, generator: &dyn SqlGenerator) -> String {
match value {
Value::String(value) => format!("'%{}%'", escape_sql_string_literal(value)),
Value::Function(value) => format!("'%{}%'", escape_sql_string_literal(value)),
Value::Param(index) => {
let placeholder = generator.placeholder(*index);
generator.string_concat(&["'%'", &placeholder, "'%'"])
}
Value::NamedParam(name) => {
let placeholder = format!(":{}", name);
generator.string_concat(&["'%'", &placeholder, "'%'"])
}
value => format!(
"'%{}%'",
escape_sql_string_literal(&value_sql(value, generator))
),
}
}
fn merge_action_sql(action: &MergeAction, generator: &dyn SqlGenerator) -> String {
match action {
MergeAction::Update { assignments } => {
let assignments = assignments
.iter()
.map(|(col, expr)| {
format!(
"{} = {}",
generator.quote_identifier(col),
expr_sql(expr, generator)
)
})
.collect::<Vec<_>>()
.join(", ");
format!("UPDATE SET {}", assignments)
}
MergeAction::Insert { columns, values } => {
let mut sql = String::from("INSERT");
if !columns.is_empty() {
let cols = columns
.iter()
.map(|col| generator.quote_identifier(col))
.collect::<Vec<_>>()
.join(", ");
sql.push_str(" (");
sql.push_str(&cols);
sql.push(')');
}
let values = values
.iter()
.map(|expr| expr_sql(expr, generator))
.collect::<Vec<_>>()
.join(", ");
sql.push_str(" VALUES (");
sql.push_str(&values);
sql.push(')');
sql
}
MergeAction::Delete => "DELETE".to_string(),
MergeAction::DoNothing => "DO NOTHING".to_string(),
}
}
fn expr_sql(expr: &Expr, generator: &dyn SqlGenerator) -> String {
match expr {
Expr::Star => "*".to_string(),
Expr::Named(name) => render_named_expr(name, generator),
Expr::Aliased { name, alias } => format!(
"{} AS {}",
render_named_expr(name, generator),
render_identifier_or_error(alias, generator)
),
Expr::Aggregate {
col,
func,
distinct,
filter,
..
} => {
let mut sql = if *distinct {
format!("{}(DISTINCT {})", func, render_named_expr(col, generator))
} else {
format!("{}({})", func, render_named_expr(col, generator))
};
if let Some(conditions) = filter
&& !conditions.is_empty()
{
sql.push_str(" FILTER (WHERE ");
sql.push_str(&conditions_sql(conditions, generator));
sql.push(')');
}
sql
}
Expr::Literal(value) => value_sql(value, generator),
Expr::Case {
when_clauses,
else_value,
..
} => {
let mut sql = String::from("CASE");
for (condition, value) in when_clauses {
sql.push_str(" WHEN ");
sql.push_str(&condition_sql(condition, generator));
sql.push_str(" THEN ");
sql.push_str(&expr_sql(value, generator));
}
if let Some(value) = else_value {
sql.push_str(" ELSE ");
sql.push_str(&expr_sql(value, generator));
}
sql.push_str(" END");
sql
}
Expr::Binary {
left, op, right, ..
} => match op {
crate::ast::BinaryOp::IsNull => {
format!("({} IS NULL)", expr_sql(left, generator))
}
crate::ast::BinaryOp::IsNotNull => {
format!("({} IS NOT NULL)", expr_sql(left, generator))
}
_ => format!(
"({} {} {})",
expr_sql(left, generator),
op,
expr_sql(right, generator)
),
},
Expr::FunctionCall { name, args, .. } => {
let Some(function) = render_function_name(name) else {
return "/* ERROR: Invalid function name */".to_string();
};
let args = args
.iter()
.map(|arg| expr_sql(arg, generator))
.collect::<Vec<_>>()
.join(", ");
format!("{function}({args})")
}
Expr::SpecialFunction { name, args, .. } => {
let Some(function) = render_function_name(name) else {
return "/* ERROR: Invalid function name */".to_string();
};
let mut parts = Vec::new();
for (keyword, expr) in args {
let expr = expr_sql(expr, generator);
if let Some(keyword) = keyword {
let Some(keyword) = render_sql_keyword(keyword) else {
return "/* ERROR: Invalid function keyword */".to_string();
};
parts.push(format!("{keyword} {expr}"));
} else {
parts.push(expr);
}
}
format!("{function}({})", parts.join(" "))
}
Expr::Cast {
expr, target_type, ..
} => {
let Some(target_type) = checked_sql_type_fragment(target_type) else {
return "/* ERROR: Invalid cast target type */".to_string();
};
let inner = expr_sql(expr, generator);
if matches!(expr.as_ref(), Expr::JsonAccess { .. } | Expr::Case { .. }) {
format!("({inner})::{target_type}")
} else {
format!("{inner}::{target_type}")
}
}
Expr::JsonAccess {
column,
path_segments,
..
} => {
let mut sql = generator.quote_identifier(column);
for (path, as_text) in path_segments {
let op = if *as_text { "->>" } else { "->" };
if path.parse::<i64>().is_ok() {
sql.push_str(&format!("{}{}", op, path));
} else {
sql.push_str(&format!("{}'{}'", op, escape_sql_string_literal(path)));
}
}
sql
}
Expr::Collate {
expr, collation, ..
} => format!(
"{} COLLATE {}",
expr_sql(expr, generator),
render_identifier_or_error(collation, generator)
),
Expr::FieldAccess { expr, field, .. } => format!(
"({}).{}",
expr_sql(expr, generator),
render_identifier_or_error(field, generator)
),
Expr::ArrayConstructor { elements, .. } => {
let elements = elements
.iter()
.map(|element| expr_sql(element, generator))
.collect::<Vec<_>>()
.join(", ");
format!("ARRAY[{elements}]")
}
Expr::RowConstructor { elements, .. } => {
let elements = elements
.iter()
.map(|element| expr_sql(element, generator))
.collect::<Vec<_>>()
.join(", ");
format!("ROW({elements})")
}
Expr::Subscript { expr, index, .. } => {
format!(
"{}[{}]",
expr_sql(expr, generator),
expr_sql(index, generator)
)
}
Expr::Subquery { query, .. } => format!("({})", read_only_subquery_sql(query)),
Expr::Exists { query, negated, .. } => {
if *negated {
format!("NOT EXISTS ({})", read_only_subquery_sql(query))
} else {
format!("EXISTS ({})", read_only_subquery_sql(query))
}
}
Expr::Def { .. } | Expr::Mod { .. } | Expr::Window { .. } => {
"/* ERROR: Invalid MERGE expression */".to_string()
}
}
}
fn render_function_name(name: &str) -> Option<String> {
if name.is_empty()
|| name.contains('\0')
|| name.split('.').any(str::is_empty)
|| !name
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
{
None
} else {
Some(name.to_uppercase())
}
}
fn render_sql_keyword(keyword: &str) -> Option<String> {
if keyword.is_empty()
|| keyword.contains('\0')
|| !keyword
.bytes()
.all(|b| b.is_ascii_alphabetic() || b == b'_')
{
None
} else {
Some(keyword.to_uppercase())
}
}
fn render_named_param(name: &str) -> String {
let mut chars = name.chars();
let Some(first) = chars.next() else {
return "/* ERROR: Invalid parameter name */".to_string();
};
if !(first.is_ascii_alphabetic() || first == '_')
|| !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
{
return "/* ERROR: Invalid parameter name */".to_string();
}
format!(":{}", name)
}
fn render_raw_function_value(value: &str) -> String {
if value.len() > 1024
|| value.contains('\0')
|| value.contains(';')
|| value.contains("--")
|| value.contains("/*")
|| value.contains("*/")
{
"/* ERROR: Invalid function expression */".to_string()
} else {
value.to_string()
}
}
fn render_identifier_or_error(value: &str, generator: &dyn SqlGenerator) -> String {
if value.is_empty() || value.as_bytes().contains(&0) || value.split('.').any(str::is_empty) {
"/* ERROR: Invalid identifier */".to_string()
} else {
generator.quote_identifier(value)
}
}
fn render_named_expr(name: &str, generator: &dyn SqlGenerator) -> String {
if contains_unquoted_statement_delimiter(name) {
return generator.quote_identifier(name);
}
if name == "*"
|| name.contains('(')
|| name.starts_with('\'')
|| name.starts_with('"')
|| name.starts_with(':')
|| name.starts_with('$')
|| name.parse::<f64>().is_ok()
|| name.eq_ignore_ascii_case("NULL")
|| name.eq_ignore_ascii_case("TRUE")
|| name.eq_ignore_ascii_case("FALSE")
{
name.to_string()
} else {
generator.quote_identifier(name)
}
}
fn checked_sql_type_fragment(fragment: &str) -> Option<String> {
let fragment = fragment.trim();
if fragment.is_empty()
|| fragment.contains('\0')
|| fragment.contains(';')
|| fragment.contains('\'')
|| fragment.contains('"')
|| fragment.contains("--")
|| fragment.contains("/*")
|| fragment.contains("*/")
|| !fragment.bytes().all(|b| {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'_' | b'.' | b' ' | b'(' | b')' | b',' | b'[' | b']' | b'%' | b'+' | b'-'
)
})
{
None
} else {
Some(fragment.to_string())
}
}
fn contains_unquoted_statement_delimiter(value: &str) -> bool {
let bytes = value.as_bytes();
let mut i = 0;
let mut in_single = false;
let mut in_double = false;
while i < bytes.len() {
let b = bytes[i];
if b == 0 {
return true;
}
if in_single {
if b == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
in_single = false;
}
i += 1;
continue;
}
if in_double {
if b == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
continue;
}
in_double = false;
}
i += 1;
continue;
}
match b {
b'\'' => in_single = true,
b'"' => in_double = true,
b';' => return true,
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => return true,
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => return true,
_ => {}
}
i += 1;
}
false
}
fn validate_merge_shape(merge: &Merge) -> Option<String> {
match &merge.source {
MergeSource::Table { name, .. } if name.trim().is_empty() => {
return Some("MERGE requires a USING source table or query".to_string());
}
MergeSource::Query { query, .. } => {
if let Some(error) = validate_merge_source_query(query) {
return Some(error);
}
}
_ => {}
}
for clause in &merge.clauses {
match (&clause.match_kind, &clause.action) {
(MergeMatchKind::Matched, MergeAction::Insert { .. }) => {
return Some("WHEN MATCHED cannot INSERT".to_string());
}
(MergeMatchKind::NotMatchedByTarget, MergeAction::Update { .. })
| (MergeMatchKind::NotMatchedByTarget, MergeAction::Delete) => {
return Some(
"WHEN NOT MATCHED BY TARGET can only INSERT or DO NOTHING".to_string(),
);
}
(MergeMatchKind::NotMatchedBySource, MergeAction::Insert { .. }) => {
return Some("WHEN NOT MATCHED BY SOURCE cannot INSERT".to_string());
}
(_, MergeAction::Update { assignments }) if assignments.is_empty() => {
return Some("MERGE UPDATE requires at least one assignment".to_string());
}
(_, MergeAction::Insert { columns, values }) => {
if values.is_empty() {
return Some("MERGE INSERT requires at least one value".to_string());
}
if !columns.is_empty() && columns.len() != values.len() {
return Some("MERGE INSERT column count must match value count".to_string());
}
}
_ => {}
}
}
None
}
fn validate_merge_source_query(query: &Qail) -> Option<String> {
if !matches!(query.action, Action::Get | Action::With) {
return Some(format!(
"MERGE source query must be read-only SELECT, got {}",
query.action
));
}
for cte in &query.ctes {
if let Some(error) = validate_merge_source_query(&cte.base_query) {
return Some(error);
}
if let Some(ref recursive_query) = cte.recursive_query
&& let Some(error) = validate_merge_source_query(recursive_query)
{
return Some(error);
}
}
for (_, set_query) in &query.set_ops {
if let Some(error) = validate_merge_source_query(set_query) {
return Some(error);
}
}
None
}