use crate::config::types::TrackingTable;
use crate::db::error_context::SqlErrorContext;
use crate::migration::section_parser::{
BackoffStrategy, LockTimeoutAction, MigrationSection, RetryConfig, TransactionMode,
};
use crate::migration_tracking::section_tracking::*;
use crate::progress::SectionReporter;
use anyhow::Result;
use sqlx::PgPool;
use sqlx::postgres::PgDatabaseError;
use std::time::{Duration, Instant};
fn format_section_error(error: sqlx::Error, sql: &str, section_name: &str) -> anyhow::Error {
let ctx = SqlErrorContext::from_sqlx_error(&error, sql);
anyhow::anyhow!("{}", ctx.format(section_name, sql))
}
#[derive(Debug, Clone, Copy)]
pub enum ExecutionMode {
Production,
Validation,
}
pub struct SectionExecutor {
pool: PgPool,
tracking_table: TrackingTable,
reporter: SectionReporter,
mode: ExecutionMode,
}
impl SectionExecutor {
pub fn new(
pool: PgPool,
tracking_table: TrackingTable,
reporter: SectionReporter,
mode: ExecutionMode,
) -> Self {
Self {
pool,
tracking_table,
reporter,
mode,
}
}
pub async fn execute_section(
&mut self,
migration_version: u64,
section: &MigrationSection,
) -> Result<()> {
if matches!(self.mode, ExecutionMode::Validation) {
return self.execute_validation(section).await;
}
if let Some(SectionStatus::Completed) = get_section_status(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
)
.await?
{
self.reporter.skip_section(§ion.name);
return Ok(());
}
let result = match section.mode {
TransactionMode::Transactional => {
self.execute_transactional(migration_version, section).await
}
TransactionMode::NonTransactional => {
self.execute_non_transactional(migration_version, section)
.await
}
TransactionMode::Autocommit => {
self.execute_autocommit(migration_version, section).await
}
};
match &result {
Ok(_) => {}
Err(e) => {
self.reporter.fail_section(§ion.name, e);
}
}
result
}
async fn execute_transactional(
&mut self,
migration_version: u64,
section: &MigrationSection,
) -> Result<()> {
self.reporter
.start_section(§ion.name, section.description.as_deref());
let start = Instant::now();
record_section_start(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
)
.await?;
let mut tx = self.pool.begin().await?;
let timeout_ms = section.timeout.as_millis();
sqlx::query(&format!("SET LOCAL statement_timeout = '{}'", timeout_ms))
.execute(&mut *tx)
.await?;
if let Some(lock_timeout) = section.lock_timeout {
let lock_timeout_ms = lock_timeout.as_millis();
sqlx::query(&format!("SET LOCAL lock_timeout = '{}'", lock_timeout_ms))
.execute(&mut *tx)
.await?;
}
use sqlx::Executor;
let result = match tx.execute(section.sql.as_str()).await {
Ok(result) => result,
Err(e) => {
tx.rollback().await?;
record_section_failed(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
&e.to_string(),
)
.await?;
return Err(format_section_error(e, §ion.sql, §ion.name));
}
};
tx.commit().await?;
let duration = start.elapsed();
let rows = result.rows_affected() as i64;
record_section_complete(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
Some(rows),
duration.as_millis() as i64,
)
.await?;
self.reporter
.complete_section(§ion.name, duration, Some(rows as usize));
Ok(())
}
async fn execute_non_transactional(
&mut self,
migration_version: u64,
section: &MigrationSection,
) -> Result<()> {
self.reporter
.start_section(§ion.name, section.description.as_deref());
let default_retry_config = RetryConfig::default();
let retry_config = section
.retry_config
.as_ref()
.unwrap_or(&default_retry_config);
let start = Instant::now();
record_section_start(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
)
.await?;
for attempt in 1..=retry_config.attempts {
self.reporter.attempt(attempt, retry_config.attempts);
let timeout_ms = section.timeout.as_millis();
sqlx::query(&format!("SET statement_timeout = '{}'", timeout_ms))
.execute(&self.pool)
.await?;
if let Some(lock_timeout) = section.lock_timeout {
let lock_timeout_ms = lock_timeout.as_millis();
sqlx::query(&format!("SET lock_timeout = '{}'", lock_timeout_ms))
.execute(&self.pool)
.await?;
}
use sqlx::Executor;
match self.pool.execute(section.sql.as_str()).await {
Ok(result) => {
let duration = start.elapsed();
let rows = result.rows_affected() as i64;
record_section_complete(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
Some(rows),
duration.as_millis() as i64,
)
.await?;
self.reporter.complete_section_with_retry(
§ion.name,
duration,
Some(rows as usize),
attempt,
retry_config.attempts,
);
return Ok(());
}
Err(e) => {
let is_lock_timeout = classify_timeout_error(&e) == Some(TimeoutKind::Lock);
let should_retry = attempt < retry_config.attempts
&& (retry_config.on_lock_timeout == LockTimeoutAction::Retry
|| !is_lock_timeout);
if should_retry {
let delay = calculate_retry_delay(retry_config, attempt);
self.reporter
.retry(§ion.name, attempt, &e.into(), delay);
tokio::time::sleep(delay).await;
continue;
} else {
record_section_failed(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
&e.to_string(),
)
.await?;
return Err(format_section_error(e, §ion.sql, §ion.name));
}
}
}
}
unreachable!()
}
async fn execute_autocommit(
&mut self,
migration_version: u64,
section: &MigrationSection,
) -> Result<()> {
self.reporter
.start_section(§ion.name, section.description.as_deref());
let start = Instant::now();
record_section_start(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
)
.await?;
let timeout_ms = section.timeout.as_millis();
sqlx::query(&format!("SET statement_timeout = '{}'", timeout_ms))
.execute(&self.pool)
.await?;
if let Some(lock_timeout) = section.lock_timeout {
let lock_timeout_ms = lock_timeout.as_millis();
sqlx::query(&format!("SET lock_timeout = '{}'", lock_timeout_ms))
.execute(&self.pool)
.await?;
}
let result = sqlx::query(§ion.sql)
.execute(&self.pool)
.await
.map_err(|e| format_section_error(e, §ion.sql, §ion.name))?;
let duration = start.elapsed();
let rows = result.rows_affected() as i64;
record_section_complete(
&self.pool,
&self.tracking_table,
migration_version,
§ion.name,
Some(rows),
duration.as_millis() as i64,
)
.await?;
self.reporter
.complete_section(§ion.name, duration, Some(rows as usize));
Ok(())
}
async fn execute_validation(&mut self, section: &MigrationSection) -> Result<()> {
use sqlx::Executor;
match section.mode {
TransactionMode::Transactional => {
let mut tx = self.pool.begin().await?;
tx.execute(section.sql.as_str())
.await
.map_err(|e| format_section_error(e, §ion.sql, §ion.name))?;
tx.commit().await?;
}
TransactionMode::NonTransactional | TransactionMode::Autocommit => {
self.pool
.execute(section.sql.as_str())
.await
.map_err(|e| format_section_error(e, §ion.sql, §ion.name))?;
}
};
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TimeoutKind {
Lock,
Statement,
}
fn classify_timeout_error(error: &sqlx::Error) -> Option<TimeoutKind> {
if let Some(db_error) = error.as_database_error()
&& let Some(pg_error) = db_error.try_downcast_ref::<PgDatabaseError>()
{
return match pg_error.code() {
"55P03" => Some(TimeoutKind::Lock),
"57014" => Some(TimeoutKind::Statement),
_ => None,
};
}
None
}
fn calculate_retry_delay(config: &RetryConfig, attempt: u32) -> Duration {
match config.backoff {
BackoffStrategy::None => config.delay,
BackoffStrategy::Exponential => {
let multiplier = 2_u64.pow(attempt.saturating_sub(1));
config.delay.saturating_mul(multiplier.min(32) as u32) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_timeout_error_non_database_errors() {
let config_error = sqlx::Error::Configuration("statement timeout".into());
assert_eq!(classify_timeout_error(&config_error), None);
let other_error = sqlx::Error::Configuration("lock not available".into());
assert_eq!(classify_timeout_error(&other_error), None);
let syntax_error = sqlx::Error::Configuration("invalid syntax".into());
assert_eq!(classify_timeout_error(&syntax_error), None);
}
#[test]
fn test_calculate_retry_delay_none() {
let config = RetryConfig {
attempts: 5,
delay: Duration::from_secs(2),
backoff: BackoffStrategy::None,
on_lock_timeout: LockTimeoutAction::Retry,
};
assert_eq!(calculate_retry_delay(&config, 1), Duration::from_secs(2));
assert_eq!(calculate_retry_delay(&config, 2), Duration::from_secs(2));
assert_eq!(calculate_retry_delay(&config, 5), Duration::from_secs(2));
}
#[test]
fn test_calculate_retry_delay_exponential() {
let config = RetryConfig {
attempts: 5,
delay: Duration::from_secs(1),
backoff: BackoffStrategy::Exponential,
on_lock_timeout: LockTimeoutAction::Retry,
};
assert_eq!(calculate_retry_delay(&config, 1), Duration::from_secs(1)); assert_eq!(calculate_retry_delay(&config, 2), Duration::from_secs(2)); assert_eq!(calculate_retry_delay(&config, 3), Duration::from_secs(4)); assert_eq!(calculate_retry_delay(&config, 4), Duration::from_secs(8)); assert_eq!(calculate_retry_delay(&config, 5), Duration::from_secs(16)); }
#[test]
fn test_calculate_retry_delay_exponential_capped() {
let config = RetryConfig {
attempts: 10,
delay: Duration::from_secs(1),
backoff: BackoffStrategy::Exponential,
on_lock_timeout: LockTimeoutAction::Retry,
};
assert_eq!(calculate_retry_delay(&config, 6), Duration::from_secs(32)); assert_eq!(calculate_retry_delay(&config, 7), Duration::from_secs(32)); assert_eq!(calculate_retry_delay(&config, 10), Duration::from_secs(32)); }
}