candle_core_temp/
strided_index.rs1use crate::Layout;
2
3#[derive(Debug)]
6pub struct StridedIndex<'a> {
7 next_storage_index: Option<usize>,
8 multi_index: Vec<usize>,
9 dims: &'a [usize],
10 stride: &'a [usize],
11}
12
13impl<'a> StridedIndex<'a> {
14 pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
15 let elem_count: usize = dims.iter().product();
16 let next_storage_index = if elem_count == 0 {
17 None
18 } else {
19 Some(start_offset)
21 };
22 StridedIndex {
23 next_storage_index,
24 multi_index: vec![0; dims.len()],
25 dims,
26 stride,
27 }
28 }
29
30 pub(crate) fn from_layout(l: &'a Layout) -> Self {
31 Self::new(l.dims(), l.stride(), l.start_offset())
32 }
33}
34
35impl<'a> Iterator for StridedIndex<'a> {
36 type Item = usize;
37
38 fn next(&mut self) -> Option<Self::Item> {
39 let storage_index = match self.next_storage_index {
40 None => return None,
41 Some(storage_index) => storage_index,
42 };
43 let mut updated = false;
44 let mut next_storage_index = storage_index;
45 for ((multi_i, max_i), stride_i) in self
46 .multi_index
47 .iter_mut()
48 .zip(self.dims.iter())
49 .zip(self.stride.iter())
50 .rev()
51 {
52 let next_i = *multi_i + 1;
53 if next_i < *max_i {
54 *multi_i = next_i;
55 updated = true;
56 next_storage_index += stride_i;
57 break;
58 } else {
59 next_storage_index -= *multi_i * stride_i;
60 *multi_i = 0
61 }
62 }
63 self.next_storage_index = if updated {
64 Some(next_storage_index)
65 } else {
66 None
67 };
68 Some(storage_index)
69 }
70}
71
72#[derive(Debug)]
73pub enum StridedBlocks<'a> {
74 SingleBlock {
75 start_offset: usize,
76 len: usize,
77 },
78 MultipleBlocks {
79 block_start_index: StridedIndex<'a>,
80 block_len: usize,
81 },
82}