use chrono::{DateTime, Utc};
use sqlx::{PgPool, Postgres, Row, pool::PoolConnection};
use std::sync::Arc;
use crate::backends::error::{DatabaseError, Result};
pub struct PgSession {
pub connection: PoolConnection<Postgres>,
pub xid: String,
pub state: PgTwoPhaseState,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PgTwoPhaseState {
Active,
Prepared,
}
#[derive(Clone)]
pub struct PostgresTwoPhaseParticipant {
pool: Arc<PgPool>,
sessions: Arc<std::sync::Mutex<std::collections::HashMap<String, PgSession>>>,
}
impl PostgresTwoPhaseParticipant {
pub fn new(pool: PgPool) -> Self {
Self {
pool: Arc::new(pool),
sessions: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
}
}
pub fn from_pool_arc(pool: Arc<PgPool>) -> Self {
Self {
pool,
sessions: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
}
}
pub fn pool(&self) -> &PgPool {
self.pool.as_ref()
}
pub async fn begin(&self, xid: &str) -> Result<PgSession> {
let mut connection = self.pool.acquire().await.map_err(DatabaseError::from)?;
sqlx::query("BEGIN")
.execute(&mut *connection)
.await
.map_err(DatabaseError::from)?;
Ok(PgSession {
connection,
xid: xid.to_string(),
state: PgTwoPhaseState::Active,
})
}
pub async fn prepare(&self, session: &mut PgSession) -> Result<()> {
let xid_escaped = pg_escape::quote_literal(&session.xid);
let sql = format!("PREPARE TRANSACTION {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
session.state = PgTwoPhaseState::Prepared;
Ok(())
}
pub async fn commit(&self, mut session: PgSession) -> Result<()> {
let xid_escaped = pg_escape::quote_literal(&session.xid);
let sql = format!("COMMIT PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn commit_by_xid(&self, xid: &str) -> Result<()> {
let mut conn = self.pool.acquire().await.map_err(DatabaseError::from)?;
let xid_escaped = pg_escape::quote_literal(xid);
let sql = format!("COMMIT PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *conn)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback(&self, mut session: PgSession) -> Result<()> {
let xid_escaped = pg_escape::quote_literal(&session.xid);
let sql = format!("ROLLBACK PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback_by_xid(&self, xid: &str) -> Result<()> {
let mut conn = self.pool.acquire().await.map_err(DatabaseError::from)?;
let xid_escaped = pg_escape::quote_literal(xid);
let sql = format!("ROLLBACK PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *conn)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn list_prepared_transactions(&self) -> Result<Vec<PreparedTransactionInfo>> {
let rows = sqlx::query(
"SELECT gid, prepared, owner, database FROM pg_prepared_xacts ORDER BY prepared",
)
.fetch_all(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
let mut transactions = Vec::new();
for row in rows {
transactions.push(PreparedTransactionInfo {
gid: row.try_get("gid").map_err(DatabaseError::from)?,
prepared: row.try_get("prepared").map_err(DatabaseError::from)?,
owner: row.try_get("owner").map_err(DatabaseError::from)?,
database: row.try_get("database").map_err(DatabaseError::from)?,
});
}
Ok(transactions)
}
pub async fn find_prepared_transaction(
&self,
xid: &str,
) -> Result<Option<PreparedTransactionInfo>> {
let row = sqlx::query(
"SELECT gid, prepared, owner, database FROM pg_prepared_xacts WHERE gid = $1",
)
.bind(xid)
.fetch_optional(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
if let Some(row) = row {
Ok(Some(PreparedTransactionInfo {
gid: row.try_get("gid").map_err(DatabaseError::from)?,
prepared: row.try_get("prepared").map_err(DatabaseError::from)?,
owner: row.try_get("owner").map_err(DatabaseError::from)?,
database: row.try_get("database").map_err(DatabaseError::from)?,
}))
} else {
Ok(None)
}
}
pub async fn cleanup_stale_transactions(&self, max_age: std::time::Duration) -> Result<usize> {
let max_age_secs = max_age.as_secs() as i32;
let rows = sqlx::query(
"SELECT gid FROM pg_prepared_xacts
WHERE EXTRACT(EPOCH FROM (NOW() - prepared)) > $1",
)
.bind(max_age_secs)
.fetch_all(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
let mut cleaned = 0;
for row in rows {
let gid: String = row.try_get("gid").map_err(DatabaseError::from)?;
if self.rollback_by_xid(&gid).await.is_ok() {
cleaned += 1;
}
}
Ok(cleaned)
}
fn acquire_sessions_lock(
sessions: &std::sync::Mutex<std::collections::HashMap<String, PgSession>>,
) -> Result<std::sync::MutexGuard<'_, std::collections::HashMap<String, PgSession>>> {
match sessions.lock() {
Ok(guard) => Ok(guard),
Err(poisoned) => {
tracing::warn!(
"2PC sessions mutex was poisoned, clearing all sessions to prevent inconsistent state"
);
let mut guard = poisoned.into_inner();
guard.clear();
Ok(guard)
}
}
}
pub async fn begin_by_xid(&self, xid: &str) -> Result<()> {
let session = self.begin(xid).await?;
let mut sessions = Self::acquire_sessions_lock(&self.sessions)?;
sessions.insert(xid.to_string(), session);
Ok(())
}
pub async fn prepare_by_xid(&self, xid: &str) -> Result<()> {
let mut session = {
let mut sessions = Self::acquire_sessions_lock(&self.sessions)?;
sessions.remove(xid).ok_or_else(|| {
DatabaseError::QueryError(format!("No active session for XID: {}", xid))
})?
};
let xid_escaped = pg_escape::quote_literal(xid);
let sql = format!("PREPARE TRANSACTION {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
session.state = PgTwoPhaseState::Prepared;
let mut sessions = Self::acquire_sessions_lock(&self.sessions)?;
sessions.insert(xid.to_string(), session);
Ok(())
}
pub async fn commit_managed(&self, xid: &str) -> Result<()> {
let mut session = Self::acquire_sessions_lock(&self.sessions)?
.remove(xid)
.ok_or_else(|| {
DatabaseError::QueryError(format!("No active session for XID: {}", xid))
})?;
let xid_escaped = pg_escape::quote_literal(xid);
let sql = format!("COMMIT PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback_managed(&self, xid: &str) -> Result<()> {
let mut session = Self::acquire_sessions_lock(&self.sessions)?
.remove(xid)
.ok_or_else(|| {
DatabaseError::QueryError(format!("No active session for XID: {}", xid))
})?;
let xid_escaped = pg_escape::quote_literal(xid);
let sql = format!("ROLLBACK PREPARED {}", xid_escaped);
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PreparedTransactionInfo {
pub gid: String,
pub prepared: DateTime<Utc>,
pub owner: String,
pub database: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepared_transaction_info_creation() {
let info = PreparedTransactionInfo {
gid: "txn_001".to_string(),
prepared: DateTime::UNIX_EPOCH,
owner: "postgres".to_string(),
database: "testdb".to_string(),
};
assert_eq!(info.gid, "txn_001");
assert_eq!(info.owner, "postgres");
assert_eq!(info.database, "testdb");
}
#[tokio::test]
async fn test_participant_clone() {
let pool = Arc::new(
PgPool::connect_lazy("postgresql://localhost/testdb")
.expect("Failed to create lazy pool"),
);
let participant1 = PostgresTwoPhaseParticipant::from_pool_arc(pool.clone());
let participant2 = participant1.clone();
assert!(Arc::ptr_eq(&participant1.pool, &participant2.pool));
}
}