use std::ops::RangeInclusive;
use ndarray::ArrayView1;
#[cfg(feature = "rayon")]
use rayon::slice::ParallelSliceMut as _;
pub struct KthSmallestTree {
roots: Vec<u32>,
counts: Vec<u32>,
siblings: Vec<Node>,
sorted: Vec<f64>,
len: u32,
}
impl KthSmallestTree {
#[inline]
pub fn build(values: &ArrayView1<f64>) -> Self {
assert!(values.len() < u32::MAX as usize - 1, "Input array too big");
let roots = Vec::with_capacity(values.len());
let total_estimate = values.len() * values.len().next_power_of_two().ilog2() as usize + 1;
let siblings = Vec::with_capacity(total_estimate);
let counts = Vec::with_capacity(total_estimate);
let len = values.len() as u32;
let mut sorted = values.to_vec();
#[cfg(feature = "rayon")]
sorted.par_sort_unstable_by(f64::total_cmp);
#[cfg(not(feature = "rayon"))]
sorted.sort_unstable_by(f64::total_cmp);
sorted.dedup();
let mut this = Self {
roots,
siblings,
counts,
len,
sorted: sorted.clone(),
};
this.siblings.push(Node {
left_index: 0,
right_index: 0,
});
this.counts.push(0);
this.roots.push(0);
let indices: Vec<u32> = values
.iter()
.map(|value| {
sorted
.binary_search_by(|sorted_value| f64::total_cmp(sorted_value, value))
.unwrap_or_default()
.saturating_add(1) as u32
})
.collect();
for index in indices {
let root = this.insert(
*this.roots.last().expect("Building root failed"),
1..=len,
index,
);
this.roots.push(root);
}
this
}
pub fn kth(&self, range: RangeInclusive<usize>, mut kth: usize) -> f64 {
let mut current_node = &self.siblings[self.roots[*range.end() + 1] as usize];
let mut previous_node = &self.siblings[self.roots[*range.start()] as usize];
let mut start = 1_u32;
let mut end = self.len;
while start != end {
let left_size = {
let current_left_count = self.counts[current_node.left_index as usize];
let previous_left_count = self.counts[previous_node.left_index as usize];
(current_left_count as usize).saturating_sub(previous_left_count as usize)
};
let mid = start.midpoint(end);
if kth <= left_size {
current_node = &self.siblings[current_node.left_index as usize];
previous_node = &self.siblings[previous_node.left_index as usize];
end = mid;
} else {
current_node = &self.siblings[current_node.right_index as usize];
previous_node = &self.siblings[previous_node.right_index as usize];
start = mid + 1;
kth -= left_size;
}
}
self.sorted[start as usize - 1]
}
fn insert(&mut self, current_index: u32, range: RangeInclusive<u32>, update_index: u32) -> u32 {
debug_assert!(update_index >= *range.start(), "{update_index} {range:?}");
debug_assert!(update_index <= *range.end(), "{update_index} {range:?}");
let current_index = current_index as usize;
let mut node = self.siblings[current_index];
let mut count = self.counts[current_index];
count += 1;
if range.start() == range.end() {
let index = self.siblings.len() as u32;
self.siblings.push(node);
self.counts.push(count);
return index;
}
let mid = range.start().midpoint(*range.end());
if update_index <= mid {
node.left_index = self.insert(node.left_index, *range.start()..=mid, update_index);
} else {
node.right_index =
self.insert(node.right_index, (mid + 1)..=*range.end(), update_index);
};
let index = self.siblings.len() as u32;
self.siblings.push(node);
self.counts.push(count);
index
}
}
#[derive(Clone, Copy)]
pub struct Node {
left_index: u32,
right_index: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kth() {
let input = ndarray::aview1(&[3.5, 1.2, 4.8, 2.1, 5.0, 1.2]);
let tree = KthSmallestTree::build(&input);
let mut sorted = input.to_vec();
sorted.sort_by(f64::total_cmp);
sorted.into_iter().enumerate().for_each(|(index, value)| {
let range = 0..=(input.len() - 1);
let kth = index + 1;
assert_eq!(
tree.kth(range.clone(), kth),
value,
"{input:?} {range:?} K-th {kth}",
);
});
assert_eq!(tree.kth(2..=4, 1), 2.1);
assert_eq!(tree.kth(2..=4, 2), 4.8);
assert_eq!(tree.kth(2..=4, 3), 5.0);
}
}