use serde::{Deserialize, Serialize};
use tokio_postgres::Client;
use tracing::{debug, info, warn};
use crate::error::Result;
use crate::tls::connect_with_sslmode;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SourceSequence {
pub schema: String,
pub name: String,
pub last_value: Option<i64>,
}
pub const COLLECT_SEQUENCES_SQL_NO_FILTER: &str = "\
SELECT n.nspname::text, c.relname::text, \
pg_sequence_last_value(c.oid::regclass) AS last_value \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.relkind = 'S' \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
AND n.nspname NOT LIKE 'pg_temp_%' \
AND n.nspname NOT LIKE 'pg_toast_temp_%'";
pub const COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER: &str = "\
SELECT n.nspname::text, c.relname::text, \
pg_sequence_last_value(c.oid::regclass) AS last_value \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.relkind = 'S' \
AND n.nspname = ANY($1::text[])";
pub async fn collect_source_sequences(
source: &Client,
schema_filter: &[String],
) -> Result<Vec<SourceSequence>> {
let rows = if schema_filter.is_empty() {
source.query(COLLECT_SEQUENCES_SQL_NO_FILTER, &[]).await?
} else {
source
.query(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER, &[&schema_filter])
.await?
};
let mut out = Vec::with_capacity(rows.len());
for row in rows {
out.push(SourceSequence {
schema: row.get(0),
name: row.get(1),
last_value: row.get(2),
});
}
Ok(out)
}
pub async fn apply_sequences_to_target(
target: &Client,
sequences: &[SourceSequence],
) -> Result<usize> {
let mut applied = 0usize;
for seq in sequences {
let Some(last_value) = seq.last_value else {
debug!(
schema = %seq.schema,
name = %seq.name,
"skipping sequence: never advanced on source",
);
continue;
};
let sql = build_setval_sql(&seq.schema, &seq.name)?;
match target.execute(&sql, &[&last_value]).await {
Ok(_) => {
applied += 1;
debug!(schema = %seq.schema, name = %seq.name, last_value, "synced sequence");
}
Err(e) => {
warn!(
schema = %seq.schema,
name = %seq.name,
error = %e,
"failed to sync sequence (continuing)"
);
}
}
}
Ok(applied)
}
pub fn build_setval_sql(schema: &str, name: &str) -> Result<String> {
let qualified = format!(
"{}.{}",
pg_walstream::quote_ident(schema)?,
pg_walstream::quote_ident(name)?,
);
let qualified_lit = pg_walstream::quote_literal(&qualified)?;
Ok(format!(
"SELECT setval({qualified_lit}::regclass, $1::bigint, true)"
))
}
pub async fn sync_sequences(
source_conn: &str,
target_conn: &str,
schema_filter: &[String],
) -> Result<usize> {
info!("syncing sequences from source to target");
let source = connect_with_sslmode(source_conn).await?;
let target = connect_with_sslmode(target_conn).await?;
let seqs = collect_source_sequences(&source, schema_filter).await?;
let total = seqs.len();
info!(total, "collected sequences from source");
let applied = apply_sequences_to_target(&target, &seqs).await?;
info!(applied, skipped = total - applied, "sequence sync complete");
Ok(applied)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn collect_sql_no_filter_includes_pg_sequence_last_value_function() {
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_sequence_last_value"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("c.relkind = 'S'"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_catalog"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("information_schema"));
}
#[test]
fn collect_sql_with_schema_filter_uses_parameterised_array() {
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("$1::text[]"));
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("ANY"));
}
#[test]
fn source_sequence_serde_roundtrip() {
let s = SourceSequence {
schema: "public".into(),
name: "widgets_id_seq".into(),
last_value: Some(100),
};
let json = serde_json::to_string(&s).unwrap();
let s2: SourceSequence = serde_json::from_str(&json).unwrap();
assert_eq!(s, s2);
}
#[test]
fn source_sequence_handles_never_advanced() {
let s = SourceSequence {
schema: "public".into(),
name: "fresh_seq".into(),
last_value: None,
};
assert!(s.last_value.is_none());
}
#[test]
fn build_setval_sql_quotes_identifiers_and_literal() {
let sql = build_setval_sql("public", "widgets_id_seq").unwrap();
assert_eq!(
sql,
"SELECT setval('\"public\".\"widgets_id_seq\"'::regclass, $1::bigint, true)"
);
}
#[test]
fn build_setval_sql_escapes_embedded_double_quote() {
let sql = build_setval_sql("we\"ird", "seq").unwrap();
assert!(sql.contains("\"we\"\"ird\""));
assert!(sql.contains("'\"we\"\"ird\".\"seq\"'::regclass"));
}
#[test]
fn build_setval_sql_escapes_embedded_single_quote() {
let sql = build_setval_sql("public", "o'reilly").unwrap();
assert!(sql.contains("''"));
assert!(sql.starts_with("SELECT setval('"));
assert!(sql.contains("::regclass, $1::bigint, true)"));
}
#[test]
fn build_setval_sql_handles_spaces_in_identifiers() {
let sql = build_setval_sql("my schema", "my seq").unwrap();
assert!(sql.contains("\"my schema\""));
assert!(sql.contains("\"my seq\""));
}
#[test]
fn build_setval_sql_handles_backtick_in_identifiers() {
let sql = build_setval_sql("pub`lic", "seq`name").unwrap();
assert!(sql.contains("\"pub`lic\""));
assert!(sql.contains("\"seq`name\""));
}
#[test]
fn source_sequence_none_last_value_serde_roundtrip() {
let s = SourceSequence {
schema: "public".into(),
name: "fresh".into(),
last_value: None,
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("null"));
let s2: SourceSequence = serde_json::from_str(&json).unwrap();
assert_eq!(s, s2);
assert!(s2.last_value.is_none());
}
#[test]
fn collect_sql_no_filter_excludes_temp_schemas() {
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_temp_%"));
assert!(COLLECT_SEQUENCES_SQL_NO_FILTER.contains("pg_toast_temp_%"));
}
#[test]
fn collect_sql_with_schema_filter_uses_any_operator() {
assert!(COLLECT_SEQUENCES_SQL_WITH_SCHEMA_FILTER.contains("ANY($1::text[])"));
}
}