use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use crossbeam_skiplist::SkipMap;
use tracing::info;
use super::types::{MemTableConfig, MemTableEntry, MemTableStats};
pub type Result<T> = std::result::Result<T, MemTableError>;
#[derive(Debug, Clone, thiserror::Error)]
pub enum MemTableError {
#[error("MemTable is read-only (being flushed)")]
ReadOnly,
#[error("MemTable is full")]
Full,
}
pub struct MemTable {
pub(crate) data: Arc<SkipMap<Vec<u8>, MemTableEntry>>,
pub(crate) size_bytes: Arc<AtomicUsize>,
pub(crate) entry_count: Arc<AtomicUsize>,
pub(crate) tombstone_count: Arc<AtomicUsize>,
pub(crate) sequence: Arc<AtomicU64>,
pub(crate) created_at: Instant,
pub(crate) config: MemTableConfig,
read_only: Arc<AtomicBool>,
}
impl MemTable {
pub fn new(config: MemTableConfig) -> Self {
Self {
data: Arc::new(SkipMap::new()),
size_bytes: Arc::new(AtomicUsize::new(0)),
entry_count: Arc::new(AtomicUsize::new(0)),
tombstone_count: Arc::new(AtomicUsize::new(0)),
sequence: Arc::new(AtomicU64::new(0)),
created_at: Instant::now(),
config,
read_only: Arc::new(AtomicBool::new(false)),
}
}
#[inline]
pub fn insert(&self, key: &[u8], value: &[u8]) -> Result<u64> {
if self.read_only.load(Ordering::Acquire) {
return Err(MemTableError::ReadOnly);
}
if self.should_flush() {
return Err(MemTableError::Full);
}
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
let entry_size = key.len() + value.len() + 32;
let entry = MemTableEntry::new(value.to_vec(), sequence);
let was_tombstone = self
.data
.get(key)
.map(|e| e.value().is_tombstone())
.unwrap_or(false);
self.data.insert(key.to_vec(), entry);
self.size_bytes.fetch_add(entry_size, Ordering::Relaxed);
if self.data.get(key).is_some() && was_tombstone {
self.tombstone_count.fetch_sub(1, Ordering::Relaxed);
} else {
self.entry_count.fetch_add(1, Ordering::Relaxed);
}
Ok(sequence)
}
#[inline]
pub fn delete(&self, key: &[u8]) -> Result<u64> {
if self.read_only.load(Ordering::Acquire) {
return Err(MemTableError::ReadOnly);
}
if self.should_flush() {
return Err(MemTableError::Full);
}
let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
let entry = MemTableEntry::tombstone(sequence);
let was_value = self
.data
.get(key)
.map(|e| !e.value().is_tombstone())
.unwrap_or(false);
self.data.insert(key.to_vec(), entry);
if was_value {
self.tombstone_count.fetch_add(1, Ordering::Relaxed);
} else {
self.entry_count.fetch_add(1, Ordering::Relaxed);
self.tombstone_count.fetch_add(1, Ordering::Relaxed);
}
self.size_bytes.fetch_add(key.len() + 32, Ordering::Relaxed);
Ok(sequence)
}
#[inline]
pub fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.data.get(key).and_then(|entry| {
let e = entry.value();
if e.is_tombstone() {
None
} else {
e.value.clone()
}
})
}
#[inline]
pub fn contains_key(&self, key: &[u8]) -> bool {
self.data
.get(key)
.map(|e| !e.value().is_tombstone())
.unwrap_or(false)
}
pub fn get_entry(&self, key: &[u8]) -> Option<MemTableEntry> {
self.data.get(key).map(|e| e.value().clone())
}
pub fn range(&self, start: &[u8], end: &[u8]) -> Vec<(Vec<u8>, Vec<u8>)> {
self.data
.range(start.to_vec()..end.to_vec())
.filter_map(|entry| {
let e = entry.value();
if e.is_tombstone() {
None
} else {
e.value.clone().map(|v| (entry.key().clone(), v))
}
})
.collect()
}
pub fn scan_prefix(&self, prefix: &[u8]) -> Vec<(Vec<u8>, Vec<u8>)> {
let mut end = prefix.to_vec();
if let Some(last) = end.last_mut() {
if *last < 255 {
*last += 1;
} else {
end.push(0);
}
}
self.data
.range(prefix.to_vec()..end)
.filter_map(|entry| {
let e = entry.value();
if e.is_tombstone() {
None
} else {
e.value.clone().map(|v| (entry.key().clone(), v))
}
})
.collect()
}
pub fn should_flush(&self) -> bool {
let size = self.size_bytes.load(Ordering::Relaxed);
let count = self.entry_count.load(Ordering::Relaxed);
let age = self.created_at.elapsed();
size >= self.config.max_size
|| count >= self.config.max_entries
|| age >= self.config.max_age
}
pub fn set_read_only(&self) {
self.read_only.store(true, Ordering::Release);
info!(
"MemTable marked as read-only (size: {} bytes, entries: {})",
self.size_bytes.load(Ordering::Relaxed),
self.entry_count.load(Ordering::Relaxed)
);
}
pub fn is_read_only(&self) -> bool {
self.read_only.load(Ordering::Acquire)
}
pub fn get_all_entries(&self) -> Vec<(Vec<u8>, MemTableEntry)> {
self.data
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
pub fn get_all_kv(&self) -> Vec<(Vec<u8>, Option<Vec<u8>>)> {
self.data
.iter()
.map(|entry| (entry.key().clone(), entry.value().value.clone()))
.collect()
}
pub fn clear(&self) {
self.size_bytes.store(0, Ordering::Relaxed);
self.entry_count.store(0, Ordering::Relaxed);
self.tombstone_count.store(0, Ordering::Relaxed);
info!("MemTable cleared after flush");
}
pub fn current_sequence(&self) -> u64 {
self.sequence.load(Ordering::Relaxed)
}
pub fn stats(&self) -> MemTableStats {
let now = Instant::now();
let (oldest, newest) = if let Some(first) = self.data.front() {
let oldest = Some(now - first.value().timestamp);
let newest = if let Some(last) = self.data.back() {
Some(now - last.value().timestamp)
} else {
oldest
};
(oldest, newest)
} else {
(None, None)
};
MemTableStats {
entry_count: self.entry_count.load(Ordering::Relaxed),
size_bytes: self.size_bytes.load(Ordering::Relaxed),
oldest_entry_age: oldest,
newest_entry_age: newest,
tombstone_count: self.tombstone_count.load(Ordering::Relaxed),
}
}
pub fn sequence(&self) -> u64 {
self.sequence.load(Ordering::Relaxed)
}
pub fn size_bytes(&self) -> usize {
self.size_bytes.load(Ordering::Relaxed)
}
pub fn entry_count(&self) -> usize {
self.entry_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.entry_count.load(Ordering::Relaxed) == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> MemTableConfig {
MemTableConfig {
max_size: 1024 * 1024,
max_entries: 1000,
..Default::default()
}
}
#[test]
fn test_insert_and_get() {
let table = MemTable::new(test_config());
table.insert(b"key1", b"value1").unwrap();
table.insert(b"key2", b"value2").unwrap();
assert_eq!(table.get(b"key1"), Some(b"value1".to_vec()));
assert_eq!(table.get(b"key2"), Some(b"value2".to_vec()));
assert_eq!(table.get(b"key3"), None);
}
#[test]
fn test_delete() {
let table = MemTable::new(test_config());
table.insert(b"key1", b"value1").unwrap();
assert!(table.contains_key(b"key1"));
table.delete(b"key1").unwrap();
assert!(!table.contains_key(b"key1"));
assert_eq!(table.get(b"key1"), None);
}
#[test]
fn test_range_scan() {
let table = MemTable::new(test_config());
table.insert(b"a", b"1").unwrap();
table.insert(b"b", b"2").unwrap();
table.insert(b"c", b"3").unwrap();
table.insert(b"d", b"4").unwrap();
let range = table.range(b"b", b"d");
assert_eq!(range.len(), 2);
assert_eq!(range[0], (b"b".to_vec(), b"2".to_vec()));
assert_eq!(range[1], (b"c".to_vec(), b"3".to_vec()));
}
#[test]
fn test_prefix_scan() {
let table = MemTable::new(test_config());
table.insert(b"user:1", b"alice").unwrap();
table.insert(b"user:2", b"bob").unwrap();
table.insert(b"post:1", b"hello").unwrap();
let users = table.scan_prefix(b"user:");
assert_eq!(users.len(), 2);
let posts = table.scan_prefix(b"post:");
assert_eq!(posts.len(), 1);
}
#[test]
fn test_overwrite() {
let table = MemTable::new(test_config());
table.insert(b"key", b"value1").unwrap();
assert_eq!(table.get(b"key"), Some(b"value1".to_vec()));
table.insert(b"key", b"value2").unwrap();
assert_eq!(table.get(b"key"), Some(b"value2".to_vec()));
}
#[test]
fn test_read_only() {
let table = MemTable::new(test_config());
table.insert(b"key", b"value").unwrap();
table.set_read_only();
assert!(matches!(
table.insert(b"key2", b"value"),
Err(MemTableError::ReadOnly)
));
}
}