use crate::mvcc::tx_manager::TxManager;
use crate::types::{MvccTxStatus, TxId, TxKey};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct ConflictError {
pub message: String,
pub txid: TxId,
pub conflicting_keys: Vec<String>,
}
impl ConflictError {
pub fn new(txid: TxId, keys: Vec<String>) -> Self {
let message = format!(
"Transaction {} conflicts with concurrent transactions on keys: {}",
txid,
keys.join(", ")
);
Self {
message,
txid,
conflicting_keys: keys,
}
}
}
impl std::fmt::Display for ConflictError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ConflictError {}
#[derive(Debug)]
pub struct ConflictDetector {
}
impl ConflictDetector {
pub fn new() -> Self {
Self {}
}
pub fn check_conflicts(&self, tx_manager: &TxManager, txid: TxId) -> Vec<String> {
let tx = match tx_manager.tx(txid) {
Some(tx) => tx,
None => return Vec::new(),
};
if tx.status != MvccTxStatus::Active {
return Vec::new();
}
if tx.read_set.is_empty() && tx.write_set.is_empty() {
return Vec::new();
}
let tx_snapshot_ts = tx.start_ts;
let mut conflicts: HashSet<String> = HashSet::new();
for read_key in &tx.read_set {
if tx_manager.has_conflicting_write(read_key, tx_snapshot_ts) {
conflicts.insert(read_key.to_string());
}
}
for write_key in &tx.write_set {
if tx_manager.has_conflicting_write(write_key, tx_snapshot_ts) {
conflicts.insert(write_key.to_string());
}
}
conflicts.into_iter().collect()
}
pub fn has_conflicts(&self, tx_manager: &TxManager, txid: TxId) -> bool {
let tx = match tx_manager.tx(txid) {
Some(tx) => tx,
None => return false,
};
if tx.status != MvccTxStatus::Active {
return false;
}
let tx_snapshot_ts = tx.start_ts;
for read_key in &tx.read_set {
if tx_manager.has_conflicting_write(read_key, tx_snapshot_ts) {
return true;
}
}
for write_key in &tx.write_set {
if tx_manager.has_conflicting_write(write_key, tx_snapshot_ts) {
return true;
}
}
false
}
pub fn validate_commit(&self, tx_manager: &TxManager, txid: TxId) -> Result<(), ConflictError> {
let conflicts = self.check_conflicts(tx_manager, txid);
if conflicts.is_empty() {
Ok(())
} else {
Err(ConflictError::new(txid, conflicts))
}
}
pub fn check_key_conflict(&self, tx_manager: &TxManager, txid: TxId, key: &TxKey) -> bool {
let tx = match tx_manager.tx(txid) {
Some(tx) => tx,
None => return false,
};
if tx.status != MvccTxStatus::Active {
return false;
}
tx_manager.has_conflicting_write(key, tx.start_ts)
}
}
impl Default for ConflictDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConflictType {
ReadWrite,
WriteWrite,
}
#[derive(Debug, Clone)]
pub struct ConflictInfo {
pub key: String,
pub conflict_type: ConflictType,
pub conflicting_write_ts: u64,
}
impl ConflictDetector {
pub fn conflict_details(&self, tx_manager: &TxManager, txid: TxId) -> Vec<ConflictInfo> {
let tx = match tx_manager.tx(txid) {
Some(tx) => tx,
None => return Vec::new(),
};
if tx.status != MvccTxStatus::Active {
return Vec::new();
}
let tx_snapshot_ts = tx.start_ts;
let mut conflicts = Vec::new();
for read_key in &tx.read_set {
if let Some(write_ts) = tx_manager.committed_write_ts(read_key, tx_snapshot_ts) {
conflicts.push(ConflictInfo {
key: read_key.to_string(),
conflict_type: ConflictType::ReadWrite,
conflicting_write_ts: write_ts,
});
}
}
for write_key in &tx.write_set {
if tx.read_set.contains(write_key) {
continue;
}
if let Some(write_ts) = tx_manager.committed_write_ts(write_key, tx_snapshot_ts) {
conflicts.push(ConflictInfo {
key: write_key.to_string(),
conflict_type: ConflictType::WriteWrite,
conflicting_write_ts: write_ts,
});
}
}
conflicts
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key(name: &str) -> TxKey {
TxKey::Key(std::sync::Arc::from(name))
}
fn setup() -> (TxManager, ConflictDetector) {
let tx_mgr = TxManager::new();
let detector = ConflictDetector::new();
(tx_mgr, detector)
}
#[test]
fn test_no_conflicts_empty_tx() {
let (mut tx_mgr, detector) = setup();
let (txid, _) = tx_mgr.begin_tx();
let conflicts = detector.check_conflicts(&tx_mgr, txid);
assert!(conflicts.is_empty());
}
#[test]
fn test_no_conflicts_serial_commits() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.commit_tx(txid1).expect("expected value");
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid2, key("key1"));
let conflicts = detector.check_conflicts(&tx_mgr, txid2);
assert!(conflicts.is_empty() || conflicts.contains(&key("key1").to_string()));
}
#[test]
fn test_write_write_conflict() {
let (mut tx_mgr, detector) = setup();
let (txid1, start_ts1) = tx_mgr.begin_tx();
let (txid2, start_ts2) = tx_mgr.begin_tx();
assert_eq!(start_ts1, start_ts2);
tx_mgr.record_write(txid1, key("shared_key"));
tx_mgr.record_write(txid2, key("shared_key"));
tx_mgr.commit_tx(txid1).expect("expected value");
let conflicts = detector.check_conflicts(&tx_mgr, txid2);
assert!(conflicts.contains(&key("shared_key").to_string()));
}
#[test]
fn test_read_write_conflict() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.record_read(txid2, key("key1"));
tx_mgr.commit_tx(txid1).expect("expected value");
let conflicts = detector.check_conflicts(&tx_mgr, txid2);
assert!(conflicts.contains(&key("key1").to_string()));
}
#[test]
fn test_has_conflicts_fast_path() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.commit_tx(txid1).expect("expected value");
tx_mgr.record_write(txid2, key("key1"));
assert!(detector.has_conflicts(&tx_mgr, txid2));
}
#[test]
fn test_validate_commit_success() {
let (mut tx_mgr, detector) = setup();
let (txid, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid, key("unique_key"));
let result = detector.validate_commit(&tx_mgr, txid);
assert!(result.is_ok());
}
#[test]
fn test_validate_commit_failure() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key"));
tx_mgr.record_write(txid2, key("key"));
tx_mgr.commit_tx(txid1).expect("expected value");
let result = detector.validate_commit(&tx_mgr, txid2);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.txid, txid2);
assert!(err.conflicting_keys.contains(&key("key").to_string()));
}
#[test]
fn test_check_key_conflict() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.commit_tx(txid1).expect("expected value");
assert!(detector.check_key_conflict(&tx_mgr, txid2, &key("key1")));
assert!(!detector.check_key_conflict(&tx_mgr, txid2, &key("key2")));
}
#[test]
fn test_no_conflict_different_keys() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.record_write(txid2, key("key2"));
tx_mgr.commit_tx(txid1).expect("expected value");
let conflicts = detector.check_conflicts(&tx_mgr, txid2);
assert!(conflicts.is_empty());
}
#[test]
fn test_conflict_error_display() {
let err = ConflictError::new(42, vec!["key1".to_string(), "key2".to_string()]);
let display = err.to_string();
assert!(display.contains("42"));
assert!(display.contains("key1"));
assert!(display.contains("key2"));
}
#[test]
fn test_conflict_details() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("key1"));
tx_mgr.record_read(txid2, key("key1"));
tx_mgr.record_write(txid2, key("key2"));
tx_mgr.record_write(txid1, key("key2"));
tx_mgr.commit_tx(txid1).expect("expected value");
let details = detector.conflict_details(&tx_mgr, txid2);
assert_eq!(details.len(), 2);
let key1_conflict = details.iter().find(|c| c.key == "key:key1");
assert!(key1_conflict.is_some());
assert_eq!(
key1_conflict.expect("expected value").conflict_type,
ConflictType::ReadWrite
);
let key2_conflict = details.iter().find(|c| c.key == "key:key2");
assert!(key2_conflict.is_some());
assert_eq!(
key2_conflict.expect("expected value").conflict_type,
ConflictType::WriteWrite
);
}
#[test]
fn test_conflict_type_eq() {
assert_eq!(ConflictType::ReadWrite, ConflictType::ReadWrite);
assert_eq!(ConflictType::WriteWrite, ConflictType::WriteWrite);
assert_ne!(ConflictType::ReadWrite, ConflictType::WriteWrite);
}
#[test]
fn test_detector_with_aborted_tx() {
let (mut tx_mgr, detector) = setup();
let (txid, _) = tx_mgr.begin_tx();
tx_mgr.abort_tx(txid);
let conflicts = detector.check_conflicts(&tx_mgr, txid);
assert!(conflicts.is_empty());
}
#[test]
fn test_detector_with_committed_tx() {
let (mut tx_mgr, detector) = setup();
let (txid, _) = tx_mgr.begin_tx();
tx_mgr.commit_tx(txid).expect("expected value");
let conflicts = detector.check_conflicts(&tx_mgr, txid);
assert!(conflicts.is_empty());
}
#[test]
fn test_multiple_concurrent_writers() {
let (mut tx_mgr, detector) = setup();
let (txid1, _) = tx_mgr.begin_tx();
let (txid2, _) = tx_mgr.begin_tx();
let (txid3, _) = tx_mgr.begin_tx();
tx_mgr.record_write(txid1, key("hot_key"));
tx_mgr.record_write(txid2, key("hot_key"));
tx_mgr.record_write(txid3, key("hot_key"));
assert!(detector.validate_commit(&tx_mgr, txid1).is_ok());
tx_mgr.commit_tx(txid1).expect("expected value");
assert!(detector.validate_commit(&tx_mgr, txid2).is_err());
assert!(detector.validate_commit(&tx_mgr, txid3).is_err());
}
}