candle_core/
strided_index.rs

1use crate::Layout;
2
3/// An iterator over offset position for items of an N-dimensional arrays stored in a
4/// flat buffer using some potential strides.
5#[derive(Debug)]
6pub struct StridedIndex<'a> {
7    next_storage_index: Option<usize>,
8    multi_index: Vec<usize>,
9    dims: &'a [usize],
10    stride: &'a [usize],
11    remaining: usize,
12}
13
14impl<'a> StridedIndex<'a> {
15    pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
16        let elem_count: usize = dims.iter().product();
17        let next_storage_index = if elem_count == 0 {
18            None
19        } else {
20            // This applies to the scalar case.
21            Some(start_offset)
22        };
23        StridedIndex {
24            next_storage_index,
25            multi_index: vec![0; dims.len()],
26            dims,
27            stride,
28            remaining: elem_count,
29        }
30    }
31
32    pub(crate) fn from_layout(l: &'a Layout) -> Self {
33        Self::new(l.dims(), l.stride(), l.start_offset())
34    }
35}
36
37impl Iterator for StridedIndex<'_> {
38    type Item = usize;
39
40    #[inline]
41    fn next(&mut self) -> Option<Self::Item> {
42        let storage_index = self.next_storage_index?;
43        let mut updated = false;
44        let mut next_storage_index = storage_index;
45        for ((multi_i, max_i), stride_i) in self
46            .multi_index
47            .iter_mut()
48            .zip(self.dims.iter())
49            .zip(self.stride.iter())
50            .rev()
51        {
52            let next_i = *multi_i + 1;
53            if next_i < *max_i {
54                *multi_i = next_i;
55                updated = true;
56                next_storage_index += stride_i;
57                break;
58            } else {
59                next_storage_index -= *multi_i * stride_i;
60                *multi_i = 0
61            }
62        }
63        self.remaining -= 1;
64        self.next_storage_index = if updated {
65            Some(next_storage_index)
66        } else {
67            None
68        };
69        Some(storage_index)
70    }
71
72    #[inline]
73    fn size_hint(&self) -> (usize, Option<usize>) {
74        (self.remaining, Some(self.remaining))
75    }
76}
77
78impl ExactSizeIterator for StridedIndex<'_> {
79    fn len(&self) -> usize {
80        self.remaining
81    }
82}
83
84#[derive(Debug)]
85pub enum StridedBlocks<'a> {
86    SingleBlock {
87        start_offset: usize,
88        len: usize,
89    },
90    MultipleBlocks {
91        block_start_index: StridedIndex<'a>,
92        block_len: usize,
93    },
94}