concision_transformer/params/
store.rs

1/*
2    Appellation: params <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::attention::{Score, _attention};
6use concision::nn::DropoutLayer;
7use concision::{dimensional, getters};
8use nd::linalg::Dot;
9use nd::*;
10use num::complex::ComplexFloat;
11use num::traits::{One, Zero};
12
13/// [QkvBase] is a container for the query, key, and value arrays used in the
14/// attention mechanism of the transformer model.
15pub struct QkvBase<S = OwnedRepr<f64>, D = Ix2>
16where
17    D: Dimension,
18    S: RawData,
19{
20    pub(crate) q: ArrayBase<S, D>,
21    pub(crate) k: ArrayBase<S, D>,
22    pub(crate) v: ArrayBase<S, D>,
23}
24
25impl<A, S, D> QkvBase<S, D>
26where
27    D: Dimension,
28    S: RawData<Elem = A>,
29{
30    pub fn builder<Sh, F>(shape: Sh, builder: F) -> Self
31    where
32        F: Fn(D) -> ArrayBase<S, D>,
33        Sh: ShapeBuilder<Dim = D>,
34    {
35        let dim = shape.into_shape().raw_dim().clone();
36        Self {
37            q: builder(dim.clone()),
38            k: builder(dim.clone()),
39            v: builder(dim),
40        }
41    }
42
43    pub fn from_elem<Sh>(shape: Sh, value: A) -> Self
44    where
45        Sh: ShapeBuilder<Dim = D>,
46        A: Clone,
47        S: DataOwned,
48    {
49        let dim = shape.into_shape().raw_dim().clone();
50        Self {
51            q: ArrayBase::from_elem(dim.clone(), value.clone()),
52            k: ArrayBase::from_elem(dim.clone(), value.clone()),
53            v: ArrayBase::from_elem(dim, value),
54        }
55    }
56
57    pub fn as_qkv(&self) -> (ArrayView<A, D>, ArrayView<A, D>, ArrayView<A, D>)
58    where
59        S: Data,
60    {
61        (self.q.view(), self.k.view(), self.v.view())
62    }
63
64    /// Consumes the store and returns a three-tuple consisting of the query, key, and value arrays respectively.
65    pub fn into_qkv(self) -> (ArrayBase<S, D>, ArrayBase<S, D>, ArrayBase<S, D>) {
66        (self.q, self.k, self.v)
67    }
68
69    pub fn qkv(&self) -> (&ArrayBase<S, D>, &ArrayBase<S, D>, &ArrayBase<S, D>) {
70        (&self.q, &self.k, &self.v)
71    }
72
73    ndbuilder!(new::default() where A: Default, S: DataOwned);
74    ndbuilder!(ones() where A: Clone + One, S: DataOwned);
75    ndbuilder!(zeros() where A: Clone + Zero, S: DataOwned);
76
77    getters!(q, k, v => ArrayBase<S, D>);
78
79    dimensional!(q());
80
81    qkv_view!(into_owned::<OwnedRepr>(self) where A: Clone, S: Data);
82    qkv_view!(to_owned::<OwnedRepr>(&self) where A: Clone, S: Data);
83
84    qkv_view!(into_shared::<OwnedArcRepr>(self) where A: Clone, S: DataOwned);
85    qkv_view!(to_shared::<OwnedArcRepr>(&self) where A: Clone, S: DataShared);
86
87    qkv_view!(view::<'a, ViewRepr>(&self) where S: Data);
88    qkv_view!(view_mut::<'a, ViewRepr>(&mut self) where S: DataMut);
89}
90
91#[cfg(not(feature = "rand"))]
92impl<A, S, D> QkvBase<S, D>
93where
94    D: Dimension,
95    S: RawData<Elem = A>,
96    A: Clone,
97{
98    /// Computes the [Score] using scaled dot-product attention.
99    pub fn attention(&self, dropout: Option<f64>, mask: Option<&Array<bool, D>>) -> Score<A, D>
100    where
101        A: ComplexFloat + ScalarOperand,
102        S: Data,
103        ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
104        Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
105    {
106        let (q, k, v) = self.qkv();
107        _attention(q, k, v, mask, None)
108    }
109}
110
111#[cfg(feature = "rand")]
112impl<A, S, D> QkvBase<S, D>
113where
114    D: Dimension,
115    S: RawData<Elem = A>,
116    A: Clone,
117{
118    /// Computes the [Score] using scaled dot-product attention.
119    pub fn attention(&self, dropout: Option<f64>, mask: Option<&Array<bool, D>>) -> Score<A, D>
120    where
121        A: ComplexFloat + ScalarOperand,
122        S: Data,
123        ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
124        Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
125    {
126        let dropout = dropout.map(DropoutLayer::new);
127        let (q, k, v) = self.qkv();
128        _attention(q, k, v, mask, dropout.as_ref())
129    }
130}