use crate::distributed::sharding::{ShardId, ShardManager};
use crate::model::{BlankNode, Literal, NamedNode, Triple};
use anyhow::{anyhow, Result};
use dashmap::DashMap;
use parking_lot::RwLock;
use scirs2_core::random::{Random, Rng};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::{mpsc, oneshot};
use uuid::Uuid;
pub type TransactionId = Uuid;
pub type NodeId = u64;
#[derive(Debug, Clone)]
pub struct TransactionConfig {
pub timeout: Duration,
pub enable_read_only_optimization: bool,
pub enable_single_shard_optimization: bool,
pub max_retries: usize,
pub enable_parallel_prepare: bool,
pub deadlock_timeout: Duration,
}
impl Default for TransactionConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
enable_read_only_optimization: true,
enable_single_shard_optimization: true,
max_retries: 3,
enable_parallel_prepare: true,
deadlock_timeout: Duration::from_secs(10),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Preparing,
Prepared,
Committing,
Committed,
Aborted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransactionOp {
Insert(SerializableTriple),
Remove(SerializableTriple),
Read(ReadQuery),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableTriple {
pub subject: String,
pub predicate: String,
pub object: String,
pub object_type: ObjectType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ObjectType {
NamedNode,
BlankNode,
Literal {
datatype: Option<String>,
language: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadQuery {
pub subject: Option<String>,
pub predicate: Option<String>,
pub object: Option<String>,
}
#[allow(dead_code)]
pub struct TransactionCoordinator {
config: TransactionConfig,
transactions: Arc<DashMap<TransactionId, Transaction>>,
shard_manager: Arc<ShardManager>,
transaction_log: Arc<RwLock<TransactionLog>>,
lock_manager: Arc<LockManager>,
participant_tx: mpsc::UnboundedSender<ParticipantMessage>,
}
pub struct Transaction {
pub id: TransactionId,
pub state: Arc<RwLock<TransactionState>>,
pub operations: Arc<RwLock<Vec<TransactionOp>>>,
pub participants: Arc<RwLock<HashSet<ShardId>>>,
pub votes: Arc<DashMap<ShardId, Vote>>,
pub start_time: Instant,
pub completion_tx: Option<oneshot::Sender<Result<()>>>,
pub is_read_only: bool,
pub is_single_shard: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum Vote {
Yes,
No(AbortReason),
}
#[derive(Debug, Clone, Copy)]
pub enum AbortReason {
LockConflict,
ValidationFailure,
Timeout,
NodeFailure,
Other,
}
pub struct TransactionLog {
entries: Vec<LogEntry>,
log_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub timestamp: SystemTime,
pub transaction_id: TransactionId,
pub event: LogEvent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogEvent {
Started,
PrepareStarted { participants: Vec<ShardId> },
ParticipantVoted { shard: ShardId, vote: bool },
GlobalDecision { commit: bool },
Completed,
}
#[derive(Debug)]
pub enum ParticipantMessage {
Prepare {
transaction_id: TransactionId,
operations: Vec<TransactionOp>,
reply_tx: oneshot::Sender<Vote>,
},
Commit { transaction_id: TransactionId },
Abort { transaction_id: TransactionId },
}
pub struct LockManager {
transaction_locks: Arc<DashMap<TransactionId, HashSet<LockId>>>,
wait_graph: Arc<RwLock<HashMap<TransactionId, HashSet<TransactionId>>>>,
lock_table: Arc<DashMap<LockId, LockInfo>>,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct LockId {
pub shard_id: ShardId,
pub resource: String,
}
#[derive(Debug, Clone)]
pub struct LockInfo {
pub holder: Option<TransactionId>,
pub waiters: Vec<TransactionId>,
pub lock_type: LockType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockType {
Shared,
Exclusive,
}
impl TransactionCoordinator {
pub fn new(config: TransactionConfig, shard_manager: Arc<ShardManager>) -> Self {
let (participant_tx, _participant_rx) = mpsc::unbounded_channel();
Self {
config,
transactions: Arc::new(DashMap::new()),
shard_manager,
transaction_log: Arc::new(RwLock::new(TransactionLog::new())),
lock_manager: Arc::new(LockManager::new()),
participant_tx,
}
}
pub async fn begin_transaction(&self) -> Result<TransactionId> {
let transaction_id = Uuid::new_v4();
let (completion_tx, _completion_rx) = oneshot::channel();
let transaction = Transaction {
id: transaction_id,
state: Arc::new(RwLock::new(TransactionState::Active)),
operations: Arc::new(RwLock::new(Vec::new())),
participants: Arc::new(RwLock::new(HashSet::new())),
votes: Arc::new(DashMap::new()),
start_time: Instant::now(),
completion_tx: Some(completion_tx),
is_read_only: true, is_single_shard: true, };
self.transactions.insert(transaction_id, transaction);
self.log_event(LogEntry {
timestamp: SystemTime::now(),
transaction_id,
event: LogEvent::Started,
});
Ok(transaction_id)
}
pub async fn add_operation(
&self,
transaction_id: TransactionId,
operation: TransactionOp,
) -> Result<()> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
if *transaction.state.read() != TransactionState::Active {
return Err(anyhow!("Transaction is not active"));
}
let affected_shards = self.get_affected_shards(&operation)?;
{
let mut ops = transaction.operations.write();
ops.push(operation.clone());
if matches!(
operation,
TransactionOp::Insert(_) | TransactionOp::Remove(_)
) {
let state = transaction.state.write();
drop(state); }
let mut participants = transaction.participants.write();
for shard in affected_shards {
participants.insert(shard);
}
if participants.len() > 1 {
}
}
Ok(())
}
pub async fn commit_transaction(&self, transaction_id: TransactionId) -> Result<()> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
if transaction.start_time.elapsed() > self.config.timeout {
self.abort_transaction(transaction_id).await?;
return Err(anyhow!("Transaction timeout"));
}
if transaction.is_read_only && self.config.enable_read_only_optimization {
self.complete_transaction(transaction_id, true).await?;
return Ok(());
}
if transaction.is_single_shard && self.config.enable_single_shard_optimization {
return self.commit_single_shard(transaction_id).await;
}
self.two_phase_commit(transaction_id).await
}
async fn two_phase_commit(&self, transaction_id: TransactionId) -> Result<()> {
let prepare_result = self.prepare_phase(transaction_id).await?;
if prepare_result {
self.commit_phase(transaction_id).await
} else {
self.abort_phase(transaction_id).await
}
}
async fn prepare_phase(&self, transaction_id: TransactionId) -> Result<bool> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
*transaction.state.write() = TransactionState::Preparing;
let participants = transaction.participants.read().clone();
let operations = transaction.operations.read().clone();
self.log_event(LogEntry {
timestamp: SystemTime::now(),
transaction_id,
event: LogEvent::PrepareStarted {
participants: participants.iter().copied().collect(),
},
});
let mut prepare_futures = Vec::new();
for shard_id in participants {
let (reply_tx, _reply_rx) = oneshot::channel();
let _message = ParticipantMessage::Prepare {
transaction_id,
operations: self.filter_operations_for_shard(shard_id, &operations),
reply_tx,
};
let vote = self.simulate_participant_vote(shard_id, &operations);
let (sim_tx, sim_rx) = oneshot::channel();
let _ = sim_tx.send(vote);
prepare_futures
.push(async move { sim_rx.await.unwrap_or(Vote::No(AbortReason::NodeFailure)) });
}
let votes = if self.config.enable_parallel_prepare {
futures::future::join_all(prepare_futures).await
} else {
let mut votes = Vec::new();
for future in prepare_futures {
votes.push(future.await);
}
votes
};
let mut all_yes = true;
for (i, vote) in votes.iter().enumerate() {
let shard_id = *transaction
.participants
.read()
.iter()
.nth(i)
.expect("participant index should be valid");
transaction.votes.insert(shard_id, *vote);
self.log_event(LogEntry {
timestamp: SystemTime::now(),
transaction_id,
event: LogEvent::ParticipantVoted {
shard: shard_id,
vote: matches!(vote, Vote::Yes),
},
});
if !matches!(vote, Vote::Yes) {
all_yes = false;
}
}
if all_yes {
*transaction.state.write() = TransactionState::Prepared;
}
self.log_event(LogEntry {
timestamp: SystemTime::now(),
transaction_id,
event: LogEvent::GlobalDecision { commit: all_yes },
});
Ok(all_yes)
}
async fn commit_phase(&self, transaction_id: TransactionId) -> Result<()> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
*transaction.state.write() = TransactionState::Committing;
let participants = transaction.participants.read().clone();
for shard_id in participants {
let _message = ParticipantMessage::Commit { transaction_id };
self.simulate_participant_commit(shard_id, transaction_id)?;
}
self.complete_transaction(transaction_id, true).await
}
async fn abort_phase(&self, transaction_id: TransactionId) -> Result<()> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
let participants = transaction.participants.read().clone();
for shard_id in participants {
let _message = ParticipantMessage::Abort { transaction_id };
self.simulate_participant_abort(shard_id, transaction_id)?;
}
self.complete_transaction(transaction_id, false).await
}
pub async fn abort_transaction(&self, transaction_id: TransactionId) -> Result<()> {
self.abort_phase(transaction_id).await
}
async fn complete_transaction(
&self,
transaction_id: TransactionId,
committed: bool,
) -> Result<()> {
if let Some((_, transaction)) = self.transactions.remove(&transaction_id) {
*transaction.state.write() = if committed {
TransactionState::Committed
} else {
TransactionState::Aborted
};
self.log_event(LogEntry {
timestamp: SystemTime::now(),
transaction_id,
event: LogEvent::Completed,
});
if let Some(tx) = transaction.completion_tx {
let _ = tx.send(Ok(()));
}
self.lock_manager.release_transaction_locks(transaction_id);
}
Ok(())
}
async fn commit_single_shard(&self, transaction_id: TransactionId) -> Result<()> {
let transaction = self
.transactions
.get(&transaction_id)
.ok_or_else(|| anyhow!("Transaction not found"))?;
let shard_id = *transaction
.participants
.read()
.iter()
.next()
.ok_or_else(|| anyhow!("No participants"))?;
let _message = ParticipantMessage::Commit { transaction_id };
self.simulate_participant_commit(shard_id, transaction_id)?;
self.complete_transaction(transaction_id, true).await
}
fn get_affected_shards(&self, operation: &TransactionOp) -> Result<Vec<ShardId>> {
match operation {
TransactionOp::Insert(triple) | TransactionOp::Remove(triple) => {
let t = self.deserialize_triple(triple)?;
Ok(vec![self.shard_manager.get_shard_for_triple(&t)])
}
TransactionOp::Read(_query) => {
Ok((0..16).collect()) }
}
}
fn filter_operations_for_shard(
&self,
shard_id: ShardId,
operations: &[TransactionOp],
) -> Vec<TransactionOp> {
operations
.iter()
.filter(|op| match self.get_affected_shards(op) {
Ok(shards) => shards.contains(&shard_id),
Err(_) => false,
})
.cloned()
.collect()
}
fn simulate_participant_vote(&self, _shard_id: ShardId, _operations: &[TransactionOp]) -> Vote {
if {
let mut rng = Random::default();
rng.random::<f32>()
} < 0.95
{
Vote::Yes
} else {
Vote::No(AbortReason::LockConflict)
}
}
fn simulate_participant_commit(
&self,
_shard_id: ShardId,
_transaction_id: TransactionId,
) -> Result<()> {
Ok(())
}
fn simulate_participant_abort(
&self,
_shard_id: ShardId,
_transaction_id: TransactionId,
) -> Result<()> {
Ok(())
}
fn log_event(&self, entry: LogEntry) {
self.transaction_log.write().add_entry(entry);
}
fn deserialize_triple(&self, st: &SerializableTriple) -> Result<Triple> {
let subject = NamedNode::new(&st.subject)?;
let predicate = NamedNode::new(&st.predicate)?;
let object = match &st.object_type {
ObjectType::NamedNode => crate::model::Object::NamedNode(NamedNode::new(&st.object)?),
ObjectType::BlankNode => crate::model::Object::BlankNode(BlankNode::new(&st.object)?),
ObjectType::Literal { datatype, language } => {
if let Some(lang) = language {
crate::model::Object::Literal(Literal::new_language_tagged_literal(
&st.object, lang,
)?)
} else if let Some(dt) = datatype {
crate::model::Object::Literal(Literal::new_typed(
&st.object,
NamedNode::new(dt)?,
))
} else {
crate::model::Object::Literal(Literal::new(&st.object))
}
}
};
Ok(Triple::new(subject, predicate, object))
}
}
impl Default for TransactionLog {
fn default() -> Self {
Self::new()
}
}
impl TransactionLog {
pub fn new() -> Self {
Self {
entries: Vec::new(),
log_path: None,
}
}
pub fn add_entry(&mut self, entry: LogEntry) {
self.entries.push(entry);
if let Some(_path) = &self.log_path {
}
}
pub fn get_transaction_entries(&self, transaction_id: TransactionId) -> Vec<&LogEntry> {
self.entries
.iter()
.filter(|e| e.transaction_id == transaction_id)
.collect()
}
}
impl Default for LockManager {
fn default() -> Self {
Self::new()
}
}
impl LockManager {
pub fn new() -> Self {
Self {
transaction_locks: Arc::new(DashMap::new()),
wait_graph: Arc::new(RwLock::new(HashMap::new())),
lock_table: Arc::new(DashMap::new()),
}
}
pub fn acquire_lock(
&self,
transaction_id: TransactionId,
lock_id: LockId,
lock_type: LockType,
) -> Result<()> {
let mut lock_info = self.lock_table.entry(lock_id.clone()).or_insert(LockInfo {
holder: None,
waiters: Vec::new(),
lock_type: LockType::Shared,
});
let can_grant = match (&lock_info.holder, lock_type) {
(None, _) => true,
(Some(holder), LockType::Shared) if *holder == transaction_id => true,
(Some(_), LockType::Shared) if lock_info.lock_type == LockType::Shared => true,
_ => false,
};
if can_grant {
lock_info.holder = Some(transaction_id);
lock_info.lock_type = lock_type;
self.transaction_locks
.entry(transaction_id)
.or_default()
.insert(lock_id);
Ok(())
} else {
lock_info.waiters.push(transaction_id);
if let Some(holder) = lock_info.holder {
let mut wait_graph = self.wait_graph.write();
wait_graph.entry(transaction_id).or_default().insert(holder);
}
Err(anyhow!("Lock not available"))
}
}
pub fn release_transaction_locks(&self, transaction_id: TransactionId) {
if let Some((_, locks)) = self.transaction_locks.remove(&transaction_id) {
for lock_id in locks {
self.release_lock(transaction_id, &lock_id);
}
}
let mut wait_graph = self.wait_graph.write();
wait_graph.remove(&transaction_id);
for waiters in wait_graph.values_mut() {
waiters.remove(&transaction_id);
}
}
fn release_lock(&self, transaction_id: TransactionId, lock_id: &LockId) {
if let Some(mut lock_info) = self.lock_table.get_mut(lock_id) {
if lock_info.holder == Some(transaction_id) {
if let Some(next_holder) = lock_info.waiters.first().copied() {
lock_info.holder = Some(next_holder);
lock_info.waiters.remove(0);
let mut wait_graph = self.wait_graph.write();
if let Some(waiting_on) = wait_graph.get_mut(&next_holder) {
waiting_on.remove(&transaction_id);
}
} else {
lock_info.holder = None;
}
}
}
}
pub fn detect_deadlocks(&self) -> Vec<Vec<TransactionId>> {
let wait_graph = self.wait_graph.read();
let mut cycles = Vec::new();
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &node in wait_graph.keys() {
if !visited.contains(&node) {
let mut path = Vec::new();
if Self::detect_cycle_dfs(
&wait_graph,
node,
&mut visited,
&mut rec_stack,
&mut path,
&mut cycles,
) {
}
}
}
cycles
}
fn detect_cycle_dfs(
graph: &HashMap<TransactionId, HashSet<TransactionId>>,
node: TransactionId,
visited: &mut HashSet<TransactionId>,
rec_stack: &mut HashSet<TransactionId>,
path: &mut Vec<TransactionId>,
cycles: &mut Vec<Vec<TransactionId>>,
) -> bool {
visited.insert(node);
rec_stack.insert(node);
path.push(node);
if let Some(neighbors) = graph.get(&node) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
if Self::detect_cycle_dfs(graph, neighbor, visited, rec_stack, path, cycles) {
return true;
}
} else if rec_stack.contains(&neighbor) {
let cycle_start = path
.iter()
.position(|&n| n == neighbor)
.expect("neighbor should exist in path when cycle detected");
cycles.push(path[cycle_start..].to_vec());
return true;
}
}
}
path.pop();
rec_stack.remove(&node);
false
}
}
use futures;
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::sharding::ShardingStrategy;
#[tokio::test]
#[ignore] async fn test_basic_transaction() {
use tokio::time::{timeout, Duration};
let config = TransactionConfig {
timeout: Duration::from_secs(5),
..Default::default()
};
let shard_config = crate::distributed::sharding::ShardingConfig::default();
let shard_manager = Arc::new(ShardManager::new(shard_config, ShardingStrategy::Hash));
let coordinator = TransactionCoordinator::new(config, shard_manager);
let tx_id = timeout(Duration::from_secs(2), coordinator.begin_transaction())
.await
.expect("begin_transaction timed out")
.expect("begin_transaction failed");
let op = TransactionOp::Insert(SerializableTriple {
subject: "http://example.org/s".to_string(),
predicate: "http://example.org/p".to_string(),
object: "value".to_string(),
object_type: ObjectType::Literal {
datatype: None,
language: None,
},
});
timeout(Duration::from_secs(2), coordinator.add_operation(tx_id, op))
.await
.expect("add_operation timed out")
.expect("add_operation failed");
let transaction = coordinator
.transactions
.get(&tx_id)
.expect("Transaction should exist");
{
let participants = transaction.participants.read();
assert!(
!participants.is_empty(),
"Transaction should have participants after adding operation"
);
println!("Participants: {:?}", *participants);
}
timeout(
Duration::from_secs(2),
coordinator.commit_transaction(tx_id),
)
.await
.expect("commit_transaction timed out")
.expect("commit_transaction failed");
}
#[test]
fn test_lock_manager() {
let lock_manager = LockManager::new();
let tx1 = Uuid::new_v4();
let tx2 = Uuid::new_v4();
let lock1 = LockId {
shard_id: 0,
resource: "resource1".to_string(),
};
assert!(lock_manager
.acquire_lock(tx1, lock1.clone(), LockType::Exclusive)
.is_ok());
assert!(lock_manager
.acquire_lock(tx2, lock1.clone(), LockType::Shared)
.is_err());
lock_manager.release_transaction_locks(tx1);
assert!(lock_manager
.acquire_lock(tx2, lock1, LockType::Shared)
.is_ok());
}
#[test]
fn test_deadlock_detection() {
let lock_manager = LockManager::new();
let mut wait_graph = lock_manager.wait_graph.write();
let tx1 = Uuid::new_v4();
let tx2 = Uuid::new_v4();
let tx3 = Uuid::new_v4();
wait_graph.insert(tx1, vec![tx2].into_iter().collect());
wait_graph.insert(tx2, vec![tx3].into_iter().collect());
wait_graph.insert(tx3, vec![tx1].into_iter().collect());
drop(wait_graph);
let cycles = lock_manager.detect_deadlocks();
assert_eq!(cycles.len(), 1);
assert_eq!(cycles[0].len(), 3);
}
}