1use crate::prelude_dev::*;
2
3pub enum AxesIndex<T> {
5 Val(T),
6 Vec(Vec<T>),
7}
8
9impl<T> AsRef<[T]> for AxesIndex<T> {
10 fn as_ref(&self) -> &[T] {
11 match self {
12 AxesIndex::Val(v) => core::slice::from_ref(v),
13 AxesIndex::Vec(v) => v.as_slice(),
14 }
15 }
16}
17
18impl<T> From<T> for AxesIndex<T> {
21 fn from(value: T) -> Self {
22 AxesIndex::Val(value)
23 }
24}
25
26impl<T> From<&T> for AxesIndex<T>
27where
28 T: Clone,
29{
30 fn from(value: &T) -> Self {
31 AxesIndex::Val(value.clone())
32 }
33}
34
35impl<T> TryFrom<Vec<T>> for AxesIndex<T> {
36 type Error = Error;
37
38 fn try_from(value: Vec<T>) -> Result<Self> {
39 Ok(AxesIndex::Vec(value))
40 }
41}
42
43impl<T, const N: usize> TryFrom<[T; N]> for AxesIndex<T>
44where
45 T: Clone,
46{
47 type Error = Error;
48
49 fn try_from(value: [T; N]) -> Result<Self> {
50 Ok(AxesIndex::Vec(value.to_vec()))
51 }
52}
53
54impl<T> TryFrom<&Vec<T>> for AxesIndex<T>
55where
56 T: Clone,
57{
58 type Error = Error;
59
60 fn try_from(value: &Vec<T>) -> Result<Self> {
61 Ok(AxesIndex::Vec(value.clone()))
62 }
63}
64
65impl<T> TryFrom<&[T]> for AxesIndex<T>
66where
67 T: Clone,
68{
69 type Error = Error;
70
71 fn try_from(value: &[T]) -> Result<Self> {
72 Ok(AxesIndex::Vec(value.to_vec()))
73 }
74}
75
76impl<T, const N: usize> TryFrom<&[T; N]> for AxesIndex<T>
77where
78 T: Clone,
79{
80 type Error = Error;
81
82 fn try_from(value: &[T; N]) -> Result<Self> {
83 Ok(AxesIndex::Vec(value.to_vec()))
84 }
85}
86
87#[duplicate_item(T; [usize]; [isize])]
88impl TryFrom<()> for AxesIndex<T> {
89 type Error = Error;
90
91 fn try_from(_: ()) -> Result<Self> {
92 Ok(AxesIndex::Vec(vec![]))
93 }
94}
95
96#[duplicate_item(T; [usize]; [isize])]
97impl TryFrom<Option<T>> for AxesIndex<T> {
98 type Error = Error;
99
100 fn try_from(value: Option<T>) -> Result<Self> {
101 match value {
102 Some(v) => Ok(AxesIndex::Val(v)),
103 None => Ok(AxesIndex::Vec(vec![])),
104 }
105 }
106}
107
108macro_rules! impl_try_from_axes_index {
113 ($t1:ty, $($t2:ty),*) => {
114 $(
115 impl TryFrom<$t2> for AxesIndex<$t1> {
116 type Error = Error;
117
118 fn try_from(value: $t2) -> Result<Self> {
119 Ok(AxesIndex::Val(value.try_into()?))
120 }
121 }
122
123 impl TryFrom<&$t2> for AxesIndex<$t1> {
124 type Error = Error;
125
126 fn try_from(value: &$t2) -> Result<Self> {
127 Ok(AxesIndex::Val((*value).try_into()?))
128 }
129 }
130
131 impl TryFrom<Vec<$t2>> for AxesIndex<$t1> {
132 type Error = Error;
133
134 fn try_from(value: Vec<$t2>) -> Result<Self> {
135 let value = value
136 .into_iter()
137 .map(|v| v.try_into().map_err(|_| rstsr_error!(TryFromIntError)))
138 .collect::<Result<Vec<$t1>>>()?;
139 Ok(AxesIndex::Vec(value))
140 }
141 }
142
143 impl<const N: usize> TryFrom<[$t2; N]> for AxesIndex<$t1> {
144 type Error = Error;
145
146 fn try_from(value: [$t2; N]) -> Result<Self> {
147 value.to_vec().try_into()
148 }
149 }
150
151 impl TryFrom<&Vec<$t2>> for AxesIndex<$t1> {
152 type Error = Error;
153
154 fn try_from(value: &Vec<$t2>) -> Result<Self> {
155 value.to_vec().try_into()
156 }
157 }
158
159 impl TryFrom<&[$t2]> for AxesIndex<$t1> {
160 type Error = Error;
161
162 fn try_from(value: &[$t2]) -> Result<Self> {
163 value.to_vec().try_into()
164 }
165 }
166
167 impl<const N: usize> TryFrom<&[$t2; N]> for AxesIndex<$t1> {
168 type Error = Error;
169
170 fn try_from(value: &[$t2; N]) -> Result<Self> {
171 value.to_vec().try_into()
172 }
173 }
174 )*
175 };
176}
177
178impl_try_from_axes_index!(usize, isize, u32, u64, i32, i64);
179impl_try_from_axes_index!(isize, usize, u32, u64, i32, i64);
180
181#[macro_export]
189macro_rules! impl_from_tuple_to_axes_index {
190 ($t: ty) => {
191 impl<F1, F2> TryFrom<(F1, F2)> for AxesIndex<$t>
192 where
193 $t: TryFrom<F1> + TryFrom<F2>,
194 {
195 type Error = Error;
196
197 fn try_from(value: (F1, F2)) -> Result<Self> {
198 Ok(AxesIndex::Vec(vec![value.0.try_into().ok().unwrap(), value.1.try_into().ok().unwrap()]))
199 }
200 }
201
202 impl<F1, F2, F3> TryFrom<(F1, F2, F3)> for AxesIndex<$t>
203 where
204 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3>,
205 {
206 type Error = Error;
207
208 fn try_from(value: (F1, F2, F3)) -> Result<Self> {
209 Ok(AxesIndex::Vec(vec![
210 value.0.try_into().ok().unwrap(),
211 value.1.try_into().ok().unwrap(),
212 value.2.try_into().ok().unwrap(),
213 ]))
214 }
215 }
216
217 impl<F1, F2, F3, F4> TryFrom<(F1, F2, F3, F4)> for AxesIndex<$t>
218 where
219 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4>,
220 {
221 type Error = Error;
222
223 fn try_from(value: (F1, F2, F3, F4)) -> Result<Self> {
224 Ok(AxesIndex::Vec(vec![
225 value.0.try_into().ok().unwrap(),
226 value.1.try_into().ok().unwrap(),
227 value.2.try_into().ok().unwrap(),
228 value.3.try_into().ok().unwrap(),
229 ]))
230 }
231 }
232
233 impl<F1, F2, F3, F4, F5> TryFrom<(F1, F2, F3, F4, F5)> for AxesIndex<$t>
234 where
235 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5>,
236 {
237 type Error = Error;
238
239 fn try_from(value: (F1, F2, F3, F4, F5)) -> Result<Self> {
240 Ok(AxesIndex::Vec(vec![
241 value.0.try_into().ok().unwrap(),
242 value.1.try_into().ok().unwrap(),
243 value.2.try_into().ok().unwrap(),
244 value.3.try_into().ok().unwrap(),
245 value.4.try_into().ok().unwrap(),
246 ]))
247 }
248 }
249
250 impl<F1, F2, F3, F4, F5, F6> TryFrom<(F1, F2, F3, F4, F5, F6)> for AxesIndex<$t>
251 where
252 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6>,
253 {
254 type Error = Error;
255
256 fn try_from(value: (F1, F2, F3, F4, F5, F6)) -> Result<Self> {
257 Ok(AxesIndex::Vec(vec![
258 value.0.try_into().ok().unwrap(),
259 value.1.try_into().ok().unwrap(),
260 value.2.try_into().ok().unwrap(),
261 value.3.try_into().ok().unwrap(),
262 value.4.try_into().ok().unwrap(),
263 value.5.try_into().ok().unwrap(),
264 ]))
265 }
266 }
267
268 impl<F1, F2, F3, F4, F5, F6, F7> TryFrom<(F1, F2, F3, F4, F5, F6, F7)> for AxesIndex<$t>
269 where
270 $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6> + TryFrom<F7>,
271 {
272 type Error = Error;
273
274 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7)) -> Result<Self> {
275 Ok(AxesIndex::Vec(vec![
276 value.0.try_into().ok().unwrap(),
277 value.1.try_into().ok().unwrap(),
278 value.2.try_into().ok().unwrap(),
279 value.3.try_into().ok().unwrap(),
280 value.4.try_into().ok().unwrap(),
281 value.5.try_into().ok().unwrap(),
282 value.6.try_into().ok().unwrap(),
283 ]))
284 }
285 }
286
287 impl<F1, F2, F3, F4, F5, F6, F7, F8> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8)> for AxesIndex<$t>
288 where
289 $t: TryFrom<F1>
290 + TryFrom<F2>
291 + TryFrom<F3>
292 + TryFrom<F4>
293 + TryFrom<F5>
294 + TryFrom<F6>
295 + TryFrom<F7>
296 + TryFrom<F8>,
297 {
298 type Error = Error;
299
300 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8)) -> Result<Self> {
301 Ok(AxesIndex::Vec(vec![
302 value.0.try_into().ok().unwrap(),
303 value.1.try_into().ok().unwrap(),
304 value.2.try_into().ok().unwrap(),
305 value.3.try_into().ok().unwrap(),
306 value.4.try_into().ok().unwrap(),
307 value.5.try_into().ok().unwrap(),
308 value.6.try_into().ok().unwrap(),
309 value.7.try_into().ok().unwrap(),
310 ]))
311 }
312 }
313
314 impl<F1, F2, F3, F4, F5, F6, F7, F8, F9> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9)> for AxesIndex<$t>
315 where
316 $t: TryFrom<F1>
317 + TryFrom<F2>
318 + TryFrom<F3>
319 + TryFrom<F4>
320 + TryFrom<F5>
321 + TryFrom<F6>
322 + TryFrom<F7>
323 + TryFrom<F8>
324 + TryFrom<F9>,
325 {
326 type Error = Error;
327
328 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9)) -> Result<Self> {
329 Ok(AxesIndex::Vec(vec![
330 value.0.try_into().ok().unwrap(),
331 value.1.try_into().ok().unwrap(),
332 value.2.try_into().ok().unwrap(),
333 value.3.try_into().ok().unwrap(),
334 value.4.try_into().ok().unwrap(),
335 value.5.try_into().ok().unwrap(),
336 value.6.try_into().ok().unwrap(),
337 value.7.try_into().ok().unwrap(),
338 value.8.try_into().ok().unwrap(),
339 ]))
340 }
341 }
342
343 impl<F1, F2, F3, F4, F5, F6, F7, F8, F9, F10> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)>
344 for AxesIndex<$t>
345 where
346 $t: TryFrom<F1>
347 + TryFrom<F2>
348 + TryFrom<F3>
349 + TryFrom<F4>
350 + TryFrom<F5>
351 + TryFrom<F6>
352 + TryFrom<F7>
353 + TryFrom<F8>
354 + TryFrom<F9>
355 + TryFrom<F10>,
356 {
357 type Error = Error;
358
359 fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)) -> Result<Self> {
360 Ok(AxesIndex::Vec(vec![
361 value.0.try_into().ok().unwrap(),
362 value.1.try_into().ok().unwrap(),
363 value.2.try_into().ok().unwrap(),
364 value.3.try_into().ok().unwrap(),
365 value.4.try_into().ok().unwrap(),
366 value.5.try_into().ok().unwrap(),
367 value.6.try_into().ok().unwrap(),
368 value.7.try_into().ok().unwrap(),
369 value.8.try_into().ok().unwrap(),
370 value.9.try_into().ok().unwrap(),
371 ]))
372 }
373 }
374 };
375}
376
377impl_from_tuple_to_axes_index!(isize);
378impl_from_tuple_to_axes_index!(usize);
379
380pub fn normalize_axes_index(axes: AxesIndex<isize>, ndim: usize, allow_duplicate: bool) -> Result<Vec<isize>> {
386 let vec = match axes {
388 AxesIndex::Val(axis) => {
389 let axis = if axis < 0 { (ndim as isize) + axis } else { axis };
390 if axis < 0 || axis >= ndim as isize {
391 rstsr_raise!(InvalidValue, "Axis index {axis} is out of bounds for tensor with {ndim} dimensions.")?;
392 }
393 vec![axis]
394 },
395 AxesIndex::Vec(axes) => {
396 let mut normalized_axes = Vec::with_capacity(axes.len());
397 for &axis in axes.iter() {
398 let norm_axis = if axis < 0 { (ndim as isize) + axis } else { axis };
399 if norm_axis < 0 || norm_axis >= ndim as isize {
400 rstsr_raise!(
401 InvalidValue,
402 "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
403 )?;
404 }
405 normalized_axes.push(norm_axis);
406 }
407 normalized_axes.sort();
408 normalized_axes
409 },
410 };
411 if !allow_duplicate {
412 for i in 1..vec.len() {
413 if vec[i] == vec[i - 1] {
414 rstsr_raise!(InvalidValue, "Duplicate axis index {} found in axes argument.", vec[i])?;
415 }
416 }
417 }
418 Ok(vec)
419}
420
421