use core::marker::PhantomData;
use core::ptr::NonNull;
#[cfg(feature = "parallel")]
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use crate::{Block, SYNC_POINTS};
pub(crate) trait Memory<'a> {
fn for_each_segment<F>(&mut self, lanes: usize, f: F)
where
F: Fn(SegmentView<'_>, usize, usize) + Sync + Send;
}
impl Memory<'_> for &mut [Block] {
#[cfg(not(feature = "parallel"))]
fn for_each_segment<F>(&mut self, lanes: usize, f: F)
where
F: Fn(SegmentView<'_>, usize, usize) + Sync + Send,
{
let inner = MemoryInner::new(self, lanes);
for slice in 0..SYNC_POINTS {
for lane in 0..lanes {
let segment = unsafe { SegmentView::new(inner, slice, lane) };
f(segment, slice, lane);
}
}
}
#[cfg(feature = "parallel")]
fn for_each_segment<F>(&mut self, lanes: usize, f: F)
where
F: Fn(SegmentView<'_>, usize, usize) + Sync + Send,
{
let inner = MemoryInner::new(self, lanes);
for slice in 0..SYNC_POINTS {
(0..lanes).into_par_iter().for_each(|lane| {
let segment = unsafe { SegmentView::new(inner, slice, lane) };
f(segment, slice, lane);
});
}
}
}
#[derive(Clone, Copy)]
struct MemoryInner<'a> {
blocks: NonNull<Block>,
block_count: usize,
lane_length: usize,
phantom: PhantomData<&'a mut Block>,
}
impl MemoryInner<'_> {
fn new(memory_blocks: &mut [Block], lanes: usize) -> Self {
let block_count = memory_blocks.len();
let lane_length = block_count / lanes;
let blocks = NonNull::from(memory_blocks);
MemoryInner {
blocks: blocks.cast(),
block_count,
lane_length,
phantom: PhantomData,
}
}
fn lane_of(&self, index: usize) -> usize {
index / self.lane_length
}
fn slice_of(&self, index: usize) -> usize {
index / (self.lane_length / SYNC_POINTS) % SYNC_POINTS
}
}
unsafe impl Send for MemoryInner<'_> {}
unsafe impl Sync for MemoryInner<'_> {}
pub(crate) struct SegmentView<'a> {
inner: MemoryInner<'a>,
slice: usize,
lane: usize,
}
impl<'a> SegmentView<'a> {
unsafe fn new(inner: MemoryInner<'a>, slice: usize, lane: usize) -> Self {
SegmentView { inner, slice, lane }
}
pub fn get_block(&self, index: usize) -> &Block {
assert!(index < self.inner.block_count);
assert!(self.inner.lane_of(index) == self.lane || self.inner.slice_of(index) != self.slice);
unsafe { self.inner.blocks.add(index).as_ref() }
}
pub fn get_block_mut(&mut self, index: usize) -> &mut Block {
assert!(index < self.inner.block_count);
assert_eq!(self.inner.lane_of(index), self.lane);
assert_eq!(self.inner.slice_of(index), self.slice);
unsafe { self.inner.blocks.add(index).as_mut() }
}
}