candle_core/
strided_index.rs1use crate::Layout;
2
3#[derive(Debug)]
6pub struct StridedIndex<'a> {
7 next_storage_index: Option<usize>,
8 multi_index: Vec<usize>,
9 dims: &'a [usize],
10 stride: &'a [usize],
11 remaining: usize,
12}
13
14impl<'a> StridedIndex<'a> {
15 pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
16 let elem_count: usize = dims.iter().product();
17 let next_storage_index = if elem_count == 0 {
18 None
19 } else {
20 Some(start_offset)
22 };
23 StridedIndex {
24 next_storage_index,
25 multi_index: vec![0; dims.len()],
26 dims,
27 stride,
28 remaining: elem_count,
29 }
30 }
31
32 pub(crate) fn from_layout(l: &'a Layout) -> Self {
33 Self::new(l.dims(), l.stride(), l.start_offset())
34 }
35}
36
37impl Iterator for StridedIndex<'_> {
38 type Item = usize;
39
40 #[inline]
41 fn next(&mut self) -> Option<Self::Item> {
42 let storage_index = self.next_storage_index?;
43 let mut updated = false;
44 let mut next_storage_index = storage_index;
45 for ((multi_i, max_i), stride_i) in self
46 .multi_index
47 .iter_mut()
48 .zip(self.dims.iter())
49 .zip(self.stride.iter())
50 .rev()
51 {
52 let next_i = *multi_i + 1;
53 if next_i < *max_i {
54 *multi_i = next_i;
55 updated = true;
56 next_storage_index += stride_i;
57 break;
58 } else {
59 next_storage_index -= *multi_i * stride_i;
60 *multi_i = 0
61 }
62 }
63 self.remaining -= 1;
64 self.next_storage_index = if updated {
65 Some(next_storage_index)
66 } else {
67 None
68 };
69 Some(storage_index)
70 }
71
72 #[inline]
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 (self.remaining, Some(self.remaining))
75 }
76}
77
78impl ExactSizeIterator for StridedIndex<'_> {
79 fn len(&self) -> usize {
80 self.remaining
81 }
82}
83
84#[derive(Debug)]
85pub enum StridedBlocks<'a> {
86 SingleBlock {
87 start_offset: usize,
88 len: usize,
89 },
90 MultipleBlocks {
91 block_start_index: StridedIndex<'a>,
92 block_len: usize,
93 },
94}