use std::collections::HashMap;
pub fn map_schema(schema_mappings: &HashMap<String, String>, source_schema: &str) -> String {
schema_mappings
.get(source_schema)
.cloned()
.unwrap_or_else(|| source_schema.to_string())
}
#[cfg(any(feature = "mysql", feature = "sqlite"))]
pub(crate) async fn execute_sqlx_batch_with_hook<DB>(
pool: &sqlx::Pool<DB>,
commands: &[std::borrow::Cow<'_, str>],
pre_commit_hook: Option<super::destination_factory::PreCommitHook>,
db_name: &str,
) -> crate::error::Result<()>
where
DB: sqlx::Database,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
{
let mut tx = pool.begin().await.map_err(|e| {
crate::error::CdcError::generic(format!("{db_name} BEGIN transaction failed: {e}"))
})?;
for (idx, sql) in commands.iter().enumerate() {
if let Err(e) = sqlx::query(sql.as_ref()).execute(&mut *tx).await {
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"{db_name} ROLLBACK failed after execution error: {}",
rollback_err
);
}
return Err(crate::error::CdcError::generic(format!(
"{db_name} execute_sql_batch failed at command {}/{}: {}",
idx + 1,
commands.len(),
e
)));
}
}
if let Some(hook) = pre_commit_hook {
if let Err(e) = hook().await {
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"{db_name} ROLLBACK failed after pre-commit hook error: {}",
rollback_err
);
}
return Err(crate::error::CdcError::generic(format!(
"{db_name} pre-commit hook failed, transaction rolled back: {}",
e
)));
}
}
tx.commit().await.map_err(|e| {
crate::error::CdcError::generic(format!("{db_name} COMMIT transaction failed: {e}"))
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map_schema() {
let mut mappings = HashMap::new();
mappings.insert("public".to_string(), "cdc_db".to_string());
assert_eq!(map_schema(&mappings, "public"), "cdc_db");
assert_eq!(map_schema(&mappings, "other"), "other");
}
}