use crate::{Tensor, ZyxError, tensor::Axis};
use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo};
impl<I> std::ops::Index<I> for Tensor {
type Output = Tensor;
fn index(&self, _index: I) -> &Self::Output {
panic!(
"Tensor does not support indexing with `[]` because rust only allows indexing on referece types. \
Use `.slice(...)` instead, which supports ranges, integers, and tuples. \
Example: tensor.slice((0..3, -1))"
);
}
}
impl Tensor {
pub fn slice(&self, index: impl IntoIndex) -> Result<Tensor, ZyxError> {
let shape = self.shape();
let rank = shape.len();
let mut squeeze_axes: Vec<Axis> = Vec::new();
let index = index.into_index();
let padding_len = index.len();
if rank < padding_len {
return Err(ZyxError::shape_error(
format!("Slice with {padding_len} indices, but tensor has rank {rank}").into(),
));
}
let padding = index
.zip(shape)
.enumerate()
.map(|(axis, (dim_index, dim_size))| {
let dim_size = dim_size as i64;
match dim_index {
DimIndex::Range { start, end } => {
let s = if start < 0 { (start + dim_size).max(0) } else { start };
let s = s.min(dim_size);
let e = if end > dim_size {
dim_size
} else if end < 0 {
(end + dim_size).max(0)
} else {
end
};
let e = e.min(dim_size).max(0);
if e < s {
return Err(ZyxError::shape_error(
format!("Slice range end {e} is less than start {s} for dimension {axis}").into(),
));
}
Ok((-(s as i64), -((dim_size as i64) - e as i64)))
}
DimIndex::Index(i) => {
squeeze_axes.push(axis as i32);
let i = if i < 0 { i + dim_size } else { i };
if i < 0 || i >= dim_size {
return Err(ZyxError::shape_error(
format!("Index {i} out of bounds for dimension {axis} of size {dim_size}").into(),
));
}
Ok((-(i as i64), -((dim_size as i64) - i as i64 - 1)))
}
DimIndex::RangeFull => Ok((0i64, 0i64)),
DimIndex::RangeFrom { start } => {
let s = if start < 0 { (start + dim_size).max(0) } else { start };
let s = s.min(dim_size);
Ok((-(s as i64), 0i64))
}
DimIndex::RangeTo { end } => {
let e = if end > dim_size {
dim_size
} else if end < 0 {
(end + dim_size).max(0)
} else {
end
};
let e = e.min(dim_size).max(0);
Ok((0i64, -((dim_size as i64) - e as i64)))
}
}
})
.collect::<Result<Vec<_>, _>>()?;
let mut result = self.pad_zeros(padding)?;
result = result.squeeze(squeeze_axes);
Ok(result)
}
#[allow(clippy::missing_panics_doc)]
pub fn rslice(&self, index: impl IntoIndex) -> Result<Tensor, ZyxError> {
let shape = self.shape();
let rank = shape.len();
let mut squeeze_axes: Vec<Axis> = Vec::new();
let index = index.into_index();
let padding_len = index.len();
if padding_len > rank {
return Err(ZyxError::shape_error(
format!("Index length {padding_len} > rank {rank}").into(),
));
}
let padding = index
.zip(shape.into_iter().rev())
.enumerate()
.map(|(axis, (dim_index, dim_size))| {
let dim_size = dim_size as i64;
match dim_index {
DimIndex::Range { start, end } => {
let s = if start < 0 { (start + dim_size).max(0) } else { start };
let s = s.min(dim_size);
let e = if end > dim_size {
dim_size
} else if end < 0 {
(end + dim_size).max(0)
} else {
end
};
let e = e.min(dim_size).max(0);
if e < s {
return Err(ZyxError::shape_error(
format!("Slice range end {e} is less than start {s} for dimension {axis}").into(),
));
}
Ok((-(s as i64), -((dim_size as i64) - e as i64)))
}
DimIndex::Index(i) => {
squeeze_axes.push(axis as i32);
let i = if i < 0 { i + dim_size } else { i };
if i < 0 || i >= dim_size {
return Err(ZyxError::shape_error(
format!("Index {i} out of bounds for dimension {axis} of size {dim_size}").into(),
));
}
Ok((-(i as i64), -((dim_size as i64) - i as i64 - 1)))
}
DimIndex::RangeFull => Ok((0i64, 0i64)),
DimIndex::RangeFrom { start } => {
let s = if start < 0 { (start + dim_size).max(0) } else { start };
let s = s.min(dim_size);
Ok((-(s as i64), 0i64))
}
DimIndex::RangeTo { end } => {
let e = if end > dim_size {
dim_size
} else if end < 0 {
(end + dim_size).max(0)
} else {
end
};
let e = e.min(dim_size).max(0);
Ok((0i64, -((dim_size as i64) - e as i64)))
}
}
})
.collect::<Result<Vec<_>, _>>()?;
let padding = padding
.into_iter()
.chain(std::iter::repeat_n((0i64, 0i64), rank - padding_len));
let mut padding_vec: Vec<(i64, i64)> = padding.into_iter().collect();
padding_vec.reverse();
let mut result = self.pad_zeros(padding_vec)?;
result = result.squeeze(squeeze_axes);
Ok(result)
}
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn diagonal(&self) -> Tensor {
let n = *self.shape().last().expect("Shape in invalid state. Internal bug.");
self.flatten(..)
.unwrap()
.rpad_zeros([(0i64, i64::try_from(n).unwrap())])
.unwrap()
.reshape([n, n + 1])
.unwrap()
.slice((.., 0))
.unwrap()
.flatten(..)
.unwrap()
}
}
#[derive(Clone, Debug)]
pub enum DimIndex {
Index(i64),
Range { start: i64, end: i64 },
RangeFrom { start: i64 },
RangeTo { end: i64 },
RangeFull,
}
pub trait IntoIndex {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator;
}
impl From<i64> for DimIndex {
fn from(val: i64) -> DimIndex {
DimIndex::Index(val)
}
}
impl From<i32> for DimIndex {
fn from(val: i32) -> DimIndex {
DimIndex::Index(i64::from(val))
}
}
impl From<usize> for DimIndex {
fn from(val: usize) -> DimIndex {
DimIndex::Index(val as i64)
}
}
impl From<u64> for DimIndex {
fn from(val: u64) -> DimIndex {
DimIndex::Index(val as i64)
}
}
impl From<Range<i64>> for DimIndex {
fn from(val: Range<i64>) -> DimIndex {
DimIndex::Range { start: val.start, end: val.end }
}
}
impl From<Range<i32>> for DimIndex {
fn from(val: Range<i32>) -> DimIndex {
DimIndex::Range { start: i64::from(val.start), end: i64::from(val.end) }
}
}
impl From<Range<usize>> for DimIndex {
fn from(val: Range<usize>) -> DimIndex {
DimIndex::Range { start: val.start as i64, end: val.end as i64 }
}
}
impl From<Range<u64>> for DimIndex {
fn from(val: Range<u64>) -> DimIndex {
DimIndex::Range { start: val.start as i64, end: val.end as i64 }
}
}
impl From<RangeInclusive<i64>> for DimIndex {
fn from(val: RangeInclusive<i64>) -> DimIndex {
DimIndex::Range { start: *val.start(), end: val.end() + 1 }
}
}
impl From<RangeInclusive<i32>> for DimIndex {
fn from(val: RangeInclusive<i32>) -> DimIndex {
DimIndex::Range { start: i64::from(*val.start()), end: i64::from(*val.end()) + 1 }
}
}
impl From<RangeInclusive<usize>> for DimIndex {
fn from(val: RangeInclusive<usize>) -> DimIndex {
DimIndex::Range { start: *val.start() as i64, end: *val.end() as i64 + 1 }
}
}
impl From<RangeInclusive<u64>> for DimIndex {
fn from(val: RangeInclusive<u64>) -> DimIndex {
DimIndex::Range { start: *val.start() as i64, end: *val.end() as i64 + 1 }
}
}
impl From<RangeFrom<i64>> for DimIndex {
fn from(val: RangeFrom<i64>) -> DimIndex {
DimIndex::RangeFrom { start: val.start }
}
}
impl From<RangeFrom<i32>> for DimIndex {
fn from(val: RangeFrom<i32>) -> DimIndex {
DimIndex::RangeFrom { start: i64::from(val.start) }
}
}
impl From<RangeFrom<usize>> for DimIndex {
fn from(val: RangeFrom<usize>) -> DimIndex {
DimIndex::RangeFrom { start: val.start as i64 }
}
}
impl From<RangeFrom<u64>> for DimIndex {
fn from(val: RangeFrom<u64>) -> DimIndex {
DimIndex::RangeFrom { start: val.start as i64 }
}
}
impl From<RangeTo<i64>> for DimIndex {
fn from(val: RangeTo<i64>) -> DimIndex {
DimIndex::RangeTo { end: val.end }
}
}
impl From<RangeTo<i32>> for DimIndex {
fn from(val: RangeTo<i32>) -> DimIndex {
DimIndex::RangeTo { end: i64::from(val.end) }
}
}
impl From<RangeTo<usize>> for DimIndex {
fn from(val: RangeTo<usize>) -> DimIndex {
DimIndex::RangeTo { end: val.end as i64 }
}
}
impl From<RangeTo<u64>> for DimIndex {
fn from(val: RangeTo<u64>) -> DimIndex {
DimIndex::RangeTo { end: val.end as i64 }
}
}
impl From<RangeFull> for DimIndex {
fn from(_val: RangeFull) -> DimIndex {
DimIndex::RangeFull
}
}
impl<I: Into<DimIndex>> IntoIndex for I {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
std::iter::once(self.into())
}
}
impl<I: Into<DimIndex>, const N: usize> IntoIndex for [I; N] {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
self.into_iter().map(Into::into)
}
}
impl<I: Into<DimIndex> + Clone> IntoIndex for &[I] {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
self.iter().map(|e| e.clone().into())
}
}
impl<I: Into<DimIndex>> IntoIndex for Vec<I> {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
self.into_iter().map(Into::into)
}
}
impl<I0: Into<DimIndex>, I1: Into<DimIndex>> IntoIndex for (I0, I1) {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[self.0.into(), self.1.into()].into_iter()
}
}
impl<I0: Into<DimIndex>, I1: Into<DimIndex>, I2: Into<DimIndex>> IntoIndex for (I0, I1, I2) {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[self.0.into(), self.1.into(), self.2.into()].into_iter()
}
}
impl<I0: Into<DimIndex>, I1: Into<DimIndex>, I2: Into<DimIndex>, I3: Into<DimIndex>> IntoIndex for (I0, I1, I2, I3) {
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[self.0.into(), self.1.into(), self.2.into(), self.3.into()].into_iter()
}
}
impl<I0: Into<DimIndex>, I1: Into<DimIndex>, I2: Into<DimIndex>, I3: Into<DimIndex>, I4: Into<DimIndex>> IntoIndex
for (I0, I1, I2, I3, I4)
{
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[self.0.into(), self.1.into(), self.2.into(), self.3.into(), self.4.into()].into_iter()
}
}
impl<I0: Into<DimIndex>, I1: Into<DimIndex>, I2: Into<DimIndex>, I3: Into<DimIndex>, I4: Into<DimIndex>, I5: Into<DimIndex>>
IntoIndex for (I0, I1, I2, I3, I4, I5)
{
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[
self.0.into(),
self.1.into(),
self.2.into(),
self.3.into(),
self.4.into(),
self.5.into(),
]
.into_iter()
}
}
impl<
I0: Into<DimIndex>,
I1: Into<DimIndex>,
I2: Into<DimIndex>,
I3: Into<DimIndex>,
I4: Into<DimIndex>,
I5: Into<DimIndex>,
I6: Into<DimIndex>,
> IntoIndex for (I0, I1, I2, I3, I4, I5, I6)
{
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[
self.0.into(),
self.1.into(),
self.2.into(),
self.3.into(),
self.4.into(),
self.5.into(),
self.6.into(),
]
.into_iter()
}
}
impl<
I0: Into<DimIndex>,
I1: Into<DimIndex>,
I2: Into<DimIndex>,
I3: Into<DimIndex>,
I4: Into<DimIndex>,
I5: Into<DimIndex>,
I6: Into<DimIndex>,
I7: Into<DimIndex>,
> IntoIndex for (I0, I1, I2, I3, I4, I5, I6, I7)
{
fn into_index(self) -> impl ExactSizeIterator<Item = DimIndex> + DoubleEndedIterator {
[
self.0.into(),
self.1.into(),
self.2.into(),
self.3.into(),
self.4.into(),
self.5.into(),
self.6.into(),
self.7.into(),
]
.into_iter()
}
}