concision_neural/model/impls/
impl_model_params.rs1use crate::model::{ModelFeatures, ModelParamsBase};
6
7use cnc::params::ParamsBase;
8use ndarray::{DataOwned, Dimension, RawData};
9use num_traits::{One, Zero};
10
11impl<A, S> ModelParamsBase<S>
12where
13 S: RawData<Elem = A>,
14{
15 pub fn default(features: ModelFeatures) -> Self
18 where
19 A: Clone + Default,
20 S: DataOwned,
21 {
22 let input = ParamsBase::default(features.dim_input());
23 let hidden = (0..features.layers())
24 .map(|_| ParamsBase::default(features.dim_hidden()))
25 .collect::<Vec<_>>();
26 let output = ParamsBase::default(features.dim_output());
27 Self::new(input, hidden, output)
28 }
29 pub fn ones(features: ModelFeatures) -> Self
32 where
33 A: Clone + One,
34 S: DataOwned,
35 {
36 let input = ParamsBase::ones(features.dim_input());
37 let hidden = (0..features.layers())
38 .map(|_| ParamsBase::ones(features.dim_hidden()))
39 .collect::<Vec<_>>();
40 let output = ParamsBase::ones(features.dim_output());
41 Self::new(input, hidden, output)
42 }
43 pub fn zeros(features: ModelFeatures) -> Self
46 where
47 A: Clone + Zero,
48 S: DataOwned,
49 {
50 let input = ParamsBase::zeros(features.dim_input());
51 let hidden = (0..features.layers())
52 .map(|_| ParamsBase::zeros(features.dim_hidden()))
53 .collect::<Vec<_>>();
54 let output = ParamsBase::zeros(features.dim_output());
55 Self::new(input, hidden, output)
56 }
57}
58
59impl<A, S, D> core::ops::Index<usize> for ModelParamsBase<S, D>
60where
61 A: Clone,
62 D: Dimension,
63 S: ndarray::Data<Elem = A>,
64{
65 type Output = ParamsBase<S, D>;
66
67 fn index(&self, index: usize) -> &Self::Output {
68 if index == 0 {
69 &self.input
70 } else if index == self.count_hidden() + 1 {
71 &self.output
72 } else {
73 &self.hidden[index - 1]
74 }
75 }
76}
77
78impl<A, S, D> core::ops::IndexMut<usize> for ModelParamsBase<S, D>
79where
80 A: Clone,
81 D: Dimension,
82 S: ndarray::Data<Elem = A>,
83{
84 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
85 if index == 0 {
86 &mut self.input
87 } else if index == self.count_hidden() + 1 {
88 &mut self.output
89 } else {
90 &mut self.hidden[index - 1]
91 }
92 }
93}