use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashSet;
use crate::api::management::types::{
CreateIndexRequest, CreateTableRequest, ManagementColumnRequest, TableEditOperation,
};
use crate::parser::query_builder::sanitize_identifier;
pub const ALLOWED_COLUMN_DATA_TYPES: &[&str] = &[
"text",
"varchar(n)",
"boolean",
"integer",
"bigint",
"numeric(p,s)",
"double precision",
"uuid",
"jsonb",
"date",
"timestamptz",
];
pub const ALLOWED_INDEX_METHODS: &[&str] = &["btree", "gin", "gist", "hash", "brin"];
pub const ALLOWED_EXTENSIONS: &[&str] = &[
"pgcrypto",
"uuid-ossp",
"citext",
"pg_stat_statements",
"vector",
];
static RE_VARCHAR: Lazy<Regex> = Lazy::new(|| Regex::new(r"^varchar\(\d+\)$").unwrap());
static RE_NUMERIC: Lazy<Regex> = Lazy::new(|| Regex::new(r"^numeric\(\d+,\d+\)$").unwrap());
static RE_STRING_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^'(?:''|[^'])*'$").unwrap());
static RE_INTEGER_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^-?\d+$").unwrap());
static RE_NUMERIC_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^-?\d+\.\d+$").unwrap());
static RE_CREATE_OR_REPLACE_FUNCTION: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r#"(?is)^\s*create\s+or\s+replace\s+function\s+((?:(?:"?[a-zA-Z_][a-zA-Z0-9_]*"?)[.])?(?:"?[a-zA-Z_][a-zA-Z0-9_]*"?))\s*\("#,
)
.unwrap()
});
static RE_TYPE_FRAGMENT: Lazy<Regex> =
Lazy::new(|| Regex::new(r#"^[a-zA-Z_"][a-zA-Z0-9_"\[\]\s(),.]*$"#).unwrap());
static RE_WHITESPACE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\s+").unwrap());
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedFunctionDdlTarget {
pub schema_name: String,
pub function_name: String,
pub identity_signature: String,
}
pub fn table_uuid_column_name(table_name: &str) -> String {
format!("{}_id", table_name)
}
pub fn reserved_column_names(table_name: &str) -> HashSet<String> {
[
"id".to_string(),
"created_at".to_string(),
"time".to_string(),
table_uuid_column_name(table_name),
]
.into_iter()
.collect()
}
pub fn validate_identifier(identifier: &str, label: &str) -> Result<String, String> {
sanitize_identifier(identifier.trim())
.ok_or_else(|| format!("Invalid {} '{}'.", label, identifier))
}
pub fn validate_data_type(data_type: &str) -> Result<String, String> {
let normalized = data_type.trim().to_ascii_lowercase();
if matches!(
normalized.as_str(),
"text"
| "boolean"
| "integer"
| "bigint"
| "double precision"
| "uuid"
| "jsonb"
| "date"
| "timestamptz"
) || RE_VARCHAR.is_match(&normalized)
|| RE_NUMERIC.is_match(&normalized)
{
Ok(normalized)
} else {
Err(format!(
"Unsupported data_type '{}'. Allowed values: {}.",
data_type,
ALLOWED_COLUMN_DATA_TYPES.join(", ")
))
}
}
pub fn validate_default_expression(expression: &str) -> Result<String, String> {
let trimmed: &str = expression.trim();
let lowered: String = trimmed.to_ascii_lowercase();
if matches!(
lowered.as_str(),
"now()" | "current_timestamp" | "gen_random_uuid()" | "extract(epoch from now())::bigint"
) || RE_STRING_LITERAL.is_match(trimmed)
|| RE_INTEGER_LITERAL.is_match(trimmed)
|| RE_NUMERIC_LITERAL.is_match(trimmed)
|| matches!(lowered.as_str(), "true" | "false" | "null")
{
Ok(trimmed.to_string())
} else {
Err(format!(
"Unsupported default_expression '{}'. Only simple literals and whitelisted expressions are allowed.",
expression
))
}
}
fn parse_dollar_tag_at(input: &str, start: usize) -> Option<String> {
let bytes = input.as_bytes();
if start >= bytes.len() || bytes[start] != b'$' {
return None;
}
let mut end = start + 1;
while end < bytes.len() {
let byte = bytes[end];
if byte == b'$' {
return Some(input[start..=end].to_string());
}
if !byte.is_ascii_alphanumeric() && byte != b'_' {
return None;
}
end += 1;
}
None
}
fn has_only_whitespace_or_comments(fragment: &str) -> bool {
let bytes = fragment.as_bytes();
let mut i = 0usize;
while i < bytes.len() {
let current = bytes[i];
let next = bytes.get(i + 1).copied();
if current.is_ascii_whitespace() {
i += 1;
continue;
}
if current == b'-' && next == Some(b'-') {
i += 2;
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
continue;
}
if current == b'/' && next == Some(b'*') {
i += 2;
let mut found_end = false;
while i + 1 < bytes.len() {
if bytes[i] == b'*' && bytes[i + 1] == b'/' {
i += 2;
found_end = true;
break;
}
i += 1;
}
if !found_end {
return false;
}
continue;
}
return false;
}
true
}
fn find_matching_paren(input: &str, open_paren_index: usize) -> Option<usize> {
let bytes = input.as_bytes();
if bytes.get(open_paren_index) != Some(&b'(') {
return None;
}
let mut i = open_paren_index + 1;
let mut depth = 1usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut in_line_comment = false;
let mut block_comment_depth = 0usize;
let mut dollar_tag: Option<String> = None;
while i < bytes.len() {
let current = bytes[i];
let next = bytes.get(i + 1).copied();
if in_line_comment {
if current == b'\n' {
in_line_comment = false;
}
i += 1;
continue;
}
if block_comment_depth > 0 {
if current == b'/' && next == Some(b'*') {
block_comment_depth += 1;
i += 2;
continue;
}
if current == b'*' && next == Some(b'/') {
block_comment_depth -= 1;
i += 2;
continue;
}
i += 1;
continue;
}
if let Some(tag) = dollar_tag.as_deref() {
if input[i..].starts_with(tag) {
i += tag.len();
dollar_tag = None;
continue;
}
i += 1;
continue;
}
if in_single_quote {
if current == b'\'' {
if next == Some(b'\'') {
i += 2;
} else {
in_single_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if in_double_quote {
if current == b'"' {
if next == Some(b'"') {
i += 2;
} else {
in_double_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if current == b'-' && next == Some(b'-') {
in_line_comment = true;
i += 2;
continue;
}
if current == b'/' && next == Some(b'*') {
block_comment_depth = 1;
i += 2;
continue;
}
if current == b'\'' {
in_single_quote = true;
i += 1;
continue;
}
if current == b'"' {
in_double_quote = true;
i += 1;
continue;
}
if current == b'$'
&& let Some(tag) = parse_dollar_tag_at(input, i)
{
i += tag.len();
dollar_tag = Some(tag);
continue;
}
if current == b'(' {
depth += 1;
i += 1;
continue;
}
if current == b')' {
depth = depth.saturating_sub(1);
if depth == 0 {
return Some(i);
}
i += 1;
continue;
}
i += 1;
}
None
}
fn is_single_statement_sql(ddl: &str) -> bool {
let bytes = ddl.as_bytes();
let mut i = 0usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut in_line_comment = false;
let mut block_comment_depth = 0usize;
let mut dollar_tag: Option<String> = None;
while i < bytes.len() {
let current = bytes[i];
let next = bytes.get(i + 1).copied();
if in_line_comment {
if current == b'\n' {
in_line_comment = false;
}
i += 1;
continue;
}
if block_comment_depth > 0 {
if current == b'/' && next == Some(b'*') {
block_comment_depth += 1;
i += 2;
continue;
}
if current == b'*' && next == Some(b'/') {
block_comment_depth -= 1;
i += 2;
continue;
}
i += 1;
continue;
}
if let Some(tag) = dollar_tag.as_deref() {
if ddl[i..].starts_with(tag) {
i += tag.len();
dollar_tag = None;
continue;
}
i += 1;
continue;
}
if in_single_quote {
if current == b'\'' {
if next == Some(b'\'') {
i += 2;
} else {
in_single_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if in_double_quote {
if current == b'"' {
if next == Some(b'"') {
i += 2;
} else {
in_double_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if current == b'-' && next == Some(b'-') {
in_line_comment = true;
i += 2;
continue;
}
if current == b'/' && next == Some(b'*') {
block_comment_depth = 1;
i += 2;
continue;
}
if current == b'\'' {
in_single_quote = true;
i += 1;
continue;
}
if current == b'"' {
in_double_quote = true;
i += 1;
continue;
}
if current == b'$'
&& let Some(tag) = parse_dollar_tag_at(ddl, i)
{
i += tag.len();
dollar_tag = Some(tag);
continue;
}
if current == b';' {
if !has_only_whitespace_or_comments(&ddl[i + 1..]) {
return false;
}
break;
}
i += 1;
}
true
}
fn split_top_level_commas(input: &str) -> Vec<String> {
let mut segments: Vec<String> = Vec::new();
let mut current = String::new();
let mut depth = 0usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
let bytes = input.as_bytes();
let mut i = 0usize;
while i < bytes.len() {
let current_byte = bytes[i];
let next = bytes.get(i + 1).copied();
if in_single_quote {
current.push(current_byte as char);
if current_byte == b'\'' {
if next == Some(b'\'') {
current.push('\'');
i += 2;
} else {
in_single_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if in_double_quote {
current.push(current_byte as char);
if current_byte == b'"' {
if next == Some(b'"') {
current.push('"');
i += 2;
} else {
in_double_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if current_byte == b'\'' {
in_single_quote = true;
current.push('\'');
i += 1;
continue;
}
if current_byte == b'"' {
in_double_quote = true;
current.push('"');
i += 1;
continue;
}
if current_byte == b'(' {
depth += 1;
current.push('(');
i += 1;
continue;
}
if current_byte == b')' {
depth = depth.saturating_sub(1);
current.push(')');
i += 1;
continue;
}
if current_byte == b',' && depth == 0 {
let fragment = current.trim();
if !fragment.is_empty() {
segments.push(fragment.to_string());
}
current.clear();
i += 1;
continue;
}
current.push(current_byte as char);
i += 1;
}
let fragment = current.trim();
if !fragment.is_empty() {
segments.push(fragment.to_string());
}
segments
}
fn trim_function_arg_default(input: &str) -> String {
let bytes = input.as_bytes();
let mut i = 0usize;
let mut depth = 0usize;
let mut in_single_quote = false;
let mut in_double_quote = false;
while i < bytes.len() {
let current = bytes[i];
let next = bytes.get(i + 1).copied();
if in_single_quote {
if current == b'\'' {
if next == Some(b'\'') {
i += 2;
} else {
in_single_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if in_double_quote {
if current == b'"' {
if next == Some(b'"') {
i += 2;
} else {
in_double_quote = false;
i += 1;
}
} else {
i += 1;
}
continue;
}
if current == b'\'' {
in_single_quote = true;
i += 1;
continue;
}
if current == b'"' {
in_double_quote = true;
i += 1;
continue;
}
if current == b'(' {
depth += 1;
i += 1;
continue;
}
if current == b')' {
depth = depth.saturating_sub(1);
i += 1;
continue;
}
if depth == 0 && current == b'=' {
return input[..i].trim().to_string();
}
if depth == 0 && i + 7 <= bytes.len() {
let is_default_keyword = bytes[i..(i + 7)]
.iter()
.zip(b"default")
.all(|(candidate, expected)| candidate.to_ascii_lowercase() == *expected);
if is_default_keyword
&& (i == 0 || bytes[i - 1].is_ascii_whitespace())
&& (i + 7 == bytes.len() || bytes[i + 7].is_ascii_whitespace())
{
return input[..i].trim().to_string();
}
}
i += 1;
}
input.trim().to_string()
}
fn normalize_signature(signature: &str) -> String {
RE_WHITESPACE.replace_all(signature.trim(), " ").to_string()
}
fn parse_identity_signature(args_segment: &str) -> Result<String, String> {
let trimmed = args_segment.trim();
if trimmed.is_empty() {
return Ok(String::new());
}
let fragments = split_top_level_commas(trimmed)
.into_iter()
.map(|fragment| trim_function_arg_default(&fragment))
.map(|fragment| normalize_signature(&fragment))
.filter(|fragment| !fragment.is_empty())
.collect::<Vec<String>>();
if fragments.is_empty() {
return Err("Failed to parse function argument signature from DDL.".to_string());
}
Ok(fragments.join(", "))
}
pub fn validate_create_or_replace_function_ddl(
ddl: &str,
) -> Result<ParsedFunctionDdlTarget, String> {
let normalized = ddl.trim();
if normalized.is_empty() {
return Err("DDL statement cannot be empty.".to_string());
}
if !is_single_statement_sql(normalized) {
return Err("Function DDL must contain exactly one SQL statement.".to_string());
}
let captures = RE_CREATE_OR_REPLACE_FUNCTION
.captures(normalized)
.ok_or_else(|| {
"Only CREATE OR REPLACE FUNCTION statements are allowed for this endpoint.".to_string()
})?;
let target = captures
.get(1)
.map(|m| m.as_str())
.ok_or_else(|| "Failed to resolve function target from DDL.".to_string())?;
let open_paren_index = captures
.get(0)
.map(|m| m.end().saturating_sub(1))
.ok_or_else(|| "Failed to resolve function signature from DDL.".to_string())?;
let close_paren_index = find_matching_paren(normalized, open_paren_index)
.ok_or_else(|| "Function DDL has an unbalanced argument list.".to_string())?;
let identity_signature =
parse_identity_signature(&normalized[(open_paren_index + 1)..close_paren_index])?;
let (schema_name, function_name) =
if let Some((schema, function_name)) = target.rsplit_once('.') {
(
schema.trim().trim_matches('"'),
function_name.trim().trim_matches('"'),
)
} else {
("public", target.trim().trim_matches('"'))
};
let schema_name = validate_identifier(schema_name, "schema name")?;
let function_name = validate_identifier(function_name, "function name")?;
Ok(ParsedFunctionDdlTarget {
schema_name,
function_name,
identity_signature,
})
}
pub fn validate_function_arg_types(arg_types: &[String]) -> Result<Vec<String>, String> {
if arg_types.is_empty() {
return Ok(Vec::new());
}
let mut sanitized: Vec<String> = Vec::with_capacity(arg_types.len());
for raw in arg_types {
let value = raw.trim();
if value.is_empty() {
return Err("Function arg_types cannot contain empty values.".to_string());
}
if value.contains(';')
|| value.contains("--")
|| value.contains("/*")
|| value.contains("*/")
{
return Err(format!("Function argument type '{}' is invalid.", raw));
}
if !RE_TYPE_FRAGMENT.is_match(value) {
return Err(format!("Function argument type '{}' is invalid.", raw));
}
sanitized.push(value.to_string());
}
Ok(sanitized)
}
fn build_column_definition(column: &ManagementColumnRequest) -> Result<String, String> {
let column_name = validate_identifier(&column.name, "column name")?;
let data_type = validate_data_type(&column.data_type)?;
let default_expression = column
.default_expression
.as_deref()
.map(validate_default_expression)
.transpose()?;
let mut parts = vec![column_name, data_type];
if !column.nullable {
parts.push("NOT NULL".to_string());
}
if let Some(default_expression) = default_expression {
parts.push(format!("DEFAULT {}", default_expression));
}
Ok(parts.join(" "))
}
pub fn build_create_table_statement(request: &CreateTableRequest) -> Result<String, String> {
let schema_name = validate_identifier(&request.schema_name, "schema name")?;
let table_name = validate_identifier(&request.table_name, "table name")?;
let table_uuid_column = validate_identifier(
&table_uuid_column_name(&request.table_name),
"generated table uuid column name",
)?;
let reserved = reserved_column_names(&request.table_name);
for column in &request.columns {
if reserved.contains(&column.name) {
return Err(format!(
"Column '{}' is reserved and managed by Athena.",
column.name
));
}
}
let mut columns = vec![
"\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY".to_string(),
"\"created_at\" timestamptz NOT NULL DEFAULT now()".to_string(),
format!(
"{} uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE",
table_uuid_column
),
"\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint".to_string(),
];
for column in &request.columns {
columns.push(build_column_definition(column)?);
}
Ok(format!(
"CREATE TABLE {}.{} (\n {}\n)",
schema_name,
table_name,
columns.join(",\n ")
))
}
pub fn build_edit_table_statements(
schema_name: &str,
table_name: &str,
operations: &[TableEditOperation],
) -> Result<Vec<String>, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
let reserved: HashSet<String> = reserved_column_names(table_name.trim_matches('"'));
let mut statements: Vec<String> = Vec::with_capacity(operations.len());
for operation in operations {
let statement: String = match operation {
TableEditOperation::AddColumn { column } => {
if reserved.contains(&column.name) {
return Err(format!(
"Column '{}' is reserved and managed by Athena.",
column.name
));
}
format!(
"ALTER TABLE {}.{} ADD COLUMN {}",
schema_name,
table_name,
build_column_definition(column)?
)
}
TableEditOperation::RenameColumn { from, to } => {
if reserved.contains(from) || reserved.contains(to) {
return Err("Reserved columns cannot be renamed.".to_string());
}
let from: String = validate_identifier(from, "column name")?;
let to: String = validate_identifier(to, "column name")?;
format!(
"ALTER TABLE {}.{} RENAME COLUMN {} TO {}",
schema_name, table_name, from, to
)
}
TableEditOperation::SetDefault {
column_name,
default_expression,
} => {
let column_name: String = validate_identifier(column_name, "column name")?;
let default_expression: String = validate_default_expression(default_expression)?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} SET DEFAULT {}",
schema_name, table_name, column_name, default_expression
)
}
TableEditOperation::DropDefault { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} DROP DEFAULT",
schema_name, table_name, column_name
)
}
TableEditOperation::SetNotNull { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} SET NOT NULL",
schema_name, table_name, column_name
)
}
TableEditOperation::DropNotNull { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} DROP NOT NULL",
schema_name, table_name, column_name
)
}
};
statements.push(statement);
}
Ok(statements)
}
pub fn build_drop_table_statement(
schema_name: &str,
table_name: &str,
cascade: bool,
) -> Result<String, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
Ok(format!(
"DROP TABLE {}.{} {}",
schema_name,
table_name,
if cascade { "CASCADE" } else { "RESTRICT" }
))
}
pub fn build_drop_column_statement(
schema_name: &str,
table_name: &str,
column_name: &str,
cascade: bool,
) -> Result<String, String> {
let reserved: HashSet<String> = reserved_column_names(table_name);
if reserved.contains(column_name) {
return Err(format!(
"Column '{}' is reserved and cannot be dropped.",
column_name
));
}
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
let column_name: String = validate_identifier(column_name, "column name")?;
Ok(format!(
"ALTER TABLE {}.{} DROP COLUMN {} {}",
schema_name,
table_name,
column_name,
if cascade { "CASCADE" } else { "RESTRICT" }
))
}
pub fn generate_index_name(table_name: &str, columns: &[String]) -> String {
format!("{}_{}_idx", table_name, columns.join("_"))
}
pub fn build_create_index_statement(
request: &CreateIndexRequest,
) -> Result<(String, String), String> {
let schema_name: String = validate_identifier(&request.schema_name, "schema name")?;
let table_name: String = validate_identifier(&request.table_name, "table name")?;
if request.columns.is_empty() {
return Err("At least one index column is required.".to_string());
}
let method: String = request.method.trim().to_ascii_lowercase();
if !ALLOWED_INDEX_METHODS.contains(&method.as_str()) {
return Err(format!(
"Unsupported index method '{}'. Allowed values: {}.",
request.method,
ALLOWED_INDEX_METHODS.join(", ")
));
}
let sanitized_columns: Vec<String> = request
.columns
.iter()
.map(|column| validate_identifier(column, "column name"))
.collect::<Result<Vec<_>, _>>()?;
let index_name_raw: String = request
.index_name
.clone()
.unwrap_or_else(|| generate_index_name(&request.table_name, &request.columns));
let index_name = validate_identifier(&index_name_raw, "index name")?;
Ok((
index_name_raw,
format!(
"CREATE {} INDEX {} ON {}.{} USING {} ({})",
if request.unique { "UNIQUE" } else { "" },
index_name,
schema_name,
table_name,
method,
sanitized_columns.join(", ")
)
.replace("CREATE INDEX", "CREATE INDEX"),
))
}
pub fn build_drop_index_statement(schema_name: &str, index_name: &str) -> Result<String, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let index_name: String = validate_identifier(index_name, "index name")?;
Ok(format!("DROP INDEX {}.{}", schema_name, index_name))
}
pub fn build_create_extension_statement(
extension_name: &str,
if_not_exists: bool,
) -> Result<String, String> {
let normalized: String = extension_name.trim().to_ascii_lowercase();
if normalized.is_empty() {
return Err("Extension name must not be empty.".to_string());
}
if !ALLOWED_EXTENSIONS.contains(&normalized.as_str()) {
return Err(format!(
"Unsupported extension '{}'. Allowed values: {}.",
extension_name,
ALLOWED_EXTENSIONS.join(", ")
));
}
let extension_ident: String = validate_identifier(&normalized, "extension name")?;
Ok(format!(
"CREATE EXTENSION {}{}",
if if_not_exists { "IF NOT EXISTS " } else { "" },
extension_ident
))
}
#[cfg(test)]
mod tests {
use super::{
build_create_table_statement, generate_index_name, reserved_column_names,
table_uuid_column_name, validate_create_or_replace_function_ddl,
validate_function_arg_types, validate_identifier,
};
use crate::api::management::types::CreateTableRequest;
#[test]
fn create_table_statement_injects_required_columns_first() {
let statement: String = build_create_table_statement(&CreateTableRequest {
schema_name: "public".to_string(),
table_name: "users".to_string(),
columns: vec![],
if_not_exists: false,
})
.expect("statement");
assert!(
statement
.contains("\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY")
);
assert!(statement.contains("\"created_at\" timestamptz NOT NULL DEFAULT now()"));
assert!(statement.contains("\"users_id\" uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE"));
assert!(
statement
.contains("\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint")
);
}
#[test]
fn reserved_columns_include_generated_uuid_column() {
let reserved = reserved_column_names("tickets");
assert!(reserved.contains("tickets_id"));
assert_eq!(table_uuid_column_name("tickets"), "tickets_id");
}
#[test]
fn invalid_identifier_is_rejected() {
assert!(validate_identifier("bad-name", "table name").is_err());
}
#[test]
fn generated_index_name_is_deterministic() {
assert_eq!(
generate_index_name("users", &["email".to_string(), "created_at".to_string()]),
"users_email_created_at_idx"
);
}
#[test]
fn validate_function_ddl_accepts_create_or_replace() {
let parsed = validate_create_or_replace_function_ddl(
r#"
CREATE OR REPLACE FUNCTION public.hello_world()
RETURNS text
LANGUAGE sql
AS $$ SELECT 'Hello'; $$;
"#,
)
.expect("valid function ddl");
assert_eq!(parsed.schema_name, "\"public\"");
assert_eq!(parsed.function_name, "\"hello_world\"");
assert_eq!(parsed.identity_signature, "");
}
#[test]
fn validate_function_ddl_extracts_identity_signature() {
let parsed = validate_create_or_replace_function_ddl(
r#"
CREATE OR REPLACE FUNCTION public.echo_city(name text, suffix text DEFAULT '!')
RETURNS text
LANGUAGE sql
AS $$ SELECT name || suffix; $$;
"#,
)
.expect("valid function ddl");
assert_eq!(parsed.identity_signature, "name text, suffix text");
}
#[test]
fn validate_function_ddl_rejects_non_function_ddl() {
let err = validate_create_or_replace_function_ddl("DROP TABLE users;").unwrap_err();
assert!(err.contains("CREATE OR REPLACE FUNCTION"));
}
#[test]
fn validate_function_ddl_rejects_multi_statement_input() {
let err = validate_create_or_replace_function_ddl(
r#"
CREATE OR REPLACE FUNCTION public.hello_world() RETURNS text
LANGUAGE sql AS $$ SELECT 'ok'; $$;
DROP TABLE users;
"#,
)
.unwrap_err();
assert!(err.contains("exactly one SQL statement"));
}
#[test]
fn validate_function_ddl_allows_semicolon_inside_dollar_quoted_body() {
let parsed = validate_create_or_replace_function_ddl(
r#"
CREATE OR REPLACE FUNCTION public.fn_with_body(name text)
RETURNS text
LANGUAGE plpgsql
AS $fn$
BEGIN
RETURN name || ';ok';
END;
$fn$;
"#,
)
.expect("valid function ddl");
assert_eq!(parsed.function_name, "\"fn_with_body\"");
assert_eq!(parsed.identity_signature, "name text");
}
#[test]
fn validate_function_ddl_allows_trailing_comments_after_statement() {
let parsed = validate_create_or_replace_function_ddl(
r#"
CREATE OR REPLACE FUNCTION public.comment_tail()
RETURNS text
LANGUAGE sql
AS $$ SELECT 'ok'; $$;
-- trailing migration comment
"#,
)
.expect("valid function ddl");
assert_eq!(parsed.schema_name, "\"public\"");
assert_eq!(parsed.function_name, "\"comment_tail\"");
}
#[test]
fn validate_function_arg_types_rejects_sql_comment_tokens() {
let err = validate_function_arg_types(&["text -- inject".to_string()]).unwrap_err();
assert!(err.contains("invalid"));
}
}