concision_linear/model/
config.rs1use super::layout::{Features, Layout};
6use crate::params::{Biased, Unbiased};
7use core::marker::PhantomData;
8use nd::prelude::*;
9use nd::{IntoDimension, RemoveAxis, ShapeError};
10
11#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct Config<K = Biased, D = Ix2> {
14 pub layout: Layout<D>,
15 pub name: String,
16 _biased: PhantomData<K>,
17}
18
19impl<K, D> Config<K, D>
20where
21 D: Dimension,
22{
23 pub fn new() -> Self {
24 Self {
25 layout: Layout::default(),
26 name: String::new(),
27 _biased: PhantomData::<K>,
28 }
29 }
30
31 pub fn from_dim(dim: D) -> Result<Self, ShapeError>
32 where
33 D: Dimension,
34 {
35 let layout = Layout::from_dim(dim)?;
36 let res = Self::new().with_layout(layout);
37 Ok(res)
38 }
39
40 pub fn from_shape<Sh>(shape: Sh) -> Self
41 where
42 D: RemoveAxis,
43 Sh: ShapeBuilder<Dim = D>,
44 {
45 let layout = Layout::from_shape(shape);
46 Self::new().with_layout(layout)
47 }
48
49 pub fn into_biased(self) -> Config<Biased, D> {
50 Config {
51 layout: self.layout,
52 name: self.name,
53 _biased: PhantomData::<Biased>,
54 }
55 }
56
57 pub fn into_unbiased(self) -> Config<Unbiased, D> {
58 Config {
59 layout: self.layout,
60 name: self.name,
61 _biased: PhantomData::<Unbiased>,
62 }
63 }
64
65 pub fn with_name(self, name: impl ToString) -> Self {
66 Self {
67 name: name.to_string(),
68 ..self
69 }
70 }
71
72 pub fn with_layout<E>(self, layout: Layout<E>) -> Config<K, E>
73 where
74 E: Dimension,
75 {
76 Config {
77 layout,
78 name: self.name,
79 _biased: self._biased,
80 }
81 }
82
83 pub fn with_shape<E, Sh>(self, shape: Sh) -> Config<K, E>
84 where
85 E: RemoveAxis,
86 Sh: ShapeBuilder<Dim = E>,
87 {
88 Config {
89 layout: self.layout.with_shape(shape),
90 name: self.name,
91 _biased: self._biased,
92 }
93 }
94
95 pub fn into_dimensionality<E>(self, dim: E) -> Result<Config<K, E>, nd::ShapeError>
97 where
98 E: Dimension,
99 {
100 let tmp = Config {
101 layout: self.layout.into_dimensionality(dim)?,
102 name: self.name,
103 _biased: self._biased,
104 };
105 Ok(tmp)
106 }
107 pub fn is_biased(&self) -> bool
110 where
111 K: 'static,
112 {
113 use core::any::TypeId;
114
115 TypeId::of::<K>() == TypeId::of::<Biased>()
116 }
117 pub fn features(&self) -> Features {
119 self.layout().features()
120 }
121 pub const fn layout(&self) -> &Layout<D> {
123 &self.layout
124 }
125 pub fn name(&self) -> &str {
127 &self.name
128 }
129
130 pub fn dim(&self) -> D {
132 self.layout().dim()
133 }
134
135 pub fn into_pattern(self) -> D::Pattern {
136 self.dim().into_pattern()
137 }
138
139 pub fn ndim(&self) -> usize {
140 self.layout().ndim()
141 }
142}
143
144impl<K> Config<K, Ix2> {
145 pub fn std(inputs: usize, outputs: usize) -> Self {
146 Self {
147 layout: Layout::new((outputs, inputs).into_dimension()),
148 name: String::new(),
149 _biased: PhantomData::<K>,
150 }
151 }
152}
153
154impl<D> Config<Biased, D>
155where
156 D: Dimension,
157{
158 pub fn biased() -> Self {
160 Self::new()
161 }
162}
163
164impl<D> Config<Unbiased, D>
165where
166 D: Dimension,
167{
168 pub fn unbiased() -> Self {
169 Self::new()
170 }
171}
172
173impl<K, D> concision::Config for Config<K, D> where D: Dimension {}
174
175impl<D> Default for Config<Biased, D>
176where
177 D: Dimension,
178{
179 fn default() -> Self {
180 Self::new()
181 }
182}