concision_linear/model/
config.rs

1/*
2    Appellation: config <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::layout::{Features, Layout};
6use crate::params::{Biased, Unbiased};
7use core::marker::PhantomData;
8use nd::prelude::*;
9use nd::{IntoDimension, RemoveAxis, ShapeError};
10
11#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct Config<K = Biased, D = Ix2> {
14    pub layout: Layout<D>,
15    pub name: String,
16    _biased: PhantomData<K>,
17}
18
19impl<K, D> Config<K, D>
20where
21    D: Dimension,
22{
23    pub fn new() -> Self {
24        Self {
25            layout: Layout::default(),
26            name: String::new(),
27            _biased: PhantomData::<K>,
28        }
29    }
30
31    pub fn from_dim(dim: D) -> Result<Self, ShapeError>
32    where
33        D: Dimension,
34    {
35        let layout = Layout::from_dim(dim)?;
36        let res = Self::new().with_layout(layout);
37        Ok(res)
38    }
39
40    pub fn from_shape<Sh>(shape: Sh) -> Self
41    where
42        D: RemoveAxis,
43        Sh: ShapeBuilder<Dim = D>,
44    {
45        let layout = Layout::from_shape(shape);
46        Self::new().with_layout(layout)
47    }
48
49    pub fn into_biased(self) -> Config<Biased, D> {
50        Config {
51            layout: self.layout,
52            name: self.name,
53            _biased: PhantomData::<Biased>,
54        }
55    }
56
57    pub fn into_unbiased(self) -> Config<Unbiased, D> {
58        Config {
59            layout: self.layout,
60            name: self.name,
61            _biased: PhantomData::<Unbiased>,
62        }
63    }
64
65    pub fn with_name(self, name: impl ToString) -> Self {
66        Self {
67            name: name.to_string(),
68            ..self
69        }
70    }
71
72    pub fn with_layout<E>(self, layout: Layout<E>) -> Config<K, E>
73    where
74        E: Dimension,
75    {
76        Config {
77            layout,
78            name: self.name,
79            _biased: self._biased,
80        }
81    }
82
83    pub fn with_shape<E, Sh>(self, shape: Sh) -> Config<K, E>
84    where
85        E: RemoveAxis,
86        Sh: ShapeBuilder<Dim = E>,
87    {
88        Config {
89            layout: self.layout.with_shape(shape),
90            name: self.name,
91            _biased: self._biased,
92        }
93    }
94
95    /// This function attempts to convert the [layout](Layout) of the [Config] into a new [dimension](ndarray::Dimension)
96    pub fn into_dimensionality<E>(self, dim: E) -> Result<Config<K, E>, nd::ShapeError>
97    where
98        E: Dimension,
99    {
100        let tmp = Config {
101            layout: self.layout.into_dimensionality(dim)?,
102            name: self.name,
103            _biased: self._biased,
104        };
105        Ok(tmp)
106    }
107    /// Determine whether the [Config] is [Biased];
108    /// Returns true by comparing the [TypeId](core::any::TypeId) of `K` against the [TypeId](core::any::TypeId) of the [Biased] type
109    pub fn is_biased(&self) -> bool
110    where
111        K: 'static,
112    {
113        use core::any::TypeId;
114
115        TypeId::of::<K>() == TypeId::of::<Biased>()
116    }
117    /// Returns an instance to the [Features] of the [Layout]
118    pub fn features(&self) -> Features {
119        self.layout().features()
120    }
121    /// Returns an owned reference to the [Layout]
122    pub const fn layout(&self) -> &Layout<D> {
123        &self.layout
124    }
125    /// Returns an immutable reference to the `name` of the model.
126    pub fn name(&self) -> &str {
127        &self.name
128    }
129
130    /// Returns a cloned reference to the [dimension](ndarray::Dimension) of the [layout](Layout)
131    pub fn dim(&self) -> D {
132        self.layout().dim()
133    }
134
135    pub fn into_pattern(self) -> D::Pattern {
136        self.dim().into_pattern()
137    }
138
139    pub fn ndim(&self) -> usize {
140        self.layout().ndim()
141    }
142}
143
144impl<K> Config<K, Ix2> {
145    pub fn std(inputs: usize, outputs: usize) -> Self {
146        Self {
147            layout: Layout::new((outputs, inputs).into_dimension()),
148            name: String::new(),
149            _biased: PhantomData::<K>,
150        }
151    }
152}
153
154impl<D> Config<Biased, D>
155where
156    D: Dimension,
157{
158    /// The default constructor method for building [Biased] configurations.
159    pub fn biased() -> Self {
160        Self::new()
161    }
162}
163
164impl<D> Config<Unbiased, D>
165where
166    D: Dimension,
167{
168    pub fn unbiased() -> Self {
169        Self::new()
170    }
171}
172
173impl<K, D> concision::Config for Config<K, D> where D: Dimension {}
174
175impl<D> Default for Config<Biased, D>
176where
177    D: Dimension,
178{
179    fn default() -> Self {
180        Self::new()
181    }
182}