redstone_ml/util/axis.rs
1use crate::util::haslength::HasLength;
2
3pub struct Axis(pub isize);
4
5pub trait AxisType {
6 fn isize(&self) -> isize;
7
8 /// Computes the absolute axis index for a given `NdArray` dimension.
9 ///
10 /// Negative axis values are normalized to represent their positive counterparts.
11 /// For example, `-1` represents the last axis, `-2` the second-to-last axis, and so on.
12 ///
13 /// # Arguments
14 ///
15 /// * `ndims` - The total number of dimensions in the ndarray.
16 ///
17 /// # Panics
18 /// * If the provided axis is less than `-ndims` (lower bound).
19 /// * If the provided axis is greater than or equal to `ndims` (upper bound).
20 ///
21 /// # Examples
22 ///
23 /// ```
24 /// # use redstone_ml::*;
25 /// assert_eq!(Axis(-1).as_absolute(4), 3);
26 /// assert_eq!(Axis(-2).as_absolute(4), 2);
27 /// assert_eq!(Axis(1).as_absolute(4), 1);
28 /// ```
29 fn as_absolute(&self, ndims: usize) -> usize {
30 let axis = self.isize();
31 let ndims = ndims as isize;
32
33 if axis < -ndims || axis >= ndims {
34 panic!("axis '{}' out of bounds for tensor of dimension {}", axis, ndims);
35 }
36
37 (if axis < 0 { axis + ndims } else { axis }) as usize
38 }
39}
40
41impl AxisType for Axis {
42 fn isize(&self) -> isize {
43 self.0
44 }
45}
46
47impl AxisType for isize {
48 fn isize(&self) -> isize {
49 *self
50 }
51}
52
53
54pub trait AxesType: IntoIterator<Item=usize> + HasLength + Clone {}
55
56impl<const N: usize> AxesType for [usize; N] {}
57
58impl AxesType for Vec<usize> {}