concision_core/params/impls/
impl_params.rs

1/*
2    appellation: impl_params <module>
3    authors: @FL03
4*/
5use crate::params::ParamsBase;
6
7use core::iter::Once;
8use ndarray::{ArrayBase, Data, DataOwned, Dimension, Ix1, Ix2, RawData};
9
10impl<A, S> ParamsBase<S, Ix1>
11where
12    S: RawData<Elem = A>,
13{
14    /// returns a new instance of the [`ParamsBase`] initialized using a _scalar_ bias along
15    /// with the given, one-dimensional weight tensor.
16    pub fn from_scalar_bias(bias: A, weights: ArrayBase<S, Ix1>) -> Self
17    where
18        A: Clone,
19        S: DataOwned,
20    {
21        Self {
22            bias: ArrayBase::from_elem((), bias),
23            weights,
24        }
25    }
26    /// returns the number of rows in the weights matrix
27    pub fn nrows(&self) -> usize {
28        self.weights().len()
29    }
30}
31
32impl<A, S> ParamsBase<S, Ix2>
33where
34    S: RawData<Elem = A>,
35{
36    /// returns the number of columns in the weights matrix
37    pub fn ncols(&self) -> usize {
38        self.weights().ncols()
39    }
40    /// returns the number of rows in the weights matrix
41    pub fn nrows(&self) -> usize {
42        self.weights().nrows()
43    }
44}
45
46impl<A, S, D> core::fmt::Debug for ParamsBase<S, D>
47where
48    D: Dimension,
49    S: Data<Elem = A>,
50    A: core::fmt::Debug,
51{
52    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
53        f.debug_struct("ModelParams")
54            .field("bias", self.bias())
55            .field("weights", self.weights())
56            .finish()
57    }
58}
59
60impl<A, S, D> core::fmt::Display for ParamsBase<S, D>
61where
62    D: Dimension,
63    S: Data<Elem = A>,
64    A: core::fmt::Display,
65{
66    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
67        write!(
68            f,
69            "{{ bias: {}, weights: {} }}",
70            self.bias(),
71            self.weights()
72        )
73    }
74}
75
76impl<A, S, D> Clone for ParamsBase<S, D>
77where
78    D: Dimension,
79    S: ndarray::RawDataClone<Elem = A>,
80    A: Clone,
81{
82    fn clone(&self) -> Self {
83        Self::new(self.bias().clone(), self.weights().clone())
84    }
85}
86
87impl<A, S, D> Copy for ParamsBase<S, D>
88where
89    D: Dimension + Copy,
90    <D as Dimension>::Smaller: Copy,
91    S: ndarray::RawDataClone<Elem = A> + Copy,
92    A: Copy,
93{
94}
95
96impl<A, S, D> PartialEq for ParamsBase<S, D>
97where
98    D: Dimension,
99    S: Data<Elem = A>,
100    A: PartialEq,
101{
102    fn eq(&self, other: &Self) -> bool {
103        self.bias() == other.bias() && self.weights() == other.weights()
104    }
105}
106
107impl<A, S, D> PartialEq<&ParamsBase<S, D>> for ParamsBase<S, D>
108where
109    D: Dimension,
110    S: Data<Elem = A>,
111    A: PartialEq,
112{
113    fn eq(&self, other: &&ParamsBase<S, D>) -> bool {
114        self.bias() == other.bias() && self.weights() == other.weights()
115    }
116}
117
118impl<A, S, D> PartialEq<&mut ParamsBase<S, D>> for ParamsBase<S, D>
119where
120    D: Dimension,
121    S: Data<Elem = A>,
122    A: PartialEq,
123{
124    fn eq(&self, other: &&mut ParamsBase<S, D>) -> bool {
125        self.bias() == other.bias() && self.weights() == other.weights()
126    }
127}
128
129impl<A, S, D> Eq for ParamsBase<S, D>
130where
131    D: Dimension,
132    S: Data<Elem = A>,
133    A: Eq,
134{
135}
136
137impl<A, S, D> IntoIterator for ParamsBase<S, D>
138where
139    D: Dimension,
140    S: RawData<Elem = A>,
141{
142    type Item = ParamsBase<S, D>;
143    type IntoIter = Once<ParamsBase<S, D>>;
144
145    fn into_iter(self) -> Self::IntoIter {
146        core::iter::once(self)
147    }
148}