Skip to main content

rstsr_common/
axis_index.rs

1use crate::prelude_dev::*;
2
3/// Enum for Axes indexing
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum AxesIndex<T> {
6    None,
7    Val(T),
8    Vec(Vec<T>),
9}
10
11impl<T> AsRef<[T]> for AxesIndex<T> {
12    fn as_ref(&self) -> &[T] {
13        match self {
14            AxesIndex::Val(v) => core::slice::from_ref(v),
15            AxesIndex::Vec(v) => v.as_slice(),
16            AxesIndex::None => panic!("AxesIndex::None cannot be converted to a slice. This is developer's error; if encountered, please report it to github issue."),
17        }
18    }
19}
20
21/* #region AxesIndex self-type from */
22
23impl<T> From<T> for AxesIndex<T> {
24    fn from(value: T) -> Self {
25        AxesIndex::Val(value)
26    }
27}
28
29impl<T> From<&T> for AxesIndex<T>
30where
31    T: Clone,
32{
33    fn from(value: &T) -> Self {
34        AxesIndex::Val(value.clone())
35    }
36}
37
38impl<T> From<Vec<T>> for AxesIndex<T> {
39    fn from(value: Vec<T>) -> Self {
40        AxesIndex::Vec(value)
41    }
42}
43
44impl<T, const N: usize> From<[T; N]> for AxesIndex<T>
45where
46    T: Clone,
47{
48    fn from(value: [T; N]) -> Self {
49        AxesIndex::Vec(value.to_vec())
50    }
51}
52
53impl<T> From<&Vec<T>> for AxesIndex<T>
54where
55    T: Clone,
56{
57    fn from(value: &Vec<T>) -> Self {
58        AxesIndex::Vec(value.clone())
59    }
60}
61
62impl<T> From<&[T]> for AxesIndex<T>
63where
64    T: Clone,
65{
66    fn from(value: &[T]) -> Self {
67        AxesIndex::Vec(value.to_vec())
68    }
69}
70
71impl<T, const N: usize> From<&[T; N]> for AxesIndex<T>
72where
73    T: Clone,
74{
75    fn from(value: &[T; N]) -> Self {
76        AxesIndex::Vec(value.to_vec())
77    }
78}
79
80#[duplicate_item(T; [usize]; [isize])]
81impl From<()> for AxesIndex<T> {
82    fn from(_: ()) -> Self {
83        AxesIndex::Vec(vec![])
84    }
85}
86
87#[duplicate_item(T; [usize]; [isize])]
88impl TryFrom<Option<T>> for AxesIndex<T> {
89    type Error = Error;
90
91    fn try_from(value: Option<T>) -> Result<Self> {
92        match value {
93            Some(v) => Ok(AxesIndex::Val(v)),
94            None => Ok(AxesIndex::None),
95        }
96    }
97}
98
99/* #endregion AxesIndex self-type from */
100
101/* #region AxesIndex other-type from */
102
103macro_rules! impl_try_from_axes_index {
104    ($t1:ty, $($t2:ty),*) => {
105        $(
106            impl TryFrom<$t2> for AxesIndex<$t1> {
107                type Error = Error;
108
109                fn try_from(value: $t2) -> Result<Self> {
110                    Ok(AxesIndex::Val(value.try_into()?))
111                }
112            }
113
114            impl TryFrom<&$t2> for AxesIndex<$t1> {
115                type Error = Error;
116
117                fn try_from(value: &$t2) -> Result<Self> {
118                    Ok(AxesIndex::Val((*value).try_into()?))
119                }
120            }
121
122            impl TryFrom<Vec<$t2>> for AxesIndex<$t1> {
123                type Error = Error;
124
125                fn try_from(value: Vec<$t2>) -> Result<Self> {
126                    let value = value
127                        .into_iter()
128                        .map(|v| v.try_into().map_err(|_| rstsr_error!(TryFromIntError)))
129                        .collect::<Result<Vec<$t1>>>()?;
130                    Ok(AxesIndex::Vec(value))
131                }
132            }
133
134            impl<const N: usize> TryFrom<[$t2; N]> for AxesIndex<$t1> {
135                type Error = Error;
136
137                fn try_from(value: [$t2; N]) -> Result<Self> {
138                    value.to_vec().try_into()
139                }
140            }
141
142            impl TryFrom<&Vec<$t2>> for AxesIndex<$t1> {
143                type Error = Error;
144
145                fn try_from(value: &Vec<$t2>) -> Result<Self> {
146                    value.to_vec().try_into()
147                }
148            }
149
150            impl TryFrom<&[$t2]> for AxesIndex<$t1> {
151                type Error = Error;
152
153                fn try_from(value: &[$t2]) -> Result<Self> {
154                    value.to_vec().try_into()
155                }
156            }
157
158            impl<const N: usize> TryFrom<&[$t2; N]> for AxesIndex<$t1> {
159                type Error = Error;
160
161                fn try_from(value: &[$t2; N]) -> Result<Self> {
162                    value.to_vec().try_into()
163                }
164            }
165        )*
166    };
167}
168
169impl_try_from_axes_index!(usize, isize, u32, u64, i32, i64);
170impl_try_from_axes_index!(isize, usize, u32, u64, i32, i64);
171
172/* #endregion AxesIndex other-type from */
173
174/* #region AxesIndex tuple-type from */
175
176// it seems that this directly implementing arbitary AxesIndex<T> will cause
177// conflicting implementation so make a macro for this task
178
179#[macro_export]
180macro_rules! impl_from_tuple_to_axes_index {
181    ($t: ty) => {
182        impl<F1, F2> TryFrom<(F1, F2)> for AxesIndex<$t>
183        where
184            $t: TryFrom<F1> + TryFrom<F2>,
185        {
186            type Error = Error;
187
188            fn try_from(value: (F1, F2)) -> Result<Self> {
189                Ok(AxesIndex::Vec(vec![value.0.try_into().ok().unwrap(), value.1.try_into().ok().unwrap()]))
190            }
191        }
192
193        impl<F1, F2, F3> TryFrom<(F1, F2, F3)> for AxesIndex<$t>
194        where
195            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3>,
196        {
197            type Error = Error;
198
199            fn try_from(value: (F1, F2, F3)) -> Result<Self> {
200                Ok(AxesIndex::Vec(vec![
201                    value.0.try_into().ok().unwrap(),
202                    value.1.try_into().ok().unwrap(),
203                    value.2.try_into().ok().unwrap(),
204                ]))
205            }
206        }
207
208        impl<F1, F2, F3, F4> TryFrom<(F1, F2, F3, F4)> for AxesIndex<$t>
209        where
210            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4>,
211        {
212            type Error = Error;
213
214            fn try_from(value: (F1, F2, F3, F4)) -> Result<Self> {
215                Ok(AxesIndex::Vec(vec![
216                    value.0.try_into().ok().unwrap(),
217                    value.1.try_into().ok().unwrap(),
218                    value.2.try_into().ok().unwrap(),
219                    value.3.try_into().ok().unwrap(),
220                ]))
221            }
222        }
223
224        impl<F1, F2, F3, F4, F5> TryFrom<(F1, F2, F3, F4, F5)> for AxesIndex<$t>
225        where
226            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5>,
227        {
228            type Error = Error;
229
230            fn try_from(value: (F1, F2, F3, F4, F5)) -> Result<Self> {
231                Ok(AxesIndex::Vec(vec![
232                    value.0.try_into().ok().unwrap(),
233                    value.1.try_into().ok().unwrap(),
234                    value.2.try_into().ok().unwrap(),
235                    value.3.try_into().ok().unwrap(),
236                    value.4.try_into().ok().unwrap(),
237                ]))
238            }
239        }
240
241        impl<F1, F2, F3, F4, F5, F6> TryFrom<(F1, F2, F3, F4, F5, F6)> for AxesIndex<$t>
242        where
243            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6>,
244        {
245            type Error = Error;
246
247            fn try_from(value: (F1, F2, F3, F4, F5, F6)) -> Result<Self> {
248                Ok(AxesIndex::Vec(vec![
249                    value.0.try_into().ok().unwrap(),
250                    value.1.try_into().ok().unwrap(),
251                    value.2.try_into().ok().unwrap(),
252                    value.3.try_into().ok().unwrap(),
253                    value.4.try_into().ok().unwrap(),
254                    value.5.try_into().ok().unwrap(),
255                ]))
256            }
257        }
258
259        impl<F1, F2, F3, F4, F5, F6, F7> TryFrom<(F1, F2, F3, F4, F5, F6, F7)> for AxesIndex<$t>
260        where
261            $t: TryFrom<F1> + TryFrom<F2> + TryFrom<F3> + TryFrom<F4> + TryFrom<F5> + TryFrom<F6> + TryFrom<F7>,
262        {
263            type Error = Error;
264
265            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7)) -> Result<Self> {
266                Ok(AxesIndex::Vec(vec![
267                    value.0.try_into().ok().unwrap(),
268                    value.1.try_into().ok().unwrap(),
269                    value.2.try_into().ok().unwrap(),
270                    value.3.try_into().ok().unwrap(),
271                    value.4.try_into().ok().unwrap(),
272                    value.5.try_into().ok().unwrap(),
273                    value.6.try_into().ok().unwrap(),
274                ]))
275            }
276        }
277
278        impl<F1, F2, F3, F4, F5, F6, F7, F8> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8)> for AxesIndex<$t>
279        where
280            $t: TryFrom<F1>
281                + TryFrom<F2>
282                + TryFrom<F3>
283                + TryFrom<F4>
284                + TryFrom<F5>
285                + TryFrom<F6>
286                + TryFrom<F7>
287                + TryFrom<F8>,
288        {
289            type Error = Error;
290
291            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8)) -> Result<Self> {
292                Ok(AxesIndex::Vec(vec![
293                    value.0.try_into().ok().unwrap(),
294                    value.1.try_into().ok().unwrap(),
295                    value.2.try_into().ok().unwrap(),
296                    value.3.try_into().ok().unwrap(),
297                    value.4.try_into().ok().unwrap(),
298                    value.5.try_into().ok().unwrap(),
299                    value.6.try_into().ok().unwrap(),
300                    value.7.try_into().ok().unwrap(),
301                ]))
302            }
303        }
304
305        impl<F1, F2, F3, F4, F5, F6, F7, F8, F9> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9)> for AxesIndex<$t>
306        where
307            $t: TryFrom<F1>
308                + TryFrom<F2>
309                + TryFrom<F3>
310                + TryFrom<F4>
311                + TryFrom<F5>
312                + TryFrom<F6>
313                + TryFrom<F7>
314                + TryFrom<F8>
315                + TryFrom<F9>,
316        {
317            type Error = Error;
318
319            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9)) -> Result<Self> {
320                Ok(AxesIndex::Vec(vec![
321                    value.0.try_into().ok().unwrap(),
322                    value.1.try_into().ok().unwrap(),
323                    value.2.try_into().ok().unwrap(),
324                    value.3.try_into().ok().unwrap(),
325                    value.4.try_into().ok().unwrap(),
326                    value.5.try_into().ok().unwrap(),
327                    value.6.try_into().ok().unwrap(),
328                    value.7.try_into().ok().unwrap(),
329                    value.8.try_into().ok().unwrap(),
330                ]))
331            }
332        }
333
334        impl<F1, F2, F3, F4, F5, F6, F7, F8, F9, F10> TryFrom<(F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)>
335            for AxesIndex<$t>
336        where
337            $t: TryFrom<F1>
338                + TryFrom<F2>
339                + TryFrom<F3>
340                + TryFrom<F4>
341                + TryFrom<F5>
342                + TryFrom<F6>
343                + TryFrom<F7>
344                + TryFrom<F8>
345                + TryFrom<F9>
346                + TryFrom<F10>,
347        {
348            type Error = Error;
349
350            fn try_from(value: (F1, F2, F3, F4, F5, F6, F7, F8, F9, F10)) -> Result<Self> {
351                Ok(AxesIndex::Vec(vec![
352                    value.0.try_into().ok().unwrap(),
353                    value.1.try_into().ok().unwrap(),
354                    value.2.try_into().ok().unwrap(),
355                    value.3.try_into().ok().unwrap(),
356                    value.4.try_into().ok().unwrap(),
357                    value.5.try_into().ok().unwrap(),
358                    value.6.try_into().ok().unwrap(),
359                    value.7.try_into().ok().unwrap(),
360                    value.8.try_into().ok().unwrap(),
361                    value.9.try_into().ok().unwrap(),
362                ]))
363            }
364        }
365    };
366}
367
368impl_from_tuple_to_axes_index!(isize);
369impl_from_tuple_to_axes_index!(usize);
370
371/* #endregion AxesIndex tuple-type from */
372
373/* #region utilities for AxesIndex */
374
375/// Normalize axes argument into a tuple of non-negative integer axes.
376///
377/// Though the returned vector will be of type `isize` for convenience, the values will be actually
378/// non-negative (`usize`-compatible).
379pub fn normalize_axes_index(
380    axes: AxesIndex<isize>,
381    ndim: usize,
382    allow_duplicate: bool,
383    sort: bool,
384) -> Result<Vec<isize>> {
385    // generate the normalized axes vector
386    let vec = match axes {
387        AxesIndex::None => rstsr_raise!(InvalidValue, "Axes argument cannot be None for this operation.")?,
388        AxesIndex::Val(axis) => {
389            let axis = if axis < 0 { (ndim as isize) + axis } else { axis };
390            rstsr_pattern!(
391                axis,
392                0..(ndim as isize),
393                InvalidValue,
394                "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
395            )?;
396            vec![axis]
397        },
398        AxesIndex::Vec(axes) => {
399            let mut normalized_axes = Vec::with_capacity(axes.len());
400            for &axis in axes.iter() {
401                let norm_axis = if axis < 0 { (ndim as isize) + axis } else { axis };
402                rstsr_pattern!(
403                    norm_axis,
404                    0..(ndim as isize),
405                    InvalidValue,
406                    "Axis index {axis} is out of bounds for tensor with {ndim} dimensions."
407                )?;
408                normalized_axes.push(norm_axis);
409            }
410            if sort {
411                normalized_axes.sort();
412            }
413            normalized_axes
414        },
415    };
416    if !allow_duplicate {
417        let vec_sorted = if sort { vec.clone() } else { vec.iter().copied().sorted().collect() };
418        // check for duplicates in sorted vector
419        if vec_sorted.windows(2).any(|w| w[0] == w[1]) {
420            rstsr_raise!(InvalidValue, "Duplicate axes are not allowed.")?;
421        }
422    }
423    Ok(vec)
424}
425
426/* #endregion */