alximo_core/
utils.rs

1use ndarray::{Array0, Array1, Array2, Axis};
2
3pub fn squeeze<A>(x: Array2<A>) -> Option<Array1<A>>
4where
5    A: Clone,
6{
7    let shape = x.shape();
8    if shape.contains(&1) {
9        if shape[0] == 1 {
10            Some(x.remove_axis(Axis(0)))
11        } else {
12            Some(x.remove_axis(Axis(1)))
13        }
14    } else {
15        None
16    }
17}
18
19pub fn squeeze_both<A>(x: Array2<A>) -> Option<Array0<A>>
20where
21    A: Clone,
22{
23    let shape = x.shape();
24    if shape[0] == 1 && shape[1] == 1 {
25        Some(x.remove_axis(Axis(0)).remove_axis(Axis(0)))
26    } else {
27        None
28    }
29}