use crate::{E2eError, Result};
use sqlx::postgres::PgPoolOptions;
use sqlx::{PgPool, Row};
use std::time::Duration;
use tracing::info;
pub struct PostgresResource {
pool: PgPool,
admin_pool: PgPool,
pub database: String,
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
base_url: String,
query_string: String,
should_drop: bool,
}
impl PostgresResource {
pub async fn connect_existing(admin_url: &str, database: &str) -> Result<Self> {
let parsed = url::Url::parse(admin_url)
.map_err(|e| E2eError::Postgres(sqlx::Error::Configuration(e.to_string().into())))?;
let host = parsed.host_str().unwrap_or("localhost").to_string();
let port = parsed.port().unwrap_or(5432);
let user = parsed.username().to_string();
let password = parsed.password().unwrap_or("").to_string();
let query_string = parsed
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
let base_url = format!("postgres://{}:{}@{}:{}", user, password, host, port);
let admin_pool = PgPoolOptions::new()
.max_connections(2)
.acquire_timeout(Duration::from_secs(30))
.connect(admin_url)
.await?;
let db_url = format!("{}/{}{}", base_url, database, query_string);
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(30))
.connect(&db_url)
.await?;
info!("Connected to existing PostgreSQL database: {}", database);
Ok(Self {
pool,
admin_pool,
database: database.to_string(),
host,
port,
user,
password,
base_url,
query_string,
should_drop: false, })
}
pub async fn new(admin_url: &str, database: &str) -> Result<Self> {
let parsed = url::Url::parse(admin_url)
.map_err(|e| E2eError::Postgres(sqlx::Error::Configuration(e.to_string().into())))?;
let host = parsed.host_str().unwrap_or("localhost").to_string();
let port = parsed.port().unwrap_or(5432);
let user = parsed.username().to_string();
let password = parsed.password().unwrap_or("").to_string();
let query_string = parsed
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
let base_url = format!("postgres://{}:{}@{}:{}", user, password, host, port);
let admin_pool = PgPoolOptions::new()
.max_connections(2)
.acquire_timeout(Duration::from_secs(30))
.connect(admin_url)
.await?;
sqlx::query(&format!(
"CREATE DATABASE \"{}\"",
database.replace('"', "\"\"")
))
.execute(&admin_pool)
.await
.ok();
let db_url = format!("{}/{}{}", base_url, database, query_string);
let pool = PgPoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(30))
.connect(&db_url)
.await?;
info!("Connected to PostgreSQL database: {}", database);
Ok(Self {
pool,
admin_pool,
database: database.to_string(),
host,
port,
user,
password,
base_url,
query_string,
should_drop: true, })
}
pub fn connection_string(&self) -> String {
format!("{}/{}{}", self.base_url, self.database, self.query_string)
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn execute(&self, sql: &str) -> Result<()> {
sqlx::query(sql).execute(&self.pool).await?;
Ok(())
}
pub async fn count(&self, query: &str) -> Result<i64> {
let row = sqlx::query(query).fetch_one(&self.pool).await?;
let count: i64 = row.try_get(0)?;
Ok(count)
}
pub async fn query<T>(&self, query: &str) -> Result<Vec<T>>
where
T: for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> + Send + Unpin,
{
let rows = sqlx::query_as::<_, T>(query).fetch_all(&self.pool).await?;
Ok(rows)
}
pub async fn list_tables(&self) -> Result<Vec<String>> {
#[derive(sqlx::FromRow)]
struct TableRow {
tablename: String,
}
let tables: Vec<TableRow> = self
.query("SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename")
.await?;
Ok(tables.into_iter().map(|t| t.tablename).collect())
}
pub async fn get_sample_data(&self, table: &str, limit: usize) -> Result<Vec<Vec<String>>> {
let rows = sqlx::query(&format!(
"SELECT * FROM public.\"{}\" LIMIT {}",
table, limit
))
.fetch_all(&self.pool)
.await?;
let mut results = Vec::new();
for row in rows {
let values: Vec<String> = (0..row.len())
.map(|i| {
let val: Option<String> = row.try_get(i).ok();
val.unwrap_or_else(|| "NULL".to_string())
})
.collect();
results.push(values);
}
Ok(results)
}
pub async fn get_column_names(&self, table: &str) -> Result<Vec<String>> {
let rows = sqlx::query(&format!("SELECT * FROM public.\"{}\" LIMIT 1", table))
.fetch_all(&self.pool)
.await?;
if let Some(row) = rows.first() {
use sqlx::Column;
Ok(row.columns().iter().map(|c| c.name().to_string()).collect())
} else {
#[derive(sqlx::FromRow)]
struct ColumnRow {
column_name: String,
}
let columns: Vec<ColumnRow> = self
.query(&format!(
"SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = '{}' ORDER BY ordinal_position",
table
))
.await?;
Ok(columns.into_iter().map(|c| c.column_name).collect())
}
}
#[allow(dead_code)]
pub async fn cleanup(&self) -> Result<()> {
self.pool.close().await;
let terminate_sql = format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid()",
self.database.replace('\'', "''")
);
let _ = sqlx::query(&terminate_sql).execute(&self.admin_pool).await;
tokio::time::sleep(Duration::from_millis(100)).await;
let drop_sql = format!(
"DROP DATABASE IF EXISTS \"{}\"",
self.database.replace('"', "\"\"")
);
sqlx::query(&drop_sql).execute(&self.admin_pool).await?;
info!("Dropped PostgreSQL database: {}", self.database);
Ok(())
}
}
impl Drop for PostgresResource {
fn drop(&mut self) {
if !self.should_drop {
return;
}
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let database = self.database.clone();
let admin_pool = self.admin_pool.clone();
let pool = self.pool.clone();
handle.spawn(async move {
pool.close().await;
let terminate_sql = format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid()",
database.replace('\'', "''")
);
let _ = sqlx::query(&terminate_sql).execute(&admin_pool).await;
tokio::time::sleep(Duration::from_millis(200)).await;
let drop_sql = format!(
"DROP DATABASE IF EXISTS \"{}\"",
database.replace('"', "\"\"")
);
if let Err(e) = sqlx::query(&drop_sql).execute(&admin_pool).await {
tracing::warn!("Failed to drop database {}: {}", database, e);
} else {
info!("Dropped PostgreSQL database: {}", database);
}
admin_pool.close().await;
});
}
}
}