use crate::shard::ShardId;
use crate::transaction::{
IsolationLevel, Transaction, TransactionId, TransactionOp, TransactionState,
};
use anyhow::Result;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct TwoPhaseOptimizer {
enable_readonly_opt: bool,
enable_single_shard_opt: bool,
enable_presumed_abort: bool,
enable_batching: bool,
batch_size: usize,
stats: Arc<RwLock<OptimizationStats>>,
}
impl Default for TwoPhaseOptimizer {
fn default() -> Self {
Self::new()
}
}
impl TwoPhaseOptimizer {
pub fn new() -> Self {
Self {
enable_readonly_opt: true,
enable_single_shard_opt: true,
enable_presumed_abort: true,
enable_batching: true,
batch_size: 100,
stats: Arc::new(RwLock::new(OptimizationStats::default())),
}
}
pub async fn analyze_transaction(&self, transaction: &Transaction) -> TransactionOptimization {
let mut optimization = TransactionOptimization::default();
let is_readonly = transaction
.operations
.iter()
.all(|(_, op)| matches!(op, TransactionOp::Query { .. }));
if is_readonly && self.enable_readonly_opt {
optimization.skip_2pc = true;
optimization.reason = "Read-only transaction".to_string();
self.stats.write().await.readonly_optimized += 1;
return optimization;
}
let affected_shards: HashSet<_> = transaction
.operations
.iter()
.map(|(shard_id, _)| *shard_id)
.collect();
if affected_shards.len() == 1 && self.enable_single_shard_opt {
optimization.skip_2pc = true;
optimization.single_shard = Some(
*affected_shards
.iter()
.next()
.expect("affected_shards should not be empty when len == 1"),
);
optimization.reason = "Single-shard transaction".to_string();
self.stats.write().await.single_shard_optimized += 1;
return optimization;
}
if self.enable_presumed_abort {
optimization.use_presumed_abort = true;
self.stats.write().await.presumed_abort_used += 1;
}
if self.enable_batching && transaction.operations.len() > self.batch_size {
optimization.batch_operations = true;
optimization.batch_size = self.batch_size;
self.stats.write().await.batched_transactions += 1;
}
match transaction.isolation_level {
IsolationLevel::ReadUncommitted => {
optimization.skip_locking = true;
}
IsolationLevel::ReadCommitted => {
optimization.release_locks_early = true;
}
_ => {}
}
optimization
}
pub async fn optimize_prepare_phase(&self, transaction: &Transaction) -> PrepareOptimization {
let mut optimization = PrepareOptimization::default();
let mut shard_ops: HashMap<ShardId, Vec<TransactionOp>> = HashMap::new();
for (shard_id, op) in &transaction.operations {
shard_ops.entry(*shard_id).or_default().push(op.clone());
}
let parallel_groups = self.compute_parallel_groups(&shard_ops);
optimization.parallel_groups = parallel_groups;
optimization.critical_path = self.compute_critical_path(&shard_ops);
optimization.optimized_timeout = self.compute_optimized_timeout(transaction);
optimization
}
fn compute_parallel_groups(
&self,
shard_ops: &HashMap<ShardId, Vec<TransactionOp>>,
) -> Vec<Vec<ShardId>> {
vec![shard_ops.keys().cloned().collect()]
}
fn compute_critical_path(
&self,
shard_ops: &HashMap<ShardId, Vec<TransactionOp>>,
) -> Vec<ShardId> {
let mut shards: Vec<_> = shard_ops
.iter()
.map(|(shard_id, ops)| (*shard_id, ops.len()))
.collect();
shards.sort_by_key(|(_, count)| std::cmp::Reverse(*count));
shards.into_iter().map(|(shard_id, _)| shard_id).collect()
}
fn compute_optimized_timeout(&self, transaction: &Transaction) -> std::time::Duration {
use std::time::Duration;
let mut timeout = Duration::from_secs(10);
timeout += Duration::from_millis(transaction.operations.len() as u64 * 100);
timeout += Duration::from_millis(transaction.participants.len() as u64 * 500);
match transaction.isolation_level {
IsolationLevel::Serializable => timeout *= 2,
IsolationLevel::RepeatableRead => timeout = timeout * 3 / 2,
_ => {}
}
timeout.min(Duration::from_secs(60)) }
pub async fn optimize_commit_phase(&self, transaction: &Transaction) -> CommitOptimization {
let mut optimization = CommitOptimization::default();
if transaction.isolation_level != IsolationLevel::Serializable {
optimization.async_commit = true;
}
optimization.group_commit = self.enable_batching;
if self.enable_presumed_abort {
optimization.skip_participant_logging = true;
}
optimization
}
pub async fn get_statistics(&self) -> OptimizationStats {
self.stats.read().await.clone()
}
pub async fn reset_statistics(&self) {
*self.stats.write().await = OptimizationStats::default();
}
}
#[derive(Debug, Default)]
pub struct TransactionOptimization {
pub skip_2pc: bool,
pub single_shard: Option<ShardId>,
pub use_presumed_abort: bool,
pub batch_operations: bool,
pub batch_size: usize,
pub skip_locking: bool,
pub release_locks_early: bool,
pub reason: String,
}
#[derive(Debug, Default)]
pub struct PrepareOptimization {
pub parallel_groups: Vec<Vec<ShardId>>,
pub critical_path: Vec<ShardId>,
pub optimized_timeout: std::time::Duration,
}
#[derive(Debug, Default)]
pub struct CommitOptimization {
pub async_commit: bool,
pub group_commit: bool,
pub skip_participant_logging: bool,
}
#[derive(Debug, Default, Clone)]
pub struct OptimizationStats {
pub readonly_optimized: u64,
pub single_shard_optimized: u64,
pub presumed_abort_used: u64,
pub batched_transactions: u64,
pub total_analyzed: u64,
}
pub struct DeadlockDetector {
wait_graph: Arc<RwLock<HashMap<TransactionId, HashSet<TransactionId>>>>,
}
impl Default for DeadlockDetector {
fn default() -> Self {
Self::new()
}
}
impl DeadlockDetector {
pub fn new() -> Self {
Self {
wait_graph: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_wait(&self, waiter: &str, holder: &str) -> Result<()> {
let mut graph = self.wait_graph.write().await;
if self.would_create_cycle(&graph, waiter, holder) {
return Err(anyhow::anyhow!("Deadlock detected"));
}
graph
.entry(waiter.to_string())
.or_insert_with(HashSet::new)
.insert(holder.to_string());
Ok(())
}
pub async fn remove_transaction(&self, tx_id: &str) {
let mut graph = self.wait_graph.write().await;
graph.remove(tx_id);
for waiters in graph.values_mut() {
waiters.remove(tx_id);
}
}
fn would_create_cycle(
&self,
graph: &HashMap<TransactionId, HashSet<TransactionId>>,
from: &str,
to: &str,
) -> bool {
let mut visited = HashSet::new();
let mut stack = vec![to.to_string()];
while let Some(node) = stack.pop() {
if node == from {
return true; }
if visited.insert(node.clone()) {
if let Some(neighbors) = graph.get(&node) {
stack.extend(neighbors.iter().cloned());
}
}
}
false
}
}
pub struct RecoveryOptimizer {
checkpoint_interval: std::time::Duration,
last_checkpoint: Arc<RwLock<std::time::Instant>>,
}
impl RecoveryOptimizer {
pub fn new(checkpoint_interval: std::time::Duration) -> Self {
Self {
checkpoint_interval,
last_checkpoint: Arc::new(RwLock::new(std::time::Instant::now())),
}
}
pub async fn should_checkpoint(&self) -> bool {
let last = *self.last_checkpoint.read().await;
last.elapsed() >= self.checkpoint_interval
}
pub async fn update_checkpoint(&self) {
*self.last_checkpoint.write().await = std::time::Instant::now();
}
pub fn optimize_recovery_plan(
&self,
pending_transactions: Vec<(TransactionId, TransactionState)>,
) -> RecoveryPlan {
let mut plan = RecoveryPlan::default();
for (tx_id, state) in pending_transactions {
match state {
TransactionState::Preparing | TransactionState::Prepared => {
plan.transactions_to_query.push(tx_id);
}
TransactionState::Committing => {
plan.transactions_to_commit.push(tx_id);
}
TransactionState::Aborting => {
plan.transactions_to_abort.push(tx_id);
}
_ => {}
}
}
plan
}
}
#[derive(Debug, Default)]
pub struct RecoveryPlan {
pub transactions_to_query: Vec<TransactionId>,
pub transactions_to_commit: Vec<TransactionId>,
pub transactions_to_abort: Vec<TransactionId>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_readonly_optimization() {
let optimizer = TwoPhaseOptimizer::new();
let transaction = Transaction {
id: "test-tx".to_string(),
state: TransactionState::Active,
operations: vec![(
0,
TransactionOp::Query {
subject: Some("test".to_string()),
predicate: None,
object: None,
},
)],
participants: HashMap::new(),
created_at: std::time::Instant::now(),
timeout: std::time::Duration::from_secs(30),
isolation_level: IsolationLevel::ReadCommitted,
};
let optimization = optimizer.analyze_transaction(&transaction).await;
assert!(optimization.skip_2pc);
assert_eq!(optimization.reason, "Read-only transaction");
}
#[tokio::test]
async fn test_single_shard_optimization() {
let optimizer = TwoPhaseOptimizer::new();
let transaction = Transaction {
id: "test-tx".to_string(),
state: TransactionState::Active,
operations: vec![(
0,
TransactionOp::Insert {
triple: oxirs_core::model::Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/o").unwrap(),
),
},
)],
participants: HashMap::new(),
created_at: std::time::Instant::now(),
timeout: std::time::Duration::from_secs(30),
isolation_level: IsolationLevel::ReadCommitted,
};
let optimization = optimizer.analyze_transaction(&transaction).await;
assert!(optimization.skip_2pc);
assert_eq!(optimization.single_shard, Some(0));
assert_eq!(optimization.reason, "Single-shard transaction");
}
#[test]
fn test_deadlock_detection() {
let detector = DeadlockDetector::new();
let mut graph = HashMap::new();
graph.insert(
"tx1".to_string(),
vec!["tx2".to_string()].into_iter().collect(),
);
graph.insert(
"tx2".to_string(),
vec!["tx3".to_string()].into_iter().collect(),
);
assert!(detector.would_create_cycle(&graph, "tx3", "tx1"));
assert!(!detector.would_create_cycle(&graph, "tx3", "tx4"));
}
}