Skip to main content

core_utils/circuit/
slice.rs

1use itertools::izip;
2use serde::{Deserialize, Serialize};
3use wincode::{SchemaRead, SchemaWrite};
4
5use crate::circuit::errors::SliceError;
6
7/// A general slicing structure which can represent a single index, a strided 1d range, a strided 2d
8/// range or a vector of slices.
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    PartialOrd,
15    Ord,
16    Hash,
17    Serialize,
18    Deserialize,
19    SchemaRead,
20    SchemaWrite,
21)]
22pub struct Slice(SliceEnum);
23
24#[derive(
25    Debug,
26    Clone,
27    PartialEq,
28    Eq,
29    PartialOrd,
30    Ord,
31    Hash,
32    Serialize,
33    Deserialize,
34    SchemaRead,
35    SchemaWrite,
36)]
37#[repr(C)]
38enum SliceEnum {
39    /// A single index slice
40    Single(u32),
41    /// A slice with indices given by
42    /// ```
43    /// (0..size).map(|i| start + i * step)
44    /// ```
45    Range { start: u32, size: u32, step: i32 },
46    /// A slice with indices given by
47    /// ```
48    /// (0..size1).flat_map(|i| (0..size2).map(|j| start + step1 * i + step2 * j))
49    /// ```
50    Range2d {
51        start: u32,
52        size1: u32,
53        step1: i32,
54        size2: u32,
55        step2: i32,
56    },
57    /// A slice with indices given by a vector of slices.
58    RangeVec(Vec<SliceEnum>),
59}
60
61impl Slice {
62    pub fn empty() -> Self {
63        Self(SliceEnum::RangeVec(vec![]))
64    }
65
66    pub fn single(index: u32) -> Self {
67        Self(SliceEnum::Single(index))
68    }
69
70    pub fn range(start: u32, size: u32, step: i32) -> Result<Self, SliceError> {
71        let min_index = generate_range_indices(start, size, step).min().unwrap_or(0);
72        if min_index < 0 {
73            return Err(SliceError::NegativeIndex(min_index));
74        }
75
76        Ok(Self(SliceEnum::Range { start, size, step }))
77    }
78
79    pub fn shift_start(&mut self, delta: u32) {
80        self.0.shift_start(delta);
81    }
82
83    pub fn range2d(
84        start: u32,
85        size1: u32,
86        size2: u32,
87        step1: i32,
88        step2: i32,
89    ) -> Result<Self, SliceError> {
90        let min_index = generate_range_2d_indices(start, size1, step1, size2, step2)
91            .min()
92            .unwrap_or(0);
93        if min_index < 0 {
94            return Err(SliceError::NegativeIndex(min_index));
95        }
96
97        Ok(Self(SliceEnum::Range2d {
98            start,
99            size1,
100            step1,
101            size2,
102            step2,
103        }))
104    }
105
106    pub fn append(&mut self, other: Self) {
107        match (&mut self.0, other.0) {
108            (SliceEnum::RangeVec(v), SliceEnum::RangeVec(v1)) => v.extend(v1),
109            (SliceEnum::RangeVec(v), slice) => v.push(slice),
110            (slice, SliceEnum::RangeVec(mut v1)) => {
111                v1.insert(0, slice.clone());
112                *slice = SliceEnum::RangeVec(v1);
113            }
114            (slice, slice1) => *slice = SliceEnum::RangeVec(vec![slice.clone(), slice1]),
115        }
116    }
117
118    pub fn get_indices(&self) -> Vec<u32> {
119        self.0.get_indices()
120    }
121
122    pub fn is_empty(&self) -> bool {
123        self.len() == 0
124    }
125
126    pub fn len(&self) -> u32 {
127        self.0.len()
128    }
129
130    pub fn from_indices(indices: Vec<u32>) -> Self {
131        Self(SliceEnum::from_indices(indices))
132    }
133
134    pub fn optimize(self) -> Self {
135        Self::from_indices(self.get_indices())
136    }
137}
138
139fn generate_range_indices(start: u32, size: u32, step: i32) -> impl Iterator<Item = i32> {
140    (0..size as i32).map(move |i| start as i32 + step * i)
141}
142
143fn generate_range_2d_indices(
144    start: u32,
145    size1: u32,
146    step1: i32,
147    size2: u32,
148    step2: i32,
149) -> impl Iterator<Item = i32> {
150    (0..size1 as i32)
151        .flat_map(move |i| (0..size2 as i32).map(move |j| start as i32 + step1 * i + step2 * j))
152}
153
154impl SliceEnum {
155    fn get_indices(&self) -> Vec<u32> {
156        match self {
157            SliceEnum::Single(idx) => vec![*idx],
158            SliceEnum::Range { start, size, step } => generate_range_indices(*start, *size, *step)
159                .map(|i| i as u32)
160                .collect(),
161            SliceEnum::Range2d {
162                start,
163                size1,
164                size2,
165                step1,
166                step2,
167            } => generate_range_2d_indices(*start, *size1, *step1, *size2, *step2)
168                .map(|i| i as u32)
169                .collect(),
170            SliceEnum::RangeVec(v) => v.iter().flat_map(|r| r.get_indices()).collect(),
171        }
172    }
173
174    pub fn len(&self) -> u32 {
175        match self {
176            SliceEnum::Single(_) => 1,
177            SliceEnum::Range { size, .. } => *size,
178            SliceEnum::Range2d { size1, size2, .. } => size1 * size2,
179            SliceEnum::RangeVec(v) => v.iter().map(|r| r.len()).sum(),
180        }
181    }
182
183    /// Given a start index and a vector of deltas tries to find a slice (`Single`, `Range` or
184    /// `Range2d`) which generates the longest sequence `[index0, index0 + deltas[0], index0 +
185    /// deltas[0] + deltas[1], ...]`
186    fn match_largest_slice(start: u32, deltas: &[i32]) -> Self {
187        if deltas.is_empty() {
188            return Self::Single(start);
189        }
190
191        // The longest sequence of equal deltas generates a 1d range slice
192        // A 1d slice verifies: `deltas[..] = deltas[0] | deltas[0] | .. | deltas[0]`
193        let step_j = deltas[0];
194        let n_j = deltas.iter().skip(1).take_while(|&&d| d == step_j).count() + 2;
195
196        let mut res_slice = Self::Range {
197            start,
198            size: n_j as u32,
199            step: step_j,
200        };
201
202        if n_j < deltas.len() + 1 {
203            // If the sequence of deltas is not finished, try to match a 2d slice.
204            // A 2d slice verifies:
205            //  `deltas[..] = deltas[0..n_j] | deltas[0..n_j] | .. | deltas[0..n_j - 1]`
206            let exp_chunk = &deltas[0..n_j];
207            let chunks = deltas.chunks(n_j).skip(1);
208            let mut n_i = chunks
209                .take_while(|chunk| {
210                    izip!(exp_chunk, *chunk).take_while(|(e, d)| e == d).count() == n_j
211                })
212                .count()
213                + 1;
214            if let Some(chunk) = deltas.chunks(n_j).nth(n_i) {
215                if izip!(exp_chunk, chunk).take_while(|(e, d)| e == d).count() == n_j - 1 {
216                    n_i += 1;
217                }
218            }
219
220            if n_i > 1 {
221                let step_i = exp_chunk.iter().sum::<i32>();
222                res_slice = Self::Range2d {
223                    start,
224                    size1: n_i as u32,
225                    size2: n_j as u32,
226                    step1: step_i,
227                    step2: step_j,
228                };
229            }
230        }
231
232        res_slice
233    }
234
235    /// Reduces the current slice to a slice with at most `new_size` indices.
236    fn reduce(&mut self, max_size: u32) {
237        assert!(max_size > 0);
238        match self {
239            SliceEnum::Single(_) => {}
240            SliceEnum::Range { start, size, .. } => {
241                if max_size < *size {
242                    if max_size == 1 {
243                        *self = SliceEnum::Single(*start);
244                    } else {
245                        *size = max_size;
246                    }
247                }
248            }
249            SliceEnum::Range2d {
250                start,
251                size1,
252                size2,
253                step2,
254                ..
255            } => {
256                if max_size < *size1 * *size2 {
257                    if max_size == 1 {
258                        *self = SliceEnum::Single(*start);
259                    } else if max_size <= *size2 {
260                        *self = SliceEnum::Range {
261                            start: *start,
262                            size: max_size,
263                            step: *step2,
264                        }
265                    } else if max_size / *size2 == 1 {
266                        *self = SliceEnum::Range {
267                            start: *start,
268                            size: *size2,
269                            step: *step2,
270                        }
271                    } else {
272                        *size1 = max_size / *size2;
273                    }
274                }
275            }
276            SliceEnum::RangeVec(_) => {}
277        }
278    }
279
280    fn match_slices(mut max_len_slices: Vec<Self>) -> Vec<Self> {
281        let mut res = vec![]; // result slices with absolute start indices
282        let mut ranges_to_visit = vec![(0, max_len_slices.len())]; // start with full range
283        while let Some((start, end)) = ranges_to_visit.pop() {
284            // Find the slice which generates the longest sequence of indices in the current range
285            // `[start, end)`
286            let (slice_pos, slice) = max_len_slices[start..end]
287                .iter()
288                .enumerate()
289                .max_by_key(|(pos, slice)| (slice.len(), end - pos)) // `end - pos` is used to return the first maximum
290                .unwrap();
291            let slice_start = start + slice_pos; // to absolute position
292            let slice_end = slice_start + slice.len() as usize;
293
294            // Store the max slice for the result
295            res.push((slice_start, slice.clone()));
296
297            // Add left and right ranges to visit if they are not empty
298            if start < slice_start {
299                // Reduce the length of the slices on before the max slice to not overlap with the
300                // max slice
301                max_len_slices[start..slice_start]
302                    .iter_mut()
303                    .enumerate()
304                    .for_each(|(pos, slice)| slice.reduce((slice_pos - pos) as u32));
305
306                ranges_to_visit.push((start, slice_start));
307            }
308            if slice_end < end {
309                ranges_to_visit.push((slice_end, end));
310            }
311        }
312
313        res.sort_by_key(|(start, _)| *start);
314        res.into_iter().map(|(_, slice)| slice).collect()
315    }
316
317    /// Given a vector of indices tries to find a minimal number of slices
318    /// (`Single`, `Range` or `Range2d`) which generates the same sequence of indices.
319    ///
320    /// The algorithm finds the largest slice in input sequence `indices` and then
321    /// recursively matches slices in the left-hand and right-hand size indices which are not
322    /// covered by the largest slice.
323    pub fn from_indices(indices: Vec<u32>) -> Self {
324        if indices.is_empty() {
325            return Self::RangeVec(vec![]);
326        }
327
328        let deltas = indices
329            .windows(2)
330            .map(|w| w[1] as i32 - w[0] as i32)
331            .collect::<Vec<_>>();
332        let max_slice_vec: Vec<_> = (0..indices.len())
333            .map(|i| Self::match_largest_slice(indices[i], &deltas[i..]))
334            .collect();
335
336        let optimized_slices = SliceEnum::match_slices(max_slice_vec);
337        if optimized_slices.len() == 1 {
338            optimized_slices[0].clone()
339        } else {
340            SliceEnum::RangeVec(optimized_slices)
341        }
342    }
343
344    pub fn shift_start(&mut self, delta: u32) {
345        match self {
346            SliceEnum::Single(idx) => {
347                *idx += delta;
348            }
349            SliceEnum::Range { start, .. } => {
350                *start += delta;
351            }
352            SliceEnum::Range2d { start, .. } => {
353                *start += delta;
354            }
355            SliceEnum::RangeVec(v) => v.iter_mut().for_each(|slice| slice.shift_start(delta)),
356        }
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::SliceEnum;
363    use crate::circuit::Slice;
364
365    #[test]
366    fn test_slice_range() {
367        let range = SliceEnum::Range2d {
368            start: 0,
369            size1: 2,
370            size2: 3,
371            step1: 6,
372            step2: 1,
373        };
374        let expected = vec![0, 1, 2, 6, 7, 8];
375        assert_eq!(range.get_indices(), expected);
376
377        let range = SliceEnum::Range2d {
378            start: 0,
379            size1: 4,
380            size2: 2,
381            step1: 3,
382            step2: 1,
383        };
384        let expected = vec![0, 1, 3, 4, 6, 7, 9, 10];
385        assert_eq!(range.get_indices(), expected);
386
387        let range = SliceEnum::Range2d {
388            start: 0,
389            size1: 4,
390            size2: 2,
391            step1: 3,
392            step2: 2,
393        };
394        let expected = vec![0, 2, 3, 5, 6, 8, 9, 11];
395        assert_eq!(range.get_indices(), expected);
396
397        let range = SliceEnum::Range2d {
398            start: 2,
399            size1: 1,
400            size2: 4,
401            step1: 1,
402            step2: 3,
403        };
404        let expected = vec![2, 5, 8, 11];
405        assert_eq!(range.get_indices(), expected);
406    }
407
408    #[test]
409    fn test_slice_match_largest_slice() {
410        fn match_largest_slice(indices: &[u32]) -> SliceEnum {
411            SliceEnum::match_largest_slice(
412                indices[0],
413                &indices
414                    .windows(2)
415                    .map(|w| w[1] as i32 - w[0] as i32)
416                    .collect::<Vec<_>>(),
417            )
418        }
419
420        //// Full match
421        // single point slices
422        let indices = vec![0];
423        let slice = match_largest_slice(&indices);
424        assert_eq!(slice.get_indices(), indices);
425
426        let indices = vec![3];
427        let slice = match_largest_slice(&indices);
428        assert_eq!(slice.get_indices(), indices);
429
430        // 1d slices
431        let indices = vec![0, 1, 2, 3, 4];
432        let slice = match_largest_slice(&indices);
433        assert_eq!(slice.get_indices(), indices);
434
435        let indices = vec![5, 7, 9, 11, 13];
436        let slice = match_largest_slice(&indices);
437        assert_eq!(slice.get_indices(), indices);
438
439        let indices = vec![5, 6];
440        let slice = match_largest_slice(&indices);
441        assert_eq!(slice.get_indices(), indices);
442
443        let indices = vec![5, 2];
444        let slice = match_largest_slice(&indices);
445        assert_eq!(slice.get_indices(), indices[..2].to_vec());
446
447        // 2d slices
448        let indices = vec![0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]; // A[0..4][0..3] in a 4x5 matrix (row-major order)
449        let slice = match_largest_slice(&indices);
450        assert_eq!(slice.get_indices(), indices);
451
452        let indices = vec![2, 3, 4, 7, 8, 9]; // A[0..2][2..5] in a 2x5 matrix (row-major order)
453        let slice = match_largest_slice(&indices);
454        assert_eq!(slice.get_indices(), indices);
455
456        let indices = vec![0, 2, 8, 10]; // A[(0..3).step_by(2)][(0..4).step_by(2)] in a 3x4 matrix (row-major order)
457        let slice = match_largest_slice(&indices);
458        assert_eq!(slice.get_indices(), indices);
459
460        let indices = vec![10, 12, 5, 7, 0, 2]; // A[(0..3).reverse()][(0..3).step_by(2)] in a 3x5 matrix (row-major order)
461        let slice = match_largest_slice(&indices);
462        assert_eq!(slice.get_indices(), indices.to_vec());
463
464        //// Partial matches
465        // 1d slices
466        let indices = vec![0, 2, 4, 4, 5];
467        let slice = match_largest_slice(&indices);
468        assert_eq!(slice.get_indices(), indices[..3].to_vec());
469
470        // 2d slices
471        let indices = vec![0, 1, 3, 4, 5];
472        let slice = match_largest_slice(&indices);
473        assert_eq!(slice.get_indices(), indices[..4].to_vec());
474
475        let indices = vec![10, 12, 5, 7, 0, 2, 1];
476        let slice = match_largest_slice(&indices);
477        assert_eq!(slice.get_indices(), indices[..6].to_vec());
478
479        // Special cases
480        let indices = vec![1, 1, 0, 0, 1, 1, 0, 0];
481        let slice = match_largest_slice(&indices);
482        assert_eq!(slice.get_indices(), indices[..4].to_vec());
483    }
484
485    #[test]
486    fn test_slice_optimize() {
487        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
488        let slice = Slice::from_indices(indices.clone());
489        assert_eq!(slice.get_indices(), indices);
490        assert_eq!(
491            slice.0,
492            SliceEnum::Range {
493                start: 0,
494                size: 12,
495                step: 1,
496            }
497        );
498
499        let indices = vec![19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
500        let slice = Slice::from_indices(indices.clone());
501        assert_eq!(slice.get_indices(), indices);
502        assert_eq!(
503            slice.0,
504            SliceEnum::RangeVec(vec![
505                SliceEnum::Single(19),
506                SliceEnum::Range {
507                    start: 3,
508                    size: 9,
509                    step: 1
510                }
511            ])
512        );
513
514        let indices = vec![0, 1, 2, 19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
515        let slice = Slice::from_indices(indices.clone());
516        assert_eq!(slice.get_indices(), indices);
517        assert_eq!(
518            slice.0,
519            SliceEnum::RangeVec(vec![
520                SliceEnum::Range {
521                    start: 0,
522                    size: 3,
523                    step: 1
524                },
525                SliceEnum::Single(19),
526                SliceEnum::Range {
527                    start: 3,
528                    size: 9,
529                    step: 1
530                }
531            ])
532        );
533
534        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19];
535        let slice = Slice::from_indices(indices.clone());
536        assert_eq!(slice.get_indices(), indices);
537        assert_eq!(
538            slice.0,
539            SliceEnum::RangeVec(vec![
540                SliceEnum::Range {
541                    start: 0,
542                    size: 10,
543                    step: 1
544                },
545                SliceEnum::Single(19),
546            ])
547        );
548
549        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 10, 11];
550        let slice = Slice::from_indices(indices.clone());
551        assert_eq!(slice.get_indices(), indices);
552        assert_eq!(
553            slice.0,
554            SliceEnum::RangeVec(vec![
555                SliceEnum::Range {
556                    start: 0,
557                    size: 10,
558                    step: 1
559                },
560                SliceEnum::Range {
561                    start: 19,
562                    size: 2,
563                    step: -9
564                },
565                SliceEnum::Single(11),
566            ])
567        );
568
569        // Large example, 4000 indices
570        let mut indices = Vec::new();
571        for _i in 0..1000 {
572            indices.extend(vec![0, 1, 1, 0]);
573        }
574        let slice = Slice::from_indices(indices.clone());
575        assert_eq!(slice.get_indices(), indices);
576    }
577}