concision_transformer/impls/
impl_head.rs1use crate::attention::{Attention, AttentionHead, Score};
6use crate::params::QkvBase;
7use core::borrow::{Borrow, BorrowMut};
8use nd::linalg::Dot;
9use nd::prelude::*;
10use nd::{Data, DataOwned, RawData, RawDataClone, ScalarOperand};
11use num::complex::ComplexFloat;
12
13impl<A, S, D> Attention for AttentionHead<A, D, S>
14where
15 A: ComplexFloat + ScalarOperand,
16 D: Dimension,
17 S: Data<Elem = A>,
18 ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
19 Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
20{
21 type Output = Score<A, D>;
22
23 fn attention(&self) -> Self::Output {
24 self.attention()
25 }
26}
27
28impl<A, S, D> Borrow<QkvBase<S, D>> for AttentionHead<A, D, S>
29where
30 D: Dimension,
31 S: RawData<Elem = A>,
32{
33 fn borrow(&self) -> &QkvBase<S, D> {
34 self.params()
35 }
36}
37
38impl<A, S, D> BorrowMut<QkvBase<S, D>> for AttentionHead<A, D, S>
39where
40 D: Dimension,
41 S: RawData<Elem = A>,
42{
43 fn borrow_mut(&mut self) -> &mut QkvBase<S, D> {
44 self.params_mut()
45 }
46}
47
48impl<A, S, D> Clone for AttentionHead<A, D, S>
49where
50 A: Copy,
51 D: Dimension,
52 S: RawDataClone<Elem = A>,
53{
54 fn clone(&self) -> Self {
55 Self {
56 #[cfg(feature = "rand")]
57 dropout: self.dropout.clone(),
58 mask: self.mask.clone(),
59 params: self.params.clone(),
60 }
61 }
62}
63
64impl<A, S, D> Copy for AttentionHead<A, D, S>
65where
66 A: Copy,
67 D: Copy + Dimension,
68 S: Copy + RawDataClone<Elem = A>,
69 Array<bool, D>: Copy,
70{
71}
72
73impl<A, S, D> Default for AttentionHead<A, D, S>
74where
75 A: Default,
76 D: Dimension,
77 S: DataOwned<Elem = A>,
78{
79 fn default() -> Self {
80 Self::from_params(QkvBase::default())
81 }
82}
83
84impl<A, S, D> From<QkvBase<S, D>> for AttentionHead<A, D, S>
85where
86 D: Dimension,
87 S: RawData<Elem = A>,
88{
89 fn from(params: QkvBase<S, D>) -> Self {
90 Self::from_params(params)
91 }
92}