#![allow(unsafe_code)]
use std::cmp::Ordering;
use std::marker::PhantomData;
use std::mem;
use std::ptr::NonNull;
struct Node<T> {
data: T,
height: i32,
left: Option<NonNull<Node<T>>>,
right: Option<NonNull<Node<T>>>,
parent: Option<NonNull<Node<T>>>,
}
pub struct BinaryTree<T, C = DefaultCompare>
where
C: Comparator<T>,
{
root: Option<NonNull<Node<T>>>,
size: usize,
comparator: C,
marker: PhantomData<T>,
}
#[derive(Default)]
pub struct DefaultCompare;
pub trait Comparator<T> {
fn compare(&self, a: &T, b: &T) -> Ordering;
}
impl<T: Ord> Comparator<T> for DefaultCompare {
fn compare(&self, a: &T, b: &T) -> Ordering {
a.cmp(b)
}
}
impl<T> Node<T> {
fn new(data: T) -> Self {
Node {
data,
height: 1,
left: None,
right: None,
parent: None,
}
}
fn into_ptr(self) -> NonNull<Self> {
let node = Box::new(self);
NonNull::new(Box::into_raw(node)).unwrap()
}
fn height(&self) -> i32 {
self.height
}
fn balance_factor(&self) -> i32 {
let left_height = self
.left
.map_or(0, |left| unsafe { (*left.as_ptr()).height });
let right_height = self
.right
.map_or(0, |right| unsafe { (*right.as_ptr()).height });
left_height - right_height
}
fn update_height(&mut self) {
let left_height = self
.left
.map_or(0, |left| unsafe { (*left.as_ptr()).height });
let right_height = self
.right
.map_or(0, |right| unsafe { (*right.as_ptr()).height });
self.height = 1 + std::cmp::max(left_height, right_height);
}
}
impl<T, C> BinaryTree<T, C>
where
C: Comparator<T>,
{
pub fn new(comparator: C) -> Self {
BinaryTree {
root: None,
size: 0,
comparator,
marker: PhantomData,
}
}
pub fn new_default() -> Self
where
C: Default,
{
Self::new(C::default())
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn insert(&mut self, data: T) {
let new_node = Node::new(data).into_ptr();
if self.root.is_none() {
self.root = Some(new_node);
self.size = 1;
return;
}
unsafe {
let mut current = self.root;
let mut parent = None;
let mut is_left = false;
while let Some(node) = current {
parent = current;
match self
.comparator
.compare(&(*new_node.as_ptr()).data, &(*node.as_ptr()).data)
{
Ordering::Less => {
current = (*node.as_ptr()).left;
is_left = true;
}
Ordering::Greater => {
current = (*node.as_ptr()).right;
is_left = false;
}
Ordering::Equal => {
unsafe {
mem::swap(&mut (*node.as_ptr()).data, &mut (*new_node.as_ptr()).data);
}
let _ = unsafe { Box::from_raw(new_node.as_ptr()) };
return;
}
}
}
(*new_node.as_ptr()).parent = parent;
if let Some(p) = parent {
if is_left {
(*p.as_ptr()).left = Some(new_node);
} else {
(*p.as_ptr()).right = Some(new_node);
}
}
self.rebalance_after_insert(parent);
}
self.size += 1;
}
pub fn find(&self, data: &T) -> Option<&T> {
unsafe {
let mut current = self.root;
while let Some(node) = current {
match self.comparator.compare(data, &(*node.as_ptr()).data) {
Ordering::Less => current = (*node.as_ptr()).left,
Ordering::Greater => current = (*node.as_ptr()).right,
Ordering::Equal => return Some(&(*node.as_ptr()).data),
}
}
None
}
}
unsafe fn rebalance_after_insert(&mut self, mut node: Option<NonNull<Node<T>>>) {
while let Some(n) = node {
(*n.as_ptr()).update_height();
let balance = (*n.as_ptr()).balance_factor();
if balance > 1 {
let left = (*n.as_ptr()).left.unwrap();
if (*left.as_ptr()).balance_factor() < 0 {
self.rotate_left(Some(left));
}
node = self.rotate_right(Some(n));
} else if balance < -1 {
let right = (*n.as_ptr()).right.unwrap();
if (*right.as_ptr()).balance_factor() > 0 {
self.rotate_right(Some(right));
}
node = self.rotate_left(Some(n));
}
node = (*n.as_ptr()).parent;
}
}
unsafe fn rotate_left(&mut self, node: Option<NonNull<Node<T>>>) -> Option<NonNull<Node<T>>> {
if let Some(n) = node {
if let Some(right) = (*n.as_ptr()).right {
(*n.as_ptr()).right = (*right.as_ptr()).left;
if let Some(left) = (*right.as_ptr()).left {
(*left.as_ptr()).parent = Some(n);
}
(*right.as_ptr()).left = Some(n);
(*right.as_ptr()).parent = (*n.as_ptr()).parent;
(*n.as_ptr()).parent = Some(right);
if let Some(parent) = (*right.as_ptr()).parent {
if (*parent.as_ptr()).left == Some(n) {
(*parent.as_ptr()).left = Some(right);
} else {
(*parent.as_ptr()).right = Some(right);
}
} else {
self.root = Some(right);
}
(*n.as_ptr()).update_height();
(*right.as_ptr()).update_height();
return Some(right);
}
}
node
}
unsafe fn rotate_right(&mut self, node: Option<NonNull<Node<T>>>) -> Option<NonNull<Node<T>>> {
if let Some(n) = node {
if let Some(left) = (*n.as_ptr()).left {
(*n.as_ptr()).left = (*left.as_ptr()).right;
if let Some(right) = (*left.as_ptr()).right {
(*right.as_ptr()).parent = Some(n);
}
(*left.as_ptr()).right = Some(n);
(*left.as_ptr()).parent = (*n.as_ptr()).parent;
(*n.as_ptr()).parent = Some(left);
if let Some(parent) = (*left.as_ptr()).parent {
if (*parent.as_ptr()).left == Some(n) {
(*parent.as_ptr()).left = Some(left);
} else {
(*parent.as_ptr()).right = Some(left);
}
} else {
self.root = Some(left);
}
(*n.as_ptr()).update_height();
(*left.as_ptr()).update_height();
return Some(left);
}
}
node
}
pub fn remove(&mut self, data: &T) -> Option<T> {
unsafe {
let mut current = self.root;
while let Some(node) = current {
match self.comparator.compare(data, &(*node.as_ptr()).data) {
Ordering::Less => current = (*node.as_ptr()).left,
Ordering::Greater => current = (*node.as_ptr()).right,
Ordering::Equal => {
return Some(self.remove_node(node));
}
}
}
None
}
}
unsafe fn remove_node(&mut self, node: NonNull<Node<T>>) -> T {
let node = node.as_ptr();
let parent = (*node).parent;
match ((*node).left, (*node).right) {
(None, None) => {
self.update_parent_link(node, None);
let node = Box::from_raw(node);
self.size -= 1;
self.rebalance_after_remove(parent);
node.data
}
(Some(left), None) | (None, Some(left)) => {
let child = left;
self.update_parent_link(node, Some(child));
(*child.as_ptr()).parent = (*node).parent;
let node = Box::from_raw(node);
self.size -= 1;
self.rebalance_after_remove(parent);
node.data
}
(Some(_), Some(_)) => {
let successor = self.find_successor(node);
mem::swap(&mut (*node).data, &mut (*successor.as_ptr()).data);
self.remove_node(successor)
}
}
}
unsafe fn update_parent_link(
&mut self,
node: *mut Node<T>,
new_child: Option<NonNull<Node<T>>>,
) {
let parent = (*node).parent;
match parent {
None => self.root = new_child,
Some(p) => {
if (*p.as_ptr()).left == Some(NonNull::new(node).unwrap()) {
(*p.as_ptr()).left = new_child;
} else {
(*p.as_ptr()).right = new_child;
}
}
}
}
unsafe fn find_successor(&self, node: *mut Node<T>) -> NonNull<Node<T>> {
let mut current = (*node).right.unwrap();
while let Some(left) = (*current.as_ptr()).left {
current = left;
}
current
}
unsafe fn rebalance_after_remove(&mut self, mut node: Option<NonNull<Node<T>>>) {
while let Some(n) = node {
(*n.as_ptr()).update_height();
let balance = (*n.as_ptr()).balance_factor();
if balance > 1 {
let left = (*n.as_ptr()).left.unwrap();
if (*left.as_ptr()).balance_factor() < 0 {
self.rotate_left(Some(left));
}
node = self.rotate_right(Some(n));
} else if balance < -1 {
let right = (*n.as_ptr()).right.unwrap();
if (*right.as_ptr()).balance_factor() > 0 {
self.rotate_right(Some(right));
}
node = self.rotate_left(Some(n));
}
node = (*n.as_ptr()).parent;
}
}
pub fn iter(&self) -> InorderIter<'_, T> {
InorderIter::new(self)
}
pub fn iter_mut(&mut self) -> InorderIterMut<'_, T> {
InorderIterMut::new(self)
}
pub fn iter_preorder(&self) -> PreorderIter<'_, T> {
PreorderIter::new(self)
}
pub fn iter_postorder(&self) -> PostorderIter<'_, T> {
PostorderIter::new(self)
}
}
pub struct InorderIter<'a, T> {
stack: Vec<NonNull<Node<T>>>,
current: Option<NonNull<Node<T>>>,
marker: PhantomData<&'a T>,
}
impl<'a, T> InorderIter<'a, T> {
fn new<C: Comparator<T>>(tree: &'a BinaryTree<T, C>) -> Self {
let mut iter = InorderIter {
stack: Vec::new(),
current: tree.root,
marker: PhantomData,
};
iter.push_left_edge();
iter
}
fn push_left_edge(&mut self) {
while let Some(node) = self.current {
self.stack.push(node);
unsafe {
self.current = (*node.as_ptr()).left;
}
}
}
}
impl<'a, T> Iterator for InorderIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(node) = self.stack.pop() {
unsafe {
self.current = (*node.as_ptr()).right;
self.push_left_edge();
Some(&(*node.as_ptr()).data)
}
} else {
None
}
}
}
pub struct InorderIterMut<'a, T> {
stack: Vec<NonNull<Node<T>>>,
current: Option<NonNull<Node<T>>>,
marker: PhantomData<&'a mut T>,
}
impl<'a, T> InorderIterMut<'a, T> {
fn new<C: Comparator<T>>(tree: &'a mut BinaryTree<T, C>) -> Self {
let mut iter = InorderIterMut {
stack: Vec::new(),
current: tree.root,
marker: PhantomData,
};
iter.push_left_edge();
iter
}
fn push_left_edge(&mut self) {
while let Some(node) = self.current {
self.stack.push(node);
unsafe {
self.current = (*node.as_ptr()).left;
}
}
}
}
impl<'a, T> Iterator for InorderIterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(node) = self.stack.pop() {
unsafe {
self.current = (*node.as_ptr()).right;
self.push_left_edge();
Some(&mut (*node.as_ptr()).data)
}
} else {
None
}
}
}
pub struct PreorderIter<'a, T> {
stack: Vec<NonNull<Node<T>>>,
marker: PhantomData<&'a T>,
}
impl<'a, T> PreorderIter<'a, T> {
fn new<C: Comparator<T>>(tree: &'a BinaryTree<T, C>) -> Self {
let mut iter = PreorderIter {
stack: Vec::new(),
marker: PhantomData,
};
if let Some(root) = tree.root {
iter.stack.push(root);
}
iter
}
}
impl<'a, T> Iterator for PreorderIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(node) = self.stack.pop() {
unsafe {
if let Some(right) = (*node.as_ptr()).right {
self.stack.push(right);
}
if let Some(left) = (*node.as_ptr()).left {
self.stack.push(left);
}
Some(&(*node.as_ptr()).data)
}
} else {
None
}
}
}
pub struct PostorderIter<'a, T> {
stack: Vec<(NonNull<Node<T>>, bool)>,
marker: PhantomData<&'a T>,
}
impl<'a, T> PostorderIter<'a, T> {
fn new<C: Comparator<T>>(tree: &'a BinaryTree<T, C>) -> Self {
let mut iter = PostorderIter {
stack: Vec::new(),
marker: PhantomData,
};
if let Some(root) = tree.root {
iter.push_left_path(root);
}
iter
}
fn push_left_path(&mut self, mut node: NonNull<Node<T>>) {
unsafe {
loop {
self.stack.push((node, false));
if let Some(left) = (*node.as_ptr()).left {
node = left;
} else {
break;
}
}
}
}
}
impl<'a, T> Iterator for PostorderIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
while let Some(&mut (node, ref mut visited)) = self.stack.last_mut() {
unsafe {
if !*visited {
*visited = true;
if let Some(right) = (*node.as_ptr()).right {
self.push_left_path(right);
}
} else {
self.stack.pop();
return Some(&(*node.as_ptr()).data);
}
}
}
None
}
}
impl<'a, T, C> IntoIterator for &'a BinaryTree<T, C>
where
C: Comparator<T>,
{
type Item = &'a T;
type IntoIter = InorderIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T, C> IntoIterator for &'a mut BinaryTree<T, C>
where
C: Comparator<T>,
{
type Item = &'a mut T;
type IntoIter = InorderIterMut<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_find() {
let mut tree = BinaryTree::<i32>::new_default();
assert!(tree.is_empty());
tree.insert(5);
tree.insert(3);
tree.insert(7);
tree.insert(1);
tree.insert(9);
assert_eq!(tree.len(), 5);
assert!(!tree.is_empty());
assert!(tree.find(&5).is_some());
assert!(tree.find(&3).is_some());
assert!(tree.find(&7).is_some());
assert!(tree.find(&1).is_some());
assert!(tree.find(&9).is_some());
assert!(tree.find(&4).is_none());
}
#[test]
fn test_balance() {
let mut tree = BinaryTree::<i32>::new_default();
for i in 1..=7 {
tree.insert(i);
}
unsafe {
let root = tree.root.unwrap();
assert!((*root.as_ptr()).balance_factor().abs() <= 1);
fn check_balance<T>(node: Option<NonNull<Node<T>>>) -> bool {
match node {
None => true,
Some(n) => unsafe {
let balance = (*n.as_ptr()).balance_factor();
balance.abs() <= 1
&& check_balance((*n.as_ptr()).left)
&& check_balance((*n.as_ptr()).right)
},
}
}
assert!(check_balance(Some(root)));
}
}
#[test]
fn test_custom_comparator() {
struct ReverseCompare;
impl Comparator<i32> for ReverseCompare {
fn compare(&self, a: &i32, b: &i32) -> Ordering {
b.cmp(a)
}
}
let mut tree = BinaryTree::new(ReverseCompare);
tree.insert(5);
tree.insert(3);
tree.insert(7);
assert!(tree.find(&5).is_some());
assert!(tree.find(&3).is_some());
assert!(tree.find(&7).is_some());
}
#[test]
fn test_remove() {
let mut tree = BinaryTree::<i32>::new_default();
tree.insert(5);
tree.insert(3);
tree.insert(7);
tree.insert(1);
tree.insert(9);
assert_eq!(tree.remove(&3), Some(3));
assert_eq!(tree.len(), 4);
assert!(tree.find(&3).is_none());
assert_eq!(tree.remove(&5), Some(5));
assert_eq!(tree.len(), 3);
assert!(tree.find(&5).is_none());
assert_eq!(tree.remove(&10), None);
}
#[test]
fn test_iterators() {
let mut tree = BinaryTree::<i32>::new_default();
for i in &[5, 3, 7, 1, 9] {
tree.insert(*i);
}
let inorder: Vec<_> = tree.iter().copied().collect();
assert_eq!(inorder, vec![1, 3, 5, 7, 9]);
let preorder: Vec<_> = tree.iter_preorder().copied().collect();
assert_eq!(preorder, vec![5, 3, 1, 7, 9]);
let postorder: Vec<_> = tree.iter_postorder().copied().collect();
assert_eq!(postorder, vec![1, 3, 9, 7, 5]);
for x in &mut tree {
*x += 10;
}
let modified: Vec<_> = tree.iter().copied().collect();
assert_eq!(modified, vec![11, 13, 15, 17, 19]);
}
}