athena_rs 2.0.2

Database gateway API
Documentation
use regex::Regex;
use std::collections::HashSet;

use crate::api::management::types::{
    CreateIndexRequest, CreateTableRequest, ManagementColumnRequest, TableEditOperation,
};
use crate::parser::query_builder::sanitize_identifier;

pub const ALLOWED_COLUMN_DATA_TYPES: &[&str] = &[
    "text",
    "varchar(n)",
    "boolean",
    "integer",
    "bigint",
    "numeric(p,s)",
    "double precision",
    "uuid",
    "jsonb",
    "date",
    "timestamptz",
];

pub const ALLOWED_INDEX_METHODS: &[&str] = &["btree", "gin", "gist", "hash", "brin"];

pub fn table_uuid_column_name(table_name: &str) -> String {
    format!("{}_id", table_name)
}

pub fn reserved_column_names(table_name: &str) -> HashSet<String> {
    [
        "id".to_string(),
        "created_at".to_string(),
        "time".to_string(),
        table_uuid_column_name(table_name),
    ]
    .into_iter()
    .collect()
}

pub fn validate_identifier(identifier: &str, label: &str) -> Result<String, String> {
    sanitize_identifier(identifier.trim())
        .ok_or_else(|| format!("Invalid {} '{}'.", label, identifier))
}

pub fn validate_data_type(data_type: &str) -> Result<String, String> {
    let normalized = data_type.trim().to_ascii_lowercase();
    let varchar = Regex::new(r"^varchar\(\d+\)$").expect("valid regex");
    let numeric = Regex::new(r"^numeric\(\d+,\d+\)$").expect("valid regex");

    if matches!(
        normalized.as_str(),
        "text"
            | "boolean"
            | "integer"
            | "bigint"
            | "double precision"
            | "uuid"
            | "jsonb"
            | "date"
            | "timestamptz"
    ) || varchar.is_match(&normalized)
        || numeric.is_match(&normalized)
    {
        Ok(normalized)
    } else {
        Err(format!(
            "Unsupported data_type '{}'. Allowed values: {}.",
            data_type,
            ALLOWED_COLUMN_DATA_TYPES.join(", ")
        ))
    }
}

pub fn validate_default_expression(expression: &str) -> Result<String, String> {
    let trimmed = expression.trim();
    let lowered = trimmed.to_ascii_lowercase();

    let string_literal = Regex::new(r"^'(?:''|[^'])*'$").expect("valid regex");
    let integer_literal = Regex::new(r"^-?\d+$").expect("valid regex");
    let numeric_literal = Regex::new(r"^-?\d+\.\d+$").expect("valid regex");

    if matches!(
        lowered.as_str(),
        "now()" | "current_timestamp" | "gen_random_uuid()" | "extract(epoch from now())::bigint"
    ) || string_literal.is_match(trimmed)
        || integer_literal.is_match(trimmed)
        || numeric_literal.is_match(trimmed)
        || matches!(lowered.as_str(), "true" | "false" | "null")
    {
        Ok(trimmed.to_string())
    } else {
        Err(format!(
            "Unsupported default_expression '{}'. Only simple literals and whitelisted expressions are allowed.",
            expression
        ))
    }
}

fn build_column_definition(column: &ManagementColumnRequest) -> Result<String, String> {
    let column_name = validate_identifier(&column.name, "column name")?;
    let data_type = validate_data_type(&column.data_type)?;
    let default_expression = column
        .default_expression
        .as_deref()
        .map(validate_default_expression)
        .transpose()?;

    let mut parts = vec![column_name, data_type];
    if !column.nullable {
        parts.push("NOT NULL".to_string());
    }
    if let Some(default_expression) = default_expression {
        parts.push(format!("DEFAULT {}", default_expression));
    }

    Ok(parts.join(" "))
}

