use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{Result, TxnError};
use crate::sync::{self, RwLock, RwLockWriteGuard};
use crate::timestamp::Timestamp;
pub type WriteEntry = (Arc<[u8]>, Option<Arc<[u8]>>);
#[cfg(not(loom))]
const DEFAULT_SHARDS: usize = 16;
#[cfg(loom)]
const DEFAULT_SHARDS: usize = 2;
pub trait VersionStore: Send + Sync {
fn get(&self, key: &[u8], read_ts: Timestamp) -> Result<Option<Arc<[u8]>>>;
fn try_commit(
&self,
read_ts: Timestamp,
commit_ts: Timestamp,
writes: Vec<WriteEntry>,
reads: &[Arc<[u8]>],
) -> Result<()>;
fn collect_garbage(&self, low_watermark: Timestamp) -> usize {
let _ = low_watermark;
0
}
}
#[derive(Debug, Clone)]
struct Version {
commit_ts: Timestamp,
value: Option<Arc<[u8]>>,
}
type Chains = HashMap<Arc<[u8]>, Vec<Version>>;
struct Shard {
chains: RwLock<Chains>,
}
pub struct MemoryStore {
shards: Box<[Shard]>,
mask: usize,
}
impl Default for MemoryStore {
fn default() -> Self {
MemoryStore::new()
}
}
impl MemoryStore {
#[must_use]
pub fn new() -> Self {
MemoryStore::with_shards(DEFAULT_SHARDS)
}
#[must_use]
pub fn with_shards(shards: usize) -> Self {
let count = shards.max(1).next_power_of_two();
let shards = (0..count)
.map(|_| Shard {
chains: RwLock::new(HashMap::new()),
})
.collect::<Vec<_>>()
.into_boxed_slice();
MemoryStore {
shards,
mask: count - 1,
}
}
#[must_use]
pub fn key_count(&self) -> usize {
self.shards
.iter()
.map(|shard| sync::read(&shard.chains).len())
.sum()
}
#[inline]
fn shard_of(&self, key: &[u8]) -> usize {
(hash_key(key) as usize) & self.mask
}
#[cfg(feature = "durability")]
pub(crate) fn install_recovered(&self, commit_ts: Timestamp, writes: Vec<WriteEntry>) {
for (key, value) in writes {
let shard = self.shard_of(&key);
sync::write(&self.shards[shard].chains)
.entry(key)
.or_default()
.push(Version { commit_ts, value });
}
}
}
impl VersionStore for MemoryStore {
fn get(&self, key: &[u8], read_ts: Timestamp) -> Result<Option<Arc<[u8]>>> {
let shard = &self.shards[self.shard_of(key)];
let chains = sync::read(&shard.chains);
Ok(visible_value(chains.get(key), read_ts))
}
fn try_commit(
&self,
read_ts: Timestamp,
commit_ts: Timestamp,
writes: Vec<WriteEntry>,
reads: &[Arc<[u8]>],
) -> Result<()> {
if writes.len() == 1 && reads.is_empty() {
let shard = self.shard_of(&writes[0].0);
let mut chains = sync::write(&self.shards[shard].chains);
if newer_than(chains.get(writes[0].0.as_ref()), read_ts) {
return Err(TxnError::conflict(writes[0].0.len()));
}
for (key, value) in writes {
chains
.entry(key)
.or_default()
.push(Version { commit_ts, value });
}
return Ok(());
}
let write_shards: Vec<usize> = writes.iter().map(|(k, _)| self.shard_of(k)).collect();
let read_shards: Vec<usize> = reads.iter().map(|k| self.shard_of(k)).collect();
let mut to_lock: Vec<usize> = write_shards
.iter()
.copied()
.chain(read_shards.iter().copied())
.collect();
to_lock.sort_unstable();
to_lock.dedup();
let mut guards: Vec<RwLockWriteGuard<'_, Chains>> = Vec::with_capacity(to_lock.len());
for &shard in &to_lock {
guards.push(sync::write(&self.shards[shard].chains));
}
for (entry, &shard) in writes.iter().zip(&write_shards) {
if let Ok(pos) = to_lock.binary_search(&shard) {
if newer_than(guards[pos].get(entry.0.as_ref()), read_ts) {
return Err(TxnError::conflict(entry.0.len()));
}
}
}
for (key, &shard) in reads.iter().zip(&read_shards) {
if let Ok(pos) = to_lock.binary_search(&shard) {
if newer_than(guards[pos].get(key.as_ref()), read_ts) {
return Err(TxnError::conflict(key.len()));
}
}
}
for ((key, value), &shard) in writes.into_iter().zip(&write_shards) {
if let Ok(pos) = to_lock.binary_search(&shard) {
guards[pos]
.entry(key)
.or_default()
.push(Version { commit_ts, value });
}
}
Ok(())
}
fn collect_garbage(&self, low_watermark: Timestamp) -> usize {
let mut reclaimed = 0;
for shard in &self.shards {
let mut chains = sync::write(&shard.chains);
chains.retain(|_key, chain| {
let visible = chain.partition_point(|v| v.commit_ts <= low_watermark);
if visible > 1 {
reclaimed += visible - 1;
let _ = chain.drain(0..visible - 1);
}
if chain.len() == 1
&& chain[0].commit_ts <= low_watermark
&& chain[0].value.is_none()
{
reclaimed += 1;
false
} else {
true
}
});
}
reclaimed
}
}
#[inline]
fn newer_than(versions: Option<&Vec<Version>>, read_ts: Timestamp) -> bool {
matches!(versions.and_then(|v| v.last()), Some(v) if v.commit_ts > read_ts)
}
#[inline]
fn visible_value(versions: Option<&Vec<Version>>, read_ts: Timestamp) -> Option<Arc<[u8]>> {
let versions = versions?;
let visible = versions.partition_point(|v| v.commit_ts <= read_ts);
let idx = visible.checked_sub(1)?;
versions[idx].value.clone()
}
#[inline]
fn hash_key(key: &[u8]) -> u64 {
let mut hash = 0xcbf2_9ce4_8422_2325;
for &byte in key {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
}
hash
}
#[cfg(all(test, not(loom)))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn k(b: &[u8]) -> Arc<[u8]> {
Arc::from(b)
}
fn commit(store: &MemoryStore, ts: u64, writes: Vec<WriteEntry>) {
store
.try_commit(
Timestamp::from_raw(ts - 1),
Timestamp::from_raw(ts),
writes,
&[],
)
.expect("commit");
}
#[test]
fn test_get_on_missing_key_returns_none() {
let store = MemoryStore::new();
assert_eq!(store.get(b"absent", Timestamp::from_raw(10)).unwrap(), None);
}
#[test]
fn test_read_sees_only_versions_at_or_before_snapshot() {
let store = MemoryStore::new();
commit(&store, 2, vec![(k(b"x"), Some(k(b"a")))]);
commit(&store, 4, vec![(k(b"x"), Some(k(b"b")))]);
assert_eq!(store.get(b"x", Timestamp::from_raw(1)).unwrap(), None);
assert_eq!(
store.get(b"x", Timestamp::from_raw(2)).unwrap().as_deref(),
Some(&b"a"[..])
);
assert_eq!(
store.get(b"x", Timestamp::from_raw(3)).unwrap().as_deref(),
Some(&b"a"[..])
);
assert_eq!(
store.get(b"x", Timestamp::from_raw(4)).unwrap().as_deref(),
Some(&b"b"[..])
);
assert_eq!(
store.get(b"x", Timestamp::from_raw(99)).unwrap().as_deref(),
Some(&b"b"[..])
);
}
#[test]
fn test_tombstone_reads_as_absent() {
let store = MemoryStore::new();
commit(&store, 1, vec![(k(b"x"), Some(k(b"a")))]);
commit(&store, 2, vec![(k(b"x"), None)]);
assert_eq!(
store.get(b"x", Timestamp::from_raw(1)).unwrap().as_deref(),
Some(&b"a"[..])
);
assert_eq!(store.get(b"x", Timestamp::from_raw(2)).unwrap(), None);
}
#[test]
fn test_write_write_conflict_is_detected() {
let store = MemoryStore::new();
commit(&store, 5, vec![(k(b"x"), Some(k(b"a")))]);
let err = store
.try_commit(
Timestamp::from_raw(4),
Timestamp::from_raw(6),
vec![(k(b"x"), Some(k(b"b")))],
&[],
)
.unwrap_err();
assert!(matches!(err, TxnError::Conflict { .. }));
assert_eq!(
store.get(b"x", Timestamp::from_raw(99)).unwrap().as_deref(),
Some(&b"a"[..])
);
}
#[test]
fn test_read_set_validation_detects_skew() {
let store = MemoryStore::new();
commit(&store, 5, vec![(k(b"y"), Some(k(b"1")))]);
let err = store
.try_commit(
Timestamp::from_raw(4),
Timestamp::from_raw(6),
vec![(k(b"x"), Some(k(b"a")))],
&[k(b"y")],
)
.unwrap_err();
assert!(matches!(err, TxnError::Conflict { .. }));
}
#[test]
fn test_multi_shard_commit_applies_all_keys() {
let store = MemoryStore::with_shards(8);
let writes: Vec<WriteEntry> = (0u8..32).map(|i| (k(&[i]), Some(k(&[i])))).collect();
commit(&store, 1, writes);
for i in 0u8..32 {
assert_eq!(
store.get(&[i], Timestamp::from_raw(1)).unwrap().as_deref(),
Some(&[i][..])
);
}
assert_eq!(store.key_count(), 32);
}
#[test]
fn test_with_shards_rounds_up_to_power_of_two() {
let store = MemoryStore::with_shards(5);
assert_eq!(store.shards.len(), 8);
assert_eq!(store.mask, 7);
}
#[test]
fn test_gc_prunes_versions_below_watermark_but_keeps_newest_visible() {
let store = MemoryStore::new();
commit(&store, 1, vec![(k(b"x"), Some(k(b"a")))]);
commit(&store, 2, vec![(k(b"x"), Some(k(b"b")))]);
commit(&store, 3, vec![(k(b"x"), Some(k(b"c")))]);
let reclaimed = store.collect_garbage(Timestamp::from_raw(2));
assert_eq!(reclaimed, 1);
assert_eq!(
store.get(b"x", Timestamp::from_raw(2)).unwrap().as_deref(),
Some(&b"b"[..])
);
assert_eq!(
store.get(b"x", Timestamp::from_raw(3)).unwrap().as_deref(),
Some(&b"c"[..])
);
}
#[test]
fn test_gc_drops_key_whose_only_survivor_is_a_passed_tombstone() {
let store = MemoryStore::new();
commit(&store, 1, vec![(k(b"x"), Some(k(b"a")))]);
commit(&store, 2, vec![(k(b"x"), None)]);
let reclaimed = store.collect_garbage(Timestamp::from_raw(5));
assert_eq!(reclaimed, 2);
assert_eq!(store.key_count(), 0);
}
#[test]
fn test_gc_keeps_everything_above_watermark() {
let store = MemoryStore::new();
commit(&store, 5, vec![(k(b"x"), Some(k(b"a")))]);
commit(&store, 6, vec![(k(b"x"), Some(k(b"b")))]);
assert_eq!(store.collect_garbage(Timestamp::from_raw(4)), 0);
assert_eq!(
store.get(b"x", Timestamp::from_raw(5)).unwrap().as_deref(),
Some(&b"a"[..])
);
}
#[test]
fn test_default_trait_gc_is_noop() {
struct NoHistory;
impl VersionStore for NoHistory {
fn get(&self, _: &[u8], _: Timestamp) -> Result<Option<Arc<[u8]>>> {
Ok(None)
}
fn try_commit(
&self,
_: Timestamp,
_: Timestamp,
_: Vec<WriteEntry>,
_: &[Arc<[u8]>],
) -> Result<()> {
Ok(())
}
}
assert_eq!(NoHistory.collect_garbage(Timestamp::from_raw(100)), 0);
}
}