1use itertools::izip;
2use serde::{Deserialize, Serialize};
3use wincode::{SchemaRead, SchemaWrite};
4
5use crate::circuit::errors::SliceError;
6
7#[derive(
10 Debug,
11 Clone,
12 PartialEq,
13 Eq,
14 PartialOrd,
15 Ord,
16 Hash,
17 Serialize,
18 Deserialize,
19 SchemaRead,
20 SchemaWrite,
21)]
22pub struct Slice(SliceEnum);
23
24#[derive(
25 Debug,
26 Clone,
27 PartialEq,
28 Eq,
29 PartialOrd,
30 Ord,
31 Hash,
32 Serialize,
33 Deserialize,
34 SchemaRead,
35 SchemaWrite,
36)]
37#[repr(C)]
38enum SliceEnum {
39 Single(u32),
41 Range { start: u32, size: u32, step: i32 },
46 Range2d {
51 start: u32,
52 size1: u32,
53 step1: i32,
54 size2: u32,
55 step2: i32,
56 },
57 RangeVec(Vec<SliceEnum>),
59}
60
61impl Slice {
62 pub fn empty() -> Self {
63 Self(SliceEnum::RangeVec(vec![]))
64 }
65
66 pub fn single(index: u32) -> Self {
67 Self(SliceEnum::Single(index))
68 }
69
70 pub fn range(start: u32, size: u32, step: i32) -> Result<Self, SliceError> {
71 let min_index = generate_range_indices(start, size, step).min().unwrap_or(0);
72 if min_index < 0 {
73 return Err(SliceError::NegativeIndex(min_index));
74 }
75
76 Ok(Self(SliceEnum::Range { start, size, step }))
77 }
78
79 pub fn shift_start(&mut self, delta: u32) {
80 self.0.shift_start(delta);
81 }
82
83 pub fn range2d(
84 start: u32,
85 size1: u32,
86 size2: u32,
87 step1: i32,
88 step2: i32,
89 ) -> Result<Self, SliceError> {
90 let min_index = generate_range_2d_indices(start, size1, step1, size2, step2)
91 .min()
92 .unwrap_or(0);
93 if min_index < 0 {
94 return Err(SliceError::NegativeIndex(min_index));
95 }
96
97 Ok(Self(SliceEnum::Range2d {
98 start,
99 size1,
100 step1,
101 size2,
102 step2,
103 }))
104 }
105
106 pub fn append(&mut self, other: Self) {
107 match (&mut self.0, other.0) {
108 (SliceEnum::RangeVec(v), SliceEnum::RangeVec(v1)) => v.extend(v1),
109 (SliceEnum::RangeVec(v), slice) => v.push(slice),
110 (slice, SliceEnum::RangeVec(mut v1)) => {
111 v1.insert(0, slice.clone());
112 *slice = SliceEnum::RangeVec(v1);
113 }
114 (slice, slice1) => *slice = SliceEnum::RangeVec(vec![slice.clone(), slice1]),
115 }
116 }
117
118 pub fn get_indices(&self) -> Vec<u32> {
119 self.0.get_indices()
120 }
121
122 pub fn is_empty(&self) -> bool {
123 self.len() == 0
124 }
125
126 pub fn len(&self) -> u32 {
127 self.0.len()
128 }
129
130 pub fn from_indices(indices: Vec<u32>) -> Self {
131 Self(SliceEnum::from_indices(indices))
132 }
133
134 pub fn optimize(self) -> Self {
135 Self::from_indices(self.get_indices())
136 }
137}
138
139fn generate_range_indices(start: u32, size: u32, step: i32) -> impl Iterator<Item = i32> {
140 (0..size as i32).map(move |i| start as i32 + step * i)
141}
142
143fn generate_range_2d_indices(
144 start: u32,
145 size1: u32,
146 step1: i32,
147 size2: u32,
148 step2: i32,
149) -> impl Iterator<Item = i32> {
150 (0..size1 as i32)
151 .flat_map(move |i| (0..size2 as i32).map(move |j| start as i32 + step1 * i + step2 * j))
152}
153
154impl SliceEnum {
155 fn get_indices(&self) -> Vec<u32> {
156 match self {
157 SliceEnum::Single(idx) => vec![*idx],
158 SliceEnum::Range { start, size, step } => generate_range_indices(*start, *size, *step)
159 .map(|i| i as u32)
160 .collect(),
161 SliceEnum::Range2d {
162 start,
163 size1,
164 size2,
165 step1,
166 step2,
167 } => generate_range_2d_indices(*start, *size1, *step1, *size2, *step2)
168 .map(|i| i as u32)
169 .collect(),
170 SliceEnum::RangeVec(v) => v.iter().flat_map(|r| r.get_indices()).collect(),
171 }
172 }
173
174 pub fn len(&self) -> u32 {
175 match self {
176 SliceEnum::Single(_) => 1,
177 SliceEnum::Range { size, .. } => *size,
178 SliceEnum::Range2d { size1, size2, .. } => size1 * size2,
179 SliceEnum::RangeVec(v) => v.iter().map(|r| r.len()).sum(),
180 }
181 }
182
183 fn match_largest_slice(start: u32, deltas: &[i32]) -> Self {
187 if deltas.is_empty() {
188 return Self::Single(start);
189 }
190
191 let step_j = deltas[0];
194 let n_j = deltas.iter().skip(1).take_while(|&&d| d == step_j).count() + 2;
195
196 let mut res_slice = Self::Range {
197 start,
198 size: n_j as u32,
199 step: step_j,
200 };
201
202 if n_j < deltas.len() + 1 {
203 let exp_chunk = &deltas[0..n_j];
207 let chunks = deltas.chunks(n_j).skip(1);
208 let mut n_i = chunks
209 .take_while(|chunk| {
210 izip!(exp_chunk, *chunk).take_while(|(e, d)| e == d).count() == n_j
211 })
212 .count()
213 + 1;
214 if let Some(chunk) = deltas.chunks(n_j).nth(n_i) {
215 if izip!(exp_chunk, chunk).take_while(|(e, d)| e == d).count() == n_j - 1 {
216 n_i += 1;
217 }
218 }
219
220 if n_i > 1 {
221 let step_i = exp_chunk.iter().sum::<i32>();
222 res_slice = Self::Range2d {
223 start,
224 size1: n_i as u32,
225 size2: n_j as u32,
226 step1: step_i,
227 step2: step_j,
228 };
229 }
230 }
231
232 res_slice
233 }
234
235 fn reduce(&mut self, max_size: u32) {
237 assert!(max_size > 0);
238 match self {
239 SliceEnum::Single(_) => {}
240 SliceEnum::Range { start, size, .. } => {
241 if max_size < *size {
242 if max_size == 1 {
243 *self = SliceEnum::Single(*start);
244 } else {
245 *size = max_size;
246 }
247 }
248 }
249 SliceEnum::Range2d {
250 start,
251 size1,
252 size2,
253 step2,
254 ..
255 } => {
256 if max_size < *size1 * *size2 {
257 if max_size == 1 {
258 *self = SliceEnum::Single(*start);
259 } else if max_size <= *size2 {
260 *self = SliceEnum::Range {
261 start: *start,
262 size: max_size,
263 step: *step2,
264 }
265 } else if max_size / *size2 == 1 {
266 *self = SliceEnum::Range {
267 start: *start,
268 size: *size2,
269 step: *step2,
270 }
271 } else {
272 *size1 = max_size / *size2;
273 }
274 }
275 }
276 SliceEnum::RangeVec(_) => {}
277 }
278 }
279
280 fn match_slices(mut max_len_slices: Vec<Self>) -> Vec<Self> {
281 let mut res = vec![]; let mut ranges_to_visit = vec![(0, max_len_slices.len())]; while let Some((start, end)) = ranges_to_visit.pop() {
284 let (slice_pos, slice) = max_len_slices[start..end]
287 .iter()
288 .enumerate()
289 .max_by_key(|(pos, slice)| (slice.len(), end - pos)) .unwrap();
291 let slice_start = start + slice_pos; let slice_end = slice_start + slice.len() as usize;
293
294 res.push((slice_start, slice.clone()));
296
297 if start < slice_start {
299 max_len_slices[start..slice_start]
302 .iter_mut()
303 .enumerate()
304 .for_each(|(pos, slice)| slice.reduce((slice_pos - pos) as u32));
305
306 ranges_to_visit.push((start, slice_start));
307 }
308 if slice_end < end {
309 ranges_to_visit.push((slice_end, end));
310 }
311 }
312
313 res.sort_by_key(|(start, _)| *start);
314 res.into_iter().map(|(_, slice)| slice).collect()
315 }
316
317 pub fn from_indices(indices: Vec<u32>) -> Self {
324 if indices.is_empty() {
325 return Self::RangeVec(vec![]);
326 }
327
328 let deltas = indices
329 .windows(2)
330 .map(|w| w[1] as i32 - w[0] as i32)
331 .collect::<Vec<_>>();
332 let max_slice_vec: Vec<_> = (0..indices.len())
333 .map(|i| Self::match_largest_slice(indices[i], &deltas[i..]))
334 .collect();
335
336 let optimized_slices = SliceEnum::match_slices(max_slice_vec);
337 if optimized_slices.len() == 1 {
338 optimized_slices[0].clone()
339 } else {
340 SliceEnum::RangeVec(optimized_slices)
341 }
342 }
343
344 pub fn shift_start(&mut self, delta: u32) {
345 match self {
346 SliceEnum::Single(idx) => {
347 *idx += delta;
348 }
349 SliceEnum::Range { start, .. } => {
350 *start += delta;
351 }
352 SliceEnum::Range2d { start, .. } => {
353 *start += delta;
354 }
355 SliceEnum::RangeVec(v) => v.iter_mut().for_each(|slice| slice.shift_start(delta)),
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::SliceEnum;
363 use crate::circuit::Slice;
364
365 #[test]
366 fn test_slice_range() {
367 let range = SliceEnum::Range2d {
368 start: 0,
369 size1: 2,
370 size2: 3,
371 step1: 6,
372 step2: 1,
373 };
374 let expected = vec![0, 1, 2, 6, 7, 8];
375 assert_eq!(range.get_indices(), expected);
376
377 let range = SliceEnum::Range2d {
378 start: 0,
379 size1: 4,
380 size2: 2,
381 step1: 3,
382 step2: 1,
383 };
384 let expected = vec![0, 1, 3, 4, 6, 7, 9, 10];
385 assert_eq!(range.get_indices(), expected);
386
387 let range = SliceEnum::Range2d {
388 start: 0,
389 size1: 4,
390 size2: 2,
391 step1: 3,
392 step2: 2,
393 };
394 let expected = vec![0, 2, 3, 5, 6, 8, 9, 11];
395 assert_eq!(range.get_indices(), expected);
396
397 let range = SliceEnum::Range2d {
398 start: 2,
399 size1: 1,
400 size2: 4,
401 step1: 1,
402 step2: 3,
403 };
404 let expected = vec![2, 5, 8, 11];
405 assert_eq!(range.get_indices(), expected);
406 }
407
408 #[test]
409 fn test_slice_match_largest_slice() {
410 fn match_largest_slice(indices: &[u32]) -> SliceEnum {
411 SliceEnum::match_largest_slice(
412 indices[0],
413 &indices
414 .windows(2)
415 .map(|w| w[1] as i32 - w[0] as i32)
416 .collect::<Vec<_>>(),
417 )
418 }
419
420 let indices = vec![0];
423 let slice = match_largest_slice(&indices);
424 assert_eq!(slice.get_indices(), indices);
425
426 let indices = vec![3];
427 let slice = match_largest_slice(&indices);
428 assert_eq!(slice.get_indices(), indices);
429
430 let indices = vec![0, 1, 2, 3, 4];
432 let slice = match_largest_slice(&indices);
433 assert_eq!(slice.get_indices(), indices);
434
435 let indices = vec![5, 7, 9, 11, 13];
436 let slice = match_largest_slice(&indices);
437 assert_eq!(slice.get_indices(), indices);
438
439 let indices = vec![5, 6];
440 let slice = match_largest_slice(&indices);
441 assert_eq!(slice.get_indices(), indices);
442
443 let indices = vec![5, 2];
444 let slice = match_largest_slice(&indices);
445 assert_eq!(slice.get_indices(), indices[..2].to_vec());
446
447 let indices = vec![0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]; let slice = match_largest_slice(&indices);
450 assert_eq!(slice.get_indices(), indices);
451
452 let indices = vec![2, 3, 4, 7, 8, 9]; let slice = match_largest_slice(&indices);
454 assert_eq!(slice.get_indices(), indices);
455
456 let indices = vec![0, 2, 8, 10]; let slice = match_largest_slice(&indices);
458 assert_eq!(slice.get_indices(), indices);
459
460 let indices = vec![10, 12, 5, 7, 0, 2]; let slice = match_largest_slice(&indices);
462 assert_eq!(slice.get_indices(), indices.to_vec());
463
464 let indices = vec![0, 2, 4, 4, 5];
467 let slice = match_largest_slice(&indices);
468 assert_eq!(slice.get_indices(), indices[..3].to_vec());
469
470 let indices = vec![0, 1, 3, 4, 5];
472 let slice = match_largest_slice(&indices);
473 assert_eq!(slice.get_indices(), indices[..4].to_vec());
474
475 let indices = vec![10, 12, 5, 7, 0, 2, 1];
476 let slice = match_largest_slice(&indices);
477 assert_eq!(slice.get_indices(), indices[..6].to_vec());
478
479 let indices = vec![1, 1, 0, 0, 1, 1, 0, 0];
481 let slice = match_largest_slice(&indices);
482 assert_eq!(slice.get_indices(), indices[..4].to_vec());
483 }
484
485 #[test]
486 fn test_slice_optimize() {
487 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
488 let slice = Slice::from_indices(indices.clone());
489 assert_eq!(slice.get_indices(), indices);
490 assert_eq!(
491 slice.0,
492 SliceEnum::Range {
493 start: 0,
494 size: 12,
495 step: 1,
496 }
497 );
498
499 let indices = vec![19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
500 let slice = Slice::from_indices(indices.clone());
501 assert_eq!(slice.get_indices(), indices);
502 assert_eq!(
503 slice.0,
504 SliceEnum::RangeVec(vec![
505 SliceEnum::Single(19),
506 SliceEnum::Range {
507 start: 3,
508 size: 9,
509 step: 1
510 }
511 ])
512 );
513
514 let indices = vec![0, 1, 2, 19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
515 let slice = Slice::from_indices(indices.clone());
516 assert_eq!(slice.get_indices(), indices);
517 assert_eq!(
518 slice.0,
519 SliceEnum::RangeVec(vec![
520 SliceEnum::Range {
521 start: 0,
522 size: 3,
523 step: 1
524 },
525 SliceEnum::Single(19),
526 SliceEnum::Range {
527 start: 3,
528 size: 9,
529 step: 1
530 }
531 ])
532 );
533
534 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19];
535 let slice = Slice::from_indices(indices.clone());
536 assert_eq!(slice.get_indices(), indices);
537 assert_eq!(
538 slice.0,
539 SliceEnum::RangeVec(vec![
540 SliceEnum::Range {
541 start: 0,
542 size: 10,
543 step: 1
544 },
545 SliceEnum::Single(19),
546 ])
547 );
548
549 let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 10, 11];
550 let slice = Slice::from_indices(indices.clone());
551 assert_eq!(slice.get_indices(), indices);
552 assert_eq!(
553 slice.0,
554 SliceEnum::RangeVec(vec![
555 SliceEnum::Range {
556 start: 0,
557 size: 10,
558 step: 1
559 },
560 SliceEnum::Range {
561 start: 19,
562 size: 2,
563 step: -9
564 },
565 SliceEnum::Single(11),
566 ])
567 );
568
569 let mut indices = Vec::new();
571 for _i in 0..1000 {
572 indices.extend(vec![0, 1, 1, 0]);
573 }
574 let slice = Slice::from_indices(indices.clone());
575 assert_eq!(slice.get_indices(), indices);
576 }
577}