use crate::{XdlError, MAXRANK};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Dimension {
dimensions: Vec<usize>,
}
impl Dimension {
pub fn scalar() -> Self {
Self { dimensions: vec![] }
}
pub fn from_vec(dims: Vec<usize>) -> Result<Self, XdlError> {
if dims.len() > MAXRANK {
return Err(XdlError::DimensionError(format!(
"Too many dimensions: {} > {}",
dims.len(),
MAXRANK
)));
}
if dims.contains(&0) {
return Err(XdlError::DimensionError(
"Zero dimensions not allowed".to_string(),
));
}
Ok(Self { dimensions: dims })
}
pub fn from_size(size: usize) -> Result<Self, XdlError> {
if size == 0 {
return Err(XdlError::DimensionError(
"Zero size not allowed".to_string(),
));
}
Ok(Self {
dimensions: vec![size],
})
}
pub fn rank(&self) -> usize {
self.dimensions.len()
}
pub fn dims(&self) -> &[usize] {
&self.dimensions
}
pub fn dim(&self, index: usize) -> Option<usize> {
self.dimensions.get(index).copied()
}
pub fn n_elements(&self) -> usize {
if self.dimensions.is_empty() {
1 } else {
self.dimensions.iter().product()
}
}
pub fn is_scalar(&self) -> bool {
self.dimensions.is_empty()
}
pub fn is_vector(&self) -> bool {
self.dimensions.len() == 1
}
pub fn linear_index(&self, indices: &[usize]) -> Result<usize, XdlError> {
if indices.len() != self.dimensions.len() {
return Err(XdlError::DimensionError(format!(
"Index rank {} doesn't match array rank {}",
indices.len(),
self.dimensions.len()
)));
}
let mut linear_idx = 0;
let mut stride = 1;
for (i, (&idx, &dim)) in indices.iter().zip(&self.dimensions).enumerate().rev() {
if idx >= dim {
return Err(XdlError::IndexError(format!(
"Index {} out of range for dimension {} (size {})",
idx, i, dim
)));
}
linear_idx += idx * stride;
stride *= dim;
}
Ok(linear_idx)
}
pub fn multi_index(&self, linear_idx: usize) -> Result<Vec<usize>, XdlError> {
if linear_idx >= self.n_elements() {
return Err(XdlError::IndexError(format!(
"Linear index {} out of range for array with {} elements",
linear_idx,
self.n_elements()
)));
}
if self.is_scalar() {
return Ok(vec![]);
}
let mut indices = vec![0; self.dimensions.len()];
let mut remaining = linear_idx;
for i in (0..self.dimensions.len()).rev() {
let dim_size = self.dimensions[i];
indices[i] = remaining % dim_size;
remaining /= dim_size;
}
Ok(indices)
}
pub fn reform(&self, new_dims: Vec<usize>) -> Result<Self, XdlError> {
let new_n_elements: usize = new_dims.iter().product();
if new_n_elements != self.n_elements() {
return Err(XdlError::DimensionError(format!(
"Cannot reform array of {} elements to {} elements",
self.n_elements(),
new_n_elements
)));
}
Self::from_vec(new_dims)
}
pub fn transpose(&self, perm: Option<&[usize]>) -> Result<Self, XdlError> {
if self.is_scalar() {
return Ok(self.clone());
}
let perm = if let Some(p) = perm {
if p.len() != self.dimensions.len() {
return Err(XdlError::DimensionError(
"Permutation length doesn't match array rank".to_string(),
));
}
p.to_vec()
} else {
(0..self.dimensions.len()).rev().collect()
};
let mut check = vec![false; self.dimensions.len()];
for &p in &perm {
if p >= self.dimensions.len() {
return Err(XdlError::DimensionError(
"Invalid permutation index".to_string(),
));
}
if check[p] {
return Err(XdlError::DimensionError(
"Duplicate in permutation".to_string(),
));
}
check[p] = true;
}
let new_dims = perm.iter().map(|&i| self.dimensions[i]).collect();
Ok(Self {
dimensions: new_dims,
})
}
}
impl Default for Dimension {
fn default() -> Self {
Self::scalar()
}
}
impl std::fmt::Display for Dimension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_scalar() {
write!(f, "scalar")
} else {
write!(
f,
"[{}]",
self.dimensions
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(", ")
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_dimension() {
let dim = Dimension::scalar();
assert!(dim.is_scalar());
assert_eq!(dim.rank(), 0);
assert_eq!(dim.n_elements(), 1);
}
#[test]
fn test_vector_dimension() {
let dim = Dimension::from_size(10).unwrap();
assert!(dim.is_vector());
assert_eq!(dim.rank(), 1);
assert_eq!(dim.n_elements(), 10);
assert_eq!(dim.dim(0), Some(10));
}
#[test]
fn test_multi_dimension() {
let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
assert_eq!(dim.rank(), 3);
assert_eq!(dim.n_elements(), 60);
assert_eq!(dim.dims(), &[3, 4, 5]);
}
#[test]
fn test_indexing() {
let dim = Dimension::from_vec(vec![3, 4]).unwrap();
assert_eq!(dim.linear_index(&[0, 0]).unwrap(), 0);
assert_eq!(dim.linear_index(&[2, 3]).unwrap(), 11);
assert_eq!(dim.multi_index(0).unwrap(), vec![0, 0]);
assert_eq!(dim.multi_index(11).unwrap(), vec![2, 3]);
}
#[test]
fn test_reform() {
let dim = Dimension::from_vec(vec![3, 4]).unwrap();
let reformed = dim.reform(vec![2, 6]).unwrap();
assert_eq!(reformed.dims(), &[2, 6]);
assert_eq!(reformed.n_elements(), 12);
}
#[test]
fn test_transpose() {
let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
let transposed = dim.transpose(None).unwrap();
assert_eq!(transposed.dims(), &[5, 4, 3]);
let custom_transpose = dim.transpose(Some(&[1, 0, 2])).unwrap();
assert_eq!(custom_transpose.dims(), &[4, 3, 5]);
}
#[test]
fn test_error_cases() {
assert!(Dimension::from_vec(vec![3, 0, 5]).is_err());
assert!(Dimension::from_vec(vec![1; MAXRANK + 1]).is_err());
let dim = Dimension::from_size(10).unwrap();
assert!(dim.reform(vec![3, 4]).is_err()); }
}