concision_params/iter/
iter_params.rs

1/*
2    Appellation: iter <module>
3    Contrib: @FL03
4*/
5//! Iterators for parameters within a neural network
6use ndarray::Dimension;
7use ndarray::iter::{AxisIter, AxisIterMut};
8use ndarray::iter::{Iter as NdIter, IterMut as NdIterMut};
9
10pub(crate) type ItemRef<'a, A, D> = (
11    <AxisIter<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
12    &'a A,
13);
14pub(crate) type ItemMut<'a, A, D> = (
15    <AxisIterMut<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
16    &'a mut A,
17);
18/// The [`ParamsIter`] defines a iterator over owned parameters within a neural network by
19/// zipping together an axis iterator over the columns of the weights and an iterator over the
20/// bias.
21pub struct ParamsIter<'a, A, D>
22where
23    D: Dimension,
24{
25    pub(crate) weights: AxisIter<'a, A, D::Smaller>,
26    pub(crate) bias: NdIter<'a, A, D::Smaller>,
27}
28/// The [`ParamsIterMut`] type provides a mutable iterator over the parameters of a neural
29/// network layer by zipping together a mutable axis iterator over the columns of the weights
30/// and a mutable iterator over the bias.
31pub struct ParamsIterMut<'a, A, D>
32where
33    D: Dimension,
34{
35    pub(crate) weights: AxisIterMut<'a, A, D::Smaller>,
36    pub(crate) bias: NdIterMut<'a, A, D::Smaller>,
37}
38
39/*
40 ************* Implementations *************
41*/
42impl<'a, A, D> Iterator for ParamsIter<'a, A, D>
43where
44    D: Dimension,
45{
46    type Item = ItemRef<'a, A, D>;
47
48    fn next(&mut self) -> Option<Self::Item> {
49        match (self.weights.next(), self.bias.next()) {
50            (Some(w), Some(b)) => Some((w, b)),
51            _ => None,
52        }
53    }
54}
55
56impl<'a, A, D> ExactSizeIterator for ParamsIter<'a, A, D>
57where
58    D: Dimension,
59{
60    fn len(&self) -> usize {
61        self.weights.len()
62    }
63}
64
65impl<'a, A, D> Iterator for ParamsIterMut<'a, A, D>
66where
67    D: Dimension,
68{
69    type Item = ItemMut<'a, A, D>;
70
71    fn next(&mut self) -> Option<Self::Item> {
72        match (self.weights.next(), self.bias.next()) {
73            (Some(w), Some(b)) => Some((w, b)),
74            _ => None,
75        }
76    }
77}