concision_transformer/ops/split.rs
1/*
2 Appellation: split <mod>
3 Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use ndarray::{Array, ArrayBase, Data, Dimension, RemoveAxis, ShapeError};
6
7/// Split a dimension into two parts
8pub trait DimSplit {
9 type Output;
10
11 fn split(&self, h: usize) -> Self::Output;
12}
13
14pub trait SplitHead {
15 type Output;
16
17 fn split(&self, heads: usize) -> Result<Self::Output, ShapeError>;
18}
19
20/*
21 ************* Implementations *************
22*/
23
24impl<D, E> DimSplit for D
25where
26 D: Dimension<Larger = E>,
27 E: RemoveAxis<Smaller = D>,
28{
29 type Output = E;
30
31 fn split(&self, h: usize) -> Self::Output {
32 super::utils::_split_dim(self, h)
33 }
34}
35
36impl<A, S, D, E> SplitHead for ArrayBase<S, D>
37where
38 A: Clone,
39 D: Dimension<Larger = E>,
40 E: RemoveAxis<Smaller = D>,
41 S: Data<Elem = A>,
42 ArrayBase<S, D>: Clone,
43{
44 type Output = Array<A, E>;
45
46 fn split(&self, h: usize) -> Result<Self::Output, ShapeError> {
47 super::_split(self, h, super::ORDER)
48 }
49}
50
51// impl<T: Clone> Split for Array2<T> {
52// type Output = Array3<T>;
53
54// fn split(&self, heads: usize) -> Result<Self::Output, ShapeError> {
55// let (seq, model) = self.dim();
56// let query = model / heads;
57// // reshape the qkv matrix into a 3d array
58// let mut res = self.clone().into_shape((seq, heads, query))?;
59// // swap the sequence and head axes
60// res.swap_axes(0, 1);
61// Ok(res)
62// }
63// }
64
65// impl<T: Clone> Split for Array3<T> {
66// type Output = Array4<T>;
67
68// fn split(&self, heads: usize) -> Result<Self::Output, ShapeError> {
69// let (batch, seq, model) = self.dim();
70// let query = model / heads;
71// // reshape the qkv matrix into a 3d array
72// let mut res = self.clone().into_shape((batch, seq, heads, query))?;
73// // swap the sequence and head axes
74// res.swap_axes(1, 2);
75// Ok(res)
76// }
77// }