use crate::Shape;
use crate::indexing::AsIndex;
use alloc::format;
use alloc::vec::Vec;
use core::fmt::{Display, Formatter};
use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use core::str::FromStr;
pub trait SliceArg {
fn into_slices(self, shape: &Shape) -> Vec<Slice>;
}
impl<S: Into<Slice> + Clone> SliceArg for &[S] {
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
assert!(
self.len() <= shape.num_dims(),
"Too many slices provided for shape, got {} but expected at most {}",
self.len(),
shape.num_dims()
);
shape
.iter()
.enumerate()
.map(|(i, dim_size)| {
let slice = if i >= self.len() {
Slice::full()
} else {
self[i].clone().into()
};
let clamped_range = slice.to_range(*dim_size);
Slice::new(
clamped_range.start as isize,
Some(clamped_range.end as isize),
slice.step(),
)
})
.collect::<Vec<_>>()
}
}
impl SliceArg for &Vec<Slice> {
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
self.as_slice().into_slices(shape)
}
}
impl<const R: usize, T> SliceArg for [T; R]
where
T: Into<Slice> + Clone,
{
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
self.as_slice().into_slices(shape)
}
}
impl<T> SliceArg for T
where
T: Into<Slice>,
{
fn into_slices(self, shape: &Shape) -> Vec<Slice> {
let slice: Slice = self.into();
[slice].as_slice().into_slices(shape)
}
}
#[macro_export]
macro_rules! s {
[] => {
compile_error!("Empty slice specification")
};
[$range:expr; $step:expr] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::tensor::Slice::from_range_stepped($range, $step)
}
}
};
[$range:expr] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::tensor::Slice::from($range)
}
}
};
[$range:expr; $step:expr, $($rest:tt)*] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
}
}
};
[$range:expr, $($rest:tt)*] => {
{
#[allow(clippy::reversed_empty_ranges)]
{
$crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
}
}
};
(@internal [$($acc:expr),*]) => {
[$($acc),*]
};
(@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
};
(@internal [$($acc:expr),*] $range:expr; $step:expr) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
};
(@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
};
(@internal [$($acc:expr),*] $range:expr) => {
$crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
};
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Slice {
pub start: isize,
pub end: Option<isize>,
pub step: isize,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SliceIter {
slice: Slice,
current: isize,
}
impl Iterator for SliceIter {
type Item = isize;
fn next(&mut self) -> Option<Self::Item> {
let next = self.current;
self.current += self.slice.step;
if let Some(end) = self.slice.end {
if self.slice.is_reversed() {
if next <= end {
return None;
}
} else if next >= end {
return None;
}
}
Some(next)
}
}
impl IntoIterator for Slice {
type Item = isize;
type IntoIter = SliceIter;
fn into_iter(self) -> Self::IntoIter {
SliceIter {
slice: self,
current: self.start,
}
}
}
impl Default for Slice {
fn default() -> Self {
Self::full()
}
}
impl Slice {
pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
Self { start, end, step }
}
pub const fn full() -> Self {
Self::new(0, None, 1)
}
pub fn index(idx: isize) -> Self {
Self {
start: idx,
end: handle_signed_inclusive_end(idx),
step: 1,
}
}
pub fn into_vec(self) -> Vec<isize> {
assert!(
self.end.is_some(),
"Slice must have an end to convert to a vector: {self:?}"
);
self.into_iter().collect()
}
pub fn bound_to(self, size: usize) -> Self {
let mut bounds = size as isize;
if let Some(end) = self.end {
if end > 0 {
bounds = end.min(bounds);
} else {
bounds = end.max(-(bounds + 1));
}
} else if self.is_reversed() {
bounds = -(bounds + 1);
}
Self {
end: Some(bounds),
..self
}
}
pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
Self { start, end, step }
}
pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
assert!(step != 0, "Step cannot be zero");
let mut slice = range.into();
slice.step = step;
slice
}
pub fn step(&self) -> isize {
self.step
}
pub fn range(&self, size: usize) -> Range<usize> {
self.to_range(size)
}
pub fn to_range(&self, size: usize) -> Range<usize> {
let start = convert_signed_index(self.start, size);
let end = match self.end {
Some(end) => convert_signed_index(end, size),
None => size,
};
start..end
}
pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
let range = self.to_range(size);
(range, self.step)
}
pub fn is_reversed(&self) -> bool {
self.step < 0
}
pub fn output_size(&self, dim_size: usize) -> usize {
let range = self.to_range(dim_size);
if range.start >= range.end {
return 0;
}
let len = range.end - range.start;
if self.step.unsigned_abs() == 1 {
len
} else {
len.div_ceil(self.step.unsigned_abs())
}
}
}
fn convert_signed_index(index: isize, size: usize) -> usize {
if index < 0 {
(size as isize + index).max(0) as usize
} else {
(index as usize).min(size)
}
}
fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
match end {
-1 => None,
end => Some(end + 1),
}
}
impl<I: AsIndex> From<Range<I>> for Slice {
fn from(r: Range<I>) -> Self {
Self {
start: r.start.as_index(),
end: Some(r.end.as_index()),
step: 1,
}
}
}
impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
fn from(r: RangeInclusive<I>) -> Self {
Self {
start: r.start().as_index(),
end: handle_signed_inclusive_end(r.end().as_index()),
step: 1,
}
}
}
impl<I: AsIndex> From<RangeFrom<I>> for Slice {
fn from(r: RangeFrom<I>) -> Self {
Self {
start: r.start.as_index(),
end: None,
step: 1,
}
}
}
impl<I: AsIndex> From<RangeTo<I>> for Slice {
fn from(r: RangeTo<I>) -> Self {
Self {
start: 0,
end: Some(r.end.as_index()),
step: 1,
}
}
}
impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
fn from(r: RangeToInclusive<I>) -> Self {
Self {
start: 0,
end: handle_signed_inclusive_end(r.end.as_index()),
step: 1,
}
}
}
impl From<RangeFull> for Slice {
fn from(_: RangeFull) -> Self {
Self {
start: 0,
end: None,
step: 1,
}
}
}
impl From<usize> for Slice {
fn from(i: usize) -> Self {
Slice::index(i as isize)
}
}
impl From<isize> for Slice {
fn from(i: isize) -> Self {
Slice::index(i)
}
}
impl From<i32> for Slice {
fn from(i: i32) -> Self {
Slice::index(i as isize)
}
}
impl From<i64> for Slice {
fn from(i: i64) -> Self {
Slice::index(i as isize)
}
}
impl Display for Slice {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
if self.step == 1
&& let Some(end) = self.end
&& self.start == end - 1
{
f.write_fmt(format_args!("{}", self.start))
} else {
if self.start != 0 {
f.write_fmt(format_args!("{}", self.start))?;
}
f.write_str("..")?;
if let Some(end) = self.end {
f.write_fmt(format_args!("{}", end))?;
}
if self.step != 1 {
f.write_fmt(format_args!(";{}", self.step))?;
}
Ok(())
}
}
}
impl FromStr for Slice {
type Err = crate::ExpressionError;
fn from_str(source: &str) -> Result<Self, Self::Err> {
let mut s = source.trim();
let parse_int = |v: &str| -> Result<isize, Self::Err> {
v.parse::<isize>().map_err(|e| {
crate::ExpressionError::parse_error(
format!("Invalid integer: '{v}': {}", e),
source,
)
})
};
let mut start: isize = 0;
let mut end: Option<isize> = None;
let mut step: isize = 1;
if let Some((head, tail)) = s.split_once(";") {
step = parse_int(tail)?;
s = head;
}
if s.is_empty() {
return Err(crate::ExpressionError::parse_error(
"Empty expression",
source,
));
}
if let Some((start_s, end_s)) = s.split_once("..") {
if !start_s.is_empty() {
start = parse_int(start_s)?;
}
if !end_s.is_empty() {
if let Some(end_s) = end_s.strip_prefix('=') {
end = Some(parse_int(end_s)? + 1);
} else {
end = Some(parse_int(end_s)?);
}
}
} else {
start = parse_int(s)?;
end = Some(start + 1);
}
if step == 0 {
return Err(crate::ExpressionError::invalid_expression(
"Step cannot be zero",
source,
));
}
Ok(Slice::new(start, end, step))
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
use alloc::vec;
#[test]
fn test_slice_to_str() {
assert_eq!(Slice::new(0, None, 1).to_string(), "..");
assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
}
#[test]
fn test_slice_from_str() {
assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
assert_eq!(
"..;0".parse::<Slice>(),
Err(crate::ExpressionError::invalid_expression(
"Step cannot be zero",
"..;0"
))
);
assert_eq!(
"".parse::<Slice>(),
Err(crate::ExpressionError::parse_error("Empty expression", ""))
);
assert_eq!(
"a".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a': invalid digit found in string",
"a"
))
);
assert_eq!(
"..a".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a': invalid digit found in string",
"..a"
))
);
assert_eq!(
"a:b:c".parse::<Slice>(),
Err(crate::ExpressionError::parse_error(
"Invalid integer: 'a:b:c': invalid digit found in string",
"a:b:c"
))
);
}
#[test]
fn test_slice_output_size() {
assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); }
#[test]
fn test_bound_to() {
assert_eq!(
Slice::new(0, None, 1).bound_to(10),
Slice::new(0, Some(10), 1)
);
assert_eq!(
Slice::new(0, Some(5), 1).bound_to(10),
Slice::new(0, Some(5), 1)
);
assert_eq!(
Slice::new(0, None, -1).bound_to(10),
Slice::new(0, Some(-11), -1)
);
assert_eq!(
Slice::new(0, Some(-5), -1).bound_to(10),
Slice::new(0, Some(-5), -1)
);
}
#[test]
fn test_slice_iter() {
assert_eq!(
Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),
vec![2]
);
assert_eq!(
Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),
vec![3, 2, 1, 0]
);
assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);
assert_eq!(
Slice::new(3, None, 2)
.into_iter()
.take(3)
.collect::<Vec<_>>(),
vec![3, 5, 7]
);
assert_eq!(
Slice::new(3, None, 2)
.bound_to(8)
.into_iter()
.collect::<Vec<_>>(),
vec![3, 5, 7]
);
}
#[test]
#[should_panic(
expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }"
)]
fn test_unbound_slice_into_vec() {
Slice::new(0, None, 1).into_vec();
}
#[test]
fn into_slices_should_return_for_all_shape_dims() {
let slice = s![1];
let shape = Shape::new([2, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
let slice = s![1, 0..2];
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
let slice = s![..];
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(0, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(3), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
fn into_slices_all_dimensions() {
let slice = s![1, ..2, ..];
let shape = Shape::new([2, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(1, Some(2), 1));
assert_eq!(slices[1], Slice::new(0, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
fn into_slices_supports_empty_dimensions() {
let slice = s![.., 1, ..];
let shape = Shape::new([0, 3, 1]);
let slices = slice.into_slices(&shape);
assert_eq!(slices.len(), shape.len());
assert_eq!(slices[0], Slice::new(0, Some(0), 1));
assert_eq!(slices[1], Slice::new(1, Some(2), 1));
assert_eq!(slices[2], Slice::new(0, Some(1), 1));
}
#[test]
#[should_panic = "Too many slices provided for shape"]
fn into_slices_should_match_shape_rank() {
let slice = s![.., 1, ..];
let shape = Shape::new([3, 1]);
let _ = slice.into_slices(&shape);
}
#[test]
fn should_support_const_and_full() {
static SLICES: [Slice; 2] = [Slice::full(), Slice::new(2, None, 1)];
assert_eq!(SLICES[0], Slice::new(0, None, 1));
assert_eq!(SLICES[1], Slice::new(2, None, 1));
}
#[test]
fn should_support_default() {
assert_eq!(Slice::default(), Slice::new(0, None, 1));
}
#[test]
fn should_support_copy() {
let mut slice = Slice::new(1, Some(3), 2);
let slice_copy = slice;
slice.end = Some(4);
assert_eq!(slice, Slice::new(1, Some(4), 2));
assert_eq!(slice_copy, Slice::new(1, Some(3), 2));
}
}