use alloc::{collections::btree_map::BTreeMap, vec::Vec};
use core::ops::Range;
pub(crate) struct RangeCounter {
counters: BTreeMap<usize, usize>,
}
impl RangeCounter {
pub(crate) const fn new() -> Self {
Self {
counters: BTreeMap::new(),
}
}
#[cfg(ktest)]
pub(crate) fn get(&self, index: usize) -> usize {
self.counters.get(&index).cloned().unwrap_or(0)
}
pub(crate) fn add(&mut self, range: &Range<usize>) -> impl Iterator<Item = Range<usize>> {
assert!(range.start <= range.end);
let mut updated_ranges = Vec::new();
let mut reported_end = range.start;
for i in range.clone() {
let counter = self.counters.entry(i).or_insert(0);
if *counter != 0 {
if reported_end < i {
updated_ranges.push(reported_end..i);
}
reported_end = i + 1;
}
*counter += 1;
}
if reported_end < range.end {
updated_ranges.push(reported_end..range.end);
}
updated_ranges.into_iter()
}
pub(crate) fn remove(&mut self, range: &Range<usize>) -> impl Iterator<Item = Range<usize>> {
assert!(range.start <= range.end);
let mut updated_ranges = Vec::new();
let mut reported_end = range.start;
for i in range.clone() {
let counter = self.counters.get_mut(&i).expect("Removing a zero counter");
if *counter > 1 {
if reported_end < i {
updated_ranges.push(reported_end..i);
}
reported_end = i + 1;
}
*counter -= 1;
if *counter == 0 {
self.counters.remove(&i);
}
}
if reported_end < range.end {
updated_ranges.push(reported_end..range.end);
}
updated_ranges.into_iter()
}
}
#[cfg(ktest)]
mod test {
use alloc::vec;
use super::*;
use crate::prelude::*;
macro_rules! check_counter_values {
($counter:expr, $([$range:expr, $expected:expr]),* $(,)?) => {
$(
for i in $range {
assert_eq!($counter.get(i), $expected,
"Counter at index {} should be {}, but got {}",
i, $expected, $counter.get(i));
}
)*
};
}
#[ktest]
fn add_remove_range() {
let mut counter = RangeCounter::new();
let range = 0..5;
assert_eq!(counter.add(&range).collect::<Vec<_>>(), vec![range.clone()]);
check_counter_values!(counter, [range.clone(), 1]);
assert_eq!(
counter.remove(&range).collect::<Vec<_>>(),
vec![range.clone()]
);
check_counter_values!(counter, [range, 0]);
}
#[ktest]
fn add_remove_overlapping_beginning() {
let mut counter = RangeCounter::new();
let range1 = 10..15;
let range2 = 3..13;
assert_eq!(
counter.add(&range1).collect::<Vec<_>>(),
vec![range1.clone()]
);
assert_eq!(counter.add(&range2).collect::<Vec<_>>(), vec![3..10]);
check_counter_values!(counter, [3..10, 1], [10..13, 2], [13..15, 1]);
assert_eq!(counter.remove(&range2).collect::<Vec<_>>(), vec![3..10]);
check_counter_values!(counter, [3..10, 0], [10..13, 1], [13..15, 1]);
}
#[ktest]
fn add_remove_overlapping_end() {
let mut counter = RangeCounter::new();
let range1 = 10..15;
let range2 = 12..18;
assert_eq!(
counter.add(&range1).collect::<Vec<_>>(),
vec![range1.clone()]
);
assert_eq!(counter.add(&range2).collect::<Vec<_>>(), vec![15..18]);
check_counter_values!(counter, [10..12, 1], [12..15, 2], [15..18, 1]);
assert_eq!(counter.remove(&range2).collect::<Vec<_>>(), vec![15..18]);
check_counter_values!(counter, [10..12, 1], [12..15, 1], [15..18, 0]);
}
#[ktest]
fn add_remove_covering() {
let mut counter = RangeCounter::new();
let range1 = 20..30;
let range2 = 15..35;
assert_eq!(
counter.add(&range1).collect::<Vec<_>>(),
vec![range1.clone()]
);
assert_eq!(
counter.add(&range2).collect::<Vec<_>>(),
vec![15..20, 30..35]
);
check_counter_values!(counter, [15..20, 1], [20..30, 2], [30..35, 1]);
assert_eq!(
counter.remove(&range2).collect::<Vec<_>>(),
vec![15..20, 30..35]
);
check_counter_values!(counter, [15..20, 0], [20..30, 1], [30..35, 0]);
}
#[ktest]
fn add_remove_partial_overlap() {
let mut counter = RangeCounter::new();
let range1 = 5..15;
let range2 = 10..20;
let remove_range = 8..12;
assert_eq!(
counter.add(&range1).collect::<Vec<_>>(),
vec![range1.clone()]
);
assert_eq!(counter.add(&range2).collect::<Vec<_>>(), vec![15..20]);
check_counter_values!(counter, [5..10, 1], [10..15, 2], [15..20, 1]);
assert_eq!(
counter.remove(&remove_range).collect::<Vec<_>>(),
vec![8..10]
);
check_counter_values!(
counter,
[5..8, 1],
[8..10, 0],
[10..12, 1],
[12..15, 2],
[15..20, 1]
);
}
}