use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashSet;
use crate::api::management::types::{
CreateIndexRequest, CreateTableRequest, ManagementColumnRequest, TableEditOperation,
};
use crate::parser::query_builder::sanitize_identifier;
pub const ALLOWED_COLUMN_DATA_TYPES: &[&str] = &[
"text",
"varchar(n)",
"boolean",
"integer",
"bigint",
"numeric(p,s)",
"double precision",
"uuid",
"jsonb",
"date",
"timestamptz",
];
pub const ALLOWED_INDEX_METHODS: &[&str] = &["btree", "gin", "gist", "hash", "brin"];
pub const ALLOWED_EXTENSIONS: &[&str] = &[
"pgcrypto",
"uuid-ossp",
"citext",
"pg_stat_statements",
"vector",
];
static RE_VARCHAR: Lazy<Regex> = Lazy::new(|| Regex::new(r"^varchar\(\d+\)$").unwrap());
static RE_NUMERIC: Lazy<Regex> = Lazy::new(|| Regex::new(r"^numeric\(\d+,\d+\)$").unwrap());
static RE_STRING_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^'(?:''|[^'])*'$").unwrap());
static RE_INTEGER_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^-?\d+$").unwrap());
static RE_NUMERIC_LITERAL: Lazy<Regex> = Lazy::new(|| Regex::new(r"^-?\d+\.\d+$").unwrap());
pub fn table_uuid_column_name(table_name: &str) -> String {
format!("{}_id", table_name)
}
pub fn reserved_column_names(table_name: &str) -> HashSet<String> {
[
"id".to_string(),
"created_at".to_string(),
"time".to_string(),
table_uuid_column_name(table_name),
]
.into_iter()
.collect()
}
pub fn validate_identifier(identifier: &str, label: &str) -> Result<String, String> {
sanitize_identifier(identifier.trim())
.ok_or_else(|| format!("Invalid {} '{}'.", label, identifier))
}
pub fn validate_data_type(data_type: &str) -> Result<String, String> {
let normalized = data_type.trim().to_ascii_lowercase();
if matches!(
normalized.as_str(),
"text"
| "boolean"
| "integer"
| "bigint"
| "double precision"
| "uuid"
| "jsonb"
| "date"
| "timestamptz"
) || RE_VARCHAR.is_match(&normalized)
|| RE_NUMERIC.is_match(&normalized)
{
Ok(normalized)
} else {
Err(format!(
"Unsupported data_type '{}'. Allowed values: {}.",
data_type,
ALLOWED_COLUMN_DATA_TYPES.join(", ")
))
}
}
pub fn validate_default_expression(expression: &str) -> Result<String, String> {
let trimmed: &str = expression.trim();
let lowered: String = trimmed.to_ascii_lowercase();
if matches!(
lowered.as_str(),
"now()" | "current_timestamp" | "gen_random_uuid()" | "extract(epoch from now())::bigint"
) || RE_STRING_LITERAL.is_match(trimmed)
|| RE_INTEGER_LITERAL.is_match(trimmed)
|| RE_NUMERIC_LITERAL.is_match(trimmed)
|| matches!(lowered.as_str(), "true" | "false" | "null")
{
Ok(trimmed.to_string())
} else {
Err(format!(
"Unsupported default_expression '{}'. Only simple literals and whitelisted expressions are allowed.",
expression
))
}
}
fn build_column_definition(column: &ManagementColumnRequest) -> Result<String, String> {
let column_name = validate_identifier(&column.name, "column name")?;
let data_type = validate_data_type(&column.data_type)?;
let default_expression = column
.default_expression
.as_deref()
.map(validate_default_expression)
.transpose()?;
let mut parts = vec![column_name, data_type];
if !column.nullable {
parts.push("NOT NULL".to_string());
}
if let Some(default_expression) = default_expression {
parts.push(format!("DEFAULT {}", default_expression));
}
Ok(parts.join(" "))
}
pub fn build_create_table_statement(request: &CreateTableRequest) -> Result<String, String> {
let schema_name = validate_identifier(&request.schema_name, "schema name")?;
let table_name = validate_identifier(&request.table_name, "table name")?;
let table_uuid_column = validate_identifier(
&table_uuid_column_name(&request.table_name),
"generated table uuid column name",
)?;
let reserved = reserved_column_names(&request.table_name);
for column in &request.columns {
if reserved.contains(&column.name) {
return Err(format!(
"Column '{}' is reserved and managed by Athena.",
column.name
));
}
}
let mut columns = vec![
"\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY".to_string(),
"\"created_at\" timestamptz NOT NULL DEFAULT now()".to_string(),
format!(
"{} uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE",
table_uuid_column
),
"\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint".to_string(),
];
for column in &request.columns {
columns.push(build_column_definition(column)?);
}
Ok(format!(
"CREATE TABLE {}.{} (\n {}\n)",
schema_name,
table_name,
columns.join(",\n ")
))
}
pub fn build_edit_table_statements(
schema_name: &str,
table_name: &str,
operations: &[TableEditOperation],
) -> Result<Vec<String>, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
let reserved: HashSet<String> = reserved_column_names(table_name.trim_matches('"'));
let mut statements: Vec<String> = Vec::with_capacity(operations.len());
for operation in operations {
let statement: String = match operation {
TableEditOperation::AddColumn { column } => {
if reserved.contains(&column.name) {
return Err(format!(
"Column '{}' is reserved and managed by Athena.",
column.name
));
}
format!(
"ALTER TABLE {}.{} ADD COLUMN {}",
schema_name,
table_name,
build_column_definition(column)?
)
}
TableEditOperation::RenameColumn { from, to } => {
if reserved.contains(from) || reserved.contains(to) {
return Err("Reserved columns cannot be renamed.".to_string());
}
let from: String = validate_identifier(from, "column name")?;
let to: String = validate_identifier(to, "column name")?;
format!(
"ALTER TABLE {}.{} RENAME COLUMN {} TO {}",
schema_name, table_name, from, to
)
}
TableEditOperation::SetDefault {
column_name,
default_expression,
} => {
let column_name: String = validate_identifier(column_name, "column name")?;
let default_expression: String = validate_default_expression(default_expression)?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} SET DEFAULT {}",
schema_name, table_name, column_name, default_expression
)
}
TableEditOperation::DropDefault { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} DROP DEFAULT",
schema_name, table_name, column_name
)
}
TableEditOperation::SetNotNull { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} SET NOT NULL",
schema_name, table_name, column_name
)
}
TableEditOperation::DropNotNull { column_name } => {
let column_name = validate_identifier(column_name, "column name")?;
format!(
"ALTER TABLE {}.{} ALTER COLUMN {} DROP NOT NULL",
schema_name, table_name, column_name
)
}
};
statements.push(statement);
}
Ok(statements)
}
pub fn build_drop_table_statement(
schema_name: &str,
table_name: &str,
cascade: bool,
) -> Result<String, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
Ok(format!(
"DROP TABLE {}.{} {}",
schema_name,
table_name,
if cascade { "CASCADE" } else { "RESTRICT" }
))
}
pub fn build_drop_column_statement(
schema_name: &str,
table_name: &str,
column_name: &str,
cascade: bool,
) -> Result<String, String> {
let reserved: HashSet<String> = reserved_column_names(table_name);
if reserved.contains(column_name) {
return Err(format!(
"Column '{}' is reserved and cannot be dropped.",
column_name
));
}
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let table_name: String = validate_identifier(table_name, "table name")?;
let column_name: String = validate_identifier(column_name, "column name")?;
Ok(format!(
"ALTER TABLE {}.{} DROP COLUMN {} {}",
schema_name,
table_name,
column_name,
if cascade { "CASCADE" } else { "RESTRICT" }
))
}
pub fn generate_index_name(table_name: &str, columns: &[String]) -> String {
format!("{}_{}_idx", table_name, columns.join("_"))
}
pub fn build_create_index_statement(
request: &CreateIndexRequest,
) -> Result<(String, String), String> {
let schema_name: String = validate_identifier(&request.schema_name, "schema name")?;
let table_name: String = validate_identifier(&request.table_name, "table name")?;
if request.columns.is_empty() {
return Err("At least one index column is required.".to_string());
}
let method: String = request.method.trim().to_ascii_lowercase();
if !ALLOWED_INDEX_METHODS.contains(&method.as_str()) {
return Err(format!(
"Unsupported index method '{}'. Allowed values: {}.",
request.method,
ALLOWED_INDEX_METHODS.join(", ")
));
}
let sanitized_columns: Vec<String> = request
.columns
.iter()
.map(|column| validate_identifier(column, "column name"))
.collect::<Result<Vec<_>, _>>()?;
let index_name_raw: String = request
.index_name
.clone()
.unwrap_or_else(|| generate_index_name(&request.table_name, &request.columns));
let index_name = validate_identifier(&index_name_raw, "index name")?;
Ok((
index_name_raw,
format!(
"CREATE {} INDEX {} ON {}.{} USING {} ({})",
if request.unique { "UNIQUE" } else { "" },
index_name,
schema_name,
table_name,
method,
sanitized_columns.join(", ")
)
.replace("CREATE INDEX", "CREATE INDEX"),
))
}
pub fn build_drop_index_statement(schema_name: &str, index_name: &str) -> Result<String, String> {
let schema_name: String = validate_identifier(schema_name, "schema name")?;
let index_name: String = validate_identifier(index_name, "index name")?;
Ok(format!("DROP INDEX {}.{}", schema_name, index_name))
}
pub fn build_create_extension_statement(
extension_name: &str,
if_not_exists: bool,
) -> Result<String, String> {
let normalized: String = extension_name.trim().to_ascii_lowercase();
if normalized.is_empty() {
return Err("Extension name must not be empty.".to_string());
}
if !ALLOWED_EXTENSIONS.contains(&normalized.as_str()) {
return Err(format!(
"Unsupported extension '{}'. Allowed values: {}.",
extension_name,
ALLOWED_EXTENSIONS.join(", ")
));
}
let extension_ident: String = validate_identifier(&normalized, "extension name")?;
Ok(format!(
"CREATE EXTENSION {}{}",
if if_not_exists { "IF NOT EXISTS " } else { "" },
extension_ident
))
}
#[cfg(test)]
mod tests {
use super::{
build_create_table_statement, generate_index_name, reserved_column_names,
table_uuid_column_name, validate_identifier,
};
use crate::api::management::types::CreateTableRequest;
#[test]
fn create_table_statement_injects_required_columns_first() {
let statement: String = build_create_table_statement(&CreateTableRequest {
schema_name: "public".to_string(),
table_name: "users".to_string(),
columns: vec![],
if_not_exists: false,
})
.expect("statement");
assert!(
statement
.contains("\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY")
);
assert!(statement.contains("\"created_at\" timestamptz NOT NULL DEFAULT now()"));
assert!(statement.contains("\"users_id\" uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE"));
assert!(
statement
.contains("\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint")
);
}
#[test]
fn reserved_columns_include_generated_uuid_column() {
let reserved = reserved_column_names("tickets");
assert!(reserved.contains("tickets_id"));
assert_eq!(table_uuid_column_name("tickets"), "tickets_id");
}
#[test]
fn invalid_identifier_is_rejected() {
assert!(validate_identifier("bad-name", "table name").is_err());
}
#[test]
fn generated_index_name_is_deterministic() {
assert_eq!(
generate_index_name("users", &["email".to_string(), "created_at".to_string()]),
"users_email_created_at_idx"
);
}
}