armdb 0.1.11

sharded bitcask key-value storage optimized for NVMe
Documentation
pub mod iter;
pub mod node;

use std::ptr;

#[cfg(feature = "loom")]
use loom::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(feature = "loom"))]
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::sync::{self, Mutex};
use seize::Collector;

pub use self::node::SkipNode;

/// Strip the mark tag bit from a tower pointer.
#[inline(always)]
pub(crate) fn strip_mark<T>(ptr: *mut T) -> *mut T {
    (ptr as usize & !1) as *mut T
}

/// Result of an insert operation.
pub enum InsertResult<'g, N> {
    /// The key was newly inserted.
    Inserted,
    /// The key already existed. Returns a reference to the existing node.
    Exists(&'g N),
}

/// A concurrent skip list using `seize` for memory reclamation.
///
/// **Concurrency model:**
/// - Reads (`get`, iterators) are lock-free, protected by `seize::Guard`.
/// - Writes (`insert`, `remove`) are serialized by an internal `write_lock`.
///   During normal operation the shard Mutex already provides per-key serialization,
///   so the write_lock is uncontended. During parallel recovery it ensures safety.
pub struct SkipList<N: SkipNode> {
    /// Sentinel head node. Has MAX_HEIGHT, never contains real data, never removed.
    head: *mut N,
    collector: Collector,
    len: AtomicUsize,
    height: AtomicUsize,
    write_lock: Mutex<()>,
    reversed: bool,
}

// SAFETY: SkipList is designed for concurrent access. Head is a stable allocation.
// All node access is protected by seize guards.
unsafe impl<N: SkipNode> Send for SkipList<N> {}
unsafe impl<N: SkipNode> Sync for SkipList<N> {}

impl<N: SkipNode> SkipList<N> {
    pub fn new(reversed: bool) -> Self {
        let head = N::alloc_head();
        Self {
            head,
            collector: Collector::new(),
            len: AtomicUsize::new(0),
            height: AtomicUsize::new(1),
            write_lock: Mutex::new(()),
            reversed,
        }
    }

    /// Compare keys respecting the `reversed` flag.
    #[inline(always)]
    fn key_cmp(&self, a: &[u8], b: &[u8]) -> std::cmp::Ordering {
        if self.reversed { b.cmp(a) } else { a.cmp(b) }
    }

    /// Collector getter (for creating guards externally).
    pub fn collector(&self) -> &Collector {
        &self.collector
    }

    /// Head sentinel pointer (for iteration from the beginning).
    pub fn head_ptr(&self) -> *mut N {
        self.head
    }

