use serde::Deserialize;
use serde::Serialize;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SliceError {
#[error("invalid dims: expected {expected}, got {got}")]
InvalidDims { expected: usize, got: usize },
#[error("nonrectangular shape")]
NonrectangularShape,
#[error("nonunique strides")]
NonuniqueStrides,
#[error("stride {stride} must be larger than size of previous space {space}")]
StrideTooSmall { stride: usize, space: usize },
#[error("index {index} out of range {total}")]
IndexOutOfRange { index: usize, total: usize },
#[error("value {value} not in slice")]
ValueNotInSlice { value: usize },
}
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Slice {
offset: usize,
sizes: Vec<usize>,
strides: Vec<usize>,
}
impl Slice {
pub fn new(offset: usize, sizes: Vec<usize>, strides: Vec<usize>) -> Result<Self, SliceError> {
if sizes.len() != strides.len() {
return Err(SliceError::InvalidDims {
expected: sizes.len(),
got: strides.len(),
});
}
let mut combined: Vec<(usize, usize)> =
strides.iter().cloned().zip(sizes.iter().cloned()).collect();
combined.sort();
let mut prev_stride: Option<usize> = None;
let mut prev_size: Option<usize> = None;
let mut total: usize = 1;
for (stride, size) in combined {
if let Some(prev_stride) = prev_stride {
if stride % prev_stride != 0 {
return Err(SliceError::NonrectangularShape);
}
if stride == prev_stride && size != 1 && prev_size.unwrap_or(1) != 1 {
return Err(SliceError::NonuniqueStrides);
}
}
if total > stride {
return Err(SliceError::StrideTooSmall {
stride,
space: total,
});
}
total = stride * size;
prev_stride = Some(stride);
prev_size = Some(size);
}
Ok(Slice {
offset,
sizes,
strides,
})
}
pub fn new_row_major(sizes: impl Into<Vec<usize>>) -> Self {
let sizes = sizes.into();
let mut strides: Vec<usize> = sizes.clone();
let _ = strides.iter_mut().rev().fold(1, |acc, n| {
let next = *n * acc;
*n = acc;
next
});
Self {
offset: 0,
sizes,
strides,
}
}
pub fn new_single_multi_dim_cell(dims: usize) -> Self {
Self {
offset: 0,
sizes: vec![1; dims],
strides: vec![1; dims],
}
}
pub fn num_dim(&self) -> usize {
self.sizes.len()
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn sizes(&self) -> &[usize] {
&self.sizes
}
pub fn strides(&self) -> &[usize] {
&self.strides
}
pub fn location(&self, coord: &[usize]) -> Result<usize, SliceError> {
if coord.len() != self.sizes.len() {
return Err(SliceError::InvalidDims {
expected: self.sizes.len(),
got: coord.len(),
});
}
Ok(self.offset
+ coord
.iter()
.zip(&self.strides)
.map(|(pos, stride)| pos * stride)
.sum::<usize>())
}
pub fn coordinates(&self, value: usize) -> Result<Vec<usize>, SliceError> {
let mut pos = value
.checked_sub(self.offset)
.ok_or(SliceError::ValueNotInSlice { value })?;
let mut result = vec![0; self.sizes.len()];
let mut sorted_info: Vec<_> = self
.strides
.iter()
.zip(self.sizes.iter().enumerate())
.collect();
sorted_info.sort_by_key(|&(stride, _)| *stride);
for &(stride, (i, &size)) in sorted_info.iter().rev() {
let (index, new_pos) = if size > 1 {
(pos / stride, pos % stride)
} else {
(0, pos)
};
if index >= size {
return Err(SliceError::ValueNotInSlice { value });
}
result[i] = index;
pos = new_pos;
}
if pos != 0 {
return Err(SliceError::ValueNotInSlice { value });
}
Ok(result)
}
pub fn get(&self, index: usize) -> Result<usize, SliceError> {
let mut val = self.offset;
let mut rest = index;
let mut total = 1;
for (size, stride) in self.sizes.iter().zip(self.strides.iter()).rev() {
total *= size;
val += (rest % size) * stride;
rest /= size;
}
if index < total {
Ok(val)
} else {
Err(SliceError::IndexOutOfRange { index, total })
}
}
pub fn len(&self) -> usize {
self.sizes.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter(&self) -> SliceIterator {
SliceIterator {
slice: self,
pos: CartesianIterator::new(&self.sizes),
}
}
pub fn dim_iter(&self, dims: usize) -> DimSliceIterator {
DimSliceIterator {
pos: CartesianIterator::new(&self.sizes[0..dims]),
}
}
pub fn index(&self, value: usize) -> Result<usize, SliceError> {
let coords = self.coordinates(value)?;
let mut stride = 1;
let mut result = 0;
for (idx, size) in coords.iter().rev().zip(self.sizes.iter().rev()) {
result += *idx * stride;
stride *= size;
}
Ok(result)
}
pub fn map<T, F>(&self, mapper: F) -> MapSlice<'_, T, F>
where
F: Fn(usize) -> T,
{
MapSlice {
slice: self,
mapper,
}
}
}
impl std::fmt::Display for Slice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl<'a> IntoIterator for &'a Slice {
type Item = usize;
type IntoIter = SliceIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct SliceIterator<'a> {
slice: &'a Slice,
pos: CartesianIterator<'a>,
}
impl<'a> Iterator for SliceIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
match self.pos.next() {
None => None,
Some(pos) => Some(self.slice.location(&pos).unwrap()),
}
}
}
pub struct DimSliceIterator<'a> {
pos: CartesianIterator<'a>,
}
impl<'a> Iterator for DimSliceIterator<'a> {
type Item = Vec<usize>;
fn next(&mut self) -> Option<Self::Item> {
self.pos.next()
}
}
struct CartesianIterator<'a> {
dims: &'a [usize],
index: usize,
}
impl<'a> CartesianIterator<'a> {
fn new(dims: &'a [usize]) -> Self {
CartesianIterator { dims, index: 0 }
}
}
impl<'a> Iterator for CartesianIterator<'a> {
type Item = Vec<usize>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.dims.iter().product::<usize>() {
return None;
}
let mut result: Vec<usize> = vec![0; self.dims.len()];
let mut rest = self.index;
for (i, dim) in self.dims.iter().enumerate().rev() {
result[i] = rest % dim;
rest /= dim;
}
self.index += 1;
Some(result)
}
}
pub struct MapSlice<'a, T, F>
where
F: Fn(usize) -> T,
{
slice: &'a Slice,
mapper: F,
}
impl<'a, T, F> MapSlice<'a, T, F>
where
F: Fn(usize) -> T,
{
pub fn sizes(&self) -> &[usize] {
&self.slice.sizes
}
pub fn strides(&self) -> &[usize] {
&self.slice.strides
}
pub fn location(&self, coord: &[usize]) -> Result<T, SliceError> {
self.slice.location(coord).map(&self.mapper)
}
pub fn get(&self, index: usize) -> Result<T, SliceError> {
self.slice.get(index).map(&self.mapper)
}
pub fn len(&self) -> usize {
self.slice.len()
}
pub fn is_empty(&self) -> bool {
self.slice.is_empty()
}
}
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::vec;
use super::*;
#[test]
fn test_cartesian_iterator() {
let dims = vec![2, 2, 2];
let iter = CartesianIterator::new(&dims);
let products: Vec<Vec<usize>> = iter.collect();
assert_eq!(
products,
vec![
vec![0, 0, 0],
vec![0, 0, 1],
vec![0, 1, 0],
vec![0, 1, 1],
vec![1, 0, 0],
vec![1, 0, 1],
vec![1, 1, 0],
vec![1, 1, 1],
]
);
}
#[test]
#[allow(clippy::explicit_counter_loop)]
fn test_slice() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
for i in 0..4 {
assert_eq!(s.get(i).unwrap(), i);
}
{
let mut current = 0;
for index in &s {
assert_eq!(index, current);
current += 1;
}
}
let s = Slice::new(0, vec![3, 4, 5], vec![20, 5, 1]).unwrap();
assert_eq!(s.get(3 * 4 + 1).unwrap(), 13);
let s = Slice::new(0, vec![2, 2, 2], vec![4, 32, 1]).unwrap();
assert_eq!(s.get(0).unwrap(), 0);
assert_eq!(s.get(1).unwrap(), 1);
assert_eq!(s.get(2).unwrap(), 32);
assert_eq!(s.get(3).unwrap(), 33);
assert_eq!(s.get(4).unwrap(), 4);
assert_eq!(s.get(5).unwrap(), 5);
assert_eq!(s.get(6).unwrap(), 36);
assert_eq!(s.get(7).unwrap(), 37);
let s = Slice::new(0, vec![2, 2, 2], vec![32, 4, 1]).unwrap();
assert_eq!(s.get(0).unwrap(), 0);
assert_eq!(s.get(1).unwrap(), 1);
assert_eq!(s.get(2).unwrap(), 4);
assert_eq!(s.get(4).unwrap(), 32);
}
#[test]
fn test_slice_iter() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
assert!(s.iter().eq(0..6));
let s = Slice::new(10, vec![10, 2], vec![10, 5]).unwrap();
assert!(s.iter().eq((10..=105).step_by(5)));
assert!(s.iter().eq((0..s.len()).map(|i| s.get(i).unwrap())));
}
#[test]
fn test_dim_slice_iter() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
let sub_dims: Vec<_> = s.dim_iter(1).collect();
assert_eq!(sub_dims, vec![vec![0], vec![1]]);
}
#[test]
fn test_slice_coordinates() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
assert_eq!(s.coordinates(0).unwrap(), vec![0, 0]);
assert_eq!(s.coordinates(3).unwrap(), vec![1, 0]);
assert_matches!(
s.coordinates(6),
Err(SliceError::ValueNotInSlice { value: 6 })
);
let s = Slice::new(10, vec![2, 3], vec![3, 1]).unwrap();
assert_matches!(
s.coordinates(6),
Err(SliceError::ValueNotInSlice { value: 6 })
);
assert_eq!(s.coordinates(10).unwrap(), vec![0, 0]);
assert_eq!(s.coordinates(13).unwrap(), vec![1, 0]);
let s = Slice::new(0, vec![2, 1, 1], vec![1, 1, 1]).unwrap();
assert_eq!(s.coordinates(1).unwrap(), vec![1, 0, 0]);
}
#[test]
fn test_slice_index() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
assert_eq!(s.index(3).unwrap(), 3);
assert!(s.index(14).is_err());
let s = Slice::new(0, vec![2, 2], vec![4, 2]).unwrap();
assert_eq!(s.index(2).unwrap(), 1);
}
#[test]
fn test_slice_map() {
let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
let m = s.map(|i| i * 2);
assert_eq!(m.get(0).unwrap(), 0);
assert_eq!(m.get(3).unwrap(), 6);
assert_eq!(m.get(5).unwrap(), 10);
}
#[test]
fn test_slice_size_one() {
let s = Slice::new(0, vec![1, 1], vec![1, 1]).unwrap();
assert_eq!(s.get(0).unwrap(), 0);
}
#[test]
fn test_row_major() {
let s = Slice::new_row_major(vec![4, 4, 4]);
assert_eq!(s.offset(), 0);
assert_eq!(s.sizes(), &[4, 4, 4]);
assert_eq!(s.strides(), &[16, 4, 1]);
}
}