concision_transformer/impls/
impl_linalg.rs

1/*
2    Appellation: impl_linalg <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::params::{Qkv, QkvBase};
6use concision::Matmul;
7use nd::linalg::Dot;
8use nd::*;
9
10impl<A, S, T, D, E, F> Matmul<QkvBase<T, E>> for QkvBase<S, D>
11where
12    A: LinalgScalar,
13    D: Dimension,
14    E: Dimension,
15    F: Dimension,
16    S: Data<Elem = A>,
17    T: Data<Elem = A>,
18    ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
19{
20    type Output = Qkv<A, F>;
21
22    fn matmul(&self, rhs: &QkvBase<T, E>) -> Self::Output {
23        QkvBase {
24            q: self.q().dot(rhs.q()),
25            k: self.k().dot(rhs.k()),
26            v: self.v().dot(rhs.v()),
27        }
28    }
29}
30
31impl<A, S, T, D, E, F> Matmul<ArrayBase<T, E>> for QkvBase<S, D>
32where
33    A: LinalgScalar,
34    D: Dimension,
35    E: Dimension,
36    F: Dimension,
37    S: Data<Elem = A>,
38    T: Data<Elem = A>,
39    ArrayBase<S, D>: Dot<ArrayBase<T, E>, Output = Array<A, F>>,
40{
41    type Output = Qkv<A, F>;
42
43    fn matmul(&self, rhs: &ArrayBase<T, E>) -> Self::Output {
44        QkvBase {
45            q: self.q().dot(rhs),
46            k: self.k().dot(rhs),
47            v: self.v().dot(rhs),
48        }
49    }
50}