pub fn build_create_table_statement(request: &CreateTableRequest) -> Result<String, String> {
    let schema_name = validate_identifier(&request.schema_name, "schema name")?;
    let table_name = validate_identifier(&request.table_name, "table name")?;
    let table_uuid_column = validate_identifier(
        &table_uuid_column_name(&request.table_name),
        "generated table uuid column name",
    )?;
    let reserved = reserved_column_names(&request.table_name);

    for column in &request.columns {
        if reserved.contains(&column.name) {
            return Err(format!(
                "Column '{}' is reserved and managed by Athena.",
                column.name
            ));
        }
    }

    let mut columns = vec![
        "\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY".to_string(),
        "\"created_at\" timestamptz NOT NULL DEFAULT now()".to_string(),
        format!(
            "{} uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE",
            table_uuid_column
        ),
        "\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint".to_string(),
    ];

    for column in &request.columns {
        columns.push(build_column_definition(column)?);
    }

    Ok(format!(
        "CREATE TABLE {}.{} (\n    {}\n)",
        schema_name,
        table_name,
        columns.join(",\n    ")
    ))
}

pub fn build_edit_table_statements(
    schema_name: &str,
    table_name: &str,
    operations: &[TableEditOperation],
) -> Result<Vec<String>, String> {
    let schema_name = validate_identifier(schema_name, "schema name")?;
    let table_name = validate_identifier(table_name, "table name")?;
    let reserved = reserved_column_names(table_name.trim_matches('"'));
    let mut statements = Vec::with_capacity(operations.len());

    for operation in operations {
        let statement = match operation {
            TableEditOperation::AddColumn { column } => {
                if reserved.contains(&column.name) {
                    return Err(format!(
                        "Column '{}' is reserved and managed by Athena.",
                        column.name
                    ));
                }
                format!(
                    "ALTER TABLE {}.{} ADD COLUMN {}",
                    schema_name,
                    table_name,
                    build_column_definition(column)?
                )
            }
            TableEditOperation::RenameColumn { from, to } => {
                if reserved.contains(from) || reserved.contains(to) {
                    return Err("Reserved columns cannot be renamed.".to_string());
                }
                let from = validate_identifier(from, "column name")?;
                let to = validate_identifier(to, "column name")?;
                format!(
                    "ALTER TABLE {}.{} RENAME COLUMN {} TO {}",
                    schema_name, table_name, from, to
                )
            }
            TableEditOperation::SetDefault {
                column_name,
                default_expression,
            } => {
                let column_name = validate_identifier(column_name, "column name")?;
                let default_expression = validate_default_expression(default_expression)?;
                format!(
                    "ALTER TABLE {}.{} ALTER COLUMN {} SET DEFAULT {}",
                    schema_name, table_name, column_name, default_expression
                )
            }
            TableEditOperation::DropDefault { column_name } => {
                let column_name = validate_identifier(column_name, "column name")?;
                format!(
                    "ALTER TABLE {}.{} ALTER COLUMN {} DROP DEFAULT",
                    schema_name, table_name, column_name
                )
            }
            TableEditOperation::SetNotNull { column_name } => {
                let column_name = validate_identifier(column_name, "column name")?;
                format!(
                    "ALTER TABLE {}.{} ALTER COLUMN {} SET NOT NULL",
                    schema_name, table_name, column_name
                )
            }
            TableEditOperation::DropNotNull { column_name } => {
                let column_name = validate_identifier(column_name, "column name")?;
                format!(
                    "ALTER TABLE {}.{} ALTER COLUMN {} DROP NOT NULL",
                    schema_name, table_name, column_name
                )
            }
        };
        statements.push(statement);
    }

    Ok(statements)
}

pub fn build_drop_table_statement(
    schema_name: &str,
    table_name: &str,
    cascade: bool,
) -> Result<String, String> {
    let schema_name = validate_identifier(schema_name, "schema name")?;
    let table_name = validate_identifier(table_name, "table name")?;
    Ok(format!(
        "DROP TABLE {}.{} {}",
        schema_name,
        table_name,
        if cascade { "CASCADE" } else { "RESTRICT" }
    ))
}

