#![allow(deprecated)]
use super::types::ShardId;
use crate::core::hlc::HybridTimestamp;
use crate::core::id::TxId;
use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
const DEFAULT_RETRIES: u32 = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionPhase {
Pending,
Preparing,
Prepared,
Committing,
Committed,
Aborted,
Failed,
}
impl fmt::Display for TransactionPhase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransactionPhase::Pending => write!(f, "Pending"),
TransactionPhase::Preparing => write!(f, "Preparing"),
TransactionPhase::Prepared => write!(f, "Prepared"),
TransactionPhase::Committing => write!(f, "Committing"),
TransactionPhase::Committed => write!(f, "Committed"),
TransactionPhase::Aborted => write!(f, "Aborted"),
TransactionPhase::Failed => write!(f, "Failed"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParticipantState {
Unknown,
Prepared,
Committed,
Aborted,
Unreachable,
}
impl fmt::Display for ParticipantState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ParticipantState::Unknown => write!(f, "Unknown"),
ParticipantState::Prepared => write!(f, "Prepared"),
ParticipantState::Committed => write!(f, "Committed"),
ParticipantState::Aborted => write!(f, "Aborted"),
ParticipantState::Unreachable => write!(f, "Unreachable"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistributedTxError {
PrepareFailed {
failed_participants: Vec<ShardId>,
},
CommitFailed {
tx_id: TxId,
failed_participants: Vec<ShardId>,
},
Timeout {
phase: TransactionPhase,
duration: Duration,
},
ParticipantUnavailable {
shard_id: ShardId,
},
Aborted {
reason: String,
},
Deadlock {
involved_transactions: Vec<TxId>,
},
InvalidStateTransition {
from: TransactionPhase,
to: TransactionPhase,
},
}
impl fmt::Display for DistributedTxError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DistributedTxError::PrepareFailed {
failed_participants,
} => {
write!(
f,
"Prepare failed for participants: {:?}",
failed_participants
)
}
DistributedTxError::CommitFailed {
tx_id,
failed_participants,
} => {
write!(
f,
"Commit failed for transaction {} on participants: {:?}",
tx_id, failed_participants
)
}
DistributedTxError::Timeout { phase, duration } => {
write!(
f,
"Transaction timeout in {} phase after {:?}",
phase, duration
)
}
DistributedTxError::ParticipantUnavailable { shard_id } => {
write!(f, "Participant {} is unavailable", shard_id)
}
DistributedTxError::Aborted { reason } => {
write!(f, "Transaction aborted: {}", reason)
}
DistributedTxError::Deadlock {
involved_transactions,
} => {
write!(
f,
"Deadlock detected involving transactions: {:?}",
involved_transactions
)
}
DistributedTxError::InvalidStateTransition { from, to } => {
write!(f, "Invalid state transition from {} to {}", from, to)
}
}
}
}
impl std::error::Error for DistributedTxError {}
#[derive(Debug)]
pub struct DistributedTransaction {
pub tx_id: TxId,
pub phase: TransactionPhase,
pub participants: HashMap<ShardId, ParticipantState>,
pub start_time: Instant,
pub timeout: Duration,
pub retries_remaining: u32,
pub commit_decision_logged: bool,
pub commit_timestamp: Option<HybridTimestamp>,
}
impl DistributedTransaction {
pub fn new(tx_id: TxId, participants: Vec<ShardId>, timeout: Duration) -> Self {
let mut participant_map = HashMap::new();
for shard in participants {
participant_map.insert(shard, ParticipantState::Unknown);
}
Self {
tx_id,
phase: TransactionPhase::Pending,
participants: participant_map,
start_time: Instant::now(),
timeout,
retries_remaining: DEFAULT_RETRIES,
commit_decision_logged: false,
commit_timestamp: None,
}
}
pub fn is_timed_out(&self) -> bool {
self.start_time.elapsed() > self.timeout
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn begin_prepare(&mut self) -> Result<(), DistributedTxError> {
if self.phase != TransactionPhase::Pending {
return Err(DistributedTxError::InvalidStateTransition {
from: self.phase,
to: TransactionPhase::Preparing,
});
}
self.phase = TransactionPhase::Preparing;
Ok(())
}
pub fn record_prepare_success(&mut self, shard_id: ShardId) {
if let Some(state) = self.participants.get_mut(&shard_id) {
*state = ParticipantState::Prepared;
}
}
pub fn record_prepare_failure(&mut self, shard_id: ShardId) {
if let Some(state) = self.participants.get_mut(&shard_id) {
*state = ParticipantState::Aborted;
}
}
pub fn record_unreachable(&mut self, shard_id: ShardId) {
if let Some(state) = self.participants.get_mut(&shard_id) {
*state = ParticipantState::Unreachable;
}
}
pub fn all_prepared(&self) -> bool {
self.participants
.values()
.all(|s| *s == ParticipantState::Prepared)
}
pub fn any_aborted(&self) -> bool {
self.participants
.values()
.any(|s| *s == ParticipantState::Aborted)
}
pub fn any_unreachable(&self) -> bool {
self.participants
.values()
.any(|s| *s == ParticipantState::Unreachable)
}
pub fn mark_prepared(&mut self) -> Result<(), DistributedTxError> {
if self.phase != TransactionPhase::Preparing {
return Err(DistributedTxError::InvalidStateTransition {
from: self.phase,
to: TransactionPhase::Prepared,
});
}
if !self.all_prepared() {
let failed: Vec<ShardId> = self
.participants
.iter()
.filter(|(_, s)| **s != ParticipantState::Prepared)
.map(|(id, _)| *id)
.collect();
return Err(DistributedTxError::PrepareFailed {
failed_participants: failed,
});
}
self.phase = TransactionPhase::Prepared;
Ok(())
}
pub fn begin_commit(&mut self) -> Result<(), DistributedTxError> {
if self.phase != TransactionPhase::Prepared {
return Err(DistributedTxError::InvalidStateTransition {
from: self.phase,
to: TransactionPhase::Committing,
});
}
self.phase = TransactionPhase::Committing;
Ok(())
}
pub fn record_commit_success(&mut self, shard_id: ShardId) {
if let Some(state) = self.participants.get_mut(&shard_id) {
*state = ParticipantState::Committed;
}
}
pub fn all_committed(&self) -> bool {
self.participants
.values()
.all(|s| *s == ParticipantState::Committed)
}
pub fn uncommitted_participants(&self) -> Vec<ShardId> {
self.participants
.iter()
.filter(|(_, s)| **s != ParticipantState::Committed)
.map(|(id, _)| *id)
.collect()
}
pub fn mark_committed(&mut self) -> Result<(), DistributedTxError> {
if self.phase != TransactionPhase::Committing {
return Err(DistributedTxError::InvalidStateTransition {
from: self.phase,
to: TransactionPhase::Committed,
});
}
if !self.all_committed() {
let failed = self.uncommitted_participants();
return Err(DistributedTxError::CommitFailed {
tx_id: self.tx_id,
failed_participants: failed,
});
}
self.phase = TransactionPhase::Committed;
Ok(())
}
pub fn abort(&mut self, reason: &str) {
self.phase = TransactionPhase::Aborted;
for state in self.participants.values_mut() {
if *state != ParticipantState::Committed {
*state = ParticipantState::Aborted;
}
}
let _ = reason; }
pub fn mark_failed(&mut self) {
self.phase = TransactionPhase::Failed;
}
pub fn can_retry(&self) -> bool {
self.retries_remaining > 0
&& !self.is_timed_out()
&& self.phase != TransactionPhase::Committed
&& self.phase != TransactionPhase::Aborted
}
pub fn decrement_retries(&mut self) {
if self.retries_remaining > 0 {
self.retries_remaining -= 1;
}
}
pub fn participant_shards(&self) -> Vec<ShardId> {
self.participants.keys().copied().collect()
}
}
#[derive(Debug)]
#[deprecated(note = "Use PersistentCommitLog instead")]
pub struct TwoPhaseCommitLog {
pending_decisions: HashMap<TxId, CommitDecision>,
lsn_generator: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct CommitDecision {
pub tx_id: TxId,
pub lsn: u64,
pub participants: Vec<ShardId>,
pub decision: bool,
pub timestamp: Instant,
pub commit_timestamp: Option<HybridTimestamp>,
}
impl TwoPhaseCommitLog {
pub fn new() -> Self {
Self {
pending_decisions: HashMap::new(),
lsn_generator: AtomicU64::new(0),
}
}
pub fn log_commit(
&mut self,
tx_id: TxId,
participants: Vec<ShardId>,
commit_timestamp: Option<HybridTimestamp>,
) -> u64 {
let lsn = self.lsn_generator.fetch_add(1, Ordering::SeqCst);
let decision = CommitDecision {
tx_id,
lsn,
participants,
decision: true,
timestamp: Instant::now(),
commit_timestamp,
};
self.pending_decisions.insert(tx_id, decision);
lsn
}
pub fn log_abort(&mut self, tx_id: TxId, participants: Vec<ShardId>) -> u64 {
let lsn = self.lsn_generator.fetch_add(1, Ordering::SeqCst);
let decision = CommitDecision {
tx_id,
lsn,
participants,
decision: false,
timestamp: Instant::now(),
commit_timestamp: None,
};
self.pending_decisions.insert(tx_id, decision);
lsn
}
pub fn clear_decision(&mut self, tx_id: TxId) -> Option<CommitDecision> {
self.pending_decisions.remove(&tx_id)
}
pub fn get_decision(&self, tx_id: TxId) -> Option<&CommitDecision> {
self.pending_decisions.get(&tx_id)
}
pub fn pending_decisions(&self) -> Vec<&CommitDecision> {
self.pending_decisions.values().collect()
}
pub fn decisions_to_replay(&self) -> Vec<&CommitDecision> {
self.pending_decisions
.values()
.filter(|d| d.decision)
.collect()
}
pub fn aborts_to_process(&self) -> Vec<&CommitDecision> {
self.pending_decisions
.values()
.filter(|d| !d.decision)
.collect()
}
pub fn has_pending_decision(&self, tx_id: TxId) -> bool {
self.pending_decisions.contains_key(&tx_id)
}
pub fn current_lsn(&self) -> u64 {
self.lsn_generator.load(Ordering::SeqCst)
}
}
impl Default for TwoPhaseCommitLog {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tx_id(id: u64) -> TxId {
TxId::new(id)
}
#[test]
fn test_transaction_phase_display() {
assert_eq!(format!("{}", TransactionPhase::Pending), "Pending");
assert_eq!(format!("{}", TransactionPhase::Preparing), "Preparing");
assert_eq!(format!("{}", TransactionPhase::Prepared), "Prepared");
assert_eq!(format!("{}", TransactionPhase::Committing), "Committing");
assert_eq!(format!("{}", TransactionPhase::Committed), "Committed");
assert_eq!(format!("{}", TransactionPhase::Aborted), "Aborted");
assert_eq!(format!("{}", TransactionPhase::Failed), "Failed");
}
#[test]
fn test_participant_state_display() {
assert_eq!(format!("{}", ParticipantState::Unknown), "Unknown");
assert_eq!(format!("{}", ParticipantState::Prepared), "Prepared");
assert_eq!(format!("{}", ParticipantState::Committed), "Committed");
assert_eq!(format!("{}", ParticipantState::Aborted), "Aborted");
assert_eq!(format!("{}", ParticipantState::Unreachable), "Unreachable");
}
#[test]
fn test_distributed_tx_creation() {
let tx_id = make_tx_id(1);
let shards = vec![ShardId::new(0).unwrap(), ShardId::new(1).unwrap()];
let tx = DistributedTransaction::new(tx_id, shards, Duration::from_secs(30));
assert_eq!(tx.tx_id, tx_id);
assert_eq!(tx.phase, TransactionPhase::Pending);
assert_eq!(tx.participants.len(), 2);
assert!(!tx.all_prepared());
}
#[test]
fn test_distributed_tx_prepare_flow() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut tx =
DistributedTransaction::new(tx_id, vec![shard0, shard1], Duration::from_secs(30));
assert!(tx.begin_prepare().is_ok());
assert_eq!(tx.phase, TransactionPhase::Preparing);
assert!(!tx.all_prepared());
tx.record_prepare_success(shard0);
assert!(!tx.all_prepared());
tx.record_prepare_success(shard1);
assert!(tx.all_prepared());
assert!(tx.mark_prepared().is_ok());
assert_eq!(tx.phase, TransactionPhase::Prepared);
}
#[test]
fn test_distributed_tx_commit_flow() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut tx =
DistributedTransaction::new(tx_id, vec![shard0, shard1], Duration::from_secs(30));
tx.begin_prepare().unwrap();
tx.record_prepare_success(shard0);
tx.record_prepare_success(shard1);
tx.mark_prepared().unwrap();
assert!(tx.begin_commit().is_ok());
assert_eq!(tx.phase, TransactionPhase::Committing);
assert!(!tx.all_committed());
tx.record_commit_success(shard0);
assert!(!tx.all_committed());
tx.record_commit_success(shard1);
assert!(tx.all_committed());
assert!(tx.mark_committed().is_ok());
assert_eq!(tx.phase, TransactionPhase::Committed);
}
#[test]
fn test_distributed_tx_prepare_failure() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut tx =
DistributedTransaction::new(tx_id, vec![shard0, shard1], Duration::from_secs(30));
tx.begin_prepare().unwrap();
tx.record_prepare_success(shard0);
tx.record_prepare_failure(shard1);
assert!(tx.any_aborted());
assert!(!tx.all_prepared());
let result = tx.mark_prepared();
assert!(result.is_err());
if let Err(DistributedTxError::PrepareFailed {
failed_participants,
}) = result
{
assert!(failed_participants.contains(&shard1));
} else {
panic!("Expected PrepareFailed error");
}
}
#[test]
fn test_distributed_tx_unreachable() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut tx =
DistributedTransaction::new(tx_id, vec![shard0, shard1], Duration::from_secs(30));
tx.begin_prepare().unwrap();
tx.record_prepare_success(shard0);
tx.record_unreachable(shard1);
assert!(tx.any_unreachable());
assert!(!tx.all_prepared());
}
#[test]
fn test_distributed_tx_abort() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let mut tx =
DistributedTransaction::new(tx_id, vec![shard0, shard1], Duration::from_secs(30));
tx.begin_prepare().unwrap();
tx.record_prepare_success(shard0);
tx.abort("test abort");
assert_eq!(tx.phase, TransactionPhase::Aborted);
}
#[test]
fn test_distributed_tx_invalid_transitions() {
let tx_id = make_tx_id(1);
let mut tx = DistributedTransaction::new(
tx_id,
vec![ShardId::new(0).unwrap()],
Duration::from_secs(30),
);
assert!(tx.begin_commit().is_err());
assert!(tx.mark_prepared().is_err());
assert!(tx.mark_committed().is_err());
tx.begin_prepare().unwrap();
assert!(tx.begin_prepare().is_err());
}
#[test]
fn test_distributed_tx_retry() {
let tx_id = make_tx_id(1);
let mut tx = DistributedTransaction::new(
tx_id,
vec![ShardId::new(0).unwrap()],
Duration::from_secs(30),
);
assert_eq!(tx.retries_remaining, 3);
assert!(tx.can_retry());
tx.decrement_retries();
assert_eq!(tx.retries_remaining, 2);
assert!(tx.can_retry());
tx.decrement_retries();
tx.decrement_retries();
assert_eq!(tx.retries_remaining, 0);
assert!(!tx.can_retry());
}
#[test]
fn test_distributed_tx_timeout() {
let tx_id = make_tx_id(1);
let tx = DistributedTransaction::new(
tx_id,
vec![ShardId::new(0).unwrap()],
Duration::from_millis(1),
);
std::thread::sleep(Duration::from_millis(10));
assert!(tx.is_timed_out());
assert!(!tx.can_retry());
}
#[test]
fn test_two_phase_commit_log() {
let mut log = TwoPhaseCommitLog::new();
let tx_id = make_tx_id(1);
let shards = vec![ShardId::new(0).unwrap(), ShardId::new(1).unwrap()];
let lsn = log.log_commit(tx_id, shards.clone(), None);
assert_eq!(lsn, 0);
assert!(log.has_pending_decision(tx_id));
let decision = log.get_decision(tx_id).unwrap();
assert_eq!(decision.tx_id, tx_id);
assert!(decision.decision);
assert_eq!(decision.participants.len(), 2);
let cleared = log.clear_decision(tx_id);
assert!(cleared.is_some());
assert!(!log.has_pending_decision(tx_id));
}
#[test]
fn test_two_phase_commit_log_abort() {
let mut log = TwoPhaseCommitLog::new();
let tx_id = make_tx_id(1);
let shards = vec![ShardId::new(0).unwrap()];
let lsn = log.log_abort(tx_id, shards);
assert_eq!(lsn, 0);
let decision = log.get_decision(tx_id).unwrap();
assert!(!decision.decision);
let aborts = log.aborts_to_process();
assert_eq!(aborts.len(), 1);
let commits = log.decisions_to_replay();
assert!(commits.is_empty());
}
#[test]
fn test_two_phase_commit_log_recovery() {
let mut log = TwoPhaseCommitLog::new();
let tx1 = make_tx_id(1);
let tx2 = make_tx_id(2);
let tx3 = make_tx_id(3);
let shards = vec![ShardId::new(0).unwrap()];
log.log_commit(tx1, shards.clone(), None);
log.log_abort(tx2, shards.clone());
log.log_commit(tx3, shards.clone(), None);
let pending = log.pending_decisions();
assert_eq!(pending.len(), 3);
let commits = log.decisions_to_replay();
assert_eq!(commits.len(), 2);
let aborts = log.aborts_to_process();
assert_eq!(aborts.len(), 1);
}
#[test]
fn test_two_phase_commit_log_lsn_ordering() {
let mut log = TwoPhaseCommitLog::new();
let shards = vec![ShardId::new(0).unwrap()];
let lsn1 = log.log_commit(make_tx_id(1), shards.clone(), None);
let lsn2 = log.log_commit(make_tx_id(2), shards.clone(), None);
let lsn3 = log.log_commit(make_tx_id(3), shards.clone(), None);
assert!(lsn1 < lsn2);
assert!(lsn2 < lsn3);
assert_eq!(log.current_lsn(), 3);
}
#[test]
fn test_distributed_tx_uncommitted_participants() {
let tx_id = make_tx_id(1);
let shard0 = ShardId::new(0).unwrap();
let shard1 = ShardId::new(1).unwrap();
let shard2 = ShardId::new(2).unwrap();
let mut tx = DistributedTransaction::new(
tx_id,
vec![shard0, shard1, shard2],
Duration::from_secs(30),
);
tx.begin_prepare().unwrap();
tx.record_prepare_success(shard0);
tx.record_prepare_success(shard1);
tx.record_prepare_success(shard2);
tx.mark_prepared().unwrap();
tx.begin_commit().unwrap();
tx.record_commit_success(shard0);
let uncommitted = tx.uncommitted_participants();
assert_eq!(uncommitted.len(), 2);
assert!(uncommitted.contains(&shard1));
assert!(uncommitted.contains(&shard2));
}
#[test]
fn test_distributed_tx_error_display() {
let err = DistributedTxError::PrepareFailed {
failed_participants: vec![ShardId::new(0).unwrap()],
};
assert!(format!("{}", err).contains("Prepare failed"));
let err = DistributedTxError::Timeout {
phase: TransactionPhase::Preparing,
duration: Duration::from_secs(30),
};
assert!(format!("{}", err).contains("timeout"));
assert!(format!("{}", err).contains("Preparing"));
let err = DistributedTxError::Deadlock {
involved_transactions: vec![make_tx_id(1), make_tx_id(2)],
};
assert!(format!("{}", err).contains("Deadlock"));
}
}