use alloc::boxed::Box;
use core::{cmp::Ordering, fmt::Debug, ops::Range};
use crate::interval::Interval;
#[derive(Debug)]
pub(super) enum RemoveResult<T> {
Removed(T),
ParentUnlink,
}
#[derive(Debug, Clone)]
pub(crate) struct Node<R, V> {
left: Option<Box<Node<R, V>>>,
right: Option<Box<Node<R, V>>>,
height: u8,
subtree_max: R,
interval: Interval<R>,
value: V,
}
impl<R, V> Node<R, V> {
pub(crate) fn new(interval: Interval<R>, value: V) -> Self
where
R: Clone,
{
Self {
subtree_max: interval.end().clone(),
interval,
value,
left: None,
right: None,
height: 0,
}
}
pub(crate) fn insert(self: &mut Box<Self>, interval: Interval<R>, value: V) -> Option<V>
where
R: Ord + Clone,
{
let child = match interval.cmp(&self.interval) {
Ordering::Less => &mut self.left,
Ordering::Equal => {
return Some(core::mem::replace(&mut self.value, value));
}
Ordering::Greater => &mut self.right,
};
let inserted = match child {
Some(v) => v.insert(interval, value),
None => {
*child = Some(Box::new(Self::new(interval, value)));
update_height(self);
update_subtree_max(self);
return None;
}
};
if inserted.is_some() {
return inserted;
}
update_height(self);
match (balance(self), self.left(), self.right()) {
(2, Some(l), _) if balance(l) >= 0 => {
rotate_right(self);
}
(2, Some(_l), _) => {
rotate_left(self.left_mut().unwrap());
rotate_right(self);
}
(-2, _, Some(r)) if balance(r) < 0 => {
rotate_left(self);
}
(-2, _, Some(_r)) => {
rotate_right(self.right_mut().unwrap());
rotate_left(self);
}
(-1..=1, _, _) => { }
_ => unreachable!(),
};
update_subtree_max(self);
debug_assert!(balance(self).abs() <= 1);
debug_assert!(inserted.is_none());
None
}
pub(super) fn remove(self: &mut Box<Self>, range: &Range<R>) -> Option<RemoveResult<V>>
where
R: Ord + Clone + Debug,
{
match self.interval.partial_cmp(range).unwrap() {
Ordering::Greater => return remove_recurse(&mut self.left, range),
Ordering::Less => return remove_recurse(&mut self.right, range),
Ordering::Equal => {
debug_assert_eq!(self.interval, *range);
}
};
let old = if let Some(mut right) = self.right.take() {
debug_assert_ne!(self.height, 0);
match extract_subtree_min(&mut right) {
Some(mut min) => {
debug_assert!(min.left.is_none());
debug_assert!(min.right.is_none());
min.left = self.left.take();
min.right = Some(right);
core::mem::replace(self, min)
}
None => {
debug_assert!(right.left.is_none());
right.left = self.left.take();
core::mem::replace(self, right)
}
}
} else if let Some(left) = self.left.take() {
debug_assert!(self.right.is_none());
debug_assert_ne!(self.height, 0);
core::mem::replace(self, left)
} else {
debug_assert!(self.left.is_none());
debug_assert!(self.right.is_none());
debug_assert_eq!(self.height, 0);
return Some(RemoveResult::ParentUnlink);
};
debug_assert!(old.right.is_none());
debug_assert!(old.left.is_none());
debug_assert_eq!(old.interval, *range);
debug_assert_ne!(self.interval, *range);
Some(RemoveResult::Removed(old.value))
}
pub(crate) fn get(&self, range: &Range<R>) -> Option<&V>
where
R: Ord + Eq,
{
let node = match self.interval.partial_cmp(range).unwrap() {
Ordering::Greater => self.left(),
Ordering::Equal => return Some(&self.value),
Ordering::Less => self.right(),
}?;
if *node.subtree_max() < range.end {
return None;
}
node.get(range)
}
pub(crate) fn get_mut(&mut self, range: &Range<R>) -> Option<&mut V>
where
R: Ord + Eq,
{
let node = match self.interval.partial_cmp(range).unwrap() {
Ordering::Greater => self.left_mut(),
Ordering::Equal => return Some(&mut self.value),
Ordering::Less => self.right_mut(),
}?;
if *node.subtree_max() < range.end {
return None;
}
node.get_mut(range)
}
pub(crate) fn value(&self) -> &V {
&self.value
}
pub(crate) fn interval(&self) -> &Interval<R> {
&self.interval
}
pub(crate) fn subtree_max(&self) -> &R {
&self.subtree_max
}
pub(crate) fn height(&self) -> u8 {
self.height
}
pub(crate) fn left(&self) -> Option<&Self> {
self.left.as_deref()
}
pub(crate) fn left_mut(&mut self) -> Option<&mut Box<Self>> {
self.left.as_mut()
}
pub(crate) fn take_left(&mut self) -> Option<Box<Self>> {
self.left.take()
}
pub(crate) fn right(&self) -> Option<&Self> {
self.right.as_deref()
}
pub(crate) fn right_mut(&mut self) -> Option<&mut Box<Self>> {
self.right.as_mut()
}
pub(crate) fn take_right(&mut self) -> Option<Box<Self>> {
self.right.take()
}
pub(crate) fn into_tuple(self) -> (Range<R>, V) {
(self.interval.into_range(), self.value)
}
}
fn height<R, V>(n: Option<&Node<R, V>>) -> u8 {
n.map(|v| v.height()).unwrap_or_default()
}
fn update_height<R, V>(n: &mut Node<R, V>) {
n.height = n
.left()
.map(|v| v.height() + 1)
.max(n.right().map(|v| v.height() + 1))
.unwrap_or_default()
}
fn update_subtree_max<R, V>(n: &mut Node<R, V>)
where
R: Ord + Clone,
{
let new_max = n
.left()
.map(|v| v.subtree_max())
.max(n.right().map(|v| v.subtree_max()))
.max(Some(n.interval().end()));
if let Some(new_max) = new_max {
n.subtree_max = new_max.clone();
}
}
fn balance<R, V>(n: &Node<R, V>) -> i8 {
(height(n.left()) as i16 - height(n.right()) as i16) as i8
}
fn rotate_left<R, V>(x: &mut Box<Node<R, V>>)
where
R: Ord + Clone,
{
let mut p = x.right.take().unwrap();
core::mem::swap(x, &mut p);
p.right = x.left.take();
update_height(&mut p);
update_subtree_max(&mut p);
x.left = Some(p);
update_height(x);
update_subtree_max(x);
}
fn rotate_right<R, V>(y: &mut Box<Node<R, V>>)
where
R: Ord + Clone,
{
let mut p = y.left.take().unwrap();
core::mem::swap(y, &mut p);
p.left = y.right.take();
update_height(&mut p);
update_subtree_max(&mut p);
y.right = Some(p);
update_height(y);
update_subtree_max(y);
}
fn extract_subtree_min<R, V>(root: &mut Box<Node<R, V>>) -> Option<Box<Node<R, V>>>
where
R: Ord + Clone,
{
let v = match extract_subtree_min(root.left_mut()?) {
Some(v) => Some(v),
None => {
let left_right = root.left_mut().and_then(|v| v.right.take());
core::mem::replace(&mut root.left, left_right)
}
};
rebalance_after_remove(root);
debug_assert!(balance(root).abs() <= 1);
v
}
pub(super) fn remove_recurse<R, V>(
node: &mut Option<Box<Node<R, V>>>,
interval: &Range<R>,
) -> Option<RemoveResult<V>>
where
R: Ord + Clone + Debug,
{
let remove_ret = node.as_mut().and_then(|v| {
if *v.subtree_max() < interval.end {
return None;
}
let ret = v.remove(interval)?;
rebalance_after_remove(v);
Some(ret)
})?;
let v = match remove_ret {
RemoveResult::Removed(v) => v,
RemoveResult::ParentUnlink => {
let node = node.take().unwrap();
debug_assert_eq!(node.interval, *interval);
node.value
}
};
Some(RemoveResult::Removed(v))
}
fn rebalance_after_remove<R, V>(v: &mut Box<Node<R, V>>)
where
R: Ord + Clone,
{
update_height(v);
match balance(v) {
(2..) if v.left().map(balance).unwrap_or_default() >= 0 => {
rotate_right(v);
}
(2..) => {
v.left_mut().map(rotate_left);
rotate_right(v);
}
..=-2 if v.right().map(balance).unwrap_or_default() <= 0 => {
rotate_left(v);
}
..=-2 => {
v.right_mut().map(rotate_right);
rotate_left(v);
}
#[allow(clippy::manual_range_patterns)]
-1 | 0 | 1 => { }
}
update_subtree_max(v);
debug_assert!(balance(v).abs() <= 1);
}
#[cfg(test)]
mod tests {
use super::*;
fn add_left<R, V>(n: &mut Node<R, V>, interval: impl Into<Interval<R>>, v: V) -> &mut Node<R, V>
where
R: Clone,
{
assert!(n.left.is_none());
n.left = Some(Box::new(Node::new(interval.into(), v)));
n.left_mut().unwrap()
}
fn add_right<R, V>(
n: &mut Node<R, V>,
interval: impl Into<Interval<R>>,
v: V,
) -> &mut Node<R, V>
where
R: Clone,
{
assert!(n.right.is_none());
n.right = Some(Box::new(Node::new(interval.into(), v)));
n.right.as_mut().unwrap()
}
#[test]
fn test_rotate_left() {
let mut t = Node::new(Interval::from(2..2), 2);
add_left(&mut t, 1..1, 1);
let v = add_right(&mut t, 4..4, 4);
add_left(v, 3..3, 3);
let v = add_right(v, 6..6, 6);
add_left(v, 5..5, 5);
add_right(v, 7..7, 7);
let mut t = Box::new(t);
rotate_left(&mut t);
assert_eq!(t.interval, 4..4);
{
let left_root = t.left().unwrap();
assert_eq!(left_root.value, 2);
let left = left_root.left().unwrap();
assert_eq!(left.value, 1);
let right = left_root.right().unwrap();
assert_eq!(right.value, 3);
}
{
let right_root = t.right().unwrap();
assert_eq!(right_root.value, 6);
let left = right_root.left().unwrap();
assert_eq!(left.value, 5);
let right = right_root.right().unwrap();
assert_eq!(right.value, 7);
}
}
#[test]
fn test_rotate_right() {
let mut t = Node::new(Interval::from(6..6), 6);
add_right(&mut t, 7..7, 7);
let v = add_left(&mut t, 4..4, 4);
add_right(v, 5..5, 5);
let v = add_left(v, 2..2, 2);
add_right(v, 3..3, 3);
add_left(v, 1..1, 1);
let mut t = Box::new(t);
rotate_right(&mut t);
assert_eq!(t.interval, 4..4);
{
let left_root = t.left().unwrap();
assert_eq!(left_root.value, 2);
let left = left_root.left().unwrap();
assert_eq!(left.value, 1);
let right = left_root.right().unwrap();
assert_eq!(right.value, 3);
}
{
let right_root = t.right().unwrap();
assert_eq!(right_root.value, 6);
let left = right_root.left().unwrap();
assert_eq!(left.value, 5);
let right = right_root.right().unwrap();
assert_eq!(right.value, 7);
}
}
#[test]
fn test_extract_subtree_min() {
let mut t = Box::new(Node::new(Interval::from(6..6), 6));
add_right(&mut t, 7..7, 7);
let v = add_left(&mut t, 4..4, 7);
add_right(v, 5..5, 5);
let v = add_left(v, 2..2, 2);
add_right(v, 3..3, 3);
add_left(v, 1..1, 1);
for want in [1, 2, 3] {
let n: Box<Node<_, _>> = extract_subtree_min(&mut t).unwrap();
assert_eq!(n.value, want);
assert!(n.right.is_none());
}
assert!(extract_subtree_min(&mut t).is_none());
assert!(extract_subtree_min(&mut t).is_none());
assert!(t.left.is_none());
assert_eq!(t.interval, 4..4);
let right = t.right().unwrap();
assert_eq!(right.interval, 6..6);
assert_eq!(right.left().unwrap().interval, 5..5);
assert_eq!(right.right().unwrap().interval, 7..7);
}
}