use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct Axes(pub u32);
impl fmt::Debug for Axes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bits = self.0;
if bits == 0 {
return f.write_str("Axes()");
}
let indices: Vec<u32> = (0u32..32).filter(|&i| (bits & (1u32 << i)) != 0).collect();
f.write_str("Axes(")?;
for (i, &axis) in indices.iter().enumerate() {
if i > 0 {
f.write_str("|")?;
}
write!(f, "{}", axis)?;
}
f.write_str(")")
}
}
impl Axes {
pub const fn empty() -> Self {
Axes(0)
}
#[inline]
pub const fn from_bits(bits: u32) -> Self {
Axes(bits)
}
#[inline]
pub const fn new<const N: usize>(indices: [u32; N]) -> Self {
let mut bits = 0u32;
let mut i = 0;
while i < N {
bits |= 1u32.wrapping_shl(indices[i]);
i += 1;
}
Axes(bits)
}
pub fn from_slice(indices: &[u32]) -> Self {
let mut bits = 0u32;
for &index in indices {
bits |= 1u32.wrapping_shl(index);
}
Axes(bits)
}
#[inline]
pub const fn with_axis(self, axis: u32) -> Self {
Axes(self.0 | 1u32.wrapping_shl(axis))
}
#[inline]
pub const fn to_bits(self) -> u32 {
self.0
}
}
impl From<u32> for Axes {
#[inline]
fn from(bits: u32) -> Self {
Axes(bits)
}
}
impl From<Axes> for u32 {
#[inline]
fn from(axes: Axes) -> u32 {
axes.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_is_zero() {
assert_eq!(Axes::empty().to_bits(), 0);
assert_eq!(Axes::default().to_bits(), 0);
}
#[test]
fn from_bits_roundtrip() {
assert_eq!(Axes::from_bits(0).to_bits(), 0);
assert_eq!(Axes::from_bits(0b111).to_bits(), 0b111);
assert_eq!(Axes::from_bits(0xFFFF_FFFF).to_bits(), 0xFFFF_FFFF);
}
#[test]
fn from_indices_empty() {
let axes = Axes::new([]);
assert_eq!(axes.to_bits(), 0);
}
#[test]
fn from_indices_single() {
assert_eq!(Axes::new([0]).to_bits(), 1);
assert_eq!(Axes::new([3]).to_bits(), 1 << 3);
assert_eq!(Axes::new([31]).to_bits(), 1 << 31);
}
#[test]
fn from_indices_multiple() {
let axes = Axes::new([1, 2]);
assert_eq!(axes.to_bits(), 0b110);
let axes = Axes::new([0, 2, 4]);
assert_eq!(axes.to_bits(), 0b10101);
}
#[test]
fn from_indices_duplicate_same_as_unique() {
let a = Axes::new([1, 1, 2]);
let b = Axes::new([1, 2]);
assert_eq!(a.to_bits(), b.to_bits());
}
#[test]
fn with_axis_builder() {
let axes = Axes::empty().with_axis(0).with_axis(2);
assert_eq!(axes.to_bits(), 0b101);
let axes = Axes::from_bits(1).with_axis(1);
assert_eq!(axes.to_bits(), 0b11);
}
#[test]
fn from_u32_into_u32() {
let raw: u32 = 0b10110;
let axes: Axes = raw.into();
assert_eq!(axes.to_bits(), raw);
let back: u32 = axes.into();
assert_eq!(back, raw);
}
#[test]
fn eq_and_clone() {
let a = Axes::new([1, 2]);
let b = a;
let c = a;
assert_eq!(a, b);
assert_eq!(a, c);
assert_eq!(b.to_bits(), 0b110);
}
#[test]
fn debug_shows_indices_separated_by_pipe() {
assert_eq!(format!("{:?}", Axes::empty()), "Axes()");
assert_eq!(format!("{:?}", Axes::new([0])), "Axes(0)");
assert_eq!(format!("{:?}", Axes::new([1, 2])), "Axes(1|2)");
assert_eq!(format!("{:?}", Axes::new([0, 2, 4])), "Axes(0|2|4)");
}
}