use super::*;
use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
const BATCH_SIZE: usize = 128;
impl<I> Default for IntCOStack<I>
where
I: IntCO,
{
#[inline]
fn default() -> Self {
Self {
change_points: Arc::from([]),
covered: OnceLock::default(),
height_stats: HeightStats::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum EndpointKind {
Enter,
Leave,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Endpoint<C> {
at: C,
kind: EndpointKind,
}
#[derive(Debug, Default)]
struct StackParts<C>
where
C: Default,
{
points: Vec<ChangePoint<C>>,
height_stats: HeightStats,
}
fn build_parts_from_endpoints<C>(mut endpoints: Vec<Endpoint<C>>) -> StackParts<C>
where
C: Default + Copy + Ord,
{
endpoints.sort_unstable_by_key(|endpoint| endpoint.at);
let mut points = Vec::with_capacity(endpoints.len());
let mut height_stats = HeightStats::default();
let mut height_after = 0usize;
let mut cursor = 0usize;
while cursor < endpoints.len() {
let at = endpoints[cursor].at;
let mut enters = 0usize;
let mut leaves = 0usize;
while cursor < endpoints.len() && endpoints[cursor].at == at {
match endpoints[cursor].kind {
EndpointKind::Enter => enters += 1,
EndpointKind::Leave => leaves += 1,
}
cursor += 1;
}
let next_height = if enters >= leaves {
height_after.checked_add(enters - leaves)
} else {
height_after.checked_sub(leaves - enters)
}
.expect("valid intervals must never produce a negative stack height");
height_stats.observe(next_height);
if next_height != height_after {
points.push(ChangePoint {
at,
height_after: next_height,
});
}
height_after = next_height;
}
debug_assert_eq!(
height_after, 0,
"all finite half-open intervals must eventually close"
);
StackParts {
points,
height_stats,
}
}
fn merge_parts<C>(lhs: &StackParts<C>, rhs: &StackParts<C>) -> StackParts<C>
where
C: Default + Copy + Ord,
{
let lhs_points_len = lhs.points.len();
let rhs_points_len = rhs.points.len();
let mut out_points = Vec::with_capacity(lhs_points_len + rhs_points_len);
let mut out_stats = HeightStats::default();
let mut lhs_height = 0usize;
let mut rhs_height = 0usize;
let mut merged_height = 0usize;
let mut lhs_cursor = 0usize;
let mut rhs_cursor = 0usize;
while lhs_cursor < lhs_points_len || rhs_cursor < rhs_points_len {
let at = match (lhs.points.get(lhs_cursor), rhs.points.get(rhs_cursor)) {
(Some(l), Some(r)) => match l.at.cmp(&r.at) {
std::cmp::Ordering::Less => {
lhs_height = l.height_after;
lhs_cursor += 1;
l.at
}
std::cmp::Ordering::Greater => {
rhs_height = r.height_after;
rhs_cursor += 1;
r.at
}
std::cmp::Ordering::Equal => {
lhs_height = l.height_after;
rhs_height = r.height_after;
lhs_cursor += 1;
rhs_cursor += 1;
l.at
}
},
(Some(l), None) => {
lhs_height = l.height_after;
lhs_cursor += 1;
l.at
}
(None, Some(r)) => {
rhs_height = r.height_after;
rhs_cursor += 1;
r.at
}
(None, None) => unreachable!(),
};
let next_merged_height = lhs_height
.checked_add(rhs_height)
.expect("stack height overflow");
out_stats.observe(next_merged_height);
if next_merged_height != merged_height {
out_points.push(ChangePoint {
at,
height_after: next_merged_height,
});
merged_height = next_merged_height;
}
}
debug_assert_eq!(merged_height, 0);
StackParts {
points: out_points,
height_stats: out_stats,
}
}
#[derive(Debug)]
struct StackBuildAcc<C>
where
C: Default + Copy + Ord,
{
endpoints: Vec<Endpoint<C>>,
levels: Vec<Option<StackParts<C>>>,
}
impl<C> StackBuildAcc<C>
where
C: Default + Copy + Ord,
{
#[inline]
fn new() -> Self {
Self {
endpoints: Vec::with_capacity(BATCH_SIZE.saturating_mul(2)),
levels: Vec::new(),
}
}
#[inline]
fn push_interval<I>(&mut self, interval: I)
where
I: IntCO<CoordType = C>,
{
self.endpoints.push(Endpoint {
at: interval.start(),
kind: EndpointKind::Enter,
});
self.endpoints.push(Endpoint {
at: interval.end_excl(),
kind: EndpointKind::Leave,
});
if self.endpoints.len() >= BATCH_SIZE.saturating_mul(2) {
self.flush();
}
}
#[inline]
fn finish(mut self) -> StackParts<C> {
self.flush();
self.levels
.into_iter()
.flatten()
.reduce(|lhs, rhs| merge_parts(&lhs, &rhs))
.unwrap_or_default()
}
#[inline]
fn flush(&mut self) {
if self.endpoints.is_empty() {
return;
}
let endpoints = core::mem::replace(
&mut self.endpoints,
Vec::with_capacity(BATCH_SIZE.saturating_mul(2)),
);
self.push_points(build_parts_from_endpoints(endpoints));
}
fn push_points(&mut self, mut carry: StackParts<C>) {
let mut level = 0usize;
loop {
if level == self.levels.len() {
self.levels.push(Some(carry));
break;
}
match self.levels[level].take() {
None => {
self.levels[level] = Some(carry);
break;
}
Some(parts) => {
carry = merge_parts(&parts, &carry);
level += 1;
}
}
}
}
}
impl<I> FromIterator<I> for IntCOStack<I>
where
I: IntCO + Copy,
{
#[inline]
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = I>,
{
let mut acc = StackBuildAcc::new();
for interval in iter {
acc.push_interval(interval);
}
let StackParts {
points,
height_stats,
} = acc.finish();
Self {
change_points: points.into(),
covered: OnceLock::new(),
height_stats,
}
}
}
impl<I> FromParallelIterator<I> for IntCOStack<I>
where
I: IntCO + Copy + Send,
{
#[inline]
fn from_par_iter<T>(par_iter: T) -> Self
where
T: IntoParallelIterator<Item = I>,
{
let StackParts {
points,
height_stats,
} = par_iter
.into_par_iter()
.fold(StackBuildAcc::new, |mut acc, interval| {
acc.push_interval(interval);
acc
})
.map(StackBuildAcc::finish)
.reduce(StackParts::default, |lhs, rhs| merge_parts(&lhs, &rhs));
Self {
change_points: points.into(),
covered: OnceLock::new(),
height_stats,
}
}
}
#[cfg(test)]
pub(crate) mod test_support;
#[cfg(test)]
mod tests_for_build_parts_from_endpoints;
#[cfg(test)]
mod tests_for_merge_parts;
#[cfg(test)]
mod tests_for_stack_build_acc;
#[cfg(test)]
mod tests_for_from_iter_and_from_par_iter;