use crate::{NodeType, OutOfBoundError};
use anyhow::{Context, Result};
use log::info;
use mmap_rs::{Mmap, MmapFlags, MmapMut};
use std::path::Path;
use sux::bits::BitFieldVec;
use value_traits::slices::{SliceByValue, SliceByValueMut};
pub struct Node2Type<B> {
data: BitFieldVec<usize, B>,
}
impl<B: AsRef<[usize]>> Node2Type<B> {
#[inline]
pub unsafe fn get_unchecked(&self, node_id: usize) -> NodeType {
NodeType::try_from(self.data.get_value_unchecked(node_id) as u8).unwrap()
}
#[inline]
pub fn get(&self, node_id: usize) -> Result<NodeType, OutOfBoundError> {
NodeType::try_from(self.data.index_value(node_id) as u8).map_err(|_| OutOfBoundError {
index: node_id,
len: self.data.len(),
})
}
}
impl<B: AsRef<[usize]> + AsMut<[usize]>> Node2Type<B> {
#[inline]
pub unsafe fn set_unchecked(&mut self, node_id: usize, node_type: NodeType) {
self.data.set_value_unchecked(node_id, node_type as usize);
}
#[inline]
pub fn set(&mut self, node_id: usize, node_type: NodeType) {
self.data.set_value(node_id, node_type as usize);
}
}
pub struct UsizeMmap<B>(B);
impl<B: AsRef<[u8]>> AsRef<[usize]> for UsizeMmap<B> {
fn as_ref(&self) -> &[usize] {
bytemuck::cast_slice(self.0.as_ref())
}
}
impl<B: AsRef<[u8]> + AsMut<[u8]>> AsMut<[usize]> for UsizeMmap<B> {
fn as_mut(&mut self) -> &mut [usize] {
bytemuck::cast_slice_mut(self.0.as_mut())
}
}
impl Node2Type<UsizeMmap<MmapMut>> {
pub fn new<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = ((num_nodes * NodeType::BITWIDTH) as u64).div_ceil(64) * 8;
info!("The resulting file will be {} bytes long.", file_len);
let node2type_file = std::fs::File::options()
.read(true)
.write(true)
.create_new(true)
.open(path)
.with_context(|| {
format!(
"While creating the .node2type.bin file: {}",
path.to_string_lossy()
)
})?;
node2type_file
.set_len(file_len)
.with_context(|| "While fallocating the file with zeros")?;
let mmap = unsafe {
mmap_rs::MmapOptions::new(file_len as _)
.context("Could not initialize mmap")?
.with_flags(MmapFlags::SHARED)
.with_file(&node2type_file, 0)
.map_mut()
.with_context(|| "While mmapping the file")?
};
let mmap = UsizeMmap(mmap);
let node2type = unsafe { BitFieldVec::from_raw_parts(mmap, NodeType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
pub fn load_mut<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = path
.metadata()
.with_context(|| format!("Could not stat {}", path.display()))?
.len();
let file = std::fs::File::open(path)
.with_context(|| format!("Could not open {}", path.display()))?;
let data = unsafe {
mmap_rs::MmapOptions::new(file_len as _)
.context("Could not initialize mmap")?
.with_flags(MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::RANDOM_ACCESS)
.with_file(&file, 0)
.map_mut()?
};
let data = UsizeMmap(data);
let node2type = unsafe { BitFieldVec::from_raw_parts(data, NodeType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
}
impl Node2Type<UsizeMmap<Mmap>> {
pub fn load<P: AsRef<Path>>(path: P, num_nodes: usize) -> Result<Self> {
let path = path.as_ref();
let file_len = path
.metadata()
.with_context(|| format!("Could not stat {}", path.display()))?
.len();
let expected_file_len = ((num_nodes * NodeType::BITWIDTH).div_ceil(64) * 8) as u64;
assert_eq!(
file_len,
expected_file_len,
"Expected {} to have size {} (because graph has {} nodes), but it has size {}",
path.display(),
expected_file_len,
num_nodes,
file_len,
);
let file = std::fs::File::open(path)
.with_context(|| format!("Could not open {}", path.display()))?;
let data = unsafe {
mmap_rs::MmapOptions::new(file_len as _)?
.with_flags(MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::RANDOM_ACCESS)
.with_file(&file, 0)
.map()?
};
let data = UsizeMmap(data);
let node2type = unsafe { BitFieldVec::from_raw_parts(data, NodeType::BITWIDTH, num_nodes) };
Ok(Self { data: node2type })
}
}
impl Node2Type<Vec<usize>> {
pub fn new_from_iter(types: impl ExactSizeIterator<Item = NodeType>) -> Self {
let num_nodes = types.len();
let file_len = ((num_nodes * NodeType::BITWIDTH) as u64).div_ceil(64) * 8;
let file_len = usize::try_from(file_len).expect("num_nodes overflowed usize");
let data = vec![0usize; file_len.div_ceil((usize::BITS / 8).try_into().unwrap())];
let data = unsafe { BitFieldVec::from_raw_parts(data, NodeType::BITWIDTH, num_nodes) };
let mut node2type = Node2Type { data };
for (i, type_) in types.enumerate() {
node2type.set(i, type_);
}
node2type
}
}