use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct VisitedMap {
data: Vec<AtomicUsize>,
elements: usize,
bitfield_size: usize,
}
impl VisitedMap {
pub fn new(elements: usize) -> VisitedMap {
let bitfield_size = std::mem::size_of::<usize>() * 8;
let num_bitfields = elements.div_ceil(bitfield_size);
let mut data = Vec::with_capacity(num_bitfields);
for _ in 0..num_bitfields {
data.push(AtomicUsize::new(0));
}
VisitedMap {
data,
elements,
bitfield_size,
}
}
pub fn len(&self) -> usize {
self.elements
}
pub fn is_empty(&self) -> bool {
self.elements == 0
}
pub fn get(&self, element: usize) -> bool {
if element > self.elements {
return false;
}
if let Some(bitfield) = self.data.get(element / self.bitfield_size) {
let shift_amount = u32::try_from(element % self.bitfield_size).unwrap_or(0);
let current_value = bitfield.load(Ordering::Acquire);
return (current_value.wrapping_shr(shift_amount) & 1_usize) != 0;
}
false
}
pub fn get_range(&self, element: usize) -> usize {
if element > self.elements {
return 0;
}
let mut counter = 0;
while let Some(bitfield) = self.data.get((element + counter) / self.bitfield_size) {
let current_value = bitfield.load(Ordering::Acquire);
if current_value == usize::MAX {
counter += self.bitfield_size;
} else {
let shift_amount =
u32::try_from((element + counter) % self.bitfield_size).unwrap_or(0);
if (current_value.wrapping_shr(shift_amount) & 1_usize) == 0 {
counter += 1;
} else {
break;
}
}
}
counter
}
pub fn get_first(&self, visited: bool) -> usize {
let mut counter = 0;
while let Some(bitfield) = self.data.get(counter / self.bitfield_size) {
let current_value = bitfield.load(Ordering::Acquire);
if visited {
if current_value == usize::MAX {
return counter;
} else if current_value == 0 {
counter += self.bitfield_size;
} else {
let shift_amount = u32::try_from(counter % self.bitfield_size).unwrap_or(0);
if (current_value.wrapping_shr(shift_amount) & 1_usize) == 0 {
counter += 1;
} else {
return counter;
}
}
} else if current_value == 0 {
return counter;
} else if current_value == usize::MAX {
counter += self.bitfield_size;
} else if (current_value
.wrapping_shr(u32::try_from(counter % self.bitfield_size).unwrap_or(0))
& 1_usize)
!= 0
{
counter += 1;
} else {
return counter;
}
}
0
}
pub fn set(&self, element: usize, visited: bool) {
self.set_range(element, visited, 1);
}
pub fn set_range(&self, element: usize, state: bool, len: usize) {
if element > self.elements || (element + len) > self.elements {
debug_assert!(false, "Invalid element!");
return;
}
let mut counter = 0;
while counter < len {
let current_pos = element + counter;
let bit_in_field = current_pos % self.bitfield_size;
let remaining = len - counter;
if let Some(bitfield) = self.data.get(current_pos / self.bitfield_size) {
if bit_in_field == 0 && remaining >= self.bitfield_size {
if state {
bitfield.store(usize::MAX, Ordering::Release);
} else {
bitfield.store(0, Ordering::Release);
}
counter += self.bitfield_size;
} else {
let shift_amount = u32::try_from(bit_in_field).unwrap_or(0);
let bit_mask = 1_usize.wrapping_shl(shift_amount);
if state {
bitfield.fetch_or(bit_mask, Ordering::AcqRel);
} else {
bitfield.fetch_and(!bit_mask, Ordering::AcqRel);
}
counter += 1;
}
} else {
debug_assert!(false);
return;
}
}
}
pub fn clear(&self, element: usize) {
self.set(element, false);
}
pub fn clear_all(&self) {
self.set_range(0, false, self.elements);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_small() {
let elements = 4096;
let map = VisitedMap::new(elements);
assert_eq!(map.len(), elements);
}
#[test]
fn create_big() {
let elements = 4 * 1024 * 1024;
let map = VisitedMap::new(elements);
assert_eq!(map.len(), elements);
}
#[test]
fn use_one() {
let map = VisitedMap::new(4096);
map.set(1, true);
assert!(map.get(1));
map.set(1, false);
assert!(!map.get(1));
assert!(!map.get(2));
}
#[test]
fn use_many() {
let map = VisitedMap::new(4096);
map.set(0, true);
map.set(2, true);
map.set(4, true);
map.set(8, true);
map.set(100, true);
map.set(101, true);
map.set(102, true);
map.set(104, true);
map.set(103, true);
assert!(map.get(0));
assert!(map.get(4));
assert!(map.get(8));
assert!(map.get(100));
assert!(map.get(101));
assert!(map.get(102));
assert!(map.get(104));
assert!(map.get(103));
}
#[test]
fn clear_one() {
let map = VisitedMap::new(4096);
map.set(4, true);
assert!(map.get(4));
map.clear(4);
assert!(!map.get(4));
}
#[test]
fn clear_many() {
let map = VisitedMap::new(4096);
map.set(0, true);
map.set(4, true);
map.set(8, true);
map.set(100, true);
map.set(101, true);
map.set(102, true);
map.set(104, true);
map.set(103, true);
assert!(map.get(0));
assert!(map.get(4));
assert!(map.get(8));
assert!(map.get(100));
assert!(map.get(101));
assert!(map.get(102));
assert!(map.get(104));
assert!(map.get(103));
map.clear_all();
assert!(!map.get(0));
assert!(!map.get(4));
assert!(!map.get(8));
assert!(!map.get(100));
assert!(!map.get(101));
assert!(!map.get(102));
assert!(!map.get(104));
assert!(!map.get(103));
}
#[test]
fn get_range() {
let map = VisitedMap::new(4096);
map.set(0, true);
map.set(1, true);
map.set(2, true);
map.set(3, true);
map.set(10, true);
map.set(11, true);
map.set(12, true);
assert_eq!(map.get_range(4), 6);
}
#[test]
fn set_range_long() {
let map = VisitedMap::new(4096);
map.set_range(0, true, 1001);
assert!(map.get(0));
assert!(map.get(4));
assert!(map.get(8));
assert!(map.get(100));
assert!(map.get(101));
assert!(map.get(444));
assert!(map.get(666));
assert!(map.get(1000));
assert!(!map.get(1001));
}
#[test]
fn set_range_small() {
let map = VisitedMap::new(4096);
map.set_range(0, true, 32);
assert!(map.get(0));
assert!(map.get(4));
assert!(map.get(8));
assert!(map.get(24));
assert!(!map.get(35));
assert!(!map.get(33));
}
#[test]
fn get_first_true() {
let map = VisitedMap::new(4096);
map.clear_all();
map.set_range(0, true, 64);
assert_eq!(map.get_first(true), 0);
assert_eq!(map.get_first(false), 64);
map.clear_all();
map.set_range(1, true, 64);
assert_eq!(map.get_first(true), 1);
assert_eq!(map.get_first(false), 0);
}
#[test]
fn bitfield_boundary() {
let bitfield_size = std::mem::size_of::<usize>() * 8;
for offset in 1..8 {
let elements = bitfield_size + offset;
let map = VisitedMap::new(elements);
for i in 0..elements {
map.set(i, true);
assert!(map.get(i), "Element {i} should be set to true");
}
let last_element = elements - 1;
map.set(last_element, false);
assert!(
!map.get(last_element),
"Last element should be set to false"
);
map.set(last_element, true);
assert!(
map.get(last_element),
"Last element should be set to true again"
);
}
}
#[test]
fn bitfield_boundary_exact() {
let bitfield_size = std::mem::size_of::<usize>() * 8;
let map = VisitedMap::new(bitfield_size);
for i in 0..bitfield_size {
map.set(i, true);
assert!(map.get(i));
}
}
#[test]
fn set_range_non_aligned_start() {
let bitfield_size = std::mem::size_of::<usize>() * 8;
let map = VisitedMap::new(bitfield_size * 3);
let start = 10;
let len = bitfield_size + 40; map.set_range(start, true, len);
for i in 0..start {
assert!(!map.get(i), "Byte {} before range should NOT be visited", i);
}
for i in start..(start + len) {
assert!(map.get(i), "Byte {} in range should be visited", i);
}
for i in (start + len)..(bitfield_size * 3) {
assert!(!map.get(i), "Byte {} after range should NOT be visited", i);
}
}
}