use itertools::izip;
use serde::{Deserialize, Serialize};
use wincode::{SchemaRead, SchemaWrite};
use crate::circuit::errors::SliceError;
#[derive(
Debug,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
)]
pub struct Slice(SliceEnum);
#[derive(
Debug,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
)]
#[repr(C)]
enum SliceEnum {
Single(u32),
Range { start: u32, size: u32, step: i64 },
Range2d {
start: u32,
size1: u32,
step1: i64,
size2: u32,
step2: i64,
},
RangeVec(Vec<SliceEnum>),
}
impl Slice {
pub fn empty() -> Self {
Self(SliceEnum::RangeVec(vec![]))
}
pub fn single(index: u32) -> Self {
Self(SliceEnum::Single(index))
}
pub fn range(start: u32, size: u32, step: i64) -> Result<Self, SliceError> {
validate_range_bounds(start, size, step)?;
Ok(Self(SliceEnum::Range { start, size, step }))
}
pub fn shift_start(&mut self, delta: u32) {
self.0.shift_start(delta);
}
pub fn range2d(
start: u32,
size1: u32,
size2: u32,
step1: i64,
step2: i64,
) -> Result<Self, SliceError> {
validate_range_2d_bounds(start, size1, step1, size2, step2)?;
Ok(Self(SliceEnum::Range2d {
start,
size1,
step1,
size2,
step2,
}))
}
pub fn append(&mut self, other: Self) {
match (&mut self.0, other.0) {
(SliceEnum::RangeVec(v), SliceEnum::RangeVec(v1)) => v.extend(v1),
(SliceEnum::RangeVec(v), slice) => v.push(slice),
(slice, SliceEnum::RangeVec(mut v1)) => {
v1.insert(0, slice.clone());
*slice = SliceEnum::RangeVec(v1);
}
(slice, slice1) => *slice = SliceEnum::RangeVec(vec![slice.clone(), slice1]),
}
}
pub fn get_indices(&self) -> Vec<u32> {
self.0.get_indices()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> u32 {
self.0.len()
}
pub fn from_indices(indices: Vec<u32>) -> Self {
Self(SliceEnum::from_indices(indices))
}
pub fn optimize(self) -> Self {
Self::from_indices(self.get_indices())
}
}
fn validate_bounds(min_index: i128, max_index: i128) -> Result<(), SliceError> {
if min_index < 0 {
return Err(SliceError::NegativeIndex(min_index));
}
if max_index > i128::from(u32::MAX) {
return Err(SliceError::IndexOutOfBounds {
found: max_index,
max: u32::MAX,
});
}
Ok(())
}
#[inline]
fn range_index(start: u32, step: i64, i: i64) -> i128 {
i128::from(start) + i128::from(step) * i128::from(i)
}
#[inline]
fn range_2d_index(start: u32, step1: i64, i: i64, step2: i64, j: i64) -> i128 {
i128::from(start) + i128::from(step1) * i128::from(i) + i128::from(step2) * i128::from(j)
}
fn validate_range_bounds(start: u32, size: u32, step: i64) -> Result<(), SliceError> {
if size == 0 {
return Ok(());
}
let last = i64::from(size - 1);
let first = i128::from(start);
let end = range_index(start, step, last);
validate_bounds(first.min(end), first.max(end))
}
fn validate_range_2d_bounds(
start: u32,
size1: u32,
step1: i64,
size2: u32,
step2: i64,
) -> Result<(), SliceError> {
if size1 == 0 || size2 == 0 {
return Ok(());
}
let i_last = i64::from(size1 - 1);
let j_last = i64::from(size2 - 1);
let corners = [
range_2d_index(start, step1, 0, step2, 0),
range_2d_index(start, step1, i_last, step2, 0),
range_2d_index(start, step1, 0, step2, j_last),
range_2d_index(start, step1, i_last, step2, j_last),
];
let min_index = corners.into_iter().min().unwrap_or(0);
let max_index = corners.into_iter().max().unwrap_or(0);
validate_bounds(min_index, max_index)
}
#[inline]
fn to_u32_index(index: i128) -> u32 {
u32::try_from(index).unwrap_or_else(|_| panic!("slice index out of bounds: {index}"))
}
fn generate_range_indices(start: u32, size: u32, step: i64) -> impl Iterator<Item = u32> {
(0..i64::from(size)).map(move |i| to_u32_index(range_index(start, step, i)))
}
fn generate_range_2d_indices(
start: u32,
size1: u32,
step1: i64,
size2: u32,
step2: i64,
) -> impl Iterator<Item = u32> {
(0..i64::from(size1)).flat_map(move |i| {
(0..i64::from(size2)).map(move |j| to_u32_index(range_2d_index(start, step1, i, step2, j)))
})
}
impl SliceEnum {
fn get_indices(&self) -> Vec<u32> {
match self {
SliceEnum::Single(idx) => vec![*idx],
SliceEnum::Range { start, size, step } => {
generate_range_indices(*start, *size, *step).collect()
}
SliceEnum::Range2d {
start,
size1,
size2,
step1,
step2,
} => generate_range_2d_indices(*start, *size1, *step1, *size2, *step2).collect(),
SliceEnum::RangeVec(v) => v.iter().flat_map(|r| r.get_indices()).collect(),
}
}
pub fn len(&self) -> u32 {
match self {
SliceEnum::Single(_) => 1,
SliceEnum::Range { size, .. } => *size,
SliceEnum::Range2d { size1, size2, .. } => size1
.checked_mul(*size2)
.expect("slice length overflow for range2d"),
SliceEnum::RangeVec(v) => v.iter().fold(0u32, |acc, r| {
acc.checked_add(r.len())
.expect("slice length overflow for range vector")
}),
}
}
fn match_largest_slice(start: u32, deltas: &[i64]) -> Self {
if deltas.is_empty() {
return Self::Single(start);
}
let step_j = deltas[0];
let n_j = deltas.iter().skip(1).take_while(|&&d| d == step_j).count() + 2;
let mut res_slice = Self::Range {
start,
size: n_j as u32,
step: step_j,
};
if n_j < deltas.len() + 1 {
let exp_chunk = &deltas[0..n_j];
let chunks = deltas.chunks(n_j).skip(1);
let mut n_i = chunks
.take_while(|chunk| {
izip!(exp_chunk, *chunk).take_while(|(e, d)| e == d).count() == n_j
})
.count()
+ 1;
if let Some(chunk) = deltas.chunks(n_j).nth(n_i) {
if izip!(exp_chunk, chunk).take_while(|(e, d)| e == d).count() == n_j - 1 {
n_i += 1;
}
}
if n_i > 1 {
let step_i = exp_chunk.iter().sum::<i64>();
res_slice = Self::Range2d {
start,
size1: n_i as u32,
size2: n_j as u32,
step1: step_i,
step2: step_j,
};
}
}
res_slice
}
fn reduce(&mut self, max_size: u32) {
assert!(max_size > 0);
match self {
SliceEnum::Single(_) => {}
SliceEnum::Range { start, size, .. } => {
if max_size < *size {
if max_size == 1 {
*self = SliceEnum::Single(*start);
} else {
*size = max_size;
}
}
}
SliceEnum::Range2d {
start,
size1,
size2,
step2,
..
} => {
if max_size < *size1 * *size2 {
if max_size == 1 {
*self = SliceEnum::Single(*start);
} else if max_size <= *size2 {
*self = SliceEnum::Range {
start: *start,
size: max_size,
step: *step2,
}
} else if max_size / *size2 == 1 {
*self = SliceEnum::Range {
start: *start,
size: *size2,
step: *step2,
}
} else {
*size1 = max_size / *size2;
}
}
}
SliceEnum::RangeVec(_) => {}
}
}
fn match_slices(mut max_len_slices: Vec<Self>) -> Vec<Self> {
let mut res = vec![]; let mut ranges_to_visit = vec![(0, max_len_slices.len())]; while let Some((start, end)) = ranges_to_visit.pop() {
let (slice_pos, slice) = max_len_slices[start..end]
.iter()
.enumerate()
.max_by_key(|(pos, slice)| (slice.len(), end - pos)) .unwrap();
let slice_start = start + slice_pos; let slice_end = slice_start + slice.len() as usize;
res.push((slice_start, slice.clone()));
if start < slice_start {
max_len_slices[start..slice_start]
.iter_mut()
.enumerate()
.for_each(|(pos, slice)| slice.reduce((slice_pos - pos) as u32));
ranges_to_visit.push((start, slice_start));
}
if slice_end < end {
ranges_to_visit.push((slice_end, end));
}
}
res.sort_by_key(|(start, _)| *start);
res.into_iter().map(|(_, slice)| slice).collect()
}
pub fn from_indices(indices: Vec<u32>) -> Self {
if indices.is_empty() {
return Self::RangeVec(vec![]);
}
let deltas = indices
.windows(2)
.map(|w| w[1] as i64 - w[0] as i64)
.collect::<Vec<_>>();
let max_slice_vec: Vec<_> = (0..indices.len())
.map(|i| Self::match_largest_slice(indices[i], &deltas[i..]))
.collect();
let optimized_slices = SliceEnum::match_slices(max_slice_vec);
if optimized_slices.len() == 1 {
optimized_slices[0].clone()
} else {
SliceEnum::RangeVec(optimized_slices)
}
}
pub fn shift_start(&mut self, delta: u32) {
match self {
SliceEnum::Single(idx) => {
*idx = idx
.checked_add(delta)
.expect("slice start overflow for single index");
}
SliceEnum::Range { start, .. } => {
*start = start
.checked_add(delta)
.expect("slice start overflow for range");
}
SliceEnum::Range2d { start, .. } => {
*start = start
.checked_add(delta)
.expect("slice start overflow for range2d");
}
SliceEnum::RangeVec(v) => v.iter_mut().for_each(|slice| slice.shift_start(delta)),
}
}
}
#[cfg(test)]
mod tests {
use super::SliceEnum;
use crate::circuit::{errors::SliceError, Slice};
#[test]
fn test_slice_range() {
let range = SliceEnum::Range2d {
start: 0,
size1: 2,
size2: 3,
step1: 6,
step2: 1,
};
let expected = vec![0, 1, 2, 6, 7, 8];
assert_eq!(range.get_indices(), expected);
let range = SliceEnum::Range2d {
start: 0,
size1: 4,
size2: 2,
step1: 3,
step2: 1,
};
let expected = vec![0, 1, 3, 4, 6, 7, 9, 10];
assert_eq!(range.get_indices(), expected);
let range = SliceEnum::Range2d {
start: 0,
size1: 4,
size2: 2,
step1: 3,
step2: 2,
};
let expected = vec![0, 2, 3, 5, 6, 8, 9, 11];
assert_eq!(range.get_indices(), expected);
let range = SliceEnum::Range2d {
start: 2,
size1: 1,
size2: 4,
step1: 1,
step2: 3,
};
let expected = vec![2, 5, 8, 11];
assert_eq!(range.get_indices(), expected);
}
#[test]
fn test_slice_match_largest_slice() {
fn match_largest_slice(indices: &[u32]) -> SliceEnum {
SliceEnum::match_largest_slice(
indices[0],
&indices
.windows(2)
.map(|w| w[1] as i64 - w[0] as i64)
.collect::<Vec<_>>(),
)
}
let indices = vec![0];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![3];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![0, 1, 2, 3, 4];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![5, 7, 9, 11, 13];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![5, 6];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![5, 2];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices[..2].to_vec());
let indices = vec![0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]; let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![2, 3, 4, 7, 8, 9]; let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![0, 2, 8, 10]; let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices);
let indices = vec![10, 12, 5, 7, 0, 2]; let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices.to_vec());
let indices = vec![0, 2, 4, 4, 5];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices[..3].to_vec());
let indices = vec![0, 1, 3, 4, 5];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices[..4].to_vec());
let indices = vec![10, 12, 5, 7, 0, 2, 1];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices[..6].to_vec());
let indices = vec![1, 1, 0, 0, 1, 1, 0, 0];
let slice = match_largest_slice(&indices);
assert_eq!(slice.get_indices(), indices[..4].to_vec());
}
#[test]
fn test_slice_optimize() {
let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
assert_eq!(
slice.0,
SliceEnum::Range {
start: 0,
size: 12,
step: 1,
}
);
let indices = vec![19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
assert_eq!(
slice.0,
SliceEnum::RangeVec(vec![
SliceEnum::Single(19),
SliceEnum::Range {
start: 3,
size: 9,
step: 1
}
])
);
let indices = vec![0, 1, 2, 19, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
assert_eq!(
slice.0,
SliceEnum::RangeVec(vec![
SliceEnum::Range {
start: 0,
size: 3,
step: 1
},
SliceEnum::Single(19),
SliceEnum::Range {
start: 3,
size: 9,
step: 1
}
])
);
let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19];
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
assert_eq!(
slice.0,
SliceEnum::RangeVec(vec![
SliceEnum::Range {
start: 0,
size: 10,
step: 1
},
SliceEnum::Single(19),
])
);
let indices = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 10, 11];
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
assert_eq!(
slice.0,
SliceEnum::RangeVec(vec![
SliceEnum::Range {
start: 0,
size: 10,
step: 1
},
SliceEnum::Range {
start: 19,
size: 2,
step: -9
},
SliceEnum::Single(11),
])
);
let mut indices = Vec::new();
for _i in 0..1000 {
indices.extend(vec![0, 1, 1, 0]);
}
let slice = Slice::from_indices(indices.clone());
assert_eq!(slice.get_indices(), indices);
}
#[test]
fn test_slice_checked_range_bounds() {
assert_eq!(Slice::range(0, 2, -1), Err(SliceError::NegativeIndex(-1)));
assert_eq!(
Slice::range(u32::MAX, 2, 1),
Err(SliceError::IndexOutOfBounds {
found: i128::from(u32::MAX) + 1,
max: u32::MAX
})
);
let slice = Slice::range(u32::MAX - 1, 2, 1).unwrap();
assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
}
#[test]
fn test_slice_checked_range2d_bounds() {
assert_eq!(
Slice::range2d(0, 2, 2, -1, 0),
Err(SliceError::NegativeIndex(-1))
);
assert_eq!(
Slice::range2d(u32::MAX, 2, 1, 1, 0),
Err(SliceError::IndexOutOfBounds {
found: i128::from(u32::MAX) + 1,
max: u32::MAX
})
);
let slice = Slice::range2d(u32::MAX - 1, 1, 2, 1, 1).unwrap();
assert_eq!(slice.get_indices(), vec![u32::MAX - 1, u32::MAX]);
}
}