use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use dashmap::DashMap;
use parking_lot::{Mutex, RwLock};
use smallvec::SmallVec;
use crate::durable_storage::InlineKey;
use sochdb_core::Result;
pub const DEFAULT_HOT_BUFFER_CAPACITY: usize = 100_000;
pub const FLUSH_THRESHOLD_RATIO: f64 = 0.8;
#[derive(Debug, Clone)]
pub struct HotEntry {
pub key: InlineKey,
pub value: Option<Vec<u8>>,
pub txn_id: u64,
pub seq: u64,
}
impl HotEntry {
pub fn new(key: InlineKey, value: Option<Vec<u8>>, txn_id: u64, seq: u64) -> Self {
Self {
key,
value,
txn_id,
seq,
}
}
}
#[derive(Debug)]
pub struct SortedBatch {
entries: Vec<HotEntry>,
key_index: DashMap<u64, usize>,
min_ts: u64,
max_ts: u64,
}
impl SortedBatch {
pub fn from_unsorted(mut entries: Vec<HotEntry>) -> Self {
if entries.is_empty() {
return Self {
entries: Vec::new(),
key_index: DashMap::new(),
min_ts: u64::MAX,
max_ts: 0,
};
}
entries.sort_unstable_by(|a, b| {
match a.key.as_slice().cmp(b.key.as_slice()) {
std::cmp::Ordering::Equal => b.seq.cmp(&a.seq), other => other,
}
});
let key_index = DashMap::new();
let mut last_key: Option<&[u8]> = None;
for (idx, entry) in entries.iter().enumerate() {
if last_key != Some(entry.key.as_slice()) {
let hash = Self::hash_key(&entry.key);
key_index.insert(hash, idx);
last_key = Some(entry.key.as_slice());
}
}
let min_ts = entries.iter().map(|e| e.seq).min().unwrap_or(u64::MAX);
let max_ts = entries.iter().map(|e| e.seq).max().unwrap_or(0);
Self {
entries,
key_index,
min_ts,
max_ts,
}
}
#[inline]
fn hash_key(key: &[u8]) -> u64 {
twox_hash::xxh3::hash64(key)
}
pub fn get(&self, key: &[u8]) -> Option<&HotEntry> {
let hash = Self::hash_key(key);
if let Some(idx) = self.key_index.get(&hash) {
let idx = *idx;
if idx < self.entries.len() && self.entries[idx].key.as_slice() == key {
return Some(&self.entries[idx]);
}
}
self.entries
.binary_search_by(|e| e.key.as_slice().cmp(key))
.ok()
.map(|idx| &self.entries[idx])
}
pub fn prefix_scan(&self, prefix: &[u8]) -> impl Iterator<Item = &HotEntry> {
let start_idx = self.entries
.partition_point(|e| e.key.as_slice() < prefix);
self.entries[start_idx..]
.iter()
.take_while(move |e| e.key.starts_with(prefix))
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &HotEntry> {
self.entries.iter()
}
pub fn timestamp_range(&self) -> (u64, u64) {
(self.min_ts, self.max_ts)
}
}
pub struct TieredMemTable {
hot_buffer: RwLock<Vec<HotEntry>>,
hot_capacity: usize,
warm_batches: RwLock<Vec<Arc<SortedBatch>>>,
#[allow(dead_code)]
point_index: DashMap<Vec<u8>, (usize, usize)>,
seq_counter: AtomicU64,
size_bytes: AtomicU64,
entry_count: AtomicUsize,
pending_commits: DashMap<u64, u64>,
flush_lock: Mutex<()>,
}
impl TieredMemTable {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_HOT_BUFFER_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
hot_buffer: RwLock::new(Vec::with_capacity(capacity)),
hot_capacity: capacity,
warm_batches: RwLock::new(Vec::new()),
point_index: DashMap::new(),
seq_counter: AtomicU64::new(1),
size_bytes: AtomicU64::new(0),
entry_count: AtomicUsize::new(0),
pending_commits: DashMap::new(),
flush_lock: Mutex::new(()),
}
}
pub fn write(&self, key: &[u8], value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
let key_inline = SmallVec::from_slice(key);
let value_size = value.as_ref().map(|v| v.len()).unwrap_or(0);
let seq = self.seq_counter.fetch_add(1, Ordering::Relaxed);
let entry = HotEntry::new(key_inline, value, txn_id, seq);
{
let mut buffer = self.hot_buffer.write();
buffer.push(entry);
}
self.size_bytes.fetch_add((key.len() + value_size) as u64, Ordering::Relaxed);
self.entry_count.fetch_add(1, Ordering::Relaxed);
if self.should_flush() {
self.try_flush()?;
}
Ok(())
}
pub fn write_batch(&self, writes: &[(&[u8], Option<Vec<u8>>)], txn_id: u64) -> Result<()> {
let mut total_size = 0u64;
let mut entries = Vec::with_capacity(writes.len());
for (key, value) in writes {
let seq = self.seq_counter.fetch_add(1, Ordering::Relaxed);
let value_size = value.as_ref().map(|v| v.len()).unwrap_or(0);
total_size += (key.len() + value_size) as u64;
entries.push(HotEntry::new(
SmallVec::from_slice(key),
value.clone(),
txn_id,
seq,
));
}
{
let mut buffer = self.hot_buffer.write();
buffer.extend(entries);
}
self.size_bytes.fetch_add(total_size, Ordering::Relaxed);
self.entry_count.fetch_add(writes.len(), Ordering::Relaxed);
if self.should_flush() {
self.try_flush()?;
}
Ok(())
}
pub fn read(
&self,
key: &[u8],
snapshot_ts: u64,
current_txn_id: Option<u64>,
) -> Option<Vec<u8>> {
{
let buffer = self.hot_buffer.read();
for entry in buffer.iter().rev() {
if entry.key.as_slice() == key {
if self.is_visible(entry, snapshot_ts, current_txn_id) {
return entry.value.clone();
}
}
}
}
{
let batches = self.warm_batches.read();
for batch in batches.iter().rev() {
if let Some(entry) = batch.get(key) {
if self.is_visible(entry, snapshot_ts, current_txn_id) {
return entry.value.clone();
}
}
}
}
None
}
#[inline]
fn is_visible(&self, entry: &HotEntry, snapshot_ts: u64, current_txn_id: Option<u64>) -> bool {
if let Some(my_txn) = current_txn_id {
if entry.txn_id == my_txn {
return true;
}
}
if let Some(commit_ts) = self.pending_commits.get(&entry.txn_id) {
return *commit_ts < snapshot_ts;
}
false
}
pub fn commit(&self, txn_id: u64, commit_ts: u64, _write_set: &HashSet<InlineKey>) {
self.pending_commits.insert(txn_id, commit_ts);
}
pub fn abort(&self, txn_id: u64) {
self.pending_commits.remove(&txn_id);
let mut buffer = self.hot_buffer.write();
buffer.retain(|e| e.txn_id != txn_id);
}
pub fn scan_prefix(
&self,
prefix: &[u8],
snapshot_ts: u64,
current_txn_id: Option<u64>,
) -> Vec<(Vec<u8>, Vec<u8>)> {
let mut results = Vec::new();
let mut seen_keys: HashSet<Vec<u8>> = HashSet::new();
{
let buffer = self.hot_buffer.read();
for entry in buffer.iter().rev() {
if entry.key.starts_with(prefix)
&& !seen_keys.contains(entry.key.as_slice())
&& self.is_visible(entry, snapshot_ts, current_txn_id)
{
if let Some(ref value) = entry.value {
results.push((entry.key.to_vec(), value.clone()));
seen_keys.insert(entry.key.to_vec());
}
}
}
}
{
let batches = self.warm_batches.read();
for batch in batches.iter().rev() {
for entry in batch.prefix_scan(prefix) {
if !seen_keys.contains(entry.key.as_slice())
&& self.is_visible(entry, snapshot_ts, current_txn_id)
{
if let Some(ref value) = entry.value {
results.push((entry.key.to_vec(), value.clone()));
seen_keys.insert(entry.key.to_vec());
}
}
}
}
}
results.sort_unstable_by(|a, b| a.0.cmp(&b.0));
results
}
pub fn scan_prefix_tournament(
&self,
prefix: &[u8],
snapshot_ts: u64,
current_txn_id: Option<u64>,
) -> Vec<(Vec<u8>, Vec<u8>)> {
use crate::tournament_tree::TournamentTree;
let mut sorted_sources: Vec<Vec<HotEntry>> = Vec::new();
{
let buffer = self.hot_buffer.read();
let mut hot_entries: Vec<HotEntry> = buffer
.iter()
.filter(|e| e.key.starts_with(prefix))
.cloned()
.collect();
hot_entries.sort_unstable_by(|a, b| {
match a.key.as_slice().cmp(b.key.as_slice()) {
std::cmp::Ordering::Equal => b.seq.cmp(&a.seq),
other => other,
}
});
let mut seen = HashSet::new();
hot_entries.retain(|e| seen.insert(e.key.to_vec()));
if !hot_entries.is_empty() {
sorted_sources.push(hot_entries);
}
}
{
let batches = self.warm_batches.read();
for batch in batches.iter().rev() {
let entries: Vec<HotEntry> = batch
.prefix_scan(prefix)
.cloned()
.collect();
if !entries.is_empty() {
sorted_sources.push(entries);
}
}
}
if sorted_sources.is_empty() {
return Vec::new();
}
if sorted_sources.len() == 1 {
return sorted_sources
.into_iter()
.next()
.unwrap()
.into_iter()
.filter(|e| self.is_visible(e, snapshot_ts, current_txn_id))
.filter_map(|e| e.value.map(|v| (e.key.to_vec(), v)))
.collect();
}
#[derive(Clone)]
struct KeyedEntry {
entry: HotEntry,
source_idx: usize,
}
impl PartialEq for KeyedEntry {
fn eq(&self, other: &Self) -> bool {
self.entry.key.as_slice() == other.entry.key.as_slice()
}
}
impl Eq for KeyedEntry {}
impl PartialOrd for KeyedEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for KeyedEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.entry.key.as_slice().cmp(other.entry.key.as_slice()) {
std::cmp::Ordering::Equal => self.source_idx.cmp(&other.source_idx),
other => other,
}
}
}
let iters: Vec<_> = sorted_sources
.into_iter()
.enumerate()
.map(|(source_idx, v)| {
v.into_iter().map(move |e| KeyedEntry { entry: e, source_idx })
})
.collect();
let mut tree = TournamentTree::new(iters);
let mut results = Vec::new();
let mut last_key: Option<Vec<u8>> = None;
while let Some((_, keyed)) = tree.pop() {
let entry = keyed.entry;
if let Some(ref last) = last_key {
if entry.key.as_slice() == last.as_slice() {
continue;
}
}
last_key = Some(entry.key.to_vec());
if !self.is_visible(&entry, snapshot_ts, current_txn_id) {
continue;
}
if let Some(value) = entry.value {
results.push((entry.key.to_vec(), value));
}
}
results
}
fn should_flush(&self) -> bool {
let buffer = self.hot_buffer.read();
buffer.len() >= (self.hot_capacity as f64 * FLUSH_THRESHOLD_RATIO) as usize
}
pub fn try_flush(&self) -> Result<()> {
let guard = match self.flush_lock.try_lock() {
Some(g) => g,
None => return Ok(()), };
let entries = {
let mut buffer = self.hot_buffer.write();
if buffer.len() < (self.hot_capacity as f64 * FLUSH_THRESHOLD_RATIO) as usize {
return Ok(());
}
std::mem::replace(&mut *buffer, Vec::with_capacity(self.hot_capacity))
};
if entries.is_empty() {
return Ok(());
}
let batch = Arc::new(SortedBatch::from_unsorted(entries));
{
let mut batches = self.warm_batches.write();
batches.push(batch);
}
drop(guard);
Ok(())
}
pub fn flush(&self) -> Result<()> {
let _guard = self.flush_lock.lock();
let entries = {
let mut buffer = self.hot_buffer.write();
std::mem::replace(&mut *buffer, Vec::with_capacity(self.hot_capacity))
};
if entries.is_empty() {
return Ok(());
}
let batch = Arc::new(SortedBatch::from_unsorted(entries));
{
let mut batches = self.warm_batches.write();
batches.push(batch);
}
Ok(())
}
pub fn size(&self) -> u64 {
self.size_bytes.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.entry_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn batch_count(&self) -> usize {
self.warm_batches.read().len()
}
pub fn hot_buffer_len(&self) -> usize {
self.hot_buffer.read().len()
}
pub fn compact(&self) -> Result<()> {
let batches = {
let mut b = self.warm_batches.write();
std::mem::take(&mut *b)
};
if batches.len() <= 1 {
let mut b = self.warm_batches.write();
*b = batches;
return Ok(());
}
let all_entries: Vec<HotEntry> = batches
.iter()
.flat_map(|b| b.iter().cloned())
.collect();
let merged = Arc::new(SortedBatch::from_unsorted(all_entries));
{
let mut b = self.warm_batches.write();
b.clear();
b.push(merged);
}
Ok(())
}
}
impl Default for TieredMemTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tiered_memtable_basic() {
let table = TieredMemTable::new();
table.write(b"key1", Some(b"value1".to_vec()), 1).unwrap();
table.write(b"key2", Some(b"value2".to_vec()), 1).unwrap();
let mut write_set = HashSet::new();
write_set.insert(SmallVec::from_slice(b"key1"));
write_set.insert(SmallVec::from_slice(b"key2"));
table.commit(1, 100, &write_set);
let v1 = table.read(b"key1", 200, None);
let v2 = table.read(b"key2", 200, None);
assert_eq!(v1, Some(b"value1".to_vec()));
assert_eq!(v2, Some(b"value2".to_vec()));
}
#[test]
fn test_tiered_memtable_uncommitted_own() {
let table = TieredMemTable::new();
table.write(b"key1", Some(b"value1".to_vec()), 1).unwrap();
let v = table.read(b"key1", 100, Some(1));
assert_eq!(v, Some(b"value1".to_vec()));
let v = table.read(b"key1", 100, Some(2));
assert_eq!(v, None);
}
#[test]
fn test_tiered_memtable_flush() {
let table = TieredMemTable::with_capacity(100);
for i in 0..90 {
table.write(
format!("key{:04}", i).as_bytes(),
Some(format!("value{}", i).into_bytes()),
1,
).unwrap();
}
table.flush().unwrap();
assert!(table.batch_count() >= 1);
assert_eq!(table.hot_buffer_len(), 0);
}
#[test]
fn test_tiered_memtable_scan_prefix() {
let table = TieredMemTable::new();
table.write(b"users:1", Some(b"alice".to_vec()), 1).unwrap();
table.write(b"users:2", Some(b"bob".to_vec()), 1).unwrap();
table.write(b"posts:1", Some(b"post1".to_vec()), 1).unwrap();
let mut write_set = HashSet::new();
write_set.insert(SmallVec::from_slice(b"users:1"));
write_set.insert(SmallVec::from_slice(b"users:2"));
write_set.insert(SmallVec::from_slice(b"posts:1"));
table.commit(1, 100, &write_set);
let results = table.scan_prefix(b"users:", 200, None);
assert_eq!(results.len(), 2);
}
#[test]
fn test_sorted_batch() {
let entries = vec![
HotEntry::new(SmallVec::from_slice(b"c"), Some(b"3".to_vec()), 1, 3),
HotEntry::new(SmallVec::from_slice(b"a"), Some(b"1".to_vec()), 1, 1),
HotEntry::new(SmallVec::from_slice(b"b"), Some(b"2".to_vec()), 1, 2),
];
let batch = SortedBatch::from_unsorted(entries);
assert_eq!(batch.len(), 3);
assert_eq!(batch.get(b"a").unwrap().value, Some(b"1".to_vec()));
assert_eq!(batch.get(b"b").unwrap().value, Some(b"2".to_vec()));
assert_eq!(batch.get(b"c").unwrap().value, Some(b"3".to_vec()));
}
#[test]
fn test_sorted_batch_prefix_scan() {
let entries = vec![
HotEntry::new(SmallVec::from_slice(b"ab"), Some(b"1".to_vec()), 1, 1),
HotEntry::new(SmallVec::from_slice(b"abc"), Some(b"2".to_vec()), 1, 2),
HotEntry::new(SmallVec::from_slice(b"abd"), Some(b"3".to_vec()), 1, 3),
HotEntry::new(SmallVec::from_slice(b"xyz"), Some(b"4".to_vec()), 1, 4),
];
let batch = SortedBatch::from_unsorted(entries);
let results: Vec<_> = batch.prefix_scan(b"ab").collect();
assert_eq!(results.len(), 3);
}
#[test]
fn test_scan_prefix_tournament() {
let table = TieredMemTable::with_capacity(100);
for batch_idx in 0..3 {
for i in 0..10 {
let key = format!("users:{:02}", i);
let value = format!("value_batch{}_item{}", batch_idx, i);
table.write(key.as_bytes(), Some(value.into_bytes()), 1).unwrap();
}
table.flush().unwrap();
}
for i in 0..5 {
let key = format!("users:{:02}", i);
let value = format!("newest_value_{}", i);
table.write(key.as_bytes(), Some(value.into_bytes()), 1).unwrap();
}
let mut write_set = HashSet::new();
for i in 0..10 {
write_set.insert(SmallVec::from_slice(format!("users:{:02}", i).as_bytes()));
}
table.commit(1, 100, &write_set);
let results = table.scan_prefix_tournament(b"users:", 200, None);
assert_eq!(results.len(), 10);
for (i, (key, value)) in results.iter().take(5).enumerate() {
let expected_key = format!("users:{:02}", i);
assert_eq!(key.as_slice(), expected_key.as_bytes());
assert!(String::from_utf8_lossy(value).starts_with("newest_value_"));
}
}
#[test]
fn test_scan_tournament_deduplication() {
let table = TieredMemTable::with_capacity(100);
table.write(b"key:001", Some(b"old1".to_vec()), 1).unwrap();
table.flush().unwrap();
table.write(b"key:001", Some(b"old2".to_vec()), 1).unwrap();
table.flush().unwrap();
table.write(b"key:001", Some(b"newest".to_vec()), 1).unwrap();
let mut write_set = HashSet::new();
write_set.insert(SmallVec::from_slice(b"key:001"));
table.commit(1, 100, &write_set);
let results = table.scan_prefix_tournament(b"key:", 200, None);
assert_eq!(results.len(), 1);
assert_eq!(results[0].1.as_slice(), b"newest");
}
}