use crate::network::{NetworkService, RpcMessage};
use crate::raft::OxirsNodeId;
use crate::shard::{ShardId, ShardRouter};
use crate::shard_manager::ShardManager;
use crate::storage::StorageBackend;
use anyhow::Result;
use chrono::{DateTime, Utc};
use oxirs_core::model::Triple;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use uuid::Uuid;
pub type TransactionId = String;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum TransactionState {
Active,
Preparing,
Prepared,
Committing,
Committed,
Aborting,
Aborted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransactionOp {
Insert { triple: Triple },
Delete { triple: Triple },
Query {
subject: Option<String>,
predicate: Option<String>,
object: Option<String>,
},
}
#[derive(Debug, Clone)]
pub struct TransactionParticipant {
pub node_id: OxirsNodeId,
pub shard_id: ShardId,
pub vote: Option<bool>,
pub last_contact: Instant,
}
#[derive(Debug, Clone)]
pub struct Transaction {
pub id: TransactionId,
pub state: TransactionState,
pub operations: Vec<(ShardId, TransactionOp)>,
pub participants: HashMap<ShardId, TransactionParticipant>,
pub created_at: Instant,
pub timeout: Duration,
pub isolation_level: IsolationLevel,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl Default for IsolationLevel {
fn default() -> Self {
Self::ReadCommitted
}
}
#[derive(Debug, Clone)]
pub struct TransactionConfig {
pub default_timeout: Duration,
pub max_concurrent_transactions: usize,
pub enable_optimistic_cc: bool,
pub enable_deadlock_detection: bool,
pub checkpoint_interval: Duration,
}
impl Default for TransactionConfig {
fn default() -> Self {
Self {
default_timeout: Duration::from_secs(30),
max_concurrent_transactions: 1000,
enable_optimistic_cc: true,
enable_deadlock_detection: true,
checkpoint_interval: Duration::from_secs(60),
}
}
}
pub struct TransactionCoordinator {
#[allow(dead_code)]
node_id: OxirsNodeId,
shard_router: Arc<ShardRouter>,
shard_manager: Arc<ShardManager>,
#[allow(dead_code)]
storage: Arc<dyn StorageBackend>,
network: Arc<NetworkService>,
config: TransactionConfig,
transactions: Arc<RwLock<HashMap<TransactionId, Transaction>>>,
transaction_log: Arc<Mutex<TransactionLog>>,
lock_manager: Arc<LockManager>,
optimizer: crate::transaction_optimizer::TwoPhaseOptimizer,
}
impl TransactionCoordinator {
pub fn new(
node_id: OxirsNodeId,
shard_router: Arc<ShardRouter>,
shard_manager: Arc<ShardManager>,
storage: Arc<dyn StorageBackend>,
network: Arc<NetworkService>,
config: TransactionConfig,
) -> Self {
Self {
node_id,
shard_router,
shard_manager,
storage,
network,
config,
transactions: Arc::new(RwLock::new(HashMap::new())),
transaction_log: Arc::new(Mutex::new(TransactionLog::new())),
lock_manager: Arc::new(LockManager::new()),
optimizer: crate::transaction_optimizer::TwoPhaseOptimizer::new(),
}
}
pub async fn begin_transaction(
&self,
isolation_level: IsolationLevel,
) -> Result<TransactionId> {
let tx_id = Uuid::new_v4().to_string();
let transaction = Transaction {
id: tx_id.clone(),
state: TransactionState::Active,
operations: Vec::new(),
participants: HashMap::new(),
created_at: Instant::now(),
timeout: self.config.default_timeout,
isolation_level,
};
{
let transactions = self.transactions.read().await;
if transactions.len() >= self.config.max_concurrent_transactions {
return Err(anyhow::anyhow!("Maximum concurrent transactions exceeded"));
}
}
self.transaction_log.lock().await.log_begin(&tx_id).await?;
self.transactions
.write()
.await
.insert(tx_id.clone(), transaction);
Ok(tx_id)
}
pub async fn add_operation(&self, tx_id: &str, operation: TransactionOp) -> Result<()> {
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
if transaction.state != TransactionState::Active {
return Err(anyhow::anyhow!("Transaction is not active"));
}
let shard_id = match &operation {
TransactionOp::Insert { triple } | TransactionOp::Delete { triple } => {
self.shard_router.route_triple(triple).await?
}
TransactionOp::Query { .. } => {
0 }
};
if let std::collections::hash_map::Entry::Vacant(e) =
transaction.participants.entry(shard_id)
{
let node_id = self.shard_manager.get_primary_node(shard_id).await?;
e.insert(TransactionParticipant {
node_id,
shard_id,
vote: None,
last_contact: Instant::now(),
});
}
transaction.operations.push((shard_id, operation));
Ok(())
}
pub async fn commit_transaction(&self, tx_id: &str) -> Result<()> {
let transaction = {
let transactions = self.transactions.read().await;
transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?
.clone()
};
if transaction.created_at.elapsed() > transaction.timeout {
self.abort_transaction(tx_id).await?;
return Err(anyhow::anyhow!("Transaction timed out"));
}
let optimization = self.optimizer.analyze_transaction(&transaction).await;
if optimization.skip_2pc {
self.commit_optimized_transaction(tx_id, optimization)
.await?;
} else {
self.prepare_phase(tx_id).await?;
let should_commit = self.check_votes(tx_id).await?;
if should_commit {
self.commit_phase(tx_id).await?;
} else {
self.abort_phase(tx_id).await?;
}
}
Ok(())
}
pub async fn abort_transaction(&self, tx_id: &str) -> Result<()> {
self.abort_phase(tx_id).await
}
async fn commit_optimized_transaction(
&self,
tx_id: &str,
_optimization: crate::transaction_optimizer::TransactionOptimization,
) -> Result<()> {
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Committed;
}
self.transaction_log.lock().await.log_commit(tx_id).await?;
self.transaction_log
.lock()
.await
.log_complete(tx_id, true)
.await?;
Ok(())
}
async fn prepare_phase(&self, tx_id: &str) -> Result<()> {
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Preparing;
}
self.transaction_log.lock().await.log_prepare(tx_id).await?;
let participants = {
let transactions = self.transactions.read().await;
let transaction = transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.participants.clone()
};
let operations = {
let transactions = self.transactions.read().await;
let transaction = transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.operations.clone()
};
for (shard_id, op) in &operations {
match op {
TransactionOp::Insert { triple } | TransactionOp::Delete { triple } => {
self.lock_manager
.acquire_write_lock(tx_id, *shard_id, &triple.subject().to_string())
.await?;
}
TransactionOp::Query { subject, .. } => {
if let Some(subj) = subject {
self.lock_manager
.acquire_read_lock(tx_id, *shard_id, subj)
.await?;
}
}
}
}
let participant_shard_ids: Vec<_> = participants.keys().copied().collect();
for (shard_id, participant) in participants {
let ops: Vec<_> = operations
.iter()
.filter(|(s, _)| *s == shard_id)
.map(|(_, op)| op.clone())
.collect();
self.send_prepare_request(tx_id, participant.node_id, shard_id, ops)
.await?;
}
if std::env::var("OXIRS_TEST_MODE").is_ok() || cfg!(test) {
for shard_id in participant_shard_ids {
self.handle_prepare_vote(tx_id, shard_id, true).await?;
}
}
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Prepared;
}
Ok(())
}
async fn check_votes(&self, tx_id: &str) -> Result<bool> {
let transactions = self.transactions.read().await;
let transaction = transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
for participant in transaction.participants.values() {
match participant.vote {
Some(true) => continue,
Some(false) => return Ok(false),
None => {
if participant.last_contact.elapsed() > transaction.timeout {
return Ok(false);
}
return Err(anyhow::anyhow!("Not all participants have voted"));
}
}
}
Ok(true)
}
async fn commit_phase(&self, tx_id: &str) -> Result<()> {
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Committing;
}
self.transaction_log.lock().await.log_commit(tx_id).await?;
let participants = {
let transactions = self.transactions.read().await;
let transaction = transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.participants.clone()
};
for (shard_id, participant) in participants {
self.send_commit_request(tx_id, participant.node_id, shard_id)
.await?;
}
self.lock_manager.release_transaction_locks(tx_id).await;
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Committed;
}
self.transaction_log
.lock()
.await
.log_complete(tx_id, true)
.await?;
Ok(())
}
async fn abort_phase(&self, tx_id: &str) -> Result<()> {
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Aborting;
}
self.transaction_log.lock().await.log_abort(tx_id).await?;
let participants = {
let transactions = self.transactions.read().await;
let transaction = transactions
.get(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.participants.clone()
};
for (shard_id, participant) in participants {
self.send_abort_request(tx_id, participant.node_id, shard_id)
.await?;
}
self.lock_manager.release_transaction_locks(tx_id).await;
{
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
transaction.state = TransactionState::Aborted;
}
self.transaction_log
.lock()
.await
.log_complete(tx_id, false)
.await?;
Ok(())
}
async fn send_prepare_request(
&self,
tx_id: &str,
node_id: OxirsNodeId,
shard_id: ShardId,
operations: Vec<TransactionOp>,
) -> Result<()> {
let message = RpcMessage::TransactionPrepare {
tx_id: tx_id.to_string(),
shard_id,
operations,
};
self.network.send_message(node_id, message).await?;
Ok(())
}
async fn send_commit_request(
&self,
tx_id: &str,
node_id: OxirsNodeId,
shard_id: ShardId,
) -> Result<()> {
let message = RpcMessage::TransactionCommit {
tx_id: tx_id.to_string(),
shard_id,
};
self.network.send_message(node_id, message).await?;
Ok(())
}
async fn send_abort_request(
&self,
tx_id: &str,
node_id: OxirsNodeId,
shard_id: ShardId,
) -> Result<()> {
let message = RpcMessage::TransactionAbort {
tx_id: tx_id.to_string(),
shard_id,
};
self.network.send_message(node_id, message).await?;
Ok(())
}
pub async fn handle_prepare_vote(
&self,
tx_id: &str,
shard_id: ShardId,
vote: bool,
) -> Result<()> {
let mut transactions = self.transactions.write().await;
let transaction = transactions
.get_mut(tx_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
if let Some(participant) = transaction.participants.get_mut(&shard_id) {
participant.vote = Some(vote);
participant.last_contact = Instant::now();
}
Ok(())
}
pub async fn get_statistics(&self) -> TransactionStatistics {
let transactions = self.transactions.read().await;
let mut stats = TransactionStatistics::default();
for transaction in transactions.values() {
stats.total_transactions += 1;
match transaction.state {
TransactionState::Active => stats.active_transactions += 1,
TransactionState::Committed => stats.committed_transactions += 1,
TransactionState::Aborted => stats.aborted_transactions += 1,
_ => {}
}
}
stats
}
pub async fn get_optimizer_statistics(
&self,
) -> crate::transaction_optimizer::OptimizationStats {
self.optimizer.get_statistics().await
}
pub async fn cleanup_transactions(&self, retention: Duration) {
let mut transactions = self.transactions.write().await;
let now = Instant::now();
transactions.retain(|_, tx| match tx.state {
TransactionState::Committed | TransactionState::Aborted => {
now.duration_since(tx.created_at) < retention
}
_ => true,
});
}
}
#[derive(Debug)]
struct TransactionLog {
entries: Vec<LogEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LogEntry {
timestamp: DateTime<Utc>,
tx_id: String,
entry_type: LogEntryType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum LogEntryType {
Begin,
Prepare,
Commit,
Abort,
Complete { committed: bool },
}
impl TransactionLog {
fn new() -> Self {
Self {
entries: Vec::new(),
}
}
async fn log_begin(&mut self, tx_id: &str) -> Result<()> {
self.entries.push(LogEntry {
timestamp: Utc::now(),
tx_id: tx_id.to_string(),
entry_type: LogEntryType::Begin,
});
Ok(())
}
async fn log_prepare(&mut self, tx_id: &str) -> Result<()> {
self.entries.push(LogEntry {
timestamp: Utc::now(),
tx_id: tx_id.to_string(),
entry_type: LogEntryType::Prepare,
});
Ok(())
}
async fn log_commit(&mut self, tx_id: &str) -> Result<()> {
self.entries.push(LogEntry {
timestamp: Utc::now(),
tx_id: tx_id.to_string(),
entry_type: LogEntryType::Commit,
});
Ok(())
}
async fn log_abort(&mut self, tx_id: &str) -> Result<()> {
self.entries.push(LogEntry {
timestamp: Utc::now(),
tx_id: tx_id.to_string(),
entry_type: LogEntryType::Abort,
});
Ok(())
}
async fn log_complete(&mut self, tx_id: &str, committed: bool) -> Result<()> {
self.entries.push(LogEntry {
timestamp: Utc::now(),
tx_id: tx_id.to_string(),
entry_type: LogEntryType::Complete { committed },
});
Ok(())
}
}
type TransactionLockMap = HashMap<TransactionId, HashSet<(ShardId, String)>>;
#[derive(Debug)]
struct LockManager {
locks: Arc<RwLock<HashMap<(ShardId, String), Lock>>>,
tx_locks: Arc<RwLock<TransactionLockMap>>,
}
#[derive(Debug)]
struct Lock {
lock_type: LockType,
tx_id: TransactionId,
#[allow(dead_code)]
acquired_at: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum LockType {
Read,
Write,
}
impl LockManager {
fn new() -> Self {
Self {
locks: Arc::new(RwLock::new(HashMap::new())),
tx_locks: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn acquire_read_lock(
&self,
tx_id: &str,
shard_id: ShardId,
resource: &str,
) -> Result<()> {
let key = (shard_id, resource.to_string());
let mut locks = self.locks.write().await;
if let Some(existing) = locks.get(&key) {
if existing.lock_type == LockType::Write && existing.tx_id != tx_id {
return Err(anyhow::anyhow!(
"Resource is write-locked by another transaction"
));
}
}
locks.insert(
key.clone(),
Lock {
lock_type: LockType::Read,
tx_id: tx_id.to_string(),
acquired_at: Instant::now(),
},
);
let mut tx_locks = self.tx_locks.write().await;
tx_locks
.entry(tx_id.to_string())
.or_insert_with(HashSet::new)
.insert(key);
Ok(())
}
async fn acquire_write_lock(
&self,
tx_id: &str,
shard_id: ShardId,
resource: &str,
) -> Result<()> {
let key = (shard_id, resource.to_string());
let mut locks = self.locks.write().await;
if let Some(existing) = locks.get(&key) {
if existing.tx_id != tx_id {
return Err(anyhow::anyhow!("Resource is locked by another transaction"));
}
}
locks.insert(
key.clone(),
Lock {
lock_type: LockType::Write,
tx_id: tx_id.to_string(),
acquired_at: Instant::now(),
},
);
let mut tx_locks = self.tx_locks.write().await;
tx_locks
.entry(tx_id.to_string())
.or_insert_with(HashSet::new)
.insert(key);
Ok(())
}
async fn release_transaction_locks(&self, tx_id: &str) {
let mut tx_locks = self.tx_locks.write().await;
if let Some(locks_to_release) = tx_locks.remove(tx_id) {
let mut locks = self.locks.write().await;
for key in locks_to_release {
locks.remove(&key);
}
}
}
}
#[derive(Debug, Default, Clone)]
pub struct TransactionStatistics {
pub total_transactions: usize,
pub active_transactions: usize,
pub committed_transactions: usize,
pub aborted_transactions: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransactionRpcMessage {
TransactionPrepare {
tx_id: TransactionId,
shard_id: ShardId,
operations: Vec<TransactionOp>,
},
TransactionVote {
tx_id: TransactionId,
shard_id: ShardId,
vote: bool,
},
TransactionCommit {
tx_id: TransactionId,
shard_id: ShardId,
},
TransactionAbort {
tx_id: TransactionId,
shard_id: ShardId,
},
TransactionAck {
tx_id: TransactionId,
shard_id: ShardId,
},
}