concision_transformer/impls/
impl_head.rs

1/*
2    Appellation: impl_head <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::attention::{Attention, AttentionHead, Score};
6use crate::params::QkvBase;
7use core::borrow::{Borrow, BorrowMut};
8use nd::linalg::Dot;
9use nd::prelude::*;
10use nd::{Data, DataOwned, RawData, RawDataClone, ScalarOperand};
11use num::complex::ComplexFloat;
12
13impl<A, S, D> Attention for AttentionHead<A, D, S>
14where
15    A: ComplexFloat + ScalarOperand,
16    D: Dimension,
17    S: Data<Elem = A>,
18    ArrayBase<S, D>: for<'a> Dot<ArrayView<'a, A, D>, Output = Array<A, D>>,
19    Array<A, D>: Dot<ArrayBase<S, D>, Output = Array<A, D>>,
20{
21    type Output = Score<A, D>;
22
23    fn attention(&self) -> Self::Output {
24        self.attention()
25    }
26}
27
28impl<A, S, D> Borrow<QkvBase<S, D>> for AttentionHead<A, D, S>
29where
30    D: Dimension,
31    S: RawData<Elem = A>,
32{
33    fn borrow(&self) -> &QkvBase<S, D> {
34        self.params()
35    }
36}
37
38impl<A, S, D> BorrowMut<QkvBase<S, D>> for AttentionHead<A, D, S>
39where
40    D: Dimension,
41    S: RawData<Elem = A>,
42{
43    fn borrow_mut(&mut self) -> &mut QkvBase<S, D> {
44        self.params_mut()
45    }
46}
47
48impl<A, S, D> Clone for AttentionHead<A, D, S>
49where
50    A: Copy,
51    D: Dimension,
52    S: RawDataClone<Elem = A>,
53{
54    fn clone(&self) -> Self {
55        Self {
56            #[cfg(feature = "rand")]
57            dropout: self.dropout.clone(),
58            mask: self.mask.clone(),
59            params: self.params.clone(),
60        }
61    }
62}
63
64impl<A, S, D> Copy for AttentionHead<A, D, S>
65where
66    A: Copy,
67    D: Copy + Dimension,
68    S: Copy + RawDataClone<Elem = A>,
69    Array<bool, D>: Copy,
70{
71}
72
73impl<A, S, D> Default for AttentionHead<A, D, S>
74where
75    A: Default,
76    D: Dimension,
77    S: DataOwned<Elem = A>,
78{
79    fn default() -> Self {
80        Self::from_params(QkvBase::default())
81    }
82}
83
84impl<A, S, D> From<QkvBase<S, D>> for AttentionHead<A, D, S>
85where
86    D: Dimension,
87    S: RawData<Elem = A>,
88{
89    fn from(params: QkvBase<S, D>) -> Self {
90        Self::from_params(params)
91    }
92}