rstsr_common/
axis_index.rs

1use crate::prelude_dev::*;
2
3/// Enum for Axes indexing
4pub enum AxesIndex<T> {
5    Val(T),
6    Vec(Vec<T>),
7}
8
9impl<T> AsRef<[T]> for AxesIndex<T> {
10    fn as_ref(&self) -> &[T] {
11        match self {
12            AxesIndex::Val(v) => core::slice::from_ref(v),
13            AxesIndex::Vec(v) => v.as_slice(),
14        }
15    }
16}
17
18/* #region AxesIndex self-type from */
19
20impl<T> From<T> for AxesIndex<T> {
21    fn from(value: T) -> Self {
22        AxesIndex::Val(value)
23    }
24}
25
26impl<T> From<&T> for AxesIndex<T>
27where
28    T: Clone,
29{
30    fn from(value: &T) -> Self {
31        AxesIndex::Val(value.clone())
32    }
33}
34
35impl<T> TryFrom<Vec<T>> for AxesIndex<T> {
36    type Error = Error;
37
38    fn try_from(value: Vec<T>) -> Result<Self> {
39        Ok(AxesIndex::Vec(value))
40    }
41}
42
43impl<T, const N: usize> TryFrom<[T; N]> for AxesIndex<T>
44where
45    T: Clone,
46{
47    type Error = Error;
48
49    fn try_from(value: [T; N]) -> Result<Self> {
50        Ok(AxesIndex::Vec(value.to_vec()))
51    }
52}
53
54impl<T> TryFrom<&Vec<T>> for AxesIndex<T>
55where
56    T: Clone,
57{
58    type Error = Error;
59
60    fn try_from(value: &Vec<T>) -> Result<Self> {
61        Ok(AxesIndex::Vec(value.clone()))
62    }
63}
64
65impl<T> TryFrom<&[T]> for AxesIndex<T>
66where
67    T: Clone,
68{
69    type Error = Error;
70
71    fn try_from(value: &[T]) -> Result<Self> {
72        Ok(AxesIndex::Vec(value.to_vec()))
73    }
74}
75
76impl<T, const N: usize> TryFrom<&[T; N]> for AxesIndex<T>
77where
78    T: Clone,
79{
80    type Error = Error;
81
82    fn try_from(value: &[T; N]) -> Result<Self> {
83        Ok(AxesIndex::Vec(value.to_vec()))
84    }
85}
86
87#[duplicate_item(T; [usize]; [isize])]
88impl TryFrom<()> for AxesIndex<T> {
89    type Error = Error;
90
91    fn try_from(_: ()) -> Result<Self> {
92        Ok(AxesIndex::Vec(vec![]))
93    }
94}
95
96#[duplicate_item(T; [usize]; [isize])]
97impl TryFrom<Option<T>> for AxesIndex<T> {
98    type Error = Error;
99
100    fn try_from(value: Option<T>) -> Result<Self> {
101        match value {
102            Some(v) => Ok(AxesIndex::Val(v)),
103            None => Ok(AxesIndex::Vec(vec![])),
104        }
105    }
106}
107
108/* #endregion AxesIndex self-type from */
109
110/* #region AxesIndex other-type from */
111
112macro_rules! impl_try_from_axes_index {
113    ($t1:ty, $($t2:ty),*) => {
114        $(
115            impl TryFrom<$t2> for AxesIndex<$t1> {
116                type Error = Error;
117
118                fn try_from(value: $t2) -> Result<Self> {
119                    Ok(AxesIndex::Val(value.try_into()?))
120                }
121            }
122
123            impl TryFrom<&$t2> for AxesIndex<$t1> {
124                type Error = Error;
125
126                fn try_from(value: &$t2) -> Result<Self> {
127                    Ok(AxesIndex::Val((*value).try_into()?))
128                }
129            }
130
131            impl TryFrom<Vec<$t2>> for AxesIndex<$t1> {
132                type Error = Error;
133
134                fn try_from(value: Vec<$t2>) -> Result<Self> {
135                    let value = value
136                        .into_iter()
137                        .map(|v| v.try_into().map_err(|_| rstsr_error!(TryFromIntError)))
138                        .collect::<Result<Vec<$t1>>>()?;
139                    Ok(AxesIndex::Vec(value))
140                }
141            }
142
143            impl<const N: usize> TryFrom<[$t2; N]> for AxesIndex<$t1> {
144                type Error = Error;
145
146                fn try_from(value: [$t2; N]) -> Result<Self> {
147                    value.to_vec().try_into()
148                }
149            }
150
151            impl TryFrom<&Vec<$t2>> for AxesIndex<$t1> {
152                type Error = Error;
153
154                fn try_from(value: &Vec<$t2>) -> Result<Self> {
155                    value.to_vec().try_into()
156                }
157            }
158
159            impl TryFrom<&[$t2]> for AxesIndex<$t1> {
160                type Error = Error;
161
162                fn try_from(value: &[$t2]) -> Result<Self> {
163                    value.to_vec().try_into()
164                }
165            }
166
167            impl<const N: usize> TryFrom<&[$t2; N]> for AxesIndex<$t1> {
168                type Error = Error;
169
170                fn try_from(value: &[$t2; N]) -> Result<Self> {
171                    value.to_vec().try_into()
172                }
173            }
174        )*
175    };
176}
177
178impl_try_from_axes_index!(usize, isize, u32, u64, i32, i64);
179impl_try_from_axes_index!(isize, usize, u32, u64, i32, i64);
180
181/* #endregion AxesIndex other-type from */
182
183/* #region AxesIndex tuple-type from */
184
185// it seems that this directly implementing arbitary AxesIndex<T> will cause
186// conflicting implementation so make a macro for this task
187
188#[macro_export]
189macro_rules! impl_from_tuple_to_axes_index {
190    ($t: ty) => {
191        impl<F1, F2> TryFrom<(F1, F2)> for AxesIndex<$t>
192        where
193            $t: TryFrom<F1> + TryFrom<F2>,
194        {
195            type Error = Error;
196
197            fn try_from(value: (F1, F2)) -> Result<Self> {
198                Ok(AxesIndex::Vec(vec![value.0.try_into().ok().unwrap(), value.1.try_into().ok().unwrap()]))
199            }
200        }
201
202        impl<F1, F2, F3> TryFrom<(F1, F2, F3)> for AxesIndex<$t>
203        where
204            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3>,
205        {
206            type Error = Error;
207
208            fn try_from(value: (F1, F2, F3)) -> Result<Self> {
209                Ok(AxesIndex::Vec(vec![
210                    value.0.try_into().ok().unwrap(),
211                    value.1.try_into().ok().unwrap(),
212                    value.2.try_into().ok().unwrap(),
213                ]))
214            }
215        }
216
217        impl<F1, F2, F3, F4> TryFrom<(F1, F2, F3, F4)> for AxesIndex<$t>
218        where
219            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4>,
220        {
221            type Error = Error;
222
223            fn try_from(value: (F1, F2, F3, F4)) -> Result<Self> {
224                Ok(AxesIndex::Vec(vec![
225                    value.0.try_into().ok().unwrap(),
226                    value.1.try_into().ok().unwrap(),
227                    value.2.try_into().ok().unwrap(),
228                    value.3.try_into().ok().unwrap(),
229                ]))
230            }
231        }
232
233        impl<F1, F2, F3, F4, F5> TryFrom<(F1, F2, F3, F4, F5)> for AxesIndex<$t>
234        where
235            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5>,
236        {
237            type Error = Error;
238
239            fn try_from(value: (F1, F2, F3, F4, F5)) -> Result<Self> {
240                Ok(AxesIndex::Vec(vec![
241                    value.0.try_into().ok().unwrap(),
242                    value.1.try_into().ok().unwrap(),
243                    value.2.try_into().ok().unwrap(),
244                    value.3.try_into().ok().unwrap(),
245                    value.4.try_into().ok().unwrap(),
246                ]))
247            }
248        }
249
250        impl<F1, F2, F3, F4, F5, F6> TryFrom<(F1, F2, F3, F4, F5, F6)> for AxesIndex<$t>
251        where
252            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6>,
253        {
254            type Error = Error;
255
256            fn try_from(value: (F1, F2, F3, F4, F5, F6)) -> Result<Self> {
257                Ok(AxesIndex::Vec(vec![
258                    value.0.try_into().ok().unwrap(),
259                    value.1.try_into().ok().unwrap(),
260                    value.2.try_into().ok().unwrap(),
261                    value.3.try_into().ok().unwrap(),
262                    value.4.try_into().ok().unwrap(),
263                    value.5.try_into().ok().unwrap(),
264                ]))
265            }
266        }
267
268        impl<F1, F2, F3, F4, F5, F6, F7> TryFrom<(F1, F2, F3, F4, F5, F6, F7)> for AxesIndex<$t>
269        where
270            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6> + TryFrom<F7>,
271        {
272            type Error = Error;
273
274            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7)) -> Result<Self> {
275                Ok(AxesIndex::Vec(vec![
276                    value.0.try_into().ok().unwrap(),
277                    value.1.try_into().ok().unwrap(),
278                    value.2.try_into().ok().unwrap(),
279                    value.3.try_into().ok().unwrap(),
280                    value.4.try_into().ok().unwrap(),
281                    value.5.try_into().ok().unwrap(),
282                    value.6.try_into().ok().unwrap(),
283                ]))
284            }
285        }
286
287        impl<F1, F2, F3, F4, F5, F6, F7, F8> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8)> for AxesIndex<$t>
288        where
289            $t: TryFrom<F1>
290                + TryFrom<F2>
291                + TryFrom<F3>
292                + TryFrom<F4>
293                + TryFrom<F5>
294                + TryFrom<F6>
295                + TryFrom<F7>
296                + TryFrom<F8>,
297        {
298            type Error = Error;
299
300            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8)) -> Result<Self> {
301                Ok(AxesIndex::Vec(vec![
302                    value.0.try_into().ok().unwrap(),
303                    value.1.try_into().ok().unwrap(),
304                    value.2.try_into().ok().unwrap(),
305                    value.3.try_into().ok().unwrap(),
306                    value.4.try_into().ok().unwrap(),
307                    value.5.try_into().ok().unwrap(),
308                    value.6.try_into().ok().unwrap(),
309                    value.7.try_into().ok().unwrap(),
310                ]))
311            }
312        }
313
314        impl<F1, F2, F3, F4, F5, F6, F7, F8, F9> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9)> for AxesIndex<$t>
315        where
316            $t: TryFrom<F1>
317                + TryFrom<F2>
318                + TryFrom<F3>
319                + TryFrom<F4>
320                + TryFrom<F5>
321                + TryFrom<F6>
322                + TryFrom<F7>
323                + TryFrom<F8>
324                + TryFrom<F9>,
325        {
326            type Error = Error;
327
328            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9)) -> Result<Self> {
329                Ok(AxesIndex::Vec(vec![
330                    value.0.try_into().ok().unwrap(),
331                    value.1.try_into().ok().unwrap(),
332                    value.2.try_into().ok().unwrap(),
333                    value.3.try_into().ok().unwrap(),
334                    value.4.try_into().ok().unwrap(),
335                    value.5.try_into().ok().unwrap(),
336                    value.6.try_into().ok().unwrap(),
337                    value.7.try_into().ok().unwrap(),
338                    value.8.try_into().ok().unwrap(),
339                ]))
340            }
341        }
342
343        impl<F1, F2, F3, F4, F5, F6, F7, F8, F9, F10> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)>
344            for AxesIndex<$t>
345        where
346            $t: TryFrom<F1>
347                + TryFrom<F2>
348                + TryFrom<F3>
349                + TryFrom<F4>
350                + TryFrom<F5>
351                + TryFrom<F6>
352                + TryFrom<F7>
353                + TryFrom<F8>
354                + TryFrom<F9>
355                + TryFrom<F10>,
356        {
357            type Error = Error;
358
359            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)) -> Result<Self> {
360                Ok(AxesIndex::Vec(vec![
361                    value.0.try_into().ok().unwrap(),
362                    value.1.try_into().ok().unwrap(),
363                    value.2.try_into().ok().unwrap(),
364                    value.3.try_into().ok().unwrap(),
365                    value.4.try_into().ok().unwrap(),
366                    value.5.try_into().ok().unwrap(),
367                    value.6.try_into().ok().unwrap(),
368                    value.7.try_into().ok().unwrap(),
369                    value.8.try_into().ok().unwrap(),
370                    value.9.try_into().ok().unwrap(),
371                ]))
372            }
373        }
374    };
375}
376
377impl_from_tuple_to_axes_index!(isize);
378impl_from_tuple_to_axes_index!(usize);
379
380/* #endregion AxesIndex tuple-type from */
381
382/* #region utilities for AxesIndex */
383
384/// Normalize axes argument into a tuple of non-negative integer axes.
385pub fn normalize_axes_index(axes: AxesIndex<isize>, ndim: usize, allow_duplicate: bool) -> Result<Vec<isize>> {
386    // generate the normalized axes vector
387    let vec = match axes {
388        AxesIndex::Val(axis) => {
389            let axis = if axis < 0 { (ndim as isize) + axis } else { axis };
390            if axis < 0 || axis >= ndim as isize {
391                rstsr_raise!(InvalidValue, "Axis index {axis} is out of bounds for tensor with {ndim} dimensions.")?;
392            }
393            vec![axis]
394        },
395        AxesIndex::Vec(axes) => {
396            let mut normalized_axes = Vec::with_capacity(axes.len());
397            for &axis in axes.iter() {
398                let norm_axis = if axis < 0 { (ndim as isize) + axis } else { axis };
399                if norm_axis < 0 || norm_axis >= ndim as isize {
400                    rstsr_raise!(
401                        InvalidValue,
402                        "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
403                    )?;
404                }
405                normalized_axes.push(norm_axis);
406            }
407            normalized_axes.sort();
408            normalized_axes
409        },
410    };
411    if !allow_duplicate {
412        for i in 1..vec.len() {
413            if vec[i] == vec[i - 1] {
414                rstsr_raise!(InvalidValue, "Duplicate axis index {} found in axes argument.", vec[i])?;
415            }
416        }
417    }
418    Ok(vec)
419}
420
421/* #endregion */