concision_linear/model/layout/
features.rs1use nd::prelude::{Dimension, Ix2, ShapeBuilder};
6use nd::{ErrorKind, IntoDimension, RemoveAxis, ShapeError};
7
8pub(crate) fn features<D>(dim: D) -> Result<Features, ShapeError>
9where
10 D: Dimension,
11{
12 if dim.ndim() == 1 {
13 Ok(Features::new(1, dim[0]))
14 } else if dim.ndim() >= 2 {
15 Ok(Features::new(dim[0], dim[1]))
16 } else {
17 Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
18 }
19}
20
21#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
23pub struct Features {
24 pub dmodel: usize, pub outputs: usize, }
27
28impl Features {
29 pub fn new(outputs: usize, dmodel: usize) -> Self {
30 Self { dmodel, outputs }
31 }
32
33 pub fn from_dim<D>(dim: D) -> Self
34 where
35 D: RemoveAxis,
36 {
37 features(dim).unwrap()
38 }
39
40 pub fn from_shape<D, Sh>(shape: Sh) -> Self
41 where
42 D: RemoveAxis,
43 Sh: ShapeBuilder<Dim = D>,
44 {
45 let shape = shape.into_shape();
46 let dim = shape.raw_dim().clone();
47 features(dim).unwrap()
48 }
49
50 pub fn check_dim<D>(&self, dim: D) -> bool
51 where
52 D: Dimension,
53 {
54 if dim.ndim() == 1 {
55 self.dmodel == dim[0]
56 } else if dim.ndim() >= 2 {
57 self.outputs == dim[0] && self.dmodel == dim[1]
58 } else {
59 false
60 }
61 }
62
63 pub fn into_pattern(self) -> (usize, usize) {
64 (self.outputs, self.dmodel)
65 }
66
67 pub fn neuron(d_model: usize) -> Self {
68 Self::new(1, d_model)
69 }
70
71 pub fn dmodel(&self) -> usize {
72 self.dmodel
73 }
74
75 pub fn features(&self) -> usize {
76 self.outputs
77 }
78
79 pub fn uniform_scale<T: num::Float>(&self) -> T {
80 T::from(self.dmodel()).unwrap().recip().sqrt()
81 }
82}
83
84impl core::fmt::Display for Features {
85 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86 write!(f, "({}, {})", self.features(), self.dmodel(),)
87 }
88}
89
90impl IntoDimension for Features {
91 type Dim = Ix2;
92
93 fn into_dimension(self) -> Self::Dim {
94 ndarray::Ix2(self.features(), self.dmodel())
95 }
96}
97
98impl PartialEq<(usize, usize)> for Features {
99 fn eq(&self, other: &(usize, usize)) -> bool {
100 self.features() == other.0 && self.dmodel() == other.1
101 }
102}
103
104macro_rules! impl_from {
105 ($($s:ty: $t:ty { $into:expr }),* $(,)?) => {
106 $(impl_from!(@impl $s: $t { $into });)*
107 };
108 (@impl $s:ty: $t:ty { $into:expr }) => {
109 impl From<$t> for $s {
110 fn from(features: $t) -> Self {
111 $into(features)
112 }
113 }
114 };
115}
116
117impl_from!(
118 Features: usize { |f: usize| Features::new(1, f) },
119 Features: [usize; 2] {| shape: [usize; 2] | Features::new(shape[0], shape[1])},
120 Features: (usize, usize) {| shape: (usize, usize) | Features::new(shape.0, shape.1)},
121 Features: nd::Ix1 {| shape: nd::Ix1 | Features::from(&shape)},
122 Features: nd::Ix2 {| shape: nd::Ix2 | Features::from(&shape)},
123 Features: nd::IxDyn {| shape: nd::IxDyn | Features::from(&shape)},
124);
125
126impl_from!(
127 nd::Ix2: Features { |f: Features| f.into_dimension() },
128 nd::IxDyn: Features { |f: Features| f.into_dimension().into_dyn() },
129 [usize; 2]: Features { |f: Features| [f.outputs, f.dmodel] },
130 (usize, usize): Features { |f: Features| (f.outputs, f.dmodel) },
131);
132
133impl<'a, D> From<&'a D> for Features
134where
135 D: RemoveAxis,
136{
137 fn from(dim: &'a D) -> Features {
138 features(dim.clone()).unwrap()
139 }
140}