use std::{
cmp,
ops::{Index, IndexMut, RangeBounds, Bound},
slice::{Iter, IterMut},
};
#[macro_export]
macro_rules! shape {
[ $( $x:expr ),* $(,)? ] => {
$crate::Shape::from([ $( $x ),* ].as_ref())
};
}
fn count_non_one(slice: &[usize]) -> usize {
slice.iter().cloned().rev().skip_while(|&x| x == 1).count()
}
fn trim_slice(slice: &[usize]) -> &[usize] {
&slice[..count_non_one(slice)]
}
fn trim_mut_slice(slice: &mut [usize]) -> &mut [usize] {
let len = count_non_one(slice);
&mut slice[..len]
}
fn trim_vec(vec: &mut Vec<usize>) {
vec.truncate(count_non_one(vec.as_slice()));
}
#[derive(Clone, Debug)]
pub struct Shape {
vec: Vec<usize>,
}
impl From<Vec<usize>> for Shape {
fn from(vec: Vec<usize>) -> Self {
Self { vec }
}
}
impl From<&[usize]> for Shape {
fn from(slice: &[usize]) -> Self {
Self::from(slice.iter().cloned().collect::<Vec<_>>())
}
}
impl Into<Vec<usize>> for Shape {
fn into(mut self) -> Vec<usize> {
trim_vec(&mut self.vec);
self.vec
}
}
impl PartialEq<Shape> for Shape {
fn eq(&self, other: &Shape) -> bool {
trim_slice(self.vec.as_slice()) == trim_slice(other.vec.as_slice())
}
}
impl Shape {
pub fn len(&self) -> usize {
count_non_one(self.vec.as_slice())
}
pub fn as_slice(&self) -> &[usize] {
trim_slice(self.vec.as_slice())
}
pub fn as_mut_slice(&mut self) -> &mut [usize] {
trim_mut_slice(self.vec.as_mut_slice())
}
pub fn iter(&self) -> Iter<usize> {
trim_slice(self.vec.as_slice()).iter()
}
pub fn iter_mut(&mut self) -> IterMut<usize> {
trim_mut_slice(self.vec.as_mut_slice()).iter_mut()
}
pub fn content(&self) -> usize {
self.iter().fold(1, |a, &x| a * x )
}
}
impl Index<usize> for Shape {
type Output = usize;
fn index(&self, i: usize) -> &usize {
if i < self.vec.len() {
&self.vec[i]
} else {
&1
}
}
}
impl IndexMut<usize> for Shape {
fn index_mut(&mut self, i: usize) -> &mut usize {
if i >= self.vec.len() {
self.vec.resize(i + 1, 1);
}
&mut self.vec[i]
}
}
impl Shape {
pub fn slice<R: RangeBounds<usize>>(&self, range: R) -> Shape {
let len = self.len();
let sidx = cmp::min(match range.start_bound() {
Bound::Included(i) => *i,
Bound::Excluded(i) => *i + 1,
Bound::Unbounded => 0,
}, len);
let eidx = cmp::min(match range.end_bound() {
Bound::Included(i) => *i + 1,
Bound::Excluded(i) => *i,
Bound::Unbounded => len,
}, len);
Self::from(&self.as_slice()[sidx..eidx])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from() {
let shape = Shape::from([1, 2, 3].as_ref());
assert_eq!(shape.as_slice(), [1, 2, 3]);
}
#[test]
fn macro_() {
let shape = shape![1, 2, 3];
assert_eq!(shape.as_slice(), [1, 2, 3]);
}
#[test]
fn trim() {
let shape = shape![1, 2, 1, 3, 1, 1];
assert_eq!(shape.len(), 4);
assert_eq!(shape, shape![1, 2, 1, 3]);
}
#[test]
fn eq() {
assert_eq!(
shape![1, 2, 1, 3, 1],
shape![1, 2, 1, 3],
);
}
#[test]
fn index() {
let mut shape = shape![1, 2, 1, 3, 1];
assert_eq!(shape[1], 2);
assert_eq!(shape[5], 1);
assert_eq!(shape, shape![1, 2, 1, 3]);
shape[5] = 1;
assert_eq!(shape, shape![1, 2, 1, 3]);
shape[5] = 4;
assert_eq!(shape, shape![1, 2, 1, 3, 1, 4]);
}
#[test]
fn iter() {
let mut shape = shape![1, 2, 1];
let mut iter = shape.iter();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), None);
let mut iter_mut = shape.iter_mut();
*iter_mut.next().unwrap() = 3;
assert_eq!(shape, shape![3, 2, 1]);
}
#[test]
fn slice() {
let shape = shape![1, 2, 1, 3, 1];
assert_eq!(shape.slice(..), shape);
assert_eq!(shape.slice(..3), shape![1, 2]);
assert_eq!(shape.slice(1..=3), shape![2, 1, 3]);
assert_eq!(shape.slice(2..5), shape![1, 3]);
assert_eq!(shape.slice(5..10), shape![]);
assert_eq!(shape.slice(5..), shape![]);
}
#[test]
fn content() {
assert_eq!(shape![1, 2, 3].content(), 6);
}
}