use crate::IndexKind;
pub use crate::errors::BoundsError;
#[allow(unused_imports)]
use alloc::format;
#[allow(unused_imports)]
use alloc::string::{String, ToString};
use core::fmt::Debug;
pub use crate::tensor::index_conversion::AsIndex;
#[derive(Debug)]
pub struct IndexWrap {
kind: IndexKind,
wrap_scalar: bool,
}
impl IndexWrap {
pub fn index() -> Self {
Self {
kind: IndexKind::Element,
wrap_scalar: false,
}
}
pub fn dim() -> Self {
Self {
kind: IndexKind::Dimension,
wrap_scalar: false,
}
}
pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
Self {
wrap_scalar,
..self
}
}
pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
try_wrap(idx, size, self.kind, self.wrap_scalar)
}
pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
expect_wrap(idx, size, self.kind, self.wrap_scalar)
}
pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
Self::index().expect_wrap(idx, size)
}
pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
Self::dim().expect_wrap(idx, size)
}
}
pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
where
I: AsIndex,
{
try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
}
pub fn try_wrap<I>(
idx: I,
size: usize,
kind: IndexKind,
wrap_scalar: bool,
) -> Result<usize, BoundsError>
where
I: AsIndex,
{
let source_idx = idx.as_index();
let source_size = size;
let size = if source_size > 0 {
source_size
} else {
if !wrap_scalar {
return Err(BoundsError::index(kind, source_idx, 0..0));
}
1
};
if source_idx >= 0 && (source_idx as usize) < size {
return Ok(source_idx as usize);
}
let _idx = if source_idx < 0 {
source_idx + size as isize
} else {
source_idx
};
if _idx < 0 || (_idx as usize) >= size {
let rank = size as isize;
return Err(BoundsError::index(kind, source_idx, 0..rank));
}
Ok(_idx as usize)
}
#[inline]
#[must_use]
pub fn wrap_index<I>(idx: I, size: usize) -> usize
where
I: AsIndex,
{
if size == 0 {
return 0; }
let wrapped = idx.as_index().rem_euclid(size as isize);
if wrapped < 0 {
(wrapped + size as isize) as usize
} else {
wrapped as usize
}
}
pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
assert_eq!(
shape.len(),
indices.len(),
"Coordinate rank mismatch: expected {}, got {}",
shape.len(),
indices.len(),
);
let mut ravel_idx = 0;
let mut stride = 1;
for (i, &dim) in shape.iter().enumerate().rev() {
let idx = indices[i];
let coord = IndexWrap::index().expect_wrap(idx, dim);
ravel_idx += coord * stride;
stride *= dim;
}
ravel_idx
}
#[cfg(test)]
#[allow(clippy::identity_op, reason = "useful for clarity")]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_ravel() {
let shape = vec![2, 3, 4, 5];
assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
assert_eq!(
ravel_index(&[1, 2, 3, 4], &shape),
1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
);
}
#[test]
fn test_wrap_idx() {
assert_eq!(wrap_index(0, 3), 0_usize);
assert_eq!(wrap_index(3, 3), 0_usize);
assert_eq!(wrap_index(2 * 3, 3), 0_usize);
assert_eq!(wrap_index(0 - 3, 3), 0_usize);
assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
assert_eq!(wrap_index(1, 3), 1_usize);
assert_eq!(wrap_index(1 + 3, 3), 1_usize);
assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
assert_eq!(wrap_index(1 - 3, 3), 1_usize);
assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
assert_eq!(wrap_index(2, 3), 2_usize);
assert_eq!(wrap_index(2 + 3, 3), 2_usize);
assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
assert_eq!(wrap_index(2 - 3, 3), 2_usize);
assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
}
#[test]
fn test_negative_wrap() {
assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
assert_eq!(
IndexWrap::index().try_wrap(3, 3),
Err(BoundsError::Index {
kind: IndexKind::Element,
index: 3,
bounds: 0..3,
})
);
assert_eq!(
IndexWrap::index().try_wrap(-4, 3),
Err(BoundsError::Index {
kind: IndexKind::Element,
index: -4,
bounds: 0..3,
})
);
assert_eq!(
IndexWrap::dim().try_wrap(3, 3),
Err(BoundsError::Index {
kind: IndexKind::Dimension,
index: 3,
bounds: 0..3,
})
);
assert_eq!(
IndexWrap::dim().try_wrap(-4, 3),
Err(BoundsError::Index {
kind: IndexKind::Dimension,
index: -4,
bounds: 0..3,
})
);
}
#[test]
fn test_negative_wrap_scalar() {
assert_eq!(
IndexWrap::index().try_wrap(0, 0),
Err(BoundsError::Index {
kind: IndexKind::Element,
index: 0,
bounds: 0..0,
})
);
assert_eq!(
IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
0
);
assert_eq!(
IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
0
);
assert_eq!(
IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
Err(BoundsError::Index {
kind: IndexKind::Element,
index: 1,
bounds: 0..0,
})
);
}
}