skippydb 0.2.2

A high-performance verifiable key-value store with SHA256 Merkle trees and optional CUDA GPU acceleration, designed for blockchain state storage.
use dashmap::DashMap;
use std::sync::atomic::{AtomicU32, Ordering};

const U32_COUNT: u64 = 1024;
const BIT_COUNT: u64 = 32 * U32_COUNT;

struct Segment {
    count: AtomicU32,
    arr: [AtomicU32; U32_COUNT as usize],
}

const ZERO: AtomicU32 = AtomicU32::new(0);

impl Segment {
    fn new() -> Self {
        Self {
            count: ZERO,
            arr: [ZERO; U32_COUNT as usize],
        }
    }

    fn get(&self, n: usize) -> bool {
        let (i, j) = (n / 32, n % 32);
        let mask = 1u32 << j;
        let old = self.arr[i].load(Ordering::Relaxed);
        (old & mask) != 0
    }

    fn set(&self, n: usize) {
        let (i, j) = (n / 32, n % 32);
        let mask = 1u32 << j;
        let old = self.arr[i].fetch_or(mask, Ordering::AcqRel);
        if (old & mask) == 0 {
            self.count.fetch_add(1, Ordering::Relaxed);
        }
    }

    fn clear(&self, n: usize) -> bool {
        let (i, j) = (n / 32, n % 32);
        let mask = 1u32 << j;
        let old = self.arr[i].fetch_and(!mask, Ordering::AcqRel);
        if (old & mask) != 0 {
            let old_count = self.count.fetch_sub(1, Ordering::Relaxed);
            return old_count == 1; //need to remove myself
        }
        false
    }
}

pub struct ActiveBits {
    m: DashMap<u64, Box<Segment>>,
}

impl ActiveBits {
    pub fn with_capacity(n: usize) -> Self {
        Self {
            m: DashMap::with_capacity(n),
        }
    }

    pub fn get(&self, n: u64) -> bool {
        let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
        if let Some(seg) = self.m.get(&i) {
            return seg.get(j as usize);
        }
        false
    }

    pub fn set(&self, n: u64) {
        let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
        self.m
            .entry(i)
            .or_insert_with(|| Box::new(Segment::new()))
            .set(j as usize);
    }

    pub fn clear(&self, n: u64) {
        let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
        let need_remove = {
            if let Some(seg) = self.m.get(&i) {
                seg.clear(j as usize)
            } else {
                false
            }
        };
        if need_remove {
            self.m.remove(&i);
        }
    }
}

#[cfg(test)]
mod segments_tests {
    use super::*;

    #[test]
    fn test_segment() {
        let segment = Segment::new();

        // Test set and get
        assert!(!segment.get(0));
        segment.set(0);
        assert!(segment.get(0));
        assert!(!segment.get(1));

        // Test clear
        assert!(!segment.clear(1)); // Clearing unset bit
        assert!(segment.clear(0)); // Clearing last set bit
        assert!(!segment.get(0));

        // Test multiple bits
        segment.set(31);
        segment.set(32);
        assert!(segment.get(31));
        assert!(segment.get(32));
        assert!(!segment.get(33));

        // Test count
        assert_eq!(segment.count.load(Ordering::SeqCst), 2);
        segment.clear(31);
        assert_eq!(segment.count.load(Ordering::SeqCst), 1);
        segment.clear(32);
        assert_eq!(segment.count.load(Ordering::SeqCst), 0);
    }

