concision_neural/params/impls/
impl_model_params.rs1use crate::params::ModelParamsBase;
6
7use crate::{DeepModelRepr, RawHidden};
8use cnc::params::ParamsBase;
9use ndarray::{Data, Dimension, RawDataClone};
10
11impl<A, S, D, H> Clone for ModelParamsBase<S, D, H>
12where
13 D: Dimension,
14 H: RawHidden<S, D> + Clone,
15 S: RawDataClone<Elem = A>,
16 A: Clone,
17{
18 fn clone(&self) -> Self {
19 Self {
20 input: self.input().clone(),
21 hidden: self.hidden().clone(),
22 output: self.output().clone(),
23 }
24 }
25}
26
27impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H>
28where
29 D: Dimension,
30 H: RawHidden<S, D> + core::fmt::Debug,
31 S: Data<Elem = A>,
32 A: core::fmt::Debug,
33{
34 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
35 f.debug_struct("ModelParams")
36 .field("input", self.input())
37 .field("hidden", self.hidden())
38 .field("output", self.output())
39 .finish()
40 }
41}
42
43impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H>
44where
45 D: Dimension,
46 H: RawHidden<S, D> + core::fmt::Debug,
47 S: Data<Elem = A>,
48 A: core::fmt::Display,
49{
50 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51 write!(
52 f,
53 "{{ input: {i}, hidden: {h:?}, output: {o} }}",
54 i = self.input(),
55 h = self.hidden(),
56 o = self.output()
57 )
58 }
59}
60
61impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H>
62where
63 D: Dimension,
64 S: Data<Elem = A>,
65 H: DeepModelRepr<S, D>,
66 A: Clone,
67{
68 type Output = ParamsBase<S, D>;
69
70 fn index(&self, index: usize) -> &Self::Output {
71 if index == 0 {
72 self.input()
73 } else if index == self.count_hidden() + 1 {
74 self.output()
75 } else {
76 &self.hidden().as_slice()[index - 1]
77 }
78 }
79}
80
81impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H>
82where
83 D: Dimension,
84 S: Data<Elem = A>,
85 H: DeepModelRepr<S, D>,
86 A: Clone,
87{
88 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
89 if index == 0 {
90 self.input_mut()
91 } else if index == self.count_hidden() + 1 {
92 self.output_mut()
93 } else {
94 &mut self.hidden_mut().as_mut_slice()[index - 1]
95 }
96 }
97}