concision_params/utils/
shape.rs

1/*
2    Appellation: shape <module>
3    Created At: 2025.12.14:07:37:48
4    Contrib: @FL03
5*/
6use ndarray::{Axis, LayoutRef, RemoveAxis};
7
8/// Extract a suitable dimension for a bias tensor from the given reference to the layout of
9/// the weight tensor.
10pub fn extract_bias_dim<A, D>(layout: impl AsRef<LayoutRef<A, D>>) -> D::Smaller
11where
12    D: RemoveAxis,
13{
14    let layout = layout.as_ref();
15    let dim = layout.raw_dim();
16    dim.remove_axis(Axis(0))
17}
18
19#[cfg(test)]
20mod tests {
21    use super::extract_bias_dim;
22    use ndarray::{Array, array};
23
24    #[test]
25    fn test_extract_bias_dim() {
26        let layout = Array::linspace(0f32, 1f32, 100);
27        let bias_dim = extract_bias_dim(&layout);
28        assert_eq!(bias_dim, ndarray::Ix0());
29
30        let layout = array![[1., 2., 3.], [4., 5., 6.]];
31        let bias_dim = extract_bias_dim(&layout);
32        assert_eq!(bias_dim, ndarray::Ix1(3));
33    }
34}