concision_transformer/ops/
mod.rs

1/*
2   Appellation: ops <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5pub use self::prelude::*;
6
7mod merge;
8mod split;
9
10pub(crate) mod prelude {
11    pub use super::merge::*;
12    pub use super::split::*;
13    pub(crate) use super::utils::*;
14}
15
16pub(crate) const ORDER: nd::Order = nd::Order::RowMajor;
17
18pub(crate) mod utils {
19    use concision::NdResult;
20    use nd::prelude::*;
21    use nd::{Data, Order, RemoveAxis};
22
23    pub(crate) fn _merge<A, S, D>(
24        arr: &ArrayBase<S, D>,
25        src: usize,
26        tgt: usize,
27        order: Order,
28    ) -> NdResult<Array<A, D::Smaller>>
29    where
30        A: Clone,
31        D: RemoveAxis,
32        S: Data<Elem = A>,
33        D::Smaller: Dimension,
34        ArrayBase<S, D>: Clone,
35    {
36        let shape = _merge_dim(&arr.raw_dim(), src);
37        let mut head = arr.clone();
38        head.swap_axes(src, tgt);
39        head.to_shape((shape, order)).map(|x| x.to_owned())
40    }
41
42    pub(crate) fn _split<A, S, D, E>(
43        arr: &ArrayBase<S, D>,
44        h: usize,
45        order: Order,
46    ) -> NdResult<Array<A, E>>
47    where
48        A: Clone,
49        D: Dimension<Larger = E>,
50        E: RemoveAxis<Smaller = D>,
51        S: Data<Elem = A>,
52        ArrayBase<S, D>: Clone,
53    {
54        let src = if arr.ndim() >= 2 { arr.ndim() - 2 } else { 0 };
55        let tgt = src + 1;
56        let shape: E = _split_dim(&arr.raw_dim(), h);
57        let mut head = arr.to_shape((shape, order))?.to_owned();
58        head.swap_axes(src, tgt);
59        Ok(head)
60    }
61    /// Creates the new dimension after merging two axes.
62    pub(crate) fn _merge_dim<D>(dim: &D, axis: usize) -> D::Smaller
63    where
64        D: RemoveAxis,
65        D::Smaller: Dimension,
66    {
67        // create a new dimension with one less axis; initialized with zeros
68        let mut dn = <D as Dimension>::Smaller::zeros(dim.ndim() - 1);
69        // create a mutable vector from the slice
70        let mut shape = dim.slice().to_vec();
71        // multiply the last axis by the target
72        shape[dn.ndim()] *= shape[axis];
73        // remove the last dimension
74        shape.remove(axis);
75
76        dn.slice_mut().copy_from_slice(&shape);
77        dn
78    }
79
80    pub(crate) fn _split_dim<D>(dim: &D::Smaller, h: usize) -> D
81    where
82        D: RemoveAxis,
83        D::Smaller: Dimension,
84    {
85        let rank = dim.ndim() + 1;
86        // create a new dimension with one less axis; initialized with zeros
87        let mut new_dim = D::zeros(rank);
88        // create a mutable vector from the slice
89        let mut shape = dim.slice().to_vec();
90        // get and remove the last axis
91        let bx = shape.pop().unwrap() / h;
92        // extend the shape with the new axes
93        shape.push(h);
94        shape.push(bx);
95        // shape.swap(rank - 2, rank - 3);
96        // copy the values into the new dimension
97        new_dim.slice_mut().copy_from_slice(&shape);
98        new_dim
99    }
100}