Skip to main content

core_utils/circuit/v2/
slice.rs

1use itertools::izip;
2use serde::{Deserialize, Serialize};
3use wincode::{SchemaRead, SchemaWrite};
4
5use crate::circuit::errors::SliceError;
6
7/// A general slicing structure which can represent a single index, a strided 1d range, a strided 2d
8/// range or a vector of slices.
9#[derive(
10    Debug,
11    Clone,
12    PartialEq,
13    Eq,
14    PartialOrd,
15    Ord,
16    Hash,
17    Serialize,
18    Deserialize,
19    SchemaRead,
20    SchemaWrite,
21)]
22pub struct Slice(SliceEnum);
23
24#[derive(
25    Debug,
26    Clone,
27    PartialEq,
28    Eq,
29    PartialOrd,
30    Ord,
31    Hash,
32    Serialize,
33    Deserialize,
34    SchemaRead,
35    SchemaWrite,
36)]
37#[repr(C)]
38enum SliceEnum {
39    /// A single index slice
40    Single(u32),
41    /// A slice with indices given by
42    /// ```text
43    /// (0..size).map(|i| start + i * step)
44    /// ```
45    Range { start: u32, size: u32, step: i64 },
46    /// A slice with indices given by
47    /// ```text
48    /// (0..size1).flat_map(|i| (0..size2).map(|j| start + step1 * i + step2 * j))
49    /// ```
50    Range2d {
51        start: u32,
52        size1: u32,
53        step1: i64,
54        size2: u32,
55        step2: i64,
56    },
57    /// A slice with indices given by a vector of slices.
58    RangeVec(Vec<SliceEnum>),
59}
60
61impl Slice {
62    pub fn empty() -> Self {
63        Self(SliceEnum::RangeVec(vec![]))
64    }
65
66    pub fn single(index: u32) -> Self {
67        Self(SliceEnum::Single(index))
68    }
69
70    pub fn range(start: u32, size: u32, step: i64) -> Result<Self, SliceError> {
71        validate_range_bounds(start, size, step)?;
72        Ok(Self(SliceEnum::Range { start, size, step }))
73    }
74
75    pub fn shift_start(&mut self, delta: u32) {
76        self.0.shift_start(delta);
77    }
78
79    pub fn range2d(
80        start: u32,
81        size1: u32,
82        size2: u32,
83        step1: i64,
84        step2: i64,
85    ) -> Result<Self, SliceError> {
86        validate_range_2d_bounds(start, size1, step1, size2, step2)?;
87        Ok(Self(SliceEnum::Range2d {
88            start,
89            size1,
90            step1,
91            size2,
92            step2,
93        }))
94    }
95
96    pub fn append(&mut self, other: Self) {
97        match (&mut self.0, other.0) {
98            (SliceEnum::RangeVec(v), SliceEnum::RangeVec(v1)) => v.extend(v1),
99            (SliceEnum::RangeVec(v), slice) => v.push(slice),
100            (slice, SliceEnum::RangeVec(mut v1)) => {
101                v1.insert(0, slice.clone());
102                *slice = SliceEnum::RangeVec(v1);
103            }
104            (slice, slice1) => *slice = SliceEnum::RangeVec(vec![slice.clone(), slice1]),
105        }
106    }
107
108    pub fn get_indices(&self) -> Vec<u32> {
109        self.0.get_indices()
110    }
111
112    pub fn is_empty(&self) -> bool {
113        self.len() == 0
114    }
115
116    pub fn len(&self) -> u32 {
117        self.0.len()
118    }
119
120    pub fn from_indices(indices: Vec<u32>) -> Self {
121        Self(SliceEnum::from_indices(indices))
122    }
123
124    pub fn optimize(self) -> Self {
125        Self::from_indices(self.get_indices())
126    }
127}
128
129fn validate_bounds(min_index: i128, max_index: i128) -> Result<(), SliceError> {
130    if min_index < 0 {
131        return Err(SliceError::NegativeIndex(min_index));
132    }
133    if max_index > i128::from(u32::MAX) {
134        return Err(SliceError::IndexOutOfBounds {
135            found: max_index,
136            max: u32::MAX,
137        });
138    }
139    Ok(())
140}
141
142#[inline]
143fn range_index(start: u32, step: i64, i: i64) -> i128 {
144    i128::from(start) + i128::from(step) * i128::from(i)
145}
146
147#[inline]
148fn range_2d_index(start: u32, step1: i64, i: i64, step2: i64, j: i64) -> i128 {
149    i128::from(start) + i128::from(step1) * i128::from(i) + i128::from(step2) * i128::from(j)
150}
151
152fn validate_range_bounds(start: u32, size: u32, step: i64) -> Result<(), SliceError> {
153    if size == 0 {
154        return Ok(());
155    }
156    let last = i64::from(size - 1);
157    let first = i128::from(start);
158    let end = range_index(start, step, last);
159    validate_bounds(first.min(end), first.max(end))
160}
161
162fn validate_range_2d_bounds(
163    start: u32,
164    size1: u32,
165    step1: i64,
166    size2: u32,
167    step2: i64,
168) -> Result<(), SliceError> {
169    if size1 == 0 || size2 == 0 {
170        return Ok(());
171    }
172    let i_last = i64::from(size1 - 1);
173    let j_last = i64::from(size2 - 1);
174
175    let corners = [
176        range_2d_index(start, step1, 0, step2, 0),
177        range_2d_index(start, step1, i_last, step2, 0),
178        range_2d_index(start, step1, 0, step2, j_last),
179        range_2d_index(start, step1, i_last, step2, j_last),
180    ];
181
182    let min_index = corners.into_iter().min().unwrap_or(0);
183    let max_index = corners.into_iter().max().unwrap_or(0);
184    validate_bounds(min_index, max_index)
185}
186
187#[inline]
188fn to_u32_index(index: i128) -> u32 {
189    u32::try_from(index).unwrap_or_else(|_| panic!("slice index out of bounds: {index}"))
190}
191
192fn generate_range_indices(start: u32, size: u32, step: i64) -> impl Iterator<Item = u32> {
193    (0..i64::from(size)).map(move |i| to_u32_index(range_index(start, step, i)))
194}
195
196fn generate_range_2d_indices(
197    start: u32,
198    size1: u32,
199    step1: i64,
200    size2: u32,
201    step2: i64,
202) -> impl Iterator<Item = u32> {
203    (0..i64::from(size1)).flat_map(move |i| {
204        (0..i64::from(size2)).map(move |j| to_u32_index(range_2d_index(start, step1, i, step2, j)))
205    })
206}
207
208impl SliceEnum {
209    fn get_indices(&self) -> Vec<u32> {
210        match self {
211            SliceEnum::Single(idx) => vec![*idx],
212            SliceEnum::Range { start, size, step } => {
213                generate_range_indices(*start, *size, *step).collect()
214            }
215            SliceEnum::Range2d {
216                start,
217                size1,
218                size2,
219                step1,
220                step2,
221            } => generate_range_2d_indices(*start, *size1, *step1, *size2, *step2).collect(),
222            SliceEnum::RangeVec(v) => v.iter().flat_map(|r| r.get_indices()).collect(),
223        }
224    }
225
226    pub fn len(&self) -> u32 {
227        match self {
228            SliceEnum::Single(_) => 1,
229            SliceEnum::Range { size, .. } => *size,
230            SliceEnum::Range2d { size1, size2, .. } => size1
231                .checked_mul(*size2)
232                .expect("slice length overflow for range2d"),
233            SliceEnum::RangeVec(v) => v.iter().fold(0u32, |acc, r| {
234                acc.checked_add(r.len())
235                    .expect("slice length overflow for range vector")
236            }),
237        }
238    }
239
240    /// Given a start index and a vector of deltas tries to find a slice (`Single`, `Range` or
241    /// `Range2d`) which generates the longest sequence `[index0, index0 + deltas[0], index0 +
242    /// deltas[0] + deltas[1], ...]`
243    fn match_largest_slice(start: u32, deltas: &[i64]) -> Self {
244        if deltas.is_empty() {
245            return Self::Single(start);
246        }
247
248        // The longest sequence of equal deltas generates a 1d range slice
249        // A 1d slice verifies: `deltas[..] = deltas[0] | deltas[0] | .. | deltas[0]`
250        let step_j = deltas[0];
251        let n_j = deltas.iter().skip(1).take_while(|&&d| d == step_j).count() + 2;
252
253        let mut res_slice = Self::Range {
254            start,
255            size: n_j as u32,
256            step: step_j,
257        };
258
259        if n_j < deltas.len() + 1 {
260            // If the sequence of deltas is not finished, try to match a 2d slice.
261            // A 2d slice verifies:
262            //  `deltas[..] = deltas[0..n_j] | deltas[0..n_j] | .. | deltas[0..n_j - 1]`
263            let exp_chunk = &deltas[0..n_j];
264            let chunks = deltas.chunks(n_j).skip(1);
265            let mut n_i = chunks
266                .take_while(|chunk| {
267                    izip!(exp_chunk, *chunk).take_while(|(e, d)| e == d).count() == n_j
268                })
269                .count()
270                + 1;
271            if let Some(chunk) = deltas.chunks(n_j).nth(n_i) {
272                if izip!(exp_chunk, chunk).take_while(|(e, d)| e == d).count() == n_j - 1 {
273                    n_i += 1;
274                }
275            }
276
277            if n_i > 1 {
278                let step_i = exp_chunk.iter().sum::<i64>();
279                res_slice = Self::Range2d {
280                    start,
281                    size1: n_i as u32,
282                    size2: n_j as u32,
283                    step1: step_i,
284                    step2: step_j,
285                };
286            }
287        }
288
289        res_slice
290    }
291
292    /// Reduces the current slice to a slice with at most `new_size` indices.
293    fn reduce(&mut self, max_size: u32) {
294        assert!(max_size > 0);
295        match self {
296            SliceEnum::Single(_) => {}
297            SliceEnum::Range { start, size, .. } => {
298                if max_size < *size {
299                    if max_size == 1 {
300                        *self = SliceEnum::Single(*start);
301                    } else {
302                        *size = max_size;
303                    }
304                }
305            }
306            SliceEnum::Range2d {
307                start,
308                size1,
309                size2,
310                step2,
311                ..
312            } => {
313                if max_size < *size1 * *size2 {
314                    if max_size == 1 {
315                        *self = SliceEnum::Single(*start);
316                    } else if max_size <= *size2 {
317                        *self = SliceEnum::Range {
318                            start: *start,
319                            size: max_size,
320                            step: *step2,
321                        }
322                    } else if max_size / *size2 == 1 {
323                        *self = SliceEnum::Range {
324                            start: *start,
325                            size: *size2,
326                            step: *step2,
327                        }
328                    } else {
329                        *size1 = max_size / *size2;
330                    }
331                }
332            }
333            SliceEnum::RangeVec(_) => {}
334        }
335    }
336
337    fn match_slices(mut max_len_slices: Vec<Self>) -> Vec<Self> {
338        let mut res = vec![]; // result slices with absolute start indices
339        let mut ranges_to_visit = vec![(0, max_len_slices.len())]; // start with full range
340        while let Some((start, end)) = ranges_to_visit.pop() {
341            // Find the slice which generates the longest sequence of indices in the current range
342            // `[start, end)`
343            let (slice_pos, slice) = max_len_slices[start..end]
344                .iter()
345                .enumerate()
346                .max_by_key(|(pos, slice)| (slice.len(), end - pos)) // `end - pos` is used to return the first maximum
347                .unwrap();
348            let slice_start = start + slice_pos; // to absolute position
349            let slice_end = slice_start + slice.len() as usize;
350
351            // Store the max slice for the result
352            res.push((slice_start, slice.clone()));
353
354            // Add left and right ranges to visit if they are not empty
355            if start < slice_start {
356                // Reduce the length of the slices on before the max slice to not overlap with the
357                // max slice
358                max_len_slices[start..slice_start]
359                    .iter_mut()
360                    .enumerate()
361                    .for_each(|(pos, slice)| slice.reduce((slice_pos - pos) as u32));
362
363                ranges_to_visit.push((start, slice_start));
364            }
365            if slice_end < end {
366                ranges_to_visit.push((slice_end, end));
367            }
368        }
369
370        res.sort_by_key(|(start, _)| *start);
371        res.into_iter().map(|(_, slice)| slice).collect()
372    }
373
374    /// Given a vector of indices tries to find a minimal number of slices
375    /// (`Single`, `Range` or `Range2d`) which generates the same sequence of indices.
376    ///
377    /// The algorithm finds the largest slice in input sequence `indices` and then
378    /// recursively matches slices in the left-hand and right-hand size indices which are not
379    /// covered by the largest slice.
380    pub fn from_indices(indices: Vec<u32>) -> Self {
381        if indices.is_empty() {
382            return Self::RangeVec(vec![]);
383        }
384
385        let deltas = indices
386            .windows(2)
387            .map(|w| w[1] as i64 - w[0] as i64)
388            .collect::<Vec<_>>();
389        let max_slice_vec: Vec<_> = (0..indices.len())
390            .map(|i| Self::match_largest_slice(indices[i], &deltas[i..]))
391            .collect();
392
393        let optimized_slices = SliceEnum::match_slices(max_slice_vec);
394        if optimized_slices.len() == 1 {
395            optimized_slices[0].clone()
396        } else {
397            SliceEnum::RangeVec(optimized_slices)
398        }
399    }
400
401    pub fn shift_start(&mut self, delta: u32) {
402        match self {
403            SliceEnum::Single(idx) => {
404                *idx = idx
405                    .checked_add(delta)
406                    .expect("slice start overflow for single index");
407            }
408            SliceEnum::Range { start, .. } => {
409                *start = start
410                    .checked_add(delta)
411                    .expect("slice start overflow for range");
412            }
413            SliceEnum::Range2d { start, .. } => {
414                *start = start
415                    .checked_add(delta)
416                    .expect("slice start overflow for range2d");
417            }
418            SliceEnum::RangeVec(v) => v.iter_mut().for_each(|slice| slice.shift_start(delta)),
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::SliceEnum;
426    use crate::circuit::{errors::SliceError, Slice};
427
428    #[test]
429    fn test_slice_range() {
430        let range = SliceEnum::Range2d {
431            start: 0,
432            size1: 2,
433            size2: 3,
434            step1: 6,
435            step2: 1,
436        };
437        let expected = vec![0, 1, 2, 6, 7, 8];
438        assert_eq!(range.get_indices(), expected);
439
440        let range = SliceEnum::Range2d {
441            start: 0,
442            size1: 4,
443            size2: 2,
444            step1: 3,
445            step2: 1,
446        };
447        let expected = vec![0, 1, 3, 4, 6, 7, 9, 10];
448        assert_eq!(range.get_indices(), expected);
449
450        let range = SliceEnum::Range2d {
451            start: 0,
452            size1: 4,
453            size2: 2,
454            step1: 3,
455            step2: 2,
456        };
457        let expected = vec![0, 2, 3, 5, 6, 8, 9, 11];
458        assert_eq!(range.get_indices(), expected);
459
460        let range = SliceEnum::Range2d {
461            start: 2,
462            size1: 1,
463            size2: 4,
464            step1: 1,
465            step2: 3,
466        };
467        let expected = vec![2, 5, 8, 11];
468        assert_eq!(range.get_indices(), expected);
469    }
470
471    #[test]
472    fn test_slice_match_largest_slice() {
473        fn match_largest_slice(indices: &[u32]) -> SliceEnum {
474            SliceEnum::match_largest_slice(
475                indices[0],
476                &indices
477                    .windows(2)
478                    .map(|w| w[1] as i64 - w[0] as i64)
479                    .collect::<Vec<_>>(),
480            )
481        }
482
483        //// Full match
484        // single point slices
485        let indices = vec![0];
486        let slice = match_largest_slice(&indices);
487        assert_eq!(slice.get_indices(), indices);
488
489        let indices = vec![3];
490        let slice = match_largest_slice(&indices);
491        assert_eq!(slice.get_indices(), indices);
492
493        // 1d slices
494        let indices = vec![0, 1, 2, 3, 4];
495        let slice = match_largest_slice(&indices);
496        assert_eq!(slice.get_indices(), indices);
497
498        let indices = vec![5, 7, 9, 11, 13];
499        let slice = match_largest_slice(&indices);
500        assert_eq!(slice.get_indices(), indices);
501
502        let indices = vec![5, 6];
503        let slice = match_largest_slice(&indices);
504        assert_eq!(slice.get_indices(), indices);
505
506        let indices = vec![5, 2];
507        let slice = match_largest_slice(&indices);
508        assert_eq!(slice.get_indices(), indices[..2].to_vec());
509
510        // 2d slices
511        let indices = vec![0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]; // A[0..4][0..3] in a 4x5 matrix (row-major order)
512        let slice = match_largest_slice(&indices);
513        assert_eq!(slice.get_indices(), indices);
514
515        let indices = vec![2, 3, 4, 7, 8, 9]; // A[0..2][2..5] in a 2x5 matrix (row-major order)
516        let slice = match_largest_slice(&indices);
517        assert_eq!(slice.get_indices(), indices);
518
519        let indices = vec![0, 2, 8, 10]; // A[(0..3).step_by(2)][(0..4).step_by(2)] in a 3x4 matrix (row-major order)
520        let slice = match_largest_slice(&indices);
521        assert_eq!(slice.get_indices(), indices);
522
523        let indices = vec![10, 12, 5, 7, 0, 2]; // A[(0..3).reverse()][(0..3).step_by(2)] in a 3x5 matrix (row-major order)
524        let slice = match_largest_slice(&indices);
525        assert_eq!(slice.get_indices(), indices.to_vec());
526
527        //// Partial matches
528        // 1d slices
529        let indices = vec![0, 2, 4, 4, 5];
530        let slice = match_largest_slice(&indices);
531        assert_eq!(slice.get_indices(), indices[..3].to_vec());
532
533        // 2d slices
534        let indices = vec![0, 1, 3, 4, 5];
535        let slice = match_largest_slice(&indices);
536        assert_eq!(slice.get_indices(), indices[..4].to_vec());
537
538        let indices = vec![10, 12, 5, 7, 0, 2, 1];
539        let slice = match_largest_slice(&indices);
540        assert_eq!(slice.get_indices(), indices[..6].to_vec());
541
542        // Special cases
543        let indices = vec![1, 1, 0, 0, 1, 1, 0, 0];
544        let slice = match_largest_slice(&indices);
545        assert_eq!(slice.get_indices(), indices[..4].to_vec());
546    }
547
548    #[test]
549    fn test_slice_optimize() {
550        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
551        let slice = Slice::from_indices(indices.clone());
552        assert_eq!(slice.get_indices(), indices);
553        assert_eq!(
554            slice.0,
555            SliceEnum::Range {
556                start: 0,
557                size: 12,
558                step: 1,
559            }
560        );
561
562        let indices = vec![19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
563        let slice = Slice::from_indices(indices.clone());
564        assert_eq!(slice.get_indices(), indices);
565        assert_eq!(
566            slice.0,
567            SliceEnum::RangeVec(vec![
568                SliceEnum::Single(19),
569                SliceEnum::Range {
570                    start: 3,
571                    size: 9,
572                    step: 1
573                }
574            ])
575        );
576
577        let indices = vec![0, 1, 2, 19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
578        let slice = Slice::from_indices(indices.clone());
579        assert_eq!(slice.get_indices(), indices);
580        assert_eq!(
581            slice.0,
582            SliceEnum::RangeVec(vec![
583                SliceEnum::Range {
584                    start: 0,
585                    size: 3,
586                    step: 1
587                },
588                SliceEnum::Single(19),
589                SliceEnum::Range {
590                    start: 3,
591                    size: 9,
592                    step: 1
593                }
594            ])
595        );
596
597        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19];
598        let slice = Slice::from_indices(indices.clone());
599        assert_eq!(slice.get_indices(), indices);
600        assert_eq!(
601            slice.0,
602            SliceEnum::RangeVec(vec![
603                SliceEnum::Range {
604                    start: 0,
605                    size: 10,
606                    step: 1
607                },
608                SliceEnum::Single(19),
609            ])
610        );
611
612        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 10, 11];
613        let slice = Slice::from_indices(indices.clone());
614        assert_eq!(slice.get_indices(), indices);
615        assert_eq!(
616            slice.0,
617            SliceEnum::RangeVec(vec![
618                SliceEnum::Range {
619                    start: 0,
620                    size: 10,
621                    step: 1
622                },
623                SliceEnum::Range {
624                    start: 19,
625                    size: 2,
626                    step: -9
627                },
628                SliceEnum::Single(11),
629            ])
630        );
631
632        // Large example, 4000 indices
633        let mut indices = Vec::new();
634        for _i in 0..1000 {
635            indices.extend(vec![0, 1, 1, 0]);
636        }
637        let slice = Slice::from_indices(indices.clone());
638        assert_eq!(slice.get_indices(), indices);
639    }
640
641    #[test]
642    fn test_slice_checked_range_bounds() {
643        assert_eq!(Slice::range(0, 2, -1), Err(SliceError::NegativeIndex(-1)));
644        assert_eq!(
645            Slice::range(u32::MAX, 2, 1),
646            Err(SliceError::IndexOutOfBounds {
647                found: i128::from(u32::MAX) + 1,
648                max: u32::MAX
649            })
650        );
651
652        let slice = Slice::range(u32::MAX - 1, 2, 1).unwrap();
653        assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
654    }
655
656    #[test]
657    fn test_slice_checked_range2d_bounds() {
658        assert_eq!(
659            Slice::range2d(0, 2, 2, -1, 0),
660            Err(SliceError::NegativeIndex(-1))
661        );
662        assert_eq!(
663            Slice::range2d(u32::MAX, 2, 1, 1, 0),
664            Err(SliceError::IndexOutOfBounds {
665                found: i128::from(u32::MAX) + 1,
666                max: u32::MAX
667            })
668        );
669
670        let slice = Slice::range2d(u32::MAX - 1, 1, 2, 1, 1).unwrap();
671        assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
672    }
673}