concision_core/traits/
wnb.rs

1use ndarray::{ArrayBase, Data, DataMut, Dimension, RawData};
2
3pub trait Weighted<S, D>
4where
5    D: Dimension,
6    S: RawData,
7{
8    /// returns the weights of the model
9    fn weights(&self) -> &ArrayBase<S, D>;
10    /// returns a mutable reference to the weights of the model
11    fn weights_mut(&mut self) -> &mut ArrayBase<S, D>;
12    /// assigns the given bias to the current weight
13    fn assign_weights(&mut self, weights: &ArrayBase<S, D>) -> &mut Self
14    where
15        S: DataMut,
16        S::Elem: Clone,
17    {
18        self.weights_mut().assign(weights);
19        self
20    }
21    /// replaces the current weights with the given weights
22    fn replace_weights(&mut self, weights: ArrayBase<S, D>) -> ArrayBase<S, D> {
23        core::mem::replace(self.weights_mut(), weights)
24    }
25    /// sets the weights of the model
26    fn set_weights(&mut self, weights: ArrayBase<S, D>) -> &mut Self {
27        *self.weights_mut() = weights;
28        self
29    }
30    /// returns an iterator over the weights
31    fn iter_weights<'a>(&'a self) -> ndarray::iter::Iter<'a, S::Elem, D>
32    where
33        S: Data + 'a,
34        D: 'a,
35    {
36        self.weights().iter()
37    }
38    /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
39    fn iter_weights_mut<'a>(&'a mut self) -> ndarray::iter::IterMut<'a, S::Elem, D>
40    where
41        S: DataMut + 'a,
42        D: 'a,
43    {
44        self.weights_mut().iter_mut()
45    }
46}
47
48pub trait Biased<S, D>: Weighted<S, D>
49where
50    D: Dimension,
51    S: RawData,
52{
53    /// returns the bias of the model
54    fn bias(&self) -> &ArrayBase<S, D::Smaller>;
55    /// returns a mutable reference to the bias of the model
56    fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller>;
57    /// assigns the given bias to the current bias
58    fn assign_bias(&mut self, bias: &ArrayBase<S, D::Smaller>) -> &mut Self
59    where
60        S: DataMut,
61        S::Elem: Clone,
62    {
63        self.bias_mut().assign(bias);
64        self
65    }
66    /// replaces the current bias with the given bias
67    fn replace_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> ArrayBase<S, D::Smaller> {
68        core::mem::replace(self.bias_mut(), bias)
69    }
70    /// sets the bias of the model
71    fn set_bias(&mut self, bias: ArrayBase<S, D::Smaller>) -> &mut Self {
72        *self.bias_mut() = bias;
73        self
74    }
75    /// returns an iterator over the bias
76    fn iter_bias<'a>(&'a self) -> ndarray::iter::Iter<'a, S::Elem, D::Smaller>
77    where
78        S: Data + 'a,
79        D: 'a,
80    {
81        self.bias().iter()
82    }
83    /// returns a mutable iterator over the bias
84    fn iter_bias_mut<'a>(&'a mut self) -> ndarray::iter::IterMut<'a, S::Elem, D::Smaller>
85    where
86        S: DataMut + 'a,
87        D: 'a,
88    {
89        self.bias_mut().iter_mut()
90    }
91}
92
93/*
94 ************* Implementations *************
95*/
96use crate::params::ParamsBase;
97
98impl<S, D> Weighted<S, D> for ParamsBase<S, D>
99where
100    S: RawData,
101    D: Dimension,
102{
103    fn weights(&self) -> &ArrayBase<S, D> {
104        &self.weights
105    }
106
107    fn weights_mut(&mut self) -> &mut ArrayBase<S, D> {
108        &mut self.weights
109    }
110}
111
112impl<S, D> Biased<S, D> for ParamsBase<S, D>
113where
114    S: RawData,
115    D: Dimension,
116{
117    fn bias(&self) -> &ArrayBase<S, D::Smaller> {
118        &self.bias
119    }
120
121    fn bias_mut(&mut self) -> &mut ArrayBase<S, D::Smaller> {
122        &mut self.bias
123    }
124}