#[cfg(test)]
mod tests;
use std::{
cmp::Reverse,
collections::BTreeMap,
path::Path,
sync::{
Arc, RwLock,
atomic::{AtomicU64, Ordering},
},
time::{SystemTime, UNIX_EPOCH},
};
use crate::engine::Record;
use crate::wal::{Wal, WalError};
use thiserror::Error;
use tracing::{error, info, trace};
#[derive(Debug, Error)]
pub enum MemtableError {
#[error("WAL error: {0}")]
Wal(#[from] WalError),
#[error("Flush required")]
FlushRequired,
#[error("Invalid argument: {0}")]
InvalidArgument(String),
#[error("Internal error: {0}")]
Internal(String),
}
pub struct Memtable {
inner: Arc<RwLock<MemtableInner>>,
pub wal: Wal<Record>,
next_lsn: AtomicU64,
}
#[derive(Debug, PartialEq, Clone)]
pub enum MemtablePointEntry {
Put {
value: Vec<u8>,
timestamp: u64,
lsn: u64,
},
Delete {
timestamp: u64,
lsn: u64,
},
}
impl MemtablePointEntry {
pub fn lsn(&self) -> u64 {
match self {
Self::Put { lsn, .. } | Self::Delete { lsn, .. } => *lsn,
}
}
#[allow(dead_code)]
pub fn timestamp(&self) -> u64 {
match self {
Self::Put { timestamp, .. } | Self::Delete { timestamp, .. } => *timestamp,
}
}
#[allow(dead_code)]
pub fn is_delete(&self) -> bool {
matches!(self, Self::Delete { .. })
}
#[allow(dead_code)]
pub fn value(&self) -> Option<&[u8]> {
match self {
Self::Put { value, .. } => Some(value),
Self::Delete { .. } => None,
}
}
}
const POINT_ENTRY_TAG_PUT: u8 = 0;
const POINT_ENTRY_TAG_DELETE: u8 = 1;
impl crate::encoding::Encode for MemtablePointEntry {
fn encode_to(&self, buf: &mut Vec<u8>) -> Result<(), crate::encoding::EncodingError> {
match self {
Self::Put {
value,
timestamp,
lsn,
} => {
crate::encoding::Encode::encode_to(&POINT_ENTRY_TAG_PUT, buf)?;
crate::encoding::Encode::encode_to(value, buf)?;
crate::encoding::Encode::encode_to(timestamp, buf)?;
crate::encoding::Encode::encode_to(lsn, buf)?;
}
Self::Delete { timestamp, lsn } => {
crate::encoding::Encode::encode_to(&POINT_ENTRY_TAG_DELETE, buf)?;
crate::encoding::Encode::encode_to(timestamp, buf)?;
crate::encoding::Encode::encode_to(lsn, buf)?;
}
}
Ok(())
}
}
impl crate::encoding::Decode for MemtablePointEntry {
fn decode_from(buf: &[u8]) -> Result<(Self, usize), crate::encoding::EncodingError> {
let (tag, mut offset) = <u8 as crate::encoding::Decode>::decode_from(buf)?;
match tag {
POINT_ENTRY_TAG_PUT => {
let (value, n) = <Vec<u8> as crate::encoding::Decode>::decode_from(&buf[offset..])?;
offset += n;
let (timestamp, n) = <u64 as crate::encoding::Decode>::decode_from(&buf[offset..])?;
offset += n;
let (lsn, n) = <u64 as crate::encoding::Decode>::decode_from(&buf[offset..])?;
offset += n;
Ok((
Self::Put {
value,
timestamp,
lsn,
},
offset,
))
}
POINT_ENTRY_TAG_DELETE => {
let (timestamp, n) = <u64 as crate::encoding::Decode>::decode_from(&buf[offset..])?;
offset += n;
let (lsn, n) = <u64 as crate::encoding::Decode>::decode_from(&buf[offset..])?;
offset += n;
Ok((Self::Delete { timestamp, lsn }, offset))
}
_ => Err(crate::encoding::EncodingError::InvalidTag {
tag: tag as u32,
type_name: "MemtablePointEntry",
}),
}
}
}
use crate::engine::RangeTombstone;
#[derive(Debug, PartialEq)]
pub enum MemtableGetResult {
Put(Vec<u8>),
Delete,
RangeDelete,
NotFound,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(dead_code)]
pub struct MemtableStats {
pub size_bytes: usize,
pub key_count: usize,
pub entry_count: usize,
pub tombstone_count: usize,
pub range_tombstone_count: usize,
}
struct MemtableInner {
tree: BTreeMap<Vec<u8>, BTreeMap<Reverse<u64>, MemtablePointEntry>>,
range_tombstones: BTreeMap<Vec<u8>, BTreeMap<Reverse<u64>, RangeTombstone>>,
approximate_size: usize,
write_buffer_size: usize,
}
impl Memtable {
pub fn new<P: AsRef<Path>>(
wal_path: P,
max_record_size: Option<u32>,
write_buffer_size: usize,
) -> Result<Self, MemtableError> {
info!("Initializing Memtable with WAL replay");
let wal = Wal::open(&wal_path, max_record_size)?;
let mut inner = MemtableInner {
tree: BTreeMap::new(),
range_tombstones: BTreeMap::new(),
approximate_size: 0,
write_buffer_size,
};
let mut max_lsn_seen: u64 = 0;
let records = wal.replay_iter()?;
for record in records {
let record: Record = record?;
match record {
Record::Put {
key,
value,
lsn,
timestamp,
} => {
let record_size =
std::mem::size_of::<MemtablePointEntry>() + key.len() + value.len();
inner.approximate_size += record_size;
if lsn > max_lsn_seen {
max_lsn_seen = lsn;
}
let entry = MemtablePointEntry::Put {
value,
timestamp,
lsn,
};
inner
.tree
.entry(key)
.or_default()
.insert(Reverse(lsn), entry);
}
Record::Delete {
key,
lsn,
timestamp,
} => {
let record_size = std::mem::size_of::<MemtablePointEntry>() + key.len();
inner.approximate_size += record_size;
if lsn > max_lsn_seen {
max_lsn_seen = lsn;
}
let entry = MemtablePointEntry::Delete { timestamp, lsn };
inner
.tree
.entry(key)
.or_default()
.insert(Reverse(lsn), entry);
}
Record::RangeDelete {
start,
end,
lsn,
timestamp,
} => {
let record_size =
std::mem::size_of::<RangeTombstone>() + start.len() + end.len();
inner.approximate_size += record_size;
if lsn > max_lsn_seen {
max_lsn_seen = lsn;
}
let record_value = RangeTombstone {
start,
end,
lsn,
timestamp,
};
inner
.range_tombstones
.entry(record_value.start.clone())
.or_default()
.insert(Reverse(record_value.lsn), record_value);
}
}
}
info!(
"Memtable initialized successfully with LSN: {}",
max_lsn_seen
);
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
wal,
next_lsn: AtomicU64::new(max_lsn_seen.saturating_add(1)),
})
}
pub fn put(&self, key: Vec<u8>, value: Vec<u8>) -> Result<(), MemtableError> {
trace!("put() started, key: {}", HexKey(&key));
if key.is_empty() || value.is_empty() {
return Err(MemtableError::InvalidArgument(
"Key or value is empty".to_string(),
));
}
let record_size = std::mem::size_of::<MemtablePointEntry>() + key.len() + value.len();
let key_for_wal = key.clone();
let value_for_wal = value.clone();
let lsn = self.apply_write(
record_size,
"put",
|lsn, timestamp| Record::Put {
key: key_for_wal,
value: value_for_wal,
timestamp,
lsn,
},
|inner, lsn, timestamp| {
let entry = MemtablePointEntry::Put {
value,
timestamp,
lsn,
};
inner
.tree
.entry(key)
.or_default()
.insert(Reverse(lsn), entry);
},
)?;
trace!("Put operation completed with LSN: {}", lsn);
Ok(())
}
pub fn delete(&self, key: Vec<u8>) -> Result<(), MemtableError> {
trace!("delete() started, key: {}", HexKey(&key));
if key.is_empty() {
return Err(MemtableError::InvalidArgument("Key is empty".to_string()));
}
let record_size = std::mem::size_of::<MemtablePointEntry>() + key.len();
let key_for_wal = key.clone();
let lsn = self.apply_write(
record_size,
"delete",
|lsn, timestamp| Record::Delete {
key: key_for_wal,
lsn,
timestamp,
},
|inner, lsn, timestamp| {
let entry = MemtablePointEntry::Delete { timestamp, lsn };
inner
.tree
.entry(key)
.or_default()
.insert(Reverse(lsn), entry);
},
)?;
trace!("Delete operation completed with LSN: {}", lsn);
Ok(())
}
pub fn delete_range(&self, start: Vec<u8>, end: Vec<u8>) -> Result<(), MemtableError> {
trace!(
"delete_range() started, start key: {}, end key: {}",
HexKey(&start),
HexKey(&end)
);
if start.is_empty() || end.is_empty() {
return Err(MemtableError::InvalidArgument(
"Start or end key is empty".to_string(),
));
}
if start >= end {
return Err(MemtableError::InvalidArgument(
"Start key must be less than end key".to_string(),
));
}
let record_size = std::mem::size_of::<RangeTombstone>() + start.len() + end.len();
let start_for_wal = start.clone();
let end_for_wal = end.clone();
let lsn = self.apply_write(
record_size,
"delete_range",
|lsn, timestamp| Record::RangeDelete {
start: start_for_wal,
end: end_for_wal,
lsn,
timestamp,
},
|inner, lsn, timestamp| {
let entry_key = start.clone();
let tombstone = RangeTombstone {
start,
end,
lsn,
timestamp,
};
inner
.range_tombstones
.entry(entry_key)
.or_default()
.insert(Reverse(lsn), tombstone);
},
)?;
trace!("delete_range completed with LSN: {}", lsn);
Ok(())
}
fn apply_write<F, G>(
&self,
record_size: usize,
op_name: &str,
build_record: F,
apply_to_inner: G,
) -> Result<u64, MemtableError>
where
F: FnOnce(u64, u64) -> Record,
G: FnOnce(&mut MemtableInner, u64, u64),
{
{
let guard = self.inner.read().map_err(|_| {
error!("Read-write lock poisoned during {}", op_name);
MemtableError::Internal("Read-write lock poisoned".into())
})?;
if guard.approximate_size + record_size > guard.write_buffer_size {
return Err(MemtableError::FlushRequired);
}
}
let lsn = self.next_lsn.fetch_add(1, Ordering::SeqCst);
let timestamp = Self::current_timestamp();
let record = build_record(lsn, timestamp);
self.wal.append(&record)?;
let mut guard = self.inner.write().map_err(|_| {
error!("Read-write lock poisoned during {}", op_name);
MemtableError::Internal("Read-write lock poisoned".into())
})?;
apply_to_inner(&mut guard, lsn, timestamp);
guard.approximate_size += record_size;
Ok(lsn)
}
pub fn get(&self, key: &[u8]) -> Result<MemtableGetResult, MemtableError> {
trace!("get() started, key: {}", HexKey(key));
let guard = self.inner.read().map_err(|_| {
error!("Read-write lock poisoned during get");
MemtableError::Internal("RwLock poisoned".into())
})?;
let point_opt = guard
.tree
.get(key)
.and_then(|versions| versions.values().next());
let mut covering_tombstone_lsn: Option<u64> = None;
for (_start, versions) in guard.range_tombstones.range(..=key.to_vec()) {
for tombstone in versions.values() {
if tombstone.start.as_slice() <= key && key < tombstone.end.as_slice() {
covering_tombstone_lsn = Some(
covering_tombstone_lsn
.map(|lsn| lsn.max(tombstone.lsn))
.unwrap_or(tombstone.lsn),
);
break;
}
}
}
match (point_opt, covering_tombstone_lsn) {
(None, None) => Ok(MemtableGetResult::NotFound),
(None, Some(_)) => Ok(MemtableGetResult::RangeDelete),
(Some(point), None) => match point {
MemtablePointEntry::Delete { .. } => Ok(MemtableGetResult::Delete),
MemtablePointEntry::Put { value, .. } => Ok(MemtableGetResult::Put(value.clone())),
},
(Some(point), Some(tombstone_lsn)) => {
if tombstone_lsn > point.lsn() {
Ok(MemtableGetResult::RangeDelete)
} else {
match point {
MemtablePointEntry::Delete { .. } => Ok(MemtableGetResult::Delete),
MemtablePointEntry::Put { value, .. } => {
Ok(MemtableGetResult::Put(value.clone()))
}
}
}
}
}
}
pub fn scan(
&self,
start: &[u8],
end: &[u8],
) -> Result<impl Iterator<Item = Record>, MemtableError> {
trace!(
"scan() started with range. Start key: {} end key: {}",
HexKey(start),
HexKey(end)
);
if start >= end {
return Ok(Vec::new().into_iter());
}
let guard = self.inner.read().map_err(|_| {
error!("Read-write lock poisoned during scan");
MemtableError::Internal("RwLock poisoned".into())
})?;
let mut out = Vec::new();
for (key, versions) in guard.tree.range(start.to_vec()..end.to_vec()) {
for entry in versions.values() {
let record = match entry {
MemtablePointEntry::Delete { lsn, timestamp } => Record::Delete {
key: key.clone(),
lsn: *lsn,
timestamp: *timestamp,
},
MemtablePointEntry::Put {
value,
lsn,
timestamp,
} => Record::Put {
key: key.clone(),
value: value.clone(),
lsn: *lsn,
timestamp: *timestamp,
},
};
out.push(record);
}
}
for (_tombstone_start, versions) in guard.range_tombstones.iter() {
for tombstone in versions.values() {
if tombstone.end.as_slice() <= start || tombstone.start.as_slice() >= end {
continue;
}
let record = Record::RangeDelete {
start: tombstone.start.clone(),
end: tombstone.end.clone(),
lsn: tombstone.lsn,
timestamp: tombstone.timestamp,
};
out.push(record);
}
}
out.sort_by(|a, b| {
let ka = a.key();
let kb = b.key();
match ka.cmp(kb) {
std::cmp::Ordering::Equal => b.lsn().cmp(&a.lsn()), other => other,
}
});
Ok(out.into_iter())
}
pub fn iter_for_flush(&self) -> Result<impl Iterator<Item = Record>, MemtableError> {
let guard = self.inner.read().map_err(|_| {
error!("Read-write lock poisoned during iter_for_flush");
MemtableError::Internal("Read-write lock poisoned".into())
})?;
let mut records = Vec::new();
for (key, versions) in guard.tree.iter() {
if let Some(entry) = versions.values().next() {
let record = match entry {
MemtablePointEntry::Delete { lsn, timestamp } => Record::Delete {
key: key.clone(),
lsn: *lsn,
timestamp: *timestamp,
},
MemtablePointEntry::Put {
value,
lsn,
timestamp,
} => Record::Put {
key: key.clone(),
value: value.clone(),
lsn: *lsn,
timestamp: *timestamp,
},
};
records.push(record);
}
}
for (start, versions) in guard.range_tombstones.iter() {
for entry in versions.values() {
let record = Record::RangeDelete {
start: start.clone(),
end: entry.end.clone(),
lsn: entry.lsn,
timestamp: entry.timestamp,
};
records.push(record);
}
}
Ok(records.into_iter())
}
#[allow(dead_code)]
pub fn stats(&self) -> Result<MemtableStats, MemtableError> {
let guard = self.inner.read().map_err(|_| {
error!("Read-write lock poisoned during stats");
MemtableError::Internal("Read-write lock poisoned".into())
})?;
let mut entry_count: usize = 0;
let mut tombstone_count: usize = 0;
for versions in guard.tree.values() {
for entry in versions.values() {
entry_count += 1;
if entry.is_delete() {
tombstone_count += 1;
}
}
}
let range_tombstone_count: usize = guard
.range_tombstones
.values()
.map(|versions| versions.len())
.sum();
Ok(MemtableStats {
size_bytes: guard.approximate_size,
key_count: guard.tree.len(),
entry_count,
tombstone_count,
range_tombstone_count,
})
}
pub fn frozen(self) -> Result<FrozenMemtable, MemtableError> {
Ok(FrozenMemtable::new(self))
}
pub fn inject_max_lsn(&self, lsn: u64) {
self.next_lsn.store(lsn.saturating_add(1), Ordering::SeqCst);
}
pub fn max_lsn(&self) -> Option<u64> {
let next = self.next_lsn.load(Ordering::SeqCst);
if next <= 1 { None } else { Some(next - 1) }
}
pub fn wal_seq(&self) -> u64 {
self.wal.wal_seq()
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
}
pub struct FrozenMemtable {
memtable: Memtable,
#[allow(dead_code)]
creation_timestamp: u64,
}
impl FrozenMemtable {
pub fn new(memtable: Memtable) -> Self {
Self {
memtable,
creation_timestamp: Memtable::current_timestamp(),
}
}
pub fn wal_seq(&self) -> u64 {
self.memtable.wal.wal_seq()
}
#[allow(dead_code)]
pub fn creation_timestamp(&self) -> u64 {
self.creation_timestamp
}
pub fn get(&self, key: &[u8]) -> Result<MemtableGetResult, MemtableError> {
self.memtable.get(key)
}
pub fn scan(
&self,
start: &[u8],
end: &[u8],
) -> Result<impl Iterator<Item = Record>, MemtableError> {
self.memtable.scan(start, end)
}
pub fn iter_for_flush(&self) -> Result<impl Iterator<Item = Record>, MemtableError> {
self.memtable.iter_for_flush()
}
pub fn max_lsn(&self) -> Option<u64> {
self.memtable.max_lsn()
}
}
#[allow(dead_code)]
pub trait ReadMemtable {
fn get(&self, key: &[u8]) -> Result<MemtableGetResult, MemtableError>;
fn scan(
&self,
start: &[u8],
end: &[u8],
) -> Result<Box<dyn Iterator<Item = Record>>, MemtableError>;
fn max_lsn(&self) -> Option<u64>;
}
impl ReadMemtable for Memtable {
fn get(&self, key: &[u8]) -> Result<MemtableGetResult, MemtableError> {
self.get(key)
}
fn scan(
&self,
start: &[u8],
end: &[u8],
) -> Result<Box<dyn Iterator<Item = Record>>, MemtableError> {
let records: Vec<_> = self.scan(start, end)?.collect();
Ok(Box::new(records.into_iter()))
}
fn max_lsn(&self) -> Option<u64> {
self.max_lsn()
}
}
impl ReadMemtable for FrozenMemtable {
fn get(&self, key: &[u8]) -> Result<MemtableGetResult, MemtableError> {
self.get(key)
}
fn scan(
&self,
start: &[u8],
end: &[u8],
) -> Result<Box<dyn Iterator<Item = Record>>, MemtableError> {
let records: Vec<_> = self.scan(start, end)?.collect();
Ok(Box::new(records.into_iter()))
}
fn max_lsn(&self) -> Option<u64> {
self.max_lsn()
}
}
struct HexKey<'a>(&'a [u8]);
impl<'a> std::fmt::Display for HexKey<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.0.len() <= 32 {
for byte in self.0 {
write!(f, "{:02x}", byte)?;
}
} else {
for byte in &self.0[..16] {
write!(f, "{:02x}", byte)?;
}
write!(f, "...[{} bytes]", self.0.len())?;
}
Ok(())
}
}