concision_core/models/impls/
impl_params_shallow.rs1use crate::models::{ModelParamsBase, ShallowParamsBase};
6
7use crate::ModelFeatures;
8use crate::activate::{ReLUActivation, SigmoidActivation};
9use crate::models::traits::ShallowModelRepr;
10use concision_params::ParamsBase;
11use ndarray::{
12 Array1, ArrayBase, Data, DataOwned, Dimension, Ix2, RawData, RemoveAxis, ScalarOperand,
13};
14use num_traits::Float;
15
16impl<S, D, H, A> ModelParamsBase<S, D, H, A>
17where
18 D: Dimension,
19 S: RawData<Elem = A>,
20 H: ShallowModelRepr<S, D>,
21{
22 pub const fn shallow(input: ParamsBase<S, D>, hidden: H, output: ParamsBase<S, D>) -> Self {
24 Self {
25 input,
26 hidden,
27 output,
28 }
29 }
30}
31
32impl<S, D, A> ShallowParamsBase<S, D, A>
33where
34 S: RawData<Elem = A>,
35 D: Dimension,
36{
37 #[allow(clippy::should_implement_trait)]
38 pub fn default(input: D, hidden: D, output: D) -> Self
41 where
42 A: Clone + Default,
43 S: DataOwned,
44 D: RemoveAxis,
45 {
46 Self {
47 hidden: ParamsBase::default(hidden),
48 input: ParamsBase::default(input),
49 output: ParamsBase::default(output),
50 }
51 }
52 #[inline]
54 pub fn size(&self) -> usize {
55 let mut size = self.input().count_weights();
56 size += self.hidden().count_weights();
57 size + self.output().count_weights()
58 }
59 pub const fn hidden_weights(&self) -> &ArrayBase<S, D, A> {
61 self.hidden().weights()
62 }
63 pub const fn hidden_weights_mut(&mut self) -> &mut ArrayBase<S, D, A> {
65 self.hidden_mut().weights_mut()
66 }
67}
68
69impl<S, A> ShallowParamsBase<S, Ix2, A>
70where
71 S: RawData<Elem = A>,
72{
73 pub fn from_features(features: ModelFeatures) -> Self
74 where
75 A: Clone + Default,
76 S: DataOwned,
77 {
78 Self {
79 hidden: ParamsBase::default(features.dim_hidden()),
80 input: ParamsBase::default(features.dim_input()),
81 output: ParamsBase::default(features.dim_output()),
82 }
83 }
84 pub fn forward(&self, input: &Array1<A>) -> Array1<A>
86 where
87 A: Float + ScalarOperand,
88 S: Data,
89 {
90 use concision_traits::Forward;
91 let mut output = self.input().forward_then(input, |x| x.relu());
92 output = self.hidden().forward_then(&output, |x| x.relu());
93 self.output().forward_then(&output, |x| x.sigmoid())
94 }
95}
96
97impl<A, S> Default for ShallowParamsBase<S, Ix2, A>
98where
99 S: DataOwned<Elem = A>,
100 A: Clone + Default,
101{
102 fn default() -> Self {
103 Self::from_features(ModelFeatures::default())
104 }
105}