concision_transformer/impls/
impl_params.rs

1/*
2    Appellation: impl_params <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::params::QkvBase;
6use nd::prelude::*;
7use nd::{Data, DataOwned, RawDataClone};
8
9pub(crate) type ThreeTuple<A, B = A, C = B> = (A, B, C);
10
11impl<A, S, D> Clone for QkvBase<S, D>
12where
13    D: Dimension,
14    S: RawDataClone<Elem = A>,
15{
16    fn clone(&self) -> Self {
17        Self {
18            q: self.q.clone(),
19            k: self.k.clone(),
20            v: self.v.clone(),
21        }
22    }
23}
24
25impl<A, S, D> Copy for QkvBase<S, D>
26where
27    D: Copy + Dimension,
28    S: Copy + RawDataClone<Elem = A>,
29{
30}
31
32impl<A, S, D> Default for QkvBase<S, D>
33where
34    A: Default,
35    D: Dimension,
36    S: DataOwned<Elem = A>,
37{
38    fn default() -> Self {
39        Self {
40            q: Default::default(),
41            k: Default::default(),
42            v: Default::default(),
43        }
44    }
45}
46
47impl<A, S, D> PartialEq for QkvBase<S, D>
48where
49    A: PartialEq,
50    D: Dimension,
51    S: Data<Elem = A>,
52{
53    fn eq(&self, other: &Self) -> bool {
54        self.q() == other.q() && self.k() == other.k() && self.v() == other.v()
55    }
56}
57
58impl<A, B, S, D, S2, D2> PartialEq<ArrayBase<S2, D2>> for QkvBase<S, D>
59where
60    A: PartialEq,
61    B: PartialEq,
62    D: Dimension,
63    S: Data<Elem = A>,
64    S2: Data<Elem = B>,
65    D2: Dimension,
66    ArrayBase<S, D>: PartialEq<ArrayBase<S2, D2>>,
67{
68    fn eq(&self, other: &ArrayBase<S2, D2>) -> bool {
69        self.q() == other && self.k() == other && self.v() == other
70    }
71}
72
73impl<A, B, S, D, S2, D2> PartialEq<ThreeTuple<ArrayBase<S2, D2>>> for QkvBase<S, D>
74where
75    A: PartialEq,
76    B: PartialEq,
77    D: Dimension,
78    S: Data<Elem = A>,
79    S2: Data<Elem = B>,
80    D2: Dimension,
81    ArrayBase<S, D>: PartialEq<ArrayBase<S2, D2>>,
82{
83    fn eq(&self, (q, k, v): &ThreeTuple<ArrayBase<S2, D2>>) -> bool {
84        self.q() == q && self.k() == k && self.v() == v
85    }
86}