use std::{
borrow::Borrow,
cmp::Ordering,
fmt::Debug,
hash::Hash,
iter::{
successors,
FromIterator,
},
mem::swap,
ops::Index,
};
#[derive(Clone)]
pub struct AvlTree<T> {
root: Option<Box<Node<T>>>,
}
impl<T> AvlTree<T> {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.root.is_none()
}
pub fn len(&self) -> usize {
len(self.root.as_deref())
}
pub fn push_back(
&mut self,
value: T,
) {
self.append(&mut Self { root: Some(new(value)) })
}
pub fn push_front(
&mut self,
value: T,
) {
let mut swp = Self { root: Some(new(value)) };
swp.append(self);
*self = swp;
}
pub fn pop_back(&mut self) -> Option<T> {
let root = self.root.take()?;
let last_index = root.len - 1;
let (left, center, _right) = split_delete(root, last_index);
self.root = left;
Some(center.value)
}
pub fn pop_front(&mut self) -> Option<T> {
let (_left, center, right) = split_delete(self.root.take()?, 0);
self.root = right;
Some(center.value)
}
pub fn back(&self) -> Option<&T> {
self.get(self.len().checked_sub(1)?)
}
pub fn front(&self) -> Option<&T> {
self.get(0)
}
pub fn back_mut(&mut self) -> Option<&mut T> {
self.get_mut(self.len().checked_sub(1)?)
}
pub fn front_mut(&mut self) -> Option<&mut T> {
self.get_mut(0)
}
pub fn append(
&mut self,
other: &mut Self,
) {
self.root = merge(self.root.take(), other.root.take());
}
pub fn split_off(
&mut self,
index: usize,
) -> Self {
assert!(index <= self.len());
let (left, right) = split(self.root.take(), index);
self.root = left;
Self { root: right }
}
pub fn insert(
&mut self,
index: usize,
value: T,
) {
assert!(index <= self.len());
let other = self.split_off(index);
self.root =
Some(merge_with_root(self.root.take(), new(value), other.root));
}
pub fn remove(
&mut self,
index: usize,
) -> Option<T> {
if index < self.len() {
let (left, center, right) =
split_delete(self.root.take().unwrap(), index);
self.root = merge(left, right);
Some(center.value)
} else {
None
}
}
pub fn get(
&self,
index: usize,
) -> Option<&T> {
if index < self.len() {
Some(&get(self.root.as_ref().unwrap(), index).value)
} else {
None
}
}
pub fn get_mut(
&mut self,
index: usize,
) -> Option<&mut T> {
if index < self.len() {
Some(&mut get_mut(self.root.as_mut().unwrap(), index).value)
} else {
None
}
}
pub fn binary_search_by(
&self,
f: impl FnMut(&T) -> Ordering,
) -> Result<usize, usize> {
binary_search_by(self.root.as_deref(), f)
}
pub fn binary_search_by_key<B: Ord>(
&self,
b: &B,
mut f: impl FnMut(&T) -> B,
) -> Result<usize, usize> {
self.binary_search_by(|x| f(x).cmp(b))
}
pub fn binary_search<Q: Ord>(
&self,
value: &Q,
) -> Result<usize, usize>
where
T: Borrow<Q>,
{
self.binary_search_by(|x| x.borrow().cmp(value))
}
pub fn partition_point(
&self,
mut is_right: impl FnMut(&T) -> bool,
) -> usize {
partition_point(self.root.as_deref(), |node| is_right(&node.value))
}
pub fn lower_bound<Q: Ord>(
&self,
value: &Q,
) -> usize
where
T: Borrow<Q>,
{
partition_point(self.root.as_deref(), |node| {
value <= node.value.borrow()
})
}
pub fn upper_bound<Q: Ord>(
&self,
value: &Q,
) -> usize
where
T: Borrow<Q>,
{
partition_point(self.root.as_deref(), |node| {
value < node.value.borrow()
})
}
pub fn iter(&self) -> Iter<'_, T> {
Iter {
stack: successors(self.root.as_deref(), |current| {
current.left.as_deref()
})
.collect(),
rstack: successors(self.root.as_deref(), |current| {
current.right.as_deref()
})
.collect(),
}
}
}
impl<T> Default for AvlTree<T> {
fn default() -> Self {
Self { root: None }
}
}
impl<T: PartialEq> PartialEq for AvlTree<T> {
fn eq(
&self,
other: &Self,
) -> bool {
self.iter().eq(other)
}
}
impl<T: PartialEq, A> PartialEq<[A]> for AvlTree<T>
where
T: PartialEq<A>,
{
fn eq(
&self,
other: &[A],
) -> bool {
self.iter().eq(other)
}
}
impl<T: Eq> Eq for AvlTree<T> {}
impl<T: PartialOrd> PartialOrd for AvlTree<T> {
fn partial_cmp(
&self,
other: &Self,
) -> Option<Ordering> {
self.iter().partial_cmp(other)
}
}
impl<T: Ord> Ord for AvlTree<T> {
fn cmp(
&self,
other: &Self,
) -> Ordering {
self.iter().cmp(other)
}
}
impl<T: Debug> Debug for AvlTree<T> {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
f.debug_list().entries(self).finish()
}
}
impl<T: Hash> Hash for AvlTree<T> {
fn hash<H: std::hash::Hasher>(
&self,
state: &mut H,
) {
self.iter().for_each(|elm| elm.hash(state));
}
}
impl<T> IntoIterator for AvlTree<T> {
type IntoIter = IntoIter<T>;
type Item = T;
fn into_iter(self) -> Self::IntoIter {
let mut stack = Vec::new();
if let Some(mut current) = self.root {
while let Some(next) = current.left.take() {
stack.push(current);
current = next;
}
stack.push(current);
}
IntoIter { stack }
}
}
impl<'a, T> IntoIterator for &'a AvlTree<T> {
type IntoIter = Iter<'a, T>;
type Item = &'a T;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T> Index<usize> for AvlTree<T> {
type Output = T;
fn index(
&self,
index: usize,
) -> &Self::Output {
self.get(index).unwrap()
}
}
impl<T> FromIterator<T> for AvlTree<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
fn from_slice_of_nodes<T>(
nodes: &mut [Option<Box<Node<T>>>]
) -> Option<Box<Node<T>>> {
if nodes.is_empty() {
None
} else {
let i = nodes.len() / 2;
Some(merge_with_root(
from_slice_of_nodes(&mut nodes[..i]),
nodes[i].take().unwrap(),
from_slice_of_nodes(&mut nodes[i + 1..]),
))
}
}
Self {
root: from_slice_of_nodes(
iter.into_iter()
.map(new)
.map(Some)
.collect::<Vec<_>>()
.as_mut_slice(),
),
}
}
}
pub struct Iter<'a, T> {
stack: Vec<&'a Node<T>>,
rstack: Vec<&'a Node<T>>,
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
let current = self.stack.pop()?;
self.stack.extend(successors(current.right.as_deref(), |node| {
node.left.as_deref()
}));
if std::ptr::eq(current, *self.rstack.last().unwrap()) {
self.stack.clear();
self.rstack.clear();
}
Some(¤t.value)
}
}
impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
let current = self.rstack.pop()?;
self.rstack.extend(successors(current.left.as_deref(), |node| {
node.right.as_deref()
}));
if std::ptr::eq(current, *self.stack.last().unwrap()) {
self.stack.clear();
self.rstack.clear();
}
Some(¤t.value)
}
}
pub struct IntoIter<T> {
stack: Vec<Box<Node<T>>>,
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
let mut current = self.stack.pop()?;
if let Some(mut current) = current.right.take() {
while let Some(next) = current.left.take() {
self.stack.push(current);
current = next;
}
self.stack.push(current);
}
Some(current.value)
}
}
#[derive(Clone)]
struct Node<T> {
left: Option<Box<Self>>,
right: Option<Box<Self>>,
value: T,
len: usize,
ht: u8,
}
fn new<T>(value: T) -> Box<Node<T>> {
Box::new(Node { left: None, right: None, value, len: 1, ht: 1 })
}
impl<T> Node<T> {
fn update(&mut self) {
self.len = len(self.left.as_deref()) + 1 + len(self.right.as_deref());
self.ht = 1 + ht(self.left.as_deref()).max(ht(self.right.as_deref()));
}
}
fn len<T>(tree: Option<&Node<T>>) -> usize {
tree.as_ref().map_or(0, |node| node.len)
}
fn ht<T>(tree: Option<&Node<T>>) -> u8 {
tree.as_ref().map_or(0, |node| node.ht)
}
fn balance<T>(node: &mut Box<Node<T>>) {
fn rotate_left<T>(node: &mut Box<Node<T>>) {
let mut x = node.left.take().unwrap();
let y = x.right.take();
swap(node, &mut x);
x.left = y;
x.update();
node.right = Some(x);
node.update();
}
fn rotate_right<T>(node: &mut Box<Node<T>>) {
let mut x = node.right.take().unwrap();
let y = x.left.take();
swap(node, &mut x);
x.right = y;
x.update();
node.left = Some(x);
node.update();
}
if ht(node.left.as_deref()) > 1 + ht(node.right.as_deref()) {
let left = node.left.as_mut().unwrap();
if ht(left.left.as_deref()) < ht(left.right.as_deref()) {
rotate_right(left);
}
rotate_left(node);
} else if ht(node.left.as_deref()) + 1 < ht(node.right.as_deref()) {
let right = node.right.as_mut().unwrap();
if ht(right.left.as_deref()) > ht(right.right.as_deref()) {
rotate_left(right);
}
rotate_right(node);
} else {
node.update();
}
}
fn merge_with_root<T>(
mut left: Option<Box<Node<T>>>,
mut center: Box<Node<T>>,
mut right: Option<Box<Node<T>>>,
) -> Box<Node<T>> {
match ht(left.as_deref()).cmp(&ht(right.as_deref())) {
Ordering::Less => {
let mut root = right.take().unwrap();
root.left = Some(merge_with_root(left, center, root.left.take()));
balance(&mut root);
root
}
Ordering::Equal => {
center.left = left;
center.right = right;
center.update();
center
}
Ordering::Greater => {
let mut root = left.take().unwrap();
root.right =
Some(merge_with_root(root.right.take(), center, right));
balance(&mut root);
root
}
}
}
fn merge<T>(
left: Option<Box<Node<T>>>,
mut right: Option<Box<Node<T>>>,
) -> Option<Box<Node<T>>> {
match right.take() {
None => left,
Some(right) => {
let (_none, center, rhs) = split_delete(right, 0);
Some(merge_with_root(left, center, rhs))
}
}
}
#[allow(clippy::type_complexity)]
fn split_delete<T>(
mut root: Box<Node<T>>,
index: usize,
) -> (Option<Box<Node<T>>>, Box<Node<T>>, Option<Box<Node<T>>>) {
debug_assert!((0..root.len).contains(&index));
let left = root.left.take();
let right = root.right.take();
let lsize = len(left.as_deref());
match lsize.cmp(&index) {
Ordering::Less => {
let mut res = split_delete(right.unwrap(), index - lsize - 1);
res.0 = Some(merge_with_root(left, root, res.0));
res
}
Ordering::Equal => (left, root, right),
Ordering::Greater => {
let mut res = split_delete(left.unwrap(), index);
res.2 = Some(merge_with_root(res.2, root, right));
res
}
}
}
#[allow(clippy::type_complexity)]
fn split<T>(
tree: Option<Box<Node<T>>>,
index: usize,
) -> (Option<Box<Node<T>>>, Option<Box<Node<T>>>) {
match tree {
Some(root) => {
if root.len == index {
(Some(root), None)
} else {
let (left, center, right) = split_delete(root, index);
(left, Some(merge_with_root(None, center, right)))
}
}
None => (None, None),
}
}
fn binary_search_by<T>(
tree: Option<&Node<T>>,
mut f: impl FnMut(&T) -> Ordering,
) -> Result<usize, usize> {
let node = match tree {
None => return Err(0),
Some(node) => node,
};
let lsize = len(node.left.as_deref());
match f(&node.value) {
Ordering::Less => binary_search_by(node.right.as_deref(), f)
.map(|index| lsize + 1 + index)
.map_err(|index| lsize + 1 + index),
Ordering::Equal => Ok(lsize),
Ordering::Greater => binary_search_by(node.left.as_deref(), f),
}
}
fn partition_point<T>(
tree: Option<&Node<T>>,
mut is_right: impl FnMut(&Node<T>) -> bool,
) -> usize {
let node = match tree {
None => return 0,
Some(node) => node,
};
let lsize = len(node.left.as_deref());
if is_right(node) {
partition_point(node.left.as_deref(), is_right)
} else {
lsize + 1 + partition_point(node.right.as_deref(), is_right)
}
}
fn get<T>(
node: &Node<T>,
index: usize,
) -> &Node<T> {
let lsize = len(node.left.as_deref());
match lsize.cmp(&index) {
Ordering::Less => get(node.right.as_ref().unwrap(), index - lsize - 1),
Ordering::Equal => node,
Ordering::Greater => get(node.left.as_ref().unwrap(), index),
}
}
fn get_mut<T>(
node: &mut Node<T>,
index: usize,
) -> &mut Node<T> {
let lsize = len(node.left.as_deref());
match lsize.cmp(&index) {
Ordering::Less => {
get_mut(node.right.as_mut().unwrap(), index - lsize - 1)
}
Ordering::Equal => node,
Ordering::Greater => get_mut(node.left.as_mut().unwrap(), index),
}
}