use crate::ast::{Action, Expr, Qail, Value};
use std::fmt;
#[derive(Debug, Clone)]
pub struct SanitizeError {
pub field: String,
pub value: String,
pub reason: String,
}
impl fmt::Display for SanitizeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"AST validation failed: {} '{}' — {}",
self.field, self.value, self.reason
)
}
}
impl std::error::Error for SanitizeError {}
const MAX_IDENT_LEN: usize = 63;
fn is_safe_identifier(s: &str) -> bool {
!s.is_empty()
&& s.len() <= MAX_IDENT_LEN
&& s.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
}
fn check_ident(field: &str, value: &str) -> Result<(), SanitizeError> {
if is_safe_identifier(value) {
Ok(())
} else {
Err(SanitizeError {
field: field.to_string(),
value: value.chars().take(40).collect(),
reason: "identifiers must match [a-zA-Z0-9_.] and be ≤63 chars".to_string(),
})
}
}
fn check_expr(field: &str, expr: &Expr) -> Result<(), SanitizeError> {
match expr {
Expr::Star => Ok(()),
Expr::Named(name) => check_ident(field, name),
Expr::Aliased { name, alias } => {
check_ident(field, name)?;
check_ident(&format!("{field}.alias"), alias)
}
Expr::Aggregate {
col, alias, filter, ..
} => {
check_ident(field, col)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
if let Some(conditions) = filter {
for cond in conditions {
check_expr(&format!("{field}.filter"), &cond.left)?;
}
}
Ok(())
}
Expr::FunctionCall { name, args, alias } => {
check_ident(field, name)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
for arg in args {
check_expr(&format!("{field}.arg"), arg)?;
}
Ok(())
}
Expr::Cast {
expr,
target_type,
alias,
} => {
check_expr(field, expr)?;
check_ident(&format!("{field}.cast_type"), target_type)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Binary {
left, right, alias, ..
} => {
check_expr(field, left)?;
check_expr(field, right)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Literal(_) => Ok(()),
Expr::JsonAccess {
column,
alias,
path_segments,
..
} => {
check_ident(field, column)?;
for (key, _) in path_segments {
if key.parse::<i64>().is_err() && !is_safe_identifier(key) {
return Err(SanitizeError {
field: format!("{field}.json_path"),
value: key.chars().take(40).collect(),
reason: "JSON path key must be a safe identifier or integer".to_string(),
});
}
}
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Subquery { query, alias } => {
validate_ast(query)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Exists { query, alias, .. } => {
validate_ast(query)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Window {
name,
func,
partition,
params,
order,
..
} => {
if !name.is_empty() {
check_ident(&format!("{field}.window_alias"), name)?;
}
check_ident(&format!("{field}.window_func"), func)?;
for p in partition {
check_ident(&format!("{field}.partition"), p)?;
}
for p in params {
check_expr(&format!("{field}.window_param"), p)?;
}
for cage in order {
for cond in &cage.conditions {
check_expr(&format!("{field}.window_order"), &cond.left)?;
check_value(&format!("{field}.window_order"), &cond.value)?;
}
}
Ok(())
}
Expr::Case {
when_clauses,
else_value,
alias,
} => {
for (cond, val) in when_clauses {
check_expr(
&format!("{field}.case_when"),
&Expr::Named(cond.left.to_string()),
)?;
check_expr(&format!("{field}.case_then"), val)?;
}
if let Some(e) = else_value {
check_expr(&format!("{field}.case_else"), e)?;
}
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::SpecialFunction { args, alias, name } => {
check_ident(&format!("{field}.special_func"), name)?;
for (_, arg) in args {
check_expr(&format!("{field}.special_func_arg"), arg)?;
}
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::ArrayConstructor { elements, alias } | Expr::RowConstructor { elements, alias } => {
for elem in elements {
check_expr(&format!("{field}.element"), elem)?;
}
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Subscript { expr, index, alias } => {
check_expr(&format!("{field}.subscript_expr"), expr)?;
check_expr(&format!("{field}.subscript_index"), index)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Collate {
expr,
collation,
alias,
} => {
check_expr(&format!("{field}.collate_expr"), expr)?;
check_ident(&format!("{field}.collation"), collation)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::FieldAccess {
expr,
field: f,
alias,
} => {
check_expr(&format!("{field}.field_access_expr"), expr)?;
check_ident(&format!("{field}.field"), f)?;
if let Some(a) = alias {
check_ident(&format!("{field}.alias"), a)?;
}
Ok(())
}
Expr::Def { name, .. } => check_ident(field, name),
Expr::Mod { col, .. } => check_expr(field, col),
}
}
fn check_value(field: &str, value: &Value) -> Result<(), SanitizeError> {
match value {
Value::Subquery(q) => validate_ast(q),
Value::Array(vals) => {
for v in vals {
check_value(field, v)?;
}
Ok(())
}
Value::Expr(expr) => check_expr(field, expr),
_ => Ok(()),
}
}
pub fn validate_ast(cmd: &Qail) -> Result<(), SanitizeError> {
match cmd.action {
Action::Call | Action::Do | Action::SessionSet | Action::SessionReset => {
return Err(SanitizeError {
field: "action".to_string(),
value: format!("{:?}", cmd.action),
reason: "procedural/session actions are not allowed via binary AST".to_string(),
});
}
_ => {}
}
if !cmd.table.is_empty() {
check_ident("table", &cmd.table)?;
}
for (i, col) in cmd.columns.iter().enumerate() {
check_expr(&format!("columns[{i}]"), col)?;
}
for (i, join) in cmd.joins.iter().enumerate() {
for token in join.table.split_whitespace() {
check_ident(&format!("joins[{i}].table"), token)?;
}
if let Some(ref conditions) = join.on {
for cond in conditions {
check_expr(&format!("joins[{i}].on"), &cond.left)?;
check_value(&format!("joins[{i}].on"), &cond.value)?;
}
}
}
for cage in &cmd.cages {
for cond in &cage.conditions {
check_expr("cage.condition.left", &cond.left)?;
check_value("cage.condition.value", &cond.value)?;
}
}
for cte in &cmd.ctes {
check_ident("cte.name", &cte.name)?;
for col in &cte.columns {
check_ident("cte.column", col)?;
}
validate_ast(&cte.base_query)?;
if let Some(ref rq) = cte.recursive_query {
validate_ast(rq)?;
}
}
for expr in &cmd.distinct_on {
check_expr("distinct_on", expr)?;
}
if let Some(ref cols) = cmd.returning {
for col in cols {
check_expr("returning", col)?;
}
}
if let Some(ref oc) = cmd.on_conflict {
for col in &oc.columns {
check_ident("on_conflict.column", col)?;
}
}
for t in &cmd.from_tables {
check_ident("from_tables", t)?;
}
for t in &cmd.using_tables {
check_ident("using_tables", t)?;
}
for (_, sub) in &cmd.set_ops {
validate_ast(sub)?;
}
if let Some(ref sq) = cmd.source_query {
validate_ast(sq)?;
}
for cond in &cmd.having {
check_expr("having", &cond.left)?;
check_value("having", &cond.value)?;
}
if let Some(ref ch) = cmd.channel {
check_ident("channel", ch)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Qail;
#[test]
fn valid_simple_query_passes() {
let cmd = Qail::get("users").columns(["id", "name"]);
assert!(validate_ast(&cmd).is_ok());
}
#[test]
fn sql_injection_in_table_rejected() {
let cmd = Qail::get("users; DROP TABLE users; --");
let err = validate_ast(&cmd).unwrap_err();
assert_eq!(err.field, "table");
}
#[test]
fn call_action_rejected() {
let cmd = Qail {
action: Action::Call,
table: "my_proc()".to_string(),
..Default::default()
};
let err = validate_ast(&cmd).unwrap_err();
assert_eq!(err.field, "action");
}
#[test]
fn do_action_rejected() {
let cmd = Qail {
action: Action::Do,
table: "plpgsql".to_string(),
..Default::default()
};
let err = validate_ast(&cmd).unwrap_err();
assert_eq!(err.field, "action");
}
#[test]
fn valid_qualified_name_passes() {
let cmd = Qail::get("public.users").columns(["users.id", "users.name"]);
assert!(validate_ast(&cmd).is_ok());
}
#[test]
fn injection_in_join_table_rejected() {
use crate::ast::JoinKind;
let cmd = Qail::get("users").join(
JoinKind::Left,
"orders; DROP TABLE x",
"users.id",
"orders.user_id",
);
let err = validate_ast(&cmd).unwrap_err();
assert!(err.field.contains("joins"));
}
#[test]
fn injection_in_column_rejected() {
let cmd = Qail::get("users").columns(["id", "name; DROP TABLE x"]);
let err = validate_ast(&cmd).unwrap_err();
assert!(err.field.contains("columns"));
}
#[test]
fn empty_table_name_passes() {
let cmd = Qail {
action: Action::TxnStart,
table: String::new(),
..Default::default()
};
assert!(validate_ast(&cmd).is_ok());
}
#[test]
fn oversized_identifier_rejected() {
let long_name = "a".repeat(64);
let cmd = Qail::get(&long_name);
let err = validate_ast(&cmd).unwrap_err();
assert!(err.reason.contains("63"));
}
}