concision_transformer/ops/
mod.rs1pub use self::prelude::*;
6
7mod merge;
8mod split;
9
10pub(crate) mod prelude {
11 pub use super::merge::*;
12 pub use super::split::*;
13 pub(crate) use super::utils::*;
14}
15
16pub(crate) const ORDER: nd::Order = nd::Order::RowMajor;
17
18pub(crate) mod utils {
19 use concision::NdResult;
20 use nd::prelude::*;
21 use nd::{Data, Order, RemoveAxis};
22
23 pub(crate) fn _merge<A, S, D>(
24 arr: &ArrayBase<S, D>,
25 src: usize,
26 tgt: usize,
27 order: Order,
28 ) -> NdResult<Array<A, D::Smaller>>
29 where
30 A: Clone,
31 D: RemoveAxis,
32 S: Data<Elem = A>,
33 D::Smaller: Dimension,
34 ArrayBase<S, D>: Clone,
35 {
36 let shape = _merge_dim(&arr.raw_dim(), src);
37 let mut head = arr.clone();
38 head.swap_axes(src, tgt);
39 head.to_shape((shape, order)).map(|x| x.to_owned())
40 }
41
42 pub(crate) fn _split<A, S, D, E>(
43 arr: &ArrayBase<S, D>,
44 h: usize,
45 order: Order,
46 ) -> NdResult<Array<A, E>>
47 where
48 A: Clone,
49 D: Dimension<Larger = E>,
50 E: RemoveAxis<Smaller = D>,
51 S: Data<Elem = A>,
52 ArrayBase<S, D>: Clone,
53 {
54 let src = if arr.ndim() >= 2 { arr.ndim() - 2 } else { 0 };
55 let tgt = src + 1;
56 let shape: E = _split_dim(&arr.raw_dim(), h);
57 let mut head = arr.to_shape((shape, order))?.to_owned();
58 head.swap_axes(src, tgt);
59 Ok(head)
60 }
61 pub(crate) fn _merge_dim<D>(dim: &D, axis: usize) -> D::Smaller
63 where
64 D: RemoveAxis,
65 D::Smaller: Dimension,
66 {
67 let mut dn = <D as Dimension>::Smaller::zeros(dim.ndim() - 1);
69 let mut shape = dim.slice().to_vec();
71 shape[dn.ndim()] *= shape[axis];
73 shape.remove(axis);
75
76 dn.slice_mut().copy_from_slice(&shape);
77 dn
78 }
79
80 pub(crate) fn _split_dim<D>(dim: &D::Smaller, h: usize) -> D
81 where
82 D: RemoveAxis,
83 D::Smaller: Dimension,
84 {
85 let rank = dim.ndim() + 1;
86 let mut new_dim = D::zeros(rank);
88 let mut shape = dim.slice().to_vec();
90 let bx = shape.pop().unwrap() / h;
92 shape.push(h);
94 shape.push(bx);
95 new_dim.slice_mut().copy_from_slice(&shape);
98 new_dim
99 }
100}