con-art-rust 0.2.0

A Rust implementation of ART-OLC concurrent adaptive radix tree.
Documentation
#[cfg(shuttle)]
use shuttle::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(all(shuttle)))]
use std::sync::atomic::{AtomicUsize, Ordering};

use crossbeam_epoch::Guard;

use crate::{
    child_ptr::NodePtr,
    lock::{ConcreteReadGuard, ReadGuard, WriteGuard},
    node_16::Node16,
    node_256::Node256,
    node_4::Node4,
    node_48::Node48,
    utils::convert_type_to_version,
};

pub(crate) const MAX_STORED_PREFIX_LEN: usize = 10;
pub(crate) type Prefix = [u8; MAX_STORED_PREFIX_LEN];

#[repr(u8)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) enum NodeType {
    N4 = 0,
    N16 = 1,
    N48 = 2,
    N256 = 3,
}

impl NodeType {
    fn node_layout(&self) -> std::alloc::Layout {
        match *self {
            NodeType::N4 => std::alloc::Layout::from_size_align(
                std::mem::size_of::<Node4>(),
                std::mem::align_of::<Node4>(),
            )
            .unwrap(),
            NodeType::N16 => std::alloc::Layout::from_size_align(
                std::mem::size_of::<Node16>(),
                std::mem::align_of::<Node16>(),
            )
            .unwrap(),
            NodeType::N48 => std::alloc::Layout::from_size_align(
                std::mem::size_of::<Node48>(),
                std::mem::align_of::<Node48>(),
            )
            .unwrap(),
            NodeType::N256 => std::alloc::Layout::from_size_align(
                std::mem::size_of::<Node256>(),
                std::mem::align_of::<Node256>(),
            )
            .unwrap(),
        }
    }
}

pub(crate) trait Node {
    fn new(prefix: &[u8]) -> Box<Self>;
    fn base(&self) -> &BaseNode;
    fn base_mut(&mut self) -> &mut BaseNode;
    fn is_full(&self) -> bool;
    fn is_under_full(&self) -> bool;
    fn insert(&mut self, key: u8, node: NodePtr);
    fn change(&mut self, key: u8, val: NodePtr);
    fn get_child(&self, key: u8) -> Option<NodePtr>;
    fn get_children(&self, start: u8, end: u8) -> Vec<(u8, NodePtr)>;
    fn remove(&mut self, k: u8);
    fn copy_to<N: Node>(&self, dst: &mut N);
    fn get_type() -> NodeType;
}

#[repr(C)]
pub(crate) struct BaseNode {
    // 2b type | 60b version | 1b lock | 1b obsolete
    pub(crate) type_version_lock_obsolete: AtomicUsize,
    pub(crate) prefix_cnt: u32,
    pub(crate) count: u16, // TODO: we only need u8
    pub(crate) prefix: Prefix,
}

impl Drop for BaseNode {
    fn drop(&mut self) {
        let layout = self.get_type().node_layout();
        unsafe {
            std::alloc::dealloc(self as *mut BaseNode as *mut u8, layout);
        }
    }
}

