use std::cmp::Ordering;
use std::fmt::Debug;
use std::mem;
#[derive(Debug, Clone, Eq, PartialEq)]
struct Node<T: Ord + Debug + PartialEq + Eq + Clone> {
value: T,
height: i32,
balance_factor: i8,
left: Option<Box<Node<T>>>,
right: Option<Box<Node<T>>>,
}
impl<T: Ord + Debug + PartialEq + Eq + Clone> Node<T> {
fn new(value: T) -> Self {
Self {
value,
height: 0,
balance_factor: 0,
left: None,
right: None,
}
}
fn update(&mut self) {
let left_node_height = self.left.as_ref().map_or(-1, |node| node.height);
let right_node_height = self.right.as_ref().map_or(-1, |node| node.height);
self.height = std::cmp::max(left_node_height, right_node_height) + 1;
self.balance_factor = (right_node_height - left_node_height) as i8;
}
}
#[derive(Default, Debug, Clone, Eq, PartialEq)]
pub struct AvlTree<T: Ord + Debug + PartialEq + Eq + Clone> {
root: Option<Box<Node<T>>>,
size: usize,
}
impl<T: Ord + Debug + PartialEq + Eq + Clone> AvlTree<T> {
pub fn new() -> Self {
Self {
root: None,
size: 0,
}
}
pub fn height(&self) -> Option<i32> {
self.root.as_ref().map(|node| node.height)
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn contains(&self, value: &T) -> bool {
fn _contains<T: Ord + Debug + Clone>(node: &Option<Box<Node<T>>>, value: &T) -> bool {
node.as_ref().map_or(false, |node| {
match value.cmp(&node.value) {
Ordering::Less => _contains(&node.left, value),
Ordering::Greater => _contains(&node.right, value),
Ordering::Equal => true,
}
})
}
_contains(&self.root, value)
}
pub fn insert(&mut self, value: T) -> bool {
fn _insert<T: Ord + Debug + Clone>(node: &mut Option<Box<Node<T>>>, value: T) -> bool {
let sucess = match node.as_mut() {
None => {
*node = Some(Box::new(Node::new(value)));
return true;
}
Some(node) => {
match value.cmp(&node.value) {
Ordering::Less => _insert(&mut node.left, value),
Ordering::Greater => _insert(&mut node.right, value),
Ordering::Equal => false,
}
}
};
let node = node.as_mut().unwrap();
node.update();
AvlTree::balance(node);
sucess
}
_insert(&mut self.root, value)
}
fn balance(node: &mut Box<Node<T>>) {
match node.balance_factor {
-2 => {
if node.left.as_ref().unwrap().balance_factor < 0 {
Self::rotate_right(node);
} else {
Self::rotate_left(&mut node.left.as_mut().unwrap());
Self::rotate_right(node);
}
}
2 => {
if node.right.as_ref().unwrap().balance_factor > 0 {
Self::rotate_left(node);
} else {
Self::rotate_right(&mut node.right.as_mut().unwrap());
Self::rotate_left(node);
}
}
_ => {}
}
}
fn rotate_left(node: &mut Box<Node<T>>) {
let right_left = node.right.as_mut().unwrap().left.take();
let new_parent = mem::replace(&mut node.right, right_left).unwrap();
let new_left_child = mem::replace(node, new_parent);
node.left = Some(new_left_child);
node.update();
node.left.as_mut().unwrap().update();
}
fn rotate_right(node: &mut Box<Node<T>>) {
let left_right = node.left.as_mut().unwrap().right.take();
let new_parent = mem::replace(&mut node.left, left_right).unwrap();
let new_right_child = mem::replace(node, new_parent);
node.right = Some(new_right_child);
node.update();
node.right.as_mut().unwrap().update();
}
pub fn remove(&mut self, elem: &T) {
fn _remove<T: Ord + Debug + Clone>(
node: Option<Box<Node<T>>>,
elem: &T,
) -> Option<Box<Node<T>>> {
match node {
None => None,
Some(mut node) => {
match elem.cmp(&node.value) {
Ordering::Less => node.left = _remove(node.left, elem),
Ordering::Greater => node.right = _remove(node.right, elem),
Ordering::Equal => {
if node.left.is_none() {
return node.right;
}
else if node.right.is_none() {
return node.left;
}
else {
let left = node.left.as_ref().unwrap();
let right = node.right.as_ref().unwrap();
if left.height >= right.height {
let successor_value = AvlTree::find_max(&left).clone();
node.value = successor_value.clone();
node.left = _remove(node.left, &successor_value);
} else {
let successor_value = AvlTree::find_min(&right).clone();
node.value = successor_value.clone();
node.right = _remove(node.right, &successor_value);
}
}
}
}
node.update();
AvlTree::balance(&mut node);
Some(node)
}
}
}
let root = mem::replace(&mut self.root, None);
self.root = _remove(root, elem);
}
pub fn remove_efficient(&mut self, elem: &T) {
fn _remove<T: Ord + Debug + Clone>(_node: &mut Option<Box<Node<T>>>, elem: &T) {
match _node {
None => {}
Some(node) => {
match elem.cmp(&node.value) {
Ordering::Less => {
_remove(&mut node.left, elem);
}
Ordering::Greater => {
_remove(&mut node.right, elem);
}
Ordering::Equal => {
*_node = match (node.left.take(), node.right.take()) {
(None, None) => None,
(None, Some(right)) => Some(right),
(Some(left), None) => Some(left),
(Some(left), Some(right)) => {
if left.height >= right.height {
let mut x = AvlTree::remove_max(left);
x.right = Some(right);
Some(x)
} else {
let mut x = AvlTree::remove_min(right);
x.left = Some(left);
Some(x)
}
}
};
}
}
let mut node = _node.as_mut().unwrap();
node.update();
AvlTree::balance(&mut node);
}
}
}
_remove(&mut self.root, elem);
}
fn find_min(mut node: &Node<T>) -> &T {
while let Some(next_node) = node.left.as_ref() {
node = &next_node;
}
&node.value
}
fn find_max(mut node: &Node<T>) -> &T {
while let Some(next_node) = node.right.as_ref() {
node = &next_node;
}
&node.value
}
fn remove_min(mut node: Box<Node<T>>) -> Box<Node<T>> {
fn _remove_min<T: Ord + Debug + PartialEq + Eq + Clone>(
node: &mut Node<T>,
) -> Option<Box<Node<T>>> {
if let Some(next_node) = node.left.as_mut() {
let res = _remove_min(next_node);
if res.is_none() {
node.left.take()
} else {
res
}
} else {
None
}
}
_remove_min(&mut node).unwrap_or(node)
}
fn remove_max(mut node: Box<Node<T>>) -> Box<Node<T>> {
fn _remove_max<T: Ord + Debug + PartialEq + Eq + Clone>(
node: &mut Node<T>,
) -> Option<Box<Node<T>>> {
if let Some(next_node) = node.right.as_mut() {
let res = _remove_max(next_node);
if res.is_none() {
node.right.take()
} else {
res
}
} else {
None
}
}
_remove_max(&mut node).unwrap_or(node)
}
pub fn iter(&self) -> AvlIter<T> {
if let Some(trav) = self.root.as_ref() {
AvlIter {
stack: Some(vec![trav]),
trav: Some(trav),
}
} else {
AvlIter {
stack: None,
trav: None,
}
}
}
}
pub struct AvlIter<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> {
stack: Option<Vec<&'a Node<T>>>,
trav: Option<&'a Node<T>>,
}
impl<'a, T: 'a + Ord + Debug + PartialEq + Eq + Clone> Iterator for AvlIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let (Some(stack), Some(trav)) = (self.stack.as_mut(), self.trav.as_mut()) {
while let Some(left) = trav.left.as_ref() {
stack.push(left);
*trav = left;
}
stack.pop().map(|curr| {
if let Some(right) = curr.right.as_ref() {
stack.push(right);
*trav = right;
}
&curr.value
})
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_avl() {
let mut avl = AvlTree::new();
avl.insert(2);
avl.insert(5);
avl.insert(7);
avl.insert(10);
avl.insert(15);
let root = avl.root.as_ref().unwrap();
assert_eq!(root.value, 5);
let n2 = root.left.as_ref().unwrap();
let n10 = root.right.as_ref().unwrap();
assert_eq!(n2.value, 2);
assert_eq!(n10.value, 10);
assert_eq!(n10.left.as_ref().unwrap().value, 7);
assert_eq!(n10.right.as_ref().unwrap().value, 15);
AvlTree::rotate_left(avl.root.as_mut().unwrap());
let root = avl.root.as_ref().unwrap();
assert_eq!(root.value, 10);
let n5 = root.left.as_ref().unwrap();
let n15 = root.right.as_ref().unwrap();
assert_eq!(n5.value, 5);
assert_eq!(n15.value, 15);
assert_eq!(n5.left.as_ref().unwrap().value, 2);
assert_eq!(n5.right.as_ref().unwrap().value, 7);
avl.remove_efficient(&5);
let root = avl.root.as_ref().unwrap();
assert_eq!(root.value, 10);
let n2 = root.left.as_ref().unwrap();
let n15 = root.right.as_ref().unwrap();
assert_eq!(n2.value, 2);
assert_eq!(n15.value, 15);
assert!(n2.left.as_ref().is_none());
assert_eq!(n2.right.as_ref().unwrap().value, 7);
}
#[test]
fn test_avl_iter() {
let mut avl = AvlTree::new();
avl.insert(2);
avl.insert(5);
avl.insert(7);
avl.insert(10);
avl.insert(15);
let v = avl.iter().cloned().collect::<Vec<_>>();
assert_eq!(&v, &[2, 5, 7, 10, 15]);
}
}