use std::{
collections::Bound,
hash::RandomState,
mem::transmute,
sync::{
Arc,
atomic::{
AtomicBool,
AtomicU64,
Ordering::Relaxed,
},
},
thread,
};
use crate::bloom::{
Bloom2,
BloomFilterBuilder,
CompressedBitmap,
FilterSize::KeyBytes3,
};
use bytes::Bytes;
use crossbeam_channel::{
Sender,
bounded,
unbounded,
};
use crossbeam_skiplist::{
SkipMap,
map::{
Entry,
Range,
},
};
use gxhash::gxhash64;
use parking_lot::Mutex;
use rand::random;
use rayon::prelude::*;
use tracing::instrument;
use crate::{
errs::{
MemtableError,
MemtableError::{
DataExceedsMaximum,
MemtableIsFrozen,
},
},
keypair::{
KeyBytes,
ValueBytes,
map_key_bound,
},
peek::Peekable,
stats::STATS,
utils::{
Deserializer,
Serializer,
},
};
pub const DEFAULT_MEMTABLE_SIZE_IN_BYTES: u64 = 64 * 1024 * 1024;
#[derive(Debug)]
pub struct Memtable {
id: u64,
gx_seed: Arc<i64>,
tx: Sender<Bytes>,
bloom: Arc<Mutex<Bloom2<RandomState, CompressedBitmap, u64>>>,
map: Arc<SkipMap<Bytes, Bytes>>,
size: AtomicU64,
max_size: AtomicU64,
entry_count: AtomicU64,
total_bytes_written: AtomicU64, max_entries: AtomicU64, frozen: Arc<AtomicBool>,
}
impl Memtable {
pub fn new(id: u64, max_size: u64) -> Self {
let frozen = Arc::new(AtomicBool::new(false));
let (tx, rx) = unbounded::<Bytes>();
let gx_seed: Arc<i64> = Arc::new(random());
let bloom = Arc::new(Mutex::new(
BloomFilterBuilder::default().size(KeyBytes3).build(),
));
let frozen_clone = frozen.clone();
let bloom_clone = bloom.clone();
let seed_clone = gx_seed.clone();
thread::spawn(move || {
let mut batch = Vec::with_capacity(1000);
while !frozen_clone.load(Relaxed) {
if let Ok(key) = rx.recv() {
batch.push(key);
while batch.len() < 1000 {
match rx.try_recv() {
| Ok(key) => batch.push(key),
| Err(_) => break,
}
}
{
let mut bloom = bloom_clone.lock();
for key_ptr in batch.drain(..) {
bloom.insert(&gxhash64(&key_ptr, *seed_clone));
}
}
}
}
STATS.current_threads.fetch_sub(1, Relaxed);
});
STATS.current_threads.fetch_add(1, Relaxed);
let initial_max_entries = ((max_size as f64 / 1536.0) * 0.5) as u64;
Memtable {
id,
gx_seed,
tx,
bloom,
map: Arc::new(SkipMap::new()),
size: AtomicU64::new(0),
max_size: AtomicU64::new(max_size),
entry_count: AtomicU64::new(0),
total_bytes_written: AtomicU64::new(0),
max_entries: AtomicU64::new(initial_max_entries),
frozen,
}
}
pub fn id(&self) -> u64 {
self.id
}
#[inline]
pub fn size(&self) -> u64 {
self.size.load(Relaxed)
}
#[instrument(level = "debug")]
#[inline]
pub fn get(&self, key: &KeyBytes) -> Option<ValueBytes> {
let _key_ptr = key.serialize_for_latest();
match self.map.get(&_key_ptr) {
| None => None,
| Some(_key) => self
.map
.get(_key.value())
.map(|val| ValueBytes::deserialize(val.value().clone())),
}
}
#[instrument(level = "debug")]
#[inline]
pub fn put(&self, key: KeyBytes, val: ValueBytes) -> Result<(), MemtableError> {
match self.put_batch(&[(key, val)]) {
| Ok(_) => Ok(()),
| Err(e) => Err(e),
}
}
#[instrument(level = "debug")]
#[inline]
pub fn put_batch(&self, data: &[(KeyBytes, ValueBytes)]) -> Result<usize, MemtableError> {
if self.frozen.load(Relaxed) {
return Err(MemtableIsFrozen);
}
let mut written = 0;
let max_entries = self.max_entries.load(Relaxed);
let max_size = self.max_size.load(Relaxed);
for (key, val) in data.iter() {
let current_entries = self.entry_count.load(Relaxed);
if current_entries >= max_entries {
if written == 0 {
return Err(DataExceedsMaximum);
}
return Ok(written);
}
let _key = key.serialize();
let _key_ptr = key.serialize_for_latest();
let _val = val.serialize();
let payload_size = ((_key.len() * 3) + _val.len() + size_of::<u128>()) as u64;
if payload_size + self.size.load(Relaxed) > max_size {
if written == 0 {
return Err(DataExceedsMaximum);
}
return Ok(written);
}
self.map.insert(_key.clone(), _val);
self.map.insert(_key_ptr.clone(), _key);
self.size.fetch_add(payload_size, Relaxed);
self.total_bytes_written.fetch_add(payload_size, Relaxed);
self.entry_count.fetch_add(1, Relaxed);
if current_entries > 0 && current_entries % 8192 == 0 {
let total_bytes = self.total_bytes_written.load(Relaxed);
let avg_entry_size = total_bytes / current_entries;
let new_max_entries = ((max_size as f64 * 0.5) / avg_entry_size as f64) as u64;
self.max_entries.store(new_max_entries, Relaxed);
}
let _ = self.tx.send(_key_ptr);
written += 1;
}
Ok(written)
}
#[instrument(level = "debug")]
#[inline]
pub fn scan(&self, lower: Bound<KeyBytes>, upper: Bound<KeyBytes>) -> MemtableIterator {
let (_lower, _upper) = (map_key_bound(lower), map_key_bound(upper));
let map_clone = self.map.clone();
let ranger = map_clone.range((_lower, _upper));
let range = unsafe { transmute(ranger) };
MemtableIterator::new(map_clone, range)
}
pub fn freeze(&self) {
self.frozen.store(true, Relaxed);
}
pub fn is_frozen(&self) -> bool {
self.frozen.load(Relaxed)
}
pub fn contains(&self, key: &KeyBytes) -> bool {
self.get(&key).is_some()
}
}
impl Drop for Memtable {
fn drop(&mut self) {
self.frozen.store(true, Relaxed);
}
}
#[derive(Debug)]
pub struct MemtableIterator {
inner: Range<'static, Bytes, (Bound<Bytes>, Bound<Bytes>), Bytes, Bytes>,
_map: Arc<SkipMap<Bytes, Bytes>>,
}
impl MemtableIterator {
#[instrument(level = "trace")]
fn new(
map: Arc<SkipMap<Bytes, Bytes>>,
inner: Range<'static, Bytes, (Bound<Bytes>, Bound<Bytes>), Bytes, Bytes>,
) -> Self {
MemtableIterator { inner, _map: map }
}
#[instrument(level = "trace")]
fn peekable(self) -> Peekable<Self> {
Peekable::new(self)
}
}
impl Iterator for MemtableIterator {
type Item = (KeyBytes, ValueBytes);
#[instrument(level = "trace")]
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
let entry = match self.inner.next() {
| Some(e) => e,
| None => return None,
};
let key = KeyBytes::deserialize(entry.key().clone());
if key.is_pointer_key() {
continue;
}
let value = ValueBytes::deserialize(entry.value().clone());
return Some((key, value));
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
#[cfg(not(loom))]
use std::{
sync::Arc,
thread,
};
use bytes::Bytes;
#[cfg(loom)]
use loom::{
sync::{
Arc,
atomic::{
AtomicBool,
AtomicU64,
Ordering,
},
},
thread,
};
use rand::{
Rng,
RngCore,
};
use crate::{
hlc::{
HLC,
HybridLogicalClock,
},
keypair::{
DEFAULT_NS,
KeyBytes,
ValueBytes,
},
memtable::{
DEFAULT_MEMTABLE_SIZE_IN_BYTES,
Memtable,
},
};
#[test]
fn test_memtable_basic() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let original_key = KeyBytes::new(DEFAULT_NS, Bytes::from("test"), clock.time());
let original_val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(
memtable
.put(original_key.clone(), original_val.clone())
.is_ok()
);
let val = memtable.get(&original_key);
assert!(val.is_some());
assert_eq!(original_val, val.unwrap());
}
#[test]
fn test_memtable_versioning() {
let memtable = Memtable::new(0, 2 << 23);
let clock = HybridLogicalClock::new();
let mut rng = rand::rng();
let ns = rng.random();
let key = Bytes::from("test-key");
const VERSIONS: usize = 1_000;
let mut batch = Vec::<(KeyBytes, ValueBytes)>::with_capacity(VERSIONS);
for i in 0..VERSIONS {
let _key = KeyBytes::new(ns, key.clone(), clock.time());
let _val = ValueBytes::new(ns, Bytes::copy_from_slice(&i.to_le_bytes()));
batch.push((_key, _val.clone()));
}
assert!(memtable.put_batch(batch.as_ref()).is_ok());
let val = memtable.get(&KeyBytes::new(ns, key.clone(), 0));
assert!(val.is_some());
let mut val_arr: [u8; 8] = Default::default();
val_arr.copy_from_slice(&val.unwrap().value.as_ref()[0..8]);
assert_eq!(usize::from_le_bytes(val_arr), VERSIONS - 1);
}
#[test]
fn test_exceeds_max_size() {
const MAX_SIZE: u64 = 2 << 6;
let memtable = Memtable::new(0, MAX_SIZE);
let clock = HybridLogicalClock::new();
let mut rng = rand::rng();
let buf = &mut [0_u8; MAX_SIZE as usize];
rng.fill_bytes(buf);
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("test-key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::copy_from_slice(buf));
assert!(
memtable.put(key, val.clone()).is_err(),
"there must be an error inserting a key pair larger than the max configured size"
);
}
#[test]
fn test_frozen() {
const MAX_SIZE: u64 = 2 << 6;
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
memtable.freeze();
let clock = HybridLogicalClock::new();
let mut rng = rand::rng();
let buf = &mut [0_u8; MAX_SIZE as usize];
rng.fill_bytes(buf);
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("test-key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::copy_from_slice(buf));
assert!(
memtable.put(key, val.clone()).is_err(),
"there must be an error inserting a key pair while the memtable is frozen"
);
}
#[test]
fn test_get_nonexistent_key() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("nonexistent"), 0);
let result = memtable.get(&key);
assert!(
result.is_none(),
"get on nonexistent key should return None"
);
}
#[test]
fn test_size_tracking() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let initial_size = memtable.size();
assert_eq!(initial_size, 0, "initial size should be 0");
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(memtable.put(key.clone(), val.clone()).is_ok());
let new_size = memtable.size();
assert!(new_size > 0, "size should increase after put");
assert!(
new_size > initial_size,
"size should be greater than initial"
);
}
#[test]
fn test_scan_empty_memtable() {
use std::collections::Bound;
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("key"), 0);
let mut iter = memtable.scan(Bound::Unbounded, Bound::Unbounded);
assert!(
iter.next().is_none(),
"scan on empty memtable should return no items"
);
}
#[test]
fn test_scan_single_key() {
use std::collections::Bound;
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(memtable.put(key.clone(), val.clone()).is_ok());
let iter = memtable.scan(Bound::Unbounded, Bound::Unbounded);
let items: Vec<_> = iter.collect();
assert!(items.len() >= 1, "scan should return at least one item");
}
#[test]
fn test_scan_with_bounds() {
use std::collections::Bound;
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
for i in 0..10 {
let key = KeyBytes::new(
DEFAULT_NS,
Bytes::from(format!("key-{:02}", i)),
clock.time(),
);
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(format!("value-{}", i)));
assert!(memtable.put(key, val).is_ok());
}
let lower = KeyBytes::new(DEFAULT_NS, Bytes::from("key-03"), u128::MAX);
let upper = KeyBytes::new(DEFAULT_NS, Bytes::from("key-07"), u128::MIN);
let iter = memtable.scan(Bound::Included(lower), Bound::Excluded(upper));
let items: Vec<_> = iter.collect();
assert!(
items.len() >= 1,
"scan with bounds should return items in range"
);
}
#[test]
fn test_multiple_gets() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
for i in 0..100 {
let key = KeyBytes::new(DEFAULT_NS, Bytes::from(format!("key-{}", i)), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(format!("value-{}", i)));
assert!(memtable.put(key, val).is_ok());
}
for i in 0..100 {
let key = KeyBytes::new(DEFAULT_NS, Bytes::from(format!("key-{}", i)), 0);
let result = memtable.get(&key);
assert!(result.is_some(), "all inserted keys should be retrievable");
}
}
#[test]
fn test_put_batch_empty() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let batch: Vec<(KeyBytes, ValueBytes)> = vec![];
let result = memtable.put_batch(&batch);
assert!(result.is_ok(), "empty batch should succeed");
}
#[test]
fn test_put_batch_versioned_keys() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let key_name = Bytes::from("versioned-key");
let mut batch = vec![];
for i in 0..10 {
let key = KeyBytes::new(DEFAULT_NS, key_name.clone(), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(format!("version-{}", i)));
batch.push((key, val));
}
assert!(memtable.put_batch(&batch).is_ok());
let result = memtable.get(&KeyBytes::new(DEFAULT_NS, key_name, 0));
assert!(result.is_some(), "versioned key should be retrievable");
}
#[test]
fn test_memtable_id_immutable() {
let memtable = Memtable::new(42, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
assert_eq!(memtable.id(), 42);
let clock = HybridLogicalClock::new();
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(memtable.put(key, val).is_ok());
assert_eq!(memtable.id(), 42);
}
#[test]
fn test_put_different_namespaces() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let key_name = Bytes::from("key");
for ns in 0..5 {
let key = KeyBytes::new(ns, key_name.clone(), clock.time());
let val = ValueBytes::new(ns, Bytes::from(format!("value-ns-{}", ns)));
assert!(memtable.put(key, val).is_ok());
}
for ns in 0..5 {
let key = KeyBytes::new(ns, key_name.clone(), 0);
let result = memtable.get(&key);
assert!(
result.is_some(),
"key in namespace {} should be retrievable",
ns
);
}
}
#[test]
fn test_iterator_size_hint() {
use std::collections::Bound;
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
for i in 0..10 {
let key = KeyBytes::new(DEFAULT_NS, Bytes::from(format!("key-{}", i)), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(memtable.put(key, val).is_ok());
}
let iter = memtable.scan(Bound::Unbounded, Bound::Unbounded);
let (lower, _upper) = iter.size_hint();
let _ = lower;
}
#[test]
fn test_drop_frozen_memtable() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let key = KeyBytes::new(DEFAULT_NS, Bytes::from("key"), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from("value"));
assert!(memtable.put(key, val).is_ok());
memtable.freeze();
drop(memtable);
}
#[test]
fn test_get_after_multiple_versions() {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
let clock = HybridLogicalClock::new();
let key_name = Bytes::from("multi-version");
for i in 0..100 {
let key = KeyBytes::new(DEFAULT_NS, key_name.clone(), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(format!("v{}", i)));
assert!(memtable.put(key, val).is_ok());
}
let result = memtable.get(&KeyBytes::new(DEFAULT_NS, key_name, 0));
assert!(
result.is_some(),
"should retrieve latest version efficiently"
);
}
#[test]
fn test_batch_exceeds_max_size() {
const SMALL_MAX: u64 = 1000;
let memtable = Memtable::new(0, SMALL_MAX);
let clock = HybridLogicalClock::new();
let mut batch = vec![];
for i in 0..100 {
let key = KeyBytes::new(DEFAULT_NS, Bytes::from(format!("key-{}", i)), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(vec![b'x'; 100]));
batch.push((key, val));
}
let result = memtable.put_batch(&batch);
match result {
| Ok(written) => {
assert!(
written < batch.len(),
"should not write all entries when exceeding max"
);
assert!(written > 0, "should write at least some entries");
},
| Err(_) => {
},
}
}
#[test]
#[cfg(not(loom))]
fn test_iterator_outlives_memtable() {
use std::collections::Bound;
let clock = HybridLogicalClock::new();
let iter = {
let memtable = Memtable::new(0, DEFAULT_MEMTABLE_SIZE_IN_BYTES);
for i in 0..5 {
let key =
KeyBytes::new(DEFAULT_NS, Bytes::from(format!("key-{}", i)), clock.time());
let val = ValueBytes::new(DEFAULT_NS, Bytes::from(format!("value-{}", i)));
assert!(memtable.put(key, val).is_ok());
}
memtable.scan(Bound::Unbounded, Bound::Unbounded)
};
let items: Vec<_> = iter.collect();
assert!(
items.len() >= 1,
"iterator should work after memtable is dropped"
);
}
#[test]
#[cfg(loom)]
fn loom_frozen_flag_race() {
use std::sync::atomic::Ordering::Relaxed;
loom::model(|| {
let frozen = Arc::new(AtomicBool::new(false));
let writes_succeeded = Arc::new(AtomicU64::new(0));
let f1 = frozen.clone();
let w1 = writes_succeeded.clone();
let f2 = frozen.clone();
let w2 = writes_succeeded.clone();
let t1 = thread::spawn(move || {
if !f1.load(Relaxed) {
thread::yield_now();
w1.fetch_add(1, Relaxed);
}
});
let t2 = thread::spawn(move || {
f2.store(true, Relaxed);
w2.load(Relaxed)
});
t1.join().unwrap();
let writes_when_frozen = t2.join().unwrap();
let is_frozen = frozen.load(Relaxed);
let total_writes = writes_succeeded.load(Relaxed);
if is_frozen && total_writes > 0 {
}
});
}
#[test]
#[cfg(loom)]
fn loom_size_tracking_race() {
use std::sync::atomic::Ordering::Relaxed;
loom::model(|| {
let size = Arc::new(AtomicU64::new(0));
let max_size = 100u64;
let s1 = size.clone();
let s2 = size.clone();
let t1 = thread::spawn(move || {
let payload_size = 30u64;
if payload_size + s1.load(Relaxed) <= max_size {
thread::yield_now();
s1.fetch_add(payload_size, Relaxed);
true
} else {
false
}
});
let t2 = thread::spawn(move || {
let payload_size = 80u64;
if payload_size + s2.load(Relaxed) <= max_size {
thread::yield_now();
s2.fetch_add(payload_size, Relaxed);
true
} else {
false
}
});
let wrote1 = t1.join().unwrap();
let wrote2 = t2.join().unwrap();
let final_size = size.load(Relaxed);
if wrote1 && wrote2 {
assert!(
final_size == 110,
"Both writes succeeded, total = {}",
final_size
);
}
});
}
#[test]
#[cfg(loom)]
fn loom_concurrent_size_updates() {
use std::sync::atomic::Ordering::Relaxed;
loom::model(|| {
let size = Arc::new(AtomicU64::new(0));
let s1 = size.clone();
let s2 = size.clone();
let t1 = thread::spawn(move || {
s1.fetch_add(10, Relaxed);
});
let t2 = thread::spawn(move || {
s2.fetch_add(20, Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(size.load(Relaxed), 30);
});
}
#[test]
#[cfg(loom)]
fn loom_freeze_idempotent() {
use std::sync::atomic::Ordering::Relaxed;
loom::model(|| {
let frozen = Arc::new(AtomicBool::new(false));
let f1 = frozen.clone();
let f2 = frozen.clone();
let t1 = thread::spawn(move || {
f1.store(true, Relaxed);
});
let t2 = thread::spawn(move || {
f2.store(true, Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
assert!(frozen.load(Relaxed));
});
}
#[test]
#[cfg(loom)]
fn loom_read_frozen_while_freezing() {
use std::sync::atomic::Ordering::Relaxed;
loom::model(|| {
let frozen = Arc::new(AtomicBool::new(false));
let f1 = frozen.clone();
let f2 = frozen.clone();
let t1 = thread::spawn(move || {
f1.store(true, Relaxed);
});
let t2 = thread::spawn(move || f2.load(Relaxed));
t1.join().unwrap();
let saw_frozen = t2.join().unwrap();
let final_frozen = frozen.load(Relaxed);
assert!(final_frozen);
});
}
}