concision_params/utils/
shape.rs1use ndarray::{Axis, LayoutRef, RemoveAxis};
7
8pub fn extract_bias_dim<A, D>(layout: impl AsRef<LayoutRef<A, D>>) -> D::Smaller
11where
12 D: RemoveAxis,
13{
14 let layout = layout.as_ref();
15 let dim = layout.raw_dim();
16 dim.remove_axis(Axis(0))
17}
18
19#[cfg(test)]
20mod tests {
21 use super::extract_bias_dim;
22 use ndarray::{Array, array};
23
24 #[test]
25 fn test_extract_bias_dim() {
26 let layout = Array::linspace(0f32, 1f32, 100);
27 let bias_dim = extract_bias_dim(&layout);
28 assert_eq!(bias_dim, ndarray::Ix0());
29
30 let layout = array![[1., 2., 3.], [4., 5., 6.]];
31 let bias_dim = extract_bias_dim(&layout);
32 assert_eq!(bias_dim, ndarray::Ix1(3));
33 }
34}