macro_rules! gen_method {
    ($method_name:ident, ($($arg_n:ident : $args:ty),*), $return:ty) => {
        impl BaseNode {
            pub(crate) fn $method_name(&self, $($arg_n : $args),*) -> $return {
                match self.get_type() {
                    NodeType::N4 => {
                        let node = unsafe{&* (self as *const BaseNode as *const Node4)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N16 => {
                        let node = unsafe{&* (self as *const BaseNode as *const Node16)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N48 => {
                        let node = unsafe{&* (self as *const BaseNode as *const Node48)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N256 => {
                        let node = unsafe{&* (self as *const BaseNode as *const Node256)};
                        node.$method_name($($arg_n),*)
                    },
                }
            }
        }
    };
}

macro_rules! gen_method_mut {
    ($method_name:ident, ($($arg_n:ident : $args:ty),*), $return:ty) => {
        impl BaseNode {
            pub(crate) fn $method_name(&mut self, $($arg_n : $args),*) -> $return {
                match self.get_type() {
                    NodeType::N4 => {
                        let node = unsafe{&mut * (self as *mut BaseNode as *mut Node4)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N16 => {
                        let node = unsafe{&mut * (self as *mut BaseNode as *mut Node16)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N48 => {
                        let node = unsafe{&mut * (self as *mut BaseNode as *mut Node48)};
                        node.$method_name($($arg_n),*)
                    },
                    NodeType::N256 => {
                        let node = unsafe{&mut * (self as *mut BaseNode as *mut Node256)};
                        node.$method_name($($arg_n),*)
                    },
                }
            }
        }
    };
}

gen_method!(get_child, (k: u8), Option<NodePtr>);
gen_method!(get_children, (start: u8, end: u8), Vec<(u8, NodePtr)>);
gen_method_mut!(change, (key: u8, val: NodePtr), ());
gen_method_mut!(remove, (key: u8), ());

impl BaseNode {
    pub(crate) fn new(n_type: NodeType, prefix: &[u8]) -> Self {
        let val = convert_type_to_version(n_type);
        let mut prefix_v: [u8; MAX_STORED_PREFIX_LEN] = [0; MAX_STORED_PREFIX_LEN];

        assert!(prefix.len() <= MAX_STORED_PREFIX_LEN);
        for (i, v) in prefix.iter().enumerate() {
            prefix_v[i] = *v;
        }

        BaseNode {
            type_version_lock_obsolete: AtomicUsize::new(val),
            prefix_cnt: prefix.len() as u32,
            count: 0,
            prefix: prefix_v,
        }
    }

    #[allow(dead_code)]
    fn set_type(&self, n_type: NodeType) {
        let val = convert_type_to_version(n_type);
        self.type_version_lock_obsolete
            .fetch_add(val, Ordering::Release);
    }

    pub(crate) fn get_type(&self) -> NodeType {
        let val = self.type_version_lock_obsolete.load(Ordering::Relaxed);
        let val = val >> 62;
        debug_assert!(val < 4);
        unsafe { std::mem::transmute(val as u8) }
    }

    pub(crate) fn set_prefix(&mut self, prefix: &[u8]) {
        let len = prefix.len();
        self.prefix_cnt = len as u32;

        for (i, v) in prefix.iter().enumerate() {
            self.prefix[i] = *v;
        }
    }

    pub(crate) fn read_lock(&self) -> Result<ReadGuard, usize> {
        let version = self.type_version_lock_obsolete.load(Ordering::Acquire);
        if Self::is_locked(version) || Self::is_obsolete(version) {
            return Err(version);
        }

        Ok(ReadGuard::new(version, self))
    }

    #[allow(dead_code)]
    pub(crate) fn write_lock(&self) -> Result<WriteGuard, usize> {
        let read = self.read_lock()?;
        read.upgrade().map_err(|v| v.1)
    }

    fn is_locked(version: usize) -> bool {
        (version & 0b10) == 0b10
    }

    pub(crate) fn get_count(&self) -> usize {
        self.count as usize
    }

    fn is_obsolete(version: usize) -> bool {
        (version & 1) == 1
    }

    pub(crate) fn has_prefix(&self) -> bool {
        self.prefix_cnt > 0
    }

    pub(crate) fn prefix_len(&self) -> u32 {
        self.prefix_cnt
    }

    pub(crate) fn prefix(&self) -> &[u8] {
        self.prefix[..self.prefix_cnt as usize].as_ref()
    }

    pub(crate) fn insert_grow<CurT: Node, BiggerT: Node>(
        n: ConcreteReadGuard<CurT>,
        parent_node: Option<ReadGuard>,
        key_parent: u8,
        key: u8,
        val: NodePtr,
        guard: &Guard,
    ) -> Result<(), ()> {
        if !n.as_ref().is_full() {
            if let Some(p) = parent_node {
                p.unlock().map_err(|_| ())?;
            }

            let mut write_n = n.upgrade_to_write_lock().map_err(|_| ())?;

            write_n.as_mut().insert(key, val);
            return Ok(());
        }

        let p = parent_node.expect("parent node must present when current node is full");

        let mut write_p = p.upgrade().map_err(|_| ())?;

        let mut write_n = n.upgrade_to_write_lock().map_err(|_| ())?;

        let mut n_big = BiggerT::new(write_n.as_ref().base().prefix());
        write_n.as_ref().copy_to(n_big.as_mut());
        n_big.insert(key, val);

        write_p.as_mut().change(
            key_parent,
            NodePtr::from_node(Box::into_raw(n_big) as *mut BaseNode),
        );

        write_n.mark_obsolete();
        let delete_n = write_n.as_mut() as *mut CurT as usize;
        std::mem::forget(write_n);
        guard.defer(move || unsafe {
            std::ptr::drop_in_place(delete_n as *mut BaseNode);
        });
        Ok(())
    }

    pub(crate) fn insert_and_unlock(
        node: ReadGuard,
        parent: Option<ReadGuard>,
        key_parent: u8,
        key: u8,
        val: NodePtr,
        guard: &Guard,
    ) -> Result<(), ()> {
        match node.as_ref().get_type() {
            NodeType::N4 => Self::insert_grow::<Node4, Node16>(
                node.into_concrete(),
                parent,
                key_parent,
                key,
                val,
                guard,
            ),
            NodeType::N16 => Self::insert_grow::<Node16, Node48>(
                node.into_concrete(),
                parent,
                key_parent,
                key,
                val,
                guard,
            ),
            NodeType::N48 => Self::insert_grow::<Node48, Node256>(
                node.into_concrete(),
                parent,
                key_parent,
                key,
                val,
                guard,
            ),
            NodeType::N256 => Self::insert_grow::<Node256, Node256>(
                node.into_concrete(),
                parent,
                key_parent,
                key,
                val,
                guard,
            ),
        }
    }
}