use super::Result;
use crate::backends::{
connection::DatabaseConnection,
types::{DatabaseType, QueryValue, Row, TransactionExecutor},
};
pub struct SchemaEditor {
connection: DatabaseConnection,
executor: Option<Box<dyn TransactionExecutor>>,
atomic: bool,
db_type: DatabaseType,
deferred_sql: Vec<String>,
}
impl SchemaEditor {
pub async fn new(
connection: DatabaseConnection,
atomic: bool,
db_type: DatabaseType,
) -> Result<Self> {
let effective_atomic = atomic && db_type.supports_transactional_ddl();
let executor = if effective_atomic {
Some(connection.begin().await?)
} else {
if atomic && !db_type.supports_transactional_ddl() {
tracing::warn!(
"atomic=true requested but {:?} doesn't support transactional DDL. \
Proceeding without transaction wrapper.",
db_type
);
}
None
};
Ok(Self {
connection,
executor,
atomic: effective_atomic,
db_type,
deferred_sql: Vec::new(),
})
}
pub async fn execute(&mut self, sql: &str) -> Result<()> {
if let Some(ref mut tx) = self.executor {
tx.execute(sql, vec![]).await?;
if self.db_type == DatabaseType::Sqlite {
tx.execute("SELECT 1", vec![]).await?;
}
} else {
self.connection.execute(sql, vec![]).await?;
}
Ok(())
}
pub async fn fetch_all(&mut self, sql: &str, params: Vec<QueryValue>) -> Result<Vec<Row>> {
if let Some(ref mut tx) = self.executor {
Ok(tx.fetch_all(sql, params).await?)
} else {
Ok(self.connection.fetch_all(sql, params).await?)
}
}
pub async fn fetch_optional(
&mut self,
sql: &str,
params: Vec<QueryValue>,
) -> Result<Option<Row>> {
if let Some(ref mut tx) = self.executor {
Ok(tx.fetch_optional(sql, params).await?)
} else {
Ok(self.connection.fetch_optional(sql, params).await?)
}
}
pub async fn table_exists(&mut self, table_name: &str) -> Result<bool> {
use reinhardt_query::prelude::{
Alias, Cond, Expr, ExprTrait, MySqlQueryBuilder, PostgresQueryBuilder, Query,
QueryStatementBuilder, SqliteQueryBuilder,
};
match self.db_type {
DatabaseType::Postgres => {
let subquery = Query::select()
.expr(Expr::asterisk())
.from((Alias::new("information_schema"), Alias::new("tables")))
.cond_where(
Cond::all()
.add(Expr::col(Alias::new("table_schema")).eq("public"))
.add(Expr::col(Alias::new("table_name")).eq(table_name)),
)
.to_owned();
let query_str = format!(
"SELECT EXISTS ({}) AS table_exists",
subquery.to_string(PostgresQueryBuilder)
);
match self.fetch_optional(&query_str, vec![]).await? {
Some(row) => match row.data.get("table_exists") {
Some(QueryValue::Bool(b)) => Ok(*b),
_ => Ok(false),
},
None => Ok(false),
}
}
DatabaseType::Sqlite => {
let query = Query::select()
.column(Alias::new("name"))
.from(Alias::new("sqlite_master"))
.cond_where(
Cond::all()
.add(Expr::col(Alias::new("type")).eq("table"))
.add(Expr::col(Alias::new("name")).eq(table_name)),
)
.to_owned();
let query_str = query.to_string(SqliteQueryBuilder);
let row = self.fetch_optional(&query_str, vec![]).await?;
Ok(row.is_some())
}
DatabaseType::Mysql => {
let query = Query::select()
.column(Alias::new("TABLE_NAME"))
.from((Alias::new("information_schema"), Alias::new("tables")))
.cond_where(
Cond::all()
.add(Expr::col(Alias::new("TABLE_SCHEMA")).eq(Expr::cust("DATABASE()")))
.add(Expr::col(Alias::new("TABLE_NAME")).eq(table_name)),
)
.to_owned();
let query_str = query.to_string(MySqlQueryBuilder);
let row = self.fetch_optional(&query_str, vec![]).await?;
Ok(row.is_some())
}
}
}
pub fn defer(&mut self, sql: String) {
self.deferred_sql.push(sql);
}
pub async fn finish(mut self) -> Result<()> {
for sql in self.deferred_sql.drain(..) {
if let Some(ref mut tx) = self.executor {
tx.execute(&sql, vec![]).await?;
} else {
self.connection.execute(&sql, vec![]).await?;
}
}
if let Some(tx) = self.executor.take() {
tx.commit().await?;
}
Ok(())
}
pub async fn rollback(mut self) -> Result<()> {
if let Some(tx) = self.executor.take() {
tx.rollback().await?;
}
Ok(())
}
pub fn is_atomic(&self) -> bool {
self.atomic
}
pub fn database_type(&self) -> DatabaseType {
self.db_type
}
pub fn connection(&self) -> &DatabaseConnection {
&self.connection
}
#[cfg(feature = "sqlite")]
pub async fn disable_foreign_keys(&mut self) -> Result<()> {
if !matches!(self.db_type, DatabaseType::Sqlite) {
return Ok(());
}
tracing::debug!("Disabling SQLite foreign key checks");
self.execute("PRAGMA foreign_keys = OFF").await?;
Ok(())
}
#[cfg(feature = "sqlite")]
pub async fn enable_foreign_keys(&mut self) -> Result<()> {
if !matches!(self.db_type, DatabaseType::Sqlite) {
return Ok(());
}
tracing::debug!("Enabling SQLite foreign key checks");
self.execute("PRAGMA foreign_keys = ON").await?;
Ok(())
}
#[cfg(feature = "sqlite")]
pub async fn check_foreign_key_integrity(&mut self) -> Result<Vec<String>> {
if !matches!(self.db_type, DatabaseType::Sqlite) {
return Ok(Vec::new());
}
tracing::debug!("Checking SQLite foreign key integrity");
let sql = "PRAGMA foreign_key_check";
let rows = if let Some(ref mut tx) = self.executor {
tx.fetch_all(sql, vec![]).await?
} else {
self.connection.fetch_all(sql, vec![]).await?
};
let violations: Vec<String> = rows
.into_iter()
.map(|row| {
let table: String = row.get("table").unwrap_or_default();
let rowid: i64 = row.get("rowid").unwrap_or_default();
let parent_table: String = row.get("parent").unwrap_or_default();
format!(
"FK violation in '{}' row {} referencing '{}'",
table, rowid, parent_table
)
})
.collect();
if !violations.is_empty() {
tracing::warn!("Foreign key violations found: {:?}", violations);
}
Ok(violations)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_type_transactional_ddl() {
assert!(DatabaseType::Postgres.supports_transactional_ddl());
assert!(DatabaseType::Sqlite.supports_transactional_ddl());
assert!(!DatabaseType::Mysql.supports_transactional_ddl());
}
}