concision_neural/model/
store.rs1use crate::ModelFeatures;
6use concision_core::params::Params;
7
8#[derive(Clone, Debug)]
16pub struct ModelParams<A = f64> {
17 pub(crate) input: Params<A>,
18 pub(crate) hidden: Vec<Params<A>>,
19 pub(crate) output: Params<A>,
20}
21
22impl<A> ModelParams<A> {
23 pub fn default(features: ModelFeatures) -> Self
26 where
27 A: Clone + Default,
28 {
29 let input = Params::default(features.d_input());
30 let hidden = (0..features.layers())
31 .map(|_| Params::default(features.d_hidden()))
32 .collect::<Vec<_>>();
33 let output = Params::default(features.d_output());
34 Self {
35 input,
36 hidden,
37 output,
38 }
39 }
40 pub fn ones(features: ModelFeatures) -> Self
43 where
44 A: Clone + num_traits::One,
45 {
46 let input = Params::ones(features.d_input());
47 let hidden = (0..features.layers())
48 .map(|_| Params::ones(features.d_hidden()))
49 .collect::<Vec<_>>();
50 let output = Params::ones(features.d_output());
51 Self {
52 input,
53 hidden,
54 output,
55 }
56 }
57 pub fn zeros(features: ModelFeatures) -> Self
60 where
61 A: Clone + num_traits::Zero,
62 {
63 let input = Params::zeros(features.d_input());
64 let hidden = (0..features.layers())
65 .map(|_| Params::zeros(features.d_hidden()))
66 .collect::<Vec<_>>();
67 let output = Params::zeros(features.d_output());
68 Self {
69 input,
70 hidden,
71 output,
72 }
73 }
74 pub fn is_shallow(&self) -> bool {
76 self.hidden.is_empty() || self.hidden.len() == 1
77 }
78 pub const fn input(&self) -> &Params<A> {
80 &self.input
81 }
82 #[inline]
84 pub fn input_mut(&mut self) -> &mut Params<A> {
85 &mut self.input
86 }
87 pub const fn hidden(&self) -> &Vec<Params<A>> {
89 &self.hidden
90 }
91 #[inline]
93 pub fn hidden_as_slice(&self) -> &[Params<A>] {
94 self.hidden.as_slice()
95 }
96 #[inline]
98 pub fn hidden_mut(&mut self) -> &mut Vec<Params<A>> {
99 &mut self.hidden
100 }
101 pub const fn output(&self) -> &Params<A> {
103 &self.output
104 }
105 #[inline]
107 pub fn output_mut(&mut self) -> &mut Params<A> {
108 &mut self.output
109 }
110 pub fn set_input(&mut self, input: Params<A>) {
112 self.input = input;
113 }
114 pub fn set_hidden<I>(&mut self, iter: I)
116 where
117 I: IntoIterator<Item = Params<A>>,
118 {
119 self.hidden = Vec::from_iter(iter);
120 }
121 pub fn set_output(&mut self, output: Params<A>) {
123 self.output = output;
124 }
125 pub fn with_input(self, input: Params<A>) -> Self {
127 Self { input, ..self }
128 }
129 pub fn with_hidden<I>(self, iter: I) -> Self
131 where
132 I: IntoIterator<Item = Params<A>>,
133 {
134 Self {
135 hidden: Vec::from_iter(iter),
136 ..self
137 }
138 }
139 pub fn with_output(self, output: Params<A>) -> Self {
141 Self { output, ..self }
142 }
143
144 pub fn dim_input(&self) -> (usize, usize) {
145 self.input().dim()
146 }
147
148 pub fn dim_hidden(&self) -> (usize, usize) {
149 assert!(self.hidden.iter().all(|p| p.dim() == self.hidden[0].dim()));
150 self.hidden[0].dim()
151 }
152
153 pub fn dim_output(&self) -> (usize, usize) {
154 self.output.dim()
155 }
156
157 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
158 where
159 A: Clone,
160 Params<A>: cnc::Forward<X, Output = Y> + cnc::Forward<Y, Output = Y>,
161 {
162 let mut output = self.input.forward(input)?;
163 for layer in &self.hidden {
164 output = layer.forward(&output)?;
165 }
166 self.output.forward(&output)
167 }
168}