1use crate::Params;
7use crate::params_base::ParamsBase;
8use crate::utils::extract_bias_dim;
9use ndarray::{
10 ArrayBase, Axis, Data, DataMut, DataOwned, Dimension, LayoutRef, RawData, RawRef, RemoveAxis,
11 ShapeArg, ShapeBuilder,
12};
13
14impl<A, S, D> ParamsBase<S, D, A>
15where
16 D: Dimension,
17 S: RawData<Elem = A>,
18{
19 pub const fn new(bias: ArrayBase<S, D::Smaller, A>, weights: ArrayBase<S, D, A>) -> Self {
21 Self { bias, weights }
22 }
23 pub fn init_from_fn<Sh, F>(shape: Sh, init: F) -> Self
25 where
26 A: Clone,
27 D: RemoveAxis,
28 S: DataOwned,
29 Sh: ShapeBuilder<Dim = D>,
30 F: Fn() -> A,
31 {
32 let weights = ArrayBase::from_shape_fn(shape, |_| init());
33 let bias = ArrayBase::from_shape_fn(extract_bias_dim(&weights), |_| init());
35 Self::new(bias, weights)
37 }
38 pub fn from_shape_fn<Sh, F1, F2>(shape: Sh, w: F1, b: F2) -> Self
40 where
41 A: Clone,
42 D: RemoveAxis,
43 S: DataOwned,
44 Sh: ShapeBuilder<Dim = D>,
45 D::Smaller: Dimension + ShapeArg,
46 F1: Fn(<D as Dimension>::Pattern) -> A,
47 F2: Fn(<D::Smaller as Dimension>::Pattern) -> A,
48 {
49 let weights = ArrayBase::from_shape_fn(shape, w);
51 let bias = ArrayBase::from_shape_fn(extract_bias_dim(&weights), b);
53 Self::new(bias, weights)
55 }
56 pub fn from_bias<Sh>(shape: Sh, bias: ArrayBase<S, D::Smaller, A>) -> Self
58 where
59 A: Clone + Default,
60 D: RemoveAxis,
61 S: DataOwned,
62 Sh: ShapeBuilder<Dim = D>,
63 {
64 let weights = ArrayBase::from_elem(shape, A::default());
65 let bdim = extract_bias_dim(&weights);
66 if bias.raw_dim() != bdim {
67 panic!("the given bias shape is invalid");
68 }
69 Self::new(bias, weights)
70 }
71 pub fn from_weights(weights: ArrayBase<S, D, A>) -> Self
74 where
75 A: Clone + Default,
76 D: RemoveAxis,
77 S: DataOwned,
78 {
79 let bias = ArrayBase::from_elem(extract_bias_dim(&weights), A::default());
80 Self::new(bias, weights)
81 }
82 pub fn from_elem<Sh: ShapeBuilder<Dim = D>>(shape: Sh, elem: A) -> Self
84 where
85 A: Clone,
86 D: RemoveAxis,
87 S: DataOwned,
88 {
89 let weights = ArrayBase::from_elem(shape, elem.clone());
90 let bias = ArrayBase::from_elem(extract_bias_dim(&weights), elem);
91 Self::new(bias, weights)
92 }
93 #[allow(clippy::should_implement_trait)]
94 pub fn default<Sh>(shape: Sh) -> Self
96 where
97 A: Clone + Default,
98 D: RemoveAxis,
99 S: DataOwned,
100 Sh: ShapeBuilder<Dim = D>,
101 {
102 Self::from_elem(shape, A::default())
103 }
104 pub fn ones<Sh>(shape: Sh) -> Self
106 where
107 A: Clone + num_traits::One,
108 D: RemoveAxis,
109 S: DataOwned,
110 Sh: ShapeBuilder<Dim = D>,
111 {
112 Self::from_elem(shape, A::one())
113 }
114 pub fn zeros<Sh>(shape: Sh) -> Self
116 where
117 A: Clone + num_traits::Zero,
118 D: RemoveAxis,
119 S: DataOwned,
120 Sh: ShapeBuilder<Dim = D>,
121 {
122 Self::from_elem(shape, A::zero())
123 }
124 pub const fn bias(&self) -> &ArrayBase<S, D::Smaller, A> {
126 &self.bias
127 }
128 pub const fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
130 &mut self.bias
131 }
132 pub const fn weights(&self) -> &ArrayBase<S, D, A> {
134 &self.weights
135 }
136 pub const fn weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
138 &mut self.weights
139 }
140 pub fn bias_as_raw_ref(&self) -> &RawRef<A, D::Smaller>
142 where
143 S: Data,
144 {
145 self.bias().as_raw_ref()
146 }
147 pub fn weights_as_raw_ref(&self) -> &RawRef<A, D>
149 where
150 S: Data,
151 {
152 self.weights().as_raw_ref()
153 }
154 pub fn bias_layout_ref(&self) -> &LayoutRef<A, D::Smaller>
156 where
157 S: Data,
158 {
159 self.bias().as_layout_ref()
160 }
161 pub fn bias_layout_ref_mut(&mut self) -> &mut LayoutRef<A, D::Smaller>
163 where
164 S: DataMut,
165 {
166 self.bias_mut().as_layout_ref_mut()
167 }
168 pub fn weights_layout_ref(&self) -> &LayoutRef<A, D>
170 where
171 S: Data,
172 {
173 self.weights().as_layout_ref()
174 }
175 pub fn weights_layout_ref_mut(&mut self) -> &mut LayoutRef<A, D>
177 where
178 S: DataMut,
179 {
180 self.weights_mut().as_layout_ref_mut()
181 }
182 pub fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller, A>) -> &mut Self
184 where
185 A: Clone,
186 S: DataMut,
187 {
188 self.bias_mut().assign(bias);
189 self
190 }
191 pub fn assign_weights(&mut self, weights: &ArrayBase<S, D, A>) -> &mut Self
193 where
194 A: Clone,
195 S: DataMut,
196 {
197 self.weights_mut().assign(weights);
198 self
199 }
200 pub fn replace_bias(
202 &mut self,
203 bias: ArrayBase<S, D::Smaller, A>,
204 ) -> ArrayBase<S, D::Smaller, A> {
205 core::mem::replace(&mut self.bias, bias)
206 }
207 pub fn replace_weights(&mut self, weights: ArrayBase<S, D, A>) -> ArrayBase<S, D, A> {
209 core::mem::replace(&mut self.weights, weights)
210 }
211 pub fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> &mut Self {
213 *self.bias_mut() = bias;
214 self
215 }
216 pub fn set_weights(&mut self, weights: ArrayBase<S, D, A>) -> &mut Self {
218 *self.weights_mut() = weights;
219 self
220 }
221 pub fn dim(&self) -> D::Pattern {
223 self.weights().dim()
224 }
225 pub fn is_empty(&self) -> bool {
227 self.is_weights_empty() && self.is_bias_empty()
228 }
229 pub fn is_weights_empty(&self) -> bool {
231 self.weights().is_empty()
232 }
233 pub fn is_bias_empty(&self) -> bool {
235 self.bias().is_empty()
236 }
237 pub fn count_weights(&self) -> usize {
239 self.weights().len()
240 }
241 pub fn count_bias(&self) -> usize {
243 self.bias().len()
244 }
245 pub fn raw_dim(&self) -> D {
247 self.weights().raw_dim()
248 }
249 pub fn shape<'a>(&'a self) -> &'a [usize]
251 where
252 A: 'a,
253 {
254 self.weights.shape()
255 }
256 pub fn shape_bias(&self) -> &[usize]
259 where
260 A: 'static,
261 {
262 self.bias.shape()
263 }
264 pub fn size(&self) -> usize {
266 self.weights().len() + self.bias().len()
267 }
268 pub fn to_owned(&self) -> Params<A, D>
270 where
271 A: Clone,
272 S: DataOwned,
273 {
274 ParamsBase::new(self.bias().to_owned(), self.weights().to_owned())
275 }
276 pub fn to_shape<Sh>(
279 &self,
280 shape: Sh,
281 ) -> crate::Result<ParamsBase<ndarray::CowRepr<'_, A>, Sh::Dim>>
282 where
283 A: Clone,
284 S: DataOwned,
285 Sh: ShapeBuilder<Dim = D>,
286 Sh::Dim: Dimension + RemoveAxis,
287 {
288 let shape = shape.into_shape_with_order();
289 let dim = shape.raw_dim().clone();
290 let bias = self.bias().to_shape(dim.remove_axis(Axis(0)))?;
291 let weights = self.weights().to_shape(dim)?;
292 Ok(ParamsBase::new(bias, weights))
293 }
294 pub fn to_shared(&self) -> ParamsBase<ndarray::OwnedArcRepr<A>, D>
297 where
298 A: Clone,
299 S: Data,
300 {
301 ParamsBase::new(self.bias().to_shared(), self.weights().to_shared())
302 }
303 pub fn view(&self) -> ParamsBase<ndarray::ViewRepr<&'_ A>, D>
305 where
306 S: Data,
307 {
308 ParamsBase::new(self.bias().view(), self.weights().view())
309 }
310 pub fn view_mut(&mut self) -> ParamsBase<ndarray::ViewRepr<&'_ mut A>, D>
312 where
313 S: DataMut,
314 {
315 ParamsBase::new(self.bias.view_mut(), self.weights.view_mut())
316 }
317 pub fn clamp(&mut self, min: A, max: A) -> Params<A, D>
319 where
320 A: 'static + Clone + PartialOrd,
321 S: Data,
322 {
323 ParamsBase {
324 bias: self.bias().clamp(min.clone(), max.clone()),
325 weights: self.weights().clamp(min, max),
326 }
327 }
328 pub fn mapv<F, U>(&self, f: F) -> Params<U, D>
330 where
331 A: Clone,
332 S: Data,
333 F: Fn(A) -> U,
334 {
335 ParamsBase {
336 bias: self.bias().mapv(&f),
337 weights: self.weights().mapv(&f),
338 }
339 }
340}