use crate::db::{EngineError, error::EngineErrorKind, sql_common};
use super::DbEngine;
use sqlparser;
use sqlx::{PgPool, Pool, Postgres, Transaction};
use sqlx::postgres::PgPoolOptions;
use std::process;
pub struct PostgresEngine {
conn_str: String,
pool: Option<Pool<Postgres>>,
tx: Option<Transaction<'static, Postgres>>,
flag_no_transaction: bool, }
impl PostgresEngine {
pub fn new(conn_str: &str) -> Self {
return PostgresEngine {
conn_str: conn_str.to_string(),
pool: None,
tx: None,
flag_no_transaction: false,
};
}
pub fn disable_transactions(&mut self) -> () {
self.flag_no_transaction = true;
}
async fn pool(&mut self) -> Result<PgPool, EngineError> {
if self.pool.is_none() {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&self.conn_str)
.await?;
self.pool = Some(pool);
}
self.pool.clone().ok_or_else(|| EngineError {
kind: EngineErrorKind::TransactionNotStarted,
})
}
async fn transaction(&mut self) -> Result<&mut Transaction<'static, Postgres>, EngineError> {
if self.tx.is_none() {
let txn = self.pool().await?.begin().await?;
self.tx = Some(txn);
}
self.tx.as_mut().ok_or_else(|| EngineError {
kind: EngineErrorKind::TransactionNotStarted,
})
}
async fn _execute(&mut self, sql: &str) -> Result<(), EngineError> {
if self.flag_no_transaction {
let pool = self.pool().await?;
sqlx::raw_sql(&sql)
.execute(&pool)
.await?;
} else {
let tx = self.transaction().await?;
sqlx::raw_sql(&sql)
.execute( &mut **tx)
.await?;
}
Ok(())
}
pub async fn begin(&mut self) -> Result<(), EngineError> {
self.pool().await?;
self.transaction().await?;
Ok(())
}
pub async fn rollback(&mut self) -> Result<(), EngineError> {
if let Some(tx) = self.tx.take() {
tx.rollback().await?;
}
Ok(())
}
pub async fn commit(&mut self) -> Result<(), EngineError> {
if let Some(tx) = self.tx.take() {
tx.commit().await?;
}
Ok(())
}
async fn exists(&mut self, sql: &str) -> Result<bool, EngineError> {
if self.flag_no_transaction {
let pool = self.pool().await?;
Ok(sqlx::query_scalar::<_, i32>(sql)
.fetch_optional(&pool)
.await?
.is_some())
} else {
let tx = self.transaction().await?;
Ok(sqlx::query_scalar::<_, i32>(sql)
.fetch_optional(&mut **tx)
.await?
.is_some())
}
}
async fn fetch_optional_i64(&mut self, sql: &str) -> Result<Option<i64>, EngineError> {
if self.flag_no_transaction {
let pool = self.pool().await?;
sqlx::query_scalar(sql)
.fetch_one(&pool)
.await
.map_err(Into::into)
} else {
let tx = self.transaction().await?;
sqlx::query_scalar(sql)
.fetch_one(&mut **tx)
.await
.map_err(Into::into)
}
}
}
impl DbEngine for PostgresEngine {
async fn ensure_table(&mut self) -> Result<(), EngineError> {
let pool = PgPool::connect(&self.conn_str).await?;
sqlx::query("CREATE SCHEMA IF NOT EXISTS swellow;")
.execute(&pool)
.await?;
sqlx::query("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
.execute(&pool)
.await?;
sqlx::query(r#"
CREATE TABLE IF NOT EXISTS swellow.records (
oid OID,
version_id BIGINT NOT NULL,
object_type TEXT NOT NULL,
object_name_before TEXT NOT NULL,
object_name_after TEXT NOT NULL,
status TEXT NOT NULL,
checksum TEXT NOT NULL,
dtm_created_at TIMESTAMP DEFAULT now(),
dtm_updated_at TIMESTAMP DEFAULT now(),
PRIMARY KEY (version_id, object_type, object_name_before, object_name_after)
);
"#)
.execute(&pool)
.await?;
Ok(())
}
async fn execute(&mut self, sql: &str) -> Result<(), EngineError> {
self._execute(sql).await
}
async fn fetch_latest_applied_version(&mut self) -> Result<Option<i64>, EngineError> {
self.fetch_optional_i64(sql_common::QUERY_LATEST_VERSION).await
}
async fn acquire_lock(&mut self) -> Result<(), EngineError> {
if self.flag_no_transaction {
if self.exists(sql_common::QUERY_LOCK_EXISTS).await? {
return Err(EngineError { kind: EngineErrorKind::LockConflict })
}
self.execute(r#"
INSERT INTO swellow.records (
version_id,
object_type,
object_name_before,
object_name_after,
status,
checksum,
dtm_created_at,
dtm_updated_at
)
VALUES (
0,
'LOCK',
'LOCK',
'LOCK',
'LOCKED',
'LOCK',
now(),
now()
)
"#).await?;
} else {
self.execute(&"LOCK TABLE swellow.records IN ACCESS EXCLUSIVE MODE;").await?;
}
Ok(())
}
async fn release_lock(&mut self) -> Result<(), EngineError> {
if self.flag_no_transaction {
self.execute(sql_common::QUERY_DELETE_LOCK).await?;
} else {
}
Ok(())
}
async fn disable_records(&mut self, current_version_id: i64) -> Result<(), EngineError> {
let query = r#"
UPDATE swellow.records
SET status='DISABLED'
WHERE version_id > $1
"#;
if self.flag_no_transaction {
let pool = self.pool().await?;
sqlx::query(query)
.bind(current_version_id)
.execute(&pool)
.await?;
} else {
let tx = self.transaction().await?;
sqlx::query(query)
.bind(current_version_id)
.execute(&mut **tx)
.await?;
}
Ok(())
}
async fn upsert_record(
&mut self,
object_type: &sqlparser::ast::ObjectType,
object_name_before: &str,
object_name_after: &str,
version_id: i64,
checksum: &str
) -> Result<(), EngineError> {
let query = r#"
INSERT INTO swellow.records(
object_type,
object_name_before,
object_name_after,
version_id,
status,
checksum
)
VALUES (
$1,
$2,
$3,
$4,
'READY',
md5($5)
)
ON CONFLICT (version_id, object_type, object_name_before, object_name_after)
DO UPDATE SET
status = EXCLUDED.status,
checksum = EXCLUDED.checksum
"#;
if self.flag_no_transaction {
let pool = self.pool().await?;
sqlx::query(query)
.bind(object_type.to_string())
.bind(object_name_before)
.bind(object_name_after)
.bind(version_id)
.bind(checksum.to_string())
.execute(&pool)
.await?;
} else {
let tx = self.transaction().await?;
sqlx::query(query)
.bind(object_type.to_string())
.bind(object_name_before)
.bind(object_name_after)
.bind(version_id)
.bind(checksum.to_string())
.execute(&mut **tx)
.await?;
}
Ok(())
}
async fn update_record(&mut self, status: &str, version_id: i64) -> Result<(), EngineError> {
let query = r#"
UPDATE swellow.records
SET
status=$1
WHERE
version_id=$2
"#;
if self.flag_no_transaction {
let pool = self.pool().await?;
sqlx::query(query)
.bind(status)
.bind(version_id)
.execute(&pool)
.await?;
} else {
let tx = self.transaction().await?;
sqlx::query(query)
.bind(status)
.bind(version_id)
.execute(&mut **tx)
.await?;
}
Ok(())
}
async fn snapshot(&mut self) -> Result<String, EngineError> {
if process::Command::new("pg_dump").arg("--version").output()
.is_err() {
tracing::error!("pg_dump not installed or not in PATH.");
std::process::exit(1);
}
let output = process::Command::new("pg_dump")
.arg("--schema-only") .arg("--no-owner") .arg("--no-privileges")
.arg(&self.conn_str)
.output()
.map_err(|source| {
EngineError { kind: EngineErrorKind::Process { source, cmd: "pg_dump --schema-only --no-owner --no-privileges".to_string() }}
})?;
if output.status.success() {
let stdout = String::from_utf8(output.stdout)
.map_err(|e| EngineError { kind: EngineErrorKind::Utf8(e) })?;
Ok(stdout)
} else {
let stderr = String::from_utf8(output.stderr)
.map_err(|e| EngineError { kind: EngineErrorKind::Utf8(e) })?;
Err(EngineError {
kind: EngineErrorKind::PGDump { stderr },
})
}
}
}