use std::collections::HashSet;
use std::ptr;
use std::sync::atomic::{AtomicPtr, AtomicU8, AtomicU64, AtomicUsize, Ordering};
use dashmap::DashMap;
use parking_lot::Mutex;
use sochdb_core::{Result, SochDBError};
const HP_PER_THREAD: usize = 2;
const MAX_THREADS: usize = 128;
const RECLAMATION_THRESHOLD: usize = 64;
const FAT_NODE_SLOTS: usize = 8;
pub const INLINE_VALUE_SIZE: usize = 56;
#[repr(C)]
pub enum ValueStorage {
Inline {
len: u8,
data: [u8; INLINE_VALUE_SIZE],
},
Heap(Box<[u8]>),
Tombstone,
}
impl std::fmt::Debug for ValueStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValueStorage::Inline { len, .. } => write!(f, "Inline(len={})", len),
ValueStorage::Heap(data) => write!(f, "Heap(len={})", data.len()),
ValueStorage::Tombstone => write!(f, "Tombstone"),
}
}
}
impl ValueStorage {
#[inline]
pub fn new(value: Option<&[u8]>) -> Self {
match value {
None => ValueStorage::Tombstone,
Some(v) if v.len() <= INLINE_VALUE_SIZE => {
let mut data = [0u8; INLINE_VALUE_SIZE];
data[..v.len()].copy_from_slice(v);
ValueStorage::Inline {
len: v.len() as u8,
data,
}
}
Some(v) => ValueStorage::Heap(v.to_vec().into_boxed_slice()),
}
}
#[inline]
pub fn as_bytes(&self) -> Option<&[u8]> {
match self {
ValueStorage::Inline { len, data } => Some(&data[..*len as usize]),
ValueStorage::Heap(data) => Some(data),
ValueStorage::Tombstone => None,
}
}
#[inline]
pub fn is_tombstone(&self) -> bool {
matches!(self, ValueStorage::Tombstone)
}
#[inline]
pub fn is_inline(&self) -> bool {
matches!(self, ValueStorage::Inline { .. })
}
#[inline]
pub fn len(&self) -> usize {
match self {
ValueStorage::Inline { len, .. } => *len as usize,
ValueStorage::Heap(data) => data.len(),
ValueStorage::Tombstone => 0,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug)]
pub struct LockFreeVersion {
pub storage: ValueStorage,
pub txn_id: u64,
pub commit_ts: AtomicU64,
pub next: AtomicPtr<LockFreeVersion>,
}
impl LockFreeVersion {
#[inline]
pub fn new_from_slice(value: Option<&[u8]>, txn_id: u64) -> Self {
Self {
storage: ValueStorage::new(value),
txn_id,
commit_ts: AtomicU64::new(0),
next: AtomicPtr::new(ptr::null_mut()),
}
}
pub fn new(value: Option<Vec<u8>>, txn_id: u64) -> Self {
Self::new_from_slice(value.as_deref(), txn_id)
}
#[inline]
pub fn get_value(&self) -> Option<&[u8]> {
self.storage.as_bytes()
}
#[inline]
pub fn value_cloned(&self) -> Option<Vec<u8>> {
self.storage.as_bytes().map(|v| v.to_vec())
}
#[inline]
pub fn is_committed(&self) -> bool {
self.commit_ts.load(Ordering::Acquire) > 0
}
#[inline]
pub fn get_commit_ts(&self) -> u64 {
self.commit_ts.load(Ordering::Acquire)
}
#[inline]
pub fn set_commit_ts(&self, ts: u64) {
self.commit_ts.store(ts, Ordering::Release);
}
#[inline]
pub fn is_inline(&self) -> bool {
self.storage.is_inline()
}
}
pub struct FatNode {
count: AtomicU8,
slots: [AtomicPtr<LockFreeVersion>; FAT_NODE_SLOTS],
next: AtomicPtr<FatNode>,
}
impl FatNode {
fn new_with_first(version: *mut LockFreeVersion, older: *mut FatNode) -> Self {
let slots = std::array::from_fn(|i| {
if i == 0 {
AtomicPtr::new(version)
} else {
AtomicPtr::new(ptr::null_mut())
}
});
Self {
count: AtomicU8::new(1),
slots,
next: AtomicPtr::new(older),
}
}
#[inline]
fn try_push(
&self,
version: *mut LockFreeVersion,
) -> std::result::Result<(), *mut LockFreeVersion> {
loop {
let c = self.count.load(Ordering::Acquire);
if c as usize >= FAT_NODE_SLOTS {
return Err(version); }
match self
.count
.compare_exchange(c, c + 1, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
self.slots[c as usize].store(version, Ordering::Release);
return Ok(());
}
Err(_) => continue, }
}
}
#[inline]
fn slot(&self, idx: u8) -> *mut LockFreeVersion {
self.slots[idx as usize].load(Ordering::Acquire)
}
#[inline]
fn iter_newest_first(&self) -> impl Iterator<Item = &LockFreeVersion> {
let count = self.count.load(Ordering::Acquire);
(0..count).rev().filter_map(move |i| {
let ptr = self.slots[i as usize].load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
}
})
}
}
pub struct LockFreeVersionChain {
head: AtomicPtr<FatNode>,
}
impl Default for LockFreeVersionChain {
fn default() -> Self {
Self::new()
}
}
impl LockFreeVersionChain {
pub fn new() -> Self {
Self {
head: AtomicPtr::new(ptr::null_mut()),
}
}
pub fn add_uncommitted(&self, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
let new_version = Box::into_raw(Box::new(LockFreeVersion::new(value, txn_id)));
loop {
let head = self.head.load(Ordering::Acquire);
if !head.is_null() {
let fat = unsafe { &*head };
let count = fat.count.load(Ordering::Acquire);
if count > 0 {
let newest = fat.slot(count - 1);
if !newest.is_null() {
let newest_ref = unsafe { &*newest };
if !newest_ref.is_committed() && newest_ref.txn_id != txn_id {
unsafe {
drop(Box::from_raw(new_version));
}
return Err(SochDBError::Internal("Write-write conflict".into()));
}
}
}
match fat.try_push(new_version) {
Ok(()) => return Ok(()),
Err(_) => {
let new_fat =
Box::into_raw(Box::new(FatNode::new_with_first(new_version, head)));
match self.head.compare_exchange(
head,
new_fat,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Ok(()),
Err(_) => {
unsafe {
(*new_fat).slots[0].store(ptr::null_mut(), Ordering::Relaxed);
(*new_fat).count.store(0, Ordering::Relaxed);
drop(Box::from_raw(new_fat));
}
continue; }
}
}
}
} else {
let new_fat = Box::into_raw(Box::new(FatNode::new_with_first(
new_version,
ptr::null_mut(),
)));
match self
.head
.compare_exchange(head, new_fat, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => return Ok(()),
Err(_) => {
unsafe {
(*new_fat).slots[0].store(ptr::null_mut(), Ordering::Relaxed);
(*new_fat).count.store(0, Ordering::Relaxed);
drop(Box::from_raw(new_fat));
}
continue;
}
}
}
}
}
pub fn commit(&self, txn_id: u64, commit_ts: u64) -> bool {
let mut fat_ptr = self.head.load(Ordering::Acquire);
while !fat_ptr.is_null() {
let fat = unsafe { &*fat_ptr };
for ver in fat.iter_newest_first() {
if ver.txn_id == txn_id && !ver.is_committed() {
ver.set_commit_ts(commit_ts);
return true;
}
}
fat_ptr = fat.next.load(Ordering::Acquire);
}
false
}
pub fn read_at(
&self,
snapshot_ts: u64,
current_txn_id: Option<u64>,
) -> Option<&LockFreeVersion> {
let mut fat_ptr = self.head.load(Ordering::Acquire);
while !fat_ptr.is_null() {
let fat = unsafe { &*fat_ptr };
for version in fat.iter_newest_first() {
if let Some(txn_id) = current_txn_id
&& version.txn_id == txn_id
&& !version.is_committed()
{
return Some(version);
}
let commit_ts = version.get_commit_ts();
if commit_ts > 0 && commit_ts < snapshot_ts {
return Some(version);
}
}
fat_ptr = fat.next.load(Ordering::Acquire);
}
None
}
pub fn has_write_conflict(&self, my_txn_id: u64) -> bool {
let head = self.head.load(Ordering::Acquire);
if !head.is_null() {
let fat = unsafe { &*head };
let count = fat.count.load(Ordering::Acquire);
if count > 0 {
let newest = fat.slot(count - 1);
if !newest.is_null() {
let version = unsafe { &*newest };
return !version.is_committed() && version.txn_id != my_txn_id;
}
}
}
false
}
}
#[repr(C, align(64))]
struct HazardRecord {
hazard: [AtomicPtr<LockFreeVersion>; HP_PER_THREAD],
active: AtomicU64,
}
impl HazardRecord {
const fn new() -> Self {
Self {
hazard: [
AtomicPtr::new(ptr::null_mut()),
AtomicPtr::new(ptr::null_mut()),
],
active: AtomicU64::new(0),
}
}
fn try_acquire(&self, thread_id: u64) -> bool {
self.active
.compare_exchange(0, thread_id, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
#[allow(dead_code)]
fn release(&self) {
for hp in &self.hazard {
hp.store(ptr::null_mut(), Ordering::Release);
}
self.active.store(0, Ordering::Release);
}
}
pub struct HazardDomain {
records: Vec<HazardRecord>,
retired: Mutex<Vec<*mut LockFreeVersion>>,
}
impl HazardDomain {
pub fn new(max_threads: usize) -> Self {
let mut records = Vec::with_capacity(max_threads);
for _ in 0..max_threads {
records.push(HazardRecord::new());
}
Self {
records,
retired: Mutex::new(Vec::with_capacity(RECLAMATION_THRESHOLD * 2)),
}
}
fn get_record(&self) -> Option<&HazardRecord> {
let thread_id = thread_id::get() as u64;
for record in &self.records {
if record.active.load(Ordering::Acquire) == thread_id {
return Some(record);
}
}
self.records
.iter()
.find(|record| record.try_acquire(thread_id))
}
#[inline]
pub fn protect(&self, ptr: *mut LockFreeVersion, slot: usize) -> bool {
if let Some(record) = self.get_record()
&& slot < HP_PER_THREAD
{
record.hazard[slot].store(ptr, Ordering::Release);
std::sync::atomic::fence(Ordering::SeqCst);
return true;
}
false
}
#[inline]
pub fn clear(&self, slot: usize) {
if let Some(record) = self.get_record()
&& slot < HP_PER_THREAD
{
record.hazard[slot].store(ptr::null_mut(), Ordering::Release);
}
}
pub fn retire(&self, ptr: *mut LockFreeVersion) {
let mut retired = self.retired.lock();
retired.push(ptr);
if retired.len() >= RECLAMATION_THRESHOLD {
self.try_reclaim(&mut retired);
}
}
fn try_reclaim(&self, retired: &mut Vec<*mut LockFreeVersion>) {
let mut protected: HashSet<usize> = HashSet::new();
for record in &self.records {
if record.active.load(Ordering::Acquire) != 0 {
for hp in &record.hazard {
let ptr = hp.load(Ordering::Acquire);
if !ptr.is_null() {
protected.insert(ptr as usize);
}
}
}
}
let mut still_retired = Vec::new();
for ptr in retired.drain(..) {
if protected.contains(&(ptr as usize)) {
still_retired.push(ptr);
} else {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
*retired = still_retired;
}
}
impl Drop for HazardDomain {
fn drop(&mut self) {
let mut retired = self.retired.lock();
for ptr in retired.drain(..) {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
}
mod thread_id {
use std::sync::atomic::{AtomicUsize, Ordering};
static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
thread_local! {
static THREAD_ID: usize = NEXT_ID.fetch_add(1, Ordering::Relaxed);
}
pub fn get() -> usize {
THREAD_ID.with(|id| *id)
}
}
pub struct LockFreeMemTable {
data: DashMap<Vec<u8>, LockFreeVersionChain>,
hazard_domain: HazardDomain,
size_bytes: AtomicUsize,
}
impl LockFreeMemTable {
pub fn new() -> Self {
Self {
data: DashMap::new(),
hazard_domain: HazardDomain::new(MAX_THREADS),
size_bytes: AtomicUsize::new(0),
}
}
pub fn read(&self, key: &[u8], snapshot_ts: u64, txn_id: Option<u64>) -> Option<Vec<u8>> {
let chain = self.data.get(key)?;
if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
self.hazard_domain.protect(ptr, 0);
let result = version.value_cloned();
self.hazard_domain.clear(0);
result
} else {
None
}
}
#[inline]
pub fn read_with<F, R>(
&self,
key: &[u8],
snapshot_ts: u64,
txn_id: Option<u64>,
f: F,
) -> Option<R>
where
F: FnOnce(&[u8]) -> R,
{
let chain = self.data.get(key)?;
if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
self.hazard_domain.protect(ptr, 0);
let result = version.get_value().map(f);
self.hazard_domain.clear(0);
result
} else {
None
}
}
pub fn write(&self, key: Vec<u8>, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
let value_size = value.as_ref().map(|v| v.len()).unwrap_or(0);
let chain = self.data.entry(key.clone()).or_default();
chain.add_uncommitted(value, txn_id)?;
self.size_bytes
.fetch_add(key.len() + value_size + 64, Ordering::Relaxed);
Ok(())
}
pub fn commit(&self, txn_id: u64, commit_ts: u64, keys: &[Vec<u8>]) {
for key in keys {
if let Some(chain) = self.data.get(key) {
chain.commit(txn_id, commit_ts);
}
}
}
pub fn has_write_conflict(&self, key: &[u8], txn_id: u64) -> bool {
if let Some(chain) = self.data.get(key) {
chain.has_write_conflict(txn_id)
} else {
false
}
}
pub fn size_bytes(&self) -> usize {
self.size_bytes.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
unsafe impl Send for LockFreeMemTable {}
unsafe impl Sync for LockFreeMemTable {}
impl Default for LockFreeMemTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_write_read() {
let memtable = LockFreeMemTable::new();
memtable
.write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
.unwrap();
let val = memtable.read(b"key1", 100, Some(1));
assert_eq!(val, Some(b"value1".to_vec()));
let val = memtable.read(b"key1", 100, Some(2));
assert!(val.is_none());
memtable.commit(1, 50, &[b"key1".to_vec()]);
let val = memtable.read(b"key1", 100, None);
assert_eq!(val, Some(b"value1".to_vec()));
}
#[test]
fn test_snapshot_isolation() {
let memtable = LockFreeMemTable::new();
memtable
.write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
.unwrap();
memtable.commit(1, 10, &[b"key".to_vec()]);
memtable
.write(b"key".to_vec(), Some(b"v2".to_vec()), 2)
.unwrap();
memtable.commit(2, 20, &[b"key".to_vec()]);
assert_eq!(memtable.read(b"key", 15, None), Some(b"v1".to_vec()));
assert_eq!(memtable.read(b"key", 25, None), Some(b"v2".to_vec()));
}
#[test]
fn test_write_conflict() {
let memtable = LockFreeMemTable::new();
memtable
.write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
.unwrap();
let result = memtable.write(b"key".to_vec(), Some(b"v2".to_vec()), 2);
assert!(result.is_err());
let result = memtable.write(b"key".to_vec(), Some(b"v1_updated".to_vec()), 1);
assert!(result.is_ok());
}
#[test]
fn test_concurrent_reads() {
let memtable = Arc::new(LockFreeMemTable::new());
for i in 0..100 {
let key = format!("key{}", i).into_bytes();
let val = format!("value{}", i).into_bytes();
memtable.write(key.clone(), Some(val), 1).unwrap();
}
memtable.commit(
1,
10,
&(0..100)
.map(|i| format!("key{}", i).into_bytes())
.collect::<Vec<_>>(),
);
let handles: Vec<_> = (0..8)
.map(|t| {
let mt = Arc::clone(&memtable);
thread::spawn(move || {
for i in 0..100 {
let key = format!("key{}", i).into_bytes();
let expected = format!("value{}", i).into_bytes();
let val = mt.read(&key, 100, None);
assert_eq!(val, Some(expected), "Thread {} failed at key{}", t, i);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_inline_storage() {
let small_value = b"small".to_vec();
let version = LockFreeVersion::new(Some(small_value.clone()), 1);
assert!(version.is_inline(), "Small values should be inline");
assert_eq!(version.get_value(), Some(small_value.as_slice()));
let large_value = vec![42u8; 100]; let version = LockFreeVersion::new(Some(large_value.clone()), 2);
assert!(!version.is_inline(), "Large values should be on heap");
assert_eq!(version.get_value(), Some(large_value.as_slice()));
let version = LockFreeVersion::new(None, 3);
assert!(version.storage.is_tombstone());
assert_eq!(version.get_value(), None);
}
#[test]
fn test_inline_threshold() {
let value = vec![0u8; INLINE_VALUE_SIZE];
let version = LockFreeVersion::new(Some(value.clone()), 1);
assert!(version.is_inline(), "Values at threshold should be inline");
let value = vec![0u8; INLINE_VALUE_SIZE + 1];
let version = LockFreeVersion::new(Some(value), 2);
assert!(
!version.is_inline(),
"Values over threshold should be on heap"
);
}
#[test]
fn test_read_with_callback() {
let memtable = LockFreeMemTable::new();
memtable
.write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
.unwrap();
memtable.commit(1, 10, &[b"key1".to_vec()]);
let len = memtable.read_with(b"key1", 100, None, |v| v.len());
assert_eq!(len, Some(6));
let matches = memtable.read_with(b"key1", 100, None, |v| v == b"value1");
assert_eq!(matches, Some(true));
}
#[test]
fn test_fat_node_overflow() {
let memtable = LockFreeMemTable::new();
for i in 0..12u64 {
memtable
.write(b"key".to_vec(), Some(format!("v{}", i).into_bytes()), i + 1)
.unwrap();
memtable.commit(i + 1, (i + 1) * 10, &[b"key".to_vec()]);
}
let val = memtable.read(b"key", 200, None);
assert_eq!(val, Some(b"v11".to_vec()));
let val = memtable.read(b"key", 55, None);
assert_eq!(val, Some(b"v4".to_vec()));
let val = memtable.read(b"key", 5, None);
assert_eq!(val, None);
}
#[test]
fn test_fat_node_concurrent_writes() {
use std::sync::Arc;
use std::thread;
let memtable = Arc::new(LockFreeMemTable::new());
let mut handles = Vec::new();
for t in 0..4u64 {
let mt = Arc::clone(&memtable);
handles.push(thread::spawn(move || {
for i in 0..20u64 {
let key = format!("k{}-{}", t, i).into_bytes();
let val = format!("v{}-{}", t, i).into_bytes();
let txn_id = t * 1000 + i + 1;
mt.write(key.clone(), Some(val), txn_id).unwrap();
mt.commit(txn_id, txn_id * 10, &[key]);
}
}));
}
for h in handles {
h.join().unwrap();
}
for t in 0..4u64 {
for i in 0..20u64 {
let key = format!("k{}-{}", t, i).into_bytes();
let val = memtable.read(&key, u64::MAX, None);
assert_eq!(
val,
Some(format!("v{}-{}", t, i).into_bytes()),
"Missing key k{}-{}",
t,
i
);
}
}
}
}