use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio_postgres::Client;
use tracing::{debug, info, warn};
use crate::error::Result;
use crate::tls::connect_with_sslmode;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SourceSequence {
pub schema: String,
pub name: String,
pub last_value: Option<i64>,
}
pub const COLLECT_SEQUENCES_SQL_NO_FILTER: &str = "\
SELECT n.nspname::text, c.relname::text, \
pg_sequence_last_value(c.oid::regclass) AS last_value \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.relkind = 'S' \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
AND n.nspname NOT LIKE 'pg_temp_%' \
AND n.nspname NOT LIKE 'pg_toast_temp_%'";
pub const COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER: &str = "\
SELECT n.nspname::text, c.relname::text, \
pg_sequence_last_value(c.oid::regclass) AS last_value \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.relkind = 'S' \
AND n.nspname = ANY($1::text[])";
pub async fn collect_source_sequences(
source: &Client,
schema_filter: &[String],
) -> Result<Vec<SourceSequence>> {
let rows = if schema_filter.is_empty() {
source.query(COLLECT_SEQUENCES_SQL_NO_FILTER, &[]).await?
} else {
source
.query(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER, &[&schema_filter])
.await?
};
let mut out = Vec::with_capacity(rows.len());
for row in rows {
out.push(SourceSequence {
schema: row.get(0),
name: row.get(1),
last_value: row.get(2),
});
}
Ok(out)
}
#[async_trait]
pub(crate) trait SeqSyncTarget: Send + Sync {
async fn execute_setval(&self, sql: &str, last_value: i64) -> Result<u64>;
async fn batch_execute_sql(&self, sql: &str) -> Result<()>;
async fn query_batch_applied(&self) -> Result<i32>;
}
#[async_trait]
impl SeqSyncTarget for Client {
async fn execute_setval(&self, sql: &str, last_value: i64) -> Result<u64> {
Ok(self.execute(sql, &[&last_value]).await?)
}
async fn batch_execute_sql(&self, sql: &str) -> Result<()> {
Ok(Client::batch_execute(self, sql).await?)
}
async fn query_batch_applied(&self) -> Result<i32> {
let row = self
.query_one("SELECT applied FROM _seq_sync_result", &[])
.await?;
Ok(row.get(0))
}
}
pub async fn apply_sequences_to_target(
target: &Client,
sequences: &[SourceSequence],
) -> Result<usize> {
apply_sequences_impl(target, sequences).await
}
async fn apply_sequences_impl(
target: &dyn SeqSyncTarget,
sequences: &[SourceSequence],
) -> Result<usize> {
let actionable: Vec<&SourceSequence> = sequences
.iter()
.filter(|s| {
if s.last_value.is_none() {
debug!(
schema = %s.schema,
name = %s.name,
"skipping sequence: never advanced on source",
);
}
s.last_value.is_some()
})
.collect();
if actionable.is_empty() {
return Ok(0);
}
if actionable.len() == 1 {
return apply_single(target, actionable[0]).await;
}
match apply_batch(target, &actionable).await {
Ok(applied) => Ok(applied),
Err(e) => {
warn!(
error = %e,
"batch sequence sync failed — falling back to individual statements"
);
apply_individually(target, &actionable).await
}
}
}
pub fn build_batch_setval_sql(sequences: &[&SourceSequence]) -> Result<String> {
let mut body = String::new();
body.push_str("CREATE TEMP TABLE IF NOT EXISTS _seq_sync_result (applied int);\n");
body.push_str("TRUNCATE _seq_sync_result;\n");
body.push_str("DO $__pg_dbmigrator_seq_sync__$\nDECLARE\n _applied int := 0;\nBEGIN\n");
for seq in sequences {
let last_value = seq.last_value.unwrap_or(0);
let qualified = format!(
"{}.{}",
pg_walstream::quote_ident(&seq.schema)?,
pg_walstream::quote_ident(&seq.name)?,
);
let qualified_lit = pg_walstream::quote_literal(&qualified)?;
body.push_str(" BEGIN\n");
body.push_str(&format!(
" PERFORM setval({qualified_lit}::regclass, {last_value}::bigint, true);\n"
));
body.push_str(" _applied := _applied + 1;\n");
body.push_str(" EXCEPTION WHEN OTHERS THEN\n");
body.push_str(&format!(
" RAISE WARNING 'setval failed for {}: %', SQLERRM;\n",
qualified.replace('\'', "''").replace('%', "%%")
));
body.push_str(" END;\n");
}
body.push_str(" INSERT INTO _seq_sync_result VALUES (_applied);\n");
body.push_str("END;\n$__pg_dbmigrator_seq_sync__$;");
Ok(body)
}
async fn apply_batch(target: &dyn SeqSyncTarget, sequences: &[&SourceSequence]) -> Result<usize> {
let sql = build_batch_setval_sql(sequences)?;
target.batch_execute_sql(&sql).await?;
let applied = target.query_batch_applied().await?;
for seq in sequences {
debug!(schema = %seq.schema, name = %seq.name, "synced sequence (batch)");
}
Ok(applied as usize)
}
async fn apply_single(target: &dyn SeqSyncTarget, seq: &SourceSequence) -> Result<usize> {
let last_value = seq.last_value.unwrap_or(0);
let sql = build_setval_sql(&seq.schema, &seq.name)?;
match target.execute_setval(&sql, last_value).await {
Ok(_) => {
debug!(schema = %seq.schema, name = %seq.name, last_value, "synced sequence");
Ok(1)
}
Err(e) => {
warn!(
schema = %seq.schema,
name = %seq.name,
error = %e,
"failed to sync sequence (continuing)"
);
Ok(0)
}
}
}
async fn apply_individually(
target: &dyn SeqSyncTarget,
sequences: &[&SourceSequence],
) -> Result<usize> {
let mut applied = 0usize;
for seq in sequences {
let last_value = seq.last_value.unwrap_or(0);
let sql = build_setval_sql(&seq.schema, &seq.name)?;
match target.execute_setval(&sql, last_value).await {
Ok(_) => {
applied += 1;
debug!(schema = %seq.schema, name = %seq.name, last_value, "synced sequence");
}
Err(e) => {
warn!(
schema = %seq.schema,
name = %seq.name,
error = %e,
"failed to sync sequence (continuing)"
);
}
}
}
Ok(applied)
}
pub fn build_setval_sql(schema: &str, name: &str) -> Result<String> {
let qualified = format!(
"{}.{}",
pg_walstream::quote_ident(schema)?,
pg_walstream::quote_ident(name)?,
);
let qualified_lit = pg_walstream::quote_literal(&qualified)?;
Ok(format!(
"SELECT setval({qualified_lit}::regclass, $1::bigint, true)"
))
}
pub async fn sync_sequences(
source_conn: &str,
target_conn: &str,
schema_filter: &[String],
) -> Result<usize> {
info!("syncing sequences from source to target");
let source = connect_with_sslmode(source_conn).await?;
let target = connect_with_sslmode(target_conn).await?;
let seqs = collect_source_sequences(&source, schema_filter).await?;
let total = seqs.len();
info!(total, "collected sequences from source");
let applied = apply_sequences_to_target(&target, &seqs).await?;
info!(applied, skipped = total - applied, "sequence sync complete");
Ok(applied)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_sql_no_filter_includes_pg_sequence_last_value_function() {
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_sequence_last_value"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("c.relkind = 'S'"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_catalog"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("information_schema"));
}
#[test]
fn collect_sql_with_schema_filter_uses_parameterised_array() {
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("$1::text[]"));
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("ANY"));
}
#[test]
fn source_sequence_serde_roundtrip() {
let s = SourceSequence {
schema: "public".into(),
name: "widgets_id_seq".into(),
last_value: Some(100),
};
let json = serde_json::to_string(&s).unwrap();
let s2: SourceSequence = serde_json::from_str(&json).unwrap();
assert_eq!(s, s2);
}
#[test]
fn source_sequence_handles_never_advanced() {
let s = SourceSequence {
schema: "public".into(),
name: "fresh_seq".into(),
last_value: None,
};
assert!(s.last_value.is_none());
}
#[test]
fn build_setval_sql_quotes_identifiers_and_literal() {
let sql = build_setval_sql("public", "widgets_id_seq").unwrap();
assert_eq!(
sql,
"SELECT setval('\"public\".\"widgets_id_seq\"'::regclass, $1::bigint, true)"
);
}
#[test]
fn build_setval_sql_escapes_embedded_double_quote() {
let sql = build_setval_sql("we\"ird", "seq").unwrap();
assert!(sql.contains("\"we\"\"ird\""));
assert!(sql.contains("'\"we\"\"ird\".\"seq\"'::regclass"));
}
#[test]
fn build_setval_sql_escapes_embedded_single_quote() {
let sql = build_setval_sql("public", "o'reilly").unwrap();
assert!(sql.contains("''"));
assert!(sql.starts_with("SELECT setval('"));
assert!(sql.contains("::regclass, $1::bigint, true)"));
}
#[test]
fn build_setval_sql_handles_spaces_in_identifiers() {
let sql = build_setval_sql("my schema", "my seq").unwrap();
assert!(sql.contains("\"my schema\""));
assert!(sql.contains("\"my seq\""));
}
#[test]
fn build_setval_sql_handles_backtick_in_identifiers() {
let sql = build_setval_sql("pub`lic", "seq`name").unwrap();
assert!(sql.contains("\"pub`lic\""));
assert!(sql.contains("\"seq`name\""));
}
#[test]
fn source_sequence_none_last_value_serde_roundtrip() {
let s = SourceSequence {
schema: "public".into(),
name: "fresh".into(),
last_value: None,
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("null"));
let s2: SourceSequence = serde_json::from_str(&json).unwrap();
assert_eq!(s, s2);
assert!(s2.last_value.is_none());
}
#[test]
fn collect_sql_no_filter_excludes_temp_schemas() {
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_temp_%"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_toast_temp_%"));
}
#[test]
fn collect_sql_with_schema_filter_uses_any_operator() {
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("ANY($1::text[])"));
}
#[test]
fn build_batch_setval_sql_produces_valid_plpgsql() {
let seqs = [
SourceSequence {
schema: "public".into(),
name: "users_id_seq".into(),
last_value: Some(42),
},
SourceSequence {
schema: "public".into(),
name: "orders_id_seq".into(),
last_value: Some(100),
},
];
let refs: Vec<&SourceSequence> = seqs.iter().collect();
let sql = build_batch_setval_sql(&refs).unwrap();
assert!(sql.contains("CREATE TEMP TABLE IF NOT EXISTS _seq_sync_result"));
assert!(!sql.contains("ON COMMIT DROP"));
assert!(sql.contains("DO $__pg_dbmigrator_seq_sync__$"));
assert!(sql.ends_with("$__pg_dbmigrator_seq_sync__$;"));
assert!(sql.contains("PERFORM setval"));
assert!(sql.contains("42::bigint"));
assert!(sql.contains("100::bigint"));
assert!(sql.contains("EXCEPTION WHEN OTHERS"));
assert!(sql.contains("_applied := _applied + 1"));
assert!(sql.contains("INSERT INTO _seq_sync_result VALUES (_applied)"));
}
#[test]
fn build_batch_setval_sql_escapes_special_chars() {
let seqs = [SourceSequence {
schema: "my\"schema".into(),
name: "o'reilly_seq".into(),
last_value: Some(7),
}];
let refs: Vec<&SourceSequence> = seqs.iter().collect();
let sql = build_batch_setval_sql(&refs).unwrap();
assert!(sql.contains("\"my\"\"schema\""));
assert!(sql.contains("7::bigint"));
}
#[test]
fn build_batch_setval_sql_escapes_percent_in_raise_warning() {
let seqs = [SourceSequence {
schema: "public".into(),
name: "pct%seq".into(),
last_value: Some(1),
}];
let refs: Vec<&SourceSequence> = seqs.iter().collect();
let sql = build_batch_setval_sql(&refs).unwrap();
assert!(
sql.contains("%%"),
"percent signs in identifiers must be doubled for RAISE WARNING"
);
}
#[test]
fn build_batch_setval_sql_empty_input() {
let refs: Vec<&SourceSequence> = vec![];
let sql = build_batch_setval_sql(&refs).unwrap();
assert!(sql.contains("DO $__pg_dbmigrator_seq_sync__$"));
assert!(!sql.contains("PERFORM setval"));
assert!(sql.contains("INSERT INTO _seq_sync_result VALUES (_applied)"));
}
use crate::error::MigrationError;
use std::collections::VecDeque;
use std::sync::Mutex;
struct MockTarget {
setval_results: Mutex<VecDeque<Result<u64>>>,
batch_exec_result: Mutex<Option<Result<()>>>,
batch_applied: Mutex<Option<Result<i32>>>,
}
impl MockTarget {
fn ok(applied_count: i32) -> Self {
Self {
setval_results: Mutex::new(VecDeque::new()),
batch_exec_result: Mutex::new(Some(Ok(()))),
batch_applied: Mutex::new(Some(Ok(applied_count))),
}
}
fn batch_fails() -> Self {
Self {
setval_results: Mutex::new(VecDeque::new()),
batch_exec_result: Mutex::new(Some(Err(MigrationError::config(
"batch not supported",
)))),
batch_applied: Mutex::new(None),
}
}
fn with_setval_results(mut self, results: Vec<Result<u64>>) -> Self {
self.setval_results = Mutex::new(results.into());
self
}
}
#[async_trait]
impl SeqSyncTarget for MockTarget {
async fn execute_setval(&self, _sql: &str, _last_value: i64) -> Result<u64> {
self.setval_results
.lock()
.unwrap()
.pop_front()
.unwrap_or(Ok(1))
}
async fn batch_execute_sql(&self, _sql: &str) -> Result<()> {
self.batch_exec_result
.lock()
.unwrap()
.take()
.unwrap_or(Ok(()))
}
async fn query_batch_applied(&self) -> Result<i32> {
self.batch_applied.lock().unwrap().take().unwrap_or(Ok(0))
}
}
fn seq(schema: &str, name: &str, val: Option<i64>) -> SourceSequence {
SourceSequence {
schema: schema.into(),
name: name.into(),
last_value: val,
}
}
#[tokio::test]
async fn apply_all_none_last_value_returns_zero() {
let target = MockTarget::ok(0);
let seqs = vec![seq("public", "s1", None), seq("public", "s2", None)];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 0);
}
#[tokio::test]
async fn apply_empty_sequences_returns_zero() {
let target = MockTarget::ok(0);
let applied = apply_sequences_impl(&target, &[]).await.unwrap();
assert_eq!(applied, 0);
}
#[tokio::test]
async fn apply_single_sequence_success() {
let target = MockTarget::ok(0).with_setval_results(vec![Ok(1)]);
let seqs = vec![seq("public", "users_id_seq", Some(42))];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 1);
}
#[tokio::test]
async fn apply_single_sequence_failure_returns_zero() {
let target = MockTarget::ok(0)
.with_setval_results(vec![Err(MigrationError::config("permission denied"))]);
let seqs = vec![seq("public", "users_id_seq", Some(42))];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 0);
}
#[tokio::test]
async fn apply_batch_success() {
let target = MockTarget::ok(2);
let seqs = vec![seq("public", "s1", Some(10)), seq("public", "s2", Some(20))];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 2);
}
#[tokio::test]
async fn apply_batch_reports_partial_success() {
let target = MockTarget::ok(1);
let seqs = vec![seq("public", "s1", Some(10)), seq("public", "s2", Some(20))];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 1);
}
#[tokio::test]
async fn apply_batch_failure_falls_back_to_individual() {
let target = MockTarget::batch_fails().with_setval_results(vec![Ok(1), Ok(1)]);
let seqs = vec![seq("public", "s1", Some(10)), seq("public", "s2", Some(20))];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 2);
}
#[tokio::test]
async fn apply_individually_mixed_results() {
let target = MockTarget::batch_fails().with_setval_results(vec![
Ok(1),
Err(MigrationError::config("fail")),
Ok(1),
]);
let seqs = vec![
seq("public", "s1", Some(1)),
seq("public", "s2", Some(2)),
seq("public", "s3", Some(3)),
];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 2);
}
#[tokio::test]
async fn apply_filters_none_and_routes_remaining() {
let target = MockTarget::ok(0).with_setval_results(vec![Ok(1)]);
let seqs = vec![
seq("public", "never_used", None),
seq("public", "used_seq", Some(99)),
];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 1);
}
#[tokio::test]
async fn apply_filters_none_and_routes_to_batch() {
let target = MockTarget::ok(2);
let seqs = vec![
seq("public", "skip_me", None),
seq("public", "s1", Some(10)),
seq("public", "s2", Some(20)),
];
let applied = apply_sequences_impl(&target, &seqs).await.unwrap();
assert_eq!(applied, 2);
}
}