use std::mem;
pub trait SliceExt<T> {
unsafe fn get_many_unchecked_mut<const N: usize>(&mut self, indices: [usize; N])
-> [&mut T; N];
fn get_many_mut_opt<const N: usize>(&mut self, indices: [usize; N]) -> Option<[&mut T; N]>;
fn get_many_mut_res_simple<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorSimple<N>>;
fn get_many_mut_res_direct<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorKind>;
fn get_many_mut_res_indirect<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], Error<N>>;
fn get_many_mut_res_indirect_niche<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorNiche<N>>;
}
impl<T> SliceExt<T> for [T] {
#[inline]
unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> [&mut T; N] {
let slice: *mut [T] = self;
let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit();
let arr_ptr = arr.as_mut_ptr();
unsafe {
for i in 0..N {
let idx = *indices.get_unchecked(i);
*(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx);
}
arr.assume_init()
}
}
#[inline]
fn get_many_mut_opt<const N: usize>(&mut self, indices: [usize; N]) -> Option<[&mut T; N]> {
if !get_many_check_valid(&indices, self.len()) {
return None;
}
unsafe { Some(self.get_many_unchecked_mut(indices)) }
}
fn get_many_mut_res_simple<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorSimple<N>> {
if !get_many_check_valid(&indices, self.len()) {
return Err(ErrorSimple { _private: () });
}
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
fn get_many_mut_res_direct<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorKind> {
get_many_check_valid_kinds(&indices, self.len())?;
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
fn get_many_mut_res_indirect<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], Error<N>> {
if !get_many_check_valid(&indices, self.len()) {
return Err(Error {
indices,
slice_len: self.len(),
});
}
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
fn get_many_mut_res_indirect_niche<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], ErrorNiche<N>> {
if !get_many_check_valid(&indices, self.len()) {
return Err(ErrorNiche {
indices,
slice_len: unsafe { SliceLenWithNiche(self.len().unchecked_add(2)) },
});
}
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
}
fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> bool {
let mut valid = true;
for (i, &idx) in indices.iter().enumerate() {
valid &= idx < len;
for &idx2 in &indices[..i] {
valid &= idx != idx2;
}
}
valid
}
pub struct ErrorSimple<const N: usize> {
_private: (),
}
impl<const N: usize> std::fmt::Debug for ErrorSimple<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorSimple").finish_non_exhaustive()
}
}
pub struct Error<const N: usize> {
indices: [usize; N],
slice_len: usize,
}
impl<const N: usize> std::fmt::Debug for Error<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Error").finish_non_exhaustive()
}
}
impl<const N: usize> Error<N> {
pub fn kind(&self) -> ErrorKind {
get_many_check_valid_kinds(&self.indices, self.slice_len).unwrap_err()
}
}
#[rustc_layout_scalar_valid_range_start(2)]
#[rustc_nonnull_optimization_guaranteed]
struct SliceLenWithNiche(usize);
pub struct ErrorNiche<const N: usize> {
indices: [usize; N],
slice_len: SliceLenWithNiche,
}
impl<const N: usize> std::fmt::Debug for ErrorNiche<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorNiche").finish_non_exhaustive()
}
}
impl<const N: usize> ErrorNiche<N> {
pub fn kind(&self) -> ErrorKind {
let len = unsafe { self.slice_len.0.unchecked_sub(2) };
get_many_check_valid_kinds(&self.indices, len).unwrap_err()
}
}
#[derive(Debug)]
pub enum ErrorKind {
OutOfBounds,
NotUnique,
}
fn get_many_check_valid_kinds<const N: usize>(
indices: &[usize; N],
len: usize,
) -> Result<(), ErrorKind> {
for (i, &idx) in indices.iter().enumerate() {
if idx >= len {
return Err(ErrorKind::OutOfBounds);
}
for &idx2 in &indices[..i] {
if idx == idx2 {
return Err(ErrorKind::NotUnique);
}
}
}
Ok(())
}
#[test]
#[should_panic]
fn test1() {
[1].get_many_mut_res_simple([99]).unwrap();
}
#[test]
fn test2() {
assert!(matches!(
[1].get_many_mut_res_indirect([99]).unwrap_err().kind(),
ErrorKind::OutOfBounds
));
assert!(matches!(
[1].get_many_mut_res_indirect([0, 0]).unwrap_err().kind(),
ErrorKind::NotUnique
));
assert!(matches!(
[1].get_many_mut_res_indirect_niche([99])
.unwrap_err()
.kind(),
ErrorKind::OutOfBounds
));
assert!(matches!(
[1].get_many_mut_res_indirect_niche([0, 0])
.unwrap_err()
.kind(),
ErrorKind::NotUnique
));
}