use crate::tensors::Dimension;
pub fn elements<const D: usize>(shape: &[(Dimension, usize); D]) -> usize {
shape.iter().map(|d| d.1).product()
}
pub fn position_of<const D: usize>(
shape: &[(Dimension, usize); D],
dimension: Dimension,
) -> Option<usize> {
shape.iter().position(|(d, _)| d == &dimension)
}
pub fn contains<const D: usize>(shape: &[(Dimension, usize); D], dimension: Dimension) -> bool {
shape.iter().any(|(d, _)| d == &dimension)
}
pub fn length_of<const D: usize>(
shape: &[(Dimension, usize); D],
dimension: Dimension,
) -> Option<usize> {
shape
.iter()
.find(|(d, _)| *d == dimension)
.map(|(_, length)| *length)
}
pub fn last_index_of<const D: usize>(
shape: &[(Dimension, usize); D],
dimension: Dimension,
) -> Option<usize> {
length_of(shape, dimension).map(|length| length.saturating_sub(1))
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct DimensionMappings<const D: usize> {
source_to_requested: [usize; D],
requested_to_source: [usize; D],
}
impl<const D: usize> DimensionMappings<D> {
pub(crate) fn no_op_mapping() -> DimensionMappings<D> {
DimensionMappings {
source_to_requested: std::array::from_fn(|d| d),
requested_to_source: std::array::from_fn(|d| d),
}
}
pub(crate) fn new(
source: &[(Dimension, usize); D],
requested: &[Dimension; D],
) -> Option<DimensionMappings<D>> {
let mut source_to_requested = [0; D];
let mut requested_to_source = [0; D];
for d in 0..D {
let dimension = source[d].0;
if requested[d] == dimension {
source_to_requested[d] = d;
requested_to_source[d] = d;
} else {
let (n_in_requested, _) = requested
.iter()
.enumerate()
.find(|(_, d)| **d == dimension)?;
source_to_requested[d] = n_in_requested;
let dimension = requested[d];
let (n_in_source, _) = source
.iter()
.enumerate()
.find(|(_, (d, _))| *d == dimension)?;
requested_to_source[d] = n_in_source;
};
}
Some(DimensionMappings {
source_to_requested,
requested_to_source,
})
}
#[inline]
pub(crate) fn map_dimensions_to_source(&self, indexes: &[usize; D]) -> [usize; D] {
std::array::from_fn(|d| indexes[self.source_to_requested[d]])
}
#[inline]
pub(crate) fn map_shape_to_requested(
&self,
source: &[(Dimension, usize); D],
) -> [(Dimension, usize); D] {
std::array::from_fn(|d| source[self.requested_to_source[d]])
}
#[inline]
pub(crate) fn map_linear_data_layout_to_transposed(
&self,
order: &[Dimension; D],
) -> [Dimension; D] {
std::array::from_fn(|d| order[self.source_to_requested[d]])
}
}
pub fn is_square<const D: usize>(shape: &[(Dimension, usize); D]) -> bool {
if D > 1 {
let first = shape[0].1;
#[allow(clippy::needless_range_loop)]
for d in 1..D {
if shape[d].1 != first {
return false;
}
}
true
} else {
true
}
}
pub fn names_of<const D: usize>(shape: &[(Dimension, usize); D]) -> [Dimension; D] {
shape.map(|(dimension, _length)| dimension)
}
pub(crate) fn has_duplicates(shape: &[(Dimension, usize)]) -> bool {
for i in 1..shape.len() {
let name = shape[i - 1].0;
if shape[i..].iter().any(|d| d.0 == name) {
return true;
}
}
false
}
pub(crate) fn has_duplicates_names(dimensions: &[Dimension]) -> bool {
for i in 1..dimensions.len() {
let name = dimensions[i - 1];
if dimensions[i..].contains(&name) {
return true;
}
}
false
}
pub(crate) fn has_duplicates_extra_names(dimensions: &[(usize, Dimension)]) -> bool {
for i in 1..dimensions.len() {
let name = dimensions[i - 1].1;
if dimensions[i..].iter().any(|&d| d.1 == name) {
return true;
}
}
false
}
#[test]
fn test_dimension_mappings() {
let shape = [("x", 0), ("y", 0), ("z", 0)];
let mapping = DimensionMappings::new(&shape, &["x", "y", "z"]).unwrap();
assert_eq!(
DimensionMappings {
source_to_requested: [0, 1, 2],
requested_to_source: [0, 1, 2],
},
mapping
);
assert_eq!(
[("x", 0), ("y", 0), ("z", 0)],
mapping.map_shape_to_requested(&shape),
);
let mapping = DimensionMappings::new(&shape, &["z", "y", "x"]).unwrap();
assert_eq!(
DimensionMappings {
source_to_requested: [2, 1, 0],
requested_to_source: [2, 1, 0],
},
mapping
);
assert_eq!(
[("z", 0), ("y", 0), ("x", 0)],
mapping.map_shape_to_requested(&shape),
);
let mapping = DimensionMappings::new(&shape, &["z", "x", "y"]).unwrap();
assert_eq!(
DimensionMappings {
source_to_requested: [1, 2, 0],
requested_to_source: [2, 0, 1],
},
mapping
);
assert_eq!(
[("z", 0), ("x", 0), ("y", 0)],
mapping.map_shape_to_requested(&shape),
);
let mapping = DimensionMappings::new(&shape, &["x", "z", "y"]).unwrap();
assert_eq!(
DimensionMappings {
source_to_requested: [0, 2, 1],
requested_to_source: [0, 2, 1],
},
mapping
);
assert_eq!(
[("x", 0), ("z", 0), ("y", 0)],
mapping.map_shape_to_requested(&shape),
);
let mapping = DimensionMappings::new(&shape, &["y", "z", "x"]).unwrap();
assert_eq!(
DimensionMappings {
source_to_requested: [2, 0, 1],
requested_to_source: [1, 2, 0],
},
mapping
);
assert_eq!(
[("y", 0), ("z", 0), ("x", 0)],
mapping.map_shape_to_requested(&shape),
);
}
#[test]
fn test_is_square() {
assert_eq!(true, is_square(&[]));
assert_eq!(true, is_square(&[("x", 1)]));
assert_eq!(true, is_square(&[("x", 1), ("y", 1)]));
assert_eq!(true, is_square(&[("x", 4), ("y", 4)]));
assert_eq!(false, is_square(&[("x", 4), ("y", 3)]));
assert_eq!(true, is_square(&[("x", 3), ("y", 3), ("z", 3)]));
assert_eq!(false, is_square(&[("x", 3), ("y", 4), ("z", 3)]));
}
#[test]
fn test_duplicate_names() {
assert_eq!(has_duplicates_names(&["a", "b", "b", "c"]), true);
assert_eq!(has_duplicates_names(&["a", "b", "c", "d"]), false);
assert_eq!(has_duplicates_names(&["a", "b", "a", "c"]), true);
assert_eq!(has_duplicates_names(&["a", "a", "a", "a"]), true);
assert_eq!(has_duplicates_names(&["a", "b", "c", "c"]), true);
}