use num_traits::AsPrimitive;
use crate::error::{NiftiError, Result};
use crate::util::{validate_dim, validate_dimensionality};
#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct Idx(
[u16; 8],
);
impl Idx {
pub fn new(idx: [u16; 8]) -> Result<Self> {
let _ = validate_dimensionality(&idx)?;
Ok(Idx(idx))
}
pub unsafe fn new_unchecked(idx: [u16; 8]) -> Self {
Idx(idx)
}
pub fn from_slice(idx: &[u16]) -> Result<Self> {
if idx.is_empty() || idx.len() > 7 {
return Err(NiftiError::InconsistentDim(0, idx.len() as u16));
}
let mut raw = [0; 8];
raw[0] = idx.len() as u16;
for (i, d) in idx.iter().enumerate() {
raw[i + 1] = *d;
}
Ok(Idx(raw))
}
pub fn raw(&self) -> &[u16; 8] {
&self.0
}
pub fn rank(&self) -> usize {
usize::from(self.0[0])
}
}
impl AsRef<[u16]> for Idx {
fn as_ref(&self) -> &[u16] {
&self.0[1..=self.rank()]
}
}
impl AsMut<[u16]> for Idx {
fn as_mut(&mut self) -> &mut [u16] {
let rank = self.rank();
&mut self.0[1..=rank]
}
}
#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)]
#[repr(transparent)]
pub struct Dim(Idx);
impl Dim {
pub fn new(dim: [u16; 8]) -> Result<Self> {
let _ = validate_dim(&dim)?;
Ok(Dim(Idx(dim)))
}
pub unsafe fn new_unchecked(dim: [u16; 8]) -> Self {
Dim(Idx(dim))
}
pub fn from_slice<T>(dim: &[T]) -> Result<Self>
where
T: 'static + Copy + AsPrimitive<u16>,
{
if dim.is_empty() || dim.len() > 7 {
return Err(NiftiError::InconsistentDim(0, dim.len() as u16));
}
let mut raw = [1; 8];
raw[0] = dim.len() as u16;
for (i, d) in dim.iter().enumerate() {
raw[i + 1] = d.as_();
}
let _ = validate_dim(&raw)?;
Ok(Dim(Idx(raw)))
}
pub fn raw(&self) -> &[u16; 8] {
self.0.raw()
}
pub fn rank(&self) -> usize {
self.0.rank()
}
pub fn element_count(&self) -> usize {
self.as_ref().iter().cloned().map(usize::from).product()
}
pub fn split(&self, axis: u16) -> (Dim, Dim) {
let axis = usize::from(axis);
assert!(axis <= self.rank());
let (l, r) = self.as_ref().split_at(axis);
(Dim::from_slice(l).unwrap(), Dim::from_slice(r).unwrap())
}
pub fn index_iter(&self) -> DimIter {
DimIter::new(*self)
}
}
impl AsRef<[u16]> for Dim {
fn as_ref(&self) -> &[u16] {
self.0.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct DimIter {
shape: Dim,
state: DimIterState,
}
#[derive(Debug, Copy, Clone)]
enum DimIterState {
First,
Middle(Idx),
Fused,
}
impl DimIter {
fn new(shape: Dim) -> Self {
DimIter {
shape,
state: DimIterState::First,
}
}
}
impl Iterator for DimIter {
type Item = Idx;
fn next(&mut self) -> Option<Self::Item> {
let (out, next_state) = match &mut self.state {
DimIterState::First => {
let out = Idx([self.shape.rank() as u16, 0, 0, 0, 0, 0, 0, 0]);
dbg!((Some(out), DimIterState::Middle(out)))
}
DimIterState::Fused => dbg!((None, DimIterState::Fused)),
DimIterState::Middle(mut current) => {
let mut good = false;
for (c, s) in Iterator::zip(current.as_mut().iter_mut(), self.shape.as_ref().iter())
{
if *c < *s - 1 {
*c += 1;
good = true;
break;
}
*c = 0;
}
if good {
dbg!((Some(current), DimIterState::Middle(current)))
} else {
dbg!((None, DimIterState::Fused))
}
}
};
self.state = next_state;
out
}
}
#[cfg(test)]
mod tests {
use super::{Dim, Idx};
#[test]
fn test_dim() {
let raw_dim = [3, 256, 256, 100, 0, 0, 0, 0];
let dim = Dim::new(raw_dim).unwrap();
assert_eq!(dim.as_ref(), &[256, 256, 100]);
assert_eq!(dim.element_count(), 6553600);
}
#[test]
fn test_dim_iter() {
let raw_dim = [2, 3, 4, 0, 0, 0, 0, 0];
let dim = Dim::new(raw_dim).unwrap();
assert_eq!(dim.as_ref(), &[3, 4]);
assert_eq!(dim.element_count(), 12);
let idx: Vec<_> = dim.index_iter().take(13).collect();
assert_eq!(idx.len(), dim.element_count());
for (i, (got, expected)) in Iterator::zip(
idx.into_iter(),
vec![
Idx::from_slice(&[0, 0]).unwrap(),
Idx::from_slice(&[1, 0]).unwrap(),
Idx::from_slice(&[2, 0]).unwrap(),
Idx::from_slice(&[0, 1]).unwrap(),
Idx::from_slice(&[1, 1]).unwrap(),
Idx::from_slice(&[2, 1]).unwrap(),
Idx::from_slice(&[0, 2]).unwrap(),
Idx::from_slice(&[1, 2]).unwrap(),
Idx::from_slice(&[2, 2]).unwrap(),
Idx::from_slice(&[0, 3]).unwrap(),
Idx::from_slice(&[1, 3]).unwrap(),
Idx::from_slice(&[2, 3]).unwrap(),
]
.into_iter(),
)
.enumerate()
{
assert_eq!(got, expected, "#{} not ok", i);
}
}
}