concision_neural/model/params/
impl_params_shallow.rs1use crate::model::params::{ModelParamsBase, ShallowParamsBase};
6
7use crate::model::ModelFeatures;
8use crate::traits::ShallowNeuralStore;
9use cnc::{ParamsBase, ReLU, Sigmoid};
10use ndarray::{
11 Array1, ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, RemoveAxis, ScalarOperand,
12};
13use num_traits::Float;
14
15impl<S, D, H, A> ModelParamsBase<S, D, H>
16where
17 D: Dimension,
18 S: RawData<Elem = A>,
19 H: ShallowNeuralStore<S, D>,
20{
21 pub const fn shallow(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
23 Self {
24 input,
25 hidden,
26 output,
27 }
28 }
29}
30
31impl<S, D, A> ShallowParamsBase<S, D>
32where
33 S: RawData<Elem = A>,
34 D: Dimension,
35{
36 #[allow(clippy::should_implement_trait)]
37 pub fn default(input: D, hidden: D, output: D) -> Self
40 where
41 A: Clone + Default,
42 S: DataOwned,
43 D: RemoveAxis,
44 {
45 Self {
46 hidden: ParamsBase::default(hidden),
47 input: ParamsBase::default(input),
48 output: ParamsBase::default(output),
49 }
50 }
51 #[inline]
53 pub fn size(&self) -> usize {
54 let mut size = self.input().count_weight();
55 size += self.hidden().count_weight();
56 size + self.output().count_weight()
57 }
58 pub const fn hidden_weights(&self) -> &ArrayBase<S, D> {
60 self.hidden().weights()
61 }
62 pub const fn hidden_weights_mut(&mut self) -> &mut ArrayBase<S, D> {
64 self.hidden_mut().weights_mut()
65 }
66}
67
68impl<S, A> ShallowParamsBase<S, Ix2>
69where
70 S: RawData<Elem = A>,
71{
72 pub fn from_features(features: ModelFeatures) -> Self
73 where
74 A: Clone + Default,
75 S: DataOwned,
76 {
77 Self {
78 hidden: ParamsBase::default(features.dim_hidden()),
79 input: ParamsBase::default(features.dim_input()),
80 output: ParamsBase::default(features.dim_output()),
81 }
82 }
83 pub fn forward(&self, input: &Array1<A>) -> cnc::Result<Array1<A>>
85 where
86 A: Float + ScalarOperand,
87 S: Data,
88 {
89 let mut output = self.input().forward(input)?.relu();
91 output = self.hidden().forward(&output)?.relu();
93 output = self.output().forward(&output)?.sigmoid();
95
96 Ok(output)
97 }
98}
99
100impl<A, S> Default for ShallowParamsBase<S, Ix2>
101where
102 S: DataOwned<Elem = A>,
103 A: Clone + Default,
104{
105 fn default() -> Self {
106 Self::from_features(ModelFeatures::default())
107 }
108}