use ndarray::Dimension;
use ndarray::iter::{AxisIter, AxisIterMut};
use ndarray::iter::{Iter as NdIter, IterMut as NdIterMut};
pub(crate) type ItemRef<'a, A, D> = (
<AxisIter<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
&'a A,
);
pub(crate) type ItemMut<'a, A, D> = (
<AxisIterMut<'a, A, <D as Dimension>::Smaller> as Iterator>::Item,
&'a mut A,
);
pub struct ParamsIter<'a, A, D>
where
D: Dimension,
{
pub(crate) weights: AxisIter<'a, A, D::Smaller>,
pub(crate) bias: NdIter<'a, A, D::Smaller>,
}
pub struct ParamsIterMut<'a, A, D>
where
D: Dimension,
{
pub(crate) weights: AxisIterMut<'a, A, D::Smaller>,
pub(crate) bias: NdIterMut<'a, A, D::Smaller>,
}
impl<'a, A, D> Iterator for ParamsIter<'a, A, D>
where
D: Dimension,
{
type Item = ItemRef<'a, A, D>;
fn next(&mut self) -> Option<Self::Item> {
match (self.weights.next(), self.bias.next()) {
(Some(w), Some(b)) => Some((w, b)),
_ => None,
}
}
}
impl<'a, A, D> ExactSizeIterator for ParamsIter<'a, A, D>
where
D: Dimension,
{
fn len(&self) -> usize {
self.weights.len()
}
}
impl<'a, A, D> Iterator for ParamsIterMut<'a, A, D>
where
D: Dimension,
{
type Item = ItemMut<'a, A, D>;
fn next(&mut self) -> Option<Self::Item> {
match (self.weights.next(), self.bias.next()) {
(Some(w), Some(b)) => Some((w, b)),
_ => None,
}
}
}