con-art-rust 0.2.0

A Rust implementation of ART-OLC concurrent adaptive radix tree.
Documentation
use std::alloc;

use crate::{
    base_node::{BaseNode, Node, NodeType},
    child_ptr::NodePtr,
};

#[repr(C)]
pub(crate) struct Node16 {
    base: BaseNode,

    keys: [u8; 16],
    children: [NodePtr; 16],
}

impl Node16 {
    fn flip_sign(val: u8) -> u8 {
        val ^ 128
    }

    fn ctz(val: u16) -> u16 {
        std::intrinsics::cttz(val)
    }

    fn get_insert_pos(&self, key: u8) -> usize {
        let flipped = Self::flip_sign(key);

        #[cfg(all(target_feature = "sse2", not(miri)))]
        {
            unsafe {
                use std::arch::x86_64::{
                    __m128i, _mm_cmplt_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
                };
                let cmp = _mm_cmplt_epi8(
                    _mm_set1_epi8(flipped as i8),
                    _mm_loadu_si128(&self.keys as *const [u8; 16] as *const __m128i),
                );
                let bit_field = _mm_movemask_epi8(cmp) & (0xFFFF >> (16 - self.base.count));
                let pos = if bit_field > 0 {
                    Self::ctz(bit_field as u16)
                } else {
                    self.base.count as u16
                };
                pos as usize
            }
        }

        #[cfg(any(not(target_feature = "sse2"), miri))]
        {
            let mut pos = 0;
            while pos < self.base.count {
                if self.keys[pos as usize] >= flipped {
                    return pos as usize;
                }
                pos += 1;
            }
            pos as usize
        }
    }

    fn get_child_pos(&self, key: u8) -> Option<usize> {
        #[cfg(all(target_feature = "sse2", not(miri)))]
        unsafe {
            self.get_child_pos_sse2(key)
        }

        #[cfg(any(not(target_feature = "sse2"), miri))]
        self.get_child_pos_linear(key)
    }

    #[cfg(any(not(target_feature = "sse2"), miri))]
    fn get_child_pos_linear(&self, key: u8) -> Option<usize> {
        for i in 0..self.base.count {
            if self.keys[i as usize] == Self::flip_sign(key) {
                return Some(i as usize);
            }
        }
        None
    }

    #[target_feature(enable = "sse2")]
    #[allow(dead_code)]
    unsafe fn get_child_pos_sse2(&self, key: u8) -> Option<usize> {
        use std::arch::x86_64::{
            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
        };
        let cmp = _mm_cmpeq_epi8(
            _mm_set1_epi8(Self::flip_sign(key) as i8),
            _mm_loadu_si128(&self.keys as *const [u8; 16] as *const __m128i),
        );
        let bit_field = _mm_movemask_epi8(cmp) & ((1 << self.base.count) - 1);
        if bit_field > 0 {
            Some(Self::ctz(bit_field as u16) as usize)
        } else {
            None
        }
    }
}

impl Node for Node16 {
    fn new(prefix: &[u8]) -> Box<Self> {
        let layout = alloc::Layout::from_size_align(
            std::mem::size_of::<Node16>(),
            std::mem::align_of::<Node16>(),
        )
        .unwrap();
        unsafe {
            let mem = alloc::alloc_zeroed(layout) as *mut BaseNode;
            let base = BaseNode::new(NodeType::N16, prefix);
            mem.write(base);
            Box::from_raw(mem as *mut Node16)
        }
    }

    fn get_type() -> NodeType {
        NodeType::N16
    }

    fn get_children(&self, start: u8, end: u8) -> Vec<(u8, NodePtr)> {
        if self.base.count == 0 {
            // FIXME: the node may be empty due to deletion, this is not intended, we should fix the delete logic
            return vec![];
        }

        let mut children = Vec::with_capacity(16);

        let start_pos = self.get_child_pos(start).unwrap_or(0);
        let end_pos = self
            .get_child_pos(end)
            .unwrap_or(self.base.count as usize - 1);

        debug_assert!(end_pos < 16);

        for i in start_pos..=end_pos {
            children.push((Self::flip_sign(self.keys[i]), self.children[i]));
        }

        children
    }

    fn remove(&mut self, k: u8) {
        let pos = self
            .get_child_pos(k)
            .expect("trying to delete a non-existing key");
        unsafe {
            std::ptr::copy(
                self.keys.as_ptr().add(pos + 1),
                self.keys.as_mut_ptr().add(pos),
                self.base.count as usize - pos - 1,
            );

            std::ptr::copy(
                self.children.as_ptr().add(pos + 1),
                self.children.as_mut_ptr().add(pos),
                self.base.count as usize - pos - 1,
            );
        }
        self.base.count -= 1;
        debug_assert!(self.get_child(k).is_none());
    }

    fn copy_to<N: Node>(&self, dst: &mut N) {
        for i in 0..self.base.count {
            dst.insert(
                Self::flip_sign(self.keys[i as usize]),
                self.children[i as usize],
            );
        }
    }

    fn base(&self) -> &BaseNode {
        &self.base
    }

    fn base_mut(&mut self) -> &mut BaseNode {
        &mut self.base
    }

    fn is_full(&self) -> bool {
        self.base.count == 16
    }

    fn is_under_full(&self) -> bool {
        self.base.count == 3
    }

    // Insert must keep keys sorted, is this necessary?
    fn insert(&mut self, key: u8, node: NodePtr) {
        let key_flipped = Self::flip_sign(key);

        let pos = self.get_insert_pos(key);

        unsafe {
            std::ptr::copy(
                self.keys.as_ptr().add(pos),
                self.keys.as_mut_ptr().add(pos + 1),
                self.base.count as usize - pos,
            );

            std::ptr::copy(
                self.children.as_ptr().add(pos),
                self.children.as_mut_ptr().add(pos + 1),
                self.base.count as usize - pos,
            );
        }

        self.keys[pos] = key_flipped;
        self.children[pos] = node;
        self.base.count += 1;

        assert!(self.base.count <= 16);
    }

    fn change(&mut self, key: u8, val: NodePtr) {
        let pos = self.get_child_pos(key).unwrap();
        self.children[pos] = val;
    }

    fn get_child(&self, key: u8) -> Option<NodePtr> {
        let pos = self.get_child_pos(key)?;
        Some(self.children[pos])
    }
}