redstone_ml/util/
axis.rs

1use crate::util::haslength::HasLength;
2
3pub struct Axis(pub isize);
4
5pub trait AxisType {
6    fn isize(&self) -> isize;
7
8    /// Computes the absolute axis index for a given `NdArray` dimension.
9    /// 
10    /// Negative axis values are normalized to represent their positive counterparts.
11    /// For example, `-1` represents the last axis, `-2` the second-to-last axis, and so on.
12    ///
13    /// # Arguments
14    ///
15    /// * `ndims` - The total number of dimensions in the ndarray.
16    ///
17    /// # Panics
18    /// * If the provided axis is less than `-ndims` (lower bound).
19    /// * If the provided axis is greater than or equal to `ndims` (upper bound).
20    ///
21    /// # Examples
22    ///
23    /// ```
24    /// # use redstone_ml::*;
25    /// assert_eq!(Axis(-1).as_absolute(4), 3);
26    /// assert_eq!(Axis(-2).as_absolute(4), 2);
27    /// assert_eq!(Axis(1).as_absolute(4), 1);
28    /// ```
29    fn as_absolute(&self, ndims: usize) -> usize {
30        let axis = self.isize();
31        let ndims = ndims as isize;
32
33        if axis < -ndims || axis >= ndims {
34            panic!("axis '{}' out of bounds for tensor of dimension {}", axis, ndims);
35        }
36
37        (if axis < 0 { axis + ndims } else { axis }) as usize
38    }
39}
40
41impl AxisType for Axis {
42    fn isize(&self) -> isize {
43        self.0
44    }
45}
46
47impl AxisType for isize {
48    fn isize(&self) -> isize {
49        *self
50    }
51}
52
53
54pub trait AxesType: IntoIterator<Item=usize> + HasLength + Clone {}
55
56impl<const N: usize> AxesType for [usize; N] {}
57
58impl AxesType for Vec<usize> {}