concision_linear/model/layout/
features.rs

1/*
2   Appellation: features <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use nd::prelude::{Dimension, Ix2, ShapeBuilder};
6use nd::{ErrorKind, IntoDimension, RemoveAxis, ShapeError};
7
8pub(crate) fn features<D>(dim: D) -> Result<Features, ShapeError>
9where
10    D: Dimension,
11{
12    if dim.ndim() == 1 {
13        Ok(Features::new(1, dim[0]))
14    } else if dim.ndim() >= 2 {
15        Ok(Features::new(dim[0], dim[1]))
16    } else {
17        Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
18    }
19}
20
21#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
22#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
23pub struct Features {
24    pub dmodel: usize,  // inputs
25    pub outputs: usize, // outputs
26}
27
28impl Features {
29    pub fn new(outputs: usize, dmodel: usize) -> Self {
30        Self { dmodel, outputs }
31    }
32
33    pub fn from_dim<D>(dim: D) -> Self
34    where
35        D: RemoveAxis,
36    {
37        features(dim).unwrap()
38    }
39
40    pub fn from_shape<D, Sh>(shape: Sh) -> Self
41    where
42        D: RemoveAxis,
43        Sh: ShapeBuilder<Dim = D>,
44    {
45        let shape = shape.into_shape();
46        let dim = shape.raw_dim().clone();
47        features(dim).unwrap()
48    }
49
50    pub fn check_dim<D>(&self, dim: D) -> bool
51    where
52        D: Dimension,
53    {
54        if dim.ndim() == 1 {
55            self.dmodel == dim[0]
56        } else if dim.ndim() >= 2 {
57            self.outputs == dim[0] && self.dmodel == dim[1]
58        } else {
59            false
60        }
61    }
62
63    pub fn into_pattern(self) -> (usize, usize) {
64        (self.outputs, self.dmodel)
65    }
66
67    pub fn neuron(d_model: usize) -> Self {
68        Self::new(1, d_model)
69    }
70
71    pub fn dmodel(&self) -> usize {
72        self.dmodel
73    }
74
75    pub fn features(&self) -> usize {
76        self.outputs
77    }
78
79    pub fn uniform_scale<T: num::Float>(&self) -> T {
80        T::from(self.dmodel()).unwrap().recip().sqrt()
81    }
82}
83
84impl core::fmt::Display for Features {
85    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86        write!(f, "({}, {})", self.features(), self.dmodel(),)
87    }
88}
89
90impl IntoDimension for Features {
91    type Dim = Ix2;
92
93    fn into_dimension(self) -> Self::Dim {
94        ndarray::Ix2(self.features(), self.dmodel())
95    }
96}
97
98impl PartialEq<(usize, usize)> for Features {
99    fn eq(&self, other: &(usize, usize)) -> bool {
100        self.features() == other.0 && self.dmodel() == other.1
101    }
102}
103
104macro_rules! impl_from {
105    ($($s:ty: $t:ty { $into:expr }),* $(,)?) => {
106        $(impl_from!(@impl $s: $t { $into });)*
107    };
108    (@impl $s:ty: $t:ty { $into:expr }) => {
109        impl From<$t> for $s {
110            fn from(features: $t) -> Self {
111                $into(features)
112            }
113        }
114    };
115}
116
117impl_from!(
118    Features: usize { |f: usize| Features::new(1, f) },
119    Features: [usize; 2] {| shape: [usize; 2] | Features::new(shape[0], shape[1])},
120    Features: (usize, usize) {| shape: (usize, usize) | Features::new(shape.0, shape.1)},
121    Features: nd::Ix1 {| shape: nd::Ix1 | Features::from(&shape)},
122    Features: nd::Ix2 {| shape: nd::Ix2 | Features::from(&shape)},
123    Features: nd::IxDyn {| shape: nd::IxDyn | Features::from(&shape)},
124);
125
126impl_from!(
127    nd::Ix2: Features { |f: Features| f.into_dimension() },
128    nd::IxDyn: Features { |f: Features| f.into_dimension().into_dyn() },
129    [usize; 2]: Features { |f: Features| [f.outputs, f.dmodel] },
130    (usize, usize): Features { |f: Features| (f.outputs, f.dmodel) },
131);
132
133impl<'a, D> From<&'a D> for Features
134where
135    D: RemoveAxis,
136{
137    fn from(dim: &'a D) -> Features {
138        features(dim.clone()).unwrap()
139    }
140}