use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum TransactionResult<T> {
Committed(T),
Aborted,
Conflict,
}
#[derive(Debug, Clone)]
pub enum TxnOp<K, V> {
Read(K),
Write(K, V),
Remove(K),
}
pub struct Transaction<K, V> {
pub(crate) ops: Vec<TxnOp<K, V>>,
pub(crate) read_set: Vec<K>,
pub(crate) write_set: Vec<K>,
#[allow(dead_code)]
pub(crate) version: u64,
}
impl<K: Clone, V: Clone> Transaction<K, V> {
#[tracing::instrument(level = "trace")]
pub fn new() -> Self {
Self {
ops: Vec::new(),
read_set: Vec::new(),
write_set: Vec::new(),
version: 0,
}
}
#[tracing::instrument(skip(self, key), level = "trace")]
pub fn read(&mut self, key: K) {
self.read_set.push(key.clone());
self.ops.push(TxnOp::Read(key));
}
#[tracing::instrument(skip(self, key, value), level = "trace")]
pub fn write(&mut self, key: K, value: V) {
self.write_set.push(key.clone());
self.ops.push(TxnOp::Write(key, value));
}
#[tracing::instrument(skip(self, key), level = "trace")]
pub fn remove(&mut self, key: K) {
self.write_set.push(key.clone());
self.ops.push(TxnOp::Remove(key));
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn len(&self) -> usize {
self.ops.len()
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
}
impl<K: Clone, V: Clone> Default for Transaction<K, V> {
#[tracing::instrument(level = "trace")]
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CasResult<V> {
Success(V),
Failure(V),
}
impl<V> CasResult<V> {
#[tracing::instrument(skip(self), level = "trace")]
pub fn is_success(&self) -> bool {
matches!(self, CasResult::Success(_))
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn into_value(self) -> V {
match self {
CasResult::Success(v) | CasResult::Failure(v) => v,
}
}
}
pub struct CowSnapshot<K, V> {
pub(crate) data: std::sync::Arc<Vec<(K, V)>>,
pub(crate) version: u64,
}
impl<K: Clone, V: Clone> CowSnapshot<K, V> {
#[tracing::instrument(skip(data), level = "trace")]
pub fn new(data: Vec<(K, V)>, version: u64) -> Self {
Self {
data: std::sync::Arc::new(data),
version,
}
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn version(&self) -> u64 {
self.version
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn iter(&self) -> impl Iterator<Item = &(K, V)> {
self.data.iter()
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn len(&self) -> usize {
self.data.len()
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Clone)]
pub enum ReplicaError {
ConnectionFailed,
Timeout,
Rejected(String),
QuorumFailed,
}
#[derive(Debug, Clone)]
pub enum ReplicationOp<K, V> {
Insert {
key: K,
value: V,
},
Remove {
key: K,
},
Clear,
}
#[async_trait::async_trait]
pub trait Replica<K, V>: Send + Sync
where
K: Send,
V: Send,
{
async fn replicate(&self, op: ReplicationOp<K, V>) -> Result<(), ReplicaError>;
async fn fetch_state(&self) -> Result<Vec<(K, V)>, ReplicaError>;
#[tracing::instrument(skip(self), level = "trace")]
async fn health_check(&self) -> Result<bool, ReplicaError> {
Ok(true)
}
}
#[derive(Debug, Clone, Default)]
pub struct LockProfile {
pub shard_id: usize,
pub contention_count: u64,
pub avg_wait_time_ns: u64,
pub max_wait_time_ns: u64,
pub reads: u64,
pub writes: u64,
}
#[derive(Debug, Clone)]
pub struct IsolatedSnapshot<K, V> {
pub version: u64,
pub timestamp: std::time::Instant,
pub(crate) data: Arc<Vec<(K, V)>>,
}
impl<K, V> IsolatedSnapshot<K, V> {
#[tracing::instrument(skip(data), level = "trace")]
pub fn new(version: u64, data: Vec<(K, V)>) -> Self {
Self {
version,
timestamp: std::time::Instant::now(),
data: Arc::new(data),
}
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn version(&self) -> u64 {
self.version
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn age(&self) -> std::time::Duration {
std::time::Instant::now().saturating_duration_since(self.timestamp)
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn len(&self) -> usize {
self.data.len()
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct QuorumConfig {
pub replica_count: usize,
pub write_quorum: usize,
pub read_quorum: usize,
pub timeout: std::time::Duration,
}
impl QuorumConfig {
#[tracing::instrument(level = "trace")]
pub fn strict(replica_count: usize) -> Self {
Self {
replica_count,
write_quorum: replica_count,
read_quorum: replica_count,
timeout: std::time::Duration::from_secs(5),
}
}
#[tracing::instrument(level = "trace")]
pub fn majority(replica_count: usize) -> Self {
let quorum = (replica_count / 2) + 1;
Self {
replica_count,
write_quorum: quorum,
read_quorum: quorum,
timeout: std::time::Duration::from_secs(5),
}
}
#[tracing::instrument(skip(self), level = "trace")]
pub fn is_valid(&self) -> bool {
self.write_quorum > self.replica_count / 2
&& self.read_quorum > 0
&& self.write_quorum <= self.replica_count
&& self.read_quorum <= self.replica_count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "async")]
use crate::AsyncShardedHashMap;
use crate::ShardedHashMap;
#[test]
fn transaction_creation() {
let txn: Transaction<String, i32> = Transaction::new();
assert!(txn.is_empty());
assert_eq!(txn.len(), 0);
}
#[test]
fn transaction_operations() {
let mut txn: Transaction<String, i32> = Transaction::new();
txn.read("a".into());
txn.write("b".into(), 10);
txn.remove("c".into());
assert_eq!(txn.len(), 3);
}
#[test]
fn cas_result_is_success() {
let success: CasResult<i32> = CasResult::Success(42);
assert!(success.is_success());
let failure: CasResult<i32> = CasResult::Failure(42);
assert!(!failure.is_success());
}
#[test]
fn cow_snapshot_creation() {
let data = vec![("a", 1), ("b", 2)];
let snap = CowSnapshot::new(data, 1);
assert_eq!(snap.len(), 2);
assert_eq!(snap.version(), 1);
}
#[test]
fn quorum_config_majority() {
let config = QuorumConfig::majority(5);
assert!(config.is_valid());
assert_eq!(config.write_quorum, 3);
}
#[test]
fn quorum_config_strict() {
let config = QuorumConfig::strict(3);
assert!(config.is_valid());
assert_eq!(config.write_quorum, 3);
assert_eq!(config.read_quorum, 3);
}
#[test]
fn isolated_snapshot() {
let data = vec![("a", 1), ("b", 2)];
let snap = IsolatedSnapshot::new(1, data);
assert_eq!(snap.version(), 1);
assert_eq!(snap.len(), 2);
assert!(snap.age() >= std::time::Duration::ZERO);
}
#[test]
fn test_sync_transaction_execution() {
let map: ShardedHashMap<String, i32> = ShardedHashMap::new(8);
map.insert("a".into(), 1);
map.insert("b".into(), 2);
let mut txn = Transaction::new();
txn.write("a".into(), 10);
txn.write("c".into(), 30);
txn.remove("b".into());
let result = map.execute_transaction(txn);
assert!(matches!(result, TransactionResult::Committed(())));
assert_eq!(map.get(&"a".into()), Some(10));
assert_eq!(map.get(&"b".into()), None);
assert_eq!(map.get(&"c".into()), Some(30));
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_transaction_execution() {
let map: AsyncShardedHashMap<String, i32> = AsyncShardedHashMap::new(8);
map.insert("a".into(), 1).await;
map.insert("b".into(), 2).await;
let mut txn = Transaction::new();
txn.write("a".into(), 10);
txn.write("c".into(), 30);
txn.remove("b".into());
let result = map.execute_transaction(txn).await;
assert!(matches!(result, TransactionResult::Committed(())));
assert_eq!(map.get(&"a".into()).await, Some(10));
assert_eq!(map.get(&"b".into()).await, None);
assert_eq!(map.get(&"c".into()).await, Some(30));
}
}