use dashmap::DashMap;
use std::sync::atomic::{AtomicU32, Ordering};
const U32_COUNT: u64 = 1024;
const BIT_COUNT: u64 = 32 * U32_COUNT;
struct Segment {
count: AtomicU32,
arr: [AtomicU32; U32_COUNT as usize],
}
const ZERO: AtomicU32 = AtomicU32::new(0);
impl Segment {
fn new() -> Self {
Self {
count: ZERO,
arr: [ZERO; U32_COUNT as usize],
}
}
fn get(&self, n: usize) -> bool {
let (i, j) = (n / 32, n % 32);
let mask = 1u32 << j;
let old = self.arr[i].load(Ordering::Relaxed);
(old & mask) != 0
}
fn set(&self, n: usize) {
let (i, j) = (n / 32, n % 32);
let mask = 1u32 << j;
let old = self.arr[i].fetch_or(mask, Ordering::AcqRel);
if (old & mask) == 0 {
self.count.fetch_add(1, Ordering::Relaxed);
}
}
fn clear(&self, n: usize) -> bool {
let (i, j) = (n / 32, n % 32);
let mask = 1u32 << j;
let old = self.arr[i].fetch_and(!mask, Ordering::AcqRel);
if (old & mask) != 0 {
let old_count = self.count.fetch_sub(1, Ordering::Relaxed);
return old_count == 1; }
false
}
}
pub struct ActiveBits {
m: DashMap<u64, Box<Segment>>,
}
impl ActiveBits {
pub fn with_capacity(n: usize) -> Self {
Self {
m: DashMap::with_capacity(n),
}
}
pub fn get(&self, n: u64) -> bool {
let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
if let Some(seg) = self.m.get(&i) {
return seg.get(j as usize);
}
false
}
pub fn set(&self, n: u64) {
let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
self.m
.entry(i)
.or_insert_with(|| Box::new(Segment::new()))
.set(j as usize);
}
pub fn clear(&self, n: u64) {
let (i, j) = (n / BIT_COUNT, n % BIT_COUNT);
let need_remove = {
if let Some(seg) = self.m.get(&i) {
seg.clear(j as usize)
} else {
false
}
};
if need_remove {
self.m.remove(&i);
}
}
}
#[cfg(test)]
mod segments_tests {
use super::*;
#[test]
fn test_segment() {
let segment = Segment::new();
assert!(!segment.get(0));
segment.set(0);
assert!(segment.get(0));
assert!(!segment.get(1));
assert!(!segment.clear(1)); assert!(segment.clear(0)); assert!(!segment.get(0));
segment.set(31);
segment.set(32);
assert!(segment.get(31));
assert!(segment.get(32));
assert!(!segment.get(33));
assert_eq!(segment.count.load(Ordering::SeqCst), 2);
segment.clear(31);
assert_eq!(segment.count.load(Ordering::SeqCst), 1);
segment.clear(32);
assert_eq!(segment.count.load(Ordering::SeqCst), 0);
}
#[test]
fn test_segment_edge_cases() {
let segment = Segment::new();
segment.set(0);
segment.set(BIT_COUNT as usize - 1);
assert!(segment.get(0));
assert!(segment.get(BIT_COUNT as usize - 1));
assert_eq!(segment.count.load(Ordering::SeqCst), 2);
segment.set(0);
assert_eq!(segment.count.load(Ordering::SeqCst), 2);
assert!(!segment.clear(1));
assert_eq!(segment.count.load(Ordering::SeqCst), 2);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_set_clear() {
let ab = ActiveBits::with_capacity(10);
ab.set(42);
assert!(ab.get(42));
assert!(!ab.get(43));
ab.clear(42);
assert!(!ab.get(42));
}
#[test]
fn test_large_numbers() {
let ab = ActiveBits::with_capacity(10);
let large_num = u64::MAX - 1;
ab.set(large_num);
assert!(ab.get(large_num));
assert!(!ab.get(large_num - 1));
assert!(!ab.get(large_num + 1));
ab.clear(large_num);
assert!(!ab.get(large_num));
}
#[test]
fn test_multiple_segments() {
let ab = ActiveBits::with_capacity(10);
let num1 = BIT_COUNT - 1;
let num2 = BIT_COUNT;
let num3 = BIT_COUNT + 1;
ab.set(num1);
ab.set(num2);
ab.set(num3);
assert!(ab.get(num1));
assert!(ab.get(num2));
assert!(ab.get(num3));
ab.clear(num2);
assert!(ab.get(num1));
assert!(!ab.get(num2));
assert!(ab.get(num3));
}
#[test]
fn test_concurrent_set_get() {
use std::sync::Arc;
use std::thread;
let ab = Arc::new(ActiveBits::with_capacity(64));
let thread_count = 8;
let bits_per_thread = 1000;
let handles: Vec<_> = (0..thread_count)
.map(|t| {
let ab = Arc::clone(&ab);
thread::spawn(move || {
let base = t * bits_per_thread;
for i in 0..bits_per_thread {
ab.set(base + i);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
for t in 0..thread_count {
let base = t * bits_per_thread;
for i in 0..bits_per_thread {
assert!(
ab.get(base + i),
"bit {} (thread={}, offset={}) not set",
base + i,
t,
i
);
}
}
}
#[test]
fn test_concurrent_set_clear() {
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread;
let ab = Arc::new(ActiveBits::with_capacity(64));
let range = 1000_u64;
for i in 0..range {
ab.set(i);
}
let set_count = Arc::new(AtomicU64::new(0));
let clear_count = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for t in 0..8_u64 {
let ab = Arc::clone(&ab);
let sc = Arc::clone(&set_count);
let cc = Arc::clone(&clear_count);
handles.push(thread::spawn(move || {
if t < 4 {
for i in 0..range {
ab.set(i);
sc.fetch_add(1, Ordering::Relaxed);
}
} else {
for i in 0..range {
ab.clear(i);
cc.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
let mut set_bits = 0u64;
for i in 0..range {
if ab.get(i) {
set_bits += 1;
}
}
assert!(
set_bits <= range,
"set_bits {} exceeds range {}", set_bits, range
);
}
}