concision_transformer/attention/
score.rs1use core::fmt;
6use nd::{Array, Dimension};
7
8#[derive(Clone, Eq, Hash, PartialEq)]
13pub struct Score<A, D>
14where
15 D: Dimension,
16{
17 pub(crate) attention: Array<A, D>,
18 pub(crate) score: Array<A, D>,
19}
20
21impl<A, D> Score<A, D>
22where
23 D: Dimension,
24{
25 pub(crate) fn new(attention: Array<A, D>, score: Array<A, D>) -> Self {
26 Self { attention, score }
27 }
28 pub fn into_attention(self) -> Array<A, D> {
30 self.attention
31 }
32 pub fn into_score(self) -> Array<A, D> {
34 self.score
35 }
36
37 pub fn attention(&self) -> &Array<A, D> {
39 &self.attention
40 }
41 pub fn score(&self) -> &Array<A, D> {
43 &self.score
44 }
45}
46
47impl<A, D> Copy for Score<A, D>
48where
49 A: Copy,
50 D: Copy + Dimension,
51 Array<A, D>: Copy,
52{
53}
54
55impl<A, D> fmt::Debug for Score<A, D>
56where
57 A: fmt::Debug,
58 D: Dimension,
59{
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("Score")
62 .field("attention", &self.attention)
63 .field("score", &self.score)
64 .finish()
65 }
66}
67
68impl<A, D> fmt::Display for Score<A, D>
69where
70 A: fmt::Display,
71 D: Dimension,
72{
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "({}, {})", self.attention, self.score)
75 }
76}
77
78impl<A, D> From<(Array<A, D>, Array<A, D>)> for Score<A, D>
79where
80 D: Dimension,
81{
82 fn from((attention, score): (Array<A, D>, Array<A, D>)) -> Self {
83 Self::new(attention, score)
84 }
85}
86
87impl<A, D> From<Score<A, D>> for (Array<A, D>, Array<A, D>)
88where
89 D: Dimension,
90{
91 fn from(score: Score<A, D>) -> Self {
92 (score.attention, score.score)
93 }
94}