use athena_query::query_builder::sanitize_identifier;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sqlx::postgres::{PgPool, PgRow};
use sqlx::types::Json;
use sqlx::{Column, Either, Row, ValueRef};
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostgresSqlExecutionMode {
JsonRows,
DirectRows,
Command,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PostgresSqlExecutionSummary {
pub statement_count: usize,
pub rows_affected: u64,
pub returned_row_count: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PostgresSqlExecutionResult {
pub rows: Vec<Value>,
pub summary: PostgresSqlExecutionSummary,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PostgresSqlTransactionMode {
SingleTransaction,
PerStatement,
}
impl Default for PostgresSqlTransactionMode {
fn default() -> Self {
Self::SingleTransaction
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PostgresSqlStatementExecution {
pub statement_index: usize,
pub total_statements: usize,
pub statement: String,
pub line_start: usize,
pub line_end: usize,
pub rows_affected: u64,
pub returned_row_count: usize,
pub duration_ms: u64,
pub quoted_reserved_identifiers: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct PostgresSqlPreprocessSummary {
pub rewritten_reserved_identifier_count: usize,
pub rewritten_reserved_identifiers: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PostgresSqlScriptExecutionResult {
pub rows: Vec<Value>,
pub summary: PostgresSqlExecutionSummary,
pub statements: Vec<PostgresSqlStatementExecution>,
pub preprocess: PostgresSqlPreprocessSummary,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PostgresSqlScriptError {
pub message: String,
pub status_hint: u16,
pub statement_index: Option<usize>,
pub total_statements: Option<usize>,
pub statement: Option<String>,
pub line_start: Option<usize>,
pub line_end: Option<usize>,
pub preprocess: PostgresSqlPreprocessSummary,
}
pub fn normalize_sql_query(query: &str) -> String {
let mut normalized: &str = query.trim();
loop {
let trimmed: &str = normalized.trim_end();
if let Some(stripped) = trimmed.strip_suffix(';') {
normalized = stripped;
continue;
}
return trimmed.to_string();
}
}
pub fn classify_sql_query(query: &str) -> PostgresSqlExecutionMode {
let normalized: String = normalize_sql_query(query);
let lowered: String = normalized.to_ascii_lowercase();
let first_keyword: &str = lowered
.split(|ch: char| ch.is_whitespace() || ch == '(')
.find(|segment| !segment.is_empty())
.unwrap_or_default();
let has_returning: bool = lowered.contains(" returning ");
match first_keyword {
"select" | "values" | "with" => PostgresSqlExecutionMode::JsonRows,
"insert" | "update" | "delete" | "merge" if has_returning => {
PostgresSqlExecutionMode::JsonRows
}
"show" | "explain" => PostgresSqlExecutionMode::DirectRows,
_ => PostgresSqlExecutionMode::Command,
}
}
pub async fn execute_postgres_sql(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let normalized_query: String = normalize_sql_query(query);
let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
match mode {
PostgresSqlExecutionMode::JsonRows => execute_json_row_query(pool, &normalized_query).await,
PostgresSqlExecutionMode::DirectRows => {
execute_direct_row_query(pool, &normalized_query).await
}
PostgresSqlExecutionMode::Command => execute_command_query(pool, &normalized_query).await,
}
}
#[derive(Debug, Clone)]
struct SqlStatementSpan {
index: usize,
statement: String,
line_start: usize,
line_end: usize,
quoted_reserved_identifiers: Vec<String>,
}
#[derive(Debug, Clone)]
enum SqlScannerState {
Normal,
SingleQuotedString,
DoubleQuotedIdentifier,
LineComment,
BlockComment(usize),
DollarQuoted(String),
}
const RESERVED_IDENTIFIER_KEYWORDS: &[&str] = &[
"all",
"analyse",
"analyze",
"and",
"any",
"array",
"as",
"asc",
"asymmetric",
"authorization",
"between",
"binary",
"both",
"case",
"cast",
"check",
"collate",
"column",
"constraint",
"create",
"cross",
"current_catalog",
"current_date",
"current_role",
"current_time",
"current_timestamp",
"current_user",
"default",
"deferrable",
"desc",
"distinct",
"do",
"else",
"end",
"except",
"false",
"fetch",
"for",
"foreign",
"from",
"grant",
"group",
"having",
"in",
"initially",
"intersect",
"into",
"leading",
"limit",
"localtime",
"localtimestamp",
"new",
"not",
"null",
"off",
"offset",
"old",
"on",
"only",
"or",
"order",
"placing",
"primary",
"references",
"returning",
"select",
"session_user",
"some",
"symmetric",
"table",
"then",
"to",
"trailing",
"true",
"union",
"unique",
"user",
"using",
"variadic",
"when",
"where",
"window",
"with",
];
const CREATE_TABLE_SEGMENT_GUARD_KEYWORDS: &[&str] = &[
"constraint",
"primary",
"foreign",
"unique",
"check",
"exclude",
"like",
];
pub fn query_contains_create_table_statement(query: &str) -> bool {
let normalized_query = normalize_sql_query(query);
if normalized_query.is_empty() {
return false;
}
split_sql_statements_with_spans(&normalized_query)
.iter()
.any(|span| looks_like_create_table_statement(&span.statement))
}
pub async fn execute_postgres_sql_script(
pool: &PgPool,
query: &str,
mode: PostgresSqlTransactionMode,
schema_name: Option<&str>,
) -> Result<PostgresSqlScriptExecutionResult, PostgresSqlScriptError> {
let normalized_query: String = normalize_sql_query(query);
if normalized_query.is_empty() {
return Err(PostgresSqlScriptError {
message: "Query cannot be empty or contain only semicolons.".to_string(),
status_hint: 400,
statement_index: None,
total_statements: None,
statement: None,
line_start: None,
line_end: None,
preprocess: PostgresSqlPreprocessSummary::default(),
});
}
let mut statements: Vec<SqlStatementSpan> = split_sql_statements_with_spans(&normalized_query);
if statements.is_empty() {
return Err(PostgresSqlScriptError {
message: "Query does not contain executable SQL statements.".to_string(),
status_hint: 400,
statement_index: None,
total_statements: None,
statement: None,
line_start: None,
line_end: None,
preprocess: PostgresSqlPreprocessSummary::default(),
});
}
let preprocess = preprocess_reserved_identifiers(&mut statements)?;
let total_statements = statements.len();
let sanitized_schema_name = match schema_name {
Some(value) => Some(
sanitize_identifier(value).ok_or_else(|| PostgresSqlScriptError {
message: "schema_name must be a valid SQL identifier".to_string(),
status_hint: 400,
statement_index: None,
total_statements: Some(total_statements),
statement: None,
line_start: None,
line_end: None,
preprocess: preprocess.clone(),
})?,
),
None => None,
};
let mut statement_results: Vec<PostgresSqlStatementExecution> = Vec::new();
let mut rows_affected_total: u64 = 0;
let mut statement_count_total: usize = 0;
let mut last_rows: Vec<Value> = Vec::new();
match mode {
PostgresSqlTransactionMode::SingleTransaction => {
let mut transaction = pool.begin().await.map_err(|err| {
to_script_sqlx_error(
err,
None,
total_statements,
preprocess.clone(),
"Failed to open SQL transaction",
)
})?;
if let Some(schema) = sanitized_schema_name.as_deref() {
let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
sqlx::query(&set_search_path)
.execute(&mut *transaction)
.await
.map_err(|err| {
to_script_sqlx_error(
err,
None,
total_statements,
preprocess.clone(),
"Failed to set search_path for SQL execution",
)
})?;
}
for span in &statements {
let started = Instant::now();
let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
.await
.map_err(|err| {
to_script_sqlx_error(
err,
Some(span),
total_statements,
preprocess.clone(),
"SQL statement execution failed",
)
})?;
rows_affected_total += result.summary.rows_affected;
statement_count_total += result.summary.statement_count;
if !result.rows.is_empty() {
last_rows = result.rows.clone();
}
statement_results.push(PostgresSqlStatementExecution {
statement_index: span.index,
total_statements,
statement: span.statement.clone(),
line_start: span.line_start,
line_end: span.line_end,
rows_affected: result.summary.rows_affected,
returned_row_count: result.summary.returned_row_count,
duration_ms: started.elapsed().as_millis() as u64,
quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
});
}
transaction.commit().await.map_err(|err| {
to_script_sqlx_error(
err,
None,
total_statements,
preprocess.clone(),
"Failed to commit SQL transaction",
)
})?;
}
PostgresSqlTransactionMode::PerStatement => {
for span in &statements {
let mut transaction = pool.begin().await.map_err(|err| {
to_script_sqlx_error(
err,
Some(span),
total_statements,
preprocess.clone(),
"Failed to open SQL transaction",
)
})?;
if let Some(schema) = sanitized_schema_name.as_deref() {
let set_search_path = format!("SET LOCAL search_path TO {schema}, public");
sqlx::query(&set_search_path)
.execute(&mut *transaction)
.await
.map_err(|err| {
to_script_sqlx_error(
err,
Some(span),
total_statements,
preprocess.clone(),
"Failed to set search_path for SQL execution",
)
})?;
}
let started = Instant::now();
let result = execute_postgres_sql_in_transaction(&mut transaction, &span.statement)
.await
.map_err(|err| {
to_script_sqlx_error(
err,
Some(span),
total_statements,
preprocess.clone(),
"SQL statement execution failed",
)
})?;
transaction.commit().await.map_err(|err| {
to_script_sqlx_error(
err,
Some(span),
total_statements,
preprocess.clone(),
"Failed to commit SQL transaction",
)
})?;
rows_affected_total += result.summary.rows_affected;
statement_count_total += result.summary.statement_count;
if !result.rows.is_empty() {
last_rows = result.rows.clone();
}
statement_results.push(PostgresSqlStatementExecution {
statement_index: span.index,
total_statements,
statement: span.statement.clone(),
line_start: span.line_start,
line_end: span.line_end,
rows_affected: result.summary.rows_affected,
returned_row_count: result.summary.returned_row_count,
duration_ms: started.elapsed().as_millis() as u64,
quoted_reserved_identifiers: span.quoted_reserved_identifiers.clone(),
});
}
}
}
Ok(PostgresSqlScriptExecutionResult {
rows: last_rows.clone(),
summary: PostgresSqlExecutionSummary {
statement_count: statement_count_total,
rows_affected: rows_affected_total,
returned_row_count: last_rows.len(),
},
statements: statement_results,
preprocess,
})
}
fn to_script_sqlx_error(
error: sqlx::Error,
span: Option<&SqlStatementSpan>,
total_statements: usize,
preprocess: PostgresSqlPreprocessSummary,
fallback_message: &str,
) -> PostgresSqlScriptError {
let (status_hint, db_message) = match &error {
sqlx::Error::Database(db) => {
let status = db
.code()
.as_ref()
.map(|code| code.to_string())
.filter(|code| code.starts_with('4'))
.map(|_| 400)
.unwrap_or(500);
(status, db.message().to_string())
}
sqlx::Error::PoolTimedOut
| sqlx::Error::PoolClosed
| sqlx::Error::Io(_)
| sqlx::Error::Tls(_) => (503, error.to_string()),
_ => (500, error.to_string()),
};
if let Some(span) = span {
return PostgresSqlScriptError {
message: format!(
"Statement {}/{} failed at lines {}-{}: {}",
span.index, total_statements, span.line_start, span.line_end, db_message
),
status_hint,
statement_index: Some(span.index),
total_statements: Some(total_statements),
statement: Some(span.statement.clone()),
line_start: Some(span.line_start),
line_end: Some(span.line_end),
preprocess,
};
}
PostgresSqlScriptError {
message: format!("{fallback_message}: {db_message}"),
status_hint,
statement_index: None,
total_statements: Some(total_statements),
statement: None,
line_start: None,
line_end: None,
preprocess,
}
}
fn preprocess_reserved_identifiers(
statements: &mut [SqlStatementSpan],
) -> Result<PostgresSqlPreprocessSummary, PostgresSqlScriptError> {
let mut summary = PostgresSqlPreprocessSummary::default();
let total_statements = statements.len();
for span in statements.iter_mut() {
let (rewritten, rewritten_identifiers) =
preprocess_create_table_reserved_identifiers(&span.statement).map_err(|message| {
PostgresSqlScriptError {
message: format!(
"Statement {}/{} failed preprocessing at lines {}-{}: {}",
span.index, total_statements, span.line_start, span.line_end, message
),
status_hint: 400,
statement_index: Some(span.index),
total_statements: Some(total_statements),
statement: Some(span.statement.clone()),
line_start: Some(span.line_start),
line_end: Some(span.line_end),
preprocess: summary.clone(),
}
})?;
span.statement = rewritten;
span.quoted_reserved_identifiers = rewritten_identifiers.clone();
summary.rewritten_reserved_identifier_count += rewritten_identifiers.len();
summary
.rewritten_reserved_identifiers
.extend(rewritten_identifiers);
}
Ok(summary)
}
fn preprocess_create_table_reserved_identifiers(
statement: &str,
) -> Result<(String, Vec<String>), String> {
if !looks_like_create_table_statement(statement) {
return Ok((statement.to_string(), Vec::new()));
}
let Some((inner_start, inner_end)) = find_create_table_columns_bounds(statement) else {
return Ok((statement.to_string(), Vec::new()));
};
let definitions = &statement[inner_start..inner_end];
let segment_ranges = split_top_level_comma_ranges(definitions);
if segment_ranges.is_empty() {
return Ok((statement.to_string(), Vec::new()));
}
let mut rewritten_identifiers: Vec<String> = Vec::new();
let mut rewritten_defs = String::with_capacity(definitions.len() + 32);
let mut cursor = 0usize;
for (start, end) in segment_ranges {
rewritten_defs.push_str(&definitions[cursor..start]);
let segment = &definitions[start..end];
let rewritten_segment =
preprocess_column_definition_segment(segment, &mut rewritten_identifiers)?;
rewritten_defs.push_str(&rewritten_segment);
cursor = end;
}
rewritten_defs.push_str(&definitions[cursor..]);
if rewritten_identifiers.is_empty() {
return Ok((statement.to_string(), rewritten_identifiers));
}
let mut rewritten_statement = String::with_capacity(statement.len() + 32);
rewritten_statement.push_str(&statement[..inner_start]);
rewritten_statement.push_str(&rewritten_defs);
rewritten_statement.push_str(&statement[inner_end..]);
Ok((rewritten_statement, rewritten_identifiers))
}
fn preprocess_column_definition_segment(
segment: &str,
rewritten_identifiers: &mut Vec<String>,
) -> Result<String, String> {
let Some(first_non_ws_idx) = segment.find(|ch: char| !ch.is_whitespace()) else {
return Ok(segment.to_string());
};
let trimmed = &segment[first_non_ws_idx..];
let lower_trimmed = trimmed.to_ascii_lowercase();
if CREATE_TABLE_SEGMENT_GUARD_KEYWORDS
.iter()
.any(|keyword| lower_trimmed.starts_with(keyword))
{
return Ok(segment.to_string());
}
let Some((token_start, token_end, quoted)) =
parse_leading_identifier_token(segment, first_non_ws_idx)
else {
return Ok(segment.to_string());
};
if quoted {
return Ok(segment.to_string());
}
let token = &segment[token_start..token_end];
if !is_reserved_identifier(token) {
return Ok(segment.to_string());
}
if !is_safe_identifier(token) {
return Err(format!(
"Reserved identifier '{}' contains unsupported characters; only [A-Za-z_][A-Za-z0-9_]* is allowed for auto-quoting",
token
));
}
rewritten_identifiers.push(token.to_string());
let mut rewritten = String::with_capacity(segment.len() + 2);
rewritten.push_str(&segment[..token_start]);
rewritten.push('"');
rewritten.push_str(token);
rewritten.push('"');
rewritten.push_str(&segment[token_end..]);
Ok(rewritten)
}
fn parse_leading_identifier_token(input: &str, start: usize) -> Option<(usize, usize, bool)> {
let bytes = input.as_bytes();
if start >= bytes.len() {
return None;
}
if bytes[start] == b'"' {
let mut i = start + 1;
while i < bytes.len() {
if bytes[i] == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
continue;
}
return Some((start, i + 1, true));
}
i += 1;
}
return None;
}
let first = bytes[start];
if !(first.is_ascii_alphabetic() || first == b'_') {
return None;
}
let mut i = start + 1;
while i < bytes.len() {
let b = bytes[i];
if b.is_ascii_alphanumeric() || b == b'_' {
i += 1;
continue;
}
break;
}
Some((start, i, false))
}
fn split_top_level_comma_ranges(input: &str) -> Vec<(usize, usize)> {
let bytes = input.as_bytes();
let mut ranges: Vec<(usize, usize)> = Vec::new();
if bytes.is_empty() {
return ranges;
}
let mut state = SqlScannerState::Normal;
let mut depth = 0usize;
let mut start = 0usize;
let mut i = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
if bytes[i] == b'\'' {
state = SqlScannerState::SingleQuotedString;
i += 1;
continue;
}
if bytes[i] == b'"' {
state = SqlScannerState::DoubleQuotedIdentifier;
i += 1;
continue;
}
if bytes[i] == b'$'
&& let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
{
state = SqlScannerState::DollarQuoted(tag);
i += len;
continue;
}
if bytes[i] == b'(' {
depth += 1;
i += 1;
continue;
}
if bytes[i] == b')' {
depth = depth.saturating_sub(1);
i += 1;
continue;
}
if bytes[i] == b',' && depth == 0 {
ranges.push((start, i));
start = i + 1;
i += 1;
continue;
}
i += 1;
}
SqlScannerState::SingleQuotedString => {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::DoubleQuotedIdentifier => {
if bytes[i] == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::DollarQuoted(tag) => {
if matches_dollar_quote_end(bytes, i, tag) {
i += tag.len() + 2;
state = SqlScannerState::Normal;
continue;
}
i += 1;
}
}
}
ranges.push((start, input.len()));
ranges
}
fn looks_like_create_table_statement(statement: &str) -> bool {
let Some(content_start) = find_first_sql_content_start(statement) else {
return false;
};
let statement = &statement[content_start..];
let lower = statement.to_ascii_lowercase();
if !lower.starts_with("create") {
return false;
}
let Some(paren_idx) = find_first_top_level_char(statement, b'(') else {
return false;
};
lower[..paren_idx]
.split_whitespace()
.any(|token| token == "table")
}
fn find_first_sql_content_start(sql: &str) -> Option<usize> {
let bytes = sql.as_bytes();
let mut state = SqlScannerState::Normal;
let mut i = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i].is_ascii_whitespace() {
i += 1;
continue;
}
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
return Some(i);
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::SingleQuotedString
| SqlScannerState::DoubleQuotedIdentifier
| SqlScannerState::DollarQuoted(_) => return Some(i),
}
}
None
}
fn find_create_table_columns_bounds(statement: &str) -> Option<(usize, usize)> {
let bytes = statement.as_bytes();
let mut state = SqlScannerState::Normal;
let mut i = 0usize;
let mut open_idx: Option<usize> = None;
let mut depth = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
if bytes[i] == b'\'' {
state = SqlScannerState::SingleQuotedString;
i += 1;
continue;
}
if bytes[i] == b'"' {
state = SqlScannerState::DoubleQuotedIdentifier;
i += 1;
continue;
}
if bytes[i] == b'$'
&& let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
{
state = SqlScannerState::DollarQuoted(tag);
i += len;
continue;
}
if bytes[i] == b'(' {
if open_idx.is_none() {
open_idx = Some(i);
depth = 1;
} else {
depth += 1;
}
i += 1;
continue;
}
if bytes[i] == b')' && open_idx.is_some() {
depth = depth.saturating_sub(1);
if depth == 0 {
let open = open_idx?;
return Some((open + 1, i));
}
}
i += 1;
}
SqlScannerState::SingleQuotedString => {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::DoubleQuotedIdentifier => {
if bytes[i] == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::DollarQuoted(tag) => {
if matches_dollar_quote_end(bytes, i, tag) {
i += tag.len() + 2;
state = SqlScannerState::Normal;
continue;
}
i += 1;
}
}
}
None
}
fn find_first_top_level_char(sql: &str, needle: u8) -> Option<usize> {
let bytes = sql.as_bytes();
let mut state = SqlScannerState::Normal;
let mut i = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
if bytes[i] == b'\'' {
state = SqlScannerState::SingleQuotedString;
i += 1;
continue;
}
if bytes[i] == b'"' {
state = SqlScannerState::DoubleQuotedIdentifier;
i += 1;
continue;
}
if bytes[i] == b'$'
&& let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
{
state = SqlScannerState::DollarQuoted(tag);
i += len;
continue;
}
if bytes[i] == needle {
return Some(i);
}
i += 1;
}
SqlScannerState::SingleQuotedString => {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::DoubleQuotedIdentifier => {
if bytes[i] == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::DollarQuoted(tag) => {
if matches_dollar_quote_end(bytes, i, tag) {
i += tag.len() + 2;
state = SqlScannerState::Normal;
continue;
}
i += 1;
}
}
}
None
}
fn split_sql_statements_with_spans(sql: &str) -> Vec<SqlStatementSpan> {
let bytes = sql.as_bytes();
let line_offsets = build_line_offsets(sql);
let mut spans: Vec<SqlStatementSpan> = Vec::new();
let mut state = SqlScannerState::Normal;
let mut statement_start = 0usize;
let mut i = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
if bytes[i] == b'\'' {
state = SqlScannerState::SingleQuotedString;
i += 1;
continue;
}
if bytes[i] == b'"' {
state = SqlScannerState::DoubleQuotedIdentifier;
i += 1;
continue;
}
if bytes[i] == b'$'
&& let Some((tag, len)) = parse_dollar_quote_tag(bytes, i)
{
state = SqlScannerState::DollarQuoted(tag);
i += len;
continue;
}
if bytes[i] == b';' {
push_statement_span(&mut spans, sql, statement_start, i, &line_offsets);
statement_start = i + 1;
i += 1;
continue;
}
i += 1;
}
SqlScannerState::SingleQuotedString => {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::DoubleQuotedIdentifier => {
if bytes[i] == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 2;
} else {
state = SqlScannerState::Normal;
i += 1;
}
} else {
i += 1;
}
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::DollarQuoted(tag) => {
if matches_dollar_quote_end(bytes, i, tag) {
i += tag.len() + 2;
state = SqlScannerState::Normal;
continue;
}
i += 1;
}
}
}
push_statement_span(&mut spans, sql, statement_start, bytes.len(), &line_offsets);
for (idx, span) in spans.iter_mut().enumerate() {
span.index = idx + 1;
}
spans
}
fn push_statement_span(
out: &mut Vec<SqlStatementSpan>,
sql: &str,
start: usize,
end: usize,
line_offsets: &[usize],
) {
let Some((trim_start, trim_end)) = trim_bounds(sql.as_bytes(), start, end) else {
return;
};
let statement = sql[trim_start..trim_end].to_string();
if statement_is_comment_only(&statement) {
return;
}
let (line_start, _) = line_col_from_offset(line_offsets, trim_start);
let (line_end, _) = line_col_from_offset(line_offsets, trim_end.saturating_sub(1));
out.push(SqlStatementSpan {
index: out.len() + 1,
statement,
line_start,
line_end,
quoted_reserved_identifiers: Vec::new(),
});
}
fn trim_bounds(bytes: &[u8], start: usize, end: usize) -> Option<(usize, usize)> {
if start >= end || end > bytes.len() {
return None;
}
let mut trim_start = start;
while trim_start < end && bytes[trim_start].is_ascii_whitespace() {
trim_start += 1;
}
if trim_start >= end {
return None;
}
let mut trim_end = end;
while trim_end > trim_start && bytes[trim_end - 1].is_ascii_whitespace() {
trim_end -= 1;
}
if trim_end <= trim_start {
return None;
}
Some((trim_start, trim_end))
}
fn build_line_offsets(sql: &str) -> Vec<usize> {
let bytes = sql.as_bytes();
let mut offsets: Vec<usize> = vec![0];
for (idx, byte) in bytes.iter().enumerate() {
if *byte == b'\n' {
offsets.push(idx + 1);
}
}
offsets
}
fn line_col_from_offset(line_offsets: &[usize], offset: usize) -> (usize, usize) {
let idx = match line_offsets.binary_search(&offset) {
Ok(found) => found,
Err(insert_idx) => insert_idx.saturating_sub(1),
};
let line_start = line_offsets[idx];
(idx + 1, offset.saturating_sub(line_start) + 1)
}
fn parse_dollar_quote_tag(bytes: &[u8], start: usize) -> Option<(String, usize)> {
if start >= bytes.len() || bytes[start] != b'$' {
return None;
}
let mut idx = start + 1;
while idx < bytes.len() && bytes[idx] != b'$' {
let b = bytes[idx];
if !(b.is_ascii_alphanumeric() || b == b'_') {
return None;
}
idx += 1;
}
if idx >= bytes.len() || bytes[idx] != b'$' {
return None;
}
let tag = String::from_utf8(bytes[start + 1..idx].to_vec()).ok()?;
Some((tag, idx - start + 1))
}
fn matches_dollar_quote_end(bytes: &[u8], start: usize, tag: &str) -> bool {
let needed = tag.len() + 2;
if start + needed > bytes.len() || bytes[start] != b'$' {
return false;
}
let end = start + needed;
if bytes[end - 1] != b'$' {
return false;
}
bytes[start + 1..end - 1] == *tag.as_bytes()
}
fn statement_is_comment_only(statement: &str) -> bool {
let bytes = statement.as_bytes();
let mut state = SqlScannerState::Normal;
let mut i = 0usize;
while i < bytes.len() {
match &mut state {
SqlScannerState::Normal => {
if bytes[i].is_ascii_whitespace() {
i += 1;
continue;
}
if bytes[i] == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
state = SqlScannerState::LineComment;
i += 2;
continue;
}
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
state = SqlScannerState::BlockComment(1);
i += 2;
continue;
}
return false;
}
SqlScannerState::LineComment => {
if bytes[i] == b'\n' {
state = SqlScannerState::Normal;
}
i += 1;
}
SqlScannerState::BlockComment(depth_state) => {
if bytes[i] == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
*depth_state += 1;
i += 2;
continue;
}
if bytes[i] == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
*depth_state = depth_state.saturating_sub(1);
i += 2;
if *depth_state == 0 {
state = SqlScannerState::Normal;
}
continue;
}
i += 1;
}
SqlScannerState::SingleQuotedString
| SqlScannerState::DoubleQuotedIdentifier
| SqlScannerState::DollarQuoted(_) => return false,
}
}
true
}
fn is_reserved_identifier(identifier: &str) -> bool {
RESERVED_IDENTIFIER_KEYWORDS
.iter()
.any(|keyword| keyword.eq_ignore_ascii_case(identifier))
}
fn is_safe_identifier(identifier: &str) -> bool {
let mut chars = identifier.chars();
let Some(first) = chars.next() else {
return false;
};
if !(first.is_ascii_alphabetic() || first == '_') {
return false;
}
chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
}
pub async fn execute_postgres_sql_in_schema(
pool: &PgPool,
query: &str,
schema_name: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let sanitized_schema_name = sanitize_identifier(schema_name).ok_or_else(|| {
sqlx::Error::Protocol("schema_name must be a valid SQL identifier".to_string())
})?;
let mut transaction = pool.begin().await?;
let set_search_path = format!("SET LOCAL search_path TO {sanitized_schema_name}, public");
sqlx::query(&set_search_path)
.execute(&mut *transaction)
.await?;
let result = execute_postgres_sql_in_transaction(&mut transaction, query).await?;
transaction.commit().await?;
Ok(result)
}
async fn execute_postgres_sql_in_transaction(
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let normalized_query: String = normalize_sql_query(query);
let mode: PostgresSqlExecutionMode = classify_sql_query(&normalized_query);
match mode {
PostgresSqlExecutionMode::JsonRows => {
execute_json_row_query_tx(transaction, &normalized_query).await
}
PostgresSqlExecutionMode::DirectRows => {
execute_direct_row_query_tx(transaction, &normalized_query).await
}
PostgresSqlExecutionMode::Command => {
execute_command_query_tx(transaction, &normalized_query).await
}
}
}
async fn execute_json_row_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let wrapped_query: String = format!(
"WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
);
let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query).fetch_all(pool).await?;
let data: Vec<Value> = rows
.into_iter()
.filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
.map(|json| json.0)
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_json_row_query_tx(
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let wrapped_query: String = format!(
"WITH athena_query_result AS ({query}) SELECT to_jsonb(athena_query_result) AS row FROM athena_query_result"
);
let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(&wrapped_query)
.fetch_all(&mut **transaction)
.await?;
let data: Vec<Value> = rows
.into_iter()
.filter_map(|row| row.try_get::<Json<Value>, _>("row").ok())
.map(|json| json.0)
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_direct_row_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(pool).await?;
let data: Vec<Value> = rows
.into_iter()
.map(|row| row_to_json(&row))
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_direct_row_query_tx(
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let rows: Vec<sqlx::postgres::PgRow> = sqlx::query(query).fetch_all(&mut **transaction).await?;
let data: Vec<Value> = rows
.into_iter()
.map(|row| row_to_json(&row))
.collect::<Vec<_>>();
Ok(PostgresSqlExecutionResult {
summary: PostgresSqlExecutionSummary {
statement_count: 1,
rows_affected: 0,
returned_row_count: data.len(),
},
rows: data,
})
}
async fn execute_command_query(
pool: &PgPool,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let mut statement_count: usize = 0usize;
let mut rows_affected: u64 = 0u64;
let mut stream = sqlx::raw_sql(query).fetch_many(pool);
while let Some(item) = futures::StreamExt::next(&mut stream).await {
match item? {
Either::Left(result) => {
statement_count += 1;
rows_affected += result.rows_affected();
}
Either::Right(_) => {}
}
}
Ok(PostgresSqlExecutionResult {
rows: Vec::new(),
summary: PostgresSqlExecutionSummary {
statement_count,
rows_affected,
returned_row_count: 0,
},
})
}
async fn execute_command_query_tx(
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
query: &str,
) -> Result<PostgresSqlExecutionResult, sqlx::Error> {
let mut statement_count: usize = 0usize;
let mut rows_affected: u64 = 0u64;
let mut stream = sqlx::raw_sql(query).fetch_many(&mut **transaction);
while let Some(item) = futures::StreamExt::next(&mut stream).await {
match item? {
Either::Left(result) => {
statement_count += 1;
rows_affected += result.rows_affected();
}
Either::Right(_) => {}
}
}
Ok(PostgresSqlExecutionResult {
rows: Vec::new(),
summary: PostgresSqlExecutionSummary {
statement_count,
rows_affected,
returned_row_count: 0,
},
})
}
fn row_to_json(row: &PgRow) -> Value {
let mut object: serde_json::Map<String, Value> = serde_json::Map::new();
for column in row.columns() {
let value: Value = read_column_value(row, column.name());
object.insert(column.name().to_string(), value);
}
Value::Object(object)
}
fn read_column_value(row: &PgRow, name: &str) -> Value {
if let Ok(raw) = row.try_get_raw(name)
&& raw.is_null()
{
return Value::Null;
}
if let Ok(value) = row.try_get::<Option<Json<Value>>, _>(name) {
return value.map(|json| json.0).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<String>, _>(name) {
return value.map(Value::String).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<bool>, _>(name) {
return value.map(Value::Bool).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i16>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i32>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<i64>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<f32>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<f64>, _>(name) {
return value.map(|inner| json!(inner)).unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<uuid::Uuid>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveDate>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveTime>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::NaiveDateTime>, _>(name) {
return value
.map(|inner| Value::String(inner.to_string()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(name) {
return value
.map(|inner| Value::String(inner.to_rfc3339()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<chrono::DateTime<chrono::FixedOffset>>, _>(name) {
return value
.map(|inner| Value::String(inner.to_rfc3339()))
.unwrap_or(Value::Null);
}
if let Ok(value) = row.try_get::<Option<Vec<u8>>, _>(name) {
return value
.map(|inner| Value::String(String::from_utf8_lossy(&inner).to_string()))
.unwrap_or(Value::Null);
}
Value::String("<unsupported>".to_string())
}
#[cfg(test)]
mod tests {
use super::{
PostgresSqlExecutionMode, PostgresSqlPreprocessSummary, classify_sql_query,
looks_like_create_table_statement, normalize_sql_query,
preprocess_create_table_reserved_identifiers, query_contains_create_table_statement,
split_sql_statements_with_spans, to_script_sqlx_error,
};
#[test]
fn normalize_sql_query_trims_trailing_semicolons() {
assert_eq!(normalize_sql_query("SELECT 1; ; \n"), "SELECT 1");
}
#[test]
fn normalize_sql_query_keeps_inner_semicolons() {
assert_eq!(
normalize_sql_query("CREATE TABLE test (id int); INSERT INTO test VALUES (1);"),
"CREATE TABLE test (id int); INSERT INTO test VALUES (1)"
);
}
#[test]
fn classify_sql_query_detects_row_queries() {
assert_eq!(
classify_sql_query("SELECT 1;"),
PostgresSqlExecutionMode::JsonRows
);
assert_eq!(
classify_sql_query("INSERT INTO users(id) VALUES (1) RETURNING id"),
PostgresSqlExecutionMode::JsonRows
);
assert_eq!(
classify_sql_query("EXPLAIN SELECT 1"),
PostgresSqlExecutionMode::DirectRows
);
}
#[test]
fn classify_sql_query_detects_command_queries() {
assert_eq!(
classify_sql_query("CREATE TABLE test (id int);"),
PostgresSqlExecutionMode::Command
);
assert_eq!(
classify_sql_query("UPDATE users SET active = true"),
PostgresSqlExecutionMode::Command
);
}
#[test]
fn split_sql_statements_preserves_semicolons_in_strings() {
let statements = split_sql_statements_with_spans(
"INSERT INTO logs(message) VALUES ('first;second');\nSELECT 1;",
);
assert_eq!(statements.len(), 2);
assert_eq!(
statements[0].statement,
"INSERT INTO logs(message) VALUES ('first;second')"
);
assert_eq!(statements[1].statement, "SELECT 1");
assert_eq!(statements[1].line_start, 2);
}
#[test]
fn preprocess_quotes_reserved_column_identifier() {
let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
"CREATE TABLE public.demo (table text, value text);",
)
.expect("preprocess should succeed");
assert_eq!(
rewritten,
"CREATE TABLE public.demo (\"table\" text, value text);"
);
assert_eq!(identifiers, vec!["table".to_string()]);
}
#[test]
fn preprocess_keeps_non_ddl_statements_unchanged() {
assert!(!looks_like_create_table_statement("SELECT * FROM demo"));
let (rewritten, identifiers) =
preprocess_create_table_reserved_identifiers("SELECT * FROM demo")
.expect("preprocess should succeed");
assert_eq!(rewritten, "SELECT * FROM demo");
assert!(identifiers.is_empty());
}
#[test]
fn preprocess_quotes_reserved_column_identifier_with_leading_comment() {
let (rewritten, identifiers) = preprocess_create_table_reserved_identifiers(
"-- keep this comment\nCREATE TABLE athena.audit_log (table text, resource_id text);",
)
.expect("preprocess should succeed");
assert_eq!(
rewritten,
"-- keep this comment\nCREATE TABLE athena.audit_log (\"table\" text, resource_id text);"
);
assert_eq!(identifiers, vec!["table".to_string()]);
}
#[test]
fn query_contains_create_table_statement_detects_commented_and_multiline_statements() {
let sql = r#"
-- this migration adds the audit table
CREATE
TABLE athena.audit_log (
table text
);
SELECT 1;
"#;
assert!(query_contains_create_table_statement(sql));
}
#[test]
fn pool_timeout_script_errors_are_marked_service_unavailable() {
let error = to_script_sqlx_error(
sqlx::Error::PoolTimedOut,
None,
1,
PostgresSqlPreprocessSummary::default(),
"Failed to open SQL transaction",
);
assert_eq!(error.status_hint, 503);
assert!(error.message.contains("Failed to open SQL transaction"));
}
}