concision_transformer/params/
store.rs1use crate::attention::{Score, _attention};
6use concision::nn::DropoutLayer;
7use concision::{dimensional, getters};
8use nd::linalg::Dot;
9use nd::*;
10use num::complex::ComplexFloat;
11use num::traits::{One, Zero};
12
13pub struct QkvBase<S = OwnedRepr<f64>, D = Ix2>
16where
17 D: Dimension,
18 S: RawData,
19{
20 pub(crate) q: ArrayBase<S, D>,
21 pub(crate) k: ArrayBase<S, D>,
22 pub(crate) v: ArrayBase<S, D>,
23}
24
25impl<A, S, D> QkvBase<S, D>
26where
27 D: Dimension,
28 S: RawData<Elem = A>,
29{
30 pub fn builder<Sh, F>(shape: Sh, builder: F) -> Self
31 where
32 F: Fn(D) -> ArrayBase<S, D>,
33 Sh: ShapeBuilder<Dim = D>,
34 {
35 let dim = shape.into_shape().raw_dim().clone();
36 Self {
37 q: builder(dim.clone()),
38 k: builder(dim.clone()),
39 v: builder(dim),
40 }
41 }
42
43 pub fn from_elem<Sh>(shape: Sh, value: A) -> Self
44 where
45 Sh: ShapeBuilder<Dim = D>,
46 A: Clone,
47 S: DataOwned,
48 {
49 let dim = shape.into_shape().raw_dim().clone();
50 Self {
51 q: ArrayBase::from_elem(dim.clone(), value.clone()),
52 k: ArrayBase::from_elem(dim.clone(), value.clone()),
53 v: ArrayBase::from_elem(dim, value),
54 }
55 }
56
57 pub fn as_qkv(&self) -> (ArrayView<A, D>, ArrayView<A, D>, ArrayView<A, D>)
58 where
59 S: Data,
60 {
61 (self.q.view(), self.k.view(), self.v.view())
62 }
63
64 pub fn into_qkv(self) -> (ArrayBase<S, D>, ArrayBase<S, D>, ArrayBase<S, D>) {
66 (self.q, self.k, self.v)
67 }
68
69 pub fn qkv(&self) -> (&ArrayBase<S, D>, &ArrayBase<S, D>, &ArrayBase<S, D>) {
70 (&self.q, &self.k, &self.v)
71 }
72
73 ndbuilder!(new::default() where A: Default, S: DataOwned);
74 ndbuilder!(ones() where A: Clone + One, S: DataOwned);
75 ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned);
76
77 getters!(q, k, v => ArrayBase<S, D>);
78
79 dimensional!(q());
80
81 qkv_view!(into_owned::<OwnedRepr>(self) where A: Clone, S: Data);
82 qkv_view!(to_owned::<OwnedRepr>(&self) where A: Clone, S: Data);
83
84 qkv_view!(into_shared::<OwnedArcRepr>(self) where A: Clone, S: DataOwned);
85 qkv_view!(to_shared::<OwnedArcRepr>(&self) where A: Clone, S: DataShared);
86
87 qkv_view!(view::<'a, ViewRepr>(&self) where S: Data);
88 qkv_view!(view_mut::<'a, ViewRepr>(&mut self) where S: DataMut);
89}
90
91#[cfg(not(feature = "rand"))]
92impl<A, S, D> QkvBase<S, D>
93where
94 D: Dimension,
95 S: RawData<Elem = A>,
96 A: Clone,
97{
98 pub fn attention(&self, dropout: Option<f64>, mask: Option<&Array<bool, D>>) -> Score<A, D>
100 where
101 A: ComplexFloat + ScalarOperand,
102 S: Data,
103 ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
104 Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
105 {
106 let (q, k, v) = self.qkv();
107 _attention(q, k, v, mask, None)
108 }
109}
110
111#[cfg(feature = "rand")]
112impl<A, S, D> QkvBase<S, D>
113where
114 D: Dimension,
115 S: RawData<Elem = A>,
116 A: Clone,
117{
118 pub fn attention(&self, dropout: Option<f64>, mask: Option<&Array<bool, D>>) -> Score<A, D>
120 where
121 A: ComplexFloat + ScalarOperand,
122 S: Data,
123 ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
124 Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
125 {
126 let dropout = dropout.map(DropoutLayer::new);
127 let (q, k, v) = self.qkv();
128 _attention(q, k, v, mask, dropout.as_ref())
129 }
130}