use std::{
fmt, hint,
marker::PhantomData,
mem,
ops::{Bound, Deref, DerefMut, RangeBounds},
ptr::NonNull,
slice,
};
use crossbeam_utils::atomic::AtomicCell;
#[repr(transparent)]
#[derive(Clone, Copy, Eq, PartialEq)]
struct SendNonNull<T> {
ptr: NonNull<T>,
}
unsafe impl<T> Send for SendNonNull<T> {}
impl<T> From<NonNull<T>> for SendNonNull<T> {
fn from(ptr: NonNull<T>) -> Self {
Self { ptr }
}
}
static ROOT: AtomicCell<Option<SendNonNull<()>>> = AtomicCell::new(None);
#[repr(C)]
pub struct Slice<'a, T> {
offset: isize,
len: usize,
_phantom: PhantomData<&'a mut [T]>,
}
unsafe impl<'a, T: Send> Send for Slice<'a, T> {}
unsafe impl<'a, T: Sync> Sync for Slice<'a, T> {}
impl<'a, T> Deref for Slice<'a, T> {
type Target = [T];
#[inline]
fn deref(&self) -> &'a Self::Target {
let root: NonNull<T> = ROOT.load().unwrap().ptr.cast();
unsafe { slice::from_raw_parts(root.as_ptr().offset(self.offset), self.len) }
}
}
impl<'a, T> DerefMut for Slice<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &'a mut Self::Target {
let root: NonNull<T> = ROOT.load().unwrap().ptr.cast();
unsafe { slice::from_raw_parts_mut(root.as_ptr().offset(self.offset), self.len) }
}
}
impl<T: fmt::Debug> fmt::Debug for Slice<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
#[repr(transparent)]
pub struct Span<'a>(Slice<'a, ()>);
impl<'a> Span<'a> {
fn from_slice(slice: &Slice<'a, ()>) -> Self {
Self(Slice {
offset: slice.offset,
len: slice.len,
_phantom: PhantomData,
})
}
#[inline]
pub fn slice<R: RangeBounds<usize>>(&self, range: R) -> Option<Self> {
let start = match range.start_bound() {
Bound::Included(&i) => i,
Bound::Excluded(&i) => i + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(&i) => i + 1,
Bound::Excluded(&i) => i,
Bound::Unbounded => self.0.len,
};
(start <= end && end <= self.0.len).then_some(Span(Slice {
offset: self.0.offset + start as isize,
len: end,
..self.0
}))
}
#[inline]
pub fn split_at(&self, mid: usize) -> (Self, Self) {
assert!(mid <= self.0.len);
(
Span(Slice { len: mid, ..self.0 }),
Span(Slice {
offset: self.0.offset + mid as isize,
len: self.0.len - mid,
..self.0
}),
)
}
#[inline]
pub fn chunks(self, chunk_size: usize) -> Chunks<'a> {
Chunks {
slice: self.0,
size: chunk_size,
}
}
}
impl fmt::Debug for Span<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.offset.fmt(f)?;
write!(f, "..")?;
self.0.len.fmt(f)?;
Ok(())
}
}
pub struct Chunks<'a> {
slice: Slice<'a, ()>,
size: usize,
}
impl<'a> Iterator for Chunks<'a> {
type Item = Span<'a>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
(self.slice.len > 0).then(|| {
let span = Span(Slice {
len: self.size.min(self.slice.len),
..self.slice
});
self.slice.offset += self.size as isize;
self.slice.len = self.slice.len.saturating_sub(self.size);
span
})
}
}
impl fmt::Debug for Chunks<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Chunks")
.field("span", &Span::from_slice(&self.slice))
.field("size", &self.size)
.finish()
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct Ref<'a, T: ?Sized>(&'a mut T);
impl<'a, T: ?Sized> Ref<'a, T> {
pub fn get(&'a mut self) -> &'a mut T {
self.0
}
}
impl<T: ?Sized> Deref for Ref<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<T: ?Sized> DerefMut for Ref<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}
impl<T: ?Sized> Drop for Ref<'_, T> {
#[inline]
fn drop(&mut self) {
ROOT.store(None);
}
}
pub struct SliceCache {
len: usize,
slices: Box<[Slice<'static, ()>]>,
}
impl SliceCache {
#[inline]
pub fn new<F>(len: usize, f: F) -> Self
where
F: Fn(Span<'_>) -> Box<[Span<'_>]> + 'static,
{
let span = Span(Slice {
offset: 0,
len,
_phantom: PhantomData,
});
Self {
len,
slices: unsafe { mem::transmute(f(span)) },
}
}
#[inline]
pub fn access<'c, 's, T>(&'c mut self, slice: &'s mut [T]) -> Option<Ref<'c, [Slice<'s, T>]>> {
if slice.len() >= self.len {
while ROOT
.compare_exchange(
None,
Some(NonNull::new(slice.as_mut_ptr()).unwrap().cast().into()),
)
.is_err()
{
hint::spin_loop();
}
return Some(unsafe { mem::transmute(&mut *self.slices) });
}
None
}
#[cfg(test)]
fn try_access<'c, 's, T>(&'c mut self, slice: &'s mut [T]) -> Option<Ref<'c, [Slice<'s, T>]>> {
if slice.len() >= self.len
&& ROOT
.compare_exchange(
None,
Some(NonNull::new(slice.as_mut_ptr()).unwrap().cast().into()),
)
.is_ok()
{
return Some(unsafe { mem::transmute(&mut *self.slices) });
}
None
}
}
impl fmt::Debug for SliceCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.slices.iter().map(Span::from_slice))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_at() {
let mut cache = SliceCache::new(5, |span| {
let (left, right) = span.split_at(2);
Box::new([left, right])
});
let mut array = [1, 2, 3, 4, 5];
for slice in cache.access(&mut array).unwrap().iter_mut() {
for val in slice.iter_mut() {
*val += 1;
}
}
assert_eq!(array, [2, 3, 4, 5, 6]);
}
#[test]
fn chunks() {
let mut cache = SliceCache::new(5, |span| span.chunks(2).collect());
let mut array = [1, 2, 3, 4, 5];
for slice in cache.access(&mut array).unwrap().iter_mut() {
for val in slice.iter_mut() {
*val += 1;
}
}
assert_eq!(array, [2, 3, 4, 5, 6]);
}
#[test]
fn ref_twice() {
let mut cache = SliceCache::new(5, |span| {
let (left, right) = span.split_at(2);
Box::new([left, right])
});
let mut array = [1, 2, 3, 4, 5];
for slice in cache.access(&mut array).unwrap().iter_mut() {
for val in slice.iter_mut() {
*val += 1;
}
}
for slice in cache.access(&mut array).unwrap().iter_mut() {
for val in slice.iter_mut() {
*val += 1;
}
}
assert_eq!(array, [3, 4, 5, 6, 7]);
}
#[test]
fn access_twice() {
let mut cache0 = SliceCache::new(5, |span| Box::new([span]));
let mut cache1 = SliceCache::new(5, |span| Box::new([span]));
let mut array0 = [1, 2, 3, 4, 5];
let mut array1 = [1, 2, 3, 4, 5];
let _slices = cache0.access(&mut array0).unwrap();
assert!(matches!(cache1.try_access(&mut array1), None));
}
}