use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::Mutex as TokioMutex;
#[cfg(feature = "postgres")]
use crate::backends::PostgresTwoPhaseParticipant;
#[cfg(feature = "mysql")]
use crate::backends::{MySqlTwoPhaseParticipant, XaSessionPrepared, XaSessionStarted};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TwoPhaseError {
PrepareFailed(String, String),
CommitFailed(String, String),
RollbackFailed(String, String),
InvalidState(String),
NoParticipants,
DuplicateTransactionId(String),
DuplicateParticipant(String),
ConnectionError(String),
Timeout(String),
RecoveryFailed(String),
LogError(String),
DatabaseError(String),
}
impl std::fmt::Display for TwoPhaseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TwoPhaseError::PrepareFailed(participant, reason) => {
write!(f, "Prepare failed for '{}': {}", participant, reason)
}
TwoPhaseError::CommitFailed(participant, reason) => {
write!(f, "Commit failed for '{}': {}", participant, reason)
}
TwoPhaseError::RollbackFailed(participant, reason) => {
write!(f, "Rollback failed for '{}': {}", participant, reason)
}
TwoPhaseError::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
TwoPhaseError::NoParticipants => write!(f, "No participants registered"),
TwoPhaseError::DuplicateTransactionId(id) => {
write!(f, "Transaction ID '{}' already exists", id)
}
TwoPhaseError::DuplicateParticipant(participant) => {
write!(f, "Participant '{}' already registered", participant)
}
TwoPhaseError::ConnectionError(msg) => write!(f, "Connection error: {}", msg),
TwoPhaseError::Timeout(msg) => write!(f, "Timeout: {}", msg),
TwoPhaseError::RecoveryFailed(msg) => write!(f, "Recovery failed: {}", msg),
TwoPhaseError::LogError(msg) => write!(f, "Transaction log error: {}", msg),
TwoPhaseError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
}
}
}
impl std::error::Error for TwoPhaseError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum TransactionState {
NotStarted,
Active,
Preparing,
Prepared,
Committing,
Committed,
Aborting,
Aborted,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParticipantStatus {
Active,
Prepared,
Committed,
Aborted,
}
#[async_trait(?Send)]
pub trait TwoPhaseParticipant: Send + Sync {
fn id(&self) -> &str;
async fn begin(&self) -> Result<(), TwoPhaseError>;
async fn prepare(&self, xid: String) -> Result<(), TwoPhaseError>;
async fn commit(&self, xid: String) -> Result<(), TwoPhaseError>;
async fn rollback(&self, xid: String) -> Result<(), TwoPhaseError>;
async fn recover(&self) -> Result<Vec<String>, TwoPhaseError>;
fn status(&self) -> ParticipantStatus;
fn set_status(&mut self, status: ParticipantStatus);
}
#[derive(Debug, Clone)]
pub struct Participant {
pub db_alias: String,
pub status: ParticipantStatus,
}
impl Participant {
pub fn new(db_alias: impl Into<String>) -> Self {
Self {
db_alias: db_alias.into(),
status: ParticipantStatus::Active,
}
}
pub fn is_prepared(&self) -> bool {
matches!(self.status, ParticipantStatus::Prepared)
}
}
#[derive(Debug)]
pub struct TwoPhaseCommit {
transaction_id: String,
state: Arc<StdMutex<TransactionState>>,
participants: Arc<StdMutex<HashMap<String, Participant>>>,
}
impl TwoPhaseCommit {
pub fn new(transaction_id: impl Into<String>) -> Self {
Self {
transaction_id: transaction_id.into(),
state: Arc::new(StdMutex::new(TransactionState::NotStarted)),
participants: Arc::new(StdMutex::new(HashMap::new())),
}
}
pub fn transaction_id(&self) -> &str {
&self.transaction_id
}
pub fn state(&self) -> Result<TransactionState, TwoPhaseError> {
self.state
.lock()
.map(|s| *s)
.map_err(|_| TwoPhaseError::InvalidState("Failed to acquire state lock".to_string()))
}
pub fn begin(&mut self) -> Result<(), TwoPhaseError> {
let mut state = self
.state
.lock()
.map_err(|_| TwoPhaseError::InvalidState("Failed to acquire state lock".to_string()))?;
if *state != TransactionState::NotStarted {
return Err(TwoPhaseError::InvalidState(
"Transaction already started".to_string(),
));
}
*state = TransactionState::Active;
Ok(())
}
pub fn add_participant(&mut self, db_alias: impl Into<String>) -> Result<(), TwoPhaseError> {
let state = self.state().unwrap();
if state != TransactionState::Active {
return Err(TwoPhaseError::InvalidState(
"Can only add participants to active transaction".to_string(),
));
}
let db_alias = db_alias.into();
let mut participants = self.participants.lock().map_err(|_| {
TwoPhaseError::InvalidState("Failed to acquire participants lock".to_string())
})?;
if participants.contains_key(&db_alias) {
return Err(TwoPhaseError::DuplicateParticipant(db_alias));
}
participants.insert(db_alias.clone(), Participant::new(db_alias));
Ok(())
}
pub fn participant_count(&self) -> usize {
self.participants.lock().map(|p| p.len()).unwrap_or(0)
}
pub fn prepare(&mut self) -> Result<Vec<String>, TwoPhaseError> {
let state = self.state().unwrap();
if state != TransactionState::Active {
return Err(TwoPhaseError::InvalidState(
"Can only prepare active transaction".to_string(),
));
}
let mut participants = self.participants.lock().map_err(|_| {
TwoPhaseError::InvalidState("Failed to acquire participants lock".to_string())
})?;
if participants.is_empty() {
return Err(TwoPhaseError::NoParticipants);
}
let mut prepared_sqls = Vec::new();
for (db_alias, participant) in participants.iter_mut() {
let sql = format!("PREPARE TRANSACTION '{}'", self.transaction_id);
prepared_sqls.push(format!("{}: {}", db_alias, sql));
participant.status = ParticipantStatus::Prepared;
}
let mut state = self
.state
.lock()
.map_err(|_| TwoPhaseError::InvalidState("Failed to acquire state lock".to_string()))?;
*state = TransactionState::Prepared;
Ok(prepared_sqls)
}
pub fn commit(&mut self) -> Result<Vec<String>, TwoPhaseError> {
let state = self.state().unwrap();
if state != TransactionState::Prepared {
return Err(TwoPhaseError::InvalidState(
"Can only commit prepared transaction".to_string(),
));
}
let mut participants = self.participants.lock().map_err(|_| {
TwoPhaseError::InvalidState("Failed to acquire participants lock".to_string())
})?;
let mut commit_sqls = Vec::new();
for (db_alias, participant) in participants.iter_mut() {
if !participant.is_prepared() {
return Err(TwoPhaseError::CommitFailed(
db_alias.clone(),
"Participant not prepared".to_string(),
));
}
let sql = format!("COMMIT PREPARED '{}'", self.transaction_id);
commit_sqls.push(format!("{}: {}", db_alias, sql));
participant.status = ParticipantStatus::Committed;
}
let mut state = self
.state
.lock()
.map_err(|_| TwoPhaseError::InvalidState("Failed to acquire state lock".to_string()))?;
*state = TransactionState::Committed;
Ok(commit_sqls)
}
pub fn rollback(&mut self) -> Result<Vec<String>, TwoPhaseError> {
let state = self.state().unwrap();
if state != TransactionState::Active
&& state != TransactionState::Prepared
&& state != TransactionState::Preparing
{
return Err(TwoPhaseError::InvalidState(
"Can only rollback active, prepared, or preparing transaction".to_string(),
));
}
let mut participants = self.participants.lock().map_err(|_| {
TwoPhaseError::InvalidState("Failed to acquire participants lock".to_string())
})?;
let mut rollback_sqls = Vec::new();
for (db_alias, participant) in participants.iter_mut() {
let sql = if participant.is_prepared() {
format!("ROLLBACK PREPARED '{}'", self.transaction_id)
} else {
"ROLLBACK".to_string()
};
rollback_sqls.push(format!("{}: {}", db_alias, sql));
participant.status = ParticipantStatus::Aborted;
}
let mut state = self
.state
.lock()
.map_err(|_| TwoPhaseError::InvalidState("Failed to acquire state lock".to_string()))?;
*state = TransactionState::Aborted;
Ok(rollback_sqls)
}
pub fn participants(&self) -> HashMap<String, Participant> {
self.participants
.lock()
.map(|p| p.clone())
.unwrap_or_default()
}
pub fn all_prepared(&self) -> bool {
self.participants
.lock()
.map(|p| p.values().all(|participant| participant.is_prepared()))
.unwrap_or(false)
}
}
impl Default for TwoPhaseCommit {
fn default() -> Self {
Self::new(uuid::Uuid::now_v7().to_string())
}
}
pub struct TwoPhaseCoordinator {
transaction_id: String,
state: Arc<TokioMutex<TransactionState>>,
participants: Arc<TokioMutex<Vec<Box<dyn TwoPhaseParticipant>>>>,
transaction_log: Option<Arc<dyn super::transaction_log::TransactionLog>>,
}
impl TwoPhaseCoordinator {
pub fn new(transaction_id: impl Into<String>) -> Self {
Self {
transaction_id: transaction_id.into(),
state: Arc::new(TokioMutex::new(TransactionState::NotStarted)),
participants: Arc::new(TokioMutex::new(Vec::new())),
transaction_log: None,
}
}
pub fn with_log(
transaction_id: impl Into<String>,
log: Arc<dyn super::transaction_log::TransactionLog>,
) -> Self {
Self {
transaction_id: transaction_id.into(),
state: Arc::new(TokioMutex::new(TransactionState::NotStarted)),
participants: Arc::new(TokioMutex::new(Vec::new())),
transaction_log: Some(log),
}
}
async fn log_state(&self, state: TransactionState) -> Result<(), TwoPhaseError> {
if let Some(log) = &self.transaction_log {
let participants = self.participants.lock().await;
let participant_ids: Vec<String> =
participants.iter().map(|p| p.id().to_string()).collect();
let entry = super::transaction_log::TransactionLogEntry::new(
&self.transaction_id,
state,
participant_ids,
);
log.write(&entry)?;
}
Ok(())
}
pub fn transaction_id(&self) -> &str {
&self.transaction_id
}
pub async fn state(&self) -> Result<TransactionState, TwoPhaseError> {
let state = self.state.lock().await;
Ok(*state)
}
pub async fn add_participant(
&mut self,
participant: Box<dyn TwoPhaseParticipant>,
) -> Result<(), TwoPhaseError> {
let mut participants = self.participants.lock().await;
let id = participant.id().to_string();
if participants.iter().any(|p| p.id() == id) {
return Err(TwoPhaseError::DuplicateParticipant(id));
}
participants.push(participant);
Ok(())
}
pub async fn begin(&mut self) -> Result<(), TwoPhaseError> {
{
let mut state = self.state.lock().await;
if *state != TransactionState::NotStarted {
return Err(TwoPhaseError::InvalidState(
"Transaction already started".to_string(),
));
}
*state = TransactionState::Active;
}
self.log_state(TransactionState::Active).await?;
{
let mut participants = self.participants.lock().await;
for participant in participants.iter_mut() {
participant.begin().await.map_err(|e| {
TwoPhaseError::PrepareFailed(participant.id().to_string(), e.to_string())
})?;
}
}
Ok(())
}
pub async fn prepare(&mut self) -> Result<(), TwoPhaseError> {
{
let mut state = self.state.lock().await;
if *state != TransactionState::Active {
return Err(TwoPhaseError::InvalidState(
"Can only prepare active transaction".to_string(),
));
}
*state = TransactionState::Preparing;
}
{
let mut participants = self.participants.lock().await;
if participants.is_empty() {
return Err(TwoPhaseError::NoParticipants);
}
for participant in participants.iter_mut() {
match participant.prepare(self.transaction_id.clone()).await {
Ok(_) => {
participant.set_status(ParticipantStatus::Prepared);
}
Err(e) => {
let participant_id = participant.id().to_string();
let error_msg = e.to_string();
drop(participants);
self.rollback().await?;
return Err(TwoPhaseError::PrepareFailed(participant_id, error_msg));
}
}
}
}
{
let mut state = self.state.lock().await;
*state = TransactionState::Prepared;
}
self.log_state(TransactionState::Prepared).await?;
Ok(())
}
pub async fn commit(&mut self) -> Result<(), TwoPhaseError> {
{
let mut state = self.state.lock().await;
if *state != TransactionState::Prepared {
return Err(TwoPhaseError::InvalidState(
"Can only commit prepared transaction".to_string(),
));
}
*state = TransactionState::Committing;
}
let mut failed_participants = Vec::new();
{
let mut participants = self.participants.lock().await;
for participant in participants.iter_mut() {
match participant.commit(self.transaction_id.clone()).await {
Ok(_) => {
participant.set_status(ParticipantStatus::Committed);
}
Err(e) => {
failed_participants.push((participant.id().to_string(), e.to_string()));
}
}
}
}
if !failed_participants.is_empty() {
let mut state = self.state.lock().await;
*state = TransactionState::Prepared; let error_msg = failed_participants
.iter()
.map(|(id, err)| format!("{}: {}", id, err))
.collect::<Vec<_>>()
.join(", ");
return Err(TwoPhaseError::CommitFailed(
"Multiple participants".to_string(),
error_msg,
));
}
{
let mut state = self.state.lock().await;
*state = TransactionState::Committed;
}
self.log_state(TransactionState::Committed).await?;
if let Some(log) = &self.transaction_log {
let _ = log.delete(&self.transaction_id);
}
Ok(())
}
pub async fn rollback(&mut self) -> Result<(), TwoPhaseError> {
{
let mut state = self.state.lock().await;
if *state != TransactionState::Active
&& *state != TransactionState::Prepared
&& *state != TransactionState::Preparing
{
return Err(TwoPhaseError::InvalidState(
"Can only rollback active, prepared, or preparing transaction".to_string(),
));
}
*state = TransactionState::Aborting;
}
let mut failed_participants = Vec::new();
{
let mut participants = self.participants.lock().await;
for participant in participants.iter_mut() {
match participant.rollback(self.transaction_id.clone()).await {
Ok(_) => {
participant.set_status(ParticipantStatus::Aborted);
}
Err(e) => {
failed_participants.push((participant.id().to_string(), e.to_string()));
}
}
}
}
if !failed_participants.is_empty() {
let error_msg = failed_participants
.iter()
.map(|(id, err)| format!("{}: {}", id, err))
.collect::<Vec<_>>()
.join(", ");
return Err(TwoPhaseError::RollbackFailed(
"Multiple participants".to_string(),
error_msg,
));
}
{
let mut state = self.state.lock().await;
*state = TransactionState::Aborted;
}
self.log_state(TransactionState::Aborted).await?;
if let Some(log) = &self.transaction_log {
let _ = log.delete(&self.transaction_id);
}
Ok(())
}
pub async fn recover_prepared_transactions(&mut self) -> Result<Vec<String>, TwoPhaseError> {
let mut all_xids = Vec::new();
{
let mut participants = self.participants.lock().await;
for participant in participants.iter_mut() {
let xids = participant
.recover()
.await
.map_err(|e| TwoPhaseError::RecoveryFailed(e.to_string()))?;
all_xids.extend(xids);
}
}
Ok(all_xids)
}
pub async fn participant_count(&self) -> usize {
self.participants.lock().await.len()
}
}
#[cfg(feature = "postgres")]
pub struct PostgresParticipantAdapter {
id: String,
backend: PostgresTwoPhaseParticipant,
status: ParticipantStatus,
}
#[cfg(feature = "postgres")]
impl PostgresParticipantAdapter {
pub fn new(id: impl Into<String>, pool: sqlx::PgPool) -> Self {
Self {
id: id.into(),
backend: PostgresTwoPhaseParticipant::new(pool),
status: ParticipantStatus::Active,
}
}
pub fn from_pool_arc(id: impl Into<String>, pool: std::sync::Arc<sqlx::PgPool>) -> Self {
Self {
id: id.into(),
backend: PostgresTwoPhaseParticipant::from_pool_arc(pool),
status: ParticipantStatus::Active,
}
}
}
#[cfg(feature = "postgres")]
#[async_trait(?Send)]
impl TwoPhaseParticipant for PostgresParticipantAdapter {
fn id(&self) -> &str {
&self.id
}
async fn begin(&self) -> Result<(), TwoPhaseError> {
self.backend
.begin_by_xid(&self.id)
.await
.map_err(|e| TwoPhaseError::DatabaseError(e.to_string()))
}
async fn prepare(&self, xid: String) -> Result<(), TwoPhaseError> {
self.backend
.prepare_by_xid(&xid)
.await
.map_err(|e| TwoPhaseError::PrepareFailed(self.id.clone(), e.to_string()))
}
async fn commit(&self, xid: String) -> Result<(), TwoPhaseError> {
self.backend
.commit_managed(&xid)
.await
.map_err(|e| TwoPhaseError::CommitFailed(self.id.clone(), e.to_string()))
}
async fn rollback(&self, xid: String) -> Result<(), TwoPhaseError> {
self.backend
.rollback_managed(&xid)
.await
.map_err(|e| TwoPhaseError::RollbackFailed(self.id.clone(), e.to_string()))
}
async fn recover(&self) -> Result<Vec<String>, TwoPhaseError> {
let txns = self
.backend
.list_prepared_transactions()
.await
.map_err(|e| TwoPhaseError::RecoveryFailed(e.to_string()))?;
Ok(txns.into_iter().map(|t| t.gid).collect())
}
fn status(&self) -> ParticipantStatus {
self.status.clone()
}
fn set_status(&mut self, status: ParticipantStatus) {
self.status = status;
}
}
#[cfg(feature = "mysql")]
enum XaSessionState {
Started(XaSessionStarted),
Prepared(XaSessionPrepared),
}
#[cfg(feature = "mysql")]
pub struct MySqlParticipantAdapter {
id: String,
backend: MySqlTwoPhaseParticipant,
status: ParticipantStatus,
session: Arc<StdMutex<Option<XaSessionState>>>,
}
#[cfg(feature = "mysql")]
impl MySqlParticipantAdapter {
pub fn new(id: impl Into<String>, pool: sqlx::MySqlPool) -> Self {
Self {
id: id.into(),
backend: MySqlTwoPhaseParticipant::new(pool),
status: ParticipantStatus::Active,
session: Arc::new(StdMutex::new(None)),
}
}
pub fn from_pool_arc(id: impl Into<String>, pool: std::sync::Arc<sqlx::MySqlPool>) -> Self {
Self {
id: id.into(),
backend: MySqlTwoPhaseParticipant::from_pool_arc(pool),
status: ParticipantStatus::Active,
session: Arc::new(StdMutex::new(None)),
}
}
}
#[cfg(feature = "mysql")]
#[async_trait(?Send)]
impl TwoPhaseParticipant for MySqlParticipantAdapter {
fn id(&self) -> &str {
&self.id
}
async fn begin(&self) -> Result<(), TwoPhaseError> {
let session = self
.backend
.begin(self.id.clone())
.await
.map_err(|e| TwoPhaseError::DatabaseError(e.to_string()))?;
let mut session_guard = self.session.lock().map_err(|e| {
TwoPhaseError::DatabaseError(format!("Failed to acquire session lock: {}", e))
})?;
*session_guard = Some(XaSessionState::Started(session));
Ok(())
}
async fn prepare(&self, _xid: String) -> Result<(), TwoPhaseError> {
let session = {
let mut session_guard = self.session.lock().map_err(|e| {
TwoPhaseError::PrepareFailed(
self.id.clone(),
format!("Failed to acquire session lock: {}", e),
)
})?;
match session_guard.take() {
Some(XaSessionState::Started(s)) => s,
Some(XaSessionState::Prepared(_)) => {
return Err(TwoPhaseError::InvalidState(format!(
"Session already prepared for participant '{}'",
self.id
)));
}
None => {
return Err(TwoPhaseError::InvalidState(format!(
"No active session for participant '{}'",
self.id
)));
}
}
};
let ended_session = self
.backend
.end(session)
.await
.map_err(|e| TwoPhaseError::PrepareFailed(self.id.clone(), e.to_string()))?;
let prepared_session = self
.backend
.prepare(ended_session)
.await
.map_err(|e| TwoPhaseError::PrepareFailed(self.id.clone(), e.to_string()))?;
let mut session_guard = self.session.lock().map_err(|e| {
TwoPhaseError::PrepareFailed(
self.id.clone(),
format!("Failed to acquire session lock: {}", e),
)
})?;
*session_guard = Some(XaSessionState::Prepared(prepared_session));
Ok(())
}
async fn commit(&self, _xid: String) -> Result<(), TwoPhaseError> {
let session = {
let mut session_guard = self.session.lock().map_err(|e| {
TwoPhaseError::CommitFailed(
self.id.clone(),
format!("Failed to acquire session lock: {}", e),
)
})?;
match session_guard.take() {
Some(XaSessionState::Prepared(s)) => s,
Some(XaSessionState::Started(_)) => {
return Err(TwoPhaseError::InvalidState(format!(
"Session not prepared for participant '{}'",
self.id
)));
}
None => {
return Err(TwoPhaseError::InvalidState(format!(
"No active session for participant '{}'",
self.id
)));
}
}
};
self.backend
.commit(session)
.await
.map_err(|e| TwoPhaseError::CommitFailed(self.id.clone(), e.to_string()))
}
async fn rollback(&self, _xid: String) -> Result<(), TwoPhaseError> {
let session_state = {
let mut session_guard = self.session.lock().map_err(|e| {
TwoPhaseError::RollbackFailed(
self.id.clone(),
format!("Failed to acquire session lock: {}", e),
)
})?;
session_guard.take().ok_or_else(|| {
TwoPhaseError::InvalidState(format!(
"No active session for participant '{}'",
self.id
))
})?
};
match session_state {
XaSessionState::Started(s) => self
.backend
.rollback_started(s)
.await
.map_err(|e| TwoPhaseError::RollbackFailed(self.id.clone(), e.to_string())),
XaSessionState::Prepared(s) => self
.backend
.rollback_prepared(s)
.await
.map_err(|e| TwoPhaseError::RollbackFailed(self.id.clone(), e.to_string())),
}
}
async fn recover(&self) -> Result<Vec<String>, TwoPhaseError> {
let txns = self
.backend
.list_prepared_transactions()
.await
.map_err(|e| TwoPhaseError::RecoveryFailed(e.to_string()))?;
Ok(txns.into_iter().map(|t| t.xid).collect())
}
fn status(&self) -> ParticipantStatus {
self.status.clone()
}
fn set_status(&mut self, status: ParticipantStatus) {
self.status = status;
}
}
#[cfg(test)]
pub struct MockParticipant {
id: String,
status: std::sync::Mutex<ParticipantStatus>,
should_fail_prepare: bool,
should_fail_commit: bool,
should_fail_rollback: bool,
}
#[cfg(test)]
impl MockParticipant {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
status: std::sync::Mutex::new(ParticipantStatus::Active),
should_fail_prepare: false,
should_fail_commit: false,
should_fail_rollback: false,
}
}
pub fn with_prepare_failure(mut self) -> Self {
self.should_fail_prepare = true;
self
}
pub fn with_commit_failure(mut self) -> Self {
self.should_fail_commit = true;
self
}
pub fn with_rollback_failure(mut self) -> Self {
self.should_fail_rollback = true;
self
}
}
#[cfg(test)]
#[async_trait(?Send)]
impl TwoPhaseParticipant for MockParticipant {
fn id(&self) -> &str {
&self.id
}
async fn begin(&self) -> Result<(), TwoPhaseError> {
Ok(())
}
async fn prepare(&self, _xid: String) -> Result<(), TwoPhaseError> {
if self.should_fail_prepare {
Err(TwoPhaseError::PrepareFailed(
self.id.clone(),
"Simulated prepare failure".to_string(),
))
} else {
*self.status.lock().unwrap() = ParticipantStatus::Prepared;
Ok(())
}
}
async fn commit(&self, _xid: String) -> Result<(), TwoPhaseError> {
if self.should_fail_commit {
Err(TwoPhaseError::CommitFailed(
self.id.clone(),
"Simulated commit failure".to_string(),
))
} else {
*self.status.lock().unwrap() = ParticipantStatus::Committed;
Ok(())
}
}
async fn rollback(&self, _xid: String) -> Result<(), TwoPhaseError> {
if self.should_fail_rollback {
Err(TwoPhaseError::RollbackFailed(
self.id.clone(),
"Simulated rollback failure".to_string(),
))
} else {
*self.status.lock().unwrap() = ParticipantStatus::Aborted;
Ok(())
}
}
async fn recover(&self) -> Result<Vec<String>, TwoPhaseError> {
Ok(Vec::new())
}
fn status(&self) -> ParticipantStatus {
self.status.lock().unwrap().clone()
}
fn set_status(&mut self, status: ParticipantStatus) {
*self.status.lock().unwrap() = status;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_transaction() {
let tpc = TwoPhaseCommit::new("txn_test_001");
assert_eq!(tpc.transaction_id(), "txn_test_001");
assert_eq!(tpc.state().unwrap(), TransactionState::NotStarted);
assert_eq!(tpc.participant_count(), 0);
}
#[test]
fn test_begin_transaction() {
let mut tpc = TwoPhaseCommit::new("txn_test_002");
let result = tpc.begin();
assert!(result.is_ok());
assert_eq!(tpc.state().unwrap(), TransactionState::Active);
}
#[test]
fn test_cannot_begin_twice() {
let mut tpc = TwoPhaseCommit::new("txn_test_003");
tpc.begin().unwrap();
let result = tpc.begin();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::InvalidState(_)
));
}
#[test]
fn test_add_participant() {
let mut tpc = TwoPhaseCommit::new("txn_test_004");
tpc.begin().unwrap();
let result = tpc.add_participant("db1");
assert!(result.is_ok());
assert_eq!(tpc.participant_count(), 1);
}
#[test]
fn test_cannot_add_duplicate_participant() {
let mut tpc = TwoPhaseCommit::new("txn_test_005");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
let result = tpc.add_participant("db1");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::DuplicateParticipant(_)
));
}
#[test]
fn test_prepare_phase() {
let mut tpc = TwoPhaseCommit::new("txn_test_006");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
tpc.add_participant("db2").unwrap();
let result = tpc.prepare();
assert!(result.is_ok());
assert_eq!(tpc.state().unwrap(), TransactionState::Prepared);
assert!(tpc.all_prepared());
let sqls = result.unwrap();
assert_eq!(sqls.len(), 2);
}
#[test]
fn test_prepare_without_participants() {
let mut tpc = TwoPhaseCommit::new("txn_test_007");
tpc.begin().unwrap();
let result = tpc.prepare();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TwoPhaseError::NoParticipants));
}
#[test]
fn test_commit_phase() {
let mut tpc = TwoPhaseCommit::new("txn_test_008");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
tpc.prepare().unwrap();
let result = tpc.commit();
assert!(result.is_ok());
assert_eq!(tpc.state().unwrap(), TransactionState::Committed);
let sqls = result.unwrap();
assert_eq!(sqls.len(), 1);
assert!(sqls[0].contains("COMMIT PREPARED"));
}
#[test]
fn test_cannot_commit_without_prepare() {
let mut tpc = TwoPhaseCommit::new("txn_test_009");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
let result = tpc.commit();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::InvalidState(_)
));
}
#[test]
fn test_rollback_active_transaction() {
let mut tpc = TwoPhaseCommit::new("txn_test_010");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
let result = tpc.rollback();
assert!(result.is_ok());
assert_eq!(tpc.state().unwrap(), TransactionState::Aborted);
let sqls = result.unwrap();
assert_eq!(sqls.len(), 1);
assert!(sqls[0].contains("ROLLBACK"));
}
#[test]
fn test_rollback_prepared_transaction() {
let mut tpc = TwoPhaseCommit::new("txn_test_011");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
tpc.prepare().unwrap();
let result = tpc.rollback();
assert!(result.is_ok());
assert_eq!(tpc.state().unwrap(), TransactionState::Aborted);
let sqls = result.unwrap();
assert_eq!(sqls.len(), 1);
assert!(sqls[0].contains("ROLLBACK PREPARED"));
}
#[test]
fn test_multiple_participants_prepare_commit() {
let mut tpc = TwoPhaseCommit::new("txn_test_012");
tpc.begin().unwrap();
tpc.add_participant("primary_db").unwrap();
tpc.add_participant("secondary_db").unwrap();
tpc.add_participant("cache_db").unwrap();
assert_eq!(tpc.participant_count(), 3);
tpc.prepare().unwrap();
assert!(tpc.all_prepared());
tpc.commit().unwrap();
assert_eq!(tpc.state().unwrap(), TransactionState::Committed);
}
#[test]
fn test_participant_new() {
let participant = Participant::new("test_db");
assert_eq!(participant.db_alias, "test_db");
assert_eq!(participant.status, ParticipantStatus::Active);
assert!(!participant.is_prepared());
}
#[test]
fn test_participant_is_prepared() {
let mut participant = Participant::new("test_db");
assert!(!participant.is_prepared());
participant.status = ParticipantStatus::Prepared;
assert!(participant.is_prepared());
}
#[test]
fn test_get_participants() {
let mut tpc = TwoPhaseCommit::new("txn_test_013");
tpc.begin().unwrap();
tpc.add_participant("db1").unwrap();
tpc.add_participant("db2").unwrap();
let participants = tpc.participants();
assert_eq!(participants.len(), 2);
assert!(participants.contains_key("db1"));
assert!(participants.contains_key("db2"));
}
#[test]
fn test_default_transaction_id() {
let tpc = TwoPhaseCommit::default();
assert!(!tpc.transaction_id().is_empty());
assert_eq!(tpc.state().unwrap(), TransactionState::NotStarted);
}
#[test]
fn test_transaction_state_transitions() {
let mut tpc = TwoPhaseCommit::new("txn_test_014");
assert_eq!(tpc.state().unwrap(), TransactionState::NotStarted);
tpc.begin().unwrap();
assert_eq!(tpc.state().unwrap(), TransactionState::Active);
tpc.add_participant("db1").unwrap();
tpc.prepare().unwrap();
assert_eq!(tpc.state().unwrap(), TransactionState::Prepared);
tpc.commit().unwrap();
assert_eq!(tpc.state().unwrap(), TransactionState::Committed);
}
#[test]
fn test_error_display() {
let err = TwoPhaseError::PrepareFailed("db1".to_string(), "Connection lost".to_string());
assert_eq!(err.to_string(), "Prepare failed for 'db1': Connection lost");
let err = TwoPhaseError::NoParticipants;
assert_eq!(err.to_string(), "No participants registered");
let err = TwoPhaseError::DuplicateParticipant("db1".to_string());
assert_eq!(err.to_string(), "Participant 'db1' already registered");
}
#[test]
fn test_cannot_add_participant_before_begin() {
let mut tpc = TwoPhaseCommit::new("txn_test_015");
let result = tpc.add_participant("db1");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::InvalidState(_)
));
}
#[tokio::test]
async fn test_coordinator_basic_flow() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_001");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
coordinator
.add_participant(Box::new(MockParticipant::new("db2")))
.await
.unwrap();
assert_eq!(coordinator.participant_count().await, 2);
coordinator.begin().await.unwrap();
assert_eq!(coordinator.state().await.unwrap(), TransactionState::Active);
coordinator.prepare().await.unwrap();
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Prepared
);
coordinator.commit().await.unwrap();
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Committed
);
}
#[tokio::test]
async fn test_coordinator_prepare_failure_triggers_rollback() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_002");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
coordinator
.add_participant(Box::new(MockParticipant::new("db2").with_prepare_failure()))
.await
.unwrap();
coordinator.begin().await.unwrap();
let result = coordinator.prepare().await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::PrepareFailed(_, _)
));
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Aborted
);
}
#[tokio::test]
async fn test_coordinator_rollback() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_003");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
coordinator
.add_participant(Box::new(MockParticipant::new("db2")))
.await
.unwrap();
coordinator.begin().await.unwrap();
coordinator.prepare().await.unwrap();
coordinator.rollback().await.unwrap();
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Aborted
);
}
#[tokio::test]
async fn test_coordinator_commit_failure() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_004");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
coordinator
.add_participant(Box::new(MockParticipant::new("db2").with_commit_failure()))
.await
.unwrap();
coordinator.begin().await.unwrap();
coordinator.prepare().await.unwrap();
let result = coordinator.commit().await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::CommitFailed(_, _)
));
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Prepared
);
}
#[tokio::test]
async fn test_coordinator_no_participants() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_005");
coordinator.begin().await.unwrap();
let result = coordinator.prepare().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), TwoPhaseError::NoParticipants));
}
#[tokio::test]
async fn test_coordinator_duplicate_participant() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_006");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
let result = coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TwoPhaseError::DuplicateParticipant(_)
));
}
#[tokio::test]
async fn test_coordinator_recover() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_007");
coordinator
.add_participant(Box::new(MockParticipant::new("db1")))
.await
.unwrap();
let xids = coordinator.recover_prepared_transactions().await.unwrap();
assert_eq!(xids.len(), 0); }
#[tokio::test]
async fn test_coordinator_multiple_participants() {
let mut coordinator = TwoPhaseCoordinator::new("coord_txn_008");
for i in 1..=5 {
coordinator
.add_participant(Box::new(MockParticipant::new(format!("db{}", i))))
.await
.unwrap();
}
assert_eq!(coordinator.participant_count().await, 5);
coordinator.begin().await.unwrap();
coordinator.prepare().await.unwrap();
coordinator.commit().await.unwrap();
assert_eq!(
coordinator.state().await.unwrap(),
TransactionState::Committed
);
}
}