concision_core/params/
iter.rs

1/*
2    Appellation: iter <module>
3    Contrib: @FL03
4*/
5use ndarray::Dimension;
6use ndarray::iter::{AxisIter, AxisIterMut};
7use ndarray::iter::{Iter as NdIter, IterMut as NdIterMut};
8
9pub(crate) type ItemRef<'a, A, D> = (
10    <AxisIter<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
11    &'a A,
12);
13pub(crate) type ItemMut<'a, A, D> = (
14    <AxisIterMut<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
15    &'a mut A,
16);
17/// The [`Iter`] type provides an iterator over the parameters of a neural network layer by
18/// zipping together an axis iterator over the columns of the weights and an iterator over the
19/// bias.
20pub struct Iter<'a, A, D>
21where
22    D: Dimension,
23{
24    pub(crate) weights: AxisIter<'a, A, D::Smaller>,
25    pub(crate) bias: NdIter<'a, A, D::Smaller>,
26}
27/// The [`IterMut`] type provides a mutable iterator over the parameters of a neural network
28/// layer by zipping together a mutable axis iterator over the columns of the weights and
29/// a mutable iterator over the bias.
30pub struct IterMut<'a, A, D>
31where
32    D: Dimension,
33{
34    pub(crate) weights: AxisIterMut<'a, A, D::Smaller>,
35    pub(crate) bias: NdIterMut<'a, A, D::Smaller>,
36}
37
38/*
39 ************* Implementations *************
40*/
41impl<'a, A, D> Iterator for Iter<'a, A, D>
42where
43    D: Dimension,
44{
45    type Item = ItemRef<'a, A, D>;
46
47    fn next(&mut self) -> Option<Self::Item> {
48        match (self.weights.next(), self.bias.next()) {
49            (Some(w), Some(b)) => Some((w, b)),
50            _ => None,
51        }
52    }
53}
54
55impl<'a, A, D> Iterator for IterMut<'a, A, D>
56where
57    D: Dimension,
58{
59    type Item = ItemMut<'a, A, D>;
60
61    fn next(&mut self) -> Option<Self::Item> {
62        match (self.weights.next(), self.bias.next()) {
63            (Some(w), Some(b)) => Some((w, b)),
64            _ => None,
65        }
66    }
67}
68
69impl<'a, A, D> ExactSizeIterator for Iter<'a, A, D>
70where
71    D: Dimension,
72{
73    fn len(&self) -> usize {
74        self.weights.len()
75    }
76}