pub trait Database: Send {
fn ensure_tracking_table(&mut self, table_name: &str) -> Result<(), String>;
fn is_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<bool, String>;
fn mark_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<(), String>;
fn remove_seed_mark(&mut self, table_name: &str, seed_set: &str) -> Result<(), String>;
fn insert_row(
&mut self,
table: &str,
columns: &[String],
values: &[String],
auto_id_column: Option<&str>,
) -> Result<Option<i64>, String>;
fn row_exists(
&mut self,
table: &str,
unique_columns: &[String],
unique_values: &[String],
) -> Result<bool, String>;
fn delete_rows(&mut self, table: &str) -> Result<u64, String>;
fn begin_transaction(&mut self) -> Result<(), String>;
fn commit_transaction(&mut self) -> Result<(), String>;
fn rollback_transaction(&mut self) -> Result<(), String>;
fn create_database(&mut self, name: &str) -> Result<(), String>;
fn create_schema(&mut self, name: &str) -> Result<(), String>;
fn object_exists(&mut self, obj_type: &str, name: &str) -> Result<bool, String>;
fn driver_name(&self) -> &str;
fn migrate_tracking_table(&mut self, table_name: &str) -> Result<(), String>;
fn ensure_row_tracking_table(&mut self, table_name: &str) -> Result<(), String>;
fn get_seed_hash(&mut self, table_name: &str, seed_set: &str)
-> Result<Option<String>, String>;
fn update_seed_entry(
&mut self,
table_name: &str,
seed_set: &str,
hash: &str,
) -> Result<(), String>;
fn store_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
row_values: &str,
) -> Result<(), String>;
fn get_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
) -> Result<Vec<(String, String)>, String>;
fn delete_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
) -> Result<(), String>;
fn delete_all_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
) -> Result<(), String>;
fn update_row(
&mut self,
table: &str,
set_columns: &[String],
set_values: &[String],
where_columns: &[String],
where_values: &[String],
) -> Result<u64, String>;
fn get_row_columns(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
fetch_columns: &[String],
) -> Result<Option<Vec<String>>, String>;
fn delete_row_by_key(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
) -> Result<u64, String>;
}
#[cfg(feature = "sqlite")]
pub struct SqliteDb {
pub(crate) conn: rusqlite::Connection,
in_transaction: bool,
}
#[cfg(feature = "sqlite")]
impl SqliteDb {
pub fn connect(url: &str) -> Result<Self, String> {
let conn = if url == ":memory:" {
rusqlite::Connection::open_in_memory()
} else {
rusqlite::Connection::open(url)
}
.map_err(|e| format!("opening sqlite database '{}': {}", url, e))?;
conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")
.map_err(|e| format!("setting sqlite pragmas: {}", e))?;
Ok(Self {
conn,
in_transaction: false,
})
}
}
#[cfg(feature = "sqlite")]
impl Database for SqliteDb {
fn ensure_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{}\" (
seed_set TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
)",
sanitize_identifier(table_name)
);
self.conn
.execute(&sql, [])
.map_err(|e| format!("creating tracking table: {}", e))?;
Ok(())
}
fn is_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<bool, String> {
let sql = format!(
"SELECT COUNT(*) FROM \"{}\" WHERE seed_set = ?1",
sanitize_identifier(table_name)
);
let count: i64 = self
.conn
.query_row(&sql, [seed_set], |row| row.get(0))
.map_err(|e| format!("checking seed status: {}", e))?;
Ok(count > 0)
}
fn mark_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"INSERT OR IGNORE INTO \"{}\" (seed_set) VALUES (?1)",
sanitize_identifier(table_name)
);
self.conn
.execute(&sql, [seed_set])
.map_err(|e| format!("marking seed applied: {}", e))?;
Ok(())
}
fn remove_seed_mark(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"DELETE FROM \"{}\" WHERE seed_set = ?1",
sanitize_identifier(table_name)
);
self.conn
.execute(&sql, [seed_set])
.map_err(|e| format!("removing seed mark: {}", e))?;
Ok(())
}
fn insert_row(
&mut self,
table: &str,
columns: &[String],
values: &[String],
_auto_id_column: Option<&str>,
) -> Result<Option<i64>, String> {
let col_list: Vec<String> = columns
.iter()
.map(|c| format!("\"{}\"", sanitize_identifier(c)))
.collect();
let placeholders: Vec<String> = (1..=values.len()).map(|i| format!("?{}", i)).collect();
let sql = format!(
"INSERT INTO \"{}\" ({}) VALUES ({})",
sanitize_identifier(table),
col_list.join(", "),
placeholders.join(", ")
);
let params: Vec<&dyn rusqlite::types::ToSql> = values
.iter()
.map(|v| v as &dyn rusqlite::types::ToSql)
.collect();
self.conn
.execute(&sql, params.as_slice())
.map_err(|e| format!("inserting row into '{}': {}", table, e))?;
Ok(Some(self.conn.last_insert_rowid()))
}
fn row_exists(
&mut self,
table: &str,
unique_columns: &[String],
unique_values: &[String],
) -> Result<bool, String> {
if unique_columns.is_empty() {
return Ok(false);
}
let conditions: Vec<String> = unique_columns
.iter()
.enumerate()
.map(|(i, c)| format!("\"{}\" = ?{}", sanitize_identifier(c), i + 1))
.collect();
let sql = format!(
"SELECT COUNT(*) FROM \"{}\" WHERE {}",
sanitize_identifier(table),
conditions.join(" AND ")
);
let params: Vec<&dyn rusqlite::types::ToSql> = unique_values
.iter()
.map(|v| v as &dyn rusqlite::types::ToSql)
.collect();
let count: i64 = self
.conn
.query_row(&sql, params.as_slice(), |row| row.get(0))
.map_err(|e| format!("checking row existence in '{}': {}", table, e))?;
Ok(count > 0)
}
fn delete_rows(&mut self, table: &str) -> Result<u64, String> {
let sql = format!("DELETE FROM \"{}\"", sanitize_identifier(table));
let count = self
.conn
.execute(&sql, [])
.map_err(|e| format!("deleting rows from '{}': {}", table, e))?;
Ok(count as u64)
}
fn begin_transaction(&mut self) -> Result<(), String> {
self.conn
.execute("BEGIN", [])
.map_err(|e| format!("beginning transaction: {}", e))?;
self.in_transaction = true;
Ok(())
}
fn commit_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
self.conn
.execute("COMMIT", [])
.map_err(|e| format!("committing transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn rollback_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
self.conn
.execute("ROLLBACK", [])
.map_err(|e| format!("rolling back transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn create_database(&mut self, _name: &str) -> Result<(), String> {
Err("sqlite does not support CREATE DATABASE (each file is a database)".into())
}
fn create_schema(&mut self, _name: &str) -> Result<(), String> {
Err("sqlite does not support schemas".into())
}
fn object_exists(&mut self, obj_type: &str, name: &str) -> Result<bool, String> {
match obj_type {
"table" => {
let count: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
[name],
|row| row.get(0),
)
.map_err(|e| format!("checking table existence: {}", e))?;
Ok(count > 0)
}
"view" => {
let count: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='view' AND name=?1",
[name],
|row| row.get(0),
)
.map_err(|e| format!("checking view existence: {}", e))?;
Ok(count > 0)
}
"schema" => Err("sqlite does not support schemas".into()),
"database" => Err("sqlite does not support checking database existence".into()),
_ => Err(format!("unsupported object type '{}' for sqlite", obj_type)),
}
}
fn driver_name(&self) -> &str {
"sqlite"
}
fn migrate_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!("PRAGMA table_info(\"{}\")", safe);
let has_hash = self
.conn
.prepare(&sql)
.map_err(|e| format!("checking tracking table schema: {}", e))?
.query_map([], |row| {
let name: String = row.get(1)?;
Ok(name)
})
.map_err(|e| format!("reading tracking table schema: {}", e))?
.any(|r| r.map(|n| n == "content_hash").unwrap_or(false));
if !has_hash {
let alter = format!("ALTER TABLE \"{}\" ADD COLUMN content_hash TEXT", safe);
self.conn
.execute(&alter, [])
.map_err(|e| format!("migrating tracking table: {}", e))?;
}
Ok(())
}
fn ensure_row_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{}_rows\" (
seed_set TEXT NOT NULL,
table_name TEXT NOT NULL,
row_key TEXT NOT NULL,
row_values TEXT NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (seed_set, table_name, row_key)
)",
safe
);
self.conn
.execute(&sql, [])
.map_err(|e| format!("creating row tracking table: {}", e))?;
Ok(())
}
fn get_seed_hash(
&mut self,
table_name: &str,
seed_set: &str,
) -> Result<Option<String>, String> {
let sql = format!(
"SELECT content_hash FROM \"{}\" WHERE seed_set = ?1",
sanitize_identifier(table_name)
);
match self
.conn
.query_row(&sql, [seed_set], |row| row.get::<_, Option<String>>(0))
{
Ok(hash) => Ok(hash),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(format!("getting seed hash: {}", e)),
}
}
fn update_seed_entry(
&mut self,
table_name: &str,
seed_set: &str,
hash: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"INSERT INTO \"{}\" (seed_set, content_hash) VALUES (?1, ?2) \
ON CONFLICT(seed_set) DO UPDATE SET content_hash = ?2, applied_at = datetime('now')",
safe
);
self.conn
.execute(&sql, [seed_set, hash])
.map_err(|e| format!("updating seed entry: {}", e))?;
Ok(())
}
fn store_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
row_values: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"INSERT INTO \"{}_rows\" (seed_set, table_name, row_key, row_values) VALUES (?1, ?2, ?3, ?4) \
ON CONFLICT(seed_set, table_name, row_key) DO UPDATE SET row_values = ?4, applied_at = datetime('now')",
safe
);
self.conn
.execute(&sql, [seed_set, table_name, row_key, row_values])
.map_err(|e| format!("storing tracked row: {}", e))?;
Ok(())
}
fn get_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
) -> Result<Vec<(String, String)>, String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"SELECT row_key, row_values FROM \"{}_rows\" WHERE seed_set = ?1 AND table_name = ?2",
safe
);
let mut stmt = self
.conn
.prepare(&sql)
.map_err(|e| format!("preparing tracked rows query: {}", e))?;
let rows = stmt
.query_map([seed_set, table_name], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})
.map_err(|e| format!("querying tracked rows: {}", e))?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("reading tracked rows: {}", e))?;
Ok(rows)
}
fn delete_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"DELETE FROM \"{}_rows\" WHERE seed_set = ?1 AND table_name = ?2 AND row_key = ?3",
safe
);
self.conn
.execute(&sql, [seed_set, table_name, row_key])
.map_err(|e| format!("deleting tracked row: {}", e))?;
Ok(())
}
fn delete_all_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!("DELETE FROM \"{}_rows\" WHERE seed_set = ?1", safe);
self.conn
.execute(&sql, [seed_set])
.map_err(|e| format!("deleting all tracked rows: {}", e))?;
Ok(())
}
fn update_row(
&mut self,
table: &str,
set_columns: &[String],
set_values: &[String],
where_columns: &[String],
where_values: &[String],
) -> Result<u64, String> {
let set_clause: Vec<String> = set_columns
.iter()
.enumerate()
.map(|(i, c)| format!("\"{}\" = ?{}", sanitize_identifier(c), i + 1))
.collect();
let where_clause: Vec<String> = where_columns
.iter()
.enumerate()
.map(|(i, c)| {
format!(
"\"{}\" = ?{}",
sanitize_identifier(c),
set_values.len() + i + 1
)
})
.collect();
let sql = format!(
"UPDATE \"{}\" SET {} WHERE {}",
sanitize_identifier(table),
set_clause.join(", "),
where_clause.join(" AND ")
);
let mut all_values: Vec<&dyn rusqlite::types::ToSql> = Vec::new();
for v in set_values.iter().chain(where_values.iter()) {
all_values.push(v as &dyn rusqlite::types::ToSql);
}
let count = self
.conn
.execute(&sql, all_values.as_slice())
.map_err(|e| format!("updating row in '{}': {}", table, e))?;
Ok(count as u64)
}
fn get_row_columns(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
fetch_columns: &[String],
) -> Result<Option<Vec<String>>, String> {
if fetch_columns.is_empty() {
return Ok(None);
}
let select_cols: Vec<String> = fetch_columns
.iter()
.map(|c| format!("CAST(\"{}\" AS TEXT)", sanitize_identifier(c)))
.collect();
let where_clause: Vec<String> = key_columns
.iter()
.enumerate()
.map(|(i, c)| format!("\"{}\" = ?{}", sanitize_identifier(c), i + 1))
.collect();
let sql = format!(
"SELECT {} FROM \"{}\" WHERE {}",
select_cols.join(", "),
sanitize_identifier(table),
where_clause.join(" AND ")
);
let params: Vec<&dyn rusqlite::types::ToSql> = key_values
.iter()
.map(|v| v as &dyn rusqlite::types::ToSql)
.collect();
match self.conn.query_row(&sql, params.as_slice(), |row| {
let mut vals = Vec::new();
for i in 0..fetch_columns.len() {
let v: Option<String> = row.get(i)?;
vals.push(v.unwrap_or_default());
}
Ok(vals)
}) {
Ok(vals) => Ok(Some(vals)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(format!("getting row from '{}': {}", table, e)),
}
}
fn delete_row_by_key(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
) -> Result<u64, String> {
let where_clause: Vec<String> = key_columns
.iter()
.enumerate()
.map(|(i, c)| format!("\"{}\" = ?{}", sanitize_identifier(c), i + 1))
.collect();
let sql = format!(
"DELETE FROM \"{}\" WHERE {}",
sanitize_identifier(table),
where_clause.join(" AND ")
);
let params: Vec<&dyn rusqlite::types::ToSql> = key_values
.iter()
.map(|v| v as &dyn rusqlite::types::ToSql)
.collect();
let count = self
.conn
.execute(&sql, params.as_slice())
.map_err(|e| format!("deleting row from '{}': {}", table, e))?;
Ok(count as u64)
}
}
#[cfg(feature = "postgres")]
pub struct PostgresDb {
client: postgres::Client,
in_transaction: bool,
}
#[cfg(feature = "postgres")]
impl PostgresDb {
pub fn connect(url: &str) -> Result<Self, String> {
let client = postgres::Client::connect(url, postgres::NoTls)
.map_err(|e| format!("connecting to postgres: {}", e))?;
Ok(Self {
client,
in_transaction: false,
})
}
}
#[cfg(feature = "postgres")]
impl Database for PostgresDb {
fn ensure_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{}\" (
seed_set TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)",
sanitize_identifier(table_name)
);
self.client
.execute(&sql, &[])
.map_err(|e| format!("creating tracking table: {}", e))?;
Ok(())
}
fn is_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<bool, String> {
let sql = format!(
"SELECT COUNT(*) FROM \"{}\" WHERE seed_set = $1",
sanitize_identifier(table_name)
);
let row = self
.client
.query_one(&sql, &[&seed_set])
.map_err(|e| format!("checking seed status: {}", e))?;
let count: i64 = row.get(0);
Ok(count > 0)
}
fn mark_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"INSERT INTO \"{}\" (seed_set) VALUES ($1) ON CONFLICT DO NOTHING",
sanitize_identifier(table_name)
);
self.client
.execute(&sql, &[&seed_set])
.map_err(|e| format!("marking seed applied: {}", e))?;
Ok(())
}
fn remove_seed_mark(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"DELETE FROM \"{}\" WHERE seed_set = $1",
sanitize_identifier(table_name)
);
self.client
.execute(&sql, &[&seed_set])
.map_err(|e| format!("removing seed mark: {}", e))?;
Ok(())
}
fn insert_row(
&mut self,
table: &str,
columns: &[String],
values: &[String],
auto_id_column: Option<&str>,
) -> Result<Option<i64>, String> {
let col_list: Vec<String> = columns
.iter()
.map(|c| format!("\"{}\"", sanitize_identifier(c)))
.collect();
let value_list: Vec<String> = values.iter().map(|v| escape_sql_value(v)).collect();
if let Some(auto_col) = auto_id_column {
let returning_col = sanitize_identifier(auto_col);
let sql = format!(
"INSERT INTO \"{}\" ({}) VALUES ({}) RETURNING COALESCE(CAST(\"{}\" AS BIGINT), 0)",
sanitize_identifier(table),
col_list.join(", "),
value_list.join(", "),
returning_col
);
let row = self
.client
.query_one(&sql, &[])
.map_err(|e| format!("inserting row into '{}': {}", table, e))?;
let id: i64 = row.get(0);
Ok(Some(id))
} else {
let sql = format!(
"INSERT INTO \"{}\" ({}) VALUES ({})",
sanitize_identifier(table),
col_list.join(", "),
value_list.join(", "),
);
self.client
.execute(&sql, &[])
.map_err(|e| format!("inserting row into '{}': {}", table, e))?;
Ok(None)
}
}
fn row_exists(
&mut self,
table: &str,
unique_columns: &[String],
unique_values: &[String],
) -> Result<bool, String> {
if unique_columns.is_empty() {
return Ok(false);
}
let conditions: Vec<String> = unique_columns
.iter()
.zip(unique_values.iter())
.map(|(c, v)| format!("\"{}\" = {}", sanitize_identifier(c), escape_sql_value(v)))
.collect();
let sql = format!(
"SELECT COUNT(*) FROM \"{}\" WHERE {}",
sanitize_identifier(table),
conditions.join(" AND ")
);
let row = self
.client
.query_one(&sql, &[])
.map_err(|e| format!("checking row existence in '{}': {}", table, e))?;
let count: i64 = row.get(0);
Ok(count > 0)
}
fn delete_rows(&mut self, table: &str) -> Result<u64, String> {
let sql = format!("DELETE FROM \"{}\"", sanitize_identifier(table));
let count = self
.client
.execute(&sql, &[])
.map_err(|e| format!("deleting rows from '{}': {}", table, e))?;
Ok(count)
}
fn begin_transaction(&mut self) -> Result<(), String> {
self.client
.execute("BEGIN", &[])
.map_err(|e| format!("beginning transaction: {}", e))?;
self.in_transaction = true;
Ok(())
}
fn commit_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
self.client
.execute("COMMIT", &[])
.map_err(|e| format!("committing transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn rollback_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
self.client
.execute("ROLLBACK", &[])
.map_err(|e| format!("rolling back transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn create_database(&mut self, name: &str) -> Result<(), String> {
let safe = sanitize_identifier(name);
let row = self
.client
.query_one(
"SELECT COUNT(*) FROM pg_database WHERE datname = $1",
&[&safe],
)
.map_err(|e| format!("checking database existence: {}", e))?;
let count: i64 = row.get(0);
if count == 0 {
let sql = format!("CREATE DATABASE \"{}\"", safe);
self.client
.execute(&sql, &[])
.map_err(|e| format!("creating database '{}': {}", name, e))?;
}
Ok(())
}
fn create_schema(&mut self, name: &str) -> Result<(), String> {
let sql = format!(
"CREATE SCHEMA IF NOT EXISTS \"{}\"",
sanitize_identifier(name)
);
self.client
.execute(&sql, &[])
.map_err(|e| format!("creating schema '{}': {}", name, e))?;
Ok(())
}
fn object_exists(&mut self, obj_type: &str, name: &str) -> Result<bool, String> {
let sql = match obj_type {
"table" => {
"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1".to_string()
}
"view" => {
"SELECT COUNT(*) FROM information_schema.views WHERE table_name = $1".to_string()
}
"schema" => "SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = $1"
.to_string(),
"database" => "SELECT COUNT(*) FROM pg_database WHERE datname = $1".to_string(),
_ => {
return Err(format!(
"unsupported object type '{}' for postgres",
obj_type
))
}
};
let row = self
.client
.query_one(&sql, &[&name])
.map_err(|e| format!("checking {} existence: {}", obj_type, e))?;
let count: i64 = row.get(0);
Ok(count > 0)
}
fn driver_name(&self) -> &str {
"postgres"
}
fn migrate_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"DO $$ BEGIN \
IF NOT EXISTS (SELECT 1 FROM information_schema.columns \
WHERE table_name='{}' AND column_name='content_hash') THEN \
ALTER TABLE \"{}\" ADD COLUMN content_hash TEXT; \
END IF; \
END $$",
safe, safe
);
self.client
.execute(&sql, &[])
.map_err(|e| format!("migrating tracking table: {}", e))?;
Ok(())
}
fn ensure_row_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"CREATE TABLE IF NOT EXISTS \"{}_rows\" (
seed_set TEXT NOT NULL,
table_name TEXT NOT NULL,
row_key TEXT NOT NULL,
row_values TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (seed_set, table_name, row_key)
)",
safe
);
self.client
.execute(&sql, &[])
.map_err(|e| format!("creating row tracking table: {}", e))?;
Ok(())
}
fn get_seed_hash(
&mut self,
table_name: &str,
seed_set: &str,
) -> Result<Option<String>, String> {
let sql = format!(
"SELECT content_hash FROM \"{}\" WHERE seed_set = $1",
sanitize_identifier(table_name)
);
let rows = self
.client
.query(&sql, &[&seed_set])
.map_err(|e| format!("getting seed hash: {}", e))?;
if rows.is_empty() {
Ok(None)
} else {
Ok(rows[0].get(0))
}
}
fn update_seed_entry(
&mut self,
table_name: &str,
seed_set: &str,
hash: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"INSERT INTO \"{}\" (seed_set, content_hash) VALUES ($1, $2) \
ON CONFLICT(seed_set) DO UPDATE SET content_hash = $2, applied_at = NOW()",
safe
);
self.client
.execute(&sql, &[&seed_set, &hash])
.map_err(|e| format!("updating seed entry: {}", e))?;
Ok(())
}
fn store_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
row_values: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"INSERT INTO \"{}_rows\" (seed_set, table_name, row_key, row_values) VALUES ($1, $2, $3, $4) \
ON CONFLICT(seed_set, table_name, row_key) DO UPDATE SET row_values = $4, applied_at = NOW()",
safe
);
self.client
.execute(&sql, &[&seed_set, &table_name, &row_key, &row_values])
.map_err(|e| format!("storing tracked row: {}", e))?;
Ok(())
}
fn get_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
) -> Result<Vec<(String, String)>, String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"SELECT row_key, row_values FROM \"{}_rows\" WHERE seed_set = $1 AND table_name = $2",
safe
);
let rows = self
.client
.query(&sql, &[&seed_set, &table_name])
.map_err(|e| format!("querying tracked rows: {}", e))?;
Ok(rows
.iter()
.map(|r| (r.get::<_, String>(0), r.get::<_, String>(1)))
.collect())
}
fn delete_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"DELETE FROM \"{}_rows\" WHERE seed_set = $1 AND table_name = $2 AND row_key = $3",
safe
);
self.client
.execute(&sql, &[&seed_set, &table_name, &row_key])
.map_err(|e| format!("deleting tracked row: {}", e))?;
Ok(())
}
fn delete_all_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!("DELETE FROM \"{}_rows\" WHERE seed_set = $1", safe);
self.client
.execute(&sql, &[&seed_set])
.map_err(|e| format!("deleting all tracked rows: {}", e))?;
Ok(())
}
fn update_row(
&mut self,
table: &str,
set_columns: &[String],
set_values: &[String],
where_columns: &[String],
where_values: &[String],
) -> Result<u64, String> {
let set_clause: Vec<String> = set_columns
.iter()
.zip(set_values.iter())
.map(|(c, v)| format!("\"{}\" = {}", sanitize_identifier(c), escape_sql_value(v)))
.collect();
let where_clause: Vec<String> = where_columns
.iter()
.zip(where_values.iter())
.map(|(c, v)| format!("\"{}\" = {}", sanitize_identifier(c), escape_sql_value(v)))
.collect();
let sql = format!(
"UPDATE \"{}\" SET {} WHERE {}",
sanitize_identifier(table),
set_clause.join(", "),
where_clause.join(" AND ")
);
let count = self
.client
.execute(&sql, &[])
.map_err(|e| format!("updating row in '{}': {}", table, e))?;
Ok(count)
}
fn get_row_columns(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
fetch_columns: &[String],
) -> Result<Option<Vec<String>>, String> {
if fetch_columns.is_empty() {
return Ok(None);
}
let select_cols: Vec<String> = fetch_columns
.iter()
.map(|c| format!("CAST(\"{}\" AS TEXT)", sanitize_identifier(c)))
.collect();
let where_clause: Vec<String> = key_columns
.iter()
.zip(key_values.iter())
.map(|(c, v)| format!("\"{}\" = {}", sanitize_identifier(c), escape_sql_value(v)))
.collect();
let sql = format!(
"SELECT {} FROM \"{}\" WHERE {}",
select_cols.join(", "),
sanitize_identifier(table),
where_clause.join(" AND ")
);
let rows = self
.client
.query(&sql, &[])
.map_err(|e| format!("getting row from '{}': {}", table, e))?;
if rows.is_empty() {
Ok(None)
} else {
let mut vals = Vec::new();
for i in 0..fetch_columns.len() {
let v: Option<String> = rows[0].get(i);
vals.push(v.unwrap_or_default());
}
Ok(Some(vals))
}
}
fn delete_row_by_key(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
) -> Result<u64, String> {
let where_clause: Vec<String> = key_columns
.iter()
.zip(key_values.iter())
.map(|(c, v)| format!("\"{}\" = {}", sanitize_identifier(c), escape_sql_value(v)))
.collect();
let sql = format!(
"DELETE FROM \"{}\" WHERE {}",
sanitize_identifier(table),
where_clause.join(" AND ")
);
let count = self
.client
.execute(&sql, &[])
.map_err(|e| format!("deleting row from '{}': {}", table, e))?;
Ok(count)
}
}
#[cfg(feature = "mysql")]
pub struct MysqlDb {
conn: mysql::PooledConn,
in_transaction: bool,
}
#[cfg(feature = "mysql")]
impl MysqlDb {
pub fn connect(url: &str) -> Result<Self, String> {
let pool = mysql::Pool::new(url).map_err(|e| format!("connecting to mysql: {}", e))?;
let conn = pool
.get_conn()
.map_err(|e| format!("getting mysql connection: {}", e))?;
Ok(Self {
conn,
in_transaction: false,
})
}
}
#[cfg(feature = "mysql")]
impl Database for MysqlDb {
fn ensure_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let sql = format!(
"CREATE TABLE IF NOT EXISTS `{}` (
seed_set VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)",
sanitize_identifier(table_name)
);
use mysql::prelude::Queryable;
self.conn
.query_drop(&sql)
.map_err(|e| format!("creating tracking table: {}", e))?;
Ok(())
}
fn is_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<bool, String> {
let sql = format!(
"SELECT COUNT(*) FROM `{}` WHERE seed_set = ?",
sanitize_identifier(table_name)
);
use mysql::prelude::Queryable;
let count: Option<i64> = self
.conn
.exec_first(&sql, (seed_set,))
.map_err(|e| format!("checking seed status: {}", e))?;
Ok(count.unwrap_or(0) > 0)
}
fn mark_seed_applied(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"INSERT IGNORE INTO `{}` (seed_set) VALUES (?)",
sanitize_identifier(table_name)
);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set,))
.map_err(|e| format!("marking seed applied: {}", e))?;
Ok(())
}
fn remove_seed_mark(&mut self, table_name: &str, seed_set: &str) -> Result<(), String> {
let sql = format!(
"DELETE FROM `{}` WHERE seed_set = ?",
sanitize_identifier(table_name)
);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set,))
.map_err(|e| format!("removing seed mark: {}", e))?;
Ok(())
}
fn insert_row(
&mut self,
table: &str,
columns: &[String],
values: &[String],
_auto_id_column: Option<&str>,
) -> Result<Option<i64>, String> {
let col_list: Vec<String> = columns
.iter()
.map(|c| format!("`{}`", sanitize_identifier(c)))
.collect();
let placeholders: Vec<String> = columns.iter().map(|_| "?".into()).collect();
let sql = format!(
"INSERT INTO `{}` ({}) VALUES ({})",
sanitize_identifier(table),
col_list.join(", "),
placeholders.join(", ")
);
use mysql::prelude::Queryable;
let params: Vec<mysql::Value> = values
.iter()
.map(|v| mysql::Value::from(v.as_str()))
.collect();
self.conn
.exec_drop(&sql, ¶ms)
.map_err(|e| format!("inserting row into '{}': {}", table, e))?;
let id: Option<i64> = self
.conn
.exec_first("SELECT LAST_INSERT_ID()", ())
.map_err(|e| format!("getting last insert id: {}", e))?;
Ok(id)
}
fn row_exists(
&mut self,
table: &str,
unique_columns: &[String],
unique_values: &[String],
) -> Result<bool, String> {
if unique_columns.is_empty() {
return Ok(false);
}
let conditions: Vec<String> = unique_columns
.iter()
.map(|c| format!("`{}` = ?", sanitize_identifier(c)))
.collect();
let sql = format!(
"SELECT COUNT(*) FROM `{}` WHERE {}",
sanitize_identifier(table),
conditions.join(" AND ")
);
use mysql::prelude::Queryable;
let params: Vec<mysql::Value> = unique_values
.iter()
.map(|v| mysql::Value::from(v.as_str()))
.collect();
let count: Option<i64> = self
.conn
.exec_first(&sql, ¶ms)
.map_err(|e| format!("checking row existence in '{}': {}", table, e))?;
Ok(count.unwrap_or(0) > 0)
}
fn delete_rows(&mut self, table: &str) -> Result<u64, String> {
let sql = format!("DELETE FROM `{}`", sanitize_identifier(table));
use mysql::prelude::Queryable;
self.conn
.query_drop(&sql)
.map_err(|e| format!("deleting rows from '{}': {}", table, e))?;
let affected: Option<u64> = self
.conn
.exec_first("SELECT ROW_COUNT()", ())
.map_err(|e| format!("getting affected rows: {}", e))?;
Ok(affected.unwrap_or(0))
}
fn begin_transaction(&mut self) -> Result<(), String> {
use mysql::prelude::Queryable;
self.conn
.query_drop("START TRANSACTION")
.map_err(|e| format!("beginning transaction: {}", e))?;
self.in_transaction = true;
Ok(())
}
fn commit_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
use mysql::prelude::Queryable;
self.conn
.query_drop("COMMIT")
.map_err(|e| format!("committing transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn rollback_transaction(&mut self) -> Result<(), String> {
if self.in_transaction {
use mysql::prelude::Queryable;
self.conn
.query_drop("ROLLBACK")
.map_err(|e| format!("rolling back transaction: {}", e))?;
self.in_transaction = false;
}
Ok(())
}
fn create_database(&mut self, name: &str) -> Result<(), String> {
let sql = format!(
"CREATE DATABASE IF NOT EXISTS `{}`",
sanitize_identifier(name)
);
use mysql::prelude::Queryable;
self.conn
.query_drop(&sql)
.map_err(|e| format!("creating database '{}': {}", name, e))?;
Ok(())
}
fn create_schema(&mut self, name: &str) -> Result<(), String> {
self.create_database(name)
}
fn object_exists(&mut self, obj_type: &str, name: &str) -> Result<bool, String> {
use mysql::prelude::Queryable;
let sql = match obj_type {
"table" => "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?",
"view" => "SELECT COUNT(*) FROM information_schema.views WHERE table_schema = DATABASE() AND table_name = ?",
"schema" | "database" => "SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?",
_ => return Err(format!("unsupported object type '{}' for mysql", obj_type)),
};
let count: Option<i64> = self
.conn
.exec_first(sql, (name,))
.map_err(|e| format!("checking {} existence: {}", obj_type, e))?;
Ok(count.unwrap_or(0) > 0)
}
fn driver_name(&self) -> &str {
"mysql"
}
fn migrate_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
use mysql::prelude::Queryable;
let check_sql = format!(
"SELECT COUNT(*) FROM information_schema.columns \
WHERE table_schema = DATABASE() AND table_name = '{}' AND column_name = 'content_hash'",
safe
);
let count: Option<i64> = self
.conn
.exec_first(&check_sql, ())
.map_err(|e| format!("checking tracking table schema: {}", e))?;
if count.unwrap_or(0) == 0 {
let alter = format!("ALTER TABLE `{}` ADD COLUMN content_hash TEXT", safe);
self.conn
.query_drop(&alter)
.map_err(|e| format!("migrating tracking table: {}", e))?;
}
Ok(())
}
fn ensure_row_tracking_table(&mut self, table_name: &str) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"CREATE TABLE IF NOT EXISTS `{}_rows` (
seed_set VARCHAR(255) NOT NULL,
table_name VARCHAR(255) NOT NULL,
row_key TEXT NOT NULL,
row_values TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
row_key_hash BINARY(32) GENERATED ALWAYS AS (UNHEX(SHA2(row_key, 256))) STORED,
PRIMARY KEY (seed_set, table_name, row_key_hash)
)",
safe
);
use mysql::prelude::Queryable;
self.conn
.query_drop(&sql)
.map_err(|e| format!("creating row tracking table: {}", e))?;
Ok(())
}
fn get_seed_hash(
&mut self,
table_name: &str,
seed_set: &str,
) -> Result<Option<String>, String> {
let sql = format!(
"SELECT content_hash FROM `{}` WHERE seed_set = ?",
sanitize_identifier(table_name)
);
use mysql::prelude::Queryable;
let result: Option<Option<String>> = self
.conn
.exec_first(&sql, (seed_set,))
.map_err(|e| format!("getting seed hash: {}", e))?;
Ok(result.flatten())
}
fn update_seed_entry(
&mut self,
table_name: &str,
seed_set: &str,
hash: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(table_name);
let sql = format!(
"INSERT INTO `{}` (seed_set, content_hash) VALUES (?, ?) \
ON DUPLICATE KEY UPDATE content_hash = VALUES(content_hash), applied_at = CURRENT_TIMESTAMP",
safe
);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set, hash))
.map_err(|e| format!("updating seed entry: {}", e))?;
Ok(())
}
fn store_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
row_values: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"INSERT INTO `{}_rows` (seed_set, table_name, row_key, row_values) VALUES (?, ?, ?, ?) \
ON DUPLICATE KEY UPDATE row_values = VALUES(row_values), applied_at = CURRENT_TIMESTAMP",
safe
);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set, table_name, row_key, row_values))
.map_err(|e| format!("storing tracked row: {}", e))?;
Ok(())
}
fn get_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
) -> Result<Vec<(String, String)>, String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"SELECT row_key, row_values FROM `{}_rows` WHERE seed_set = ? AND table_name = ?",
safe
);
use mysql::prelude::Queryable;
let rows: Vec<(String, String)> = self
.conn
.exec(&sql, (seed_set, table_name))
.map_err(|e| format!("querying tracked rows: {}", e))?;
Ok(rows)
}
fn delete_tracked_row(
&mut self,
tracking_table: &str,
seed_set: &str,
table_name: &str,
row_key: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!(
"DELETE FROM `{}_rows` WHERE seed_set = ? AND table_name = ? AND row_key = ?",
safe
);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set, table_name, row_key))
.map_err(|e| format!("deleting tracked row: {}", e))?;
Ok(())
}
fn delete_all_tracked_rows(
&mut self,
tracking_table: &str,
seed_set: &str,
) -> Result<(), String> {
let safe = sanitize_identifier(tracking_table);
let sql = format!("DELETE FROM `{}_rows` WHERE seed_set = ?", safe);
use mysql::prelude::Queryable;
self.conn
.exec_drop(&sql, (seed_set,))
.map_err(|e| format!("deleting all tracked rows: {}", e))?;
Ok(())
}
fn update_row(
&mut self,
table: &str,
set_columns: &[String],
set_values: &[String],
where_columns: &[String],
where_values: &[String],
) -> Result<u64, String> {
let set_clause: Vec<String> = set_columns
.iter()
.map(|c| format!("`{}` = ?", sanitize_identifier(c)))
.collect();
let where_clause: Vec<String> = where_columns
.iter()
.map(|c| format!("`{}` = ?", sanitize_identifier(c)))
.collect();
let sql = format!(
"UPDATE `{}` SET {} WHERE {}",
sanitize_identifier(table),
set_clause.join(", "),
where_clause.join(" AND ")
);
use mysql::prelude::Queryable;
let params: Vec<mysql::Value> = set_values
.iter()
.chain(where_values.iter())
.map(|v| mysql::Value::from(v.as_str()))
.collect();
self.conn
.exec_drop(&sql, ¶ms)
.map_err(|e| format!("updating row in '{}': {}", table, e))?;
let affected: Option<u64> = self
.conn
.exec_first("SELECT ROW_COUNT()", ())
.map_err(|e| format!("getting affected rows: {}", e))?;
Ok(affected.unwrap_or(0))
}
fn get_row_columns(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
fetch_columns: &[String],
) -> Result<Option<Vec<String>>, String> {
if fetch_columns.is_empty() {
return Ok(None);
}
let select_cols: Vec<String> = fetch_columns
.iter()
.map(|c| format!("CAST(`{}` AS CHAR)", sanitize_identifier(c)))
.collect();
let where_clause: Vec<String> = key_columns
.iter()
.map(|c| format!("`{}` = ?", sanitize_identifier(c)))
.collect();
let sql = format!(
"SELECT {} FROM `{}` WHERE {}",
select_cols.join(", "),
sanitize_identifier(table),
where_clause.join(" AND ")
);
use mysql::prelude::Queryable;
let params: Vec<mysql::Value> = key_values
.iter()
.map(|v| mysql::Value::from(v.as_str()))
.collect();
let row: Option<mysql::Row> = self
.conn
.exec_first(&sql, ¶ms)
.map_err(|e| format!("getting row from '{}': {}", table, e))?;
match row {
Some(r) => {
let mut vals = Vec::new();
for i in 0..fetch_columns.len() {
let v: Option<String> = r.get(i);
vals.push(v.unwrap_or_default());
}
Ok(Some(vals))
}
None => Ok(None),
}
}
fn delete_row_by_key(
&mut self,
table: &str,
key_columns: &[String],
key_values: &[String],
) -> Result<u64, String> {
let where_clause: Vec<String> = key_columns
.iter()
.map(|c| format!("`{}` = ?", sanitize_identifier(c)))
.collect();
let sql = format!(
"DELETE FROM `{}` WHERE {}",
sanitize_identifier(table),
where_clause.join(" AND ")
);
use mysql::prelude::Queryable;
let params: Vec<mysql::Value> = key_values
.iter()
.map(|v| mysql::Value::from(v.as_str()))
.collect();
self.conn
.exec_drop(&sql, ¶ms)
.map_err(|e| format!("deleting row from '{}': {}", table, e))?;
let affected: Option<u64> = self
.conn
.exec_first("SELECT ROW_COUNT()", ())
.map_err(|e| format!("getting affected rows: {}", e))?;
Ok(affected.unwrap_or(0))
}
}
pub fn connect(config: &crate::seed::schema::DatabaseConfig) -> Result<Box<dyn Database>, String> {
let driver = config.driver.as_str();
if config.has_structured_config() {
return connect_structured(config);
}
let url = if !config.url_env.is_empty() {
std::env::var(&config.url_env).map_err(|_| {
format!(
"environment variable '{}' not set for database URL",
config.url_env
)
})?
} else if !config.url.is_empty() {
config.url.clone()
} else {
std::env::var("DATABASE_URL").map_err(|_| {
"no database URL configured: set database.url, database.url_env, or DATABASE_URL env var, or use structured fields (host, port, user, password, name)".to_string()
})?
};
match driver {
#[cfg(feature = "sqlite")]
"sqlite" => Ok(Box::new(SqliteDb::connect(&url)?)),
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => Ok(Box::new(PostgresDb::connect(&url)?)),
#[cfg(feature = "mysql")]
"mysql" => Ok(Box::new(MysqlDb::connect(&url)?)),
_ => Err(unsupported_driver_error(driver)),
}
}
fn connect_structured(
config: &crate::seed::schema::DatabaseConfig,
) -> Result<Box<dyn Database>, String> {
let driver = config.driver.as_str();
match driver {
#[cfg(feature = "sqlite")]
"sqlite" => {
Err("structured database config is not supported for sqlite; use url instead".into())
}
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let dsn = build_postgres_dsn(config);
Ok(Box::new(PostgresDb::connect(&dsn)?))
}
#[cfg(feature = "mysql")]
"mysql" => {
if !config.options.is_empty() {
return Err(format!(
"structured database config does not support 'options' for mysql (unsupported keys: {})",
config.options.keys().cloned().collect::<Vec<_>>().join(", ")
));
}
let port = config.port.unwrap_or(3306);
let mut opts = mysql::OptsBuilder::default()
.ip_or_hostname(Some(&config.host))
.tcp_port(port);
if !config.user.is_empty() {
opts = opts.user(Some(&config.user));
}
if !config.password.is_empty() {
opts = opts.pass(Some(&config.password));
}
if !config.name.is_empty() {
opts = opts.db_name(Some(&config.name));
}
let pool = mysql::Pool::new(opts).map_err(|e| format!("connecting to mysql: {}", e))?;
let conn = pool
.get_conn()
.map_err(|e| format!("getting mysql connection: {}", e))?;
Ok(Box::new(MysqlDb {
conn,
in_transaction: false,
}))
}
_ => Err(unsupported_driver_error(driver)),
}
}
#[cfg(feature = "postgres")]
fn build_postgres_dsn(config: &crate::seed::schema::DatabaseConfig) -> String {
let mut parts = Vec::new();
parts.push(format!("host='{}'", escape_dsn_value(&config.host)));
parts.push(format!("port='{}'", config.port.unwrap_or(5432)));
if !config.user.is_empty() {
parts.push(format!("user='{}'", escape_dsn_value(&config.user)));
}
if !config.password.is_empty() {
parts.push(format!("password='{}'", escape_dsn_value(&config.password)));
}
if !config.name.is_empty() {
parts.push(format!("dbname='{}'", escape_dsn_value(&config.name)));
}
let mut keys: Vec<&String> = config.options.keys().collect();
keys.sort();
for key in keys {
let value = &config.options[key];
parts.push(format!(
"{}='{}'",
escape_dsn_value(key),
escape_dsn_value(value)
));
}
parts.join(" ")
}
fn escape_dsn_value(val: &str) -> String {
val.replace('\\', "\\\\").replace('\'', "\\'")
}
fn unsupported_driver_error(driver: &str) -> String {
let mut supported = Vec::new();
#[cfg(feature = "sqlite")]
supported.push("sqlite");
#[cfg(feature = "postgres")]
supported.push("postgres");
#[cfg(feature = "mysql")]
supported.push("mysql");
format!(
"unsupported database driver: '{}' (supported: {})",
driver,
supported.join(", ")
)
}
fn sanitize_identifier(name: &str) -> String {
name.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect()
}
fn escape_sql_value(val: &str) -> String {
format!("'{}'", val.replace('\'', "''"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_identifier() {
assert_eq!(sanitize_identifier("users"), "users");
assert_eq!(sanitize_identifier("my_table"), "my_table");
assert_eq!(sanitize_identifier("bad;drop"), "baddrop");
assert_eq!(sanitize_identifier("table--name"), "tablename");
}
#[test]
fn test_sqlite_tracking_table() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.ensure_tracking_table("initium_seed").unwrap();
assert!(!db.is_seed_applied("initium_seed", "test_set").unwrap());
db.mark_seed_applied("initium_seed", "test_set").unwrap();
assert!(db.is_seed_applied("initium_seed", "test_set").unwrap());
db.remove_seed_mark("initium_seed", "test_set").unwrap();
assert!(!db.is_seed_applied("initium_seed", "test_set").unwrap());
}
#[test]
fn test_sqlite_insert_and_exists() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.conn
.execute(
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT UNIQUE)",
[],
)
.unwrap();
let columns = vec!["name".into(), "email".into()];
let values = vec!["Alice".into(), "alice@example.com".into()];
let id = db.insert_row("users", &columns, &values, None).unwrap();
assert!(id.is_some());
assert_eq!(id.unwrap(), 1);
let unique_cols = vec!["email".into()];
let unique_vals = vec!["alice@example.com".into()];
assert!(db.row_exists("users", &unique_cols, &unique_vals).unwrap());
let unique_vals2 = vec!["bob@example.com".into()];
assert!(!db.row_exists("users", &unique_cols, &unique_vals2).unwrap());
}
#[test]
fn test_sqlite_delete_rows() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.conn
.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT)", [])
.unwrap();
db.insert_row("items", &["name".into()], &["item1".into()], None)
.unwrap();
db.insert_row("items", &["name".into()], &["item2".into()], None)
.unwrap();
let count = db.delete_rows("items").unwrap();
assert_eq!(count, 2);
}
#[test]
fn test_sqlite_transactions() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.conn
.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)", [])
.unwrap();
db.begin_transaction().unwrap();
db.insert_row("t", &["v".into()], &["a".into()], None)
.unwrap();
db.rollback_transaction().unwrap();
let count: i64 = db
.conn
.query_row("SELECT COUNT(*) FROM t", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 0);
db.begin_transaction().unwrap();
db.insert_row("t", &["v".into()], &["b".into()], None)
.unwrap();
db.commit_transaction().unwrap();
let count: i64 = db
.conn
.query_row("SELECT COUNT(*) FROM t", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn test_row_exists_empty_unique_key() {
let mut db = SqliteDb::connect(":memory:").unwrap();
assert!(!db.row_exists("any", &[], &[]).unwrap());
}
#[test]
fn test_connect_unsupported_driver() {
let config = crate::seed::schema::DatabaseConfig {
driver: "oracle".into(),
url: "localhost".into(),
..Default::default()
};
let result = connect(&config);
assert!(result.is_err());
}
#[test]
fn test_connect_sqlite() {
let config = crate::seed::schema::DatabaseConfig {
driver: "sqlite".into(),
url: ":memory:".into(),
..Default::default()
};
let db = connect(&config);
assert!(db.is_ok());
}
#[test]
fn test_connect_structured_sqlite_rejected() {
let config = crate::seed::schema::DatabaseConfig {
driver: "sqlite".into(),
host: "localhost".into(),
..Default::default()
};
let result = connect(&config);
let err = result.err().expect("expected error");
assert!(err.contains("not supported for sqlite"));
}
#[cfg(feature = "mysql")]
#[test]
fn test_connect_structured_mysql_rejects_options() {
use std::collections::HashMap;
let config = crate::seed::schema::DatabaseConfig {
driver: "mysql".into(),
host: "localhost".into(),
user: "root".into(),
name: "test".into(),
options: {
let mut m = HashMap::new();
m.insert("charset".into(), "utf8mb4".into());
m
},
..Default::default()
};
let err = connect(&config).err().expect("expected error");
assert!(err.contains("does not support 'options' for mysql"));
assert!(err.contains("charset"));
}
#[test]
fn test_escape_dsn_value() {
assert_eq!(escape_dsn_value("simple"), "simple");
assert_eq!(escape_dsn_value("it's"), "it\\'s");
assert_eq!(escape_dsn_value("back\\slash"), "back\\\\slash");
assert_eq!(escape_dsn_value("p@ss:word"), "p@ss:word");
}
#[cfg(feature = "postgres")]
#[test]
fn test_build_postgres_dsn() {
use std::collections::HashMap;
let config = crate::seed::schema::DatabaseConfig {
driver: "postgres".into(),
host: "pg.example.com".into(),
port: Some(5432),
user: "admin".into(),
password: "s3cr't".into(),
name: "mydb".into(),
options: {
let mut m = HashMap::new();
m.insert("sslmode".into(), "disable".into());
m
},
..Default::default()
};
let dsn = build_postgres_dsn(&config);
assert!(dsn.contains("host='pg.example.com'"));
assert!(dsn.contains("port='5432'"));
assert!(dsn.contains("user='admin'"));
assert!(dsn.contains("password='s3cr\\'t'"));
assert!(dsn.contains("dbname='mydb'"));
assert!(dsn.contains("sslmode='disable'"));
}
#[cfg(feature = "postgres")]
#[test]
fn test_build_postgres_dsn_default_port() {
let config = crate::seed::schema::DatabaseConfig {
driver: "postgres".into(),
host: "localhost".into(),
user: "app".into(),
name: "db".into(),
..Default::default()
};
let dsn = build_postgres_dsn(&config);
assert!(dsn.contains("port='5432'"));
assert!(!dsn.contains("password="));
}
#[test]
fn test_connect_url_from_env() {
std::env::set_var("TEST_CONNECT_DB_URL_39", "sqlite::memory:");
let config = crate::seed::schema::DatabaseConfig {
driver: "sqlite".into(),
url_env: "TEST_CONNECT_DB_URL_39".into(),
..Default::default()
};
std::env::set_var("TEST_CONNECT_DB_URL_39", ":memory:");
let result = connect(&config);
assert!(result.is_ok());
std::env::remove_var("TEST_CONNECT_DB_URL_39");
}
#[test]
fn test_connect_missing_url_env() {
std::env::remove_var("TEST_MISSING_DB_URL_39");
let config = crate::seed::schema::DatabaseConfig {
driver: "sqlite".into(),
url_env: "TEST_MISSING_DB_URL_39".into(),
..Default::default()
};
let err = connect(&config).err().expect("expected error");
assert!(err.contains("TEST_MISSING_DB_URL_39"));
}
#[test]
fn test_connect_no_url_no_env_no_structured() {
std::env::remove_var("DATABASE_URL");
let config = crate::seed::schema::DatabaseConfig {
driver: "sqlite".into(),
..Default::default()
};
let err = connect(&config).err().expect("expected error");
assert!(err.contains("no database URL configured"));
}
#[test]
fn test_mark_seed_idempotent() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.ensure_tracking_table("initium_seed").unwrap();
db.mark_seed_applied("initium_seed", "set1").unwrap();
db.mark_seed_applied("initium_seed", "set1").unwrap();
assert!(db.is_seed_applied("initium_seed", "set1").unwrap());
}
#[test]
fn test_sqlite_object_exists_table() {
let mut db = SqliteDb::connect(":memory:").unwrap();
assert!(!db.object_exists("table", "users").unwrap());
db.conn
.execute("CREATE TABLE users (id INTEGER PRIMARY KEY)", [])
.unwrap();
assert!(db.object_exists("table", "users").unwrap());
}
#[test]
fn test_sqlite_object_exists_view() {
let mut db = SqliteDb::connect(":memory:").unwrap();
db.conn
.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT)", [])
.unwrap();
assert!(!db.object_exists("view", "items_view").unwrap());
db.conn
.execute("CREATE VIEW items_view AS SELECT * FROM items", [])
.unwrap();
assert!(db.object_exists("view", "items_view").unwrap());
}
#[test]
fn test_sqlite_object_exists_schema_unsupported() {
let mut db = SqliteDb::connect(":memory:").unwrap();
let result = db.object_exists("schema", "public");
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not support schemas"));
}
#[test]
fn test_sqlite_object_exists_database_unsupported() {
let mut db = SqliteDb::connect(":memory:").unwrap();
let result = db.object_exists("database", "mydb");
assert!(result.is_err());
}
#[test]
fn test_sqlite_create_database_unsupported() {
let mut db = SqliteDb::connect(":memory:").unwrap();
let result = db.create_database("mydb");
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not support"));
}
#[test]
fn test_sqlite_create_schema_unsupported() {
let mut db = SqliteDb::connect(":memory:").unwrap();
let result = db.create_schema("myschema");
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not support"));
}
#[test]
fn test_sqlite_driver_name() {
let db = SqliteDb::connect(":memory:").unwrap();
assert_eq!(db.driver_name(), "sqlite");
}
#[test]
fn test_sqlite_object_exists_unknown_type() {
let mut db = SqliteDb::connect(":memory:").unwrap();
let result = db.object_exists("index", "my_index");
assert!(result.is_err());
assert!(result.unwrap_err().contains("unsupported object type"));
}
}