concision_transformer/ops/
split.rs

1/*
2   Appellation: split <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use ndarray::{Array, ArrayBase, Data, Dimension, RemoveAxis, ShapeError};
6
7/// Split a dimension into two parts
8pub trait DimSplit {
9    type Output;
10
11    fn split(&self, h: usize) -> Self::Output;
12}
13
14pub trait SplitHead {
15    type Output;
16
17    fn split(&self, heads: usize) -> Result<Self::Output, ShapeError>;
18}
19
20/*
21 ************* Implementations *************
22*/
23
24impl<D, E> DimSplit for D
25where
26    D: Dimension<Larger = E>,
27    E: RemoveAxis<Smaller = D>,
28{
29    type Output = E;
30
31    fn split(&self, h: usize) -> Self::Output {
32        super::utils::_split_dim(self, h)
33    }
34}
35
36impl<A, S, D, E> SplitHead for ArrayBase<S, D>
37where
38    A: Clone,
39    D: Dimension<Larger = E>,
40    E: RemoveAxis<Smaller = D>,
41    S: Data<Elem = A>,
42    ArrayBase<S, D>: Clone,
43{
44    type Output = Array<A, E>;
45
46    fn split(&self, h: usize) -> Result<Self::Output, ShapeError> {
47        super::_split(self, h, super::ORDER)
48    }
49}
50
51// impl<T: Clone> Split for Array2<T> {
52//     type Output = Array3<T>;
53
54//     fn split(&self, heads: usize) -> Result<Self::Output, ShapeError> {
55//         let (seq, model) = self.dim();
56//         let query = model / heads;
57//         // reshape the qkv matrix into a 3d array
58//         let mut res = self.clone().into_shape((seq, heads, query))?;
59//         // swap the sequence and head axes
60//         res.swap_axes(0, 1);
61//         Ok(res)
62//     }
63// }
64
65// impl<T: Clone> Split for Array3<T> {
66//     type Output = Array4<T>;
67
68//     fn split(&self, heads: usize) -> Result<Self::Output, ShapeError> {
69//         let (batch, seq, model) = self.dim();
70//         let query = model / heads;
71//         // reshape the qkv matrix into a 3d array
72//         let mut res = self.clone().into_shape((batch, seq, heads, query))?;
73//         // swap the sequence and head axes
74//         res.swap_axes(1, 2);
75//         Ok(res)
76//     }
77// }