use std::{cmp, convert::TryFrom, fmt::Display};
pub struct VecChunkIter<Idx> {
inner: NonOverlappingIntegerPairIter<Idx>,
}
impl<Idx: PartialOrd + Copy + Display> VecChunkIter<Idx>
where u64: From<Idx>
{
pub fn new(start: Idx, end_exclusive: Idx, chunk_size: usize) -> Result<Self, String> {
Ok(Self {
inner: NonOverlappingIntegerPairIter::new(start, end_exclusive, chunk_size)?,
})
}
}
macro_rules! vec_chunk_impl {
($ty:ty) => {
impl Iterator for VecChunkIter<$ty> {
type Item = Vec<$ty>;
fn next(&mut self) -> Option<Self::Item> {
let (start, end) = self.inner.next()?;
Some((start..=end).collect())
}
}
};
}
vec_chunk_impl!(u32);
vec_chunk_impl!(u64);
vec_chunk_impl!(usize);
pub struct NonOverlappingIntegerPairIter<Idx> {
current: Idx,
current_end: Idx,
end: Idx,
size: usize,
}
impl<Idx: PartialOrd + Copy + Display> NonOverlappingIntegerPairIter<Idx>
where u64: From<Idx>
{
pub fn new(start: Idx, end_exclusive: Idx, chunk_size: usize) -> Result<Self, String> {
if start > end_exclusive {
return Err(format!(
"`start` {start} must be less than or equal to `end_exclusive` {end_exclusive}"
));
}
Ok(Self {
current: start,
current_end: end_exclusive,
end: end_exclusive,
size: chunk_size,
})
}
}
macro_rules! non_overlapping_iter_impl {
($ty:ty) => {
impl Iterator for NonOverlappingIntegerPairIter<$ty> {
type Item = ($ty, $ty);
fn next(&mut self) -> Option<Self::Item> {
if self.size == 0 {
return None;
}
if self.current == <$ty>::MAX {
return None;
}
if self.current == self.end {
return None;
}
let size = match <$ty>::try_from(self.size) {
Ok(size) => size,
Err(_) => <$ty>::MAX,
};
match self.current.checked_add(size) {
Some(next) => {
let next = cmp::min(next, self.end);
if self.current == next {
return None;
}
let chunk = (self.current, next - 1);
self.current = next;
Some(chunk)
},
None => {
let chunk = (self.current, <$ty>::MAX - 1);
self.current = <$ty>::MAX;
Some(chunk)
},
}
}
}
impl DoubleEndedIterator for NonOverlappingIntegerPairIter<$ty> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.size == 0 || self.current_end == 0 {
return None;
}
if self.current_end == self.current {
return None;
}
let size = match <$ty>::try_from(self.size) {
Ok(size) => size,
Err(_) => <$ty>::MAX,
};
if self.end == self.current_end {
let rem = (self.end - self.current) % size;
if rem > 0 && self.current_end.saturating_sub(rem).checked_add(size).is_none() {
self.current_end = self.current_end.saturating_sub(rem);
let chunk = (self.current_end, <$ty>::MAX - 1);
return Some(chunk);
}
if rem > 0 {
self.current_end = self.end - rem;
let chunk = (self.current_end, self.end - 1);
return Some(chunk);
}
}
let next = self.current_end.saturating_sub(size);
let chunk = (next, self.current_end - 1);
self.current_end = next;
Some(chunk)
}
}
};
}
non_overlapping_iter_impl!(u8);
non_overlapping_iter_impl!(u16);
non_overlapping_iter_impl!(u32);
non_overlapping_iter_impl!(u64);
non_overlapping_iter_impl!(usize);
#[cfg(test)]
mod test {
use rand::RngExt;
use super::*;
#[test]
fn start_equals_end() {
let mut iter = NonOverlappingIntegerPairIter::new(10u32, 10, 0).unwrap();
assert!(iter.next().is_none());
let mut iter = VecChunkIter::new(10u32, 10, 0).unwrap();
assert!(iter.next().is_none());
}
#[test]
fn start_gt_end() {
assert!(NonOverlappingIntegerPairIter::new(11u32, 10, 0).is_err());
assert!(VecChunkIter::new(11u32, 10, 0).is_err());
}
#[test]
fn chunk_size_out_of_bounds() {
let mut iter = NonOverlappingIntegerPairIter::new(10u32, 10, 10).unwrap();
assert!(iter.next().is_none());
let mut iter = VecChunkIter::new(10u32, 10, 10).unwrap();
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 10, 100).unwrap();
assert_eq!(iter.next().unwrap(), (0, 9));
let mut iter = VecChunkIter::new(0u32, 10, 100).unwrap();
assert_eq!(iter.next().unwrap(), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn chunk_size_multiple_of_end() {
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 9, 3).unwrap();
assert_eq!(iter.next().unwrap(), (0, 2));
assert_eq!(iter.next().unwrap(), (3, 5));
assert_eq!(iter.next().unwrap(), (6, 8));
assert!(iter.next().is_none());
let mut iter = VecChunkIter::new(0u32, 9, 3).unwrap();
assert_eq!(iter.next().unwrap(), vec![0, 1, 2]);
assert_eq!(iter.next().unwrap(), vec![3, 4, 5]);
assert_eq!(iter.next().unwrap(), vec![6, 7, 8]);
assert!(iter.next().is_none());
}
#[test]
fn chunk_size_not_multiple_of_end() {
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 10, 3).unwrap();
assert_eq!(iter.next().unwrap(), (0, 2));
assert_eq!(iter.next().unwrap(), (3, 5));
assert_eq!(iter.next().unwrap(), (6, 8));
assert_eq!(iter.next().unwrap(), (9, 9));
assert!(iter.next().is_none());
let mut iter = VecChunkIter::new(0u32, 10, 3).unwrap();
assert_eq!(iter.next().unwrap(), vec![0, 1, 2]);
assert_eq!(iter.next().unwrap(), vec![3, 4, 5]);
assert_eq!(iter.next().unwrap(), vec![6, 7, 8]);
assert_eq!(iter.next().unwrap(), vec![9]);
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 16, 5).unwrap();
assert_eq!(iter.next().unwrap(), (0, 4));
assert_eq!(iter.next().unwrap(), (5, 9));
assert_eq!(iter.next().unwrap(), (10, 14));
assert_eq!(iter.next().unwrap(), (15, 15));
assert_eq!(iter.next(), None);
}
#[test]
fn non_zero_start() {
let mut iter = NonOverlappingIntegerPairIter::new(1001u32, 4000, 1000).unwrap();
assert_eq!(iter.next().unwrap(), (1001, 2000));
assert_eq!(iter.next().unwrap(), (2001, 3000));
assert_eq!(iter.next().unwrap(), (3001, 3999));
assert!(iter.next().is_none());
let mut iter = VecChunkIter::new(10u32, 21, 3).unwrap();
assert_eq!(iter.next().unwrap(), vec![10, 11, 12]);
assert_eq!(iter.next().unwrap(), vec![13, 14, 15]);
assert_eq!(iter.next().unwrap(), vec![16, 17, 18]);
assert_eq!(iter.next().unwrap(), vec![19, 20]);
assert!(iter.next().is_none());
}
#[test]
fn overflow() {
let mut iter = NonOverlappingIntegerPairIter::new(250u8, 255, 3).unwrap();
assert_eq!(iter.next().unwrap(), (250, 252));
assert_eq!(iter.next().unwrap(), (253, 254));
assert!(iter.next().is_none());
}
#[test]
fn double_ended() {
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 9, 3).unwrap().rev();
assert_eq!(iter.next().unwrap(), (6, 8));
assert_eq!(iter.next().unwrap(), (3, 5));
assert_eq!(iter.next().unwrap(), (0, 2));
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 10, 3).unwrap().rev();
assert_eq!(iter.next().unwrap(), (9, 9));
assert_eq!(iter.next().unwrap(), (6, 8));
assert_eq!(iter.next().unwrap(), (3, 5));
assert_eq!(iter.next().unwrap(), (0, 2));
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(0u32, 16, 5).unwrap().rev();
assert_eq!(iter.next().unwrap(), (15, 15));
assert_eq!(iter.next().unwrap(), (10, 14));
assert_eq!(iter.next().unwrap(), (5, 9));
assert_eq!(iter.next().unwrap(), (0, 4));
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(1001u32, 4000, 1000).unwrap().rev();
assert_eq!(iter.next().unwrap(), (3001, 3999));
assert_eq!(iter.next().unwrap(), (2001, 3000));
assert_eq!(iter.next().unwrap(), (1001, 2000));
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(254u8, u8::MAX, 1000).unwrap().rev();
assert_eq!(iter.next().unwrap(), (254, 254));
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(87u8, u8::MAX, 6).unwrap().rev();
assert_eq!(iter.next().unwrap(), (249, 254));
assert_eq!(iter.next().unwrap(), (243, 248));
for _ in 0..((255 - 87) / 6) - 2 {
assert!(iter.next().is_some());
}
assert!(iter.next().is_none());
let mut iter = NonOverlappingIntegerPairIter::new(255u8, u8::MAX, 1000).unwrap().rev();
assert!(iter.next().is_none());
}
#[test]
fn iterator_symmetry() {
let size = rand::rng().random_range(3usize..=10);
let rand_start = rand::rng().random::<u8>();
let rand_end = rand::rng().random::<u8>().saturating_add(rand_start);
eprintln!("iterator_symmetry: rand_start = {rand_start}, rand_end = {rand_end}, size = {size}");
let iter_rev = NonOverlappingIntegerPairIter::<u8>::new(rand_start, rand_end, size)
.unwrap()
.rev();
let iter = NonOverlappingIntegerPairIter::<u8>::new(rand_start, rand_end, size).unwrap();
let collect1 = iter.take(1000).collect::<Vec<_>>();
let collect2 = iter_rev
.take(1000)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<Vec<_>>();
assert_eq!(
collect1, collect2,
"Failed with rand_start = {rand_start}, rand_end = {rand_end}, size = {size}"
);
}
}