    #[test]
    fn test_segment_edge_cases() {
        let segment = Segment::new();

        // Test first and last bits
        segment.set(0);
        segment.set(BIT_COUNT as usize - 1);
        assert!(segment.get(0));
        assert!(segment.get(BIT_COUNT as usize - 1));
        assert_eq!(segment.count.load(Ordering::SeqCst), 2);

        // Test setting already set bit
        segment.set(0);
        assert_eq!(segment.count.load(Ordering::SeqCst), 2);

        // Test clearing unset bit
        assert!(!segment.clear(1));
        assert_eq!(segment.count.load(Ordering::SeqCst), 2);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_get_set_clear() {
        let ab = ActiveBits::with_capacity(10);

        // Test set and get
        ab.set(42);
        assert!(ab.get(42));
        assert!(!ab.get(43));

        // Test clear
        ab.clear(42);
        assert!(!ab.get(42));
    }

    #[test]
    fn test_large_numbers() {
        let ab = ActiveBits::with_capacity(10);

        let large_num = u64::MAX - 1;
        ab.set(large_num);
        assert!(ab.get(large_num));
        assert!(!ab.get(large_num - 1));
        assert!(!ab.get(large_num + 1));

        ab.clear(large_num);
        assert!(!ab.get(large_num));
    }

    #[test]
    fn test_multiple_segments() {
        let ab = ActiveBits::with_capacity(10);

        let num1 = BIT_COUNT - 1;
        let num2 = BIT_COUNT;
        let num3 = BIT_COUNT + 1;

        ab.set(num1);
        ab.set(num2);
        ab.set(num3);

        assert!(ab.get(num1));
        assert!(ab.get(num2));
        assert!(ab.get(num3));

        ab.clear(num2);
        assert!(ab.get(num1));
        assert!(!ab.get(num2));
        assert!(ab.get(num3));
    }

    // ========== QW-3: ActiveBits Concurrent Correctness Tests ==========

    #[test]
    fn test_concurrent_set_get() {
        use std::sync::Arc;
        use std::thread;

        let ab = Arc::new(ActiveBits::with_capacity(64));
        let thread_count = 8;
        let bits_per_thread = 1000;

        let handles: Vec<_> = (0..thread_count)
            .map(|t| {
                let ab = Arc::clone(&ab);
                thread::spawn(move || {
                    let base = t * bits_per_thread;
                    for i in 0..bits_per_thread {
                        ab.set(base + i);
                    }
                })
            })
            .collect();

        for h in handles {
            h.join().expect("thread panicked");
        }

        // Verify all bits are set
        for t in 0..thread_count {
            let base = t * bits_per_thread;
            for i in 0..bits_per_thread {
                assert!(
                    ab.get(base + i),
                    "bit {} (thread={}, offset={}) not set",
                    base + i,
                    t,
                    i
                );
            }
        }
    }

    #[test]
    fn test_concurrent_set_clear() {
        use std::sync::atomic::{AtomicU64, Ordering};
        use std::sync::Arc;
        use std::thread;

        let ab = Arc::new(ActiveBits::with_capacity(64));
        let range = 1000_u64;

        // Phase 1: Set all bits first so clear has something to clear
        for i in 0..range {
            ab.set(i);
        }

        let set_count = Arc::new(AtomicU64::new(0));
        let clear_count = Arc::new(AtomicU64::new(0));

        // 4 threads set bits [0..range), 4 threads clear bits [0..range)
        let mut handles = Vec::new();
        for t in 0..8_u64 {
            let ab = Arc::clone(&ab);
            let sc = Arc::clone(&set_count);
            let cc = Arc::clone(&clear_count);
            handles.push(thread::spawn(move || {
                if t < 4 {
                    for i in 0..range {
                        ab.set(i);
                        sc.fetch_add(1, Ordering::Relaxed);
                    }
                } else {
                    for i in 0..range {
                        ab.clear(i);
                        cc.fetch_add(1, Ordering::Relaxed);
                    }
                }
            }));
        }

        for h in handles {
            h.join().expect("thread panicked");
        }

        // After all threads finish, each bit is either set or clear
        // (no corruption, no panic). We just verify consistency.
        let mut set_bits = 0u64;
        for i in 0..range {
            if ab.get(i) {
                set_bits += 1;
            }
        }

        // The exact count depends on scheduling, but it must be
        // in [0, range]. Just verify no panics and reasonable bounds.
        assert!(
            set_bits <= range,
            "set_bits {} exceeds range {}", set_bits, range
        );
    }
}