1use crate::prelude_dev::*;
2
3#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum AxesIndex<T> {
6 None,
7 Val(T),
8 Vec(Vec<T>),
9}
10
11impl<T> AsRef<[T]> for AxesIndex<T> {
12 fn as_ref(&self) -> &[T] {
13 match self {
14 AxesIndex::Val(v) => core::slice::from_ref(v),
15 AxesIndex::Vec(v) => v.as_slice(),
16 AxesIndex::None => panic!("AxesIndex::None cannot be converted to a slice. This is developer's error; if encountered, please report it to github issue."),
17 }
18 }
19}
20
21impl<T> From<T> for AxesIndex<T> {
24 fn from(value: T) -> Self {
25 AxesIndex::Val(value)
26 }
27}
28
29impl<T> From<&T> for AxesIndex<T>
30where
31 T: Clone,
32{
33 fn from(value: &T) -> Self {
34 AxesIndex::Val(value.clone())
35 }
36}
37
38impl<T> From<Vec<T>> for AxesIndex<T> {
39 fn from(value: Vec<T>) -> Self {
40 AxesIndex::Vec(value)
41 }
42}
43
44impl<T, const N: usize> From<[T; N]> for AxesIndex<T>
45where
46 T: Clone,
47{
48 fn from(value: [T; N]) -> Self {
49 AxesIndex::Vec(value.to_vec())
50 }
51}
52
53impl<T> From<&Vec<T>> for AxesIndex<T>
54where
55 T: Clone,
56{
57 fn from(value: &Vec<T>) -> Self {
58 AxesIndex::Vec(value.clone())
59 }
60}
61
62impl<T> From<&[T]> for AxesIndex<T>
63where
64 T: Clone,
65{
66 fn from(value: &[T]) -> Self {
67 AxesIndex::Vec(value.to_vec())
68 }
69}
70
71impl<T, const N: usize> From<&[T; N]> for AxesIndex<T>
72where
73 T: Clone,
74{
75 fn from(value: &[T; N]) -> Self {
76 AxesIndex::Vec(value.to_vec())
77 }
78}
79
80#[duplicate_item(T; [usize]; [isize])]
81impl From<()> for AxesIndex<T> {
82 fn from(_: ()) -> Self {
83 AxesIndex::Vec(vec![])
84 }
85}
86
87#[duplicate_item(T; [usize]; [isize])]
88impl TryFrom<Option<T>> for AxesIndex<T> {
89 type Error = Error;
90
91 fn try_from(value: Option<T>) -> Result<Self> {
92 match value {
93 Some(v) => Ok(AxesIndex::Val(v)),
94 None => Ok(AxesIndex::None),
95 }
96 }
97}
98
99macro_rules! impl_try_from_axes_index {
104 ($t1:ty, $($t2:ty),*) => {
105 $(
106 impl TryFrom<$t2> for AxesIndex<$t1> {
107 type Error = Error;
108
109 fn try_from(value: $t2) -> Result<Self> {
110 Ok(AxesIndex::Val(value.try_into()?))
111 }
112 }
113
114 impl TryFrom<&$t2> for AxesIndex<$t1> {
115 type Error = Error;
116
117 fn try_from(value: &$t2) -> Result<Self> {
118 Ok(AxesIndex::Val((*value).try_into()?))
119 }
120 }
121
122 impl TryFrom<Vec<$t2>> for AxesIndex<$t1> {
123 type Error = Error;
124
125 fn try_from(value: Vec<$t2>) -> Result<Self> {
126 let value = value
127 .into_iter()
128 .map(|v| v.try_into().map_err(|_| rstsr_error!(TryFromIntError)))
129 .collect::<Result<Vec<$t1>>>()?;
130 Ok(AxesIndex::Vec(value))
131 }
132 }
133
134 impl<const N: usize> TryFrom<[$t2; N]> for AxesIndex<$t1> {
135 type Error = Error;
136
137 fn try_from(value: [$t2; N]) -> Result<Self> {
138 value.to_vec().try_into()
139 }
140 }
141
142 impl TryFrom<&Vec<$t2>> for AxesIndex<$t1> {
143 type Error = Error;
144
145 fn try_from(value: &Vec<$t2>) -> Result<Self> {
146 value.to_vec().try_into()
147 }
148 }
149
150 impl TryFrom<&[$t2]> for AxesIndex<$t1> {
151 type Error = Error;
152
153 fn try_from(value: &[$t2]) -> Result<Self> {
154 value.to_vec().try_into()
155 }
156 }
157
158 impl<const N: usize> TryFrom<&[$t2; N]> for AxesIndex<$t1> {
159 type Error = Error;
160
161 fn try_from(value: &[$t2; N]) -> Result<Self> {
162 value.to_vec().try_into()
163 }
164 }
165 )*
166 };
167}
168
169impl_try_from_axes_index!(usize, isize, u32, u64, i32, i64);
170impl_try_from_axes_index!(isize, usize, u32, u64, i32, i64);
171
172#[macro_export]
180macro_rules! impl_from_tuple_to_axes_index {
181 ($t: ty) => {
182 impl<F1, F2> TryFrom<(F1, F2)> for AxesIndex<$t>
183 where
184 $t: TryFrom<F1> + TryFrom<F2>,
185 {
186 type Error = Error;
187
188 fn try_from(value: (F1, F2)) -> Result<Self> {
189 Ok(AxesIndex::Vec(vec![value.0.try_into().ok().unwrap(), value.1.try_into().ok().unwrap()]))
190 }
191 }
192
193 impl<F1, F2, F3> TryFrom<(F1, F2, F3)> for AxesIndex<$t>
194 where
195 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3>,
196 {
197 type Error = Error;
198
199 fn try_from(value: (F1, F2, F3)) -> Result<Self> {
200 Ok(AxesIndex::Vec(vec![
201 value.0.try_into().ok().unwrap(),
202 value.1.try_into().ok().unwrap(),
203 value.2.try_into().ok().unwrap(),
204 ]))
205 }
206 }
207
208 impl<F1, F2, F3, F4> TryFrom<(F1, F2, F3, F4)> for AxesIndex<$t>
209 where
210 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4>,
211 {
212 type Error = Error;
213
214 fn try_from(value: (F1, F2, F3, F4)) -> Result<Self> {
215 Ok(AxesIndex::Vec(vec![
216 value.0.try_into().ok().unwrap(),
217 value.1.try_into().ok().unwrap(),
218 value.2.try_into().ok().unwrap(),
219 value.3.try_into().ok().unwrap(),
220 ]))
221 }
222 }
223
224 impl<F1, F2, F3, F4, F5> TryFrom<(F1, F2, F3, F4, F5)> for AxesIndex<$t>
225 where
226 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5>,
227 {
228 type Error = Error;
229
230 fn try_from(value: (F1, F2, F3, F4, F5)) -> Result<Self> {
231 Ok(AxesIndex::Vec(vec![
232 value.0.try_into().ok().unwrap(),
233 value.1.try_into().ok().unwrap(),
234 value.2.try_into().ok().unwrap(),
235 value.3.try_into().ok().unwrap(),
236 value.4.try_into().ok().unwrap(),
237 ]))
238 }
239 }
240
241 impl<F1, F2, F3, F4, F5, F6> TryFrom<(F1, F2, F3, F4, F5, F6)> for AxesIndex<$t>
242 where
243 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6>,
244 {
245 type Error = Error;
246
247 fn try_from(value: (F1, F2, F3, F4, F5, F6)) -> Result<Self> {
248 Ok(AxesIndex::Vec(vec![
249 value.0.try_into().ok().unwrap(),
250 value.1.try_into().ok().unwrap(),
251 value.2.try_into().ok().unwrap(),
252 value.3.try_into().ok().unwrap(),
253 value.4.try_into().ok().unwrap(),
254 value.5.try_into().ok().unwrap(),
255 ]))
256 }
257 }
258
259 impl<F1, F2, F3, F4, F5, F6, F7> TryFrom<(F1, F2, F3, F4, F5, F6, F7)> for AxesIndex<$t>
260 where
261 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6> + TryFrom<F7>,
262 {
263 type Error = Error;
264
265 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7)) -> Result<Self> {
266 Ok(AxesIndex::Vec(vec![
267 value.0.try_into().ok().unwrap(),
268 value.1.try_into().ok().unwrap(),
269 value.2.try_into().ok().unwrap(),
270 value.3.try_into().ok().unwrap(),
271 value.4.try_into().ok().unwrap(),
272 value.5.try_into().ok().unwrap(),
273 value.6.try_into().ok().unwrap(),
274 ]))
275 }
276 }
277
278 impl<F1, F2, F3, F4, F5, F6, F7, F8> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8)> for AxesIndex<$t>
279 where
280 $t: TryFrom<F1>
281 + TryFrom<F2>
282 + TryFrom<F3>
283 + TryFrom<F4>
284 + TryFrom<F5>
285 + TryFrom<F6>
286 + TryFrom<F7>
287 + TryFrom<F8>,
288 {
289 type Error = Error;
290
291 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8)) -> Result<Self> {
292 Ok(AxesIndex::Vec(vec![
293 value.0.try_into().ok().unwrap(),
294 value.1.try_into().ok().unwrap(),
295 value.2.try_into().ok().unwrap(),
296 value.3.try_into().ok().unwrap(),
297 value.4.try_into().ok().unwrap(),
298 value.5.try_into().ok().unwrap(),
299 value.6.try_into().ok().unwrap(),
300 value.7.try_into().ok().unwrap(),
301 ]))
302 }
303 }
304
305 impl<F1, F2, F3, F4, F5, F6, F7, F8, F9> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9)> for AxesIndex<$t>
306 where
307 $t: TryFrom<F1>
308 + TryFrom<F2>
309 + TryFrom<F3>
310 + TryFrom<F4>
311 + TryFrom<F5>
312 + TryFrom<F6>
313 + TryFrom<F7>
314 + TryFrom<F8>
315 + TryFrom<F9>,
316 {
317 type Error = Error;
318
319 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9)) -> Result<Self> {
320 Ok(AxesIndex::Vec(vec![
321 value.0.try_into().ok().unwrap(),
322 value.1.try_into().ok().unwrap(),
323 value.2.try_into().ok().unwrap(),
324 value.3.try_into().ok().unwrap(),
325 value.4.try_into().ok().unwrap(),
326 value.5.try_into().ok().unwrap(),
327 value.6.try_into().ok().unwrap(),
328 value.7.try_into().ok().unwrap(),
329 value.8.try_into().ok().unwrap(),
330 ]))
331 }
332 }
333
334 impl<F1, F2, F3, F4, F5, F6, F7, F8, F9, F10> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)>
335 for AxesIndex<$t>
336 where
337 $t: TryFrom<F1>
338 + TryFrom<F2>
339 + TryFrom<F3>
340 + TryFrom<F4>
341 + TryFrom<F5>
342 + TryFrom<F6>
343 + TryFrom<F7>
344 + TryFrom<F8>
345 + TryFrom<F9>
346 + TryFrom<F10>,
347 {
348 type Error = Error;
349
350 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)) -> Result<Self> {
351 Ok(AxesIndex::Vec(vec![
352 value.0.try_into().ok().unwrap(),
353 value.1.try_into().ok().unwrap(),
354 value.2.try_into().ok().unwrap(),
355 value.3.try_into().ok().unwrap(),
356 value.4.try_into().ok().unwrap(),
357 value.5.try_into().ok().unwrap(),
358 value.6.try_into().ok().unwrap(),
359 value.7.try_into().ok().unwrap(),
360 value.8.try_into().ok().unwrap(),
361 value.9.try_into().ok().unwrap(),
362 ]))
363 }
364 }
365 };
366}
367
368impl_from_tuple_to_axes_index!(isize);
369impl_from_tuple_to_axes_index!(usize);
370
371pub fn normalize_axes_index(
380 axes: AxesIndex<isize>,
381 ndim: usize,
382 allow_duplicate: bool,
383 sort: bool,
384) -> Result<Vec<isize>> {
385 let vec = match axes {
387 AxesIndex::None => rstsr_raise!(InvalidValue, "Axes argument cannot be None for this operation.")?,
388 AxesIndex::Val(axis) => {
389 let axis = if axis < 0 { (ndim as isize) + axis } else { axis };
390 rstsr_pattern!(
391 axis,
392 0..(ndim as isize),
393 InvalidValue,
394 "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
395 )?;
396 vec![axis]
397 },
398 AxesIndex::Vec(axes) => {
399 let mut normalized_axes = Vec::with_capacity(axes.len());
400 for &axis in axes.iter() {
401 let norm_axis = if axis < 0 { (ndim as isize) + axis } else { axis };
402 rstsr_pattern!(
403 norm_axis,
404 0..(ndim as isize),
405 InvalidValue,
406 "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
407 )?;
408 normalized_axes.push(norm_axis);
409 }
410 if sort {
411 normalized_axes.sort();
412 }
413 normalized_axes
414 },
415 };
416 if !allow_duplicate {
417 let vec_sorted = if sort { vec.clone() } else { vec.iter().copied().sorted().collect() };
418 if vec_sorted.windows(2).any(|w| w[0] == w[1]) {
420 rstsr_raise!(InvalidValue, "Duplicate axes are not allowed.")?;
421 }
422 }
423 Ok(vec)
424}
425
426