use std::collections::{HashMap, HashSet};
use crate::connection::{Timestamp, SochConnection, TxnId};
use crate::error::{ClientError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadCommitted,
#[default]
SnapshotIsolation,
Serializable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxnState {
Active,
Committed,
Aborted,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct TxnRead {
pub table: String,
pub key: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct TxnWrite {
pub table: String,
pub key: Vec<u8>,
pub value: Option<Vec<u8>>, }
pub struct ClientTransaction<'a> {
pub(crate) conn: &'a SochConnection,
txn_id: TxnId,
start_ts: Timestamp,
state: TxnState,
isolation: IsolationLevel,
writes: Vec<TxnWrite>,
read_set: HashSet<TxnRead>,
local_cache: HashMap<(String, Vec<u8>), Option<Vec<u8>>>,
committed: bool,
}
impl<'a> ClientTransaction<'a> {
pub fn begin(conn: &'a SochConnection, isolation: IsolationLevel) -> Result<Self> {
let txn_id = conn.storage
.begin_transaction()
.map_err(|e| ClientError::Storage(e.to_string()))?;
let start_ts = txn_id;
Ok(Self {
conn,
txn_id,
start_ts,
state: TxnState::Active,
isolation,
writes: Vec::new(),
read_set: HashSet::new(),
local_cache: HashMap::new(),
committed: false,
})
}
pub fn id(&self) -> TxnId {
self.txn_id
}
pub fn start_ts(&self) -> Timestamp {
self.start_ts
}
pub fn state(&self) -> TxnState {
self.state
}
pub fn isolation(&self) -> IsolationLevel {
self.isolation
}
pub fn is_read_only(&self) -> bool {
self.writes.is_empty()
}
pub fn get(&mut self, table: &str, key: &[u8]) -> Result<Option<Vec<u8>>> {
let cache_key = (table.to_string(), key.to_vec());
if let Some(value) = self.local_cache.get(&cache_key) {
return Ok(value.clone());
}
if self.isolation == IsolationLevel::Serializable {
self.read_set.insert(TxnRead {
table: table.to_string(),
key: key.to_vec(),
});
}
self.conn
.storage
.read(self.txn_id, key)
.map_err(|e| ClientError::Storage(format!("Read failed: {}", e)))
}
pub fn put(&mut self, table: &str, key: Vec<u8>, value: Vec<u8>) {
self.local_cache
.insert((table.to_string(), key.clone()), Some(value.clone()));
self.writes.push(TxnWrite {
table: table.to_string(),
key,
value: Some(value),
});
}
pub fn delete(&mut self, table: &str, key: Vec<u8>) {
self.local_cache
.insert((table.to_string(), key.clone()), None);
self.writes.push(TxnWrite {
table: table.to_string(),
key,
value: None,
});
}
pub fn commit(mut self) -> Result<CommitResult> {
if self.state != TxnState::Active {
return Err(ClientError::Transaction("Transaction not active".into()));
}
if self.isolation == IsolationLevel::Serializable && !self.read_set.is_empty() {
self.check_conflicts()?;
}
let commit_ts = self.conn.storage
.commit(self.txn_id)
.map_err(|e| ClientError::Storage(e.to_string()))?;
self.committed = true;
self.state = TxnState::Committed;
Ok(CommitResult {
txn_id: self.txn_id,
commit_ts,
writes_count: self.writes.len(),
})
}
pub fn rollback(mut self) -> Result<()> {
if self.state != TxnState::Active {
return Err(ClientError::Transaction("Transaction not active".into()));
}
self.conn.storage
.abort(self.txn_id)
.map_err(|e| ClientError::Storage(e.to_string()))?;
self.committed = true; self.state = TxnState::Aborted;
Ok(())
}
fn check_conflicts(&self) -> Result<()> {
Ok(())
}
}
impl<'a> Drop for ClientTransaction<'a> {
fn drop(&mut self) {
if !self.committed && self.state == TxnState::Active {
let _ = self.conn.storage.abort(self.txn_id);
}
}
}
#[derive(Debug, Clone)]
pub struct CommitResult {
pub txn_id: TxnId,
pub commit_ts: Timestamp,
pub writes_count: usize,
}
pub struct SnapshotReader<'a> {
conn: &'a SochConnection,
snapshot_ts: Timestamp,
track_visibility: bool,
visibility_log: Vec<VisibilityCheck>,
}
#[derive(Debug, Clone)]
pub struct VisibilityCheck {
pub key: Vec<u8>,
pub visible: bool,
pub reason: VisibilityReason,
}
#[derive(Debug, Clone)]
pub enum VisibilityReason {
CommittedBeforeStart { commit_ts: Timestamp },
InProgressAtStart { txn_id: TxnId },
CommittedAfterStart { commit_ts: Timestamp },
DeletedBeforeStart { delete_ts: Timestamp },
Uncommitted,
}
impl<'a> SnapshotReader<'a> {
pub fn now(conn: &'a SochConnection) -> Result<Self> {
let snapshot_ts = conn.storage
.begin_transaction()
.map_err(|e| ClientError::Storage(e.to_string()))?;
Ok(Self {
conn,
snapshot_ts,
track_visibility: false,
visibility_log: Vec::new(),
})
}
pub fn at_timestamp(conn: &'a SochConnection, ts: Timestamp) -> Result<Self> {
Ok(Self {
conn,
snapshot_ts: ts,
track_visibility: false,
visibility_log: Vec::new(),
})
}
pub fn with_visibility_tracking(mut self) -> Self {
self.track_visibility = true;
self
}
pub fn timestamp(&self) -> Timestamp {
self.snapshot_ts
}
pub fn get(&mut self, table: &str, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.conn
.storage
.read(self.snapshot_ts, key)
.map_err(|e| ClientError::Storage(format!("Read failed: {}", e)))
}
pub fn visibility_diagnostics(&self) -> &[VisibilityCheck] {
&self.visibility_log
}
}
pub struct BatchWriter<'a> {
conn: &'a SochConnection,
#[allow(dead_code)]
pending_txns: Vec<TxnId>,
pending_writes: Vec<TxnWrite>,
}
impl<'a> BatchWriter<'a> {
pub fn new(conn: &'a SochConnection) -> Self {
Self {
conn,
pending_txns: Vec::new(),
pending_writes: Vec::new(),
}
}
pub fn write(&mut self, table: &str, key: Vec<u8>, value: Vec<u8>) {
self.pending_writes.push(TxnWrite {
table: table.to_string(),
key,
value: Some(value),
});
}
pub fn delete(&mut self, table: &str, key: Vec<u8>) {
self.pending_writes.push(TxnWrite {
table: table.to_string(),
key,
value: None,
});
}
pub fn flush(&mut self) -> Result<BatchCommitResult> {
if self.pending_writes.is_empty() {
return Ok(BatchCommitResult::default());
}
let start = std::time::Instant::now();
let txn_id = self.conn.storage
.begin_transaction()
.map_err(|e| ClientError::Storage(e.to_string()))?;
for write in &self.pending_writes {
match &write.value {
Some(value) => {
self.conn.storage
.write(txn_id, write.key.clone(), value.clone())
.map_err(|e| ClientError::Storage(e.to_string()))?;
}
None => {
self.conn.storage
.delete(txn_id, write.key.clone())
.map_err(|e| ClientError::Storage(e.to_string()))?;
}
}
}
let _commit_ts = self.conn.storage
.commit(txn_id)
.map_err(|e| ClientError::Storage(e.to_string()))?;
let duration = start.elapsed();
let count = self.pending_writes.len();
self.pending_writes.clear();
Ok(BatchCommitResult {
transactions_committed: 1,
writes_committed: count,
fsync_latency: duration,
})
}
pub fn pending_count(&self) -> usize {
self.pending_writes.len()
}
}
impl<'a> Drop for BatchWriter<'a> {
fn drop(&mut self) {
if !self.pending_writes.is_empty() {
let _ = self.flush();
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchCommitResult {
pub transactions_committed: usize,
pub writes_committed: usize,
pub fsync_latency: std::time::Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_lifecycle() {
let conn = SochConnection::open("./test").unwrap();
let mut txn = ClientTransaction::begin(&conn, IsolationLevel::SnapshotIsolation).unwrap();
assert_eq!(txn.state(), TxnState::Active);
assert!(txn.is_read_only());
txn.put("test", b"key".to_vec(), b"value".to_vec());
assert!(!txn.is_read_only());
let result = txn.commit().unwrap();
assert!(result.writes_count > 0);
}
#[test]
fn test_read_your_writes() {
let conn = SochConnection::open("./test").unwrap();
let mut txn = ClientTransaction::begin(&conn, IsolationLevel::SnapshotIsolation).unwrap();
txn.put("test", b"key".to_vec(), b"value".to_vec());
let value = txn.get("test", b"key").unwrap();
assert_eq!(value, Some(b"value".to_vec()));
}
#[test]
fn test_rollback() {
let conn = SochConnection::open("./test").unwrap();
let mut txn = ClientTransaction::begin(&conn, IsolationLevel::SnapshotIsolation).unwrap();
txn.put("test", b"key".to_vec(), b"value".to_vec());
txn.rollback().unwrap();
}
#[test]
fn test_snapshot_reader() {
let conn = SochConnection::open("./test").unwrap();
let snapshot = SnapshotReader::now(&conn).unwrap();
assert!(snapshot.timestamp() > 0);
}
#[test]
fn test_batch_writer() {
let conn = SochConnection::open("./test").unwrap();
let mut batch = BatchWriter::new(&conn);
batch.write("test", b"k1".to_vec(), b"v1".to_vec());
batch.write("test", b"k2".to_vec(), b"v2".to_vec());
assert_eq!(batch.pending_count(), 2);
let result = batch.flush().unwrap();
assert_eq!(result.writes_committed, 2);
assert_eq!(batch.pending_count(), 0);
}
#[test]
fn test_isolation_levels() {
assert_eq!(IsolationLevel::default(), IsolationLevel::SnapshotIsolation);
}
}