1use crate::params_base::ParamsBase;
7
8use crate::Params;
9use crate::traits::{Biased, Weighted};
10use concision_traits::{Apply, FillLike, MapInto, MapTo, OnesLike, ZerosLike};
11use core::iter::Once;
12use ndarray::{ArrayBase, Data, DataOwned, Dimension, OwnedRepr, RawData};
13use num_traits::{One, Zero};
14use rspace_traits::RawSpace;
15
16impl<A, S, D> RawSpace for ParamsBase<S, D, A>
17where
18 D: Dimension,
19 S: RawData<Elem = A>,
20{
21 type Elem = A;
22}
23
24impl<A, S, D> Weighted<S, D, A> for ParamsBase<S, D, A>
25where
26 S: RawData<Elem = A>,
27 D: Dimension,
28{
29 type Tensor<_S, _D, _A>
30 = ArrayBase<_S, _D, _A>
31 where
32 _D: Dimension,
33 _S: RawData<Elem = _A>;
34
35 fn weights(&self) -> &ArrayBase<S, D, A> {
36 self.weights()
37 }
38
39 fn weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
40 self.weights_mut()
41 }
42}
43
44impl<A, S, D> Biased<S, D, A> for ParamsBase<S, D, A>
45where
46 S: RawData<Elem = A>,
47 D: Dimension,
48{
49 fn bias(&self) -> &ArrayBase<S, D::Smaller, A> {
50 self.bias()
51 }
52
53 fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A> {
54 self.bias_mut()
55 }
56}
57
58impl<S, D> core::ops::Deref for ParamsBase<S, D>
59where
60 D: Dimension,
61 S: RawData,
62{
63 type Target = ndarray::LayoutRef<S::Elem, D>;
64
65 fn deref(&self) -> &Self::Target {
66 self.weights().as_layout_ref()
67 }
68}
69
70impl<A, S, D> core::fmt::Debug for ParamsBase<S, D, A>
71where
72 D: Dimension,
73 S: Data<Elem = A>,
74 A: core::fmt::Debug,
75{
76 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
77 f.debug_struct("ModelParams")
78 .field("bias", self.bias())
79 .field("weights", self.weights())
80 .finish()
81 }
82}
83
84impl<A, S, D> core::fmt::Display for ParamsBase<S, D, A>
85where
86 D: Dimension,
87 S: Data<Elem = A>,
88 A: core::fmt::Display,
89{
90 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
91 write!(
92 f,
93 "{{ bias: {}, weights: {} }}",
94 self.bias(),
95 self.weights()
96 )
97 }
98}
99
100impl<A, S, D> Clone for ParamsBase<S, D, A>
101where
102 D: Dimension,
103 S: ndarray::RawDataClone<Elem = A>,
104 A: Clone,
105{
106 fn clone(&self) -> Self {
107 Self::new(self.bias().clone(), self.weights().clone())
108 }
109}
110
111impl<A, S, D> Copy for ParamsBase<S, D, A>
112where
113 D: Dimension + Copy,
114 <D as Dimension>::Smaller: Copy,
115 S: ndarray::RawDataClone<Elem = A> + Copy,
116 A: Copy,
117{
118}
119
120impl<A, S, D> PartialEq for ParamsBase<S, D, A>
121where
122 D: Dimension,
123 S: Data<Elem = A>,
124 A: PartialEq,
125{
126 fn eq(&self, other: &Self) -> bool {
127 self.bias() == other.bias() && self.weights() == other.weights()
128 }
129}
130
131impl<A, S, D> PartialEq<&ParamsBase<S, D, A>> for ParamsBase<S, D, A>
132where
133 D: Dimension,
134 S: Data<Elem = A>,
135 A: PartialEq,
136{
137 fn eq(&self, other: &&ParamsBase<S, D, A>) -> bool {
138 self.bias() == other.bias() && self.weights() == other.weights()
139 }
140}
141
142impl<A, S, D> PartialEq<&mut ParamsBase<S, D, A>> for ParamsBase<S, D, A>
143where
144 D: Dimension,
145 S: Data<Elem = A>,
146 A: PartialEq,
147{
148 fn eq(&self, other: &&mut ParamsBase<S, D, A>) -> bool {
149 self.bias() == other.bias() && self.weights() == other.weights()
150 }
151}
152
153impl<A, S, D> Eq for ParamsBase<S, D, A>
154where
155 D: Dimension,
156 S: Data<Elem = A>,
157 A: Eq,
158{
159}
160
161impl<A, S, D> IntoIterator for ParamsBase<S, D, A>
162where
163 D: Dimension,
164 S: RawData<Elem = A>,
165{
166 type Item = ParamsBase<S, D, A>;
167 type IntoIter = Once<ParamsBase<S, D, A>>;
168
169 fn into_iter(self) -> Self::IntoIter {
170 core::iter::once(self)
171 }
172}
173
174impl<A, B, S, D, F> Apply<F> for ParamsBase<S, D, A>
175where
176 A: Clone,
177 D: Dimension,
178 S: Data<Elem = A>,
179 F: Fn(A) -> B,
180{
181 type Output = ParamsBase<OwnedRepr<B>, D>;
182
183 fn apply(&self, func: F) -> Self::Output {
184 ParamsBase {
185 bias: self.bias().mapv(&func),
186 weights: self.weights().mapv(&func),
187 }
188 }
189}
190
191impl<A, B, S, D, F> MapInto<F, B> for ParamsBase<S, D, A>
192where
193 A: Clone,
194 D: Dimension,
195 S: Data<Elem = A>,
196 F: Fn(A) -> B,
197{
198 type Elem = A;
199 type Cont<T> = Params<T, D>;
200
201 fn mapi(self, func: F) -> Self::Cont<B> {
202 ParamsBase {
203 bias: self.bias().mapv(&func),
204 weights: self.weights().mapv(&func),
205 }
206 }
207}
208
209impl<A, B, S, D, F> MapInto<F, B> for &ParamsBase<S, D, A>
210where
211 A: Clone,
212 D: Dimension,
213 S: Data<Elem = A>,
214 F: Fn(A) -> B,
215{
216 type Elem = A;
217 type Cont<T> = Params<T, D>;
218
219 fn mapi(self, func: F) -> Self::Cont<B> {
220 ParamsBase {
221 bias: self.bias().mapv(&func),
222 weights: self.weights().mapv(&func),
223 }
224 }
225}
226
227impl<A, B, S, D, F> MapTo<F, B> for ParamsBase<S, D, A>
228where
229 A: Clone,
230 D: Dimension,
231 S: Data<Elem = A>,
232 F: Fn(A) -> B,
233{
234 type Cont<V> = Params<V, D>;
235 type Elem = A;
236
237 fn mapt(&self, func: F) -> Self::Cont<B> {
238 ParamsBase {
239 bias: self.bias().mapv(&func),
240 weights: self.weights().mapv(&func),
241 }
242 }
243}
244
245impl<A, S, D> OnesLike for ParamsBase<S, D, A>
246where
247 D: Dimension,
248 S: DataOwned<Elem = A>,
249 A: Clone + One,
250{
251 type Output = ParamsBase<S, D, A>;
252
253 fn ones_like(&self) -> Self::Output {
254 ParamsBase {
255 bias: self.bias().ones_like(),
256 weights: self.weights().ones_like(),
257 }
258 }
259}
260
261impl<A, S, D> ZerosLike for ParamsBase<S, D, A>
262where
263 D: Dimension,
264 S: DataOwned<Elem = A>,
265 A: Clone + Zero,
266{
267 type Output = ParamsBase<S, D, A>;
268
269 fn zeros_like(&self) -> Self::Output {
270 ParamsBase {
271 bias: self.bias().zeros_like(),
272 weights: self.weights().zeros_like(),
273 }
274 }
275}
276
277impl<A, S, D> FillLike<A> for ParamsBase<S, D, A>
278where
279 D: Dimension,
280 S: DataOwned<Elem = A>,
281 A: Clone,
282{
283 type Output = ParamsBase<S, D, A>;
284
285 fn fill_like(&self, elem: A) -> Self::Output {
286 ParamsBase {
287 bias: self.bias().fill_like(elem.clone()),
288 weights: self.weights().fill_like(elem),
289 }
290 }
291}