use std::sync::LazyLock;
use regex::Regex;
use sqlparser::ast::{
BinaryOperator, CastKind, DataType, Expr, GroupByExpr, OrderBy, OrderByExpr, OrderByKind,
Query, Select, SetExpr, Statement,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use thiserror::Error;
static RE_WHITESPACE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s+").expect("valid regex"));
static RE_TYPE_CAST: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"::([A-Za-z][A-Za-z0-9_\[\]]*)").expect("valid regex"));
const STRING_TEXT_CAST_PATTERN: &str = r"'([^']*)'::text";
static RE_STRING_TEXT_CAST: LazyLock<Regex> =
LazyLock::new(|| Regex::new(STRING_TEXT_CAST_PATTERN).expect("valid regex"));
static RE_STRING_TEXT_CAST_CI: LazyLock<Regex> =
LazyLock::new(|| Regex::new(&format!("(?i){STRING_TEXT_CAST_PATTERN}")).expect("valid regex"));
static RE_STRING_CUSTOM_CAST: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r#"'([^']*)'::(?:[a-z_][a-z0-9_]*\.)?"?[A-Za-z_][A-Za-z0-9_]*"?"#)
.expect("valid regex")
});
static RE_NULL_CAST: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r#"(?i)\bNULL::[a-zA-Z0-9_."]+(?:\.[a-zA-Z0-9_."]+)?"#).expect("valid regex")
});
static RE_NOT_ILIKE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\s*!~~\*\s*").expect("valid regex"));
static RE_ILIKE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\s*~~\*\s*").expect("valid regex"));
static RE_NOT_LIKE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\s*!~~\s*").expect("valid regex"));
static RE_LIKE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s*~~\s*").expect("valid regex"));
static RE_PAREN_OPEN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\(\s+").expect("valid regex"));
static RE_PAREN_CLOSE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\s+\)").expect("valid regex"));
static RE_FROM_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bFROM\s*\(").expect("valid regex"));
static RE_JOIN_PATTERN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^\s*\w+\s+\w*\s*JOIN\b").expect("valid regex"));
static RE_WHERE_DOUBLE_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bWHERE\s*\(\(").expect("valid regex"));
static RE_WHERE_SINGLE_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bWHERE\s*\(").expect("valid regex"));
static RE_DOUBLE_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\(\(([^()]*)\)\)").expect("valid regex"));
static RE_ON_PARENS: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bON\s*\(([^()]+)\)").expect("valid regex"));
static RE_OR_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\bOR\s*\(").expect("valid regex"));
static RE_SIMPLE_PAREN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\(([^()]+)\)").expect("valid regex"));
fn find_password_span(url: &str) -> Option<(usize, usize)> {
let at_position = url.find('@')?;
let search_start = url.find("://").map(|position| position + 3).unwrap_or(0);
if search_start >= at_position {
return None;
}
let colon_offset = url[search_start..at_position].rfind(':')?;
let colon_position = search_start + colon_offset;
if colon_position + 1 == at_position {
return None;
}
Some((colon_position, at_position))
}
pub fn sanitize_url(url: &str) -> String {
match find_password_span(url) {
Some((colon_position, at_position)) => {
format!("{}****{}", &url[..colon_position + 1], &url[at_position..])
}
None => url.to_string(),
}
}
fn extract_password(url: &str) -> Option<String> {
let (colon_position, at_position) = find_password_span(url)?;
Some(url[colon_position + 1..at_position].to_string())
}
pub fn sanitize_connection_error(connection_url: &str, error_message: &str) -> String {
match extract_password(connection_url) {
Some(password) if password.len() >= 3 => {
let mut result = error_message.replace(&password, "****");
let decoded = simple_percent_decode(&password);
if decoded != password {
result = result.replace(&decoded, "****");
}
result
}
_ => error_message.to_string(),
}
}
fn simple_percent_decode(input: &str) -> String {
let mut raw_bytes = Vec::with_capacity(input.len());
let bytes = input.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let Ok(byte) = u8::from_str_radix(&input[i + 1..i + 3], 16) {
raw_bytes.push(byte);
i += 3;
continue;
}
}
raw_bytes.push(bytes[i]);
i += 1;
}
String::from_utf8(raw_bytes).expect("percent-decoded bytes are valid UTF-8")
}
pub(crate) fn strip_dollar_quotes(body: &str) -> String {
let trimmed = body.trim();
if !trimmed.starts_with('$') {
return body.to_string();
}
if let Some(tag_end) = trimmed[1..].find('$') {
let tag = &trimmed[..=tag_end + 1];
if let Some(content) = trimmed.strip_prefix(tag) {
if let Some(inner) = content.strip_suffix(tag) {
return inner.to_string();
}
}
}
body.to_string()
}
pub fn normalize_sql_whitespace(sql: &str) -> String {
RE_WHITESPACE.replace_all(sql.trim(), " ").to_string()
}
pub fn normalize_type_casts(expr: &str) -> String {
RE_TYPE_CAST
.replace_all(expr, |caps: ®ex::Captures| {
format!("::{}", caps[1].to_lowercase())
})
.to_string()
}
fn is_numeric_type(dt: &DataType) -> bool {
matches!(
dt,
DataType::Int(_)
| DataType::Integer(_)
| DataType::BigInt(_)
| DataType::SmallInt(_)
| DataType::TinyInt(_)
| DataType::Numeric(_)
| DataType::Decimal(_)
| DataType::Float(_)
| DataType::Real
| DataType::Double(_)
| DataType::DoublePrecision
)
}
fn apply_common_normalizations(expr: &str) -> String {
let result = RE_NOT_ILIKE.replace_all(expr, " NOT ILIKE ");
let result = RE_ILIKE.replace_all(&result, " ILIKE ");
let result = RE_NOT_LIKE.replace_all(&result, " NOT LIKE ");
let result = RE_LIKE.replace_all(&result, " LIKE ");
let result = normalize_type_casts(&result);
let result = RE_WHITESPACE.replace_all(result.trim(), " ");
let result = RE_PAREN_OPEN.replace_all(&result, "(");
RE_PAREN_CLOSE.replace_all(&result, ")").to_string()
}
fn normalize_expression_regex(expr: &str) -> String {
let result = RE_STRING_CUSTOM_CAST.replace_all(expr, "'$1'");
let result = RE_STRING_TEXT_CAST.replace_all(&result, "'$1'");
let result = RE_NULL_CAST.replace_all(&result, "NULL");
apply_common_normalizations(&result)
}
fn find_matching_paren(s: &str, open_pos: usize) -> Option<usize> {
let bytes = s.as_bytes();
if bytes.get(open_pos).copied() != Some(b'(') {
return None;
}
let mut depth: u32 = 0;
for (i, &byte) in bytes.iter().enumerate().skip(open_pos) {
match byte {
b'(' => depth += 1,
b')' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
None
}
fn remove_byte_pair(s: &str, first: usize, second: usize) -> String {
assert!(first < second);
format!(
"{}{}{}",
&s[..first],
&s[first + 1..second],
&s[second + 1..]
)
}
fn remove_outer_parens_around_pattern(s: &str, pattern: &str) -> String {
let search = format!("({pattern}");
let mut result = s.to_string();
while let Some(pos) = result.find(&search) {
if let Some(close_pos) = find_matching_paren(&result, pos) {
result = remove_byte_pair(&result, pos, close_pos);
} else {
break;
}
}
result
}
fn remove_from_join_parens(s: &str) -> String {
apply_until_stable(s.to_string(), |input| {
if let Some(mat) = RE_FROM_PAREN.find(input) {
let open_pos = mat.end() - 1;
let after_paren = &input[mat.end()..];
if RE_JOIN_PATTERN.is_match(after_paren) {
if let Some(close_pos) = find_matching_paren(input, open_pos) {
return Some(remove_byte_pair(input, open_pos, close_pos));
}
}
}
None
})
}
fn apply_until_stable<F>(mut input: String, mut transform: F) -> String
where
F: FnMut(&str) -> Option<String>,
{
loop {
match transform(&input) {
Some(new) => input = new,
None => return input,
}
}
}
fn remove_where_outer_parens(s: &str) -> String {
let result = apply_until_stable(s.to_string(), |input| {
if let Some(mat) = RE_WHERE_DOUBLE_PAREN.find(input) {
let outer_open_pos = mat.end() - 2;
if let Some(outer_close_pos) = find_matching_paren(input, outer_open_pos) {
if let Some(inner_close) = find_matching_paren(input, mat.end() - 1) {
let between = &input[inner_close + 1..outer_close_pos];
let trimmed = between.trim();
if trimmed.is_empty() || trimmed.starts_with("AND") || trimmed.starts_with("OR")
{
return Some(remove_byte_pair(input, outer_open_pos, outer_close_pos));
}
}
}
}
None
});
apply_until_stable(result, |input| {
for mat in RE_WHERE_SINGLE_PAREN.find_iter(input) {
let open_pos = mat.end() - 1;
if let Some(close_pos) = find_matching_paren(input, open_pos) {
let after_close = input[close_pos + 1..].trim_start();
if after_close.is_empty()
|| after_close.starts_with("ORDER")
|| after_close.starts_with("GROUP")
|| after_close.starts_with("HAVING")
|| after_close.starts_with("LIMIT")
|| after_close.starts_with("OFFSET")
|| after_close.starts_with("UNION")
|| after_close.starts_with("INTERSECT")
|| after_close.starts_with("EXCEPT")
|| after_close.starts_with(")")
|| after_close.starts_with(";")
{
return Some(remove_byte_pair(input, open_pos, close_pos));
}
}
}
None
})
}
fn strip_text_cast_from_string_literals(query: &str) -> String {
RE_STRING_TEXT_CAST_CI
.replace_all(query, "'$1'")
.to_string()
}
fn collapse_double_parens(query: &str) -> String {
apply_until_stable(query.to_string(), |input| {
match RE_DOUBLE_PAREN.replace_all(input, "($1)") {
std::borrow::Cow::Borrowed(_) => None,
std::borrow::Cow::Owned(s) => Some(s),
}
})
}
fn strip_on_clause_parens(query: &str) -> String {
RE_ON_PARENS.replace_all(query, "ON $1").to_string()
}
fn remove_parens_around_and_groups_in_or(query: &str) -> String {
apply_until_stable(query.to_string(), |input| {
if let Some(mat) = RE_OR_PAREN.find(input) {
let open_pos = mat.end() - 1;
if let Some(close_pos) = find_matching_paren(input, open_pos) {
let content = &input[open_pos + 1..close_pos];
if content.contains(" AND ") && !content.contains(" OR ") {
return Some(remove_byte_pair(input, open_pos, close_pos));
}
}
}
None
})
}
fn remove_simple_expression_parens(query: &str) -> String {
apply_until_stable(query.to_string(), |input| {
let new = RE_SIMPLE_PAREN
.replace_all(input, |caps: ®ex::Captures| {
let content = &caps[1];
if !content.contains(" AND ")
&& !content.contains(" OR ")
&& !content.contains(',')
&& !content.to_uppercase().contains("SELECT")
{
content.to_string()
} else {
caps[0].to_string()
}
})
.to_string();
if new != input {
Some(new)
} else {
None
}
})
}
fn remove_structural_parens(query: &str) -> String {
let result = remove_outer_parens_around_pattern(query, "EXISTS");
let result = remove_from_join_parens(&result);
remove_where_outer_parens(&result)
}
pub fn normalize_view_query(query: &str) -> String {
let result = strip_text_cast_from_string_literals(query);
let result = apply_common_normalizations(&result);
let result = collapse_double_parens(&result);
let result = strip_on_clause_parens(&result);
let result = remove_parens_around_and_groups_in_or(&result);
let result = remove_simple_expression_parens(&result);
remove_structural_parens(&result)
}
pub fn views_semantically_equal(query1: &str, query2: &str) -> bool {
let dialect = PostgreSqlDialect {};
let ast1 = Parser::parse_sql(&dialect, query1);
let ast2 = Parser::parse_sql(&dialect, query2);
match (ast1, ast2) {
(Ok(stmts1), Ok(stmts2)) => {
if stmts1.len() != stmts2.len() {
return false;
}
stmts1
.into_iter()
.zip(stmts2)
.all(|(s1, s2)| normalize_statement(&s1) == normalize_statement(&s2))
}
_ => {
normalize_view_query(query1) == normalize_view_query(query2)
}
}
}
pub fn expressions_semantically_equal(expr1: &str, expr2: &str) -> bool {
let dialect = PostgreSqlDialect {};
let parse1 = Parser::new(&dialect)
.try_with_sql(expr1)
.and_then(|mut p| p.parse_expr());
let parse2 = Parser::new(&dialect)
.try_with_sql(expr2)
.and_then(|mut p| p.parse_expr());
match (parse1, parse2) {
(Ok(ast1), Ok(ast2)) => normalize_expr(&ast1) == normalize_expr(&ast2),
_ => {
normalize_expression_regex(expr1) == normalize_expression_regex(expr2)
}
}
}
pub fn optional_expressions_equal(expr1: &Option<String>, expr2: &Option<String>) -> bool {
match (expr1, expr2) {
(None, None) => true,
(Some(e1), Some(e2)) => expressions_semantically_equal(e1, e2),
_ => false,
}
}
fn normalize_statement(stmt: &Statement) -> Statement {
match stmt {
Statement::Query(query) => Statement::Query(Box::new(normalize_query(query))),
other => other.clone(),
}
}
fn normalize_query(query: &Query) -> Query {
Query {
with: query.with.as_ref().map(|w| sqlparser::ast::With {
with_token: w.with_token.clone(),
recursive: w.recursive,
cte_tables: w
.cte_tables
.iter()
.map(|cte| sqlparser::ast::Cte {
alias: cte.alias.clone(),
query: Box::new(normalize_query(&cte.query)),
from: cte.from.clone(),
materialized: cte.materialized.clone(),
closing_paren_token: cte.closing_paren_token.clone(),
})
.collect(),
}),
body: Box::new(normalize_set_expr(&query.body)),
order_by: query.order_by.as_ref().map(normalize_order_by),
limit_clause: query.limit_clause.clone(),
fetch: query.fetch.clone(),
locks: query.locks.clone(),
for_clause: query.for_clause.clone(),
settings: query.settings.clone(),
format_clause: query.format_clause.clone(),
pipe_operators: query.pipe_operators.clone(),
}
}
fn normalize_group_by(group_by: &GroupByExpr) -> GroupByExpr {
match group_by {
GroupByExpr::Expressions(exprs, modifiers) => GroupByExpr::Expressions(
exprs.iter().map(normalize_expr).collect(),
modifiers.clone(),
),
other => other.clone(),
}
}
fn normalize_order_by(order_by: &OrderBy) -> OrderBy {
OrderBy {
kind: match &order_by.kind {
OrderByKind::Expressions(exprs) => OrderByKind::Expressions(
exprs
.iter()
.map(|e| OrderByExpr {
expr: normalize_expr(&e.expr),
options: e.options,
with_fill: e.with_fill.clone(),
})
.collect(),
),
other => other.clone(),
},
interpolate: order_by.interpolate.clone(),
}
}
fn normalize_set_expr(body: &SetExpr) -> SetExpr {
match body {
SetExpr::Select(select) => SetExpr::Select(Box::new(normalize_select(select))),
SetExpr::Query(q) => SetExpr::Query(Box::new(normalize_query(q))),
SetExpr::SetOperation {
op,
set_quantifier,
left,
right,
} => SetExpr::SetOperation {
op: *op,
set_quantifier: *set_quantifier,
left: Box::new(normalize_set_expr(left)),
right: Box::new(normalize_set_expr(right)),
},
other => other.clone(),
}
}
fn normalize_ident(ident: &sqlparser::ast::Ident) -> sqlparser::ast::Ident {
sqlparser::ast::Ident {
value: ident.value.to_lowercase(),
quote_style: None,
span: ident.span,
}
}
fn normalize_object_name(name: &sqlparser::ast::ObjectName) -> sqlparser::ast::ObjectName {
let normalized_parts: Vec<_> = name
.0
.iter()
.map(|part| match part {
sqlparser::ast::ObjectNamePart::Identifier(ident) => {
sqlparser::ast::ObjectNamePart::Identifier(normalize_ident(ident))
}
other => other.clone(),
})
.collect();
if normalized_parts.len() == 2 {
if let sqlparser::ast::ObjectNamePart::Identifier(first_ident) = &normalized_parts[0] {
if first_ident.value == "public" {
return sqlparser::ast::ObjectName(vec![normalized_parts[1].clone()]);
}
}
}
sqlparser::ast::ObjectName(normalized_parts)
}
fn normalize_function_arg_expr(
arg_expr: &sqlparser::ast::FunctionArgExpr,
) -> sqlparser::ast::FunctionArgExpr {
match arg_expr {
sqlparser::ast::FunctionArgExpr::Expr(e) => {
sqlparser::ast::FunctionArgExpr::Expr(normalize_expr(e))
}
other => other.clone(),
}
}
fn normalize_function_arg(arg: &sqlparser::ast::FunctionArg) -> sqlparser::ast::FunctionArg {
match arg {
sqlparser::ast::FunctionArg::Unnamed(arg_expr) => {
sqlparser::ast::FunctionArg::Unnamed(normalize_function_arg_expr(arg_expr))
}
sqlparser::ast::FunctionArg::Named {
name,
arg,
operator,
} => sqlparser::ast::FunctionArg::Named {
name: normalize_ident(name),
arg: normalize_function_arg_expr(arg),
operator: operator.clone(),
},
sqlparser::ast::FunctionArg::ExprNamed {
name,
arg,
operator,
} => sqlparser::ast::FunctionArg::ExprNamed {
name: normalize_expr(name),
arg: normalize_function_arg_expr(arg),
operator: operator.clone(),
},
}
}
fn normalize_window_spec(spec: &sqlparser::ast::WindowSpec) -> sqlparser::ast::WindowSpec {
sqlparser::ast::WindowSpec {
window_name: spec.window_name.clone(),
partition_by: spec.partition_by.iter().map(normalize_expr).collect(),
order_by: spec
.order_by
.iter()
.map(|e| sqlparser::ast::OrderByExpr {
expr: normalize_expr(&e.expr),
options: e.options,
with_fill: e.with_fill.clone(),
})
.collect(),
window_frame: spec
.window_frame
.as_ref()
.map(|wf| sqlparser::ast::WindowFrame {
units: wf.units,
start_bound: normalize_window_frame_bound(&wf.start_bound),
end_bound: wf.end_bound.as_ref().map(normalize_window_frame_bound),
}),
}
}
fn normalize_window_frame_bound(
bound: &sqlparser::ast::WindowFrameBound,
) -> sqlparser::ast::WindowFrameBound {
match bound {
sqlparser::ast::WindowFrameBound::Preceding(Some(e)) => {
sqlparser::ast::WindowFrameBound::Preceding(Some(Box::new(normalize_expr(e))))
}
sqlparser::ast::WindowFrameBound::Following(Some(e)) => {
sqlparser::ast::WindowFrameBound::Following(Some(Box::new(normalize_expr(e))))
}
other => other.clone(),
}
}
fn normalize_table_factor(factor: &sqlparser::ast::TableFactor) -> sqlparser::ast::TableFactor {
use sqlparser::ast::TableFactor;
match factor {
TableFactor::Table {
name,
alias,
args,
with_hints,
version,
with_ordinality,
partitions,
json_path,
sample,
index_hints,
} => TableFactor::Table {
name: normalize_object_name(name),
alias: alias.as_ref().map(|a| sqlparser::ast::TableAlias {
name: normalize_ident(&a.name),
explicit: a.explicit,
columns: a.columns.clone(),
}),
args: args.clone(),
with_hints: with_hints.clone(),
version: version.clone(),
with_ordinality: *with_ordinality,
partitions: partitions.clone(),
json_path: json_path.clone(),
sample: sample.clone(),
index_hints: index_hints.clone(),
},
TableFactor::Derived {
lateral,
subquery,
alias,
} => TableFactor::Derived {
lateral: *lateral,
subquery: Box::new(normalize_query(subquery)),
alias: alias.as_ref().map(|a| sqlparser::ast::TableAlias {
name: normalize_ident(&a.name),
explicit: a.explicit,
columns: a.columns.clone(),
}),
},
TableFactor::NestedJoin {
table_with_joins,
alias,
} => {
let normalized_twj = normalize_table_with_joins(table_with_joins);
if normalized_twj.joins.is_empty() {
let mut inner = normalized_twj.relation;
if let Some(a) = alias {
if let TableFactor::Table {
alias: ref mut table_alias,
..
} = &mut inner
{
*table_alias = Some(sqlparser::ast::TableAlias {
name: normalize_ident(&a.name),
explicit: a.explicit,
columns: a.columns.clone(),
});
}
}
inner
} else {
TableFactor::NestedJoin {
table_with_joins: Box::new(normalized_twj),
alias: alias.as_ref().map(|a| sqlparser::ast::TableAlias {
name: normalize_ident(&a.name),
explicit: a.explicit,
columns: a.columns.clone(),
}),
}
}
}
other => other.clone(),
}
}
fn normalize_table_with_joins(
twj: &sqlparser::ast::TableWithJoins,
) -> sqlparser::ast::TableWithJoins {
if let sqlparser::ast::TableFactor::NestedJoin {
table_with_joins: inner_twj,
alias,
} = &twj.relation
{
if alias.is_none() {
let normalized_inner = normalize_table_with_joins(inner_twj);
let normalized_outer_joins: Vec<_> = twj.joins.iter().map(normalize_join).collect();
let mut combined_joins = normalized_inner.joins;
combined_joins.extend(normalized_outer_joins);
return sqlparser::ast::TableWithJoins {
relation: normalized_inner.relation,
joins: combined_joins,
};
}
}
let normalized_relation = normalize_table_factor(&twj.relation);
sqlparser::ast::TableWithJoins {
relation: normalized_relation,
joins: twj.joins.iter().map(normalize_join).collect(),
}
}
fn normalize_join(j: &sqlparser::ast::Join) -> sqlparser::ast::Join {
use sqlparser::ast::{Join, JoinOperator};
let normalize_constraint = normalize_join_constraint;
Join {
relation: normalize_table_factor(&j.relation),
global: j.global,
join_operator: match &j.join_operator {
JoinOperator::Join(c) | JoinOperator::Inner(c) => {
JoinOperator::Join(normalize_constraint(c))
}
JoinOperator::Left(c) | JoinOperator::LeftOuter(c) => {
JoinOperator::Left(normalize_constraint(c))
}
JoinOperator::Right(c) | JoinOperator::RightOuter(c) => {
JoinOperator::Right(normalize_constraint(c))
}
JoinOperator::FullOuter(c) => JoinOperator::FullOuter(normalize_constraint(c)),
other => other.clone(),
},
}
}
fn normalize_join_constraint(
constraint: &sqlparser::ast::JoinConstraint,
) -> sqlparser::ast::JoinConstraint {
use sqlparser::ast::JoinConstraint;
match constraint {
JoinConstraint::On(expr) => JoinConstraint::On(normalize_expr(expr)),
JoinConstraint::Using(names) => {
JoinConstraint::Using(names.iter().map(normalize_object_name).collect())
}
other => other.clone(),
}
}
fn normalize_data_type(data_type: &DataType) -> DataType {
match data_type {
DataType::Varchar(length) => DataType::CharacterVarying(*length),
DataType::Char(length) => DataType::Character(*length),
DataType::Bool => DataType::Boolean,
DataType::Float4 => DataType::Real,
DataType::Float8 => DataType::DoublePrecision,
DataType::Int2(n) => DataType::SmallInt(*n),
DataType::Int(n) => DataType::Integer(*n),
DataType::Int4(n) => DataType::Integer(*n),
DataType::Int8(n) => DataType::BigInt(*n),
other => other.clone(),
}
}
fn try_simplify_scalar_subquery(query: &Query) -> Option<Expr> {
if query.with.is_some() || query.order_by.is_some() || query.limit_clause.is_some() {
return None;
}
let SetExpr::Select(select) = query.body.as_ref() else {
return None;
};
if select.distinct.is_some()
|| !select.from.is_empty()
|| select.selection.is_some()
|| !matches!(select.group_by, sqlparser::ast::GroupByExpr::Expressions(ref exprs, _) if exprs.is_empty())
|| select.having.is_some()
{
return None;
}
if select.projection.len() != 1 {
return None;
}
let expr = match &select.projection[0] {
sqlparser::ast::SelectItem::UnnamedExpr(e) => e,
sqlparser::ast::SelectItem::ExprWithAlias { expr, alias } => {
if !is_auto_generated_alias(expr, alias) {
return None;
}
expr
}
_ => return None,
};
if !matches!(expr, Expr::Function(_)) {
return None;
}
Some(normalize_expr(expr))
}
fn normalize_select(select: &Select) -> Select {
Select {
select_token: select.select_token.clone(),
distinct: select.distinct.clone(),
top: select.top.clone(),
top_before_distinct: select.top_before_distinct,
projection: select
.projection
.iter()
.map(normalize_select_item)
.collect(),
exclude: select.exclude.clone(),
into: select.into.clone(),
from: select.from.iter().map(normalize_table_with_joins).collect(),
lateral_views: select.lateral_views.clone(),
prewhere: select.prewhere.as_ref().map(normalize_expr),
selection: select.selection.as_ref().map(normalize_expr),
group_by: normalize_group_by(&select.group_by),
cluster_by: select.cluster_by.clone(),
distribute_by: select.distribute_by.clone(),
sort_by: select.sort_by.clone(),
having: select.having.as_ref().map(normalize_expr),
named_window: select.named_window.clone(),
qualify: select.qualify.as_ref().map(normalize_expr),
window_before_qualify: select.window_before_qualify,
value_table_mode: select.value_table_mode,
connect_by: select.connect_by.clone(),
flavor: select.flavor.clone(),
}
}
fn normalize_select_item(item: &sqlparser::ast::SelectItem) -> sqlparser::ast::SelectItem {
use sqlparser::ast::SelectItem;
match item {
SelectItem::UnnamedExpr(e) => SelectItem::UnnamedExpr(normalize_expr(e)),
SelectItem::ExprWithAlias { expr, alias } => {
let normalized_expr = normalize_expr(expr);
if is_auto_generated_alias(&normalized_expr, alias) {
SelectItem::UnnamedExpr(normalized_expr)
} else {
SelectItem::ExprWithAlias {
expr: normalized_expr,
alias: alias.clone(),
}
}
}
other => other.clone(),
}
}
fn is_auto_generated_alias(expr: &Expr, alias: &sqlparser::ast::Ident) -> bool {
if let Expr::Function(f) = expr {
if let Some(sqlparser::ast::ObjectNamePart::Identifier(ident)) = f.name.0.last() {
return ident.value.to_lowercase() == alias.value.to_lowercase();
}
}
false
}
fn normalize_expr(expr: &Expr) -> Expr {
match expr {
Expr::Nested(inner) => normalize_expr(inner),
Expr::BinaryOp { left, op, right } => {
let norm_left = normalize_expr(left);
let norm_right = normalize_expr(right);
match op {
BinaryOperator::PGLikeMatch => Expr::Like {
negated: false,
any: false,
expr: Box::new(norm_left),
pattern: Box::new(norm_right),
escape_char: None,
},
BinaryOperator::PGNotLikeMatch => Expr::Like {
negated: true,
any: false,
expr: Box::new(norm_left),
pattern: Box::new(norm_right),
escape_char: None,
},
BinaryOperator::PGILikeMatch => Expr::ILike {
negated: false,
any: false,
expr: Box::new(norm_left),
pattern: Box::new(norm_right),
escape_char: None,
},
BinaryOperator::PGNotILikeMatch => Expr::ILike {
negated: true,
any: false,
expr: Box::new(norm_left),
pattern: Box::new(norm_right),
escape_char: None,
},
_ => Expr::BinaryOp {
left: Box::new(norm_left),
op: op.clone(),
right: Box::new(norm_right),
},
}
}
Expr::Cast {
expr: inner,
data_type,
format,
..
} => {
let norm_inner = normalize_expr(inner);
let norm_data_type = normalize_data_type(data_type);
if matches!(norm_data_type, DataType::Text) {
return norm_inner;
}
if matches!(
norm_inner,
Expr::Identifier(_) | Expr::CompoundIdentifier(_)
) && (matches!(norm_data_type, DataType::CharacterVarying(None))
|| is_numeric_type(&norm_data_type))
{
return norm_inner;
}
if let Expr::Value(v) = &norm_inner {
let should_strip = match &v.value {
sqlparser::ast::Value::SingleQuotedString(_) => {
matches!(
norm_data_type,
DataType::Custom(_, _)
| DataType::Array(_)
| DataType::CharacterVarying(None)
)
}
sqlparser::ast::Value::Number(_, _) => is_numeric_type(&norm_data_type),
sqlparser::ast::Value::Null => true,
_ => false,
};
if should_strip {
return norm_inner;
}
let is_interval_literal = matches!(norm_data_type, DataType::Interval { .. })
&& matches!(v.value, sqlparser::ast::Value::SingleQuotedString(_));
if is_interval_literal {
return Expr::Interval(sqlparser::ast::Interval {
value: Box::new(norm_inner),
leading_field: None,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
});
}
}
Expr::Cast {
kind: CastKind::DoubleColon,
expr: Box::new(norm_inner),
data_type: norm_data_type,
format: format.clone(),
}
}
Expr::Subquery(q) => {
if let Some(simplified) = try_simplify_scalar_subquery(q) {
simplified
} else {
Expr::Subquery(Box::new(normalize_query(q)))
}
}
Expr::Exists { subquery, negated } => Expr::Exists {
subquery: Box::new(normalize_query(subquery)),
negated: *negated,
},
Expr::InSubquery {
expr: inner,
subquery,
negated,
} => Expr::InSubquery {
expr: Box::new(normalize_expr(inner)),
subquery: Box::new(normalize_query(subquery)),
negated: *negated,
},
Expr::Like {
negated,
any,
expr: inner,
pattern,
escape_char,
} => Expr::Like {
negated: *negated,
any: *any,
expr: Box::new(normalize_expr(inner)),
pattern: Box::new(normalize_expr(pattern)),
escape_char: escape_char.clone(),
},
Expr::ILike {
negated,
any,
expr: inner,
pattern,
escape_char,
} => Expr::ILike {
negated: *negated,
any: *any,
expr: Box::new(normalize_expr(inner)),
pattern: Box::new(normalize_expr(pattern)),
escape_char: escape_char.clone(),
},
Expr::Case {
case_token,
end_token,
operand,
conditions,
else_result,
} => Expr::Case {
case_token: case_token.clone(),
end_token: end_token.clone(),
operand: operand.as_ref().map(|e| Box::new(normalize_expr(e))),
conditions: conditions
.iter()
.map(|cw| sqlparser::ast::CaseWhen {
condition: normalize_expr(&cw.condition),
result: normalize_expr(&cw.result),
})
.collect(),
else_result: else_result.as_ref().map(|e| Box::new(normalize_expr(e))),
},
Expr::Function(f) => {
let mut func = f.clone();
func.name = normalize_object_name(&f.name);
func.args = match &f.args {
sqlparser::ast::FunctionArguments::List(args) => {
sqlparser::ast::FunctionArguments::List(sqlparser::ast::FunctionArgumentList {
duplicate_treatment: args.duplicate_treatment,
args: args.args.iter().map(normalize_function_arg).collect(),
clauses: args.clauses.clone(),
})
}
other => other.clone(),
};
func.filter = f.filter.as_ref().map(|e| Box::new(normalize_expr(e)));
func.over = f.over.as_ref().map(|w| match w {
sqlparser::ast::WindowType::WindowSpec(spec) => {
sqlparser::ast::WindowType::WindowSpec(normalize_window_spec(spec))
}
other => other.clone(),
});
Expr::Function(func)
}
Expr::UnaryOp { op, expr: inner } => {
let norm_inner = normalize_expr(inner);
if matches!(op, sqlparser::ast::UnaryOperator::Not) {
if let Expr::Exists {
subquery,
negated: false,
} = norm_inner
{
return Expr::Exists {
subquery,
negated: true,
};
}
}
Expr::UnaryOp {
op: *op,
expr: Box::new(norm_inner),
}
}
Expr::InList {
expr: inner,
list,
negated,
} => Expr::InList {
expr: Box::new(normalize_expr(inner)),
list: list.iter().map(normalize_expr).collect(),
negated: *negated,
},
Expr::Between {
expr: inner,
negated,
low,
high,
} => Expr::Between {
expr: Box::new(normalize_expr(inner)),
negated: *negated,
low: Box::new(normalize_expr(low)),
high: Box::new(normalize_expr(high)),
},
Expr::IsNull(inner) => Expr::IsNull(Box::new(normalize_expr(inner))),
Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(normalize_expr(inner))),
Expr::IsDistinctFrom(left, right) => Expr::IsDistinctFrom(
Box::new(normalize_expr(left)),
Box::new(normalize_expr(right)),
),
Expr::IsNotDistinctFrom(left, right) => Expr::IsNotDistinctFrom(
Box::new(normalize_expr(left)),
Box::new(normalize_expr(right)),
),
Expr::CompoundIdentifier(idents) => {
let normalized: Vec<_> = idents
.iter()
.map(|ident| sqlparser::ast::Ident {
value: ident.value.to_lowercase(),
quote_style: None,
span: ident.span,
})
.collect();
if normalized.len() == 2 {
Expr::Identifier(normalized[1].clone())
} else {
Expr::CompoundIdentifier(normalized)
}
}
Expr::Identifier(ident) => Expr::Identifier(sqlparser::ast::Ident {
value: ident.value.to_lowercase(),
quote_style: None,
span: ident.span,
}),
Expr::AnyOp {
left,
compare_op,
right,
..
} if *compare_op == BinaryOperator::Eq => {
let norm_left = normalize_expr(left);
let norm_right = normalize_expr(right);
if let Expr::Array(arr) = &norm_right {
Expr::InList {
expr: Box::new(norm_left),
list: arr.elem.iter().map(normalize_expr).collect(),
negated: false,
}
} else {
Expr::AnyOp {
left: Box::new(norm_left),
compare_op: compare_op.clone(),
right: Box::new(norm_right),
is_some: false,
}
}
}
Expr::AllOp {
left,
compare_op,
right,
} if *compare_op == BinaryOperator::NotEq => {
let norm_left = normalize_expr(left);
let norm_right = normalize_expr(right);
if let Expr::Array(arr) = &norm_right {
Expr::InList {
expr: Box::new(norm_left),
list: arr.elem.iter().map(normalize_expr).collect(),
negated: true,
}
} else {
Expr::AllOp {
left: Box::new(norm_left),
compare_op: compare_op.clone(),
right: Box::new(norm_right),
}
}
}
Expr::Array(arr) => Expr::Array(sqlparser::ast::Array {
elem: arr.elem.iter().map(normalize_expr).collect(),
named: arr.named,
}),
other => other.clone(),
}
}
#[derive(Error, Debug)]
pub enum SchemaError {
#[error("Parse error: {0}")]
ParseError(String),
#[error("Database error: {0}")]
DatabaseError(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Lint error: {0}")]
LintError(String),
}
pub type Result<T> = std::result::Result<T, SchemaError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_view_query_strips_text_cast_from_string_literals() {
let input = "SELECT 'supplier'::text AS type FROM users";
let expected = "SELECT 'supplier' AS type FROM users";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_converts_tilde_tilde_to_like() {
let input = "SELECT * FROM users WHERE name ~~ 'test%'";
let expected = "SELECT * FROM users WHERE name LIKE 'test%'";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_combined_patterns() {
let input = "SELECT * FROM users WHERE type ~~ 'supplier'::text";
let expected = "SELECT * FROM users WHERE type LIKE 'supplier'";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_lowercases_type_casts() {
let input = "SELECT id::TEXT, name::VARCHAR FROM users";
let expected = "SELECT id::text, name::varchar FROM users";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_collapses_whitespace() {
let input = "SELECT id,
name FROM users";
let expected = "SELECT id, name FROM users";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_removes_spaces_around_parens() {
let input = "SELECT * FROM ( SELECT id FROM users )";
let expected = "SELECT * FROM (SELECT id FROM users)";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_not_like_operator() {
let input = "SELECT * FROM users WHERE name !~~ 'test%'";
let expected = "SELECT * FROM users WHERE name NOT LIKE 'test%'";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_normalizes_double_parentheses() {
let input = "SELECT * FROM a JOIN b ON ((a.id = b.id))";
let expected = "SELECT * FROM a JOIN b ON a.id = b.id";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_nested_double_parentheses() {
let input = "SELECT * FROM a WHERE (((x > 0)))";
let expected = "SELECT * FROM a WHERE x > 0";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_removes_outer_parens_in_where_compound() {
let input = "SELECT * FROM a WHERE ((x > 0) AND (y < 10))";
let expected = "SELECT * FROM a WHERE x > 0 AND y < 10";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_complex_postgresql_normalization() {
let input = "SELECT 'enterprise'::text AS type, (r.name ~~ 'enterprise_%'::text) AS is_enterprise FROM roles r";
let expected =
"SELECT 'enterprise' AS type, r.name LIKE 'enterprise_%' AS is_enterprise FROM roles r";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_ilike_operator() {
let input = "SELECT * FROM users WHERE name ~~* 'Test%'";
let expected = "SELECT * FROM users WHERE name ILIKE 'Test%'";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_not_ilike_operator() {
let input = "SELECT * FROM users WHERE name !~~* 'Test%'";
let expected = "SELECT * FROM users WHERE name NOT ILIKE 'Test%'";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_exists_with_nested_join() {
let input = "(EXISTS (SELECT 1 FROM (roles r JOIN user_roles ur ON ((ur.role_id = r.id))) WHERE ((ur.user_id = u.id) AND (r.name ~~ 'admin_%'::text))))";
let expected = "EXISTS (SELECT 1 FROM roles r JOIN user_roles ur ON ur.role_id = r.id WHERE ur.user_id = u.id AND r.name LIKE 'admin_%')";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_complex_view_with_case_and_exists() {
let input = "SELECT u.id, u.email, 'active'::text AS status, CASE WHEN (EXISTS (SELECT 1 FROM (roles r JOIN user_roles ur ON ((ur.role_id = r.id))) WHERE ((ur.user_id = u.id) AND (r.name ~~ 'admin_%'::text)))) THEN 'admin'::text ELSE 'user'::text END AS role_type FROM users u WHERE (EXISTS (SELECT 1 FROM (user_roles ur JOIN roles r ON ((ur.role_id = r.id))) WHERE ((ur.user_id = u.id) AND (r.name ~~ 'enterprise_%'::text))))";
let expected = "SELECT u.id, u.email, 'active' AS status, CASE WHEN EXISTS (SELECT 1 FROM roles r JOIN user_roles ur ON ur.role_id = r.id WHERE ur.user_id = u.id AND r.name LIKE 'admin_%') THEN 'admin' ELSE 'user' END AS role_type FROM users u WHERE EXISTS (SELECT 1 FROM user_roles ur JOIN roles r ON ur.role_id = r.id WHERE ur.user_id = u.id AND r.name LIKE 'enterprise_%')";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_uppercase_text_cast() {
let input = "SELECT 'app_admin'::TEXT, name::VARCHAR FROM users";
let expected = "SELECT 'app_admin', name::varchar FROM users";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_strips_text_cast_case_insensitive() {
let input = "SELECT 'value'::TEXT AS col FROM t";
let expected = "SELECT 'value' AS col FROM t";
assert_eq!(normalize_view_query(input), expected);
}
#[test]
fn normalize_view_query_handles_on_clause_parens() {
let db_form = "SELECT * FROM a JOIN b ON a.id = b.id";
let schema_form = "SELECT * FROM a JOIN b ON ((a.id = b.id))";
assert_eq!(
normalize_view_query(db_form),
normalize_view_query(schema_form)
);
}
#[test]
fn normalize_view_query_handles_boolean_logic_parens() {
let db_form = "SELECT * FROM t WHERE a = 'x' OR b = 'y' AND c = 'z'";
let schema_form =
"SELECT * FROM t WHERE ((a = 'x'::text) OR ((b = 'y'::text) AND (c = 'z'::text)))";
let expected = "SELECT * FROM t WHERE a = 'x' OR b = 'y' AND c = 'z'";
assert_eq!(normalize_view_query(db_form), expected);
assert_eq!(normalize_view_query(schema_form), expected);
}
#[test]
fn regex_fallback_strips_text_cast() {
let input = "'foo'::text";
let result = normalize_expression_regex(input);
assert_eq!(result, "'foo'");
}
#[test]
fn regex_fallback_normalizes_like() {
let input = "name ~~ 'test%'";
let result = normalize_expression_regex(input);
assert_eq!(result, "name LIKE 'test%'");
}
#[test]
fn regex_fallback_normalizes_not_like() {
let input = "name !~~ 'test%'";
let result = normalize_expression_regex(input);
assert_eq!(result, "name NOT LIKE 'test%'");
}
#[test]
fn check_expression_with_numeric_cast() {
let db_expr =
r#"(("liveTreeAreaHa" IS NULL) OR ("liveTreeAreaHa" >= (0)::double precision))"#;
let parsed_expr = r#""liveTreeAreaHa" IS NULL OR "liveTreeAreaHa" >= 0"#;
assert!(expressions_semantically_equal(db_expr, parsed_expr));
}
#[test]
fn check_expression_in_list_equals_any_array() {
let schema_expr = "role IN ('user', 'assistant', 'system')";
let db_expr = "(role = ANY (ARRAY['user'::text, 'assistant'::text, 'system'::text]))";
assert!(expressions_semantically_equal(schema_expr, db_expr));
}
#[test]
fn check_expression_not_in_list_equals_all_array() {
let schema_expr = "role NOT IN ('user', 'assistant', 'system')";
let db_expr = "(role <> ALL (ARRAY['user'::text, 'assistant'::text, 'system'::text]))";
assert!(expressions_semantically_equal(schema_expr, db_expr));
}
#[test]
fn flatten_double_nested_join() {
let schema_form = "SELECT 1 FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
let db_form = "SELECT 1 FROM ((a JOIN b ON a.id = b.id) JOIN c ON b.id = c.id)";
assert!(
views_semantically_equal(schema_form, db_form),
"Double nested JOIN should equal flat JOIN. Schema: {schema_form}, DB: {db_form}"
);
}
#[test]
fn flatten_double_nested_join_with_public_schema() {
let schema_form = r#"SELECT 1 FROM mrv."Cultivation" c JOIN public.user_roles ur1 ON ur1.user_id = c.owner_id JOIN public.user_roles ur2 ON ur2.farmer_id = ur1.farmer_id"#;
let db_form = r#"SELECT 1 FROM ((mrv."Cultivation" c JOIN user_roles ur1 ON ur1.user_id = c.owner_id) JOIN user_roles ur2 ON ur2.farmer_id = ur1.farmer_id)"#;
assert!(
views_semantically_equal(schema_form, db_form),
"Cross-schema nested JOIN with public prefix removal should match.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn policy_expression_with_nested_join() {
let schema_expr = r#"EXISTS (SELECT 1 FROM public.user_roles ur1 JOIN public.user_roles ur2 ON ur2.farmer_id = ur1.farmer_id WHERE ur1.user_id = auth.uid())"#;
let db_expr = r#"(EXISTS ( SELECT 1 FROM (user_roles ur1 JOIN user_roles ur2 ON ((ur2.farmer_id = ur1.farmer_id))) WHERE (ur1.user_id = auth.uid())))"#;
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"Policy EXISTS with nested JOINs should be semantically equal.\nSchema: {schema_expr}\nDB: {db_expr}"
);
}
#[test]
fn flatten_triple_nested_join() {
let schema_form =
"SELECT 1 FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id JOIN d ON c.id = d.id";
let db_form =
"SELECT 1 FROM (((a JOIN b ON a.id = b.id) JOIN c ON b.id = c.id) JOIN d ON c.id = d.id)";
assert!(
views_semantically_equal(schema_form, db_form),
"Triple nested JOIN should equal flat JOIN.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn nested_join_preserves_join_types() {
let schema_form = "SELECT 1 FROM a INNER JOIN b ON a.id = b.id LEFT JOIN c ON b.id = c.id";
let db_form = "SELECT 1 FROM ((a JOIN b ON a.id = b.id) LEFT JOIN c ON b.id = c.id)";
assert!(
views_semantically_equal(schema_form, db_form),
"Nested JOINs should preserve join types.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn inner_join_equals_join() {
let schema_form = "SELECT 1 FROM a INNER JOIN b ON a.id = b.id";
let db_form = "SELECT 1 FROM a JOIN b ON a.id = b.id";
assert!(
views_semantically_equal(schema_form, db_form),
"INNER JOIN and JOIN should be semantically equal.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn left_outer_join_equals_left_join() {
let schema_form = "SELECT 1 FROM a LEFT OUTER JOIN b ON a.id = b.id";
let db_form = "SELECT 1 FROM a LEFT JOIN b ON a.id = b.id";
assert!(
views_semantically_equal(schema_form, db_form),
"LEFT OUTER JOIN and LEFT JOIN should be semantically equal.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn right_outer_join_equals_right_join() {
let schema_form = "SELECT 1 FROM a RIGHT OUTER JOIN b ON a.id = b.id";
let db_form = "SELECT 1 FROM a RIGHT JOIN b ON a.id = b.id";
assert!(
views_semantically_equal(schema_form, db_form),
"RIGHT OUTER JOIN and RIGHT JOIN should be semantically equal.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn nested_join_with_aliases() {
let schema_form =
"SELECT 1 FROM users u JOIN roles r ON u.id = r.user_id JOIN perms p ON r.id = p.role_id";
let db_form =
"SELECT 1 FROM ((users u JOIN roles r ON u.id = r.user_id) JOIN perms p ON r.id = p.role_id)";
assert!(
views_semantically_equal(schema_form, db_form),
"Nested JOINs should preserve aliases.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn exists_subquery_with_nested_joins_in_policy() {
let schema_expr = r#"EXISTS (SELECT 1 FROM mrv."Farm" f JOIN public.user_roles ur1 ON ur1.user_id = auth.uid() JOIN public.user_roles ur2 ON ur2.farmer_id = ur1.farmer_id WHERE f.id = "Cultivation"."farmId")"#;
let db_expr = r#"(EXISTS ( SELECT 1 FROM ((mrv."Farm" f JOIN user_roles ur1 ON ((ur1.user_id = auth.uid()))) JOIN user_roles ur2 ON ((ur2.farmer_id = ur1.farmer_id))) WHERE (f.id = "farmId")))"#;
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"Complex policy EXISTS with nested JOINs should match.\nSchema: {schema_expr}\nDB: {db_expr}"
);
}
#[test]
fn sanitize_url_replaces_password() {
assert_eq!(
sanitize_url("postgres://user:secret@host/db"),
"postgres://user:****@host/db"
);
}
#[test]
fn sanitize_url_with_port() {
assert_eq!(
sanitize_url("postgres://user:secret@host:5432/db"),
"postgres://user:****@host:5432/db"
);
}
#[test]
fn sanitize_url_without_password() {
assert_eq!(sanitize_url("postgres://host/db"), "postgres://host/db");
}
#[test]
fn sanitize_url_without_at_sign() {
assert_eq!(
sanitize_url("postgres://localhost/db"),
"postgres://localhost/db"
);
}
#[test]
fn sanitize_url_user_without_password() {
assert_eq!(
sanitize_url("postgres://user@host/db"),
"postgres://user@host/db"
);
}
#[test]
fn sanitize_connection_error_scrubs_password_from_message() {
let url = "postgres://user:s3cret_p4ss@host:5432/db";
let error = "error connecting to server at host:5432: password authentication failed for user \"user\" (password was s3cret_p4ss)";
assert_eq!(
sanitize_connection_error(url, error),
"error connecting to server at host:5432: password authentication failed for user \"user\" (password was ****)"
);
}
#[test]
fn sanitize_connection_error_no_password_in_url() {
let url = "postgres://host/db";
let error = "connection refused";
assert_eq!(sanitize_connection_error(url, error), "connection refused");
}
#[test]
fn sanitize_connection_error_empty_password() {
let url = "postgres://user:@host/db";
let error = "connection refused";
assert_eq!(sanitize_connection_error(url, error), "connection refused");
}
#[test]
fn sanitize_connection_error_short_password_skips_scrubbing() {
let url = "postgres://user:db@host:5432/mydb";
let error = "connection to database failed";
assert_eq!(
sanitize_connection_error(url, error),
"connection to database failed"
);
}
#[test]
fn sanitize_connection_error_url_encoded_password() {
let url = "postgres://user:p%40ss%3Aword@host:5432/db";
let error = "authentication failed with password p@ss:word";
assert_eq!(
sanitize_connection_error(url, error),
"authentication failed with password ****"
);
}
#[test]
fn sanitize_url_empty_password() {
assert_eq!(
sanitize_url("postgres://user:@host/db"),
"postgres://user:@host/db"
);
}
#[test]
fn sanitize_url_postgresql_scheme() {
assert_eq!(
sanitize_url("postgresql://user:secret@host:5432/db"),
"postgresql://user:****@host:5432/db"
);
}
#[test]
fn sanitize_connection_error_password_appears_multiple_times() {
let url = "postgres://user:hunter2@host/db";
let error = "failed at hunter2: invalid hunter2 token";
assert_eq!(
sanitize_connection_error(url, error),
"failed at ****: invalid **** token"
);
}
#[test]
fn simple_percent_decode_multibyte_utf8() {
assert_eq!(super::simple_percent_decode("%C3%A9"), "\u{00e9}");
}
}
#[test]
fn view_with_left_join_and_public_schema_prefix() {
let schema_form = r#"SELECT e.id, u.email FROM public.enterprises e LEFT JOIN public.user_roles ur ON ur.enterprise_id = e.id LEFT JOIN auth.users u ON u.id = ur.user_id"#;
let db_form = r#"SELECT e.id, u.email FROM ((enterprises e LEFT JOIN user_roles ur ON (ur.enterprise_id = e.id)) LEFT JOIN auth.users u ON (u.id = ur.user_id))"#;
assert!(
views_semantically_equal(schema_form, db_form),
"View with LEFT JOINs and public prefix should match.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn ast_comparison_handles_like_vs_tilde() {
let like_sql = "SELECT * FROM t WHERE name LIKE 'test%'";
let tilde_sql = "SELECT * FROM t WHERE name ~~ 'test%'";
assert!(views_semantically_equal(like_sql, tilde_sql));
}
#[test]
fn ast_comparison_handles_not_like_vs_not_tilde() {
let not_like_sql = "SELECT * FROM t WHERE name NOT LIKE 'test%'";
let not_tilde_sql = "SELECT * FROM t WHERE name !~~ 'test%'";
assert!(views_semantically_equal(not_like_sql, not_tilde_sql));
}
#[test]
fn ast_comparison_handles_ilike_vs_tilde_star() {
let ilike_sql = "SELECT * FROM t WHERE name ILIKE 'test%'";
let tilde_star_sql = "SELECT * FROM t WHERE name ~~* 'test%'";
assert!(views_semantically_equal(ilike_sql, tilde_star_sql));
}
#[test]
fn ast_comparison_handles_parens() {
let no_parens = "SELECT * FROM t WHERE a = 'x'";
let single_parens = "SELECT * FROM t WHERE (a = 'x')";
let double_parens = "SELECT * FROM t WHERE ((a = 'x'))";
assert!(views_semantically_equal(no_parens, single_parens));
assert!(views_semantically_equal(no_parens, double_parens));
assert!(views_semantically_equal(single_parens, double_parens));
}
#[test]
fn ast_comparison_handles_nested_parens_in_boolean() {
let minimal = "SELECT * FROM t WHERE a = 'x' OR b = 'y' AND c = 'z'";
let with_parens = "SELECT * FROM t WHERE (a = 'x') OR ((b = 'y') AND (c = 'z'))";
let more_parens = "SELECT * FROM t WHERE ((a = 'x') OR ((b = 'y') AND (c = 'z')))";
assert!(views_semantically_equal(minimal, with_parens));
assert!(views_semantically_equal(minimal, more_parens));
}
#[test]
fn ast_comparison_handles_text_cast_on_strings() {
let without_cast = "SELECT 'value' FROM t";
let with_cast = "SELECT 'value'::text FROM t";
assert!(views_semantically_equal(without_cast, with_cast));
}
#[test]
fn ast_comparison_handles_enum_cast_on_strings() {
let without_cast = "SELECT * FROM items WHERE status = 'ACTIVE'";
let with_cast = "SELECT * FROM items WHERE status = 'ACTIVE'::status_enum";
assert!(views_semantically_equal(without_cast, with_cast));
}
#[test]
fn ast_comparison_handles_schema_qualified_enum_cast() {
let without_cast = "SELECT * FROM items WHERE status = 'ACTIVE'";
let with_cast = "SELECT * FROM items WHERE status = 'ACTIVE'::public.status_enum";
assert!(views_semantically_equal(without_cast, with_cast));
}
#[test]
fn ast_comparison_handles_type_cast_case() {
let upper = "SELECT id::TEXT FROM t";
let lower = "SELECT id::text FROM t";
assert!(views_semantically_equal(upper, lower));
}
#[test]
fn ast_comparison_strips_numeric_cast_on_column_in_greatest() {
let schema_form = "SELECT t1.id, GREATEST(t1.col, 0) AS col FROM s.t t1";
let db_form = "SELECT t1.id, GREATEST((t1.col)::integer, 0) AS col FROM s.t t1";
assert!(views_semantically_equal(schema_form, db_form),);
}
#[test]
fn ast_comparison_handles_complex_view() {
let db_form = "SELECT u.id, 'active' AS status FROM users u WHERE EXISTS (SELECT 1 FROM roles r WHERE r.user_id = u.id AND r.name LIKE 'admin_%')";
let schema_form = "SELECT u.id, 'active'::text AS status FROM users u WHERE (EXISTS (SELECT 1 FROM roles r WHERE ((r.user_id = u.id) AND (r.name ~~ 'admin_%'::text))))";
assert!(views_semantically_equal(db_form, schema_form));
}
#[test]
fn ast_comparison_detects_real_differences() {
let query1 = "SELECT * FROM users";
let query2 = "SELECT * FROM accounts";
assert!(!views_semantically_equal(query1, query2));
let query3 = "SELECT id FROM users";
let query4 = "SELECT name FROM users";
assert!(!views_semantically_equal(query3, query4));
let query5 = "SELECT * FROM t WHERE a = 1";
let query6 = "SELECT * FROM t WHERE a = 2";
assert!(!views_semantically_equal(query5, query6));
}
#[test]
fn view_normalization_case_branch_text_cast() {
let parsed = "SELECT CASE WHEN s.is_active = false THEN 'inactive' WHEN u.email_confirmed_at IS NOT NULL THEN 'active' ELSE 'pending' END AS status FROM t";
let pg = "SELECT CASE WHEN s.is_active = false THEN 'inactive'::text WHEN u.email_confirmed_at IS NOT NULL THEN 'active'::text ELSE 'pending'::text END AS status FROM t";
assert!(views_semantically_equal(parsed, pg));
}
#[test]
fn view_normalization_jsonb_extract_cast_placement() {
let parsed = "SELECT (u.raw_user_meta_data ->> 'supplier_name')::text AS name FROM t u";
let pg = "SELECT u.raw_user_meta_data ->> 'supplier_name'::text AS name FROM t u";
assert!(views_semantically_equal(parsed, pg));
}
#[test]
fn view_normalization_jsonb_extract_uuid_cast() {
let parsed = "SELECT * FROM t u LEFT JOIN s ON (s.id = (u.data ->> 'supplier_id')::uuid)";
let pg = "SELECT * FROM t u LEFT JOIN s ON s.id = ((u.data ->> 'supplier_id'::text)::uuid)";
assert!(views_semantically_equal(parsed, pg));
}
#[test]
fn view_normalization_not_exists_parens() {
let parsed = "SELECT * FROM t WHERE NOT EXISTS (SELECT 1 FROM u WHERE u.id = t.id)";
let pg = "SELECT * FROM t WHERE NOT (EXISTS ( SELECT 1 FROM u WHERE u.id = t.id))";
assert!(views_semantically_equal(parsed, pg));
}
#[test]
fn view_normalization_or_branch_parens() {
let parsed = "SELECT * FROM t WHERE (a IS NOT NULL AND f(a)) OR (b IS NOT NULL AND f(b))";
let pg = "SELECT * FROM t WHERE a IS NOT NULL AND f(a) OR b IS NOT NULL AND f(b)";
assert!(views_semantically_equal(parsed, pg));
}
#[test]
fn expression_comparison_handles_exists_subquery() {
let parsed = r#"EXISTS (SELECT 1 FROM "mrv"."OrganizationUser" ou WHERE ou."organizationId" = "Farm"."organizationId")"#;
let db = r#"(EXISTS ( SELECT 1
FROM mrv."OrganizationUser" ou
WHERE (ou."organizationId" = "Farm"."organizationId")))"#;
assert!(
expressions_semantically_equal(parsed, db),
"EXISTS expressions should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_nested_exists_with_function_calls() {
let parsed = r#"EXISTS (SELECT 1 FROM public.user_roles ur1 WHERE ur1.user_id = auth.uid() AND ur1.farmer_id IS NOT NULL AND EXISTS (SELECT 1 FROM public.user_roles ur2 WHERE ur2.user_id = "entityId" AND ur2.farmer_id = ur1.farmer_id))"#;
let db = r#"(EXISTS ( SELECT 1
FROM public.user_roles ur1
WHERE ((ur1.user_id = auth.uid()) AND (ur1.farmer_id IS NOT NULL) AND (EXISTS ( SELECT 1
FROM public.user_roles ur2
WHERE ((ur2.user_id = "entityId") AND (ur2.farmer_id = ur1.farmer_id)))))))"#;
assert!(
expressions_semantically_equal(parsed, db),
"Nested EXISTS expressions with function calls should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_numeric_literal_cast() {
let parsed = r#"EXISTS (SELECT 1 FROM users WHERE id = user_id)"#;
let db = r#"(EXISTS (SELECT (1)::integer FROM users WHERE id = user_id))"#;
assert!(
expressions_semantically_equal(parsed, db),
"Expressions with numeric literal casts should be semantically equal"
);
}
#[test]
fn view_comparison_handles_numeric_literal_cast() {
let schema = "SELECT 1 FROM users";
let db = "SELECT (1)::integer FROM users";
assert!(
views_semantically_equal(schema, db),
"Views with numeric literal casts should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_numeric_cast_without_parens() {
let parsed = r#"EXISTS (SELECT 1 FROM users WHERE id = user_id)"#;
let db = r#"(EXISTS (SELECT 1::integer FROM users WHERE id = user_id))"#;
assert!(
expressions_semantically_equal(parsed, db),
"Expressions with numeric casts (no parens) should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_function_name_quoting() {
let parsed = r#"auth.uid() = user_id"#;
let db_quoted_schema = r#""auth".uid() = user_id"#;
let db_quoted_func = r#"auth."uid"() = user_id"#;
let db_both_quoted = r#""auth"."uid"() = user_id"#;
assert!(
expressions_semantically_equal(parsed, db_quoted_schema),
"Function with quoted schema should be semantically equal: {parsed} vs {db_quoted_schema}"
);
assert!(
expressions_semantically_equal(parsed, db_quoted_func),
"Function with quoted name should be semantically equal: {parsed} vs {db_quoted_func}"
);
assert!(
expressions_semantically_equal(parsed, db_both_quoted),
"Function with both quoted should be semantically equal: {parsed} vs {db_both_quoted}"
);
}
#[test]
fn view_comparison_handles_alias_case_and_join() {
let schema = r#"SELECT
ff."facilityId" as facility_id,
ff."farmerId" as user_id
FROM mrv."FacilityFarmer" ff
JOIN public.farmer_users_view fu ON fu.user_id = ff."farmerId""#;
let db = r#"SELECT ff."facilityId" AS facility_id, ff."farmerId" AS user_id FROM mrv."FacilityFarmer" ff JOIN public.farmer_users_view fu ON fu.user_id = ff."farmerId""#;
assert!(
views_semantically_equal(schema, db),
"Views with alias case differences should be semantically equal"
);
}
#[test]
fn view_comparison_handles_postgresql_from_clause_normalization() {
let schema = r#"SELECT ff.id FROM mrv."FacilityFarmer" ff JOIN public.farmer_users fu ON fu.user_id = ff."farmerId""#;
let db = r#"SELECT ff.id FROM (mrv."FacilityFarmer" ff JOIN farmer_users fu ON ((fu.user_id = ff."farmerId")))"#;
assert!(
views_semantically_equal(schema, db),
"Views should be semantically equal despite PostgreSQL normalization:\nSchema: {schema}\nDB: {db}"
);
}
#[test]
fn expression_comparison_handles_postgresql_identifier_normalization() {
let parsed_column = r#""entityId" = user_id"#;
let db_qualified = r#"farms."entityId" = user_id"#;
assert!(
expressions_semantically_equal(parsed_column, db_qualified),
"Bare column should equal table-qualified column: {parsed_column} vs {db_qualified}"
);
let parsed_schema = r#"public.user_roles"#;
let db_no_schema = r#"user_roles"#;
assert!(
expressions_semantically_equal(parsed_schema, db_no_schema),
"Table with schema should equal table without schema: {parsed_schema} vs {db_no_schema}"
);
}
#[test]
fn expression_comparison_handles_case_with_enum_cast() {
let without_cast = r#"
CASE entity_type
WHEN 'ENTERPRISE' THEN true
WHEN 'SUPPLIER' THEN true
ELSE false
END
"#;
let with_cast = r#"
CASE entity_type
WHEN 'ENTERPRISE'::test_schema."EntityType" THEN true
WHEN 'SUPPLIER'::test_schema."EntityType" THEN true
ELSE false
END
"#;
assert!(
expressions_semantically_equal(without_cast, with_cast),
"CASE with enum casts should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_case_with_exact_pg_format() {
let with_cast = r#"CASE entity_type
WHEN 'ENTERPRISE'::test_schema."EntityType" THEN true
WHEN 'SUPPLIER'::test_schema."EntityType" THEN true
ELSE false
END"#;
let without_cast = r#"CASE entity_type
WHEN 'ENTERPRISE' THEN true
WHEN 'SUPPLIER' THEN true
ELSE false
END"#;
assert!(
expressions_semantically_equal(with_cast, without_cast),
"CASE with exact pg_get_expr enum casts should be semantically equal"
);
}
#[test]
fn varchar_cast_on_identifier_stripped_in_expression_index() {
let schema_expr = "lower(col_name)";
let db_expr = "lower((col_name)::character varying)";
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"PostgreSQL adds ::character varying casts to varchar columns in expression indexes"
);
}
#[test]
fn varchar_cast_on_compound_identifier_stripped() {
let schema_expr = "lower(t1.col_name)";
let db_expr = "lower((t1.col_name)::character varying)";
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"Compound identifier varchar cast should be stripped"
);
}
#[test]
fn varchar_cast_on_string_literal_stripped() {
let schema_expr = "COALESCE(col, 'unknown')";
let db_expr = "COALESCE(col, 'unknown'::character varying)";
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"PostgreSQL adds ::character varying casts to string literals in COALESCE with varchar columns"
);
}
#[test]
fn length_qualified_varchar_cast_on_identifier_preserved() {
let with_length = "lower((col_name)::varchar(50))";
let without_cast = "lower(col_name)";
assert!(
!expressions_semantically_equal(with_length, without_cast),
"Length-qualified varchar cast on identifier should not be stripped"
);
}
#[test]
fn length_qualified_varchar_cast_on_string_literal_preserved() {
let with_length = "'value'::varchar(10)";
let without_cast = "'value'";
assert!(
!expressions_semantically_equal(with_length, without_cast),
"Length-qualified varchar cast on string literal should not be stripped"
);
}
#[test]
fn cast_syntax_equals_double_colon_syntax() {
let cast_form = "CAST(col AS varchar(100))";
let double_colon_form = "(col)::character varying(100)";
assert!(
expressions_semantically_equal(cast_form, double_colon_form),
"CAST(x AS type) and x::type should be semantically equal"
);
}
#[test]
fn regex_fallback_strips_schema_qualified_enum_cast() {
let with_cast = r#"'ENTERPRISE'::test_schema."EntityType""#;
let normalized = normalize_expression_regex(with_cast);
assert_eq!(
normalized, "'ENTERPRISE'",
"Should strip schema.\"EnumType\" cast"
);
}
#[test]
fn regex_fallback_strips_case_with_enum_casts() {
let with_cast = r#"CASE entity_type WHEN 'ENTERPRISE'::test_schema."EntityType" THEN true WHEN 'SUPPLIER'::test_schema."EntityType" THEN true ELSE false END"#;
let without_cast =
r#"CASE entity_type WHEN 'ENTERPRISE' THEN true WHEN 'SUPPLIER' THEN true ELSE false END"#;
let normalized_with = normalize_expression_regex(with_cast);
let normalized_without = normalize_expression_regex(without_cast);
assert_eq!(
normalized_with, normalized_without,
"CASE expressions with enum casts should normalize to same form"
);
}
#[test]
fn expression_comparison_handles_null_with_type_cast() {
let without_cast = "NULL";
let with_cast = "NULL::uuid";
assert!(
expressions_semantically_equal(without_cast, with_cast),
"NULL vs NULL::uuid should be semantically equal"
);
}
#[test]
fn expression_comparison_handles_named_function_args_with_table_qualifier() {
let schema_expr = r#"auth.user_in_context(p_supplier_id => farmers.supplier_id)"#;
let db_expr = r#"auth.user_in_context(p_supplier_id => supplier_id)"#;
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"Named function args should normalize table qualifiers: {schema_expr} vs {db_expr}"
);
}
#[test]
fn expression_comparison_handles_multiple_named_args_with_table_qualifiers() {
let schema_expr =
r#"auth.user_has_permission('farmers', 'create', p_supplier_id => farmers.supplier_id)"#;
let db_expr = r#"auth.user_has_permission('farmers'::text, 'create'::text, p_supplier_id => supplier_id)"#;
assert!(
expressions_semantically_equal(schema_expr, db_expr),
"Mixed positional/named args with table qualifiers and text casts should normalize"
);
}
#[test]
fn in_list_equals_any_array() {
let schema_form = "SELECT * FROM t WHERE r.name IN ('admin', 'member')";
let db_form = "SELECT * FROM t WHERE r.name = ANY (ARRAY['admin'::text, 'member'::text])";
assert!(
views_semantically_equal(schema_form, db_form),
"IN list should equal = ANY(ARRAY[...])"
);
}
#[test]
fn not_in_list_equals_not_any_array() {
let schema_form = "SELECT * FROM t WHERE r.name NOT IN ('admin', 'member')";
let db_form = "SELECT * FROM t WHERE r.name <> ALL (ARRAY['admin'::text, 'member'::text])";
assert!(
views_semantically_equal(schema_form, db_form),
"NOT IN list should equal <> ALL(ARRAY[...])"
);
}
#[test]
fn filter_clause_with_extra_parens() {
let schema_form = "SELECT json_agg(x) FILTER (WHERE u.id IS NOT NULL) FROM t";
let db_form = "SELECT json_agg(x) FILTER (WHERE (u.id IS NOT NULL)) FROM t";
assert!(
views_semantically_equal(schema_form, db_form),
"FILTER clause extra parens should be normalized"
);
}
#[test]
fn issue_40_full_view_query() {
let schema_form = r#"SELECT
t.id AS team_id,
t.name AS team_name,
COALESCE(
json_agg(
json_build_object(
'user_id', u.id,
'email', u.email,
'role', r.name
)
) FILTER (WHERE u.id IS NOT NULL),
'[]'::json
) AS members
FROM public.teams t
LEFT JOIN public.memberships m ON m.team_id = t.id
LEFT JOIN public.roles r ON m.role_id = r.id AND r.name IN ('admin', 'member')
LEFT JOIN auth.users u ON m.user_id = u.id
GROUP BY t.id, t.name"#;
let db_form = r#"SELECT t.id AS team_id,
t.name AS team_name,
COALESCE(json_agg(json_build_object('user_id', u.id, 'email', u.email, 'role', r.name)) FILTER (WHERE (u.id IS NOT NULL)), '[]'::json) AS members
FROM (((teams t
LEFT JOIN memberships m ON ((m.team_id = t.id)))
LEFT JOIN roles r ON (((m.role_id = r.id) AND (r.name = ANY (ARRAY['admin'::text, 'member'::text])))))
LEFT JOIN auth.users u ON ((m.user_id = u.id)))
GROUP BY t.id, t.name"#;
assert!(
views_semantically_equal(schema_form, db_form),
"Full issue #40 view should be semantically equal despite PostgreSQL normalization"
);
}
#[test]
fn expressions_equal_with_anyarray_operator() {
let db_form = "(status = ANY (ARRAY['active'::text, 'pending'::text]))";
let schema_form = "status = ANY(ARRAY['active', 'pending'])";
assert!(
expressions_semantically_equal(db_form, schema_form),
"= ANY(ARRAY[...]) with ::text casts should equal version without casts"
);
}
#[test]
fn expressions_equal_with_nested_function_parens() {
let db_form = "((auth.uid() = user_id) OR (role = 'admin'::text))";
let schema_form = "(auth.uid() = user_id OR role = 'admin')";
assert!(
expressions_semantically_equal(db_form, schema_form),
"Extra parens around OR operands should normalize away"
);
}
#[test]
fn expressions_equal_with_exists_subquery_parens() {
let db_form = "(EXISTS (SELECT 1 FROM memberships m WHERE (m.user_id = users.id)))";
let schema_form = "EXISTS (SELECT 1 FROM memberships m WHERE m.user_id = users.id)";
assert!(
expressions_semantically_equal(db_form, schema_form),
"EXISTS with extra parens should normalize"
);
}
#[test]
fn expressions_equal_with_pg_function_cast() {
let db_form = "((auth.uid())::text = (user_id)::text)";
let schema_form = "auth.uid() = user_id";
assert!(
expressions_semantically_equal(db_form, schema_form),
"::text casts on both sides should be stripped"
);
}
#[test]
fn expressions_equal_with_text_literal_cast() {
let db_form = "(role() = 'admin'::text)";
let schema_form = "role() = 'admin'";
assert!(
expressions_semantically_equal(db_form, schema_form),
"::text on string literal should be stripped"
);
}
#[test]
fn expressions_equal_scalar_subquery_with_auto_alias() {
let schema_form = "(SELECT auth.uid()) = id";
let db_form = "( SELECT auth.uid() AS uid) = id";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Scalar subquery with auto-generated alias should match without alias.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn function_call_equals_scalar_subquery_form() {
let schema_form = "auth.is_admin()";
let db_form = "( SELECT auth.is_admin() AS is_admin)";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Direct function call should equal its scalar subquery form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn function_call_with_args_equals_scalar_subquery_form() {
let schema_form = "auth.uid()";
let db_form = "( SELECT auth.uid() AS uid)";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Direct function call (with no args) should equal its scalar subquery form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn function_call_in_comparison_equals_scalar_subquery_form() {
let schema_form = "auth.uid() = user_id";
let db_form = "( SELECT auth.uid() AS uid) = user_id";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Function call in comparison should equal scalar subquery form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn try_simplify_scalar_subquery_matches_sqlparser_group_by_variant() {
let dialect = PostgreSqlDialect {};
let expr_str = "( SELECT auth.is_admin() AS is_admin)";
let parsed = Parser::new(&dialect)
.try_with_sql(expr_str)
.expect("valid SQL")
.parse_expr()
.expect("parse expr");
let Expr::Subquery(query) = parsed else {
panic!("expected Expr::Subquery, got something else");
};
assert!(
try_simplify_scalar_subquery(&query).is_some(),
"GROUP BY guard in try_simplify_scalar_subquery did not match sqlparser's AST for: {expr_str}"
);
}
#[test]
fn expressions_equal_interval_literal_vs_cast() {
let schema_form = "interval '90 days'";
let db_form = "'90 days'::interval";
assert!(
expressions_semantically_equal(schema_form, db_form),
"interval literal and cast syntax should be equal.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn view_with_order_by_normalized() {
let schema_form = "SELECT id, name FROM users ORDER BY name";
let db_form = "SELECT id, name FROM users ORDER BY name";
assert!(
views_semantically_equal(schema_form, db_form),
"Views with identical ORDER BY should be equal"
);
}
#[test]
fn view_with_order_by_cast_normalized() {
let schema_form = "SELECT id, name FROM users ORDER BY lower(name)";
let db_form = "SELECT id, name FROM users ORDER BY lower(name)";
assert!(
views_semantically_equal(schema_form, db_form),
"Views with function in ORDER BY should be equal"
);
}
#[test]
fn view_with_order_by_extra_parens() {
let schema_form = "SELECT id FROM t ORDER BY name";
let db_form = "SELECT id FROM t ORDER BY (name)";
assert!(
views_semantically_equal(schema_form, db_form),
"ORDER BY with extra parens should be equal"
);
}
#[test]
fn materialized_view_count_star() {
let schema_form = "SELECT COUNT(*) FROM users";
let db_form = "SELECT count(*) FROM users";
assert!(
views_semantically_equal(schema_form, db_form),
"COUNT(*) vs count(*) should be equal"
);
}
#[test]
fn materialized_view_count_star_with_alias() {
let schema_form = "SELECT COUNT(*) AS total FROM users";
let db_form = "SELECT count(*) AS total FROM users";
assert!(
views_semantically_equal(schema_form, db_form),
"COUNT(*) AS total vs count(*) AS total should be equal"
);
}
#[test]
fn not_in_view_equals_not_all_array() {
let schema_form = "SELECT * FROM t WHERE status NOT IN ('a', 'b')";
let db_form = "SELECT * FROM t WHERE status <> ALL (ARRAY['a'::text, 'b'::text])";
assert!(
views_semantically_equal(schema_form, db_form),
"NOT IN should equal <> ALL(ARRAY[...])"
);
}
#[test]
fn expressions_equal_empty_array_literal_vs_typed_cast() {
let schema_form = "'{}'";
let db_form = "'{}'::text[]";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Empty array literal should equal typed cast form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn expressions_equal_array_literal_with_values_vs_typed_cast() {
let schema_form = "'{a,b}'";
let db_form = "'{a,b}'::text[]";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Array literal with values should equal typed cast form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn expressions_equal_empty_array_literal_vs_integer_array_cast() {
let schema_form = "'{}'";
let db_form = "'{}'::integer[]";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Empty array literal should equal integer[] typed cast form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn expressions_equal_empty_array_literal_vs_boolean_array_cast() {
let schema_form = "'{}'";
let db_form = "'{}'::boolean[]";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Empty array literal should equal boolean[] typed cast form.\nSchema: {schema_form}\nDB: {db_form}"
);
}
#[test]
fn expressions_equal_empty_array_literal_vs_uuid_array_cast() {
let schema_form = "'{}'";
let db_form = "'{}'::uuid[]";
assert!(
expressions_semantically_equal(schema_form, db_form),
"Empty array literal should equal uuid[] typed cast form.\nSchema: {schema_form}\nDB: {db_form}"
);
}