use std::fmt;
use reinhardt_query::prelude::{
Alias, AlterTableStatement, ColumnDef, CreateIndexStatement, CreateTableStatement,
DropIndexStatement, DropTableStatement, MySqlQueryBuilder, PostgresQueryBuilder, Query,
QueryBuilder, SqliteQueryBuilder,
};
pub mod ddl_references;
pub mod factory;
#[derive(Debug, Clone, PartialEq)]
pub enum DDLStatement {
CreateTable {
table: String,
columns: Vec<(String, String)>,
},
AlterTable {
table: String,
changes: Vec<AlterTableChange>,
},
DropTable {
table: String,
cascade: bool,
},
CreateIndex {
name: String,
table: String,
columns: Vec<String>,
unique: bool,
condition: Option<String>,
},
DropIndex {
name: String,
},
CreateSchema {
name: String,
if_not_exists: bool,
},
DropSchema {
name: String,
cascade: bool,
if_exists: bool,
},
RawSQL(String),
}
impl DDLStatement {
pub fn table_name(&self) -> &str {
match self {
DDLStatement::CreateTable { table, .. } => table,
DDLStatement::AlterTable { table, .. } => table,
DDLStatement::DropTable { table, .. } => table,
DDLStatement::CreateIndex { table, .. } => table,
DDLStatement::DropIndex { .. } => "",
DDLStatement::CreateSchema { .. } => "",
DDLStatement::DropSchema { .. } => "",
DDLStatement::RawSQL(_) => "",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AlterTableChange {
AddColumn {
name: String,
definition: String,
},
DropColumn {
name: String,
},
RenameColumn {
old_name: String,
new_name: String,
},
AlterColumnType {
name: String,
new_type: String,
collation: Option<String>,
},
AlterColumnDefault {
name: String,
default: Option<String>,
},
AlterColumnNullability {
name: String,
nullable: bool,
},
AddConstraint {
name: String,
definition: String,
},
DropConstraint {
name: String,
},
}
fn escape_schema_identifier(name: &str) -> String {
name.replace('"', "\"\"")
}
#[async_trait::async_trait]
pub trait BaseDatabaseSchemaEditor: Send + Sync {
fn database_type(&self) -> super::types::DatabaseType;
async fn execute(&mut self, sql: &str) -> SchemaEditorResult<()>;
fn create_table_statement(
&self,
table: &str,
columns: &[(&str, &str)],
) -> CreateTableStatement {
let mut binding = Query::create_table();
let stmt = binding.table(Alias::new(table)).if_not_exists();
for (name, definition) in columns {
stmt.col(ColumnDef::new(Alias::new(*name)).custom(*definition));
}
stmt.to_owned()
}
fn drop_table_statement(&self, table: &str, cascade: bool) -> DropTableStatement {
let mut binding = Query::drop_table();
let stmt = binding.table(Alias::new(table)).if_exists();
if cascade {
stmt.cascade();
}
stmt.to_owned()
}
fn add_column_statement(
&self,
table: &str,
column: &str,
definition: &str,
) -> AlterTableStatement {
Query::alter_table()
.table(Alias::new(table))
.add_column(ColumnDef::new(Alias::new(column)).custom(Alias::new(definition)))
.to_owned()
}
fn drop_column_statement(&self, table: &str, column: &str) -> AlterTableStatement {
Query::alter_table()
.table(Alias::new(table))
.drop_column(Alias::new(column))
.to_owned()
}
fn rename_column_statement(&self, table: &str, old_name: &str, new_name: &str) -> String {
format!(
"ALTER TABLE \"{}\" RENAME COLUMN \"{}\" TO \"{}\"",
table, old_name, new_name
)
}
fn alter_column_statement(&self, table: &str, column: &str, new_type: &str) -> String {
format!(
"ALTER TABLE \"{}\" ALTER COLUMN \"{}\" TYPE {}",
table, column, new_type
)
}
fn create_index_statement(
&self,
name: &str,
table: &str,
columns: &[&str],
unique: bool,
condition: Option<&str>,
) -> Result<CreateIndexStatement, String> {
if let Some(cond) = condition {
return Err(format!(
"Partial indexes not supported by reinhardt-query. Use raw SQL: CREATE {}INDEX \"{}\" ON \"{}\" ({}) WHERE {}",
if unique { "UNIQUE " } else { "" },
name,
table,
columns
.iter()
.map(|c| format!("\"{}\"", c))
.collect::<Vec<_>>()
.join(", "),
cond
));
}
let mut binding = Query::create_index();
let stmt = binding.name(Alias::new(name)).table(Alias::new(table));
if unique {
stmt.unique();
}
for col in columns {
stmt.col(Alias::new(*col));
}
Ok(stmt.to_owned())
}
fn drop_index_statement(&self, name: &str) -> DropIndexStatement {
let mut binding = Query::drop_index();
binding.name(Alias::new(name)).if_exists().to_owned()
}
fn create_schema_statement(&self, name: &str, if_not_exists: bool) -> String {
let escaped_name = escape_schema_identifier(name);
if if_not_exists {
format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", escaped_name)
} else {
format!("CREATE SCHEMA \"{}\"", escaped_name)
}
}
fn drop_schema_statement(&self, name: &str, cascade: bool, if_exists: bool) -> String {
let if_exists_clause = if if_exists { " IF EXISTS" } else { "" };
let cascade_clause = if cascade { " CASCADE" } else { "" };
format!(
"DROP SCHEMA{} \"{}\"{}",
if_exists_clause,
escape_schema_identifier(name),
cascade_clause
)
}
fn build_create_table_sql(&self, stmt: &CreateTableStatement) -> String {
use super::types::DatabaseType;
let (sql, _values) = match self.database_type() {
DatabaseType::Postgres => PostgresQueryBuilder.build_create_table(stmt),
DatabaseType::Mysql => MySqlQueryBuilder.build_create_table(stmt),
DatabaseType::Sqlite => SqliteQueryBuilder.build_create_table(stmt),
};
sql
}
fn build_drop_table_sql(&self, stmt: &DropTableStatement) -> String {
use super::types::DatabaseType;
let (sql, _values) = match self.database_type() {
DatabaseType::Postgres => PostgresQueryBuilder.build_drop_table(stmt),
DatabaseType::Mysql => MySqlQueryBuilder.build_drop_table(stmt),
DatabaseType::Sqlite => SqliteQueryBuilder.build_drop_table(stmt),
};
sql
}
fn build_alter_table_sql(&self, stmt: &AlterTableStatement) -> String {
use super::types::DatabaseType;
let (sql, _values) = match self.database_type() {
DatabaseType::Postgres => PostgresQueryBuilder.build_alter_table(stmt),
DatabaseType::Mysql => MySqlQueryBuilder.build_alter_table(stmt),
DatabaseType::Sqlite => SqliteQueryBuilder.build_alter_table(stmt),
};
sql
}
fn build_create_index_sql(&self, stmt: &CreateIndexStatement) -> String {
use super::types::DatabaseType;
let (sql, _values) = match self.database_type() {
DatabaseType::Postgres => PostgresQueryBuilder.build_create_index(stmt),
DatabaseType::Mysql => MySqlQueryBuilder.build_create_index(stmt),
DatabaseType::Sqlite => SqliteQueryBuilder.build_create_index(stmt),
};
sql
}
fn build_drop_index_sql(&self, stmt: &DropIndexStatement) -> String {
use super::types::DatabaseType;
let (sql, _values) = match self.database_type() {
DatabaseType::Postgres => PostgresQueryBuilder.build_drop_index(stmt),
DatabaseType::Mysql => MySqlQueryBuilder.build_drop_index(stmt),
DatabaseType::Sqlite => SqliteQueryBuilder.build_drop_index(stmt),
};
sql
}
}
pub type SchemaEditorResult<T> = Result<T, SchemaEditorError>;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum SchemaEditorError {
ExecutionError(String),
InvalidOperation(String),
DatabaseError(String),
}
impl fmt::Display for SchemaEditorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchemaEditorError::ExecutionError(msg) => write!(f, "Execution error: {}", msg),
SchemaEditorError::InvalidOperation(msg) => {
write!(f, "Invalid operation: {}", msg)
}
SchemaEditorError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
}
}
}
impl std::error::Error for SchemaEditorError {}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
struct TestSchemaEditor;
#[async_trait::async_trait]
impl BaseDatabaseSchemaEditor for TestSchemaEditor {
async fn execute(&mut self, _sql: &str) -> SchemaEditorResult<()> {
Ok(())
}
fn database_type(&self) -> crate::backends::types::DatabaseType {
crate::backends::types::DatabaseType::Postgres
}
}
#[test]
fn test_create_table_statement() {
use reinhardt_query::prelude::{PostgresQueryBuilder, QueryBuilder};
let editor = TestSchemaEditor;
let stmt = editor.create_table_statement(
"users",
&[("id", "INTEGER PRIMARY KEY"), ("name", "VARCHAR(100)")],
);
let (sql, _) = PostgresQueryBuilder.build_create_table(&stmt);
assert!(sql.contains("CREATE TABLE"));
assert!(sql.contains("\"users\""));
assert!(sql.contains("\"id\""));
assert!(sql.contains("\"name\""));
}
#[test]
fn test_drop_table_statement() {
use reinhardt_query::prelude::{PostgresQueryBuilder, QueryBuilder};
let editor = TestSchemaEditor;
let stmt_no_cascade = editor.drop_table_statement("users", false);
let (sql_no_cascade, _) = PostgresQueryBuilder.build_drop_table(&stmt_no_cascade);
assert!(sql_no_cascade.contains("DROP TABLE"));
assert!(sql_no_cascade.contains("\"users\""));
let stmt_cascade = editor.drop_table_statement("users", true);
let (sql_cascade, _) = PostgresQueryBuilder.build_drop_table(&stmt_cascade);
assert!(sql_cascade.contains("DROP TABLE"));
assert!(sql_cascade.contains("\"users\""));
assert!(sql_cascade.contains("CASCADE"));
}
#[test]
fn test_add_column_statement() {
use reinhardt_query::prelude::{PostgresQueryBuilder, QueryBuilder};
let editor = TestSchemaEditor;
let stmt = editor.add_column_statement("users", "email", "VARCHAR(255)");
let (sql, _) = PostgresQueryBuilder.build_alter_table(&stmt);
assert!(sql.contains("ALTER TABLE"));
assert!(sql.contains("\"users\""));
assert!(sql.contains("ADD COLUMN"));
assert!(sql.contains("\"email\""));
assert!(sql.contains("VARCHAR(255)"));
}
#[test]
fn test_create_index_statement() {
use reinhardt_query::prelude::{PostgresQueryBuilder, QueryBuilder};
let editor = TestSchemaEditor;
let stmt = editor.create_index_statement("idx_email", "users", &["email"], false, None);
let (sql, _) = PostgresQueryBuilder.build_create_index(&stmt.unwrap());
assert!(sql.contains("CREATE INDEX"));
assert!(sql.contains("idx_email"));
assert!(sql.contains("\"users\""));
let unique_stmt =
editor.create_index_statement("idx_email_uniq", "users", &["email"], true, None);
let (unique_sql, _) = PostgresQueryBuilder.build_create_index(&unique_stmt.unwrap());
assert!(unique_sql.contains("CREATE UNIQUE INDEX"));
let partial_result = editor.create_index_statement(
"idx_active",
"users",
&["email"],
false,
Some("active = true"),
);
assert!(partial_result.is_err());
let error_msg = partial_result.unwrap_err();
assert!(error_msg.contains("Partial indexes not supported"));
assert!(error_msg.contains("WHERE active = true"));
}
#[test]
fn test_alter_column_statement() {
let editor = TestSchemaEditor;
let sql = editor.alter_column_statement("users", "email", "TEXT");
assert_eq!(
sql,
"ALTER TABLE \"users\" ALTER COLUMN \"email\" TYPE TEXT"
);
assert!(sql.contains("\"users\""));
assert!(sql.contains("\"email\""));
assert!(sql.contains("TYPE TEXT"));
}
#[test]
fn test_ddl_statement_table_name() {
let stmt = DDLStatement::CreateTable {
table: "users".to_string(),
columns: vec![],
};
assert_eq!(stmt.table_name(), "users");
let alter_stmt = DDLStatement::AlterTable {
table: "posts".to_string(),
changes: vec![],
};
assert_eq!(alter_stmt.table_name(), "posts");
}
#[rstest]
#[case("my_schema", "CREATE SCHEMA IF NOT EXISTS \"my_schema\"")]
#[case(
"schema\"injection",
"CREATE SCHEMA IF NOT EXISTS \"schema\"\"injection\""
)]
#[case(
"special-chars_123",
"CREATE SCHEMA IF NOT EXISTS \"special-chars_123\""
)]
fn test_create_schema_escapes_identifier(
#[case] schema_name: &str,
#[case] expected_sql: &str,
) {
let editor = TestSchemaEditor;
let sql = editor.create_schema_statement(schema_name, true);
assert_eq!(sql, expected_sql);
}
#[rstest]
fn test_create_schema_without_if_not_exists() {
let editor = TestSchemaEditor;
let sql = editor.create_schema_statement("my_schema", false);
assert_eq!(sql, "CREATE SCHEMA \"my_schema\"");
}
#[rstest]
#[case("my_schema", "DROP SCHEMA IF EXISTS \"my_schema\" CASCADE")]
#[case(
"schema\"injection",
"DROP SCHEMA IF EXISTS \"schema\"\"injection\" CASCADE"
)]
#[case(
"special-chars_123",
"DROP SCHEMA IF EXISTS \"special-chars_123\" CASCADE"
)]
fn test_drop_schema_escapes_identifier(#[case] schema_name: &str, #[case] expected_sql: &str) {
let editor = TestSchemaEditor;
let sql = editor.drop_schema_statement(schema_name, true, true);
assert_eq!(sql, expected_sql);
}
#[rstest]
fn test_drop_schema_without_cascade_and_if_exists() {
let editor = TestSchemaEditor;
let sql = editor.drop_schema_statement("my_schema", false, false);
assert_eq!(sql, "DROP SCHEMA \"my_schema\"");
}
#[rstest]
fn test_escape_schema_identifier_helper() {
assert_eq!(escape_schema_identifier("simple"), "simple");
assert_eq!(escape_schema_identifier("has\"quote"), "has\"\"quote");
assert_eq!(
escape_schema_identifier("multiple\"\"quotes"),
"multiple\"\"\"\"quotes"
);
assert_eq!(escape_schema_identifier(""), "");
}
}
#[cfg(test)]
pub mod test_utils;