candle_core/
strided_index.rs

1use crate::Layout;
2
3/// An iterator over offset position for items of an N-dimensional arrays stored in a
4/// flat buffer using some potential strides.
5#[derive(Debug)]
6pub struct StridedIndex<'a> {
7    next_storage_index: Option<usize>,
8    multi_index: Vec<usize>,
9    dims: &'a [usize],
10    stride: &'a [usize],
11}
12
13impl<'a> StridedIndex<'a> {
14    pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
15        let elem_count: usize = dims.iter().product();
16        let next_storage_index = if elem_count == 0 {
17            None
18        } else {
19            // This applies to the scalar case.
20            Some(start_offset)
21        };
22        StridedIndex {
23            next_storage_index,
24            multi_index: vec![0; dims.len()],
25            dims,
26            stride,
27        }
28    }
29
30    pub(crate) fn from_layout(l: &'a Layout) -> Self {
31        Self::new(l.dims(), l.stride(), l.start_offset())
32    }
33}
34
35impl Iterator for StridedIndex<'_> {
36    type Item = usize;
37
38    fn next(&mut self) -> Option<Self::Item> {
39        let storage_index = self.next_storage_index?;
40        let mut updated = false;
41        let mut next_storage_index = storage_index;
42        for ((multi_i, max_i), stride_i) in self
43            .multi_index
44            .iter_mut()
45            .zip(self.dims.iter())
46            .zip(self.stride.iter())
47            .rev()
48        {
49            let next_i = *multi_i + 1;
50            if next_i < *max_i {
51                *multi_i = next_i;
52                updated = true;
53                next_storage_index += stride_i;
54                break;
55            } else {
56                next_storage_index -= *multi_i * stride_i;
57                *multi_i = 0
58            }
59        }
60        self.next_storage_index = if updated {
61            Some(next_storage_index)
62        } else {
63            None
64        };
65        Some(storage_index)
66    }
67}
68
69#[derive(Debug)]
70pub enum StridedBlocks<'a> {
71    SingleBlock {
72        start_offset: usize,
73        len: usize,
74    },
75    MultipleBlocks {
76        block_start_index: StridedIndex<'a>,
77        block_len: usize,
78    },
79}