concision_params/traits/
wnb.rs1use ndarray::iter as nditer;
7use ndarray::{ArrayBase, Data, DataMut, Dimension, RawData};
8
9pub trait Weighted<S, D, A = <S as RawData>::Elem>: Sized
11where
12 D: Dimension,
13 S: RawData<Elem = A>,
14{
15 type Tensor<_S, _D, _A>
16 where
17 _D: Dimension,
18 _S: RawData<Elem = _A>;
19 fn weights(&self) -> &Self::Tensor<S, D, A>;
21 fn weights_mut(&mut self) -> &mut Self::Tensor<S, D, A>;
23 fn replace_weights(&mut self, weights: Self::Tensor<S, D, A>) -> Self::Tensor<S, D, A> {
25 core::mem::replace(self.weights_mut(), weights)
26 }
27 fn set_weights(&mut self, weights: Self::Tensor<S, D, A>) -> &mut Self {
29 *self.weights_mut() = weights;
30 self
31 }
32}
33
34pub trait Biased<S, D, A = <S as RawData>::Elem>: Weighted<S, D, A>
35where
36 D: Dimension,
37 S: RawData<Elem = A>,
38{
39 fn bias(&self) -> &ArrayBase<S, D::Smaller, A>;
41 fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller, A>;
43 fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller, A>) -> &mut Self
45 where
46 S: DataMut,
47 S::Elem: Clone,
48 {
49 self.bias_mut().assign(bias);
50 self
51 }
52 fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> ArrayBase<S, D::Smaller, A> {
54 core::mem::replace(self.bias_mut(), bias)
55 }
56 fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller, A>) -> &mut Self {
58 *self.bias_mut() = bias;
59 self
60 }
61 fn iter_bias<'a>(&'a self) -> nditer::Iter<'a, S::Elem, D::Smaller>
63 where
64 S: Data + 'a,
65 D: 'a,
66 {
67 self.bias().iter()
68 }
69 fn iter_bias_mut<'a>(&'a mut self) -> nditer::IterMut<'a, S::Elem, D::Smaller>
71 where
72 S: DataMut + 'a,
73 D: 'a,
74 {
75 self.bias_mut().iter_mut()
76 }
77}