use super::{Key, Value, LSMConfig};
use crate::{Result, StorageError};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::collections::BTreeMap;
pub struct MemTable {
data: Arc<RwLock<BTreeMap<Key, Value>>>,
size: AtomicUsize,
max_size: usize,
next_seq: AtomicUsize,
}
impl MemTable {
pub fn new(config: &LSMConfig) -> Self {
Self {
data: Arc::new(RwLock::new(BTreeMap::new())),
size: AtomicUsize::new(0),
max_size: config.memtable_size,
next_seq: AtomicUsize::new(0),
}
}
pub fn put(&self, key: Key, value: Value) -> Result<()> {
let key_size = 8; let value_size = value.data.len() + 16; let entry_size = key_size + value_size;
let mut data = self.data.write()
.map_err(|_| StorageError::Lock("MemTable lock poisoned".into()))?;
if let Some(old_value) = data.get(&key) {
let old_size = key_size + old_value.data.len() + 16;
self.size.fetch_sub(old_size, Ordering::Relaxed);
}
data.insert(key, value);
self.size.fetch_add(entry_size, Ordering::Relaxed);
self.next_seq.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn batch_put(&self, kvs: &[(Key, Value)]) -> Result<()> {
if kvs.is_empty() {
return Ok(());
}
let mut data = self.data.write()
.map_err(|_| StorageError::Lock("MemTable lock poisoned".into()))?;
let mut total_size_change: i64 = 0;
for (key, value) in kvs {
let key_size = 8;
let value_size = value.data.len() + 16;
let entry_size = key_size + value_size;
if let Some(old_value) = data.get(key) {
let old_size = key_size + old_value.data.len() + 16;
total_size_change -= old_size as i64;
}
data.insert(*key, value.clone());
total_size_change += entry_size as i64;
}
if total_size_change > 0 {
self.size.fetch_add(total_size_change as usize, Ordering::Relaxed);
} else if total_size_change < 0 {
self.size.fetch_sub((-total_size_change) as usize, Ordering::Relaxed);
}
self.next_seq.fetch_add(kvs.len(), Ordering::Relaxed);
Ok(())
}
pub fn get(&self, key: Key) -> Result<Option<Value>> {
let data = self.data.read()
.map_err(|_| StorageError::Lock("MemTable lock poisoned".into()))?;
Ok(data.get(&key).cloned())
}
pub fn delete(&self, key: Key, timestamp: u64) -> Result<()> {
self.put(key, Value::tombstone(timestamp))
}
pub fn should_flush(&self) -> bool {
self.size.load(Ordering::Relaxed) >= self.max_size
}
pub fn size(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.data.read()
.map(|data| data.len())
.unwrap_or(0) }
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> MemTableIteratorOptimized {
MemTableIteratorOptimized::new(self.data.clone())
}
pub fn snapshot(&self) -> Vec<(Key, Value)> {
let data = self.data.read()
.expect("MemTable snapshot: lock poisoned (unrecoverable in test)");
data.iter()
.map(|(k, v)| (*k, v.clone()))
.collect()
}
pub fn scan_with<F>(&self, start: Key, end: Key, mut f: F) -> Result<()>
where
F: FnMut(Key, &Value) -> Result<()>,
{
let data = self.data.read()
.map_err(|_| StorageError::Lock("MemTable lock poisoned".into()))?;
use std::ops::Bound;
let range = data.range((
Bound::Included(&start),
Bound::Excluded(&end)
));
for (k, v) in range {
if !v.deleted {
f(*k, v)?; }
}
Ok(())
}
pub fn scan(&self, start: Key, end: Key) -> Result<Vec<(Key, Value)>> {
let estimated_size = ((end - start) as usize).min(1000);
let mut results = Vec::with_capacity(estimated_size);
self.scan_with(start, end, |k, v| {
results.push((k, v.clone()));
Ok(())
})?;
Ok(results)
}
pub fn scan_all_with<F>(&self, mut f: F) -> Result<()>
where
F: FnMut(Key, &Value) -> Result<()>,
{
let data = self.data.read()
.map_err(|_| StorageError::Lock("MemTable lock poisoned".into()))?;
for (k, v) in data.iter() {
if !v.deleted {
f(*k, v)?; }
}
Ok(())
}
pub fn scan_all(&self) -> Result<Vec<(Key, Value)>> {
let mut results = Vec::with_capacity(1000);
self.scan_all_with(|k, v| {
results.push((k, v.clone()));
Ok(())
})?;
Ok(results)
}
}
#[allow(dead_code)]
pub struct MemTableIterator {
data: Arc<RwLock<BTreeMap<Key, Value>>>,
index: usize,
}
#[allow(dead_code)]
impl Iterator for MemTableIterator {
type Item = (Key, Value);
fn next(&mut self) -> Option<Self::Item> {
let data = self.data.read()
.expect("MemTableIterator: lock poisoned (test-only code)");
let item = data.iter().nth(self.index)?;
self.index += 1;
Some((*item.0, item.1.clone()))
}
}
pub struct MemTableIteratorOptimized {
entries: std::vec::IntoIter<(Key, Value)>,
}
impl MemTableIteratorOptimized {
pub fn new(data: Arc<RwLock<BTreeMap<Key, Value>>>) -> Self {
let data = data.read()
.expect("MemTableIteratorOptimized: lock poisoned (unrecoverable)");
let entries: Vec<(Key, Value)> = data.iter()
.map(|(k, v)| (*k, v.clone())) .collect();
Self {
entries: entries.into_iter(),
}
}
}
impl Iterator for MemTableIteratorOptimized {
type Item = (Key, Value);
fn next(&mut self) -> Option<Self::Item> {
self.entries.next()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_memtable() -> MemTable {
MemTable::new(&LSMConfig::default())
}
#[test]
fn test_put_get() {
let memtable = create_memtable();
let key = 12345u64; let value = Value::new(b"test_value".to_vec(), 1);
memtable.put(key, value.clone()).unwrap();
let retrieved = memtable.get(key).unwrap().unwrap();
assert_eq!(retrieved.data, value.data);
assert_eq!(retrieved.timestamp, 1);
assert_eq!(retrieved.deleted, false);
}
#[test]
fn test_delete() {
let memtable = create_memtable();
let key = 12345u64; memtable.put(key, Value::new(b"value".to_vec(), 1)).unwrap();
memtable.delete(key, 2).unwrap();
let retrieved = memtable.get(key).unwrap().unwrap();
assert_eq!(retrieved.deleted, true);
assert_eq!(retrieved.timestamp, 2);
}
#[test]
fn test_size_tracking() {
let memtable = create_memtable();
assert_eq!(memtable.size(), 0);
let key = 123u64; let value = Value::new(b"value".to_vec(), 1);
memtable.put(key, value).unwrap();
assert!(memtable.size() > 0);
let new_value = Value::new(b"new_value".to_vec(), 2);
memtable.put(key, new_value).unwrap();
assert!(memtable.size() > 0);
}
#[test]
fn test_should_flush() {
let mut config = LSMConfig::default();
config.memtable_size = 100; let memtable = MemTable::new(&config);
assert_eq!(memtable.should_flush(), false);
for i in 0..10 {
let key = i as u64; let value = Value::new(vec![0u8; 20], i);
memtable.put(key, value).unwrap();
}
assert_eq!(memtable.should_flush(), true);
}
#[test]
fn test_iterator() {
let memtable = create_memtable();
for i in 0..5 {
let key = i as u64; let value = Value::new(format!("value_{}", i).into_bytes(), i as u64);
memtable.put(key, value).unwrap();
}
let items: Vec<_> = memtable.iter().collect();
assert_eq!(items.len(), 5);
for (i, (key, _)) in items.iter().enumerate() {
let expected_key = i as u64;
assert_eq!(*key, expected_key);
}
}
}