1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
    Appellation: store <mod>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
use super::{ParamKind, Parameter};
use crate::prelude::Map;
use ndarray::prelude::{Dimension, Ix2};
use num::Float;

#[derive(Clone, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ParamStore<T = f64, D = Ix2>
where
    T: Float,
    D: Dimension,
{
    store: Map<ParamKind, Parameter<T, D>>,
}

impl<T, D> ParamStore<T, D>
where
    D: Dimension,
    T: Float,
{
    pub fn new() -> Self {
        Self { store: Map::new() }
    }

    pub fn get(&self, kind: &ParamKind) -> Option<&Parameter<T, D>> {
        self.store.get(kind)
    }

    pub fn get_mut(&mut self, kind: &ParamKind) -> Option<&mut Parameter<T, D>> {
        self.store.get_mut(kind)
    }

    pub fn insert(&mut self, param: Parameter<T, D>) {
        self.store.insert(param.kind().clone(), param);
    }

    pub fn remove(&mut self, kind: &ParamKind) -> Option<Parameter<T, D>> {
        self.store.remove(kind)
    }
}

impl<T, D> Extend<Parameter<T, D>> for ParamStore<T, D>
where
    D: Dimension,
    T: Float,
{
    fn extend<I: IntoIterator<Item = Parameter<T, D>>>(&mut self, iter: I) {
        for param in iter {
            self.insert(param);
        }
    }
}

macro_rules! impl_into_iter {
    ($($p:ident)::*) => {
        impl_into_iter!(@impl $($p)::*);
    };
    (@impl $($p:ident)::*) => {
        impl<T, D> IntoIterator for ParamStore<T, D>
        where
            D: Dimension,
            T: Float,
        {
            type Item = (ParamKind, Parameter<T, D>);
            type IntoIter = $($p)::*::IntoIter<ParamKind, Parameter<T, D>>;

            fn into_iter(self) -> Self::IntoIter {
                self.store.into_iter()
            }
        }

        impl<'a, T, D> IntoIterator for &'a mut ParamStore<T, D>
        where
            D: Dimension,
            T: Float,
        {
            type Item = (&'a ParamKind, &'a mut Parameter<T, D>);
            type IntoIter = $($p)::*::IterMut<'a, ParamKind, Parameter<T, D>>;

            fn into_iter(self) -> Self::IntoIter {
                self.store.iter_mut()
            }
        }
    };

}
#[cfg(feature = "std")]
impl_into_iter!(std::collections::hash_map);
#[cfg(not(feature = "std"))]
impl_into_iter!(alloc::collections::btree_map);

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_store() {
        let (inputs, outputs) = (5, 3);

        let _shapes = [(inputs, outputs), (outputs, outputs), (outputs, 1)];

        let _params = ParamStore::<f64>::new();
    }
}