use crate::error::{AmateRSError, ErrorContext, Result};
use crate::types::{CipherBlob, Key};
use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct MemtableConfig {
pub max_size_bytes: usize,
pub enable_wal: bool,
}
impl Default for MemtableConfig {
fn default() -> Self {
Self {
max_size_bytes: 64 * 1024 * 1024, enable_wal: true,
}
}
}
#[derive(Debug, Clone)]
enum MemtableEntry {
Value(CipherBlob),
Tombstone,
}
pub struct Memtable {
data: Arc<RwLock<BTreeMap<Key, MemtableEntry>>>,
size_bytes: Arc<RwLock<usize>>,
config: MemtableConfig,
sequence: Arc<RwLock<u64>>,
}
impl Memtable {
pub fn new() -> Self {
Self::with_config(MemtableConfig::default())
}
pub fn with_config(config: MemtableConfig) -> Self {
Self {
data: Arc::new(RwLock::new(BTreeMap::new())),
size_bytes: Arc::new(RwLock::new(0)),
config,
sequence: Arc::new(RwLock::new(0)),
}
}
pub fn put(&self, key: Key, value: CipherBlob) -> Result<()> {
let entry_size = Self::estimate_entry_size(&key, &value);
let mut data = self.data.write();
let mut size = self.size_bytes.write();
if let Some(old_entry) = data.get(&key) {
let old_size = match old_entry {
MemtableEntry::Value(v) => Self::estimate_entry_size(&key, v),
MemtableEntry::Tombstone => key.len() + 1,
};
*size = size.saturating_sub(old_size);
}
data.insert(key, MemtableEntry::Value(value));
*size += entry_size;
let mut seq = self.sequence.write();
*seq += 1;
Ok(())
}
pub fn get(&self, key: &Key) -> Result<Option<CipherBlob>> {
let data = self.data.read();
match data.get(key) {
Some(MemtableEntry::Value(v)) => Ok(Some(v.clone())),
Some(MemtableEntry::Tombstone) => Ok(None),
None => Ok(None),
}
}
pub fn delete(&self, key: Key) -> Result<()> {
let mut data = self.data.write();
let mut size = self.size_bytes.write();
if let Some(old_entry) = data.get(&key) {
let old_size = match old_entry {
MemtableEntry::Value(v) => Self::estimate_entry_size(&key, v),
MemtableEntry::Tombstone => key.len() + 1,
};
*size = size.saturating_sub(old_size);
}
let tombstone_size = key.len() + 1;
data.insert(key, MemtableEntry::Tombstone);
*size += tombstone_size;
let mut seq = self.sequence.write();
*seq += 1;
Ok(())
}
pub fn should_flush(&self) -> bool {
let size = *self.size_bytes.read();
size >= self.config.max_size_bytes
}
pub fn size_bytes(&self) -> usize {
*self.size_bytes.read()
}
pub fn len(&self) -> usize {
self.data.read().len()
}
pub fn is_empty(&self) -> bool {
self.data.read().is_empty()
}
pub fn sequence(&self) -> u64 {
*self.sequence.read()
}
pub fn entries(&self) -> Vec<(Key, Option<CipherBlob>)> {
let data = self.data.read();
data.iter()
.map(|(k, v)| {
let value = match v {
MemtableEntry::Value(blob) => Some(blob.clone()),
MemtableEntry::Tombstone => None,
};
(k.clone(), value)
})
.collect()
}
pub fn range(&self, start: &Key, end: &Key) -> Vec<(Key, CipherBlob)> {
let data = self.data.read();
data.range(start..end)
.filter_map(|(k, v)| match v {
MemtableEntry::Value(blob) => Some((k.clone(), blob.clone())),
MemtableEntry::Tombstone => None,
})
.collect()
}
#[cfg(test)]
pub fn clear(&self) {
let mut data = self.data.write();
let mut size = self.size_bytes.write();
data.clear();
*size = 0;
}
fn estimate_entry_size(key: &Key, value: &CipherBlob) -> usize {
key.len() + value.len() + 64
}
}
impl Default for Memtable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memtable_basic_operations() -> Result<()> {
let memtable = Memtable::new();
let key = Key::from_str("test_key");
let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
memtable.put(key.clone(), value.clone())?;
assert_eq!(memtable.len(), 1);
let retrieved = memtable.get(&key)?;
assert_eq!(retrieved, Some(value.clone()));
memtable.delete(key.clone())?;
let retrieved = memtable.get(&key)?;
assert_eq!(retrieved, None);
Ok(())
}
#[test]
fn test_memtable_size_tracking() -> Result<()> {
let memtable = Memtable::new();
assert_eq!(memtable.size_bytes(), 0);
let key = Key::from_str("key");
let value = CipherBlob::new(vec![0u8; 1000]);
memtable.put(key, value)?;
assert!(memtable.size_bytes() > 1000);
Ok(())
}
#[test]
fn test_memtable_ordering() -> Result<()> {
let memtable = Memtable::new();
memtable.put(Key::from_str("key3"), CipherBlob::new(vec![3]))?;
memtable.put(Key::from_str("key1"), CipherBlob::new(vec![1]))?;
memtable.put(Key::from_str("key2"), CipherBlob::new(vec![2]))?;
let entries = memtable.entries();
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].0, Key::from_str("key1"));
assert_eq!(entries[1].0, Key::from_str("key2"));
assert_eq!(entries[2].0, Key::from_str("key3"));
Ok(())
}
#[test]
fn test_memtable_range() -> Result<()> {
let memtable = Memtable::new();
for i in 0..10 {
let key = Key::from_str(&format!("key_{:02}", i));
let value = CipherBlob::new(vec![i as u8]);
memtable.put(key, value)?;
}
let start = Key::from_str("key_03");
let end = Key::from_str("key_07");
let range = memtable.range(&start, &end);
assert_eq!(range.len(), 4);
Ok(())
}
#[test]
fn test_memtable_flush_threshold() -> Result<()> {
let config = MemtableConfig {
max_size_bytes: 1000,
enable_wal: false,
};
let memtable = Memtable::with_config(config);
assert!(!memtable.should_flush());
for i in 0..100 {
let key = Key::from_str(&format!("key_{}", i));
let value = CipherBlob::new(vec![0u8; 100]);
memtable.put(key, value)?;
if memtable.should_flush() {
break;
}
}
assert!(memtable.should_flush());
Ok(())
}
#[test]
fn test_memtable_update() -> Result<()> {
let memtable = Memtable::new();
let key = Key::from_str("key");
let value1 = CipherBlob::new(vec![1, 2, 3]);
let value2 = CipherBlob::new(vec![4, 5, 6, 7, 8]);
memtable.put(key.clone(), value1)?;
let size1 = memtable.size_bytes();
memtable.put(key.clone(), value2.clone())?;
let size2 = memtable.size_bytes();
assert_ne!(size1, size2);
let retrieved = memtable.get(&key)?;
assert_eq!(retrieved, Some(value2));
Ok(())
}
#[test]
fn test_memtable_sequence() -> Result<()> {
let memtable = Memtable::new();
assert_eq!(memtable.sequence(), 0);
memtable.put(Key::from_str("key1"), CipherBlob::new(vec![1]))?;
assert_eq!(memtable.sequence(), 1);
memtable.put(Key::from_str("key2"), CipherBlob::new(vec![2]))?;
assert_eq!(memtable.sequence(), 2);
memtable.delete(Key::from_str("key1"))?;
assert_eq!(memtable.sequence(), 3);
Ok(())
}
}