use bytes::{Buf, BufMut};
use crate::error::{Error, Result};
use crate::memtable::iter::OrderedMemtableKvIterator;
use crate::memtable::vlog::{RewrittenValuePlan, encode_rewritten_value};
use crate::memtable::{Memtable, MemtableReclaimer};
use crate::sst::row_codec::{encode_key_ref_into, encode_value_ref_into};
use crate::r#type::{RefKey, RefValue};
use std::collections::BTreeMap;
pub(crate) type MemtableKvIterator<'a> = OrderedMemtableKvIterator<'a>;
pub(crate) struct HashMemtable {
buffer: Vec<u8>,
data_end: usize,
index_cursor: usize,
bucket_base: usize,
bucket_count: usize,
reclaimer: Option<MemtableReclaimer>,
}
pub(crate) struct MemtableValueIter<'a> {
mem: &'a HashMemtable,
key: Vec<u8>,
next_node: u32,
bucket: usize,
}
impl HashMemtable {
pub(crate) fn with_capacity(capacity: usize) -> Self {
let bucket_count = Self::default_bucket_count(capacity);
Self::with_capacity_and_buckets(capacity, bucket_count)
}
pub(crate) fn with_buffer(mut buffer: Vec<u8>) -> Self {
let capacity = buffer.len();
let bucket_count = Self::default_bucket_count(capacity);
let bucket_count = bucket_count.max(1);
let bucket_table_bytes = bucket_count * 4;
assert!(
capacity > bucket_table_bytes,
"capacity must exceed bucket table bytes"
);
let bucket_base = capacity - bucket_table_bytes;
Self::init_bucket_table(&mut buffer, bucket_base);
Self {
buffer,
data_end: 0,
index_cursor: bucket_base,
bucket_base,
bucket_count,
reclaimer: None,
}
}
pub(crate) fn with_buffer_and_reclaimer(buffer: Vec<u8>, reclaimer: MemtableReclaimer) -> Self {
let mut memtable = Self::with_buffer(buffer);
memtable.reclaimer = Some(reclaimer);
memtable
}
fn with_capacity_and_buckets(capacity: usize, bucket_count: usize) -> Self {
let bucket_count = bucket_count.max(1);
let bucket_table_bytes = bucket_count * 4;
assert!(
capacity > bucket_table_bytes,
"capacity must exceed bucket table bytes"
);
let mut buffer = vec![0u8; capacity];
let bucket_base = capacity - bucket_table_bytes;
Self::init_bucket_table(&mut buffer, bucket_base);
Self {
buffer,
data_end: 0,
index_cursor: bucket_base,
bucket_base,
bucket_count,
reclaimer: None,
}
}
fn default_bucket_count(capacity: usize) -> usize {
let target = capacity / 128;
target.clamp(4, 1024)
}
fn init_bucket_table(buffer: &mut [u8], bucket_base: usize) {
for chunk in buffer[bucket_base..].chunks_mut(4) {
chunk.copy_from_slice(&u32::MAX.to_le_bytes());
}
}
fn hash_key(key: &[u8]) -> u64 {
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut hash = FNV_OFFSET;
for &b in key {
hash ^= b as u64;
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
fn entry_size(key_len: usize, value_len: usize) -> usize {
4 + 4 + key_len + value_len
}
fn index_entry_size() -> usize {
8 + 4 + 4
}
fn has_space(&self, data_len: usize) -> Result<()> {
let need = data_len + Self::index_entry_size();
if self.data_end + need > self.index_cursor {
return Err(Error::MemtableFull {
needed: need,
remaining: self.index_cursor.saturating_sub(self.data_end),
});
}
Ok(())
}
fn write_data(&mut self, key: &[u8], value: &[u8]) -> usize {
let key_len = key.len() as u32;
let value_len = value.len() as u32;
let start = self.data_end;
let end = start + Self::entry_size(key.len(), value.len());
let mut slice = &mut self.buffer[start..end];
slice.put_u32(key_len);
slice.put_u32(value_len);
slice.put_slice(key);
slice.put_slice(value);
self.data_end = end;
start
}
fn write_data_ref(
&mut self,
key: &RefKey<'_>,
value: &RefValue<'_>,
num_columns: usize,
key_len: usize,
value_len: usize,
) -> (usize, usize) {
let start = self.data_end;
let end = start + Self::entry_size(key_len, value_len);
let mut slice = &mut self.buffer[start..end];
slice.put_u32(key_len as u32);
slice.put_u32(value_len as u32);
encode_key_ref_into(key, &mut slice);
encode_value_ref_into(value, num_columns, &mut slice);
self.data_end = end;
(start, start + 8)
}
fn bucket_head(&self, bucket: usize) -> u32 {
let pos = self.bucket_base + bucket * 4;
let mut slice = &self.buffer[pos..pos + 4];
slice.get_u32_le()
}
fn set_bucket_head(&mut self, bucket: usize, head: u32) {
let pos = self.bucket_base + bucket * 4;
let mut slice = &mut self.buffer[pos..pos + 4];
slice.put_u32_le(head);
}
fn write_index(&mut self, bucket: usize, hash: u64, key_offset: u32) -> u32 {
let entry_size = Self::index_entry_size();
let start = self.index_cursor - entry_size;
let head = self.bucket_head(bucket);
{
let mut slice = &mut self.buffer[start..self.index_cursor];
slice.put_u64(hash);
slice.put_u32(key_offset);
slice.put_u32(head);
}
self.index_cursor = start;
start as u32
}
fn bucket_index_from_hash(&self, hash: u64) -> usize {
(hash as usize) % self.bucket_count
}
}
impl Memtable for HashMemtable {
fn put(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
let data_len = Self::entry_size(key.len(), value.len());
self.has_space(data_len)?;
let data_offset = self.write_data(key, value);
let hash = Self::hash_key(key);
let bucket = self.bucket_index_from_hash(hash);
let node_off = self.write_index(bucket, hash, data_offset as u32);
self.set_bucket_head(bucket, node_off);
Ok(())
}
fn put_ref(
&mut self,
key: &RefKey<'_>,
value: &RefValue<'_>,
num_columns: usize,
) -> Result<()> {
let key_len = key.encoded_len();
let value_len = value.encoded_len(num_columns);
let data_len = Self::entry_size(key_len, value_len);
self.has_space(data_len)?;
let (data_offset, key_offset) =
self.write_data_ref(key, value, num_columns, key_len, value_len);
let hash = Self::hash_key(&self.buffer[key_offset..key_offset + key_len]);
let bucket = self.bucket_index_from_hash(hash);
let node_off = self.write_index(bucket, hash, data_offset as u32);
self.set_bucket_head(bucket, node_off);
Ok(())
}
fn put_ref_rewritten(
&mut self,
key: &RefKey<'_>,
plan: &RewrittenValuePlan<'_>,
num_columns: usize,
) -> Result<()> {
let value_len = plan.encoded_len(num_columns);
let key_len = key.encoded_len();
let data_len = Self::entry_size(key_len, value_len);
self.has_space(data_len)?;
let start = self.data_end;
let end = start + data_len;
let mut slice = &mut self.buffer[start..end];
slice.put_u32(key_len as u32);
slice.put_u32(value_len as u32);
encode_key_ref_into(key, &mut slice);
encode_rewritten_value(plan, num_columns, &mut slice[..value_len]);
self.data_end = end;
let key_offset = start + 8;
let hash = Self::hash_key(&self.buffer[key_offset..key_offset + key_len]);
let bucket = self.bucket_index_from_hash(hash);
let node_off = self.write_index(bucket, hash, start as u32);
self.set_bucket_head(bucket, node_off);
Ok(())
}
fn get(&self, key: &[u8]) -> Option<&[u8]> {
let hash = Self::hash_key(key);
let bucket = self.bucket_index_from_hash(hash);
let mut node_off = self.bucket_head(bucket);
while node_off != u32::MAX {
let start = node_off as usize;
if start + Self::index_entry_size() > self.buffer.len() {
break;
}
let mut node_slice = &self.buffer[start..start + Self::index_entry_size()];
let h = node_slice.get_u64();
let key_off = node_slice.get_u32() as usize;
let next = node_slice.get_u32();
if h == hash && key_off + 8 <= self.data_end {
let mut slice = &self.buffer[key_off..self.data_end];
let key_len = slice.get_u32() as usize;
let value_len = slice.get_u32() as usize;
if key_len + value_len <= slice.remaining() && slice[..key_len] == *key {
let value_start = key_len;
let value_end = value_start + value_len;
return Some(&slice[value_start..value_end]);
}
}
node_off = next;
}
None
}
fn get_all(&self, key: &[u8]) -> MemtableValueIter<'_> {
let bucket = self.bucket_index_from_hash(Self::hash_key(key));
let head = self.bucket_head(bucket);
MemtableValueIter {
mem: self,
key: key.to_vec(),
next_node: head,
bucket,
}
}
fn remaining_capacity(&self) -> usize {
self.index_cursor.saturating_sub(self.data_end)
}
fn is_empty(&self) -> bool {
self.data_end == 0
}
fn append_blob(&mut self, data: &[u8]) -> Result<usize> {
if self.data_end + data.len() > self.index_cursor {
return Err(Error::MemtableFull {
needed: data.len(),
remaining: self.index_cursor.saturating_sub(self.data_end),
});
}
let start = self.index_cursor - data.len();
self.buffer[start..self.index_cursor].copy_from_slice(data);
self.index_cursor = start;
Ok(start)
}
fn read_blob(&self, offset: usize, len: usize) -> Option<&[u8]> {
let end = offset.checked_add(len)?;
if offset < self.index_cursor || end > self.bucket_base {
return None;
}
Some(&self.buffer[offset..end])
}
fn flush_blobs_to_vlog_writer(
&self,
entries: &BTreeMap<u32, (usize, usize)>,
writer: &mut crate::vlog::VlogWriter<Box<dyn crate::file::SequentialWriteFile>>,
) -> Result<()> {
for (payload_start, payload_len) in entries.values() {
let payload = self
.read_blob(*payload_start, *payload_len)
.ok_or_else(|| {
Error::IoError(format!(
"VLOG recorder payload out of range at {} (len {})",
payload_start, payload_len
))
})?;
writer.add_value(payload)?;
}
Ok(())
}
fn write_vlog_data_since(
&self,
entries: &BTreeMap<u32, (usize, usize)>,
offset: u32,
writer: &mut dyn crate::file::SequentialWriteFile,
) -> Result<usize> {
let mut written = 0usize;
for (_entry_offset, (payload_start, payload_len)) in entries.range(offset..) {
let end = payload_start
.checked_add(*payload_len)
.ok_or_else(|| Error::IoError("VLOG payload range overflow".to_string()))?;
if *payload_start < self.index_cursor || end > self.bucket_base {
return Err(Error::IoError(format!(
"VLOG recorder payload out of range at {} (len {})",
payload_start, payload_len
)));
}
let len_u32 = u32::try_from(*payload_len).map_err(|_| {
Error::IoError(format!("VLOG value too large: {} bytes", payload_len))
})?;
writer.write(&len_u32.to_le_bytes())?;
writer.write(&self.buffer[*payload_start..end])?;
written = written.saturating_add(4 + *payload_len);
}
Ok(written)
}
fn blob_cursor_checkpoint(&self) -> usize {
self.index_cursor
}
fn rollback_blob_cursor(&mut self, checkpoint: usize) {
self.index_cursor = checkpoint;
}
fn data_offset(&self) -> usize {
self.data_end
}
fn write_data_since(
&self,
offset: usize,
writer: &mut dyn crate::file::SequentialWriteFile,
) -> Result<usize> {
if offset > self.data_end {
return Err(Error::InvalidState(format!(
"invalid memtable data offset {} > {}",
offset, self.data_end
)));
}
let bytes = &self.buffer[offset..self.data_end];
writer.write(bytes)?;
Ok(bytes.len())
}
fn iter(&self) -> MemtableKvIterator<'_> {
let mut entries: Vec<(&[u8], &[u8], usize)> = Vec::new();
let mut offset = 0;
while offset < self.data_end {
if offset + 8 > self.data_end {
break;
}
let mut slice = &self.buffer[offset..self.data_end];
let key_len = slice.get_u32() as usize;
let value_len = slice.get_u32() as usize;
if key_len + value_len > slice.remaining() {
break;
}
entries.push((
&slice[..key_len],
&slice[key_len..key_len + value_len],
offset,
));
offset += Self::entry_size(key_len, value_len);
}
MemtableKvIterator::new(entries)
}
type ValueIter<'a>
= MemtableValueIter<'a>
where
Self: 'a;
type KvIter<'a> = MemtableKvIterator<'a>;
}
impl<'a> Iterator for MemtableValueIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
while self.next_node != u32::MAX {
let start = self.next_node as usize;
if start + HashMemtable::index_entry_size() > self.mem.buffer.len() {
self.next_node = u32::MAX;
return None;
}
let mut node_slice = &self.mem.buffer[start..start + HashMemtable::index_entry_size()];
let h = node_slice.get_u64();
let key_off = node_slice.get_u32() as usize;
let next = node_slice.get_u32();
self.next_node = next;
if h == HashMemtable::hash_key(&self.key) && key_off + 8 <= self.mem.data_end {
let mut slice = &self.mem.buffer[key_off..self.mem.data_end];
let key_len = slice.get_u32() as usize;
let value_len = slice.get_u32() as usize;
if key_len + value_len <= slice.remaining() && slice[..key_len] == self.key {
let value_start = key_len;
let value_end = value_start + value_len;
return Some(&slice[value_start..value_end]);
}
}
}
None
}
}
impl Drop for HashMemtable {
fn drop(&mut self) {
if let Some(reclaimer) = &self.reclaimer {
reclaimer(self.buffer.len() as u64);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::iterator::KvIterator;
#[test]
fn put_and_get() {
let mut mem = HashMemtable::with_capacity(1024);
mem.put(b"key1", b"value1").unwrap();
mem.put(b"key2", b"value2").unwrap();
assert_eq!(mem.get(b"key1").unwrap(), b"value1");
assert_eq!(mem.get(b"key2").unwrap(), b"value2");
assert!(mem.get(b"missing").is_none());
}
#[test]
fn overwrite_updates_value() {
let mut mem = HashMemtable::with_capacity(1024);
mem.put(b"key", b"old").unwrap();
mem.put(b"key", b"new").unwrap();
assert_eq!(mem.get(b"key").unwrap(), b"new");
}
#[test]
fn capacity_enforced() {
let mut mem = HashMemtable::with_capacity(64);
mem.put(b"k1", b"v1").unwrap();
let err = mem.put(b"k2", b"value_too_long").unwrap_err();
match err {
Error::MemtableFull { .. } => {}
_ => panic!("unexpected error type"),
}
}
#[test]
fn remaining_capacity_updates() {
let mut mem = HashMemtable::with_capacity(100);
let before = mem.remaining_capacity();
mem.put(b"k", b"v").unwrap();
assert!(mem.remaining_capacity() < before);
}
#[test]
fn bucket_distribution_and_lookup() {
let mut mem = HashMemtable::with_capacity_and_buckets(256, 4);
mem.put(b"key1", b"v1").unwrap();
mem.put(b"key2", b"v2").unwrap();
mem.put(b"key3", b"v3").unwrap();
assert_eq!(mem.get(b"key1").unwrap(), b"v1");
assert_eq!(mem.get(b"key2").unwrap(), b"v2");
assert_eq!(mem.get(b"key3").unwrap(), b"v3");
}
#[test]
fn get_all_returns_latest_first() {
let mut mem = HashMemtable::with_capacity(512);
mem.put(b"key", b"v1").unwrap();
mem.put(b"key", b"v2").unwrap();
mem.put(b"key", b"v3").unwrap();
let mut iter = mem.get_all(b"key");
assert_eq!(iter.next().unwrap(), b"v3");
assert_eq!(iter.next().unwrap(), b"v2");
assert_eq!(iter.next().unwrap(), b"v1");
assert!(iter.next().is_none());
}
#[test]
fn kv_iterator_orders_keys_and_values() {
let mut mem = HashMemtable::with_capacity(1024);
mem.put(b"b", b"v1").unwrap();
mem.put(b"a", b"x1").unwrap();
mem.put(b"a", b"x2").unwrap();
mem.put(b"c", b"z1").unwrap();
let mut iter = mem.iter();
iter.seek_to_first().unwrap();
let mut collected = Vec::new();
while iter.next().unwrap() {
let k = iter.take_key().unwrap().unwrap();
let v = iter.take_value().unwrap().unwrap().unwrap_encoded();
collected.push((k, v));
}
let expected: Vec<(&[u8], &[u8])> =
vec![(b"a", b"x2"), (b"a", b"x1"), (b"b", b"v1"), (b"c", b"z1")];
assert_eq!(collected.len(), expected.len());
for (got, exp) in collected.iter().zip(expected.iter()) {
assert_eq!(got.0.as_ref(), exp.0);
assert_eq!(got.1.as_ref(), exp.1);
}
}
#[test]
fn blob_storage_does_not_affect_kv_iteration() {
let mut mem = HashMemtable::with_capacity(512);
mem.put(b"k1", b"v1").unwrap();
let blob_offset = mem.append_blob(b"blob-payload").unwrap();
mem.put(b"k2", b"v2").unwrap();
assert_eq!(
mem.read_blob(blob_offset, "blob-payload".len()).unwrap(),
b"blob-payload"
);
let mut iter = mem.iter();
iter.seek_to_first().unwrap();
let mut entries = Vec::new();
while iter.next().unwrap() {
let key = iter.take_key().unwrap().unwrap();
let value = iter.take_value().unwrap().unwrap().unwrap_encoded();
entries.push((key, value));
}
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].0.as_ref(), b"k1");
assert_eq!(entries[0].1.as_ref(), b"v1");
assert_eq!(entries[1].0.as_ref(), b"k2");
assert_eq!(entries[1].1.as_ref(), b"v2");
}
}