acme_tensor/shape/
axis.rs

1/*
2   Appellation: axis <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use core::ops::Deref;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9pub trait IntoAxis {
10    fn into_axis(self) -> Axis;
11}
12
13impl IntoAxis for usize {
14    fn into_axis(self) -> Axis {
15        Axis::new(self)
16    }
17}
18
19/// An [Axis] is used to represent a dimension in a tensor.
20#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
21#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
22pub struct Axis(pub(crate) usize);
23
24impl Axis {
25    pub fn new(axis: usize) -> Self {
26        Self(axis)
27    }
28
29    pub fn into_inner(self) -> usize {
30        self.0
31    }
32
33    pub fn axis(&self) -> usize {
34        self.0
35    }
36
37    pub fn dec(&self) -> Axis {
38        self - 1
39    }
40
41    pub fn inc(&self) -> Axis {
42        self + 1
43    }
44}
45
46impl AsRef<usize> for Axis {
47    fn as_ref(&self) -> &usize {
48        &self.0
49    }
50}
51
52impl Deref for Axis {
53    type Target = usize;
54
55    fn deref(&self) -> &Self::Target {
56        &self.0
57    }
58}
59
60impl std::fmt::Display for Axis {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "{}", self.0)
63    }
64}
65
66impl From<usize> for Axis {
67    fn from(axis: usize) -> Self {
68        Axis(axis)
69    }
70}
71
72impl From<Axis> for usize {
73    fn from(axis: Axis) -> Self {
74        axis.0
75    }
76}
77
78macro_rules! impl_std_ops {
79    ($(($trait:tt, $method:ident, $e:tt)),*) => {
80        $(
81           impl_std_ops!($trait, $method, $e);
82        )*
83    };
84    ($trait:tt, $method:ident, $e:tt) => {
85        impl core::ops::$trait<usize> for Axis {
86            type Output = Axis;
87
88            fn $method(self, rhs: usize) -> Self::Output {
89                Axis(self.0 $e rhs)
90            }
91        }
92
93        impl<'a> core::ops::$trait<usize> for &'a Axis {
94            type Output = Axis;
95
96            fn $method(self, rhs: usize) -> Self::Output {
97                Axis(self.0 $e rhs)
98            }
99        }
100
101        impl core::ops::$trait for Axis {
102            type Output = Axis;
103
104            fn $method(self, rhs: Axis) -> Self::Output {
105                Axis(self.0 $e rhs.0)
106            }
107        }
108
109        impl<'a> core::ops::$trait<Axis> for &'a Axis {
110            type Output = Axis;
111
112            fn $method(self, rhs: Axis) -> Self::Output {
113                Axis(self.0 $e rhs.0)
114            }
115        }
116
117        impl<'a> core::ops::$trait<&'a Axis> for Axis {
118            type Output = Axis;
119
120            fn $method(self, rhs: &'a Axis) -> Self::Output {
121                Axis(self.0 $e rhs.0)
122            }
123        }
124
125        impl<'a> core::ops::$trait<&'a Axis> for &'a Axis {
126            type Output = Axis;
127
128            fn $method(self, rhs: &'a Axis) -> Self::Output {
129                Axis(self.0 $e rhs.0)
130            }
131        }
132    };
133}
134
135impl_std_ops!((Add, add, +), (Sub, sub, -), (Mul, mul, *), (Div, div, /), (Rem, rem, %));