use crate::transaction::{IsolationLevel, TransactionId};
use anyhow::Result;
use dashmap::DashMap;
use oxirs_core::model::Triple;
#[cfg(test)]
use oxirs_core::vocab::xsd;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct HybridLogicalClock {
physical_time: AtomicU64,
logical_counter: AtomicU64,
node_id: u64,
}
impl HybridLogicalClock {
pub fn new(node_id: u64) -> Self {
let physical_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_millis() as u64;
Self {
physical_time: AtomicU64::new(physical_time),
logical_counter: AtomicU64::new(0),
node_id,
}
}
pub fn now(&self) -> HLCTimestamp {
let current_physical = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_millis() as u64;
let last_physical = self.physical_time.load(AtomicOrdering::SeqCst);
let (physical, logical) = if current_physical > last_physical {
self.physical_time
.store(current_physical, AtomicOrdering::SeqCst);
self.logical_counter.store(0, AtomicOrdering::SeqCst);
(current_physical, 0)
} else {
let logical = self.logical_counter.fetch_add(1, AtomicOrdering::SeqCst) + 1;
(last_physical, logical)
};
HLCTimestamp {
physical,
logical,
node_id: self.node_id,
}
}
pub fn update(&self, received: &HLCTimestamp) -> HLCTimestamp {
let current_physical = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_millis() as u64;
let last_physical = self.physical_time.load(AtomicOrdering::SeqCst);
let max_physical = current_physical.max(last_physical).max(received.physical);
let (physical, logical) =
if max_physical > last_physical && max_physical > received.physical {
self.physical_time
.store(max_physical, AtomicOrdering::SeqCst);
self.logical_counter.store(0, AtomicOrdering::SeqCst);
(max_physical, 0)
} else if max_physical == received.physical {
let logical = if max_physical == last_physical {
self.logical_counter
.load(AtomicOrdering::SeqCst)
.max(received.logical)
+ 1
} else {
received.logical + 1
};
self.physical_time
.store(max_physical, AtomicOrdering::SeqCst);
self.logical_counter.store(logical, AtomicOrdering::SeqCst);
(max_physical, logical)
} else {
let logical = self.logical_counter.fetch_add(1, AtomicOrdering::SeqCst) + 1;
(max_physical, logical)
};
HLCTimestamp {
physical,
logical,
node_id: self.node_id,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct HLCTimestamp {
pub physical: u64,
pub logical: u64,
pub node_id: u64,
}
impl PartialOrd for HLCTimestamp {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HLCTimestamp {
fn cmp(&self, other: &Self) -> Ordering {
match self.physical.cmp(&other.physical) {
Ordering::Equal => match self.logical.cmp(&other.logical) {
Ordering::Equal => self.node_id.cmp(&other.node_id),
other => other,
},
other => other,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Version {
pub timestamp: HLCTimestamp,
pub transaction_id: TransactionId,
pub is_deleted: bool,
pub data: Option<Triple>,
}
#[derive(Debug, Clone)]
pub struct MVCCConfig {
pub enable_snapshot_isolation: bool,
pub gc_interval: Duration,
pub gc_min_age: Duration,
pub max_versions_per_key: usize,
pub enable_conflict_detection: bool,
}
impl Default for MVCCConfig {
fn default() -> Self {
Self {
enable_snapshot_isolation: true,
gc_interval: Duration::from_secs(60),
gc_min_age: Duration::from_secs(300), max_versions_per_key: 100,
enable_conflict_detection: true,
}
}
}
#[derive(Debug, Clone)]
pub struct TransactionSnapshot {
pub transaction_id: TransactionId,
pub timestamp: HLCTimestamp,
pub isolation_level: IsolationLevel,
pub read_set: Arc<RwLock<HashSet<String>>>,
pub write_set: Arc<RwLock<HashSet<String>>>,
}
pub struct MVCCManager {
config: MVCCConfig,
clock: Arc<HybridLogicalClock>,
versions: Arc<DashMap<String, BTreeMap<HLCTimestamp, Version>>>,
transactions: Arc<RwLock<HashMap<TransactionId, TransactionSnapshot>>>,
committed_transactions: Arc<RwLock<BTreeMap<HLCTimestamp, TransactionId>>>,
gc_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}
impl MVCCManager {
pub fn new(node_id: u64, config: MVCCConfig) -> Self {
Self {
config,
clock: Arc::new(HybridLogicalClock::new(node_id)),
versions: Arc::new(DashMap::new()),
transactions: Arc::new(RwLock::new(HashMap::new())),
committed_transactions: Arc::new(RwLock::new(BTreeMap::new())),
gc_handle: Arc::new(Mutex::new(None)),
}
}
pub async fn start(&self) -> Result<()> {
let gc_interval = self.config.gc_interval;
let gc_min_age = self.config.gc_min_age;
let max_versions = self.config.max_versions_per_key;
let versions = Arc::clone(&self.versions);
let committed_transactions = Arc::clone(&self.committed_transactions);
let clock = Arc::clone(&self.clock);
let gc_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(gc_interval);
loop {
interval.tick().await;
let current_time = clock.now();
let cutoff_physical = current_time
.physical
.saturating_sub(gc_min_age.as_millis() as u64);
for mut entry in versions.iter_mut() {
let _key = entry.key();
let versions_map = entry.value_mut();
versions_map.retain(|timestamp, _| timestamp.physical >= cutoff_physical);
if versions_map.len() > max_versions {
let to_remove: Vec<_> = versions_map
.keys()
.take(versions_map.len() - max_versions)
.cloned()
.collect();
for timestamp in to_remove {
versions_map.remove(×tamp);
}
}
}
let mut committed = committed_transactions.write().await;
committed.retain(|timestamp, _| timestamp.physical >= cutoff_physical);
debug!("MVCC garbage collection completed");
}
});
*self.gc_handle.lock().await = Some(gc_task);
info!("MVCC manager started with garbage collection");
Ok(())
}
pub async fn stop(&self) -> Result<()> {
if let Some(handle) = self.gc_handle.lock().await.take() {
handle.abort();
}
info!("MVCC manager stopped");
Ok(())
}
pub async fn begin_transaction(
&self,
transaction_id: TransactionId,
isolation_level: IsolationLevel,
) -> Result<TransactionSnapshot> {
let timestamp = self.clock.now();
let snapshot = TransactionSnapshot {
transaction_id: transaction_id.clone(),
timestamp,
isolation_level,
read_set: Arc::new(RwLock::new(HashSet::new())),
write_set: Arc::new(RwLock::new(HashSet::new())),
};
self.transactions
.write()
.await
.insert(transaction_id, snapshot.clone());
debug!(
"Started MVCC transaction {} at {:?}",
snapshot.transaction_id, timestamp
);
Ok(snapshot)
}
pub async fn read(&self, transaction_id: &TransactionId, key: &str) -> Result<Option<Triple>> {
let transactions = self.transactions.read().await;
let snapshot = transactions
.get(transaction_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
if self.config.enable_conflict_detection {
snapshot.read_set.write().await.insert(key.to_string());
}
let version = match snapshot.isolation_level {
IsolationLevel::ReadUncommitted => {
self.get_latest_version(key).await
}
IsolationLevel::ReadCommitted => {
if let Some(version) = self.get_version_from_transaction(key, transaction_id).await
{
Some(version)
} else {
self.get_latest_committed_version(key, &snapshot.timestamp)
.await
}
}
IsolationLevel::RepeatableRead | IsolationLevel::Serializable => {
self.get_version_at_timestamp(key, &snapshot.timestamp)
.await
}
};
Ok(version.and_then(|v| v.data))
}
pub async fn write(
&self,
transaction_id: &TransactionId,
key: &str,
triple: Option<Triple>,
) -> Result<()> {
let transactions = self.transactions.read().await;
let snapshot = transactions
.get(transaction_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
if self.config.enable_conflict_detection {
snapshot.write_set.write().await.insert(key.to_string());
}
let timestamp = self.clock.now();
let version = Version {
timestamp,
transaction_id: transaction_id.clone(),
is_deleted: triple.is_none(),
data: triple,
};
self.versions
.entry(key.to_string())
.or_default()
.insert(timestamp, version);
debug!(
"Wrote version for key {} in transaction {} at {:?}",
key, transaction_id, timestamp
);
Ok(())
}
pub async fn check_conflicts(&self, transaction_id: &TransactionId) -> Result<bool> {
if !self.config.enable_conflict_detection {
return Ok(false);
}
let transactions = self.transactions.read().await;
let snapshot = transactions
.get(transaction_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
let read_set = snapshot.read_set.read().await;
let write_set = snapshot.write_set.read().await;
let committed = self.committed_transactions.read().await;
for key in write_set.iter() {
if let Some(versions) = self.versions.get(key) {
let has_conflict = versions.range(snapshot.timestamp..).any(|(ts, v)| {
ts > &snapshot.timestamp &&
v.transaction_id != *transaction_id &&
committed.values().any(|tx_id| tx_id == &v.transaction_id)
});
if has_conflict {
warn!(
"Write-write conflict detected for key {} in transaction {}",
key, transaction_id
);
return Ok(true);
}
}
}
if snapshot.isolation_level == IsolationLevel::Serializable {
for key in read_set.iter() {
if let Some(versions) = self.versions.get(key) {
let has_conflict = versions.range(snapshot.timestamp..).any(|(ts, v)| {
ts > &snapshot.timestamp &&
v.transaction_id != *transaction_id &&
committed.values().any(|tx_id| tx_id == &v.transaction_id)
});
if has_conflict {
warn!(
"Read-write conflict detected for key {} in transaction {}",
key, transaction_id
);
return Ok(true);
}
}
}
}
Ok(false)
}
pub async fn commit_transaction(&self, transaction_id: &TransactionId) -> Result<()> {
if self.check_conflicts(transaction_id).await? {
return Err(anyhow::anyhow!("Transaction conflicts detected"));
}
let timestamp = {
let transactions = self.transactions.read().await;
let snapshot = transactions
.get(transaction_id)
.ok_or_else(|| anyhow::anyhow!("Transaction not found"))?;
snapshot.timestamp
};
self.committed_transactions
.write()
.await
.insert(timestamp, transaction_id.clone());
self.transactions.write().await.remove(transaction_id);
info!(
"Committed transaction {} at {:?}",
transaction_id, timestamp
);
Ok(())
}
pub async fn rollback_transaction(&self, transaction_id: &TransactionId) -> Result<()> {
for mut entry in self.versions.iter_mut() {
entry
.value_mut()
.retain(|_, version| version.transaction_id != *transaction_id);
}
self.transactions.write().await.remove(transaction_id);
info!("Rolled back transaction {}", transaction_id);
Ok(())
}
async fn get_latest_version(&self, key: &str) -> Option<Version> {
self.versions
.get(key)
.and_then(|versions| versions.values().last().cloned())
}
async fn get_latest_committed_version(
&self,
key: &str,
before_timestamp: &HLCTimestamp,
) -> Option<Version> {
let committed = self.committed_transactions.read().await;
self.versions.get(key).and_then(|versions| {
versions
.range(..=before_timestamp)
.rev()
.find(|(_ts, version)| {
committed
.values()
.any(|tx_id| tx_id == &version.transaction_id)
})
.map(|(_, version)| version.clone())
})
}
async fn get_version_from_transaction(
&self,
key: &str,
transaction_id: &TransactionId,
) -> Option<Version> {
self.versions.get(key).and_then(|versions| {
versions
.values()
.rev()
.find(|version| version.transaction_id == *transaction_id)
.cloned()
})
}
async fn get_version_at_timestamp(
&self,
key: &str,
timestamp: &HLCTimestamp,
) -> Option<Version> {
let committed = self.committed_transactions.read().await;
self.versions.get(key).and_then(|versions| {
versions
.range(..=timestamp)
.rev()
.find(|(_, version)| {
committed
.values()
.any(|tx_id| tx_id == &version.transaction_id)
})
.map(|(_, version)| version.clone())
})
}
pub async fn get_all_versions(&self, key: &str) -> Vec<Version> {
self.versions
.get(key)
.map(|versions| versions.values().cloned().collect())
.unwrap_or_default()
}
pub async fn get_statistics(&self) -> MVCCStatistics {
let total_keys = self.versions.len();
let mut total_versions = 0;
let mut max_versions_per_key = 0;
for entry in self.versions.iter() {
let version_count = entry.value().len();
total_versions += version_count;
max_versions_per_key = max_versions_per_key.max(version_count);
}
let active_transactions = self.transactions.read().await.len();
let committed_transactions = self.committed_transactions.read().await.len();
MVCCStatistics {
total_keys,
total_versions,
max_versions_per_key,
active_transactions,
committed_transactions,
}
}
pub fn update_clock(&self, timestamp: &HLCTimestamp) -> HLCTimestamp {
self.clock.update(timestamp)
}
pub fn current_timestamp(&self) -> HLCTimestamp {
self.clock.now()
}
pub async fn scan_prefix(
&self,
transaction_id: &TransactionId,
prefix: &str,
) -> Result<Vec<(String, Triple)>> {
let mut results = Vec::new();
let transactions = self.transactions.read().await;
let snapshot = transactions
.get(transaction_id)
.ok_or_else(|| anyhow::anyhow!("Transaction {} not found", transaction_id))?;
for entry in self.versions.iter() {
let key = entry.key();
if key.starts_with(prefix) {
if let Some(version) = self
.get_visible_version(
key,
&snapshot.timestamp,
snapshot.isolation_level == IsolationLevel::ReadUncommitted,
)
.await
{
if let Some(triple) = version.data {
results.push((key.clone(), triple));
}
}
}
}
Ok(results)
}
async fn get_visible_version(
&self,
key: &str,
timestamp: &HLCTimestamp,
include_uncommitted: bool,
) -> Option<Version> {
if let Some(versions) = self.versions.get(key) {
for (ts, version) in versions.iter().rev() {
if ts <= timestamp {
if include_uncommitted {
return Some(version.clone());
} else {
let committed = self.committed_transactions.read().await;
if committed
.values()
.any(|tx_id| tx_id == &version.transaction_id)
{
return Some(version.clone());
}
}
}
}
}
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MVCCStatistics {
pub total_keys: usize,
pub total_versions: usize,
pub max_versions_per_key: usize,
pub active_transactions: usize,
pub committed_transactions: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hlc_timestamp_ordering() {
let ts1 = HLCTimestamp {
physical: 100,
logical: 0,
node_id: 1,
};
let ts2 = HLCTimestamp {
physical: 100,
logical: 1,
node_id: 1,
};
let ts3 = HLCTimestamp {
physical: 101,
logical: 0,
node_id: 1,
};
let ts4 = HLCTimestamp {
physical: 100,
logical: 0,
node_id: 2,
};
assert!(ts1 < ts2);
assert!(ts2 < ts3);
assert!(ts1 < ts4); }
#[test]
fn test_hlc_generation() {
let clock = HybridLogicalClock::new(1);
let ts1 = clock.now();
let ts2 = clock.now();
assert!(ts2 > ts1);
assert_eq!(ts1.node_id, 1);
assert_eq!(ts2.node_id, 1);
}
#[test]
fn test_hlc_update() {
let clock = HybridLogicalClock::new(1);
let ts1 = clock.now();
let received = HLCTimestamp {
physical: ts1.physical + 1000,
logical: 5,
node_id: 2,
};
let ts2 = clock.update(&received);
assert!(ts2.physical >= received.physical);
assert!(ts2 > ts1);
}
#[tokio::test]
async fn test_mvcc_basic_operations() {
let mvcc = MVCCManager::new(1, MVCCConfig::default());
mvcc.start().await.unwrap();
let tx_id = "tx1".to_string();
let _snapshot = mvcc
.begin_transaction(tx_id.clone(), IsolationLevel::ReadCommitted)
.await
.unwrap();
let triple = Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::Literal::new_typed_literal("value", xsd::STRING.clone()),
);
mvcc.write(&tx_id, "key1", Some(triple.clone()))
.await
.unwrap();
let read_value = mvcc.read(&tx_id, "key1").await.unwrap();
assert!(read_value.is_some());
mvcc.commit_transaction(&tx_id).await.unwrap();
let stats = mvcc.get_statistics().await;
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.total_versions, 1);
mvcc.stop().await.unwrap();
}
#[tokio::test]
async fn test_mvcc_isolation_levels() {
let mvcc = MVCCManager::new(1, MVCCConfig::default());
mvcc.start().await.unwrap();
let tx1 = "tx1".to_string();
mvcc.begin_transaction(tx1.clone(), IsolationLevel::ReadCommitted)
.await
.unwrap();
let triple = Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::Literal::new_typed_literal("value1", xsd::STRING.clone()),
);
mvcc.write(&tx1, "key1", Some(triple.clone()))
.await
.unwrap();
mvcc.commit_transaction(&tx1).await.unwrap();
let tx2 = "tx2".to_string();
mvcc.begin_transaction(tx2.clone(), IsolationLevel::RepeatableRead)
.await
.unwrap();
let value = mvcc.read(&tx2, "key1").await.unwrap();
assert!(value.is_some());
let tx3 = "tx3".to_string();
mvcc.begin_transaction(tx3.clone(), IsolationLevel::ReadCommitted)
.await
.unwrap();
let triple2 = Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::Literal::new_typed_literal("value2", xsd::STRING.clone()),
);
mvcc.write(&tx3, "key1", Some(triple2)).await.unwrap();
mvcc.commit_transaction(&tx3).await.unwrap();
let value2 = mvcc.read(&tx2, "key1").await.unwrap();
assert!(value2.is_some());
mvcc.stop().await.unwrap();
}
#[tokio::test]
async fn test_mvcc_conflict_detection() {
let config = MVCCConfig {
enable_conflict_detection: true,
..Default::default()
};
let mvcc = MVCCManager::new(1, config);
mvcc.start().await.unwrap();
let tx1 = "tx1".to_string();
let tx2 = "tx2".to_string();
mvcc.begin_transaction(tx1.clone(), IsolationLevel::Serializable)
.await
.unwrap();
mvcc.begin_transaction(tx2.clone(), IsolationLevel::Serializable)
.await
.unwrap();
let triple = Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::Literal::new_typed_literal("value", xsd::STRING.clone()),
);
mvcc.read(&tx1, "key1").await.unwrap();
mvcc.read(&tx2, "key1").await.unwrap();
mvcc.write(&tx1, "key1", Some(triple.clone()))
.await
.unwrap();
mvcc.write(&tx2, "key1", Some(triple)).await.unwrap();
mvcc.commit_transaction(&tx1).await.unwrap();
let result = mvcc.commit_transaction(&tx2).await;
assert!(result.is_err());
mvcc.stop().await.unwrap();
}
#[tokio::test]
async fn test_mvcc_rollback() {
let mvcc = MVCCManager::new(1, MVCCConfig::default());
mvcc.start().await.unwrap();
let tx_id = "tx1".to_string();
mvcc.begin_transaction(tx_id.clone(), IsolationLevel::ReadCommitted)
.await
.unwrap();
let triple = Triple::new(
oxirs_core::model::NamedNode::new("http://example.org/s").unwrap(),
oxirs_core::model::NamedNode::new("http://example.org/p").unwrap(),
oxirs_core::model::Literal::new_typed_literal("value", xsd::STRING.clone()),
);
mvcc.write(&tx_id, "key1", Some(triple)).await.unwrap();
mvcc.rollback_transaction(&tx_id).await.unwrap();
let tx2 = "tx2".to_string();
mvcc.begin_transaction(tx2.clone(), IsolationLevel::ReadCommitted)
.await
.unwrap();
let value = mvcc.read(&tx2, "key1").await.unwrap();
assert!(value.is_none());
mvcc.stop().await.unwrap();
}
}