1use crate::{build_bias, Biased, Features, Node, ParamMode, Unbiased};
6use concision::dimensional;
7use core::marker::PhantomData;
8use nd::*;
9use num::{One, Zero};
10
11pub struct ParamsBase<S = OwnedRepr<f64>, D = Ix2, K = Biased>
16where
17 D: Dimension,
18 S: RawData,
19{
20 pub(crate) bias: Option<ArrayBase<S, D::Smaller>>,
21 pub(crate) weight: ArrayBase<S, D>,
22 pub(crate) _mode: PhantomData<K>,
23}
24
25impl<A, S, D, K> ParamsBase<S, D, K>
26where
27 D: RemoveAxis,
28 S: RawData<Elem = A>,
29{
30 pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
31 where
32 A: Clone,
33 K: ParamMode,
34 S: DataOwned,
35 Sh: ShapeBuilder<Dim = D>,
36 {
37 let dim = shape.into_shape().raw_dim().clone();
38 let bias = if K::BIASED {
39 Some(ArrayBase::from_elem(
40 crate::bias_dim(dim.clone()),
41 elem.clone(),
42 ))
43 } else {
44 None
45 };
46 Self {
47 bias,
48 weight: ArrayBase::from_elem(dim, elem),
49 _mode: PhantomData::<K>,
50 }
51 }
52
53 pub fn into_biased(self) -> ParamsBase<S, D, Biased>
54 where
55 A: Default,
56 K: 'static,
57 S: DataOwned,
58 {
59 if self.is_biased() {
60 return ParamsBase {
61 bias: self.bias,
62 weight: self.weight,
63 _mode: PhantomData::<Biased>,
64 };
65 }
66 let sm = crate::bias_dim(self.raw_dim());
67 ParamsBase {
68 bias: Some(ArrayBase::default(sm)),
69 weight: self.weight,
70 _mode: PhantomData::<Biased>,
71 }
72 }
73
74 pub fn into_unbiased(self) -> ParamsBase<S, D, Unbiased> {
75 ParamsBase {
76 bias: None,
77 weight: self.weight,
78 _mode: PhantomData::<Unbiased>,
79 }
80 }
81
82 pub const fn weights(&self) -> &ArrayBase<S, D> {
83 &self.weight
84 }
85
86 pub fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
87 &mut self.weight
88 }
89
90 pub fn features(&self) -> Features {
91 Features::from_shape(self.shape())
92 }
93
94 pub fn in_features(&self) -> usize {
95 self.features().dmodel()
96 }
97
98 pub fn out_features(&self) -> usize {
99 if self.ndim() == 1 {
100 return 1;
101 }
102 self.shape()[1]
103 }
104 pub fn is_biased(&self) -> bool
107 where
108 K: 'static,
109 {
110 crate::is_biased::<K>()
111 }
112
113 pbuilder!(new.default where A: Default, S: DataOwned);
114
115 pbuilder!(ones where A: Clone + One, S: DataOwned);
116
117 pbuilder!(zeros where A: Clone + Zero, S: DataOwned);
118
119 dimensional!(weights());
120
121 wnbview!(into_owned::<OwnedRepr>(self) where A: Clone, S: Data);
122
123 wnbview!(into_shared::<OwnedArcRepr>(self) where A: Clone, S: DataOwned);
124
125 wnbview!(to_owned::<OwnedRepr>(&self) where A: Clone, S: Data);
126
127 wnbview!(to_shared::<OwnedArcRepr>(&self) where A: Clone, S: DataOwned);
128
129 wnbview!(view::<'a, ViewRepr>(&self) where A: Clone, S: Data);
130
131 wnbview!(view_mut::<'a, ViewRepr>(&mut self) where A: Clone, S: DataMut);
132}
133
134impl<A, S, D> ParamsBase<S, D, Biased>
135where
136 D: RemoveAxis,
137 S: RawData<Elem = A>,
138{
139 pub fn biased<Sh>(shape: Sh) -> Self
141 where
142 A: Default,
143 S: DataOwned,
144 Sh: ShapeBuilder<Dim = D>,
145 {
146 let dim = shape.into_shape().raw_dim().clone();
147 Self {
148 bias: build_bias(true, dim.clone(), ArrayBase::default),
149 weight: ArrayBase::default(dim),
150 _mode: PhantomData::<Biased>,
151 }
152 }
153 pub fn bias(&self) -> &ArrayBase<S, D::Smaller> {
155 self.bias.as_ref().unwrap()
156 }
157 pub fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
159 self.bias.as_mut().unwrap()
160 }
161}
162
163impl<A, S, D> ParamsBase<S, D, Unbiased>
164where
165 D: Dimension,
166 S: RawData<Elem = A>,
167{
168 pub fn unbiased<Sh>(shape: Sh) -> Self
170 where
171 A: Default,
172 S: DataOwned,
173 Sh: ShapeBuilder<Dim = D>,
174 {
175 Self {
176 bias: None,
177 weight: ArrayBase::default(shape),
178 _mode: PhantomData::<Unbiased>,
179 }
180 }
181}
182impl<A, S, K> ParamsBase<S, Ix2, K>
183where
184 K: 'static,
185 S: RawData<Elem = A>,
186{
187 pub fn set_node(&mut self, idx: usize, node: Node<A>)
188 where
189 A: Clone + Default,
190 S: DataMut + DataOwned,
191 {
192 let (weight, bias) = node;
193 if let Some(bias) = bias {
194 if !self.is_biased() {
195 let mut tmp = ArrayBase::default(self.out_features());
196 tmp.index_axis_mut(Axis(0), idx).assign(&bias);
197 self.bias = Some(tmp);
198 }
199 self.bias
200 .as_mut()
201 .unwrap()
202 .index_axis_mut(Axis(0), idx)
203 .assign(&bias);
204 }
205
206 self.weights_mut()
207 .index_axis_mut(Axis(0), idx)
208 .assign(&weight);
209 }
210}
211
212impl<A, S, D> Default for ParamsBase<S, D, Biased>
213where
214 A: Default,
215 D: Dimension,
216 S: DataOwned<Elem = A>,
217{
218 fn default() -> Self {
219 Self {
220 bias: Some(Default::default()),
221 weight: Default::default(),
222 _mode: PhantomData::<Biased>,
223 }
224 }
225}
226
227impl<A, S, D> Default for ParamsBase<S, D, Unbiased>
228where
229 A: Default,
230 D: Dimension,
231 S: DataOwned<Elem = A>,
232{
233 fn default() -> Self {
234 Self {
235 bias: None,
236 weight: Default::default(),
237 _mode: PhantomData::<Unbiased>,
238 }
239 }
240}