use crate::error::DbxResult;
use crate::storage::StorageBackend;
use crossbeam_skiplist::SkipMap;
use dashmap::DashMap;
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
const DEFAULT_FLUSH_THRESHOLD: usize = 10_000;
pub struct DeltaStore {
#[allow(clippy::type_complexity)]
tables: DashMap<String, Arc<SkipMap<Vec<u8>, Arc<Vec<u8>>>>>,
flush_threshold: usize,
entry_count: std::sync::atomic::AtomicUsize,
}
impl DeltaStore {
pub fn new() -> Self {
Self::with_threshold(DEFAULT_FLUSH_THRESHOLD)
}
pub fn with_threshold(threshold: usize) -> Self {
Self {
tables: DashMap::new(),
flush_threshold: threshold,
entry_count: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn should_flush(&self) -> bool {
self.entry_count() >= self.flush_threshold
}
pub fn entry_count(&self) -> usize {
self.entry_count.load(std::sync::atomic::Ordering::Relaxed)
}
#[allow(clippy::type_complexity)]
pub fn drain_all(&self) -> Vec<(String, Vec<(Vec<u8>, Vec<u8>)>)> {
let mut result = Vec::new();
let table_names: Vec<String> = self.tables.iter().map(|e| e.key().clone()).collect();
for table_name in table_names {
if let Some((_, table_map)) = self.tables.remove(&table_name) {
let entries: Vec<(Vec<u8>, Vec<u8>)> = table_map
.iter()
.map(|e| (e.key().clone(), (**e.value()).clone()))
.collect();
result.push((table_name, entries));
}
}
self.entry_count
.store(0, std::sync::atomic::Ordering::Relaxed);
result
}
fn get_or_create_table(&self, table: &str) -> Arc<SkipMap<Vec<u8>, Arc<Vec<u8>>>> {
self.tables
.entry(table.to_string())
.or_insert_with(|| Arc::new(SkipMap::new()))
.value()
.clone()
}
fn convert_bound(bound: Bound<&Vec<u8>>) -> Bound<Vec<u8>> {
match bound {
Bound::Included(v) => Bound::Included(v.clone()),
Bound::Excluded(v) => Bound::Excluded(v.clone()),
Bound::Unbounded => Bound::Unbounded,
}
}
}
impl Default for DeltaStore {
fn default() -> Self {
Self::new()
}
}
impl StorageBackend for DeltaStore {
fn insert(&self, table: &str, key: &[u8], value: &[u8]) -> DbxResult<()> {
let table_map = self.get_or_create_table(table);
table_map.insert(key.to_vec(), Arc::new(value.to_vec()));
self.entry_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
fn insert_batch(&self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<()> {
let table_map = self.get_or_create_table(table);
for (key, value) in rows {
table_map.insert(key, Arc::new(value));
self.entry_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Ok(())
}
fn get(&self, table: &str, key: &[u8]) -> DbxResult<Option<Vec<u8>>> {
let Some(table_map) = self.tables.get(table) else {
return Ok(None);
};
Ok(table_map.get(key).map(|e| (**e.value()).clone()))
}
fn delete(&self, table: &str, key: &[u8]) -> DbxResult<bool> {
let Some(table_map) = self.tables.get(table) else {
return Ok(false);
};
if table_map.remove(key).is_some() {
self.entry_count
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
Ok(true)
} else {
Ok(false)
}
}
fn scan<R: RangeBounds<Vec<u8>> + Clone>(
&self,
table: &str,
range: R,
) -> DbxResult<Vec<(Vec<u8>, Vec<u8>)>> {
let Some(table_map) = self.tables.get(table) else {
return Ok(Vec::new());
};
if table_map.is_empty() {
return Ok(Vec::new());
}
let start = Self::convert_bound(range.start_bound());
let end = Self::convert_bound(range.end_bound());
let entries: Vec<(Vec<u8>, Vec<u8>)> = table_map
.range((start, end))
.map(|e| (e.key().clone(), (**e.value()).clone()))
.collect();
Ok(entries)
}
fn scan_one<R: RangeBounds<Vec<u8>> + Clone>(
&self,
table: &str,
range: R,
) -> DbxResult<Option<(Vec<u8>, Vec<u8>)>> {
let Some(table_map) = self.tables.get(table) else {
return Ok(None);
};
let start = Self::convert_bound(range.start_bound());
let end = Self::convert_bound(range.end_bound());
Ok(table_map
.range((start, end))
.next()
.map(|e| (e.key().clone(), (**e.value()).clone())))
}
fn flush(&self) -> DbxResult<()> {
Ok(())
}
fn count(&self, table: &str) -> DbxResult<usize> {
let Some(table_map) = self.tables.get(table) else {
return Ok(0);
};
Ok(table_map.len())
}
fn table_names(&self) -> DbxResult<Vec<String>> {
Ok(self.tables.iter().map(|e| e.key().clone()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transaction::mvcc::version::VersionedKey;
#[test]
fn insert_and_get() {
let store = DeltaStore::new();
store.insert("users", b"key1", b"value1").unwrap();
let result = store.get("users", b"key1").unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[test]
fn test_versioned_storage() {
let store = DeltaStore::new();
let vk1 = VersionedKey::new(b"key1".to_vec(), 100);
let vk2 = VersionedKey::new(b"key1".to_vec(), 200);
store.insert("users", &vk1.encode(), b"v1").unwrap();
store.insert("users", &vk2.encode(), b"v2").unwrap();
assert_eq!(
store.get("users", &vk1.encode()).unwrap(),
Some(b"v1".to_vec())
);
assert_eq!(
store.get("users", &vk2.encode()).unwrap(),
Some(b"v2".to_vec())
);
let results = store.scan("users", Vec::<u8>::new()..).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(VersionedKey::decode(&results[0].0).unwrap().commit_ts, 200);
assert_eq!(VersionedKey::decode(&results[1].0).unwrap().commit_ts, 100);
}
#[test]
fn delete_existing_key() {
let store = DeltaStore::new();
store.insert("users", b"key1", b"value1").unwrap();
assert!(store.delete("users", b"key1").unwrap());
assert_eq!(store.get("users", b"key1").unwrap(), None);
}
#[test]
fn entry_count_tracking() {
let store = DeltaStore::new();
assert_eq!(store.entry_count(), 0);
store.insert("t1", b"a", b"1").unwrap();
store.insert("t1", b"b", b"2").unwrap();
store.insert("t2", b"c", b"3").unwrap();
assert_eq!(store.entry_count(), 3);
}
}