use anyhow::{anyhow, Result};
use rust_decimal::Decimal;
use std::str::FromStr;
use std::sync::Arc;
use testcontainers::{ContainerAsync, GenericImage, ImageExt};
use tokio_postgres::{Client, NoTls};
#[derive(Debug, Clone)]
pub struct ReplicationPostgresConfig {
pub host: String,
pub port: u16,
pub database: String,
pub user: String,
pub password: String,
}
impl ReplicationPostgresConfig {
pub fn connection_string(&self) -> String {
format!(
"host={} port={} user={} password={} dbname={}",
self.host, self.port, self.user, self.password, self.database
)
}
}
#[derive(Clone)]
pub struct ReplicationPostgresGuard {
inner: Arc<ReplicationPostgresGuardInner>,
}
struct ReplicationPostgresGuardInner {
container: std::sync::Mutex<Option<ContainerAsync<GenericImage>>>,
config: ReplicationPostgresConfig,
}
impl ReplicationPostgresGuard {
pub async fn new() -> Self {
let (container, config) = setup_postgres_raw().await;
Self {
inner: Arc::new(ReplicationPostgresGuardInner {
container: std::sync::Mutex::new(Some(container)),
config,
}),
}
}
pub fn config(&self) -> &ReplicationPostgresConfig {
&self.inner.config
}
pub async fn get_client(&self) -> Result<Client> {
let mut last_error = None;
for _ in 0..20 {
match tokio_postgres::connect(&self.config().connection_string(), NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
if let Err(e) = connection.await {
log::error!("PostgreSQL connection error: {e}");
}
});
return Ok(client);
}
Err(e) => {
last_error = Some(e);
tokio::time::sleep(std::time::Duration::from_millis(250)).await;
}
}
}
Err(anyhow!(
"Failed to connect to PostgreSQL after retries: {last_error:?}"
))
}
pub async fn cleanup(self) {
let container_to_stop = {
if let Ok(mut container_guard) = self.inner.container.lock() {
container_guard.take()
} else {
None
}
};
if let Some(container) = container_to_stop {
let container_id = container.id().to_string();
match container.stop().await {
Ok(_) => log::debug!("Successfully stopped PostgreSQL container: {container_id}"),
Err(e) => log::warn!("Error stopping container {container_id}: {e}"),
}
drop(container);
}
}
}
impl Drop for ReplicationPostgresGuardInner {
fn drop(&mut self) {
if let Ok(mut container_guard) = self.container.lock() {
if let Some(container) = container_guard.take() {
drop(container);
}
}
}
}
pub async fn setup_replication_postgres() -> ReplicationPostgresGuard {
ReplicationPostgresGuard::new().await
}
async fn setup_postgres_raw() -> (ContainerAsync<GenericImage>, ReplicationPostgresConfig) {
use testcontainers::runners::AsyncRunner;
let image = GenericImage::new("postgres", "16-alpine")
.with_exposed_port(testcontainers::core::ContainerPort::Tcp(5432))
.with_env_var("POSTGRES_PASSWORD", "postgres")
.with_env_var("POSTGRES_USER", "postgres")
.with_env_var("POSTGRES_DB", "postgres")
.with_cmd([
"-c",
"wal_level=logical",
"-c",
"max_replication_slots=10",
"-c",
"max_wal_senders=10",
]);
let container = image
.start()
.await
.expect("Failed to start PostgreSQL container");
let pg_port = container
.get_host_port_ipv4(5432)
.await
.expect("Failed to resolve Postgres port");
let config = ReplicationPostgresConfig {
host: "localhost".to_string(), port: pg_port,
database: "postgres".to_string(),
user: "postgres".to_string(),
password: "postgres".to_string(),
};
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
(container, config)
}
pub async fn execute_sql(client: &Client, sql: &str) -> Result<u64> {
let result = client.execute(sql, &[]).await?;
Ok(result)
}
pub async fn create_test_table(client: &Client, table_name: &str) -> Result<()> {
let create_sql = format!(
"CREATE TABLE IF NOT EXISTS {table_name} (\n id INTEGER PRIMARY KEY,\n name TEXT NOT NULL\n)"
);
execute_sql(client, &create_sql).await?;
let replica_sql = format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL");
execute_sql(client, &replica_sql).await?;
Ok(())
}
pub async fn create_test_table_replica_identity_default(
client: &Client,
table_name: &str,
) -> Result<()> {
let create_sql = format!(
"CREATE TABLE IF NOT EXISTS {table_name} (\n id INTEGER PRIMARY KEY,\n name TEXT NOT NULL\n)"
);
execute_sql(client, &create_sql).await?;
let replica_sql = format!("ALTER TABLE {table_name} REPLICA IDENTITY DEFAULT");
execute_sql(client, &replica_sql).await?;
Ok(())
}
fn quote_ident(ident: &str) -> String {
format!("\"{}\"", ident.replace('"', "\"\""))
}
pub async fn create_publication(
client: &Client,
publication_name: &str,
tables: &[String],
) -> Result<()> {
let tables_list = tables
.iter()
.map(|t| quote_ident(t))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"CREATE PUBLICATION {} FOR TABLE {}",
quote_ident(publication_name),
tables_list
);
execute_sql(client, &sql).await?;
Ok(())
}
pub async fn insert_test_row(client: &Client, table: &str, id: i32, name: &str) -> Result<()> {
let sql = format!(
"INSERT INTO {} (id, name) VALUES ($1, $2)",
quote_ident(table)
);
client.execute(&sql, &[&id, &name]).await?;
Ok(())
}
pub async fn update_test_row(client: &Client, table: &str, id: i32, name: &str) -> Result<()> {
let sql = format!("UPDATE {} SET name = $1 WHERE id = $2", quote_ident(table));
client.execute(&sql, &[&name, &id]).await?;
Ok(())
}
pub async fn delete_test_row(client: &Client, table: &str, id: i32) -> Result<()> {
let sql = format!("DELETE FROM {} WHERE id = $1", quote_ident(table));
client.execute(&sql, &[&id]).await?;
Ok(())
}
pub async fn grant_replication(client: &Client, user: &str) -> Result<()> {
let sql = format!("ALTER ROLE {} WITH REPLICATION", quote_ident(user));
execute_sql(client, &sql).await?;
Ok(())
}
pub async fn grant_table_access(client: &Client, table: &str, user: &str) -> Result<()> {
let sql = format!(
"GRANT SELECT ON TABLE {} TO {}",
quote_ident(table),
quote_ident(user)
);
execute_sql(client, &sql).await?;
Ok(())
}
pub async fn create_logical_replication_slot(client: &Client, slot_name: &str) -> Result<()> {
let sql = "SELECT pg_create_logical_replication_slot($1, 'pgoutput')";
let _ = client.query_one(sql, &[&slot_name]).await?;
Ok(())
}
pub async fn create_decimal_test_table(client: &Client, table_name: &str) -> Result<()> {
let create_sql = format!(
"CREATE TABLE IF NOT EXISTS {table_name} (\n id INTEGER PRIMARY KEY,\n price NUMERIC(10, 2),\n quantity NUMERIC(15, 4),\n total NUMERIC(20, 6)\n)"
);
execute_sql(client, &create_sql).await?;
let replica_sql = format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL");
execute_sql(client, &replica_sql).await?;
Ok(())
}
pub async fn insert_decimal_test_row(
client: &Client,
table: &str,
id: i32,
price: &str,
quantity: &str,
total: &str,
) -> Result<()> {
let price_decimal = Decimal::from_str(price)?;
let quantity_decimal = Decimal::from_str(quantity)?;
let total_decimal = Decimal::from_str(total)?;
let sql = format!(
"INSERT INTO {} (id, price, quantity, total) VALUES ($1, $2, $3, $4)",
quote_ident(table)
);
client
.execute(
&sql,
&[&id, &price_decimal, &quantity_decimal, &total_decimal],
)
.await?;
Ok(())
}