concision_transformer/impls/
impl_linalg.rs1use crate::params::{Qkv, QkvBase};
6use concision::Matmul;
7use nd::linalg::Dot;
8use nd::*;
9
10impl<A, S, T, D, E, F> Matmul<QkvBase<T, E>> for QkvBase<S, D>
11where
12 A: LinalgScalar,
13 D: Dimension,
14 E: Dimension,
15 F: Dimension,
16 S: Data<Elem = A>,
17 T: Data<Elem = A>,
18 ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
19{
20 type Output = Qkv<A, F>;
21
22 fn matmul(&self, rhs: &QkvBase<T, E>) -> Self::Output {
23 QkvBase {
24 q: self.q().dot(rhs.q()),
25 k: self.k().dot(rhs.k()),
26 v: self.v().dot(rhs.v()),
27 }
28 }
29}
30
31impl<A, S, T, D, E, F> Matmul<ArrayBase<T, E>> for QkvBase<S, D>
32where
33 A: LinalgScalar,
34 D: Dimension,
35 E: Dimension,
36 F: Dimension,
37 S: Data<Elem = A>,
38 T: Data<Elem = A>,
39 ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
40{
41 type Output = Qkv<A, F>;
42
43 fn matmul(&self, rhs: &ArrayBase<T, E>) -> Self::Output {
44 QkvBase {
45 q: self.q().dot(rhs),
46 k: self.k().dot(rhs),
47 v: self.v().dot(rhs),
48 }
49 }
50}