use std::fs::{metadata, OpenOptions};
use std::io;
use std::io::{Read, Write};
use std::mem::transmute;
use std::ops::Range;
use std::path::Path;
use memmap2::{Mmap, MmapMut, MmapOptions};
use crate::bitvec::BitVector;
pub enum MmapKind {
MmapMut(MmapMut),
Mmap(Mmap),
}
impl MmapKind {
#[inline]
pub fn as_ptr(&self) -> *const u8 {
match self {
MmapKind::MmapMut(x) => x.as_ptr(),
MmapKind::Mmap(x) => x.as_ptr(),
}
}
#[inline]
pub fn as_mut_ptr(&mut self) -> Result<*mut u8, io::Error> {
match self {
MmapKind::MmapMut(x) => Ok(x.as_mut_ptr()),
MmapKind::Mmap(_) => Err(io::Error::new(
io::ErrorKind::Other,
"attempted to get a mutable pointer to a read-only mmap",
)),
}
}
#[inline]
pub fn flush(&mut self) -> Result<(), io::Error> {
match self {
MmapKind::MmapMut(x) => x.flush(),
MmapKind::Mmap(_) => Ok(()),
}
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
match self {
MmapKind::MmapMut(x) => x.as_ref(),
MmapKind::Mmap(x) => x.as_ref(),
}
}
}
pub struct MmapBitVec {
pub mmap: MmapKind,
pub size: usize,
header: Box<[u8]>,
is_map_anonymous: bool,
}
fn create_bitvec_file(
filename: &Path,
size: usize,
magic: Option<[u8; 2]>,
header: &[u8],
) -> Result<(std::fs::File, u64), io::Error> {
let byte_size = ((size - 1) >> 3) as u64 + 1;
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(filename)?;
let magic_len = if let Some(m) = magic { m.len() } else { 0 };
let total_header_size = (magic_len + 2 + header.len() + 8) as u64;
file.set_len(total_header_size + byte_size)?;
if let Some(m) = magic {
file.write_all(&m)?;
}
let serialized_header_size: [u8; 2] = (header.len() as u16).to_be_bytes();
file.write_all(&serialized_header_size)?;
file.write_all(header)?;
let serialized_size: [u8; 8] = (size as u64).to_be_bytes();
file.write_all(&serialized_size)?;
Ok((file, total_header_size))
}
impl MmapBitVec {
pub fn create<P: AsRef<Path>>(
filename: P,
size: usize,
magic: Option<[u8; 2]>,
header: &[u8],
) -> Result<Self, io::Error> {
assert!(
header.len() < 65_536,
"Headers longer than 65636 bytes not supported"
);
let (file, total_header_size) = create_bitvec_file(filename.as_ref(), size, magic, header)?;
let mmap = unsafe { MmapOptions::new().offset(total_header_size).map_mut(&file) }?;
Ok(MmapBitVec {
mmap: MmapKind::MmapMut(mmap),
size,
header: header.to_vec().into_boxed_slice(),
is_map_anonymous: false,
})
}
pub fn open<P>(filename: P, magic: Option<&[u8; 2]>, read_only: bool) -> Result<Self, io::Error>
where
P: AsRef<Path>,
{
let mut file = OpenOptions::new()
.read(true)
.write(!read_only)
.open(filename)?;
if let Some(m) = magic {
let mut file_magic = [0; 2];
file.read_exact(&mut file_magic)?;
if &file_magic != m {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"file has wrong magic bytes {:x?} (expected {:x?})",
file_magic, m
),
));
}
}
let mut serialized_header_size = [0; 2];
file.read_exact(&mut serialized_header_size)?;
let header_size: usize =
u16::from_be(unsafe { transmute(serialized_header_size) }) as usize;
let mut header = vec![0; header_size];
file.read_exact(&mut header)?;
let mut serialized_size = [0; 8];
file.read_exact(&mut serialized_size)?;
let size: u64 = u64::from_be(unsafe { transmute(serialized_size) });
let magic_len = if let Some(m) = magic { m.len() } else { 0 };
let total_header_size = (magic_len + 2 + header_size + 8) as u64;
let byte_size = ((size - 1) >> 3) + 1;
if file.metadata()?.len() != total_header_size + byte_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"file should be {} bytes (with {} header), but file is {} bytes",
byte_size + total_header_size as u64,
total_header_size,
file.metadata()?.len(),
),
));
}
let mmap = if read_only {
let mmap = unsafe { MmapOptions::new().offset(total_header_size).map(&file) }?;
MmapKind::Mmap(mmap)
} else {
let mmap = unsafe { MmapOptions::new().offset(total_header_size).map_mut(&file) }?;
MmapKind::MmapMut(mmap)
};
Ok(MmapBitVec {
mmap,
size: size as usize,
header: header.into_boxed_slice(),
is_map_anonymous: false,
})
}
pub fn open_no_header<P>(filename: P, offset: usize) -> Result<Self, io::Error>
where
P: AsRef<Path>,
{
let file_size = metadata(&filename)?.len() as usize;
let byte_size = file_size - offset;
let f = OpenOptions::new().read(true).write(false).open(&filename)?;
let mmap = unsafe { MmapOptions::new().offset(offset as u64).map(&f) }?;
Ok(MmapBitVec {
mmap: MmapKind::Mmap(mmap),
size: byte_size * 8,
header: Box::new([]),
is_map_anonymous: false,
})
}
pub fn from_memory(size: usize) -> Result<Self, io::Error> {
let byte_size = ((size - 1) >> 3) as u64 + 1;
let mmap = MmapOptions::new().len(byte_size as usize).map_anon()?;
Ok(MmapBitVec {
mmap: MmapKind::MmapMut(mmap),
size,
header: vec![].into_boxed_slice(),
is_map_anonymous: true,
})
}
pub fn save_to_disk<P: AsRef<Path>>(
&self,
filename: P,
magic: Option<[u8; 2]>,
header: &[u8],
) -> Result<(), io::Error> {
if !self.is_map_anonymous {
return Ok(());
}
let (mut file, _) = create_bitvec_file(filename.as_ref(), self.size, magic, header)?;
file.write_all(self.mmap.as_slice())?;
Ok(())
}
pub fn header(&self) -> &[u8] {
&self.header
}
pub fn get_range_bytes(&self, r: Range<usize>) -> Vec<u8> {
if r.end > self.size {
panic!("Range ends outside of BitVec")
}
let byte_idx_st = (r.start >> 3) as usize;
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let new_size: usize = (((r.end - r.start) as usize - 1) >> 3) + 1;
let ptr: *const u8 = self.mmap.as_ptr();
let mut v = vec![0; new_size];
let shift = (r.end & 7) as u8;
for (new_idx, old_idx) in (byte_idx_st..=byte_idx_en).enumerate() {
let old_val = unsafe { *ptr.add(old_idx) };
if new_idx > 0 {
if let Some(shifted_val) = old_val.checked_shr(u32::from(shift)) {
v[new_idx - 1] |= shifted_val;
}
}
if new_idx < new_size {
v[new_idx] |= (old_val & (0xFF >> shift)) << shift;
}
}
v
}
pub fn set_range_bytes(&mut self, r: Range<usize>, x: &[u8]) {
if r.end > self.size {
panic!("Range ends outside of BitVec")
}
let new_size: usize = r.end - r.start;
if ((new_size - 1) >> 3) + 1 != x.len() {
panic!("Range and array passed are different sizes")
}
let max_len = 8 * x.len();
let byte_idx_st = if r.end - 1 > max_len {
((r.end - 1 - max_len) >> 3) + 1
} else {
0
};
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let mmap: *mut u8 = self
.mmap
.as_mut_ptr()
.expect("set_range_bytes can only be called on a mutable mmap");
let shift = 8 - (r.end & 7) as u8;
let mask = 0xFFu8.checked_shr(u32::from(8 - shift)).unwrap_or(0xFF);
for (val, idx) in x.iter().zip(byte_idx_st..=byte_idx_en) {
let shifted_val = val.checked_shr(u32::from(8 - shift)).unwrap_or(0);
if idx > 0 && shift != 8 {
unsafe {
*mmap.offset(idx as isize - 1) |= shifted_val;
}
}
let shifted_val = (val & mask).checked_shl(u32::from(shift)).unwrap_or(*val);
unsafe {
*mmap.add(idx) |= shifted_val;
}
}
}
}
impl BitVector for MmapBitVec {
fn get(&self, i: usize) -> bool {
if i > self.size {
panic!("Invalid bit vector index");
}
let byte_idx = (i >> 3) as isize;
let bit_idx = 7 - (i & 7) as u8;
let mmap: *const u8 = self.mmap.as_ptr();
unsafe { (*mmap.offset(byte_idx) & (1 << bit_idx)) != 0 }
}
fn set(&mut self, i: usize, x: bool) {
if i > self.size {
panic!("Invalid bit vector index");
}
let byte_idx = (i >> 3) as isize;
let bit_idx = 7 - (i & 7) as u8;
let mmap: *mut u8 = self
.mmap
.as_mut_ptr()
.expect("set can only be called on a mutable mmap");
unsafe {
if x {
*mmap.offset(byte_idx) |= 1 << bit_idx
} else {
*mmap.offset(byte_idx) &= !(1 << bit_idx)
}
}
}
fn size(&self) -> usize {
self.size
}
fn rank(&self, r: Range<usize>) -> usize {
let byte_idx_st = (r.start >> 3) as usize;
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let mmap: *const u8 = self.mmap.as_ptr();
let mut bit_count = 0usize;
let mut size_front = 8u8 - (r.start & 7) as u8;
if size_front == 8 {
size_front = 0;
}
if let Some(mask) = 0xFFu8.checked_shl(u32::from(size_front)) {
let byte = unsafe { *mmap.add(byte_idx_st) & mask };
bit_count += byte.count_ones() as usize
}
if byte_idx_st == byte_idx_en {
return bit_count;
}
let mut size_back = (r.end & 7) as u8;
if size_back == 8 {
size_back = 0;
}
if let Some(mask) = 0xFFu8.checked_shr(u32::from(size_back)) {
let byte = unsafe { *mmap.add(byte_idx_en) & mask };
bit_count += byte.count_ones() as usize
}
if byte_idx_st + 1 == byte_idx_en {
return bit_count;
}
for byte_idx in (byte_idx_st + 1)..byte_idx_en {
let byte = unsafe { *mmap.add(byte_idx) };
bit_count += byte.count_ones() as usize
}
bit_count
}
fn select(&self, n: usize, start: usize) -> Option<usize> {
let byte_idx_st = (start >> 3) as usize;
let size_front = 8u8 - (start & 7) as u8;
let mmap: *const u8 = self.mmap.as_ptr();
let mut rank_count = 0usize;
for byte_idx in byte_idx_st.. {
let mut byte = unsafe { *mmap.add(byte_idx) };
if byte_idx == byte_idx_st {
if let Some(mask) = 0xFFu8.checked_shl(u32::from(size_front)) {
byte &= mask;
}
}
if rank_count + byte.count_ones() as usize >= n {
for bit_idx in 0..8 {
if (0b1000_0000 >> bit_idx) & byte != 0 {
rank_count += 1;
}
if rank_count == n {
return Some((byte_idx << 3) + bit_idx);
}
}
panic!("Select failed to find enough bits (but there were!)");
}
rank_count += byte.count_ones() as usize;
}
None
}
fn get_range(&self, r: Range<usize>) -> u128 {
if r.end - r.start > 128usize {
panic!("Range too large (>128)")
} else if r.end > self.size {
panic!("Range ends outside of BitVec")
}
let byte_idx_st = (r.start >> 3) as usize;
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let new_size: u8 = (r.end - r.start) as u8;
let mut v;
let ptr: *const u8 = self.mmap.as_ptr();
unsafe {
v = u128::from(*ptr.add(byte_idx_en));
}
v >>= 7 - ((r.end - 1) & 7);
if r.start < self.size - 128usize {
unsafe {
#[allow(clippy::transmute_ptr_to_ptr)]
let lg_ptr: *const u128 = transmute(ptr.add(byte_idx_st));
v |= (*lg_ptr).to_be() << (r.start & 7) >> (128 - new_size);
}
} else {
let bit_offset = new_size + (r.start & 7) as u8;
for (new_idx, old_idx) in (byte_idx_st..byte_idx_en).enumerate() {
unsafe {
v |= u128::from(*ptr.add(old_idx)) << (bit_offset - 8u8 * (new_idx as u8 + 1));
}
}
}
v & (u128::max_value() >> (128 - new_size))
}
fn set_range(&mut self, r: Range<usize>, x: u128) {
if r.end > self.size {
panic!("Range ends outside of BitVec")
}
let byte_idx_st = (r.start >> 3) as usize;
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let new_size: u8 = (r.end - r.start) as u8;
let size_front = 8u8 - (r.start & 7) as u8;
let front_byte = if size_front >= new_size {
(x << (size_front - new_size)) as u8
} else {
(x >> (new_size - size_front)) as u8
};
let mmap: *mut u8 = self
.mmap
.as_mut_ptr()
.expect("set_range can only be called on a mutable mmap");
unsafe {
*mmap.add(byte_idx_st) |= front_byte;
}
if byte_idx_st == byte_idx_en {
return;
}
let mut size_back = (r.end & 7) as u8;
if size_back == 0 {
size_back = 8;
}
let back_byte = (x << (128 - size_back) >> 120) as u8;
unsafe {
*mmap.add(byte_idx_en) |= back_byte;
}
if byte_idx_st + 1 == byte_idx_en {
return;
}
let size_main = new_size - size_front;
let main_chunk = (x << (128 - size_main)).to_be();
let bytes: [u8; 16] = main_chunk.to_le_bytes();
for (byte_idx, byte) in ((byte_idx_st + 1)..byte_idx_en).zip(bytes.iter()) {
unsafe {
*mmap.add(byte_idx) |= *byte;
}
}
}
fn clear_range(&mut self, r: Range<usize>) {
if (r.end - 1) > self.size {
panic!("Range ends outside of BitVec")
}
let byte_idx_st = (r.start >> 3) as usize;
let byte_idx_en = ((r.end - 1) >> 3) as usize;
let mmap: *mut u8 = self
.mmap
.as_mut_ptr()
.expect("clear range can only be called on a mutable mmap");
let size_front = 8u8 - (r.start & 7) as u8;
if let Some(mask) = 0xFFu8.checked_shl(u32::from(size_front)) {
unsafe {
*mmap.add(byte_idx_st) &= mask;
}
}
if byte_idx_st == byte_idx_en {
return;
}
let mut size_back = (r.end & 7) as u8;
if size_back == 0 {
size_back = 8;
}
if let Some(mask) = 0xFFu8.checked_shr(u32::from(size_back)) {
unsafe {
*mmap.add(byte_idx_en) &= mask;
}
}
if byte_idx_st + 1 == byte_idx_en {
return;
}
for byte_idx in (byte_idx_st + 1)..byte_idx_en {
unsafe {
*mmap.add(byte_idx) = 0u8;
}
}
}
}
impl Drop for MmapBitVec {
fn drop(&mut self) {
let _ = self.mmap.flush();
}
}
#[cfg(test)]
mod test {
use std::path::Path;
use super::MmapBitVec;
use crate::bitvec::BitVector;
#[test]
fn test_bitvec() {
use std::fs::remove_file;
let header = vec![];
let mut b = MmapBitVec::create("./test", 100, None, &header).unwrap();
b.set(2, true);
assert!(!b.get(1));
assert!(b.get(2));
assert!(!b.get(100));
drop(b);
assert!(Path::new("./test").exists());
let b = MmapBitVec::open("./test", None, true).unwrap();
assert!(!b.get(1));
assert!(b.get(2));
assert!(!b.get(100));
remove_file("./test").unwrap();
}
#[test]
fn test_open_no_header() {
use std::fs::remove_file;
let header = vec![];
let _ = MmapBitVec::create("./test_headerless", 80, None, &header).unwrap();
assert!(Path::new("./test_headerless").exists());
let b = MmapBitVec::open_no_header("./test_headerless", 12).unwrap();
assert_eq!(b.size(), 64);
remove_file("./test_headerless").unwrap();
}
#[test]
fn test_bitvec_get_range() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set(2, true);
b.set(3, true);
b.set(5, true);
assert_eq!(b.get_range(0..8), 52, "indexing within a single byte");
assert_eq!(b.get_range(0..16), 13312, "indexing multiple bytes");
assert_eq!(
b.get_range(0..64),
3_746_994_889_972_252_672,
"indexing the maximum # of bytes"
);
assert_eq!(
b.get_range(64..128),
0,
"indexing the maximum # of bytes to the end"
);
assert_eq!(b.get_range(2..10), 208, "indexing across bytes");
assert_eq!(
b.get_range(2..66),
14_987_979_559_889_010_688,
"indexing the maximum # of bytes across bytes"
);
assert_eq!(b.get_range(115..128), 0, "indexing across bytes to the end");
}
#[test]
fn test_bitvec_get_range_bytes() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set(2, true);
b.set(3, true);
b.set(5, true);
assert_eq!(
b.get_range_bytes(0..8),
&[0x34],
"indexing within a single byte"
);
assert_eq!(
b.get_range_bytes(0..16),
&[0x34, 0x00],
"indexing multiple bytes"
);
assert_eq!(
b.get_range_bytes(0..64),
&[0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
"indexing the maximum # of bytes"
);
assert_eq!(
b.get_range_bytes(64..128),
&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
"indexing the maximum # of bytes to the end"
);
assert_eq!(b.get_range_bytes(2..10), &[0xD0], "indexing across bytes");
assert_eq!(
b.get_range_bytes(2..66),
&[0xD0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
"indexing the maximum # of bytes across bytes"
);
assert_eq!(
b.get_range_bytes(115..128),
&[0x00, 0x00],
"indexing across bytes to the end"
);
}
#[test]
fn test_bitvec_set_range() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set_range(0..4, 0b0101);
assert_eq!(b.get_range(0..4), 0b0101);
b.set_range(5..8, 0b0101);
assert_eq!(b.get_range(5..8), 0b0101);
b.set_range(123..127, 0b0101);
assert_eq!(b.get_range(123..127), 0b0101);
b.set_range(6..9, 0b111);
assert_eq!(b.get_range(6..9), 0b111);
b.set_range(0..16, 0xFFFF);
assert_eq!(b.get_range(0..16), 0xFFFF);
b.clear_range(4..12);
assert_eq!(b.get_range(0..16), 0xF00F);
b.set_range(20..36, 0xFFFF);
assert_eq!(b.get_range(16..20), 0x0);
assert_eq!(b.get_range(20..36), 0xFFFF);
assert_eq!(b.get_range(36..44), 0x0);
assert_eq!(b.get_range(39..103), 0x0);
b.set_range(39..103, 0xABCD1234);
assert_eq!(b.get_range(39..103), 0xABCD1234);
}
#[test]
fn test_bitvec_set_range_bytes() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set_range_bytes(0..4, &[0x05u8]);
assert_eq!(b.get_range(0..4), 0b0101);
b.set_range_bytes(5..8, &[0x05u8]);
assert_eq!(b.get_range(5..8), 0b0101);
b.clear_range(0..16);
b.set_range_bytes(6..10, &[0x0Du8]);
assert_eq!(b.get_range(6..10), 0x0D);
b.set_range_bytes(0..16, &[0xFFu8, 0xFFu8]);
assert_eq!(b.get_range(0..16), 0xFFFF);
b.set_range_bytes(20..36, &[0xFFu8, 0xFFu8]);
assert_eq!(b.get_range(20..36), 0xFFFF);
b.set_range_bytes(64..80, &[0xA0u8, 0x0Au8]);
assert_eq!(b.get_range(64..80), 0xA00A);
b.set_range_bytes(64..80, &[0x0Bu8, 0xB0u8]);
assert_eq!(b.get_range(64..80), 0xABBA);
}
#[test]
fn test_rank_select() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set(7, true);
b.set(56, true);
b.set(127, true);
assert_eq!(b.rank(0..8), 1);
assert_eq!(b.rank(0..128), 3);
assert_eq!(b.select(1, 0), Some(7));
assert_eq!(b.select(3, 0), Some(127));
}
#[test]
fn can_write_anon_mmap_to_disk() {
let mut b = MmapBitVec::from_memory(128).unwrap();
b.set(0, true);
b.set(7, true);
b.set(56, true);
b.set(127, true);
let dir = tempfile::tempdir().unwrap();
b.save_to_disk(dir.path().join("test"), None, &[]).unwrap();
let f = MmapBitVec::open(dir.path().join("test"), None, false).unwrap();
assert_eq!(f.get(0), true);
assert_eq!(f.get(7), true);
assert_eq!(f.get(56), true);
assert_eq!(f.get(127), true);
assert_eq!(f.get(10), false);
}
}