acme_tensor/shape/
axis.rs1use core::ops::Deref;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9pub trait IntoAxis {
10 fn into_axis(self) -> Axis;
11}
12
13impl IntoAxis for usize {
14 fn into_axis(self) -> Axis {
15 Axis::new(self)
16 }
17}
18
19#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
21#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
22pub struct Axis(pub(crate) usize);
23
24impl Axis {
25 pub fn new(axis: usize) -> Self {
26 Self(axis)
27 }
28
29 pub fn into_inner(self) -> usize {
30 self.0
31 }
32
33 pub fn axis(&self) -> usize {
34 self.0
35 }
36
37 pub fn dec(&self) -> Axis {
38 self - 1
39 }
40
41 pub fn inc(&self) -> Axis {
42 self + 1
43 }
44}
45
46impl AsRef<usize> for Axis {
47 fn as_ref(&self) -> &usize {
48 &self.0
49 }
50}
51
52impl Deref for Axis {
53 type Target = usize;
54
55 fn deref(&self) -> &Self::Target {
56 &self.0
57 }
58}
59
60impl std::fmt::Display for Axis {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66impl From<usize> for Axis {
67 fn from(axis: usize) -> Self {
68 Axis(axis)
69 }
70}
71
72impl From<Axis> for usize {
73 fn from(axis: Axis) -> Self {
74 axis.0
75 }
76}
77
78macro_rules! impl_std_ops {
79 ($(($trait:tt, $method:ident, $e:tt)),*) => {
80 $(
81 impl_std_ops!($trait, $method, $e);
82 )*
83 };
84 ($trait:tt, $method:ident, $e:tt) => {
85 impl core::ops::$trait<usize> for Axis {
86 type Output = Axis;
87
88 fn $method(self, rhs: usize) -> Self::Output {
89 Axis(self.0 $e rhs)
90 }
91 }
92
93 impl<'a> core::ops::$trait<usize> for &'a Axis {
94 type Output = Axis;
95
96 fn $method(self, rhs: usize) -> Self::Output {
97 Axis(self.0 $e rhs)
98 }
99 }
100
101 impl core::ops::$trait for Axis {
102 type Output = Axis;
103
104 fn $method(self, rhs: Axis) -> Self::Output {
105 Axis(self.0 $e rhs.0)
106 }
107 }
108
109 impl<'a> core::ops::$trait<Axis> for &'a Axis {
110 type Output = Axis;
111
112 fn $method(self, rhs: Axis) -> Self::Output {
113 Axis(self.0 $e rhs.0)
114 }
115 }
116
117 impl<'a> core::ops::$trait<&'a Axis> for Axis {
118 type Output = Axis;
119
120 fn $method(self, rhs: &'a Axis) -> Self::Output {
121 Axis(self.0 $e rhs.0)
122 }
123 }
124
125 impl<'a> core::ops::$trait<&'a Axis> for &'a Axis {
126 type Output = Axis;
127
128 fn $method(self, rhs: &'a Axis) -> Self::Output {
129 Axis(self.0 $e rhs.0)
130 }
131 }
132 };
133}
134
135impl_std_ops!((Add, add, +), (Sub, sub, -), (Mul, mul, *), (Div, div, /), (Rem, rem, %));