concision_neural/layers/attention/
qkv.rs

1/*
2    Appellation: qkv <module>
3    Contrib: @FL03
4*/
5use cnc::Forward;
6use ndarray::linalg::Dot;
7use ndarray::{ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, ShapeBuilder};
8use num_traits::{One, Zero};
9
10pub type Qkv<A = f64, D = Ix2> = QkvParamsBase<ndarray::OwnedRepr<A>, D>;
11
12/// This object is designed to store the parameters of the QKV (Query, Key, Value)
13pub struct QkvParamsBase<S, D = Ix2>
14where
15    D: Dimension,
16    S: RawData,
17{
18    pub(crate) query: ArrayBase<S, D>,
19    pub(crate) key: ArrayBase<S, D>,
20    pub(crate) value: ArrayBase<S, D>,
21}
22
23impl<A, S, D> QkvParamsBase<S, D>
24where
25    D: Dimension,
26    S: RawData<Elem = A>,
27{
28    pub fn new(query: ArrayBase<S, D>, key: ArrayBase<S, D>, value: ArrayBase<S, D>) -> Self {
29        Self { query, key, value }
30    }
31    pub fn from_elem<Sh: ShapeBuilder<Dim = D>>(shape: Sh, elem: A) -> Self
32    where
33        A: Clone,
34        S: DataOwned,
35    {
36        let shape = shape.into_shape_with_order();
37        let dim = shape.raw_dim().clone();
38        let query = ArrayBase::from_elem(dim.clone(), elem.clone());
39        let key = ArrayBase::from_elem(dim.clone(), elem.clone());
40        let value = ArrayBase::from_elem(dim.clone(), elem);
41        Self::new(query, key, value)
42    }
43
44    pub fn default<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
45    where
46        A: Clone + Default,
47        S: DataOwned,
48    {
49        Self::from_elem(shape, A::default())
50    }
51
52    pub fn ones<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
53    where
54        A: Clone + One,
55        S: DataOwned,
56    {
57        Self::from_elem(shape, A::one())
58    }
59
60    pub fn zeros<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
61    where
62        A: Clone + Zero,
63        S: DataOwned,
64    {
65        Self::from_elem(shape, A::zero())
66    }
67    /// returns an immutable reference to the key parameters
68    pub const fn key(&self) -> &ArrayBase<S, D> {
69        &self.key
70    }
71    /// returns a mutable reference to the key parameters
72    pub fn key_mut(&mut self) -> &mut ArrayBase<S, D> {
73        &mut self.key
74    }
75    /// returns an immutable reference to the query parameters
76    pub const fn query(&self) -> &ArrayBase<S, D> {
77        &self.query
78    }
79    /// returns a mutable reference to the query parameters
80    pub fn query_mut(&mut self) -> &mut ArrayBase<S, D> {
81        &mut self.query
82    }
83    /// returns an immutable reference to the value parameters
84    pub const fn value(&self) -> &ArrayBase<S, D> {
85        &self.value
86    }
87    /// returns a mutable reference to the value parameters
88    pub fn value_mut(&mut self) -> &mut ArrayBase<S, D> {
89        &mut self.value
90    }
91
92    pub fn set_key(&mut self, key: ArrayBase<S, D>) -> &mut Self {
93        *self.key_mut() = key;
94        self
95    }
96
97    pub fn set_query(&mut self, query: ArrayBase<S, D>) -> &mut Self {
98        *self.query_mut() = query;
99        self
100    }
101
102    pub fn set_value(&mut self, value: ArrayBase<S, D>) -> &mut Self {
103        *self.value_mut() = value;
104        self
105    }
106
107    pub fn with_key(self, key: ArrayBase<S, D>) -> Self {
108        Self { key, ..self }
109    }
110
111    pub fn with_query(self, query: ArrayBase<S, D>) -> Self {
112        Self { query, ..self }
113    }
114
115    pub fn with_value(self, value: ArrayBase<S, D>) -> Self {
116        Self { value, ..self }
117    }
118}
119
120/// This trait is used to implement the forward pass for the QKV parameters.
121impl<X, Z, A, S, D> Forward<X> for QkvParamsBase<S, D>
122where
123    A: Clone,
124    D: Dimension,
125    S: Data<Elem = A>,
126    X: Dot<ArrayBase<S, D>, Output = Z>,
127    Z: core::ops::Add<Output = Z>,
128    for<'a> Z: core::ops::Add<&'a Z, Output = Z>,
129{
130    type Output = Z;
131
132    fn forward(&self, input: &X) -> cnc::Result<Self::Output> {
133        let query = input.dot(&self.query);
134        let key = input.dot(&self.key);
135        let value = input.dot(&self.value);
136        let output = query + key + value;
137        Ok(output)
138    }
139}