    pub fn len(&self) -> usize {
        self.len.load(Ordering::Relaxed)
    }

    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Lock-free lookup. Returns a reference valid for the lifetime of the guard.
    pub fn get<'g>(&self, key: &[u8], guard: &'g seize::LocalGuard<'_>) -> Option<&'g N> {
        let _ = guard; // lifetime anchor
        let mut current = self.head;
        let h = self.height.load(Ordering::Relaxed);

        // Traverse from the top level down
        for level in (0..h).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    // Skip logically deleted nodes — no CAS, keep read path
                    // hardware-read-only to avoid cache-line bouncing.
                    // Writers clean up marked nodes during their traversal.
                    current = next;
                    continue;
                }
                match self.key_cmp(next_ref.key_bytes(), key) {
                    std::cmp::Ordering::Less => current = next,
                    std::cmp::Ordering::Equal => {
                        return Some(next_ref);
                    }
                    std::cmp::Ordering::Greater => break,
                }
            }
        }
        None
    }

    /// Insert a node into the skip list.
    ///
    /// If the key already exists (and is not marked), returns `Exists` with a reference to it.
    pub fn insert<'g>(
        &self,
        node: *mut N,
        guard: &'g seize::LocalGuard<'_>,
    ) -> InsertResult<'g, N> {
        let _lock = sync::lock(&self.write_lock);
        let _ = guard;
        let key = unsafe { (*node).key_bytes() };
        let node_height = unsafe { (*node).height() } as usize;

        // Update max height
        let mut current_height = self.height.load(Ordering::Relaxed);
        while node_height > current_height {
            match self.height.compare_exchange_weak(
                current_height,
                node_height,
                Ordering::Relaxed,
                Ordering::Relaxed,
            ) {
                Ok(_) => break,
                Err(h) => current_height = h,
            }
        }
        let search_height = self.height.load(Ordering::Relaxed);

        // Find predecessors and successors at each level
        let mut preds: [*mut N; node::MAX_HEIGHT] = [ptr::null_mut(); node::MAX_HEIGHT];
        let mut succs: [*mut N; node::MAX_HEIGHT] = [ptr::null_mut(); node::MAX_HEIGHT];

        let mut current = self.head;
        for level in (0..search_height).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    // Help unlink marked nodes
                    let after = strip_mark(unsafe { (*next).tower(level).load(Ordering::Acquire) });
                    let _ = unsafe {
                        (*current).tower(level).compare_exchange(
                            next,
                            after,
                            Ordering::AcqRel,
                            Ordering::Relaxed,
                        )
                    };
                    continue;
                }
                match self.key_cmp(next_ref.key_bytes(), key) {
                    std::cmp::Ordering::Less => current = next,
                    std::cmp::Ordering::Equal => {
                        // Key exists — return it
                        return InsertResult::Exists(next_ref);
                    }
                    std::cmp::Ordering::Greater => break,
                }
            }
            preds[level] = current;
            succs[level] = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
        }

        // Link the new node bottom-up
        #[allow(clippy::needless_range_loop)]
        for level in 0..node_height {
            unsafe {
                (*node).tower(level).store(succs[level], Ordering::Relaxed);
            }
        }
        // Publish the node by linking predecessors to it
        #[allow(clippy::needless_range_loop)]
        for level in 0..node_height {
            unsafe {
                (*preds[level]).tower(level).store(node, Ordering::Release);
            }
        }

        self.len.fetch_add(1, Ordering::Relaxed);
        InsertResult::Inserted
    }

    /// Remove a node by key. Returns the removed node pointer (for value extraction).
    pub fn remove(&self, key: &[u8], guard: &seize::LocalGuard<'_>) -> Option<*mut N> {
        let _lock = sync::lock(&self.write_lock);
        let _ = guard;
        let search_height = self.height.load(Ordering::Relaxed);

        // Find the node and its predecessors
        let mut preds: [*mut N; node::MAX_HEIGHT] = [ptr::null_mut(); node::MAX_HEIGHT];
        let mut found: *mut N = ptr::null_mut();

        let mut current = self.head;
        for level in (0..search_height).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    let after = strip_mark(unsafe { (*next).tower(level).load(Ordering::Acquire) });
                    let _ = unsafe {
                        (*current).tower(level).compare_exchange(
                            next,
                            after,
                            Ordering::AcqRel,
                            Ordering::Relaxed,
                        )
                    };
                    continue;
                }
                match self.key_cmp(next_ref.key_bytes(), key) {
                    std::cmp::Ordering::Less => current = next,
                    std::cmp::Ordering::Equal => {
                        found = next;
                        break;
                    }
                    std::cmp::Ordering::Greater => break,
                }
            }
            preds[level] = current;
        }

        if found.is_null() {
            return None;
        }

        // Mark the node as logically deleted
        let node_ref = unsafe { &*found };
        if !node_ref.mark() {
            // Already marked by someone else
            return None;
        }

        // Physically unlink at each level (top-down)
        let node_h = node_ref.height() as usize;
        for level in (0..node_h).rev() {
            let next = strip_mark(unsafe { (*found).tower(level).load(Ordering::Acquire) });
            let _ = unsafe {
                (*preds[level]).tower(level).compare_exchange(
                    found,
                    next,
                    Ordering::AcqRel,
                    Ordering::Relaxed,
                )
            };
        }

        // Retire the node via seize (no-op under loom: leak is acceptable)
        #[cfg(not(feature = "loom"))]
        unsafe {
            self.collector.retire(found, N::reclaim);
        }

        self.len.fetch_sub(1, Ordering::Relaxed);
        Some(found)
    }

    /// Find the first node whose key is >= `start_key`. Used by iterators.
    pub(crate) fn find_first_ge(&self, start_key: &[u8], guard: &seize::LocalGuard<'_>) -> *mut N {
        let _ = guard;
        let mut current = self.head;
        let h = self.height.load(Ordering::Relaxed);

        for level in (0..h).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    current = next;
                    continue;
                }
                if self.key_cmp(next_ref.key_bytes(), start_key) == std::cmp::Ordering::Less {
                    current = next;
                } else {
                    break;
                }
            }
        }

        // current is the last node < start_key. The next on level 0 is >= start_key.
        strip_mark(unsafe { (*current).tower(0).load(Ordering::Acquire) })
    }

    /// Find the last node whose key is strictly less than `key` in list order.
    /// Returns null if no such node exists.
    /// Caller must hold a seize guard.
    pub(crate) fn find_last_lt(&self, key: &[u8]) -> *mut N {
        let mut current = self.head;
        let h = self.height.load(Ordering::Relaxed);

        for level in (0..h).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    current = next;
                    continue;
                }
                if self.key_cmp(next_ref.key_bytes(), key) == std::cmp::Ordering::Less {
                    current = next;
                } else {
                    break;
                }
            }
        }

        if current == self.head {
            ptr::null_mut()
        } else {
            current
        }
    }

    /// Find the last node in list order. Returns null if the list is empty.
    /// Caller must hold a seize guard.
    pub(crate) fn find_last(&self) -> *mut N {
        let mut current = self.head;
        let h = self.height.load(Ordering::Relaxed);

        for level in (0..h).rev() {
            loop {
                let next = strip_mark(unsafe { (*current).tower(level).load(Ordering::Acquire) });
                if next.is_null() {
                    break;
                }
                let next_ref = unsafe { &*next };
                if next_ref.is_marked() {
                    current = next;
                    continue;
                }
                current = next;
            }
        }

        if current == self.head {
            ptr::null_mut()
        } else {
            current
        }
    }
}

impl<N: SkipNode> Drop for SkipList<N> {
    fn drop(&mut self) {
        // Walk level 0 and free all VLA-allocated nodes (including head)
        let mut current = self.head;
        while !current.is_null() {
            let next = strip_mark(unsafe { (*current).tower(0).load(Ordering::Relaxed) });
            unsafe {
                N::dealloc_node(current);
            }
            current = next;
        }
    }
}