1use super::{Config, Layout};
6use crate::{Biased, LinearParams, ParamMode, ParamsBase, Unbiased};
7use concision::prelude::{Predict, Result};
8use nd::prelude::*;
9use nd::{DataOwned, OwnedRepr, RawData, RemoveAxis};
10
11pub struct Linear<A = f64, K = Biased, D = Ix2, S = OwnedRepr<A>>
17where
18 D: Dimension,
19 S: RawData<Elem = A>,
20{
21 pub(crate) config: Config<K, D>,
22 pub(crate) params: ParamsBase<S, D, K>,
23}
24
25impl<A, K> Linear<A, K, Ix2, OwnedRepr<A>>
26where
27 K: ParamMode,
28{
29 pub fn std(inputs: usize, outputs: usize) -> Self
30 where
31 A: Default,
32 {
33 let config = Config::<K, Ix2>::new().with_shape((inputs, outputs));
34 let params = ParamsBase::new(config.features());
35 Linear { config, params }
36 }
37}
38
39impl<A, S, D, K> Linear<A, K, D, S>
40where
41 D: RemoveAxis,
42 K: ParamMode,
43 S: RawData<Elem = A>,
44{
45 mbuilder!(new where A: Default, S: DataOwned);
46 mbuilder!(ones where A: Clone + num::One, S: DataOwned);
47 mbuilder!(zeros where A: Clone + num::Zero, S: DataOwned);
48
49 pub fn from_config(config: Config<K, D>) -> Self
50 where
51 A: Clone + Default,
52 K: ParamMode,
53 S: DataOwned,
54 {
55 let params = ParamsBase::new(config.dim());
56 Self { config, params }
57 }
58
59 pub fn from_layout(layout: Layout<D>) -> Self
60 where
61 A: Clone + Default,
62 K: ParamMode,
63 S: DataOwned,
64 {
65 let config = Config::<K, D>::new().with_layout(layout);
66 let params = ParamsBase::new(config.dim());
67 Self { config, params }
68 }
69
70 pub fn from_params(params: ParamsBase<S, D, K>) -> Self {
71 let config = Config::<K, D>::new().with_shape(params.raw_dim());
72 Self { config, params }
73 }
74
75 pub fn activate<X, Y, F>(&self, args: &X, func: F) -> Result<Y>
77 where
78 F: Fn(Y) -> Y,
79 Self: Predict<X, Output = Y>,
80 {
81 Ok(func(self.predict(args)?))
82 }
83
84 pub const fn config(&self) -> &Config<K, D> {
85 &self.config
86 }
87
88 pub fn weights(&self) -> &ArrayBase<S, D> {
89 self.params.weights()
90 }
91
92 pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
93 self.params.weights_mut()
94 }
95
96 pub const fn params(&self) -> &ParamsBase<S, D, K> {
97 &self.params
98 }
99
100 pub fn params_mut(&mut self) -> &mut ParamsBase<S, D, K> {
101 &mut self.params
102 }
103
104 pub fn into_biased(self) -> Linear<A, Biased, D, S>
105 where
106 A: Default,
107 K: 'static,
108 S: DataOwned,
109 {
110 Linear {
111 config: self.config.into_biased(),
112 params: self.params.into_biased(),
113 }
114 }
115
116 pub fn into_unbiased(self) -> Linear<A, Unbiased, D, S>
117 where
118 A: Default,
119 K: 'static,
120 S: DataOwned,
121 {
122 Linear {
123 config: self.config.into_unbiased(),
124 params: self.params.into_unbiased(),
125 }
126 }
127
128 pub fn is_biased(&self) -> bool
129 where
130 K: 'static,
131 {
132 self.config().is_biased()
133 }
134
135 pub fn with_params<E>(self, params: LinearParams<A, K, E>) -> Linear<A, K, E>
136 where
137 E: RemoveAxis,
138 {
139 let config = self.config.into_dimensionality(params.raw_dim()).unwrap();
140 Linear { config, params }
141 }
142
143 pub fn with_name(self, name: impl ToString) -> Self {
144 Self {
145 config: self.config.with_name(name),
146 ..self
147 }
148 }
149
150 concision::dimensional!(params());
151}
152
153impl<A, S, D> Linear<A, Biased, D, S>
154where
155 D: RemoveAxis,
156 S: RawData<Elem = A>,
157{
158 pub fn biased<Sh>(shape: Sh) -> Self
159 where
160 A: Default,
161 S: DataOwned,
162 Sh: ShapeBuilder<Dim = D>,
163 {
164 let config = Config::<Biased, D>::new().with_shape(shape);
165 let params = ParamsBase::biased(config.dim());
166 Linear { config, params }
167 }
168
169 pub fn bias(&self) -> &ArrayBase<S, D::Smaller> {
170 self.params().bias()
171 }
172
173 pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
174 self.params_mut().bias_mut()
175 }
176}
177
178impl<A, S, D> Linear<A, Unbiased, D, S>
179where
180 D: RemoveAxis,
181 S: RawData<Elem = A>,
182{
183 pub fn unbiased<Sh>(shape: Sh) -> Self
184 where
185 A: Default,
186 S: DataOwned,
187 Sh: ShapeBuilder<Dim = D>,
188 {
189 let config = Config::<Unbiased, D>::new().with_shape(shape);
190 let params = ParamsBase::unbiased(config.dim());
191 Linear { config, params }
192 }
193}