use bytes::Bytes;
use parking_lot::{Mutex, RwLock};
use std::collections::HashSet;
use std::ops::RangeBounds;
use std::sync::Arc;
use uuid::Uuid;
use crate::batch::WriteBatch;
use crate::bytes_range::BytesRange;
use crate::config::{MergeOptions, PutOptions, ReadOptions, ScanOptions, WriteOptions};
use crate::db::DbInner;
use crate::db::WriteHandle;
use crate::db_iter::{DbIterator, DbIteratorRangeTracker};
use crate::error::SlateDBError;
use crate::transaction_manager::{IsolationLevel, TransactionManager};
use crate::types::KeyValue;
use crate::DbRead;
pub struct DbTransaction {
txn_id: Uuid,
started_seq: u64,
txn_manager: Arc<TransactionManager>,
write_batch: RwLock<WriteBatch>,
db_inner: Arc<DbInner>,
isolation_level: IsolationLevel,
range_trackers: Mutex<Vec<Arc<DbIteratorRangeTracker>>>,
untracked_write_keys: RwLock<HashSet<Bytes>>,
}
impl DbTransaction {
#[allow(unused)]
pub(crate) fn new(
db_inner: Arc<DbInner>,
txn_manager: Arc<TransactionManager>,
isolation_level: IsolationLevel,
) -> Self {
let (txn_id, seq) = txn_manager.new_transaction();
Self {
txn_id,
started_seq: seq,
txn_manager,
write_batch: RwLock::new(WriteBatch::new().with_txn_id(txn_id)),
db_inner,
isolation_level,
range_trackers: Mutex::new(Vec::new()),
untracked_write_keys: RwLock::new(HashSet::new()),
}
}
pub async fn get<K: AsRef<[u8]> + Send>(&self, key: K) -> Result<Option<Bytes>, crate::Error> {
self.get_with_options(key, &ReadOptions::default()).await
}
pub async fn get_with_options<K: AsRef<[u8]> + Send>(
&self,
key: K,
options: &ReadOptions,
) -> Result<Option<Bytes>, crate::Error> {
self.get_key_value_with_options(key, options)
.await
.map(|kv_opt| kv_opt.map(|kv| kv.value))
}
pub async fn get_key_value<K: AsRef<[u8]> + Send>(
&self,
key: K,
) -> Result<Option<KeyValue>, crate::Error> {
self.get_key_value_with_options(key, &ReadOptions::default())
.await
}
pub async fn get_key_value_with_options<K: AsRef<[u8]> + Send>(
&self,
key: K,
options: &ReadOptions,
) -> Result<Option<KeyValue>, crate::Error> {
self.db_inner.check_closed()?;
if self.isolation_level == IsolationLevel::SerializableSnapshot {
let key_bytes = Bytes::copy_from_slice(key.as_ref());
let mut read_keys = HashSet::new();
read_keys.insert(key_bytes);
self.txn_manager.track_read_keys(&self.txn_id, read_keys);
}
let db_state = self.db_inner.state.read().view();
let write_batch_cloned = self.write_batch.read().clone();
let kv = self
.db_inner
.reader
.get_key_value_with_options(
key,
options,
&db_state,
Some(write_batch_cloned),
Some(self.started_seq),
)
.await
.map_err(crate::Error::from)?;
Ok(kv)
}
pub async fn scan<K, T>(&self, range: T) -> Result<DbIterator, crate::Error>
where
K: AsRef<[u8]> + Send,
T: RangeBounds<K> + Send,
{
self.scan_with_options(range, &ScanOptions::default()).await
}
pub async fn scan_with_options<K, T>(
&self,
range: T,
options: &ScanOptions,
) -> Result<DbIterator, crate::Error>
where
K: AsRef<[u8]> + Send,
T: RangeBounds<K> + Send,
{
let start = range
.start_bound()
.map(|b| Bytes::copy_from_slice(b.as_ref()));
let end = range
.end_bound()
.map(|b| Bytes::copy_from_slice(b.as_ref()));
let range = (start, end);
let range_tracker = if self.isolation_level == IsolationLevel::SerializableSnapshot {
let tracker = Arc::new(DbIteratorRangeTracker::new());
self.range_trackers.lock().push(tracker.clone());
Some(tracker)
} else {
None
};
self.db_inner.check_closed()?;
let db_state = self.db_inner.state.read().view();
let write_batch_cloned = self.write_batch.read().clone();
self.db_inner
.reader
.scan_with_options(
BytesRange::from(range),
options,
&db_state,
Some(write_batch_cloned),
Some(self.started_seq),
range_tracker,
)
.await
.map_err(Into::into)
}
pub async fn scan_prefix<P>(&self, prefix: P) -> Result<DbIterator, crate::Error>
where
P: AsRef<[u8]> + Send,
{
self.scan_prefix_with_options(prefix, &ScanOptions::default())
.await
}
pub async fn scan_prefix_with_options<P>(
&self,
prefix: P,
options: &ScanOptions,
) -> Result<DbIterator, crate::Error>
where
P: AsRef<[u8]> + Send,
{
self.scan_with_options(BytesRange::from_prefix(prefix.as_ref()), options)
.await
}
pub fn put<K, V>(&self, key: K, value: V) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
self.put_with_options(key, value, &PutOptions::default())
}
pub fn put_with_options<K, V>(
&self,
key: K,
value: V,
options: &PutOptions,
) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
self.write_batch
.write()
.put_with_options(key, value, options);
Ok(())
}
pub fn mark_read<K, I>(&self, keys: I) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
I: IntoIterator<Item = K>,
{
let read_keys = keys.into_iter().map(|k| Bytes::copy_from_slice(k.as_ref()));
self.txn_manager.track_read_keys(&self.txn_id, read_keys);
Ok(())
}
pub fn unmark_write<K, I>(&self, keys: I) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
I: IntoIterator<Item = K>,
{
let mut untracked_keys = self.untracked_write_keys.write();
untracked_keys.extend(keys.into_iter().map(|k| Bytes::copy_from_slice(k.as_ref())));
Ok(())
}
pub fn merge<K, V>(&self, key: K, value: V) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
self.merge_with_options(key, value, &MergeOptions::default())
}
pub fn merge_with_options<K, V>(
&self,
key: K,
value: V,
options: &MergeOptions,
) -> Result<(), crate::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
self.write_batch
.write()
.merge_with_options(key, value, options);
Ok(())
}
pub fn delete<K: AsRef<[u8]>>(&self, key: K) -> Result<(), crate::Error> {
self.write_batch.write().delete(key);
Ok(())
}
pub async fn commit(self) -> Result<Option<WriteHandle>, crate::Error> {
self.commit_with_options(&WriteOptions::default()).await
}
pub async fn commit_with_options(
self,
options: &WriteOptions,
) -> Result<Option<WriteHandle>, crate::Error> {
let write_batch = self.write_batch.read().clone();
if self.isolation_level == IsolationLevel::SerializableSnapshot {
for tracker in self.range_trackers.lock().iter() {
if tracker.has_data() {
if let Some(range) = tracker.get_range() {
self.txn_manager.track_read_range(&self.txn_id, range);
}
}
}
}
if write_batch.is_empty() {
if let Some(txn_id) = write_batch.txn_id.as_ref() {
if self.txn_manager.check_has_conflict(txn_id) {
return Err(SlateDBError::TransactionConflict.into());
}
}
return Ok(None);
}
let tracked_write_keys = {
let untracked_write_keys = self.untracked_write_keys.read();
write_batch
.keys()
.into_iter()
.filter(|key| !untracked_write_keys.contains(key))
.collect()
};
self.txn_manager
.track_write_keys(&self.txn_id, &tracked_write_keys);
self.db_inner
.write_with_options(write_batch, options)
.await
.map(Some)
.map_err(Into::into)
}
pub fn rollback(self) {
}
pub fn seqnum(&self) -> u64 {
self.started_seq
}
pub fn id(&self) -> Uuid {
self.txn_id
}
}
#[async_trait::async_trait]
impl DbRead for DbTransaction {
async fn get_with_options<K: AsRef<[u8]> + Send>(
&self,
key: K,
options: &ReadOptions,
) -> Result<Option<Bytes>, crate::Error> {
self.get_with_options(key, options).await
}
async fn get_key_value_with_options<K: AsRef<[u8]> + Send>(
&self,
key: K,
options: &ReadOptions,
) -> Result<Option<KeyValue>, crate::Error> {
self.get_key_value_with_options(key, options).await
}
async fn scan_with_options<K, T>(
&self,
range: T,
options: &ScanOptions,
) -> Result<DbIterator, crate::Error>
where
K: AsRef<[u8]> + Send,
T: RangeBounds<K> + Send,
{
self.scan_with_options(range, options).await
}
}
impl Drop for DbTransaction {
fn drop(&mut self) {
self.txn_manager.drop_txn(&self.txn_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merge_operator::{MergeOperator, MergeOperatorError};
use crate::object_store::memory::InMemory;
use rstest::rstest;
use std::sync::Arc;
struct CounterMergeOperator;
impl MergeOperator for CounterMergeOperator {
fn merge(
&self,
_key: &Bytes,
existing_value: Option<Bytes>,
value: Bytes,
) -> Result<Bytes, MergeOperatorError> {
let existing = existing_value
.map(|v| u64::from_le_bytes(v.as_ref().try_into().unwrap()))
.unwrap_or(0);
let operand = u64::from_le_bytes(value.as_ref().try_into().unwrap());
Ok(Bytes::copy_from_slice(&(existing + operand).to_le_bytes()))
}
fn merge_batch(
&self,
_key: &Bytes,
existing_value: Option<Bytes>,
operands: &[Bytes],
) -> Result<Bytes, MergeOperatorError> {
let mut total = existing_value
.map(|v| u64::from_le_bytes(v.as_ref().try_into().unwrap()))
.unwrap_or(0);
for operand in operands {
total += u64::from_le_bytes(operand.as_ref().try_into().unwrap());
}
Ok(Bytes::copy_from_slice(&total.to_le_bytes()))
}
}
#[tokio::test]
async fn test_txn_basic_visibility() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
db.put(b"k2", b"v2").await.unwrap();
let value = txn.get(b"k1").await.unwrap();
assert_eq!(value, Some(Bytes::from_static(b"v1")));
txn.commit().await.unwrap();
}
#[tokio::test]
async fn test_txn_write_visibility_in_txn() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
txn.put(b"k1", b"v2").unwrap();
let value = txn.get(b"k1").await.unwrap();
assert_eq!(value, Some(Bytes::from_static(b"v2")));
txn.commit().await.unwrap();
}
#[tokio::test]
async fn test_txn_si_commit_conflict() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn1 = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn1.put(b"k1", b"v2").unwrap();
let txn2 = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn2.put(b"k1", b"v3").unwrap();
txn1.commit().await.unwrap();
let result = txn2.commit().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_txn_si_commit_conflict_with_db_writes() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn1 = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn1.put(b"k1", b"v2").unwrap();
db.put(b"k1", b"v3").await.unwrap();
let result = txn1.commit().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_txn_ssi_commit_conflict() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
db.put(b"k2", b"v2.1").await.unwrap();
let txn1 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
txn1.put(b"k1", b"v2").unwrap();
txn1.put(b"k2", b"v2.2").unwrap();
let txn2 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
let val2 = txn2.get(b"k2").await.unwrap();
assert_eq!(val2, Some(Bytes::from_static(b"v2.1")));
txn2.put(b"k3", b"v3").unwrap();
txn1.commit().await.unwrap();
let result = txn2.commit().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_txn_ssi_commit_conflit_with_ranges() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
db.put(b"k2", b"v2.1").await.unwrap();
db.put(b"k3", b"v3").await.unwrap();
let txn1 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
let txn2 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
{
let mut iter = txn2.scan(&b"k2"[..]..=&b"k3"[..]).await.unwrap();
while let Some(_kv) = iter.next().await.unwrap() {
}
}
txn1.put(b"k2", b"v2.2").unwrap();
txn1.commit().await.unwrap();
txn2.put(b"k4", b"v4").unwrap();
let result = txn2.commit().await;
assert!(result.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_txn_commit_await_durable_false() {
use crate::config::{DurabilityLevel::*, ReadOptions, WriteOptions};
use fail_parallel::FailPointRegistry;
let fp_registry = Arc::new(FailPointRegistry::new());
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::builder("/tmp/test_txn_commit_await_durable_false", object_store)
.with_fp_registry(fp_registry.clone())
.build()
.await
.unwrap();
fail_parallel::cfg(fp_registry.clone(), "write-wal-sst-io-error", "pause").unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn.put(b"k", b"v").unwrap();
txn.commit_with_options(&WriteOptions {
await_durable: false,
})
.await
.unwrap();
let val = db
.get_with_options(b"k", &ReadOptions::new().with_durability_filter(Memory))
.await
.unwrap();
assert_eq!(val, Some(Bytes::from_static(b"v")));
let val = db
.get_with_options(b"k", &ReadOptions::new().with_durability_filter(Remote))
.await
.unwrap();
assert_eq!(val, None);
fail_parallel::cfg(fp_registry.clone(), "write-wal-sst-io-error", "off").unwrap();
db.close().await.unwrap();
}
#[derive(Debug, Clone)]
struct TransactionTestCase {
name: &'static str,
isolation_level: IsolationLevel,
initial_data: Vec<(Bytes, Bytes)>,
operations: Vec<TransactionTestOp>,
expected_results: Vec<TransactionTestOpResult>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
enum TransactionTestOp {
TxnGet(Bytes),
TxnScan(Bytes, Bytes),
TxnPut(Bytes, Bytes),
TxnDelete(Bytes),
TxnMarkRead(Bytes),
TxnCommit,
TxnRollback,
DbPut(Bytes, Bytes),
DbGet(Bytes),
}
#[derive(Debug, Clone, PartialEq)]
enum TransactionTestOpResult {
Got(Option<Bytes>),
Scanned(Vec<Bytes>),
Empty,
Conflicted,
Invalid,
}
async fn execute_transaction_test_ops(
db: crate::Db,
operations: Vec<TransactionTestOp>,
initial_data: Vec<(Bytes, Bytes)>,
isolation_level: IsolationLevel,
) -> Vec<TransactionTestOpResult> {
for (key, value) in initial_data {
db.put(key, value).await.unwrap();
}
let mut txn_opt = Some(db.begin(isolation_level).await.unwrap());
let mut results = Vec::new();
for operation in operations.iter() {
let result = match (txn_opt.as_mut(), operation) {
(Some(txn), TransactionTestOp::TxnGet(key)) => {
let val = txn.get(key).await.unwrap();
TransactionTestOpResult::Got(val)
}
(Some(txn), TransactionTestOp::TxnScan(start, end)) => {
let mut iter = txn.scan(&start[..]..=&end[..]).await.unwrap();
let mut scanned_keys = Vec::new();
while let Some(kv) = iter.next().await.unwrap() {
scanned_keys.push(kv.key);
}
TransactionTestOpResult::Scanned(scanned_keys)
}
(Some(txn), TransactionTestOp::TxnPut(key, value)) => {
txn.put(key, value).unwrap();
TransactionTestOpResult::Empty
}
(Some(txn), TransactionTestOp::TxnDelete(key)) => {
txn.delete(key).unwrap();
TransactionTestOpResult::Empty
}
(Some(txn), TransactionTestOp::TxnMarkRead(key)) => {
txn.mark_read([key]).unwrap();
TransactionTestOpResult::Empty
}
(Some(_txn), TransactionTestOp::TxnCommit) => {
let txn = txn_opt.take().unwrap();
match txn.commit().await {
Ok(_) => TransactionTestOpResult::Empty,
Err(_) => TransactionTestOpResult::Conflicted,
}
}
(Some(_txn), TransactionTestOp::TxnRollback) => {
let txn = txn_opt.take().unwrap();
txn.rollback();
TransactionTestOpResult::Empty
}
(_, TransactionTestOp::DbPut(key, value)) => {
db.put(key, value).await.unwrap();
TransactionTestOpResult::Empty
}
(_, TransactionTestOp::DbGet(key)) => {
let val = db.get(key).await.unwrap();
TransactionTestOpResult::Got(val)
}
(None, TransactionTestOp::TxnGet(_))
| (None, TransactionTestOp::TxnScan(_, _))
| (None, TransactionTestOp::TxnPut(_, _))
| (None, TransactionTestOp::TxnDelete(_))
| (None, TransactionTestOp::TxnMarkRead(_))
| (None, TransactionTestOp::TxnCommit)
| (None, TransactionTestOp::TxnRollback) => TransactionTestOpResult::Invalid,
};
results.push(result);
}
results
}
#[rstest]
#[case::ssi_basic_visibility(
TransactionTestCase {
name: "ssi_basic_visibility",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_write_visibility_in_txn(
TransactionTestCase {
name: "ssi_write_visibility_in_txn",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(Some(Bytes::from("v2"))),
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_delete_visibility_in_txn(
TransactionTestCase {
name: "ssi_delete_visibility_in_txn",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnDelete(Bytes::from("k1")),
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(None),
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_rollback_visibility(
TransactionTestCase {
name: "ssi_rollback_visibility",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnRollback,
TransactionTestOp::DbGet(Bytes::from("k1")),
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
]
}
)]
#[case::si_concurrent_read_snapshot(
TransactionTestCase {
name: "si_concurrent_read_snapshot",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_write_write_conflict(
TransactionTestCase {
name: "ssi_write_write_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v3")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::ssi_read_write_conflict(
TransactionTestCase {
name: "ssi_read_write_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v1")),
TransactionTestOp::TxnPut(Bytes::from("k2"), Bytes::from("v2.1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::si_read_write_no_conflict(
TransactionTestCase {
name: "si_read_write_no_conflict",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_write_read_conflict(
TransactionTestCase {
name: "ssi_write_read_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnPut(Bytes::from("k3"), Bytes::from("v3")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::si_write_read_no_conflict(
TransactionTestCase {
name: "si_write_read_no_conflict",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_range_write_conflict(
TransactionTestCase {
name: "ssi_range_write_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![
(Bytes::from("k1"), Bytes::from("v1")),
(Bytes::from("k2"), Bytes::from("v2")),
(Bytes::from("k3"), Bytes::from("v3")),
(Bytes::from("k4"), Bytes::from("v4")),
(Bytes::from("k5"), Bytes::from("v5"))
],
operations: vec![
TransactionTestOp::TxnScan(Bytes::from("k1"), Bytes::from("k5")),
TransactionTestOp::DbPut(Bytes::from("k3"), Bytes::from("v3_new")),
TransactionTestOp::TxnPut(Bytes::from("k100"), Bytes::from("v100")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Scanned(vec![Bytes::from("k1"), Bytes::from("k2"), Bytes::from("k3"), Bytes::from("k4"), Bytes::from("k5")]),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::si_range_write_no_conflict(
TransactionTestCase {
name: "si_range_write_no_conflict",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![
(Bytes::from("k1"), Bytes::from("v1")),
(Bytes::from("k2"), Bytes::from("v2")),
(Bytes::from("k3"), Bytes::from("v3")),
(Bytes::from("k4"), Bytes::from("v4")),
(Bytes::from("k5"), Bytes::from("v5"))
],
operations: vec![
TransactionTestOp::TxnScan(Bytes::from("k1"), Bytes::from("k5")),
TransactionTestOp::DbPut(Bytes::from("k3"), Bytes::from("v3_new")),
TransactionTestOp::TxnPut(Bytes::from("k100"), Bytes::from("v100")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Scanned(vec![Bytes::from("k1"), Bytes::from("k2"), Bytes::from("k3"), Bytes::from("k4"), Bytes::from("k5")]),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_mark_read_conflict(
TransactionTestCase {
name: "ssi_mark_read_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnPut(Bytes::from("k2"), Bytes::from("v2")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::ssi_mark_read_multiple_keys_conflict(
TransactionTestCase {
name: "ssi_mark_read_multiple_keys_conflict",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![
(Bytes::from("k1"), Bytes::from("v1")),
(Bytes::from("k2"), Bytes::from("v2")),
(Bytes::from("k3"), Bytes::from("v3"))
],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::TxnMarkRead(Bytes::from("k2")),
TransactionTestOp::DbPut(Bytes::from("k2"), Bytes::from("v2_new")),
TransactionTestOp::TxnPut(Bytes::from("k4"), Bytes::from("v4")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::ssi_mark_read_no_conflict_on_different_key(
TransactionTestCase {
name: "ssi_mark_read_no_conflict_on_different_key",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![
(Bytes::from("k1"), Bytes::from("v1")),
(Bytes::from("k2"), Bytes::from("v2"))
],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k2"), Bytes::from("v2_new")),
TransactionTestOp::TxnPut(Bytes::from("k3"), Bytes::from("v3")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
]
}
)]
#[case::ssi_mark_read_conflict_without_write(
TransactionTestCase {
name: "ssi_mark_read_conflict_without_write",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::ssi_get_conflict_without_write(
TransactionTestCase {
name: "ssi_get_conflict_without_write",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnGet(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Got(Some(Bytes::from("v1"))),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::ssi_scan_conflict_without_write(
TransactionTestCase {
name: "ssi_scan_conflict_without_write",
isolation_level: IsolationLevel::SerializableSnapshot,
initial_data: vec![
(Bytes::from("k1"), Bytes::from("v1")),
(Bytes::from("k2"), Bytes::from("v2")),
(Bytes::from("k3"), Bytes::from("v3")),
],
operations: vec![
TransactionTestOp::TxnScan(Bytes::from("k1"), Bytes::from("k3")),
TransactionTestOp::DbPut(Bytes::from("k2"), Bytes::from("v2_new")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Scanned(vec![
Bytes::from("k1"),
Bytes::from("k2"),
Bytes::from("k3"),
]),
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::si_mark_read_conflict(
TransactionTestCase {
name: "si_mark_read_conflict",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::DbPut(Bytes::from("k1"), Bytes::from("v2")),
TransactionTestOp::TxnPut(Bytes::from("k2"), Bytes::from("v2")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[case::si_mark_read_write_write_conflict(
TransactionTestCase {
name: "si_mark_read_write_write_conflict",
isolation_level: IsolationLevel::Snapshot,
initial_data: vec![(Bytes::from("k1"), Bytes::from("v1"))],
operations: vec![
TransactionTestOp::TxnMarkRead(Bytes::from("k1")),
TransactionTestOp::TxnPut(Bytes::from("k2"), Bytes::from("v2")),
TransactionTestOp::DbPut(Bytes::from("k2"), Bytes::from("v2_db")),
TransactionTestOp::TxnCommit,
],
expected_results: vec![
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Empty,
TransactionTestOpResult::Conflicted,
]
}
)]
#[tokio::test]
async fn test_txn_table_driven(#[case] test_case: TransactionTestCase) {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open(test_case.name, object_store).await.unwrap();
let initial_data_bytes: Vec<(Bytes, Bytes)> = test_case.initial_data.clone();
let results = execute_transaction_test_ops(
db,
test_case.operations,
initial_data_bytes,
test_case.isolation_level,
)
.await;
for (i, (result, expected)) in results
.iter()
.zip(test_case.expected_results.iter())
.enumerate()
{
assert_eq!(
result, expected,
"Test '{}' failed at operation {}: expected {:?}, got {:?}",
test_case.name, i, expected, result
);
}
}
#[tokio::test]
async fn test_txn_scan_sees_concurrent_put_in_same_txn() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_db", object_store).await.unwrap();
db.put(b"k1", b"v1").await.unwrap();
db.put(b"k3", b"v3").await.unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
{
let mut iter = txn.scan(&b"k1"[..]..=&b"k3"[..]).await.unwrap();
txn.put(b"k2", b"v2").unwrap();
txn.put(b"k3", b"v3_updated").unwrap();
let mut results = Vec::new();
while let Some(kv) = iter.next().await.unwrap() {
results.push((kv.key.clone(), kv.value.clone()));
}
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, Bytes::from_static(b"k1"));
assert_eq!(results[0].1, Bytes::from_static(b"v1"));
assert_eq!(results[1].0, Bytes::from_static(b"k3"));
assert_eq!(results[1].1, Bytes::from_static(b"v3"));
}
{
let mut iter2 = txn.scan(&b"k1"[..]..=&b"k3"[..]).await.unwrap();
let mut results2 = Vec::new();
while let Some(kv) = iter2.next().await.unwrap() {
results2.push((kv.key.clone(), kv.value.clone()));
}
assert_eq!(results2.len(), 3);
assert_eq!(results2[0].0, Bytes::from_static(b"k1"));
assert_eq!(results2[1].0, Bytes::from_static(b"k2"));
assert_eq!(results2[1].1, Bytes::from_static(b"v2"));
assert_eq!(results2[2].0, Bytes::from_static(b"k3"));
assert_eq!(results2[2].1, Bytes::from_static(b"v3_updated"));
}
txn.commit().await.unwrap();
let value = db.get(b"k2").await.unwrap();
assert_eq!(value, Some(Bytes::from_static(b"v2")));
}
#[tokio::test]
async fn test_mark_read_equivalent_to_get_in_ssi() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_mark_read_equivalent", object_store)
.await
.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn1 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
txn1.mark_read([b"k1"]).unwrap();
db.put(b"k1", b"v2").await.unwrap();
txn1.put(b"k2", b"v2").unwrap();
let result1 = txn1.commit().await;
assert!(
result1.is_err(),
"Transaction with mark_read() should conflict"
);
db.put(b"k1", b"v1").await.unwrap();
let txn2 = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
let _ = txn2.get(b"k1").await.unwrap();
db.put(b"k1", b"v2_again").await.unwrap();
txn2.put(b"k3", b"v3").unwrap();
let result2 = txn2.commit().await;
assert!(result2.is_err(), "Transaction with get() should conflict");
}
#[tokio::test]
async fn test_mark_read_multiple_keys_at_once() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_mark_read_multiple", object_store)
.await
.unwrap();
db.put(b"k1", b"v1").await.unwrap();
db.put(b"k2", b"v2").await.unwrap();
db.put(b"k3", b"v3").await.unwrap();
let txn = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
txn.mark_read([b"k1", b"k2", b"k3"]).unwrap();
db.put(b"k2", b"v2_modified").await.unwrap();
txn.put(b"k4", b"v4").unwrap();
let result = txn.commit().await;
assert!(
result.is_err(),
"Transaction should conflict because k2 was modified"
);
}
#[tokio::test]
async fn test_unmark_write_ignores_write_write_conflicts() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_unmark_write_ww", object_store)
.await
.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn.put(b"k1", b"v2").unwrap();
txn.unmark_write([b"k1"]).unwrap();
db.put(b"k1", b"v3").await.unwrap();
let result = txn.commit().await;
assert!(
result.is_ok(),
"Transaction should not conflict for untracked write key"
);
let value = db.get(b"k1").await.unwrap();
assert_eq!(value, Some(Bytes::from_static(b"v2")));
}
#[tokio::test]
async fn test_unmark_write_only_excludes_selected_keys() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_unmark_write_partial", object_store)
.await
.unwrap();
db.put(b"k1", b"v1").await.unwrap();
db.put(b"k2", b"v2").await.unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn.put(b"k1", b"v1_txn").unwrap();
txn.put(b"k2", b"v2_txn").unwrap();
txn.unmark_write([b"k1"]).unwrap();
db.put(b"k2", b"v2_db").await.unwrap();
let result = txn.commit().await;
assert!(
result.is_err(),
"Transaction should still conflict on tracked key k2"
);
}
#[tokio::test]
async fn test_unmark_write_avoids_read_write_conflicts_for_others() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_unmark_write_rw", object_store)
.await
.unwrap();
db.put(b"k1", b"v1").await.unwrap();
let reader_txn = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
let _ = reader_txn.get(b"k1").await.unwrap();
let writer_txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
writer_txn.put(b"k1", b"v2").unwrap();
writer_txn.unmark_write([b"k1"]).unwrap();
writer_txn.commit().await.unwrap();
reader_txn.put(b"k2", b"v2").unwrap();
let result = reader_txn.commit().await;
assert!(
result.is_ok(),
"Reader transaction should not conflict with untracked write key"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_unmark_write_merge_counter_aggregates_under_high_concurrency() {
const CONCURRENT_TXNS: usize = 32;
const ROUNDS: usize = 20;
const MERGE_INCREMENT: [u8; 8] = 1u64.to_le_bytes();
const EXPECTED: u64 = (CONCURRENT_TXNS * ROUNDS) as u64;
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::builder("test_unmark_write_merge_counter", object_store)
.with_merge_operator(Arc::new(CounterMergeOperator))
.build()
.await
.unwrap();
for _ in 0..ROUNDS {
let barrier = Arc::new(tokio::sync::Barrier::new(CONCURRENT_TXNS));
let mut handles = Vec::with_capacity(CONCURRENT_TXNS);
for _ in 0..CONCURRENT_TXNS {
let db = db.clone();
let barrier = barrier.clone();
handles.push(tokio::spawn(async move {
barrier.wait().await;
let txn = db
.begin(IsolationLevel::SerializableSnapshot)
.await
.unwrap();
txn.merge(b"counter", MERGE_INCREMENT).unwrap();
txn.unmark_write([b"counter"]).unwrap();
txn.commit().await.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
}
let value = db.get(b"counter").await.unwrap().unwrap();
let total = u64::from_le_bytes(value.as_ref().try_into().unwrap());
assert_eq!(total, EXPECTED);
}
fn test_db_options(
min_filter_keys: u32,
l0_sst_size_bytes: usize,
compactor_options: Option<crate::config::CompactorOptions>,
) -> crate::config::Settings {
crate::config::Settings {
flush_interval: None,
#[cfg(feature = "wal_disable")]
wal_enabled: true,
manifest_poll_interval: std::time::Duration::from_secs(3600),
manifest_update_timeout: std::time::Duration::from_secs(300),
max_unflushed_bytes: 134_217_728,
l0_max_ssts: 8,
l0_flush_parallelism: 1,
min_filter_keys,
filter_bits_per_key: 10,
l0_sst_size_bytes,
compactor_options,
compression_codec: None,
object_store_cache_options: crate::config::ObjectStoreCacheOptions::default(),
garbage_collector_options: None,
default_ttl: None,
block_format: None,
}
}
#[tokio::test]
async fn test_txn_commit_returns_write_handle() {
use slatedb_common::clock::MockSystemClock;
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let path = "/tmp/test_txn_commit_returns_write_handle";
let clock = Arc::new(MockSystemClock::new());
let db = crate::Db::builder(path, object_store)
.with_settings(test_db_options(0, 1024, None))
.with_system_clock(clock.clone())
.build()
.await
.unwrap();
clock.set(100);
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn.put(b"key1", b"value1").unwrap();
let handle = txn
.commit_with_options(&WriteOptions {
await_durable: false,
})
.await
.unwrap()
.unwrap();
assert_eq!(handle.seqnum(), 1);
assert_eq!(handle.create_ts(), 100);
clock.set(200);
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
let put_opts = PutOptions {
ttl: crate::config::Ttl::ExpireAfter(1000),
};
txn.put_with_options(b"key2", b"value2", &put_opts).unwrap();
let handle = txn
.commit_with_options(&WriteOptions {
await_durable: false,
})
.await
.unwrap()
.unwrap();
assert_eq!(handle.seqnum(), 2);
assert_eq!(handle.create_ts(), 200);
clock.set(300);
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
txn.delete(b"key1").unwrap();
let handle = txn
.commit_with_options(&WriteOptions {
await_durable: false,
})
.await
.unwrap()
.unwrap();
assert_eq!(handle.seqnum(), 3);
assert_eq!(handle.create_ts(), 300);
}
#[tokio::test]
async fn test_txn_commit_with_options_empty_batch_returns_none() {
let object_store: Arc<dyn object_store::ObjectStore> = Arc::new(InMemory::new());
let db = crate::Db::open("test_txn_commit_with_options_empty_batch", object_store)
.await
.unwrap();
let txn = db.begin(IsolationLevel::Snapshot).await.unwrap();
let result = txn
.commit_with_options(&WriteOptions {
await_durable: false,
})
.await
.unwrap();
assert!(result.is_none());
}
}