1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
use std::ops::{Deref, DerefMut, Range};
/// The strides of an `N` dimension array.
///
/// `Strides` make no assumptions about the shape of an array. The caller must
/// ensure that expanded/flattened indices are valid.
#[derive(Debug, Clone, Copy)]
pub struct Strides<const N: usize>([usize; N]);
impl<const N: usize> Strides<N> {
/// Create `Strides` for the given `shape`.
pub fn new(shape: &[usize; N]) -> Self {
let mut array = [1; N];
for i in 1..N {
array[i] = array[i - 1] * shape[i - 1];
}
Strides(array)
}
/// Expand an `index` into an `N` dimensional index.
///
/// # Example
///
/// ```
/// # #[cfg(feature = "strides")] {
/// # use std::ops::Range;
/// # use n_circular_array::{CircularArray, CircularIndex, Strides};
/// let mut array = CircularArray::new([3, 3], vec![
/// 0, 1, 2,
/// 3, 4, 5,
/// 6, 7, 8
/// ]);
/// assert_eq!(array.strides().expand_index(4), [1, 1]);
/// assert_eq!(array.strides().expand_index(5), [2, 1]);
/// # }
/// ```
#[allow(dead_code)]
pub fn expand_index(&self, index: usize) -> [usize; N] {
let mut array = [0; N];
for i in 0..N - 1 {
array[i] = (index % self.0[i + 1]) / self.0[i];
}
array[N - 1] = index / self.0[N - 1];
array
}
/// Flatten an `N` dimensional `index` into a contiguous index.
///
/// # Example
///
/// ```
/// # #[cfg(feature = "strides")] {
/// # use std::ops::Range;
/// # use n_circular_array::{CircularArray, CircularIndex, Strides};
/// let mut array = CircularArray::new([3, 3], vec![
/// 0, 1, 2,
/// 3, 4, 5,
/// 6, 7, 8
/// ]);
/// assert_eq!(array.strides().flatten_index([1, 1]), 4);
/// assert_eq!(array.strides().flatten_index([2, 1]), 5);
/// # }
/// ```
#[allow(dead_code)]
pub fn flatten_index(&self, index: [usize; N]) -> usize {
index
.iter()
.zip(self.iter())
.map(|(idx, stride)| idx * stride)
.sum::<usize>()
}
/// Flatten a **contiguous** `N` dimensional index range into a contiguous
/// `Range<usize>`.
///
/// # Example
///
/// ```
/// # #[cfg(feature = "strides")] {
/// # use std::ops::Range;
/// # use n_circular_array::{CircularArray, CircularIndex, Strides};
/// let mut array = CircularArray::new([3, 3], vec![
/// 0, 1, 2,
/// 3, 4, 5,
/// 6, 7, 8
/// ]);
/// // A contiguous range from [0, 0] to [1, 1].
/// assert_eq!(array.strides().flatten_range([0..2, 0..2]), 0..5);
/// // A contiguous range from [1, 0] to [1, 2].
/// assert_eq!(array.strides().flatten_range([1..2, 0..3]), 1..8);
/// # }
/// ```
#[allow(dead_code)]
pub fn flatten_range(&self, index_range: [Range<usize>; N]) -> Range<usize> {
let (start, end) = index_range.into_iter().zip(self.iter()).fold(
(0, 0),
|(start, end), (range, stride)| {
// Unrolled equivalent of `flatten_index`.
(start + range.start * stride, end + (range.end - 1) * stride)
},
);
start..end + 1
}
}
impl<const N: usize> Deref for Strides<N> {
type Target = [usize; N];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<const N: usize> DerefMut for Strides<N> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let strides_2d = Strides::new(&[5, 4]);
let strides_3d = Strides::new(&[5, 4, 3]);
let strides_4d = Strides::new(&[5, 4, 3, 2]);
assert_eq!(strides_2d.0, [1, 5]);
assert_eq!(strides_3d.0, [1, 5, 20]);
assert_eq!(strides_4d.0, [1, 5, 20, 60]);
}
#[test]
fn expand_index() {
let strides_2d = Strides::new(&[5, 4]);
let strides_3d = Strides::new(&[5, 4, 3]);
let strides_4d = Strides::new(&[5, 4, 3, 2]);
assert_eq!(strides_2d.expand_index(11), [1, 2]);
assert_eq!(strides_3d.expand_index(31), [1, 2, 1]);
assert_eq!(strides_4d.expand_index(81), [1, 0, 1, 1]);
}
#[test]
fn flatten_index() {
let strides_2d = Strides::new(&[5, 4]);
let strides_3d = Strides::new(&[5, 4, 3]);
let strides_4d = Strides::new(&[5, 4, 3, 2]);
assert_eq!(strides_2d.flatten_index([1, 2]), 11);
assert_eq!(strides_3d.flatten_index([1, 2, 1]), 31);
assert_eq!(strides_4d.flatten_index([1, 0, 1, 1]), 81);
}
#[test]
fn flatten_range() {
let strides_2d = Strides::new(&[5, 4]);
let strides_3d = Strides::new(&[5, 4, 3]);
let strides_4d = Strides::new(&[5, 4, 3, 2]);
assert_eq!(strides_2d.flatten_range([1..3, 2..3]), 11..13);
assert_eq!(strides_3d.flatten_range([1..3, 2..3, 1..2]), 31..33);
assert_eq!(strides_4d.flatten_range([1..3, 0..1, 1..2, 1..2]), 81..83);
}
}