concision_neural/layers/attention/
qkv.rs1use cnc::Forward;
6use ndarray::linalg::Dot;
7use ndarray::{ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, ShapeBuilder};
8use num_traits::{One, Zero};
9
10pub type Qkv<A = f64, D = Ix2> = QkvParamsBase<ndarray::OwnedRepr<A>, D>;
11
12pub struct QkvParamsBase<S, D = Ix2>
14where
15 D: Dimension,
16 S: RawData,
17{
18 pub(crate) query: ArrayBase<S, D>,
19 pub(crate) key: ArrayBase<S, D>,
20 pub(crate) value: ArrayBase<S, D>,
21}
22
23impl<A, S, D> QkvParamsBase<S, D>
24where
25 D: Dimension,
26 S: RawData<Elem = A>,
27{
28 pub fn new(query: ArrayBase<S, D>, key: ArrayBase<S, D>, value: ArrayBase<S, D>) -> Self {
29 Self { query, key, value }
30 }
31 pub fn from_elem<Sh: ShapeBuilder<Dim = D>>(shape: Sh, elem: A) -> Self
32 where
33 A: Clone,
34 S: DataOwned,
35 {
36 let shape = shape.into_shape_with_order();
37 let dim = shape.raw_dim().clone();
38 let query = ArrayBase::from_elem(dim.clone(), elem.clone());
39 let key = ArrayBase::from_elem(dim.clone(), elem.clone());
40 let value = ArrayBase::from_elem(dim.clone(), elem);
41 Self::new(query, key, value)
42 }
43
44 pub fn default<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
45 where
46 A: Clone + Default,
47 S: DataOwned,
48 {
49 Self::from_elem(shape, A::default())
50 }
51
52 pub fn ones<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
53 where
54 A: Clone + One,
55 S: DataOwned,
56 {
57 Self::from_elem(shape, A::one())
58 }
59
60 pub fn zeros<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
61 where
62 A: Clone + Zero,
63 S: DataOwned,
64 {
65 Self::from_elem(shape, A::zero())
66 }
67 pub const fn key(&self) -> &ArrayBase<S, D> {
69 &self.key
70 }
71 pub fn key_mut(&mut self) -> &mut ArrayBase<S, D> {
73 &mut self.key
74 }
75 pub const fn query(&self) -> &ArrayBase<S, D> {
77 &self.query
78 }
79 pub fn query_mut(&mut self) -> &mut ArrayBase<S, D> {
81 &mut self.query
82 }
83 pub const fn value(&self) -> &ArrayBase<S, D> {
85 &self.value
86 }
87 pub fn value_mut(&mut self) -> &mut ArrayBase<S, D> {
89 &mut self.value
90 }
91
92 pub fn set_key(&mut self, key: ArrayBase<S, D>) -> &mut Self {
93 *self.key_mut() = key;
94 self
95 }
96
97 pub fn set_query(&mut self, query: ArrayBase<S, D>) -> &mut Self {
98 *self.query_mut() = query;
99 self
100 }
101
102 pub fn set_value(&mut self, value: ArrayBase<S, D>) -> &mut Self {
103 *self.value_mut() = value;
104 self
105 }
106
107 pub fn with_key(self, key: ArrayBase<S, D>) -> Self {
108 Self { key, ..self }
109 }
110
111 pub fn with_query(self, query: ArrayBase<S, D>) -> Self {
112 Self { query, ..self }
113 }
114
115 pub fn with_value(self, value: ArrayBase<S, D>) -> Self {
116 Self { value, ..self }
117 }
118}
119
120impl<X, Z, A, S, D> Forward<X> for QkvParamsBase<S, D>
122where
123 A: Clone,
124 D: Dimension,
125 S: Data<Elem = A>,
126 X: Dot<ArrayBase<S, D>, Output = Z>,
127 Z: core::ops::Add<Output = Z>,
128 for<'a> Z: core::ops::Add<&'a Z, Output = Z>,
129{
130 type Output = Z;
131
132 fn forward(&self, input: &X) -> cnc::Result<Self::Output> {
133 let query = input.dot(&self.query);
134 let key = input.dot(&self.key);
135 let value = input.dot(&self.value);
136 let output = query + key + value;
137 Ok(output)
138 }
139}