concision_core/params/impls/
impl_params_iter.rs

1/*
2    appellation: impl_params_iter <module>
3    authors: @FL03
4*/
5use crate::params::ParamsBase;
6
7use crate::params::iter::{Iter, IterMut};
8use ndarray::iter as nditer;
9use ndarray::{Axis, Data, DataMut, Dimension, RawData, RemoveAxis};
10
11/// Here, we implement various iterators for the parameters and its constituents. The _core_
12/// iterators are:
13///
14/// - immutable and mutable iterators over each parameter (weights and bias) respectively;
15/// - an iterator over the parameters, which zips together an axis iterator over the columns of
16///   the weights and an iterator over the bias;
17impl<S, D, A> ParamsBase<S, D>
18where
19    S: RawData<Elem = A>,
20    D: Dimension,
21{
22    /// an iterator of the parameters; the created iterator zips together an axis iterator over
23    /// the columns of the weights and an iterator over the bias
24    pub fn iter(&self) -> Iter<'_, A, D>
25    where
26        D: RemoveAxis,
27        S: Data,
28    {
29        Iter {
30            bias: self.bias().iter(),
31            weights: self.weights().axis_iter(Axis(1)),
32        }
33    }
34    /// returns a mutable iterator of the parameters, [`IterMut`], which essentially zips
35    /// together a mutable axis iterator over the columns of the weights against a mutable
36    /// iterator over the elements of the bias
37    pub fn iter_mut(&mut self) -> IterMut<'_, A, D>
38    where
39        D: RemoveAxis,
40        S: DataMut,
41    {
42        IterMut {
43            bias: self.bias.iter_mut(),
44            weights: self.weights.axis_iter_mut(Axis(1)),
45        }
46    }
47    /// returns an iterator over the bias
48    pub fn iter_bias(&self) -> nditer::Iter<'_, A, D::Smaller>
49    where
50        S: Data,
51    {
52        self.bias().iter()
53    }
54    /// returns a mutable iterator over the bias
55    pub fn iter_bias_mut(&mut self) -> nditer::IterMut<'_, A, D::Smaller>
56    where
57        S: DataMut,
58    {
59        self.bias_mut().iter_mut()
60    }
61    /// returns an iterator over the weights
62    pub fn iter_weights(&self) -> nditer::Iter<'_, A, D>
63    where
64        S: Data,
65    {
66        self.weights().iter()
67    }
68    /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
69    pub fn iter_weights_mut(&mut self) -> nditer::IterMut<'_, A, D>
70    where
71        S: DataMut,
72    {
73        self.weights_mut().iter_mut()
74    }
75}