concision_neural/params/impls/
impl_model_params.rs

1/*
2    appellation: impl_model_params <module>
3    authors: @FL03
4*/
5use crate::params::ModelParamsBase;
6
7use crate::{DeepModelRepr, RawHidden};
8use cnc::params::ParamsBase;
9use ndarray::{Data, Dimension, RawDataClone};
10
11impl<A, S, D, H> Clone for ModelParamsBase<S, D, H>
12where
13    D: Dimension,
14    H: RawHidden<S, D> + Clone,
15    S: RawDataClone<Elem = A>,
16    A: Clone,
17{
18    fn clone(&self) -> Self {
19        Self {
20            input: self.input().clone(),
21            hidden: self.hidden().clone(),
22            output: self.output().clone(),
23        }
24    }
25}
26
27impl<A, S, D, H> core::fmt::Debug for ModelParamsBase<S, D, H>
28where
29    D: Dimension,
30    H: RawHidden<S, D> + core::fmt::Debug,
31    S: Data<Elem = A>,
32    A: core::fmt::Debug,
33{
34    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
35        f.debug_struct("ModelParams")
36            .field("input", self.input())
37            .field("hidden", self.hidden())
38            .field("output", self.output())
39            .finish()
40    }
41}
42
43impl<A, S, D, H> core::fmt::Display for ModelParamsBase<S, D, H>
44where
45    D: Dimension,
46    H: RawHidden<S, D> + core::fmt::Debug,
47    S: Data<Elem = A>,
48    A: core::fmt::Display,
49{
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        write!(
52            f,
53            "{{ input: {i}, hidden: {h:?}, output: {o} }}",
54            i = self.input(),
55            h = self.hidden(),
56            o = self.output()
57        )
58    }
59}
60
61impl<A, S, D, H> core::ops::Index<usize> for ModelParamsBase<S, D, H>
62where
63    D: Dimension,
64    S: Data<Elem = A>,
65    H: DeepModelRepr<S, D>,
66    A: Clone,
67{
68    type Output = ParamsBase<S, D>;
69
70    fn index(&self, index: usize) -> &Self::Output {
71        if index == 0 {
72            self.input()
73        } else if index == self.count_hidden() + 1 {
74            self.output()
75        } else {
76            &self.hidden().as_slice()[index - 1]
77        }
78    }
79}
80
81impl<A, S, D, H> core::ops::IndexMut<usize> for ModelParamsBase<S, D, H>
82where
83    D: Dimension,
84    S: Data<Elem = A>,
85    H: DeepModelRepr<S, D>,
86    A: Clone,
87{
88    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
89        if index == 0 {
90            self.input_mut()
91        } else if index == self.count_hidden() + 1 {
92            self.output_mut()
93        } else {
94            &mut self.hidden_mut().as_mut_slice()[index - 1]
95        }
96    }
97}