1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

use {Dimension, Axis, Ix, Ixs};

/// Create a new Axes iterator
pub fn axes_of<'a, D>(d: &'a D, strides: &'a D) -> Axes<'a, D>
    where D: Dimension,
{
    Axes {
        dim: d,
        strides: strides,
        start: 0,
        end: d.ndim(),
    }
}

/// An iterator over the length and stride of each axis of an array.
///
/// See [`.axes()`](../struct.ArrayBase.html#method.axes) for more information.
///
/// Iterator element type is `AxisDescription`.
///
/// # Examples
///
/// ```
/// use ndarray::Array3;
/// use ndarray::Axis;
///
/// let a = Array3::<f32>::zeros((3, 5, 4));
///
/// let largest_axis = a.axes()
///                     .max_by_key(|ax| ax.len())
///                     .unwrap().axis();
/// assert_eq!(largest_axis, Axis(1));
/// ```
#[derive(Debug)]
pub struct Axes<'a, D: 'a> {
    dim: &'a D,
    strides: &'a D,
    start: usize,
    end: usize,
}

/// Description of the axis, its length and its stride.
#[derive(Debug)]
pub struct AxisDescription(pub Axis, pub Ix, pub Ixs);

copy_and_clone!(AxisDescription);

impl AxisDescription {
    /// Return axis
    #[inline(always)]
    pub fn axis(self) -> Axis { self.0 }
    /// Return length
    #[inline(always)]
    pub fn len(self) -> Ix { self.1 }
    /// Return stride
    #[inline(always)]
    pub fn stride(self) -> Ixs { self.2 }
}

copy_and_clone!(['a, D] Axes<'a, D>);

impl<'a, D> Iterator for Axes<'a, D>
    where D: Dimension,
{
    /// Description of the axis, its length and its stride.
    type Item = AxisDescription;

    fn next(&mut self) -> Option<Self::Item> {
        if self.start < self.end {
            let i = self.start.post_inc();
            Some(AxisDescription(Axis(i), self.dim[i], self.strides[i] as Ixs))
        } else {
            None
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        let len = self.end - self.start;
        (len, Some(len))
    }
}

impl<'a, D> DoubleEndedIterator for Axes<'a, D>
    where D: Dimension,
{
    fn next_back(&mut self) -> Option<Self::Item> {
        if self.start < self.end {
            let i = self.end.pre_dec();
            Some(AxisDescription(Axis(i), self.dim[i], self.strides[i] as Ixs))
        } else {
            None
        }
    }
}

trait IncOps : Copy {
    fn post_inc(&mut self) -> Self;
    fn post_dec(&mut self) -> Self;
    fn pre_dec(&mut self) -> Self;
}

impl IncOps for usize {
    #[inline(always)]
    fn post_inc(&mut self) -> Self {
        let x = *self;
        *self += 1;
        x
    }
    #[inline(always)]
    fn post_dec(&mut self) -> Self {
        let x = *self;
        *self -= 1;
        x
    }
    #[inline(always)]
    fn pre_dec(&mut self) -> Self {
        *self -= 1;
        *self
    }
}