use std::cell::UnsafeCell;
use std::fs::File;
use std::hint::spin_loop;
use std::marker::{PhantomData, Sync};
use std::mem::size_of;
use std::path::Path;
use std::slice;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use anyhow::Result;
use byte_slice_cast::{AsSliceOf, FromByteSlice};
use log::{info, warn};
use mapr::{Mmap, MmapMut, MmapOptions};
pub struct CacheReader<T> {
file: File,
bufs: UnsafeCell<[Mmap; 2]>,
size: usize,
degree: usize,
window_size: usize,
cursor: IncrementingCursor,
consumer: AtomicU64,
_t: PhantomData<T>,
}
unsafe impl<T> Sync for CacheReader<T> {}
struct IncrementingCursor {
cur: AtomicUsize,
cur_safe: AtomicUsize,
}
fn compare_and_swap(atomic: &AtomicUsize, before: usize, after: usize) -> usize {
match atomic.compare_exchange_weak(before, after, Ordering::SeqCst, Ordering::SeqCst) {
Ok(x) => {
assert_eq!(x, before);
before
}
_ => after,
}
}
impl IncrementingCursor {
fn new(val: usize) -> Self {
Self {
cur: AtomicUsize::new(val),
cur_safe: AtomicUsize::new(val),
}
}
fn store(&self, val: usize) {
self.cur.store(val, Ordering::SeqCst);
self.cur_safe.store(val, Ordering::SeqCst);
}
fn compare_and_swap(&self, before: usize, after: usize) {
compare_and_swap(&self.cur, before, after);
compare_and_swap(&self.cur_safe, before, after);
}
fn increment<F: Fn() -> bool, G: Fn()>(&self, target: usize, wait_fn: F, advance_fn: G) {
let cur = self.cur_safe.load(Ordering::SeqCst);
if target > cur {
let instant_cur = compare_and_swap(&self.cur, cur, cur + 1);
if instant_cur == cur {
{
while wait_fn() {
spin_loop()
}
}
advance_fn();
self.cur_safe.fetch_add(1, Ordering::SeqCst);
} else {
while self.cur_safe.load(Ordering::SeqCst) != cur + 1 {
spin_loop()
}
}
}
}
}
impl<T: FromByteSlice> CacheReader<T> {
pub fn new(filename: &Path, window_size: Option<usize>, degree: usize) -> Result<Self> {
info!("initializing cache");
let file = File::open(filename)?;
let size = File::metadata(&file)?.len() as usize;
let window_size = match window_size {
Some(s) => {
if s < size {
assert_eq!(
0,
size % degree * size_of::<T>(),
"window size is not multiple of element size"
);
};
s
}
None => {
let num_windows = 8;
assert_eq!(0, size % num_windows);
size / num_windows
}
};
let buf0 = Self::map_buf(0, window_size, &file)?;
let buf1 = Self::map_buf(window_size as u64, window_size, &file)?;
Ok(Self {
file,
bufs: UnsafeCell::new([buf0, buf1]),
size,
degree,
window_size,
cursor: IncrementingCursor::new(0),
consumer: AtomicU64::new(0),
_t: PhantomData::<T>,
})
}
pub fn size(&self) -> usize {
self.size
}
pub fn window_nodes(&self) -> usize {
self.size() / (size_of::<T>() * self.degree)
}
pub unsafe fn increment_consumer(&self) {
self.consumer.fetch_add(1, Ordering::SeqCst);
}
pub fn store_consumer(&self, val: u64) {
self.consumer.store(val, Ordering::SeqCst);
}
pub fn get_consumer(&self) -> u64 {
self.consumer.load(Ordering::SeqCst)
}
#[inline]
fn get_bufs(&self) -> &[Mmap] {
unsafe { std::slice::from_raw_parts((*self.bufs.get()).as_ptr(), 2) }
}
#[inline]
#[allow(clippy::mut_from_ref)]
unsafe fn get_mut_bufs(&self) -> &mut [Mmap] {
slice::from_raw_parts_mut((*self.bufs.get()).as_mut_ptr(), 2)
}
#[allow(dead_code)]
pub fn reset(&self) -> Result<()> {
self.start_reset()?;
self.finish_reset()
}
pub fn start_reset(&self) -> Result<()> {
let buf0 = Self::map_buf(0, self.window_size, &self.file)?;
let bufs = unsafe { self.get_mut_bufs() };
bufs[0] = buf0;
Ok(())
}
pub fn finish_reset(&self) -> Result<()> {
let buf1 = Self::map_buf(self.window_size as u64, self.window_size, &self.file)?;
let bufs = unsafe { self.get_mut_bufs() };
bufs[1] = buf1;
self.cursor.store(0);
Ok(())
}
fn map_buf(offset: u64, len: usize, file: &File) -> Result<Mmap> {
unsafe {
MmapOptions::new()
.offset(offset)
.len(len)
.private()
.map(file)
.map_err(|e| e.into())
}
}
#[inline]
fn window_element_count(&self) -> usize {
self.window_size / size_of::<T>()
}
#[inline]
pub unsafe fn consumer_slice_at(&self, pos: usize) -> &[T] {
assert!(
pos < self.size,
"pos {} out of range for buffer of size {}",
pos,
self.size
);
let window = pos / self.window_element_count();
let pos = pos % self.window_element_count();
let targeted_buf = &self.get_bufs()[window % 2];
&targeted_buf.as_slice_of::<T>().expect("as_slice_of failed")[pos..]
}
#[inline]
pub unsafe fn slice_at(&self, pos: usize) -> &[T] {
assert!(
pos < self.size,
"pos {} out of range for buffer of size {}",
pos,
self.size
);
let window = pos / self.window_element_count();
if window == 1 {
self.cursor.compare_and_swap(0, 1);
}
let pos = pos % self.window_element_count();
let wait_fn = || {
let safe_consumer = (window - 1) * (self.window_element_count() / self.degree);
(self.consumer.load(Ordering::SeqCst) as usize) < safe_consumer
};
self.cursor
.increment(window, &wait_fn, &|| self.advance_rear_window(window));
let targeted_buf = &self.get_bufs()[window % 2];
&targeted_buf.as_slice_of::<T>().expect("as_slice_of failed")[pos..]
}
fn advance_rear_window(&self, new_window: usize) {
assert!(new_window as usize * self.window_size < self.size);
let replace_idx = (new_window % 2) as usize;
let new_buf = Self::map_buf(
(new_window * self.window_size) as u64,
self.window_size as usize,
&self.file,
)
.expect("map_buf failed");
unsafe {
self.get_mut_bufs()[replace_idx] = new_buf;
}
}
}
fn allocate_layer(sector_size: usize) -> Result<MmapMut> {
match MmapOptions::new()
.len(sector_size)
.private()
.clone()
.lock()
.map_anon()
.and_then(|mut layer| {
layer.mlock()?;
Ok(layer)
}) {
Ok(layer) => Ok(layer),
Err(err) => {
warn!("failed to lock map {:?}, falling back", err);
let layer = MmapOptions::new().len(sector_size).private().map_anon()?;
Ok(layer)
}
}
}
pub fn setup_create_label_memory(
sector_size: usize,
degree: usize,
window_size: Option<usize>,
cache_path: &Path,
) -> Result<(CacheReader<u32>, MmapMut, MmapMut)> {
let parents_cache = CacheReader::new(cache_path, window_size, degree)?;
let layer_labels = allocate_layer(sector_size)?;
let exp_labels = allocate_layer(sector_size)?;
Ok((parents_cache, layer_labels, exp_labels))
}