pub fn build_drop_column_statement(
    schema_name: &str,
    table_name: &str,
    column_name: &str,
    cascade: bool,
) -> Result<String, String> {
    let reserved = reserved_column_names(table_name);
    if reserved.contains(column_name) {
        return Err(format!(
            "Column '{}' is reserved and cannot be dropped.",
            column_name
        ));
    }
    let schema_name = validate_identifier(schema_name, "schema name")?;
    let table_name = validate_identifier(table_name, "table name")?;
    let column_name = validate_identifier(column_name, "column name")?;
    Ok(format!(
        "ALTER TABLE {}.{} DROP COLUMN {} {}",
        schema_name,
        table_name,
        column_name,
        if cascade { "CASCADE" } else { "RESTRICT" }
    ))
}

pub fn generate_index_name(table_name: &str, columns: &[String]) -> String {
    format!("{}_{}_idx", table_name, columns.join("_"))
}

pub fn build_create_index_statement(
    request: &CreateIndexRequest,
) -> Result<(String, String), String> {
    let schema_name = validate_identifier(&request.schema_name, "schema name")?;
    let table_name = validate_identifier(&request.table_name, "table name")?;
    if request.columns.is_empty() {
        return Err("At least one index column is required.".to_string());
    }

    let method = request.method.trim().to_ascii_lowercase();
    if !ALLOWED_INDEX_METHODS.contains(&method.as_str()) {
        return Err(format!(
            "Unsupported index method '{}'. Allowed values: {}.",
            request.method,
            ALLOWED_INDEX_METHODS.join(", ")
        ));
    }

    let sanitized_columns = request
        .columns
        .iter()
        .map(|column| validate_identifier(column, "column name"))
        .collect::<Result<Vec<_>, _>>()?;
    let index_name_raw = request
        .index_name
        .clone()
        .unwrap_or_else(|| generate_index_name(&request.table_name, &request.columns));
    let index_name = validate_identifier(&index_name_raw, "index name")?;

    Ok((
        index_name_raw,
        format!(
            "CREATE {} INDEX {} ON {}.{} USING {} ({})",
            if request.unique { "UNIQUE" } else { "" },
            index_name,
            schema_name,
            table_name,
            method,
            sanitized_columns.join(", ")
        )
        .replace("CREATE  INDEX", "CREATE INDEX"),
    ))
}

pub fn build_drop_index_statement(schema_name: &str, index_name: &str) -> Result<String, String> {
    let schema_name = validate_identifier(schema_name, "schema name")?;
    let index_name = validate_identifier(index_name, "index name")?;
    Ok(format!("DROP INDEX {}.{}", schema_name, index_name))
}

#[cfg(test)]
mod tests {
    use super::{
        build_create_table_statement, generate_index_name, reserved_column_names,
        table_uuid_column_name, validate_identifier,
    };
    use crate::api::management::types::CreateTableRequest;

    #[test]
    fn create_table_statement_injects_required_columns_first() {
        let statement = build_create_table_statement(&CreateTableRequest {
            schema_name: "public".to_string(),
            table_name: "users".to_string(),
            columns: vec![],
            if_not_exists: false,
        })
        .expect("statement");

        assert!(
            statement
                .contains("\"id\" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY")
        );
        assert!(statement.contains("\"created_at\" timestamptz NOT NULL DEFAULT now()"));
        assert!(statement.contains("\"users_id\" uuid NOT NULL DEFAULT gen_random_uuid() UNIQUE"));
        assert!(
            statement
                .contains("\"time\" bigint NOT NULL DEFAULT EXTRACT(EPOCH FROM now())::bigint")
        );
    }

    #[test]
    fn reserved_columns_include_generated_uuid_column() {
        let reserved = reserved_column_names("tickets");
        assert!(reserved.contains("tickets_id"));
        assert_eq!(table_uuid_column_name("tickets"), "tickets_id");
    }

    #[test]
    fn invalid_identifier_is_rejected() {
        assert!(validate_identifier("bad-name", "table name").is_err());
    }

    #[test]
    fn generated_index_name_is_deterministic() {
        assert_eq!(
            generate_index_name("users", &["email".to_string(), "created_at".to_string()]),
            "users_email_created_at_idx"
        );
    }
}