Skip to main content

candle_core/
nditer.rs

1use std::iter::FusedIterator;
2
3use crate::Layout;
4
5/// Maximum tensor rank supported by [`NdIter`].
6pub const MAX_DIMS: usize = 8;
7
8/// Multi-dimensional iterator that walks `N` tensor layouts simultaneously.
9///
10/// Internally, adjacent dimensions are merged whenever their strides allow it,
11/// reducing the number of outer iterations. Each call to [`Iterator::next`] yields
12/// `[usize; N]`, one base offset per layout, for one inner slice of `inner_size`
13/// elements. Callers iterate that slice using `inner_strides`:
14///
15/// ```ignore
16/// let nd_iter = NdIter::new([lhs_l, rhs_l]);
17/// let inner_size = nd_iter.inner_size;
18/// let [inner_ls, inner_rs] = nd_iter.inner_strides;
19/// for [lhs_off, rhs_off] in nd_iter {
20///     for i in 0..inner_size {
21///         let lhs = lhs_buf[lhs_off + i * inner_ls];
22///         let rhs = rhs_buf[rhs_off + i * inner_rs];
23///         ...
24///     }
25/// }
26/// ```
27///
28/// Note: `inner_strides[n]` can be 1 contiguous, but also 0 (broadcast/scalar) or > 1 (non-contiguous),
29/// so callers must not assume contiguous element access within a slice.
30pub struct NdIter<const N: usize> {
31    pub inner_size: usize,
32    /// Per-layout element stride within each inner slice: `element_offset = base + i * stride`.
33    pub inner_strides: [usize; N],
34
35    outer_dims: [usize; MAX_DIMS],
36    outer_strides: [[usize; MAX_DIMS]; N],
37    outer_len: usize,
38
39    offsets: [usize; N],
40    coords: [usize; MAX_DIMS],
41    remaining: usize,
42}
43
44impl<const N: usize> NdIter<N> {
45    pub fn new(layouts: [&Layout; N]) -> NdIter<N> {
46        let dims = layouts[0].dims();
47        debug_assert!(
48            dims.len() <= MAX_DIMS,
49            "rank {} exceeds MAX_DIMS={}",
50            dims.len(),
51            MAX_DIMS
52        );
53        #[cfg(debug_assertions)]
54        for l in &layouts {
55            debug_assert_eq!(l.dims(), dims);
56        }
57
58        let rank = dims.len();
59        let mut out_dims = [0usize; MAX_DIMS];
60        let mut out_strides = [[0usize; MAX_DIMS]; N];
61        let mut out_len;
62
63        if rank == 0 {
64            out_dims[0] = 1;
65            out_len = 1;
66        } else {
67            out_dims[0] = dims[0];
68            for n in 0..N {
69                out_strides[n][0] = layouts[n].stride()[0];
70            }
71            out_len = 1;
72
73            for (i, d) in dims.iter().enumerate().take(rank).skip(1) {
74                let top = out_len - 1;
75                let last_d = out_dims[top];
76
77                let (can_merge, use_inner) = if last_d == 1 {
78                    (true, true)
79                } else if *d == 1 {
80                    (true, false)
81                } else {
82                    let can_merge =
83                        (0..N).all(|n| out_strides[n][top] == layouts[n].stride()[i] * d);
84                    (can_merge, true)
85                };
86                if can_merge {
87                    out_dims[top] = last_d * d;
88                    if use_inner {
89                        for n in 0..N {
90                            out_strides[n][top] = layouts[n].stride()[i];
91                        }
92                    }
93                } else {
94                    out_dims[out_len] = *d;
95                    for n in 0..N {
96                        out_strides[n][out_len] = layouts[n].stride()[i];
97                    }
98                    out_len += 1;
99                }
100            }
101        }
102
103        // Update inner strides
104        let inner_idx = out_len - 1;
105        let inner_size = out_dims[inner_idx];
106        let mut inner_strides = [0usize; N];
107        for n in 0..N {
108            inner_strides[n] = out_strides[n][inner_idx];
109        }
110
111        // Update outer dims
112        let outer_len = inner_idx;
113        let mut outer_dims = [0usize; MAX_DIMS];
114        outer_dims[..outer_len].copy_from_slice(&out_dims[..outer_len]);
115
116        // Update outer strides
117        let mut outer_strides = [[0usize; MAX_DIMS]; N];
118        for n in 0..N {
119            outer_strides[n][..outer_len].copy_from_slice(&out_strides[n][..outer_len]);
120        }
121
122        // Update offsets
123        let mut offsets = [0usize; N];
124        for n in 0..N {
125            offsets[n] = layouts[n].start_offset();
126        }
127
128        // Number of outer blocks (product of all dims except the innermost).
129        let remaining = out_dims[..outer_len].iter().product();
130
131        NdIter {
132            inner_size,
133            inner_strides,
134            outer_dims,
135            outer_strides,
136            outer_len,
137            offsets,
138            coords: [0; MAX_DIMS],
139            remaining,
140        }
141    }
142}
143
144impl<const N: usize> Iterator for NdIter<N> {
145    type Item = [usize; N];
146
147    #[inline]
148    fn next(&mut self) -> Option<Self::Item> {
149        if self.remaining == 0 {
150            return None;
151        }
152        let item = self.offsets;
153        self.remaining -= 1;
154
155        for k in (0..self.outer_len).rev() {
156            self.coords[k] += 1;
157            for n in 0..N {
158                self.offsets[n] += self.outer_strides[n][k];
159            }
160            if self.coords[k] < self.outer_dims[k] {
161                break;
162            }
163            self.coords[k] = 0;
164            for n in 0..N {
165                self.offsets[n] -= self.outer_dims[k] * self.outer_strides[n][k];
166            }
167        }
168
169        Some(item)
170    }
171
172    #[inline]
173    fn size_hint(&self) -> (usize, Option<usize>) {
174        (self.remaining, Some(self.remaining))
175    }
176}
177
178impl<const N: usize> ExactSizeIterator for NdIter<N> {}
179impl<const N: usize> FusedIterator for NdIter<N> {}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::shape::Shape;
185
186    fn layout(dims: &[usize], strides: &[usize]) -> Layout {
187        Layout::new(Shape::from(dims.to_vec()), strides.to_vec(), 0)
188    }
189
190    #[test]
191    fn rank0_scalar() {
192        let l = Layout::contiguous(());
193        let mut it = NdIter::new([&l, &l]);
194        assert_eq!(it.inner_size, 1);
195        assert_eq!(it.inner_strides, [0, 0]);
196        assert_eq!(it.len(), 1);
197        assert_eq!(it.next(), Some([0, 0]));
198        assert_eq!(it.next(), None);
199    }
200
201    #[test]
202    fn rank1_contiguous_single_block() {
203        let l = Layout::contiguous(&[5]);
204        let mut it = NdIter::new([&l, &l]);
205        assert_eq!(it.inner_size, 5);
206        assert_eq!(it.inner_strides, [1, 1]);
207        assert_eq!(it.len(), 1);
208        assert_eq!(it.next(), Some([0, 0]));
209        assert_eq!(it.next(), None);
210    }
211
212    #[test]
213    fn rank2_contiguous_merges_to_one_block() {
214        // [3, 4] strides[4, 1] -> fully merged, inner_size=12
215        let l = Layout::contiguous(&[3, 4]);
216        let it = NdIter::new([&l, &l]);
217        assert_eq!(it.inner_size, 12);
218        assert_eq!(it.inner_strides, [1, 1]);
219        assert_eq!(it.len(), 1);
220    }
221
222    #[test]
223    fn rank3_contiguous_fully_merged() {
224        let l = Layout::contiguous(&[2, 3, 4]);
225        let it = NdIter::new([&l]);
226        assert_eq!(it.inner_size, 24);
227        assert_eq!(it.inner_strides, [1]);
228        assert_eq!(it.len(), 1);
229    }
230
231    #[test]
232    fn rank3_outer_gap_partial_merge() {
233        // shape [2, 3, 4]
234        // strides [24, 4, 1]
235        // dims 1 + 2 merge (4 == 1 * 4), dim 0 stays
236        let l = layout(&[2, 3, 4], &[24, 4, 1]);
237        let it = NdIter::new([&l]);
238        assert_eq!(it.inner_size, 12); // 3 * 4 merged
239        assert_eq!(it.inner_strides, [1]);
240        assert_eq!(it.len(), 2); // outer: dim 0 has size 2
241        let offsets: Vec<_> = it.collect();
242        assert_eq!(offsets, vec![[0], [24]]);
243    }
244
245    #[test]
246    fn rank2_no_merge() {
247        // shape [3, 4]
248        // strides [1, 3]
249        // stride[0] = 1 != stride[1] * dim[1] = 12 -> no merge
250        let l = layout(&[3, 4], &[1, 3]);
251        let it = NdIter::new([&l, &l]);
252        assert_eq!(it.inner_size, 4);
253        assert_eq!(it.inner_strides, [3, 3]);
254        assert_eq!(it.len(), 3);
255        let offsets: Vec<_> = it.collect();
256        assert_eq!(offsets, vec![[0, 0], [1, 1], [2, 2]]);
257    }
258
259    #[test]
260    fn broadcast_zeros_merge() {
261        // shape [3, 4]
262        // strides [0, 0]
263        // 0 == 0 * 4 -> dims merge
264        let l = layout(&[3, 4], &[0, 0]);
265        let it = NdIter::new([&l, &l]);
266        assert_eq!(it.inner_size, 12);
267        assert_eq!(it.inner_strides, [0, 0]);
268        assert_eq!(it.len(), 1);
269    }
270
271    #[test]
272    fn mixed_contiguous_and_broadcast_merge() {
273        // lhs [4, 1] contiguous
274        // rhs [0,0] broadcast
275        // both conditions pass -> merge
276        let lhs = Layout::contiguous(&[3, 4]);
277        let rhs = layout(&[3, 4], &[0, 0]);
278        let it = NdIter::new([&lhs, &rhs]);
279        assert_eq!(it.inner_size, 12);
280        assert_eq!(it.inner_strides, [1, 0]);
281        assert_eq!(it.len(), 1);
282    }
283
284    #[test]
285    fn offsets_lhs_contiguous_rhs_strided() {
286        // lhs row-major [2,3] strides [3,1]
287        // rhs col-major [2,3] strides [1,2]
288        // lhs can merge (3 == 1 * 3)
289        // rhs can not (1 ≠ 2 * 3 = 6) -> no merge for the pair
290        let lhs = Layout::contiguous(&[2, 3]);
291        let rhs = layout(&[2, 3], &[1, 2]);
292        let it = NdIter::new([&lhs, &rhs]);
293        assert_eq!(it.inner_size, 3);
294        assert_eq!(it.inner_strides, [1, 2]);
295        assert_eq!(it.len(), 2);
296        let offsets: Vec<_> = it.collect();
297        assert_eq!(offsets, vec![[0, 0], [3, 1]]);
298    }
299
300    #[test]
301    fn start_offset_reflected_in_first_iter() {
302        let l = Layout::contiguous_with_offset(4, 7);
303        let mut it = NdIter::new([&l]);
304        assert_eq!(it.next(), Some([7]));
305        assert_eq!(it.next(), None);
306    }
307
308    #[test]
309    fn start_offset_advances_with_outer_dims() {
310        // shape [2, 3]
311        // strides [4, 1]
312        // start_offset=10
313        // strides can't merge (4 != 1 * 3)
314        // outer blocks at 10 and 14 (10 + 1 * outer_stride)
315        let l = Layout::new(Shape::from(vec![2, 3]), vec![4, 1], 10);
316        let offsets: Vec<_> = NdIter::new([&l]).collect();
317        assert_eq!(offsets, vec![[10], [14]]);
318    }
319}