1use itertools::izip;
2use serde::{Deserialize, Serialize};
3use wincode::{SchemaRead, SchemaWrite};
4
5use crate::circuit::errors::SliceError;
6
7#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 PartialOrd,
15 Ord,
16 Hash,
17 Serialize,
18 Deserialize,
19 SchemaRead,
20 SchemaWrite,
21)]
22pub struct Slice(SliceEnum);
23
24#[derive(
25 Debug,
26 Clone,
27 PartialEq,
28 Eq,
29 PartialOrd,
30 Ord,
31 Hash,
32 Serialize,
33 Deserialize,
34 SchemaRead,
35 SchemaWrite,
36)]
37#[repr(C)]
38enum SliceEnum {
39 Single(u32),
41 Range { start: u32, size: u32, step: i64 },
46 Range2d {
51 start: u32,
52 size1: u32,
53 step1: i64,
54 size2: u32,
55 step2: i64,
56 },
57 RangeVec(Vec<SliceEnum>),
59}
60
61impl Slice {
62 pub fn empty() -> Self {
63 Self(SliceEnum::RangeVec(vec![]))
64 }
65
66 pub fn single(index: u32) -> Self {
67 Self(SliceEnum::Single(index))
68 }
69
70 pub fn range(start: u32, size: u32, step: i64) -> Result<Self, SliceError> {
71 validate_range_bounds(start, size, step)?;
72 Ok(Self(SliceEnum::Range { start, size, step }))
73 }
74
75 pub fn shift_start(&mut self, delta: u32) {
76 self.0.shift_start(delta);
77 }
78
79 pub fn range2d(
80 start: u32,
81 size1: u32,
82 size2: u32,
83 step1: i64,
84 step2: i64,
85 ) -> Result<Self, SliceError> {
86 validate_range_2d_bounds(start, size1, step1, size2, step2)?;
87 Ok(Self(SliceEnum::Range2d {
88 start,
89 size1,
90 step1,
91 size2,
92 step2,
93 }))
94 }
95
96 pub fn append(&mut self, other: Self) {
97 match (&mut self.0, other.0) {
98 (SliceEnum::RangeVec(v), SliceEnum::RangeVec(v1)) => v.extend(v1),
99 (SliceEnum::RangeVec(v), slice) => v.push(slice),
100 (slice, SliceEnum::RangeVec(mut v1)) => {
101 v1.insert(0, slice.clone());
102 *slice = SliceEnum::RangeVec(v1);
103 }
104 (slice, slice1) => *slice = SliceEnum::RangeVec(vec![slice.clone(), slice1]),
105 }
106 }
107
108 pub fn get_indices(&self) -> Vec<u32> {
109 self.0.get_indices()
110 }
111
112 pub fn is_empty(&self) -> bool {
113 self.len() == 0
114 }
115
116 pub fn len(&self) -> u32 {
117 self.0.len()
118 }
119
120 pub fn from_indices(indices: Vec<u32>) -> Self {
121 Self(SliceEnum::from_indices(indices))
122 }
123
124 pub fn optimize(self) -> Self {
125 Self::from_indices(self.get_indices())
126 }
127}
128
129fn validate_bounds(min_index: i128, max_index: i128) -> Result<(), SliceError> {
130 if min_index < 0 {
131 return Err(SliceError::NegativeIndex(min_index));
132 }
133 if max_index > i128::from(u32::MAX) {
134 return Err(SliceError::IndexOutOfBounds {
135 found: max_index,
136 max: u32::MAX,
137 });
138 }
139 Ok(())
140}
141
142#[inline]
143fn range_index(start: u32, step: i64, i: i64) -> i128 {
144 i128::from(start) + i128::from(step) * i128::from(i)
145}
146
147#[inline]
148fn range_2d_index(start: u32, step1: i64, i: i64, step2: i64, j: i64) -> i128 {
149 i128::from(start) + i128::from(step1) * i128::from(i) + i128::from(step2) * i128::from(j)
150}
151
152fn validate_range_bounds(start: u32, size: u32, step: i64) -> Result<(), SliceError> {
153 if size == 0 {
154 return Ok(());
155 }
156 let last = i64::from(size - 1);
157 let first = i128::from(start);
158 let end = range_index(start, step, last);
159 validate_bounds(first.min(end), first.max(end))
160}
161
162fn validate_range_2d_bounds(
163 start: u32,
164 size1: u32,
165 step1: i64,
166 size2: u32,
167 step2: i64,
168) -> Result<(), SliceError> {
169 if size1 == 0 || size2 == 0 {
170 return Ok(());
171 }
172 let i_last = i64::from(size1 - 1);
173 let j_last = i64::from(size2 - 1);
174
175 let corners = [
176 range_2d_index(start, step1, 0, step2, 0),
177 range_2d_index(start, step1, i_last, step2, 0),
178 range_2d_index(start, step1, 0, step2, j_last),
179 range_2d_index(start, step1, i_last, step2, j_last),
180 ];
181
182 let min_index = corners.into_iter().min().unwrap_or(0);
183 let max_index = corners.into_iter().max().unwrap_or(0);
184 validate_bounds(min_index, max_index)
185}
186
187#[inline]
188fn to_u32_index(index: i128) -> u32 {
189 u32::try_from(index).unwrap_or_else(|_| panic!("slice index out of bounds: {index}"))
190}
191
192fn generate_range_indices(start: u32, size: u32, step: i64) -> impl Iterator<Item = u32> {
193 (0..i64::from(size)).map(move |i| to_u32_index(range_index(start, step, i)))
194}
195
196fn generate_range_2d_indices(
197 start: u32,
198 size1: u32,
199 step1: i64,
200 size2: u32,
201 step2: i64,
202) -> impl Iterator<Item = u32> {
203 (0..i64::from(size1)).flat_map(move |i| {
204 (0..i64::from(size2)).map(move |j| to_u32_index(range_2d_index(start, step1, i, step2, j)))
205 })
206}
207
208impl SliceEnum {
209 fn get_indices(&self) -> Vec<u32> {
210 match self {
211 SliceEnum::Single(idx) => vec![*idx],
212 SliceEnum::Range { start, size, step } => {
213 generate_range_indices(*start, *size, *step).collect()
214 }
215 SliceEnum::Range2d {
216 start,
217 size1,
218 size2,
219 step1,
220 step2,
221 } => generate_range_2d_indices(*start, *size1, *step1, *size2, *step2).collect(),
222 SliceEnum::RangeVec(v) => v.iter().flat_map(|r| r.get_indices()).collect(),
223 }
224 }
225
226 pub fn len(&self) -> u32 {
227 match self {
228 SliceEnum::Single(_) => 1,
229 SliceEnum::Range { size, .. } => *size,
230 SliceEnum::Range2d { size1, size2, .. } => size1
231 .checked_mul(*size2)
232 .expect("slice length overflow for range2d"),
233 SliceEnum::RangeVec(v) => v.iter().fold(0u32, |acc, r| {
234 acc.checked_add(r.len())
235 .expect("slice length overflow for range vector")
236 }),
237 }
238 }
239
240 fn match_largest_slice(start: u32, deltas: &[i64]) -> Self {
244 if deltas.is_empty() {
245 return Self::Single(start);
246 }
247
248 let step_j = deltas[0];
251 let n_j = deltas.iter().skip(1).take_while(|&&d| d == step_j).count() + 2;
252
253 let mut res_slice = Self::Range {
254 start,
255 size: n_j as u32,
256 step: step_j,
257 };
258
259 if n_j < deltas.len() + 1 {
260 let exp_chunk = &deltas[0..n_j];
264 let chunks = deltas.chunks(n_j).skip(1);
265 let mut n_i = chunks
266 .take_while(|chunk| {
267 izip!(exp_chunk, *chunk).take_while(|(e, d)| e == d).count() == n_j
268 })
269 .count()
270 + 1;
271 if let Some(chunk) = deltas.chunks(n_j).nth(n_i) {
272 if izip!(exp_chunk, chunk).take_while(|(e, d)| e == d).count() == n_j - 1 {
273 n_i += 1;
274 }
275 }
276
277 if n_i > 1 {
278 let step_i = exp_chunk.iter().sum::<i64>();
279 res_slice = Self::Range2d {
280 start,
281 size1: n_i as u32,
282 size2: n_j as u32,
283 step1: step_i,
284 step2: step_j,
285 };
286 }
287 }
288
289 res_slice
290 }
291
292 fn reduce(&mut self, max_size: u32) {
294 assert!(max_size > 0);
295 match self {
296 SliceEnum::Single(_) => {}
297 SliceEnum::Range { start, size, .. } => {
298 if max_size < *size {
299 if max_size == 1 {
300 *self = SliceEnum::Single(*start);
301 } else {
302 *size = max_size;
303 }
304 }
305 }
306 SliceEnum::Range2d {
307 start,
308 size1,
309 size2,
310 step2,
311 ..
312 } => {
313 if max_size < *size1 * *size2 {
314 if max_size == 1 {
315 *self = SliceEnum::Single(*start);
316 } else if max_size <= *size2 {
317 *self = SliceEnum::Range {
318 start: *start,
319 size: max_size,
320 step: *step2,
321 }
322 } else if max_size / *size2 == 1 {
323 *self = SliceEnum::Range {
324 start: *start,
325 size: *size2,
326 step: *step2,
327 }
328 } else {
329 *size1 = max_size / *size2;
330 }
331 }
332 }
333 SliceEnum::RangeVec(_) => {}
334 }
335 }
336
337 fn match_slices(mut max_len_slices: Vec<Self>) -> Vec<Self> {
338 let mut res = vec![]; let mut ranges_to_visit = vec![(0, max_len_slices.len())]; while let Some((start, end)) = ranges_to_visit.pop() {
341 let (slice_pos, slice) = max_len_slices[start..end]
344 .iter()
345 .enumerate()
346 .max_by_key(|(pos, slice)| (slice.len(), end - pos)) .unwrap();
348 let slice_start = start + slice_pos; let slice_end = slice_start + slice.len() as usize;
350
351 res.push((slice_start, slice.clone()));
353
354 if start < slice_start {
356 max_len_slices[start..slice_start]
359 .iter_mut()
360 .enumerate()
361 .for_each(|(pos, slice)| slice.reduce((slice_pos - pos) as u32));
362
363 ranges_to_visit.push((start, slice_start));
364 }
365 if slice_end < end {
366 ranges_to_visit.push((slice_end, end));
367 }
368 }
369
370 res.sort_by_key(|(start, _)| *start);
371 res.into_iter().map(|(_, slice)| slice).collect()
372 }
373
374 pub fn from_indices(indices: Vec<u32>) -> Self {
381 if indices.is_empty() {
382 return Self::RangeVec(vec![]);
383 }
384
385 let deltas = indices
386 .windows(2)
387 .map(|w| w[1] as i64 - w[0] as i64)
388 .collect::<Vec<_>>();
389 let max_slice_vec: Vec<_> = (0..indices.len())
390 .map(|i| Self::match_largest_slice(indices[i], &deltas[i..]))
391 .collect();
392
393 let optimized_slices = SliceEnum::match_slices(max_slice_vec);
394 if optimized_slices.len() == 1 {
395 optimized_slices[0].clone()
396 } else {
397 SliceEnum::RangeVec(optimized_slices)
398 }
399 }
400
401 pub fn shift_start(&mut self, delta: u32) {
402 match self {
403 SliceEnum::Single(idx) => {
404 *idx = idx
405 .checked_add(delta)
406 .expect("slice start overflow for single index");
407 }
408 SliceEnum::Range { start, .. } => {
409 *start = start
410 .checked_add(delta)
411 .expect("slice start overflow for range");
412 }
413 SliceEnum::Range2d { start, .. } => {
414 *start = start
415 .checked_add(delta)
416 .expect("slice start overflow for range2d");
417 }
418 SliceEnum::RangeVec(v) => v.iter_mut().for_each(|slice| slice.shift_start(delta)),
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::SliceEnum;
426 use crate::circuit::{errors::SliceError, Slice};
427
428 #[test]
429 fn test_slice_range() {
430 let range = SliceEnum::Range2d {
431 start: 0,
432 size1: 2,
433 size2: 3,
434 step1: 6,
435 step2: 1,
436 };
437 let expected = vec![0, 1, 2, 6, 7, 8];
438 assert_eq!(range.get_indices(), expected);
439
440 let range = SliceEnum::Range2d {
441 start: 0,
442 size1: 4,
443 size2: 2,
444 step1: 3,
445 step2: 1,
446 };
447 let expected = vec![0, 1, 3, 4, 6, 7, 9, 10];
448 assert_eq!(range.get_indices(), expected);
449
450 let range = SliceEnum::Range2d {
451 start: 0,
452 size1: 4,
453 size2: 2,
454 step1: 3,
455 step2: 2,
456 };
457 let expected = vec![0, 2, 3, 5, 6, 8, 9, 11];
458 assert_eq!(range.get_indices(), expected);
459
460 let range = SliceEnum::Range2d {
461 start: 2,
462 size1: 1,
463 size2: 4,
464 step1: 1,
465 step2: 3,
466 };
467 let expected = vec![2, 5, 8, 11];
468 assert_eq!(range.get_indices(), expected);
469 }
470
471 #[test]
472 fn test_slice_match_largest_slice() {
473 fn match_largest_slice(indices: &[u32]) -> SliceEnum {
474 SliceEnum::match_largest_slice(
475 indices[0],
476 &indices
477 .windows(2)
478 .map(|w| w[1] as i64 - w[0] as i64)
479 .collect::<Vec<_>>(),
480 )
481 }
482
483 let indices = vec![0];
486 let slice = match_largest_slice(&indices);
487 assert_eq!(slice.get_indices(), indices);
488
489 let indices = vec![3];
490 let slice = match_largest_slice(&indices);
491 assert_eq!(slice.get_indices(), indices);
492
493 let indices = vec![0, 1, 2, 3, 4];
495 let slice = match_largest_slice(&indices);
496 assert_eq!(slice.get_indices(), indices);
497
498 let indices = vec![5, 7, 9, 11, 13];
499 let slice = match_largest_slice(&indices);
500 assert_eq!(slice.get_indices(), indices);
501
502 let indices = vec![5, 6];
503 let slice = match_largest_slice(&indices);
504 assert_eq!(slice.get_indices(), indices);
505
506 let indices = vec![5, 2];
507 let slice = match_largest_slice(&indices);
508 assert_eq!(slice.get_indices(), indices[..2].to_vec());
509
510 let indices = vec![0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]; let slice = match_largest_slice(&indices);
513 assert_eq!(slice.get_indices(), indices);
514
515 let indices = vec![2, 3, 4, 7, 8, 9]; let slice = match_largest_slice(&indices);
517 assert_eq!(slice.get_indices(), indices);
518
519 let indices = vec![0, 2, 8, 10]; let slice = match_largest_slice(&indices);
521 assert_eq!(slice.get_indices(), indices);
522
523 let indices = vec![10, 12, 5, 7, 0, 2]; let slice = match_largest_slice(&indices);
525 assert_eq!(slice.get_indices(), indices.to_vec());
526
527 let indices = vec![0, 2, 4, 4, 5];
530 let slice = match_largest_slice(&indices);
531 assert_eq!(slice.get_indices(), indices[..3].to_vec());
532
533 let indices = vec![0, 1, 3, 4, 5];
535 let slice = match_largest_slice(&indices);
536 assert_eq!(slice.get_indices(), indices[..4].to_vec());
537
538 let indices = vec![10, 12, 5, 7, 0, 2, 1];
539 let slice = match_largest_slice(&indices);
540 assert_eq!(slice.get_indices(), indices[..6].to_vec());
541
542 let indices = vec![1, 1, 0, 0, 1, 1, 0, 0];
544 let slice = match_largest_slice(&indices);
545 assert_eq!(slice.get_indices(), indices[..4].to_vec());
546 }
547
548 #[test]
549 fn test_slice_optimize() {
550 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
551 let slice = Slice::from_indices(indices.clone());
552 assert_eq!(slice.get_indices(), indices);
553 assert_eq!(
554 slice.0,
555 SliceEnum::Range {
556 start: 0,
557 size: 12,
558 step: 1,
559 }
560 );
561
562 let indices = vec![19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
563 let slice = Slice::from_indices(indices.clone());
564 assert_eq!(slice.get_indices(), indices);
565 assert_eq!(
566 slice.0,
567 SliceEnum::RangeVec(vec![
568 SliceEnum::Single(19),
569 SliceEnum::Range {
570 start: 3,
571 size: 9,
572 step: 1
573 }
574 ])
575 );
576
577 let indices = vec![0, 1, 2, 19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
578 let slice = Slice::from_indices(indices.clone());
579 assert_eq!(slice.get_indices(), indices);
580 assert_eq!(
581 slice.0,
582 SliceEnum::RangeVec(vec![
583 SliceEnum::Range {
584 start: 0,
585 size: 3,
586 step: 1
587 },
588 SliceEnum::Single(19),
589 SliceEnum::Range {
590 start: 3,
591 size: 9,
592 step: 1
593 }
594 ])
595 );
596
597 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19];
598 let slice = Slice::from_indices(indices.clone());
599 assert_eq!(slice.get_indices(), indices);
600 assert_eq!(
601 slice.0,
602 SliceEnum::RangeVec(vec![
603 SliceEnum::Range {
604 start: 0,
605 size: 10,
606 step: 1
607 },
608 SliceEnum::Single(19),
609 ])
610 );
611
612 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 10, 11];
613 let slice = Slice::from_indices(indices.clone());
614 assert_eq!(slice.get_indices(), indices);
615 assert_eq!(
616 slice.0,
617 SliceEnum::RangeVec(vec![
618 SliceEnum::Range {
619 start: 0,
620 size: 10,
621 step: 1
622 },
623 SliceEnum::Range {
624 start: 19,
625 size: 2,
626 step: -9
627 },
628 SliceEnum::Single(11),
629 ])
630 );
631
632 let mut indices = Vec::new();
634 for _i in 0..1000 {
635 indices.extend(vec![0, 1, 1, 0]);
636 }
637 let slice = Slice::from_indices(indices.clone());
638 assert_eq!(slice.get_indices(), indices);
639 }
640
641 #[test]
642 fn test_slice_checked_range_bounds() {
643 assert_eq!(Slice::range(0, 2, -1), Err(SliceError::NegativeIndex(-1)));
644 assert_eq!(
645 Slice::range(u32::MAX, 2, 1),
646 Err(SliceError::IndexOutOfBounds {
647 found: i128::from(u32::MAX) + 1,
648 max: u32::MAX
649 })
650 );
651
652 let slice = Slice::range(u32::MAX - 1, 2, 1).unwrap();
653 assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
654 }
655
656 #[test]
657 fn test_slice_checked_range2d_bounds() {
658 assert_eq!(
659 Slice::range2d(0, 2, 2, -1, 0),
660 Err(SliceError::NegativeIndex(-1))
661 );
662 assert_eq!(
663 Slice::range2d(u32::MAX, 2, 1, 1, 0),
664 Err(SliceError::IndexOutOfBounds {
665 found: i128::from(u32::MAX) + 1,
666 max: u32::MAX
667 })
668 );
669
670 let slice = Slice::range2d(u32::MAX - 1, 1, 2, 1, 1).unwrap();
671 assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
672 }
673}