use std::ops::{Bound, Index, IndexMut, RangeBounds};
pub(crate) struct Shards {
shard_count: usize,
shard_bytes: usize,
data: Vec<u8>,
}
impl Shards {
pub(crate) fn as_ref_mut(&mut self) -> ShardsRefMut {
ShardsRefMut::new(self.shard_count, self.shard_bytes, self.data.as_mut())
}
pub(crate) fn new() -> Self {
Self {
shard_count: 0,
shard_bytes: 0,
data: Vec::new(),
}
}
pub(crate) fn resize(&mut self, shard_count: usize, shard_bytes: usize) {
assert!(shard_bytes > 0 && shard_bytes & 63 == 0);
self.shard_count = shard_count;
self.shard_bytes = shard_bytes;
self.data.resize(shard_count * shard_bytes, 0);
}
}
impl Index<usize> for Shards {
type Output = [u8];
fn index(&self, index: usize) -> &Self::Output {
&self.data[index * self.shard_bytes..(index + 1) * self.shard_bytes]
}
}
impl IndexMut<usize> for Shards {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index * self.shard_bytes..(index + 1) * self.shard_bytes]
}
}
pub struct ShardsRefMut<'a> {
shard_count: usize,
shard_bytes: usize,
data: &'a mut [u8],
}
impl<'a> ShardsRefMut<'a> {
pub fn dist2_mut(&mut self, mut pos: usize, mut dist: usize) -> (&mut [u8], &mut [u8]) {
pos *= self.shard_bytes;
dist *= self.shard_bytes;
let (a, b) = self.data[pos..].split_at_mut(dist);
(&mut a[..self.shard_bytes], &mut b[..self.shard_bytes])
}
pub fn dist4_mut(
&mut self,
mut pos: usize,
mut dist: usize,
) -> (&mut [u8], &mut [u8], &mut [u8], &mut [u8]) {
pos *= self.shard_bytes;
dist *= self.shard_bytes;
let (ab, cd) = self.data[pos..].split_at_mut(dist * 2);
let (a, b) = ab.split_at_mut(dist);
let (c, d) = cd.split_at_mut(dist);
(
&mut a[..self.shard_bytes],
&mut b[..self.shard_bytes],
&mut c[..self.shard_bytes],
&mut d[..self.shard_bytes],
)
}
pub fn is_empty(&self) -> bool {
self.shard_count == 0
}
pub fn len(&self) -> usize {
self.shard_count
}
pub fn new(shard_count: usize, shard_bytes: usize, data: &'a mut [u8]) -> Self {
Self {
shard_count,
shard_bytes,
data: &mut data[..shard_count * shard_bytes],
}
}
pub fn split_at_mut(&mut self, mid: usize) -> (ShardsRefMut, ShardsRefMut) {
let (a, b) = self.data.split_at_mut(mid * self.shard_bytes);
(
ShardsRefMut::new(mid, self.shard_bytes, a),
ShardsRefMut::new(self.shard_count - mid, self.shard_bytes, b),
)
}
pub fn zero<R: RangeBounds<usize>>(&mut self, range: R) {
let start = match range.start_bound() {
Bound::Included(start) => start * self.shard_bytes,
Bound::Excluded(start) => (start + 1) * self.shard_bytes,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(end) => (end + 1) * self.shard_bytes,
Bound::Excluded(end) => end * self.shard_bytes,
Bound::Unbounded => self.shard_count * self.shard_bytes,
};
self.data[start..end].fill(0);
}
}
impl<'a> Index<usize> for ShardsRefMut<'a> {
type Output = [u8];
fn index(&self, index: usize) -> &Self::Output {
&self.data[index * self.shard_bytes..(index + 1) * self.shard_bytes]
}
}
impl<'a> IndexMut<usize> for ShardsRefMut<'a> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index * self.shard_bytes..(index + 1) * self.shard_bytes]
}
}
impl<'a> ShardsRefMut<'a> {
pub(crate) fn copy_within(&mut self, mut src: usize, mut dest: usize, mut count: usize) {
src *= self.shard_bytes;
dest *= self.shard_bytes;
count *= self.shard_bytes;
self.data.copy_within(src..src + count, dest);
}
pub(crate) fn flat2_mut(
&mut self,
mut x: usize,
mut y: usize,
mut count: usize,
) -> (&mut [u8], &mut [u8]) {
x *= self.shard_bytes;
y *= self.shard_bytes;
count *= self.shard_bytes;
if x < y {
let (head, tail) = self.data.split_at_mut(y);
(&mut head[x..x + count], &mut tail[..count])
} else {
let (head, tail) = self.data.split_at_mut(x);
(&mut tail[..count], &mut head[y..y + count])
}
}
}