concision_core/params/impls/impl_params_iter.rs
1/*
2 appellation: impl_params_iter <module>
3 authors: @FL03
4*/
5use crate::params::ParamsBase;
6
7use crate::params::iter::{Iter, IterMut};
8use ndarray::iter as nditer;
9use ndarray::{Axis, Data, DataMut, Dimension, RawData, RemoveAxis};
10
11/// Here, we implement various iterators for the parameters and its constituents. The _core_
12/// iterators are:
13///
14/// - immutable and mutable iterators over each parameter (weights and bias) respectively;
15/// - an iterator over the parameters, which zips together an axis iterator over the columns of
16/// the weights and an iterator over the bias;
17impl<S, D, A> ParamsBase<S, D>
18where
19 S: RawData<Elem = A>,
20 D: Dimension,
21{
22 /// an iterator of the parameters; the created iterator zips together an axis iterator over
23 /// the columns of the weights and an iterator over the bias
24 pub fn iter(&self) -> Iter<'_, A, D>
25 where
26 D: RemoveAxis,
27 S: Data,
28 {
29 Iter {
30 bias: self.bias().iter(),
31 weights: self.weights().axis_iter(Axis(1)),
32 }
33 }
34 /// returns a mutable iterator of the parameters, [`IterMut`], which essentially zips
35 /// together a mutable axis iterator over the columns of the weights against a mutable
36 /// iterator over the elements of the bias
37 pub fn iter_mut(&mut self) -> IterMut<'_, A, D>
38 where
39 D: RemoveAxis,
40 S: DataMut,
41 {
42 IterMut {
43 bias: self.bias.iter_mut(),
44 weights: self.weights.axis_iter_mut(Axis(1)),
45 }
46 }
47 /// returns an iterator over the bias
48 pub fn iter_bias(&self) -> nditer::Iter<'_, A, D::Smaller>
49 where
50 S: Data,
51 {
52 self.bias().iter()
53 }
54 /// returns a mutable iterator over the bias
55 pub fn iter_bias_mut(&mut self) -> nditer::IterMut<'_, A, D::Smaller>
56 where
57 S: DataMut,
58 {
59 self.bias_mut().iter_mut()
60 }
61 /// returns an iterator over the weights
62 pub fn iter_weights(&self) -> nditer::Iter<'_, A, D>
63 where
64 S: Data,
65 {
66 self.weights().iter()
67 }
68 /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more
69 pub fn iter_weights_mut(&mut self) -> nditer::IterMut<'_, A, D>
70 where
71 S: DataMut,
72 {
73 self.weights_mut().iter_mut()
74 }
75}