use std::{
cmp::Ordering,
iter::Sum,
ops::{Add, Sub},
};
use num_traits::{ConstOne, PrimInt, Unsigned};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RemoveError<T> {
#[error("Value {0} not in set")]
ValueNotInSet(T),
}
#[derive(Debug, Error)]
pub enum InsertError<T> {
#[error("value {0} already in set")]
ValueAlreadyInSet(T),
}
#[derive(Debug, Error)]
pub enum NewRangeSetError {
#[error("ranges must be non-overlapping and sorted by start")]
InvalidRanges,
}
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)]
pub struct RangeSet<T: PrimInt + ConstOne + Unsigned + Sum<T>> {
ranges: Vec<Range<T>>,
}
impl<T: PrimInt + ConstOne + Unsigned + Sum<T>> RangeSet<T> {
pub fn new(ranges: Vec<Range<T>>) -> Result<Self, NewRangeSetError> {
for i in 0..ranges.len() {
if ranges[i].start >= ranges[i].end {
return Err(NewRangeSetError::InvalidRanges);
}
if i == 0 {
continue;
}
if ranges[i - 1].end > ranges[i].start {
return Err(NewRangeSetError::InvalidRanges);
}
}
Ok(Self { ranges })
}
pub fn is_empty(&self) -> bool {
self.ranges.is_empty()
}
pub fn insert(&mut self, value: T) -> Result<(), InsertError<T>> {
if self.ranges.is_empty() {
self.ranges.push(Range::new(value, value + T::ONE));
return Ok(());
}
let range = self.ranges.binary_search_by(|range| range.compare(&value));
let i = match range {
Ok(_) => return Err(InsertError::ValueAlreadyInSet(value)),
Err(i) => i,
};
assert!(i <= self.ranges.len());
if i == self.ranges.len() {
if self.ranges.last().unwrap().end == value {
self.ranges.last_mut().unwrap().end = value + T::ONE;
} else {
self.ranges.push(Range::new(value, value + T::ONE));
}
} else {
if self.ranges[i].start == value + T::ONE {
self.ranges[i].start = value;
} else {
self.ranges.insert(i, Range::new(value, value + T::ONE));
}
if i > 0 && self.ranges[i - 1].end == self.ranges[i].start {
self.ranges[i - 1].end = self.ranges[i].end;
self.ranges.remove(i);
}
}
Ok(())
}
pub fn len(&self) -> T {
self.ranges.iter().map(|range| range.len()).sum()
}
pub fn contains(&self, value: T) -> bool {
self.ranges
.binary_search_by(|range| range.compare(&value))
.is_ok()
}
pub fn remove(&mut self, value: T) -> Result<(), RemoveError<T>> {
let range = self.ranges.binary_search_by(|range| range.compare(&value));
match range {
Ok(i) => {
if self.ranges[i].start == value {
self.ranges[i].start = value + T::ONE;
if self.ranges[i].is_empty() {
self.ranges.remove(i);
}
} else if self.ranges[i].end - T::ONE == value {
self.ranges[i].end = value;
if self.ranges[i].is_empty() {
self.ranges.remove(i);
}
} else {
let new_range = Range::new(self.ranges[i].start, value);
self.ranges[i].start = value + T::ONE;
self.ranges.insert(i, new_range);
}
Ok(())
}
Err(_) => Err(RemoveError::ValueNotInSet(value)),
}
}
pub fn nth(&self, mut n: T) -> Option<T> {
for range in self.ranges.iter() {
if n < range.len() {
return Some(range.nth(n).unwrap());
}
n = n - range.len();
}
None
}
pub fn ranges(&self) -> &[Range<T>] {
&self.ranges
}
}
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)]
pub struct Range<T: Ord + Sub<Output = T> + Add<Output = T> + Copy + Unsigned> {
pub start: T,
pub end: T,
}
impl<T: Ord + Sub<Output = T> + Add<Output = T> + Copy + Unsigned> Range<T> {
pub fn new(start: T, end: T) -> Self {
Self { start, end }
}
pub fn compare(&self, other: &T) -> Ordering {
if self.start <= *other && self.end > *other {
Ordering::Equal
} else if self.start > *other {
Ordering::Greater
} else {
Ordering::Less
}
}
pub fn len(&self) -> T {
self.end - self.start
}
pub fn is_empty(&self) -> bool {
self.start == self.end
}
pub fn nth(&self, n: T) -> Option<T> {
if n >= self.len() {
None
} else {
Some(self.start + n)
}
}
}
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use super::*;
fn check_rangeset_invariants<T: PrimInt + ConstOne + Unsigned + Sum<T> + std::fmt::Debug>(
rangeset: &RangeSet<T>,
) {
let ranges = rangeset.ranges();
for i in 1..ranges.len() {
assert!(
ranges[i - 1].start < ranges[i].start,
"Ranges not ordered: {:?} and {:?}",
ranges[i - 1],
ranges[i]
);
assert!(
ranges[i - 1].end <= ranges[i].start,
"Ranges not disjoint: {:?} and {:?}",
ranges[i - 1],
ranges[i]
);
assert!(
ranges[i].start < ranges[i].end,
"Invalid range: {:?}",
ranges[i]
);
}
}
fn test_random_operations_generic<T, R: Rng>(rng: &mut R, start: T, end: T, type_name: &str)
where
T: PrimInt
+ ConstOne
+ Unsigned
+ Sum<T>
+ std::fmt::Debug
+ TryFrom<u128>
+ std::convert::Into<u128>,
<T as std::convert::TryFrom<u128>>::Error: std::fmt::Debug,
{
println!("Running random operations test for {type_name}");
let mut rangeset = RangeSet::new(vec![Range::new(start, end)]).unwrap();
let mut allocated = Vec::new();
let initial_capacity: T = end - start;
let mut remaining = initial_capacity;
for _i in 0..1000 {
let range_len: u128 = rangeset.len().into();
let random_index = rng.random_range(0..range_len);
let n = T::try_from(random_index).unwrap();
let value = rangeset.nth(n).unwrap();
rangeset.remove(value).expect("Failed to remove value");
check_rangeset_invariants(&rangeset);
assert!(
!rangeset.contains(value),
"Value should not be in set after removal"
);
allocated.push(value);
remaining = remaining - T::ONE;
assert_eq!(rangeset.len(), remaining, "Remaining capacity mismatch");
}
for value in allocated.iter() {
rangeset.insert(*value).expect("Failed to insert value");
check_rangeset_invariants(&rangeset);
assert!(
rangeset.contains(*value),
"Value should be in set after insertion"
);
remaining = remaining + T::ONE;
assert_eq!(
rangeset.len(),
remaining,
"Remaining capacity mismatch after insertion"
);
}
assert_eq!(
rangeset.ranges().len(),
1,
"Expected single range after all operations"
);
assert_eq!(
rangeset.len(),
initial_capacity,
"Expected full capacity after all operations"
);
}
#[test]
fn random_operations_maintain_invariants() {
let mut rng_u16 = ChaCha8Rng::seed_from_u64(42);
let mut rng_u32 = ChaCha8Rng::seed_from_u64(43);
let mut rng_u128 = ChaCha8Rng::seed_from_u64(44);
test_random_operations_generic(&mut rng_u16, 1, u16::MAX, "u16");
test_random_operations_generic(&mut rng_u32, 1, u32::MAX, "u32");
test_random_operations_generic(&mut rng_u128, 1, u128::MAX, "u128");
}
#[test]
fn nth_returns_correct_element() {
let rangeset: RangeSet<u16> = RangeSet::new(vec![
Range::new(1, 5), Range::new(10, 15), ])
.unwrap();
assert_eq!(rangeset.nth(0), Some(1));
assert_eq!(rangeset.nth(1), Some(2));
assert_eq!(rangeset.nth(2), Some(3));
assert_eq!(rangeset.nth(3), Some(4));
assert_eq!(rangeset.nth(4), Some(10));
assert_eq!(rangeset.nth(5), Some(11));
assert_eq!(rangeset.nth(6), Some(12));
assert_eq!(rangeset.nth(7), Some(13));
assert_eq!(rangeset.nth(8), Some(14));
assert_eq!(rangeset.nth(9), None); }
#[test]
fn boundary_values_handled_correctly() {
let test_cases = [
(1u32, 10), (1u32, 100), (1u32, 200), ];
for (start, end) in test_cases {
let mut rangeset = RangeSet::new(vec![Range::new(start, end)]).unwrap();
let original_len = rangeset.len();
assert!(rangeset.contains(start), "Lower boundary should be in set");
assert!(
!rangeset.contains(start - 1),
"Value before lower boundary should not be in set"
);
assert!(
!rangeset.contains(end),
"Upper boundary should not be in set (half-open range)"
);
assert!(
rangeset.contains(end - 1),
"Value just before upper boundary should be in set"
);
rangeset
.remove(start)
.expect("Failed to remove lower boundary");
assert!(
!rangeset.contains(start),
"Lower boundary should be removed"
);
assert_eq!(
rangeset.len(),
original_len - 1,
"Length should decrease by 1"
);
rangeset
.remove(end - 1)
.expect("Failed to remove upper boundary - 1");
assert!(
!rangeset.contains(end - 1),
"Upper boundary - 1 should be removed"
);
assert!(
rangeset.remove(start - 1).is_err(),
"Should fail to remove value outside lower boundary"
);
assert!(
rangeset.remove(end).is_err(),
"Should fail to remove value at upper boundary"
);
assert!(
rangeset.remove(end + 1).is_err(),
"Should fail to remove value beyond upper boundary"
);
rangeset
.insert(start)
.expect("Failed to insert lower boundary");
rangeset
.insert(end - 1)
.expect("Failed to insert upper boundary - 1");
assert_eq!(
rangeset.len(),
original_len,
"Should restore original length"
);
check_rangeset_invariants(&rangeset);
}
}
#[test]
fn adjacent_ranges_merge_on_insert() {
let mut rangeset = RangeSet::new(vec![
Range::new(1u32, 5), Range::new(10, 15), ])
.unwrap();
rangeset
.insert(5)
.expect("Failed to insert value at start of gap");
check_rangeset_invariants(&rangeset);
assert_eq!(rangeset.ranges().len(), 2);
for val in 6..=9 {
rangeset.insert(val).expect("Failed to insert gap value");
check_rangeset_invariants(&rangeset);
}
assert_eq!(
rangeset.ranges().len(),
1,
"Ranges should merge when gap is filled"
);
assert_eq!(rangeset.ranges()[0].start, 1);
assert_eq!(rangeset.ranges()[0].end, 15);
}
#[test]
fn remove_splits_ranges_correctly() {
let mut rangeset = RangeSet::new(vec![Range::new(1u32, 10)]).unwrap();
rangeset.remove(5).expect("Failed to remove middle value");
check_rangeset_invariants(&rangeset);
assert_eq!(
rangeset.ranges().len(),
2,
"Range should split when middle value is removed"
);
assert_eq!(rangeset.ranges()[0].start, 1);
assert_eq!(rangeset.ranges()[0].end, 5);
assert_eq!(rangeset.ranges()[1].start, 6);
assert_eq!(rangeset.ranges()[1].end, 10);
rangeset.remove(1).expect("Failed to remove start value");
check_rangeset_invariants(&rangeset);
rangeset.remove(9).expect("Failed to remove end value");
check_rangeset_invariants(&rangeset);
assert_eq!(
rangeset.ranges().len(),
2,
"Should still have 2 ranges after edge removals"
);
}
}