use sqlx::pool::PoolConnection;
use sqlx::{MySql, MySqlPool, Row};
use std::sync::Arc;
use crate::backends::error::{DatabaseError, Result};
pub struct XaSessionStarted {
pub connection: PoolConnection<MySql>,
pub xid: String,
}
pub struct XaSessionEnded {
pub connection: PoolConnection<MySql>,
pub xid: String,
}
pub struct XaSessionPrepared {
pub connection: PoolConnection<MySql>,
pub xid: String,
}
#[derive(Clone)]
pub struct MySqlTwoPhaseParticipant {
pool: Arc<MySqlPool>,
}
impl MySqlTwoPhaseParticipant {
pub fn new(pool: MySqlPool) -> Self {
Self {
pool: Arc::new(pool),
}
}
pub fn from_pool_arc(pool: Arc<MySqlPool>) -> Self {
Self { pool }
}
pub fn pool(&self) -> &MySqlPool {
self.pool.as_ref()
}
pub async fn begin(&self, xid: impl Into<String>) -> Result<XaSessionStarted> {
let mut connection = self.pool.acquire().await.map_err(DatabaseError::from)?;
let xid = xid.into();
let sql = format!("XA START '{}'", Self::escape_xid(&xid));
sqlx::raw_sql(&sql)
.execute(&mut *connection)
.await
.map_err(DatabaseError::from)?;
Ok(XaSessionStarted { connection, xid })
}
pub async fn end(&self, mut session: XaSessionStarted) -> Result<XaSessionEnded> {
let sql = format!("XA END '{}'", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(XaSessionEnded {
connection: session.connection,
xid: session.xid,
})
}
pub async fn prepare(&self, mut session: XaSessionEnded) -> Result<XaSessionPrepared> {
let sql = format!("XA PREPARE '{}'", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(XaSessionPrepared {
connection: session.connection,
xid: session.xid,
})
}
pub async fn commit(&self, mut session: XaSessionPrepared) -> Result<()> {
let sql = format!("XA COMMIT '{}'", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn commit_by_xid(&self, xid: &str) -> Result<()> {
let mut conn = self.pool.acquire().await.map_err(DatabaseError::from)?;
let sql = format!("XA COMMIT '{}'", Self::escape_xid(xid));
sqlx::raw_sql(&sql)
.execute(&mut *conn)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn commit_one_phase(&self, mut session: XaSessionEnded) -> Result<()> {
let sql = format!("XA COMMIT '{}' ONE PHASE", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback_started(&self, mut session: XaSessionStarted) -> Result<()> {
let sql = format!("XA ROLLBACK '{}'", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback_prepared(&self, mut session: XaSessionPrepared) -> Result<()> {
let sql = format!("XA ROLLBACK '{}'", Self::escape_xid(&session.xid));
sqlx::raw_sql(&sql)
.execute(&mut *session.connection)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn rollback_by_xid(&self, xid: &str) -> Result<()> {
let mut conn = self.pool.acquire().await.map_err(DatabaseError::from)?;
let sql = format!("XA ROLLBACK '{}'", Self::escape_xid(xid));
sqlx::raw_sql(&sql)
.execute(&mut *conn)
.await
.map_err(DatabaseError::from)?;
Ok(())
}
pub async fn list_prepared_transactions(&self) -> Result<Vec<XaTransactionInfo>> {
let rows = sqlx::raw_sql("XA RECOVER")
.fetch_all(self.pool.as_ref())
.await
.map_err(DatabaseError::from)?;
let mut transactions = Vec::new();
for row in rows {
let format_id: i32 = row.try_get("formatID").map_err(DatabaseError::from)?;
let gtrid_length: i32 = row.try_get("gtrid_length").map_err(DatabaseError::from)?;
let bqual_length: i32 = row.try_get("bqual_length").map_err(DatabaseError::from)?;
let data: Vec<u8> = row.try_get("data").map_err(DatabaseError::from)?;
transactions.push(XaTransactionInfo {
format_id,
gtrid_length,
bqual_length,
data: data.clone(),
xid: String::from_utf8_lossy(&data).to_string(),
});
}
Ok(transactions)
}
pub async fn find_prepared_transaction(&self, xid: &str) -> Result<Option<XaTransactionInfo>> {
let all_txns = self.list_prepared_transactions().await?;
Ok(all_txns.into_iter().find(|txn| txn.xid == xid))
}
pub async fn cleanup_stale_transactions(&self, prefix: &str) -> Result<usize> {
let all_txns = self.list_prepared_transactions().await?;
let mut cleaned = 0;
for txn in all_txns {
if txn.xid.starts_with(prefix) && self.rollback_by_xid(&txn.xid).await.is_ok() {
cleaned += 1;
}
}
Ok(cleaned)
}
fn escape_xid(xid: &str) -> String {
xid.replace('\'', "''")
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct XaTransactionInfo {
pub format_id: i32,
pub gtrid_length: i32,
pub bqual_length: i32,
pub data: Vec<u8>,
pub xid: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xa_transaction_info_creation() {
let info = XaTransactionInfo {
format_id: 1,
gtrid_length: 7,
bqual_length: 0,
data: b"txn_001".to_vec(),
xid: "txn_001".to_string(),
};
assert_eq!(info.format_id, 1);
assert_eq!(info.gtrid_length, 7);
assert_eq!(info.bqual_length, 0);
assert_eq!(info.xid, "txn_001");
}
#[test]
fn test_escape_xid() {
assert_eq!(MySqlTwoPhaseParticipant::escape_xid("simple"), "simple");
assert_eq!(MySqlTwoPhaseParticipant::escape_xid("it's"), "it''s");
assert_eq!(MySqlTwoPhaseParticipant::escape_xid("a'b'c"), "a''b''c");
}
#[tokio::test]
async fn test_participant_clone() {
let pool = Arc::new(
MySqlPool::connect_lazy("mysql://localhost/testdb")
.expect("Failed to create lazy pool"),
);
let participant1 = MySqlTwoPhaseParticipant::from_pool_arc(pool.clone());
let participant2 = participant1.clone();
assert!(Arc::ptr_eq(&participant1.pool, &participant2.pool));
}
}