1use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72 cmp::{Ordering, PartialEq},
73 fmt, mem,
74 ptr::{NonNull, null},
75};
76
77pub unsafe trait AvlItem<Tag>: Sized {
89 fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94 Left = 0,
95 Right = 1,
96}
97
98impl AvlDirection {
99 #[inline(always)]
100 fn reverse(self) -> AvlDirection {
101 match self {
102 AvlDirection::Left => AvlDirection::Right,
103 AvlDirection::Right => AvlDirection::Left,
104 }
105 }
106}
107
108macro_rules! avlchild_to_balance {
109 ( $dir: expr ) => {
110 match $dir {
111 AvlDirection::Left => -1,
112 AvlDirection::Right => 1,
113 }
114 };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118 pub left: *const T,
119 pub right: *const T,
120 pub parent: *const T,
121 pub balance: i8,
122 _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128 #[inline(always)]
129 pub fn detach(&mut self) {
130 self.left = null();
131 self.right = null();
132 self.parent = null();
133 self.balance = 0;
134 }
135
136 #[inline(always)]
137 fn get_child(&self, dir: AvlDirection) -> *const T {
138 match dir {
139 AvlDirection::Left => self.left,
140 AvlDirection::Right => self.right,
141 }
142 }
143
144 #[inline(always)]
145 fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146 match dir {
147 AvlDirection::Left => self.left = child,
148 AvlDirection::Right => self.right = child,
149 }
150 }
151
152 #[inline(always)]
153 fn get_parent(&self) -> *const T {
154 self.parent
155 }
156
157 #[inline(always)]
159 pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160 mem::swap(&mut self.left, &mut other.left);
161 mem::swap(&mut self.right, &mut other.right);
162 mem::swap(&mut self.parent, &mut other.parent);
163 mem::swap(&mut self.balance, &mut other.balance);
164 }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168 fn default() -> Self {
169 Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170 }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176 write!(f, "(")?;
177
178 if !self.left.is_null() {
179 write!(f, "left: {:p}", self.left)?;
180 } else {
181 write!(f, "left: none ")?;
182 }
183
184 if !self.right.is_null() {
185 write!(f, "right: {:p}", self.right)?;
186 } else {
187 write!(f, "right: none ")?;
188 }
189 write!(f, ")")
190 }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195pub struct AvlTree<P, Tag>
200where
201 P: Pointer,
202 P::Target: AvlItem<Tag>,
203{
204 pub root: *const P::Target,
205 count: i64,
206 _phan: PhantomData<fn(P, &Tag)>,
207}
208
209pub struct AvlSearchResult<'a, P: Pointer> {
221 pub node: *const P::Target,
223 pub direction: Option<AvlDirection>,
225 _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229 fn default() -> Self {
230 AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231 }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235 #[inline(always)]
237 pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238 if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239 }
240
241 #[inline(always)]
243 pub fn is_exact(&self) -> bool {
244 self.direction.is_none() && !self.node.is_null()
245 }
246
247 #[inline(always)]
277 pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278 AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279 }
280
281 #[inline(always)]
283 pub fn get_nearest(&self) -> Option<&P::Target> {
284 if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285 }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289 pub fn get_exact(&self) -> Option<Arc<T>> {
291 if self.is_exact() {
292 unsafe {
293 Arc::increment_strong_count(self.node);
294 Some(Arc::from_raw(self.node))
295 }
296 } else {
297 None
298 }
299 }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303 pub fn get_exact(&self) -> Option<Rc<T>> {
305 if self.is_exact() {
306 unsafe {
307 Rc::increment_strong_count(self.node);
308 Some(Rc::from_raw(self.node))
309 }
310 } else {
311 None
312 }
313 }
314}
315
316macro_rules! return_end {
317 ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321 ($balance: expr) => {
322 match $balance {
323 0 | 1 => AvlDirection::Left,
324 _ => AvlDirection::Right,
325 }
326 };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331 P: Pointer,
332 P::Target: AvlItem<Tag>,
333{
334 pub fn new() -> Self {
336 AvlTree { count: 0, root: null(), _phan: Default::default() }
337 }
338
339 #[inline]
344 pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345 AvlDrain { tree: self, parent: null(), dir: None }
346 }
347
348 pub fn get_count(&self) -> i64 {
349 self.count
350 }
351
352 pub fn first(&self) -> Option<&P::Target> {
353 unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354 }
355
356 #[inline]
357 pub fn last(&self) -> Option<&P::Target> {
358 unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359 }
360
361 #[inline]
406 pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407 debug_assert!(w.direction.is_some());
408 self._insert(new_data, w.node, w.direction.unwrap());
409 }
410
411 #[allow(clippy::not_unsafe_ptr_arg_deref)]
412 pub fn _insert(
413 &mut self,
414 new_data: P,
415 here: *const P::Target, mut which_child: AvlDirection,
417 ) {
418 let mut new_balance: i8;
419 let new_ptr = new_data.into_raw();
420
421 if here.is_null() {
422 if self.count > 0 {
423 panic!("insert into a tree size {} with empty where.node", self.count);
424 }
425 self.root = new_ptr;
426 self.count += 1;
427 return;
428 }
429
430 let parent = unsafe { &*here };
431 let node = unsafe { (*new_ptr).get_node() };
432 let parent_node = parent.get_node();
433 node.parent = here;
434 parent_node.set_child(which_child, new_ptr);
435 self.count += 1;
436
437 let mut data: *const P::Target = here;
444 loop {
445 let node = unsafe { (*data).get_node() };
446 let old_balance = node.balance;
447 new_balance = old_balance + avlchild_to_balance!(which_child);
448 if new_balance == 0 {
449 node.balance = 0;
450 return;
451 }
452 if old_balance != 0 {
453 self.rotate(data, new_balance);
454 return;
455 }
456 node.balance = new_balance;
457 let parent_ptr = node.get_parent();
458 if parent_ptr.is_null() {
459 return;
460 }
461 which_child = self.parent_direction(data, parent_ptr);
462 data = parent_ptr;
463 }
464 }
465
466 pub unsafe fn insert_here(
481 &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
482 ) {
483 let mut dir_child = direction;
484 assert!(!here.node.is_null());
485 let here_node = here.node;
486 let child = unsafe { (*here_node).get_node().get_child(dir_child) };
487 if !child.is_null() {
488 dir_child = dir_child.reverse();
489 let node = self.bottom_child_ref(child, dir_child);
490 self._insert(new_data, node, dir_child);
491 } else {
492 self._insert(new_data, here_node, dir_child);
493 }
494 }
495
496 #[inline(always)]
498 fn set_child2(
499 &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
500 parent: *const P::Target,
501 ) {
502 if !child.is_null() {
503 unsafe { (*child).get_node().parent = parent };
504 }
505 node.set_child(dir, child);
506 }
507
508 #[inline(always)]
509 fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
510 if !parent.is_null() {
511 let parent_node = unsafe { (*parent).get_node() };
512 if parent_node.left == data {
513 return AvlDirection::Left;
514 }
515 if parent_node.right == data {
516 return AvlDirection::Right;
517 }
518 panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
519 }
520 AvlDirection::Left
522 }
523
524 #[inline(always)]
525 fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
526 let node = unsafe { (*data).get_node() };
527 let parent = node.get_parent();
528 if !parent.is_null() {
529 return self.parent_direction(data, parent);
530 }
531 AvlDirection::Left
533 }
534
535 #[inline]
536 fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
537 let dir = if balance < 0 { AvlDirection::Left } else { AvlDirection::Right };
538 let node = unsafe { (*data).get_node() };
539
540 let parent = node.get_parent();
541 let dir_inverse = dir.reverse();
542 let left_heavy = balance >> 1;
543 let right_heavy = -left_heavy;
544
545 let child = node.get_child(dir);
546 let child_node = unsafe { (*child).get_node() };
547 let mut child_balance = child_node.balance;
548
549 let which_child = self.parent_direction(data, parent);
550
551 if child_balance != right_heavy {
553 child_balance += right_heavy;
554
555 let c_right = child_node.get_child(dir_inverse);
556 self.set_child2(node, dir, c_right, data);
557 node.balance = -child_balance;
559
560 node.parent = child;
561 child_node.set_child(dir_inverse, data);
562 child_node.balance = child_balance;
565 if !parent.is_null() {
566 child_node.parent = parent;
567 unsafe { (*parent).get_node() }.set_child(which_child, child);
568 } else {
569 child_node.parent = null();
570 self.root = child;
571 }
572 return child_balance == 0;
573 }
574 let g_child = child_node.get_child(dir_inverse);
578 let g_child_node = unsafe { (*g_child).get_node() };
579 let g_left = g_child_node.get_child(dir);
580 let g_right = g_child_node.get_child(dir_inverse);
581
582 self.set_child2(node, dir, g_right, data);
583 self.set_child2(child_node, dir_inverse, g_left, child);
584
585 let g_child_balance = g_child_node.balance;
592 if g_child_balance == right_heavy {
593 child_node.balance = left_heavy;
594 } else {
595 child_node.balance = 0;
596 }
597 child_node.parent = g_child;
598 g_child_node.set_child(dir, child);
599
600 if g_child_balance == left_heavy {
601 node.balance = right_heavy;
602 } else {
603 node.balance = 0;
604 }
605 g_child_node.balance = 0;
606
607 node.parent = g_child;
608 g_child_node.set_child(dir_inverse, data);
609
610 if !parent.is_null() {
611 g_child_node.parent = parent;
612 unsafe { (*parent).get_node() }.set_child(which_child, g_child);
613 } else {
614 g_child_node.parent = null();
615 self.root = g_child;
616 }
617 true
618 }
619
620 pub unsafe fn remove(&mut self, del: *const P::Target) {
657 if self.count == 0 {
668 return;
669 }
670 if self.count == 1 && self.root == del {
671 self.root = null();
672 self.count = 0;
673 unsafe { (*del).get_node().detach() };
674 return;
675 }
676 let mut which_child: AvlDirection;
677
678 let del_node = unsafe { (*del).get_node() };
680
681 let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
682
683 if node_swap_flag {
684 let dir: AvlDirection = balance_to_child!(del_node.balance + 1);
685 let child_temp = del_node.get_child(dir);
686
687 let dir_inverse: AvlDirection = dir.reverse();
688 let child = self.bottom_child_ref(child_temp, dir_inverse);
689
690 let dir_child_temp =
693 if child == child_temp { dir } else { self.parent_direction2(child) };
694
695 let parent = del_node.get_parent();
698 let dir_child_del = if !parent.is_null() {
699 self.parent_direction(del, parent)
700 } else {
701 AvlDirection::Left
702 };
703
704 let child_node = unsafe { (*child).get_node() };
705 child_node.swap(del_node);
706
707 if child_node.get_child(dir) == child {
709 child_node.set_child(dir, del);
711 }
712
713 let c_dir = child_node.get_child(dir);
714 if c_dir == del {
715 del_node.parent = child;
716 } else if !c_dir.is_null() {
717 unsafe { (*c_dir).get_node() }.parent = child;
718 }
719
720 let c_inv = child_node.get_child(dir_inverse);
721 if c_inv == del {
722 del_node.parent = child;
723 } else if !c_inv.is_null() {
724 unsafe { (*c_inv).get_node() }.parent = child;
725 }
726
727 let parent = child_node.get_parent();
728 if !parent.is_null() {
729 unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
730 } else {
731 self.root = child;
732 }
733
734 let parent = del_node.get_parent();
737 unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
738 if !del_node.right.is_null() {
739 which_child = AvlDirection::Right;
740 } else {
741 which_child = AvlDirection::Left;
742 }
743 let child = del_node.get_child(which_child);
744 if !child.is_null() {
745 unsafe { (*child).get_node() }.parent = del;
746 }
747 which_child = dir_child_temp;
748 } else {
749 let parent = del_node.get_parent();
751 if !parent.is_null() {
752 which_child = self.parent_direction(del, parent);
753 } else {
754 which_child = AvlDirection::Left;
755 }
756 }
757
758 let parent: *const P::Target = del_node.get_parent();
761
762 let imm_data: *const P::Target =
763 if !del_node.left.is_null() { del_node.left } else { del_node.right };
764
765 if !imm_data.is_null() {
767 let imm_node = unsafe { (*imm_data).get_node() };
768 imm_node.parent = parent;
769 }
770
771 if !parent.is_null() {
772 assert!(self.count > 0);
773 self.count -= 1;
774
775 let parent_node = unsafe { (*parent).get_node() };
776 parent_node.set_child(which_child, imm_data);
777
778 let mut node_data: *const P::Target = parent;
781 let mut old_balance: i8;
782 let mut new_balance: i8;
783 loop {
784 let node = unsafe { (*node_data).get_node() };
788 old_balance = node.balance;
789 new_balance = old_balance - avlchild_to_balance!(which_child);
790
791 if old_balance == 0 {
795 node.balance = new_balance;
796 break;
797 }
798
799 let parent = node.get_parent();
800 which_child = self.parent_direction(node_data, parent);
801
802 if new_balance == 0 {
808 node.balance = new_balance;
809 } else if !self.rotate(node_data, new_balance) {
810 break;
811 }
812
813 if !parent.is_null() {
814 node_data = parent;
815 continue;
816 }
817 break;
818 }
819 } else if !imm_data.is_null() {
820 assert!(self.count > 0);
821 self.count -= 1;
822 self.root = imm_data;
823 }
824 if self.root.is_null() && self.count > 0 {
825 panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
826 }
827 del_node.detach();
828 }
829
830 #[inline]
835 pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
836 let result = self.find(val, cmp_func);
837 self.remove_with(unsafe { result.detach() })
838 }
839
840 #[inline]
852 pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
853 if result.is_exact() {
854 unsafe {
855 let p = result.node;
856 self.remove(p);
857 Some(P::from_raw(p))
858 }
859 } else {
860 None
861 }
862 }
863
864 #[inline]
870 pub fn find<'a, K>(
871 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
872 ) -> AvlSearchResult<'a, P> {
873 if self.root.is_null() {
874 return AvlSearchResult::default();
875 }
876 let mut node_data = self.root;
877 loop {
878 let diff = cmp_func(val, unsafe { &*node_data });
879 match diff {
880 Ordering::Equal => {
881 return AvlSearchResult {
882 node: node_data,
883 direction: None,
884 _phan: PhantomData,
885 };
886 }
887 Ordering::Less => {
888 let node = unsafe { (*node_data).get_node() };
889 let left = node.get_child(AvlDirection::Left);
890 if left.is_null() {
891 return AvlSearchResult {
892 node: node_data,
893 direction: Some(AvlDirection::Left),
894 _phan: PhantomData,
895 };
896 }
897 node_data = left;
898 }
899 Ordering::Greater => {
900 let node = unsafe { (*node_data).get_node() };
901 let right = node.get_child(AvlDirection::Right);
902 if right.is_null() {
903 return AvlSearchResult {
904 node: node_data,
905 direction: Some(AvlDirection::Right),
906 _phan: PhantomData,
907 };
908 }
909 node_data = right;
910 }
911 }
912 }
913 }
914
915 #[inline]
917 pub fn find_contained<'a, K>(
918 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
919 ) -> Option<&'a P::Target> {
920 if self.root.is_null() {
921 return None;
922 }
923 let mut node_data = self.root;
924 let mut result_node: *const P::Target = null();
925 loop {
926 let diff = cmp_func(val, unsafe { &*node_data });
927 match diff {
928 Ordering::Equal => {
929 let node = unsafe { (*node_data).get_node() };
930 let left = node.get_child(AvlDirection::Left);
931 result_node = node_data;
932 if left.is_null() {
933 break;
934 } else {
935 node_data = left;
936 }
937 }
938 Ordering::Less => {
939 let node = unsafe { (*node_data).get_node() };
940 let left = node.get_child(AvlDirection::Left);
941 if left.is_null() {
942 break;
943 }
944 node_data = left;
945 }
946 Ordering::Greater => {
947 let node = unsafe { (*node_data).get_node() };
948 let right = node.get_child(AvlDirection::Right);
949 if right.is_null() {
950 break;
951 }
952 node_data = right;
953 }
954 }
955 }
956 if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
957 }
958
959 #[inline]
961 pub fn find_larger_eq<'a, K>(
962 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
963 ) -> AvlSearchResult<'a, P> {
964 if self.root.is_null() {
965 return AvlSearchResult::default();
966 }
967 let mut node_data = self.root;
968 loop {
969 let diff = cmp_func(val, unsafe { &*node_data });
970 match diff {
971 Ordering::Equal => {
972 return AvlSearchResult {
973 node: node_data,
974 direction: None,
975 _phan: PhantomData,
976 };
977 }
978 Ordering::Less => {
979 return AvlSearchResult {
980 node: node_data,
981 direction: None,
982 _phan: PhantomData,
983 };
984 }
985 Ordering::Greater => {
986 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
987 if right.is_null() {
988 return AvlSearchResult {
989 node: null(),
990 direction: None,
991 _phan: PhantomData,
992 };
993 }
994 node_data = right;
995 }
996 }
997 }
998 }
999
1000 #[inline]
1002 pub fn find_nearest<'a, K>(
1003 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1004 ) -> AvlSearchResult<'a, P> {
1005 if self.root.is_null() {
1006 return AvlSearchResult::default();
1007 }
1008
1009 let mut node_data = self.root;
1010 let mut nearest_node = null();
1011 loop {
1012 let diff = cmp_func(val, unsafe { &*node_data });
1013 match diff {
1014 Ordering::Equal => {
1015 return AvlSearchResult {
1016 node: node_data,
1017 direction: None,
1018 _phan: PhantomData,
1019 };
1020 }
1021 Ordering::Less => {
1022 nearest_node = node_data;
1023 let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1024 if left.is_null() {
1025 break;
1026 }
1027 node_data = left;
1028 }
1029 Ordering::Greater => {
1030 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1031 if right.is_null() {
1032 break;
1033 }
1034 node_data = right;
1035 }
1036 }
1037 }
1038 AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData }
1039 }
1040
1041 #[inline(always)]
1042 fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1043 loop {
1044 let child = unsafe { (*data).get_node() }.get_child(dir);
1045 if !child.is_null() {
1046 data = child;
1047 } else {
1048 return data;
1049 }
1050 }
1051 }
1052
1053 pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1054 let mut node = self.first();
1055 while let Some(n) = node {
1056 cb(n);
1057 node = self.next(n);
1058 }
1059 }
1060
1061 #[inline]
1062 pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1063 if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1064 Some(unsafe { p.as_ref() })
1065 } else {
1066 None
1067 }
1068 }
1069
1070 #[inline]
1071 pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1072 if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1073 Some(unsafe { p.as_ref() })
1074 } else {
1075 None
1076 }
1077 }
1078
1079 #[inline]
1080 fn walk_dir(
1081 &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1082 ) -> Option<NonNull<P::Target>> {
1083 let dir_inverse = dir.reverse();
1084 let node = unsafe { (*data_ptr).get_node() };
1085 let temp = node.get_child(dir);
1086 if !temp.is_null() {
1087 unsafe {
1088 Some(NonNull::new_unchecked(
1089 self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1090 ))
1091 }
1092 } else {
1093 let mut parent = node.parent;
1094 if parent.is_null() {
1095 return None;
1096 }
1097 loop {
1098 let pdir = self.parent_direction(data_ptr, parent);
1099 if pdir == dir_inverse {
1100 return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1101 }
1102 data_ptr = parent;
1103 parent = unsafe { (*parent).get_node() }.parent;
1104 if parent.is_null() {
1105 return None;
1106 }
1107 }
1108 }
1109 }
1110
1111 #[inline]
1112 fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1113 let node = unsafe { (*data).get_node() };
1114 let left = node.left;
1115 if !left.is_null() {
1116 assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1117 assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1118 }
1119 let right = node.right;
1120 if !right.is_null() {
1121 assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1122 assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1123 }
1124 }
1125
1126 #[inline]
1127 pub fn nearest<'a>(
1128 &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1129 ) -> AvlSearchResult<'a, P> {
1130 if !current.node.is_null() {
1131 if current.direction.is_some() && current.direction != Some(direction) {
1132 return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1133 }
1134 if let Some(node) = self.walk_dir(current.node, direction) {
1135 return AvlSearchResult {
1136 node: node.as_ptr(),
1137 direction: None,
1138 _phan: PhantomData,
1139 };
1140 }
1141 }
1142 AvlSearchResult::default()
1143 }
1144
1145 pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1146 let c = {
1147 #[cfg(feature = "std")]
1148 {
1149 ((self.get_count() + 10) as f32).log2() as usize
1150 }
1151 #[cfg(not(feature = "std"))]
1152 {
1153 100
1154 }
1155 };
1156 let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1157 if self.root.is_null() {
1158 assert_eq!(self.count, 0);
1159 return;
1160 }
1161 let mut data = self.root;
1162 let mut visited = 0;
1163 loop {
1164 if !data.is_null() {
1165 let left = {
1166 let node = unsafe { (*data).get_node() };
1167 node.get_child(AvlDirection::Left)
1168 };
1169 if !left.is_null() {
1170 stack.push(data);
1171 data = left;
1172 continue;
1173 }
1174 visited += 1;
1175 self.validate_node(data, cmp_func);
1176 data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1177 } else if !stack.is_empty() {
1178 let _data = stack.pop().unwrap();
1179 self.validate_node(_data, cmp_func);
1180 visited += 1;
1181 let node = unsafe { (*_data).get_node() };
1182 data = node.get_child(AvlDirection::Right);
1183 } else {
1184 break;
1185 }
1186 }
1187 assert_eq!(visited, self.count);
1188 }
1189
1190 #[inline]
1196 pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1197 if self.count == 0 && self.root.is_null() {
1198 self.root = node.into_raw();
1199 self.count = 1;
1200 return true;
1201 }
1202
1203 let w = self.find(node.as_ref(), cmp_func);
1204 if w.direction.is_none() {
1205 drop(node);
1208 return false;
1209 }
1210
1211 let w_node = w.node;
1214 let w_dir = w.direction;
1215
1216 let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1217
1218 self.insert(node, w_detached);
1219 true
1220 }
1221}
1222
1223impl<P, Tag> Drop for AvlTree<P, Tag>
1224where
1225 P: Pointer,
1226 P::Target: AvlItem<Tag>,
1227{
1228 fn drop(&mut self) {
1229 if mem::needs_drop::<P>() {
1230 for _ in self.drain() {}
1231 }
1232 }
1233}
1234
1235pub struct AvlDrain<'a, P: Pointer, Tag>
1236where
1237 P::Target: AvlItem<Tag>,
1238{
1239 tree: &'a mut AvlTree<P, Tag>,
1240 parent: *const P::Target,
1241 dir: Option<AvlDirection>,
1242}
1243
1244impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1245where
1246 P::Target: AvlItem<Tag>,
1247{
1248 type Item = P;
1249
1250 fn next(&mut self) -> Option<Self::Item> {
1251 if self.tree.root.is_null() {
1252 return None;
1253 }
1254
1255 let mut node: *const P::Target;
1256 let parent: *const P::Target;
1257
1258 if self.dir.is_none() && self.parent.is_null() {
1259 let mut curr = self.tree.root;
1261 while unsafe { !(*curr).get_node().left.is_null() } {
1262 curr = unsafe { (*curr).get_node().left };
1263 }
1264 node = curr;
1265 } else {
1266 parent = self.parent;
1267 if parent.is_null() {
1268 return None;
1270 }
1271
1272 let child_dir = self.dir.unwrap();
1273 if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1277 node = parent;
1278 } else {
1279 node = unsafe { (*parent).get_node().right };
1281 while unsafe { !(*node).get_node().left.is_null() } {
1282 node = unsafe { (*node).get_node().left };
1283 }
1284 }
1285 }
1286
1287 if unsafe { !(*node).get_node().right.is_null() } {
1289 node = unsafe { (*node).get_node().right };
1292 }
1293
1294 let next_parent = unsafe { (*node).get_node().parent };
1296 if next_parent.is_null() {
1297 self.tree.root = null();
1298 self.parent = null();
1299 self.dir = Some(AvlDirection::Left);
1300 } else {
1301 self.parent = next_parent;
1302 self.dir = Some(self.tree.parent_direction(node, next_parent));
1303 unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1305 }
1306
1307 self.tree.count -= 1;
1308 unsafe {
1309 (*node).get_node().detach();
1310 Some(P::from_raw(node))
1311 }
1312 }
1313}
1314
1315impl<T, Tag> AvlTree<Arc<T>, Tag>
1316where
1317 T: AvlItem<Tag>,
1318{
1319 pub fn remove_ref(&mut self, node: &Arc<T>) {
1320 let p = Arc::as_ptr(node);
1321 unsafe { self.remove(p) };
1322 unsafe { drop(Arc::from_raw(p)) };
1323 }
1324}
1325
1326impl<T, Tag> AvlTree<Rc<T>, Tag>
1327where
1328 T: AvlItem<Tag>,
1329{
1330 pub fn remove_ref(&mut self, node: &Rc<T>) {
1331 let p = Rc::as_ptr(node);
1332 unsafe { self.remove(p) };
1333 unsafe { drop(Rc::from_raw(p)) };
1334 }
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339 use super::*;
1340 use core::cell::UnsafeCell;
1341 use rand::Rng;
1342 use std::time::Instant;
1343
1344 struct IntAvlNode {
1345 pub value: i64,
1346 pub node: UnsafeCell<AvlNode<Self, ()>>,
1347 }
1348
1349 impl fmt::Debug for IntAvlNode {
1350 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1351 write!(f, "{} {:#?}", self.value, self.node)
1352 }
1353 }
1354
1355 impl fmt::Display for IntAvlNode {
1356 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1357 write!(f, "{}", self.value)
1358 }
1359 }
1360
1361 unsafe impl AvlItem<()> for IntAvlNode {
1362 fn get_node(&self) -> &mut AvlNode<Self, ()> {
1363 unsafe { &mut *self.node.get() }
1364 }
1365 }
1366
1367 type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1368
1369 fn new_intnode(i: i64) -> Box<IntAvlNode> {
1370 Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1371 }
1372
1373 fn new_inttree() -> IntAvlTree {
1374 AvlTree::<Box<IntAvlNode>, ()>::new()
1375 }
1376
1377 fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1378 a.value.cmp(&b.value)
1379 }
1380
1381 fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1382 a.cmp(&b.value)
1383 }
1384
1385 impl AvlTree<Box<IntAvlNode>, ()> {
1386 fn remove_int(&mut self, i: i64) -> bool {
1387 if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1388 return true;
1390 }
1391 println!("not found {}", i);
1393 false
1394 }
1395
1396 fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1397 self.add(node, cmp_int_node)
1398 }
1399
1400 fn validate_tree(&self) {
1401 self.validate(cmp_int_node);
1402 }
1403
1404 fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1405 self.find(&i, cmp_int)
1406 }
1407
1408 fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1409 self.find(node, cmp_int_node)
1410 }
1411 }
1412
1413 #[test]
1414 fn int_avl_node() {
1415 let mut tree = new_inttree();
1416
1417 assert_eq!(tree.get_count(), 0);
1418 assert!(tree.first().is_none());
1419 assert!(tree.last().is_none());
1420
1421 let node1 = new_intnode(1);
1422 let node2 = new_intnode(2);
1423 let node3 = new_intnode(3);
1424
1425 let p1 = &*node1 as *const IntAvlNode;
1426 let p2 = &*node2 as *const IntAvlNode;
1427 let p3 = &*node3 as *const IntAvlNode;
1428
1429 tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1430 tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1431
1432 assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1433 assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1436 }
1437
1438 #[test]
1439 fn int_avl_tree_basic() {
1440 let mut tree = new_inttree();
1441
1442 let temp_node = new_intnode(0);
1443 let temp_node_val = Pointer::as_ref(&temp_node);
1444 assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1445 assert_eq!(
1446 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1447 false
1448 );
1449 assert_eq!(
1450 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1451 false
1452 );
1453 drop(temp_node);
1454
1455 tree.add_int_node(new_intnode(0));
1456 let result = tree.find_int(0);
1457 assert!(result.get_node_ref().is_some());
1458 assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1459 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1460
1461 let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1462 assert!(rs.is_some());
1463 let found_value = rs.unwrap().value;
1464 assert_eq!(found_value, 0);
1465
1466 let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1467 assert!(rs.is_none());
1468
1469 let result = tree.find_int(1);
1470 let left = tree.nearest(&result, AvlDirection::Left);
1471 assert_eq!(left.is_exact(), true);
1472 assert_eq!(left.get_nearest().unwrap().value, 0);
1473 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1474
1475 tree.add_int_node(new_intnode(2));
1476 let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1477 assert!(rs.is_some());
1478 let found_value = rs.unwrap().value;
1479 assert_eq!(found_value, 2);
1480 }
1481
1482 #[test]
1483 fn int_avl_tree_order() {
1484 let max;
1485 #[cfg(miri)]
1486 {
1487 max = 2000;
1488 }
1489 #[cfg(not(miri))]
1490 {
1491 max = 200000;
1492 }
1493 let mut tree = new_inttree();
1494 assert!(tree.first().is_none());
1495 let start_ts = Instant::now();
1496 for i in 0..max {
1497 tree.add_int_node(new_intnode(i));
1498 }
1499 tree.validate_tree();
1500 assert_eq!(tree.get_count(), max as i64);
1501
1502 let mut count = 0;
1503 let mut current = tree.first();
1504 let last = tree.last();
1505 while let Some(c) = current {
1506 assert_eq!(c.value, count);
1507 count += 1;
1508 if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1509 current = None;
1510 } else {
1511 current = tree.next(c);
1512 }
1513 }
1514 assert_eq!(count, max);
1515
1516 {
1517 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1518 assert!(rs.is_some());
1519 let found_value = rs.unwrap().value;
1520 println!("found larger_eq {}", found_value);
1521 assert!(found_value >= 5);
1522 tree.remove_int(5);
1523 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1524 assert!(rs.is_some());
1525 assert!(rs.unwrap().value >= 6);
1526 tree.add_int_node(new_intnode(5));
1527 }
1528
1529 for i in 0..max {
1530 assert!(tree.remove_int(i));
1531 }
1532 assert_eq!(tree.get_count(), 0);
1533
1534 let end_ts = Instant::now();
1535 println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1536 }
1537
1538 #[test]
1539 fn int_avl_tree_fixed1() {
1540 let mut tree = new_inttree();
1541 let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1542 for i in arr.iter() {
1543 let node = new_intnode(*i);
1544 tree.add_int_node(node);
1545 let rs = tree.find_int(*i);
1546 assert!(rs.get_node_ref().is_some(), "add error {}", i);
1547 }
1548 assert_eq!(tree.get_count(), arr.len() as i64);
1549 for i in arr.iter() {
1550 assert!(tree.remove_int(*i));
1551 }
1552 assert_eq!(tree.get_count(), 0);
1553 }
1554
1555 #[test]
1556 fn int_avl_tree_fixed2() {
1557 let mut tree = new_inttree();
1558 tree.validate_tree();
1559 let node1 = new_intnode(536872960);
1560 {
1561 tree.add_int_node(node1);
1562 tree.validate_tree();
1563 tree.remove_int(536872960);
1564 tree.validate_tree();
1565 tree.add_int_node(new_intnode(536872960));
1566 tree.validate_tree();
1567 }
1568
1569 assert!(tree.find_int(536872960).get_node_ref().is_some());
1570 let node2 = new_intnode(12288);
1571 tree.add_int_node(node2);
1572 tree.validate_tree();
1573 tree.remove_int(536872960);
1574 tree.validate_tree();
1575 tree.add_int_node(new_intnode(536872960));
1576 tree.validate_tree();
1577 let node3 = new_intnode(22528);
1578 tree.add_int_node(node3);
1579 tree.validate_tree();
1580 tree.remove_int(12288);
1581 assert!(tree.find_int(12288).get_node_ref().is_none());
1582 tree.validate_tree();
1583 tree.remove_int(22528);
1584 assert!(tree.find_int(22528).get_node_ref().is_none());
1585 tree.validate_tree();
1586 tree.add_int_node(new_intnode(22528));
1587 tree.validate_tree();
1588 }
1589
1590 #[test]
1591 fn int_avl_tree_random() {
1592 let count = 1000;
1593 let mut test_list: Vec<i64> = Vec::with_capacity(count);
1594 let mut rng = rand::thread_rng();
1595 let mut tree = new_inttree();
1596 tree.validate_tree();
1597 for _ in 0..count {
1598 let node_value: i64 = rng.r#gen();
1599 if !test_list.contains(&node_value) {
1600 test_list.push(node_value);
1601 assert!(tree.add_int_node(new_intnode(node_value)))
1602 }
1603 }
1604 tree.validate_tree();
1605 test_list.sort();
1606 for index in 0..test_list.len() {
1607 let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1608 let prev = tree.prev(node_ptr);
1609 let next = tree.next(node_ptr);
1610 if index == 0 {
1611 assert!(prev.is_none());
1613 assert!(next.is_some());
1614 assert_eq!(next.unwrap().value, test_list[index + 1]);
1615 } else if index == test_list.len() - 1 {
1616 assert!(prev.is_some());
1618 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1619 assert!(next.is_none());
1620 } else {
1621 assert!(prev.is_some());
1623 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1624 assert!(next.is_some());
1625 assert_eq!(next.unwrap().value, test_list[index + 1]);
1626 }
1627 }
1628 for index in 0..test_list.len() {
1629 assert!(tree.remove_int(test_list[index]));
1630 }
1631 tree.validate_tree();
1632 assert_eq!(0, tree.get_count());
1633 }
1634
1635 #[test]
1636 fn int_avl_tree_insert_here() {
1637 let mut tree = new_inttree();
1638 let node1 = new_intnode(10);
1639 tree.add_int_node(node1);
1640 let rs = tree.find_int(10);
1642 let here = unsafe { rs.detach() };
1643 unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1644 tree.validate_tree();
1645 assert_eq!(tree.get_count(), 2);
1646 assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1647
1648 let rs = tree.find_int(10);
1650 let here = unsafe { rs.detach() };
1651 unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1652 tree.validate_tree();
1653 assert_eq!(tree.get_count(), 3);
1654 assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1655
1656 let rs = tree.find_int(5);
1658 let here = unsafe { rs.detach() };
1659 unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1660 tree.validate_tree();
1661 assert_eq!(tree.get_count(), 4);
1662
1663 let rs = tree.find_int(5);
1665 let here = unsafe { rs.detach() };
1666 unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1667 tree.validate_tree();
1668 assert_eq!(tree.get_count(), 5);
1669 }
1670
1671 #[test]
1672 fn test_arc_avl_tree_get_exact() {
1673 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1674 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1676 tree.add(node.clone(), cmp_int_node);
1677
1678 let result_search = tree.find(&100, cmp_int);
1680
1681 let exact = result_search.get_exact();
1683 assert!(exact.is_some());
1684 let exact_arc = exact.unwrap();
1685 assert_eq!(exact_arc.value, 100);
1686 assert!(Arc::ptr_eq(&node, &exact_arc));
1687 assert_eq!(Arc::strong_count(&node), 3);
1689 }
1690
1691 #[test]
1692 fn test_arc_avl_tree_remove_ref() {
1693 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1694 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1695 tree.add(node.clone(), cmp_int_node);
1696 assert_eq!(tree.get_count(), 1);
1697 assert_eq!(Arc::strong_count(&node), 2);
1698
1699 tree.remove_ref(&node);
1700 assert_eq!(tree.get_count(), 0);
1701 assert_eq!(Arc::strong_count(&node), 1);
1702 }
1703
1704 #[test]
1705 fn test_arc_avl_tree_remove_with() {
1706 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1707 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1708 tree.add(node.clone(), cmp_int_node);
1709
1710 let removed = tree.remove_by_key(&300, cmp_int);
1711 assert!(removed.is_some());
1712 let removed_arc = removed.unwrap();
1713 assert_eq!(removed_arc.value, 300);
1714 assert_eq!(tree.get_count(), 0);
1715 assert_eq!(Arc::strong_count(&node), 2);
1717
1718 drop(removed_arc);
1719 assert_eq!(Arc::strong_count(&node), 1);
1720 }
1721
1722 #[test]
1723 fn test_avl_drain() {
1724 let mut tree = new_inttree();
1725 for i in 0..100 {
1726 tree.add_int_node(new_intnode(i));
1727 }
1728 assert_eq!(tree.get_count(), 100);
1729
1730 let mut count = 0;
1731 for node in tree.drain() {
1732 assert!(node.value >= 0 && node.value < 100);
1733 count += 1;
1734 }
1735 assert_eq!(count, 100);
1736 assert_eq!(tree.get_count(), 0);
1737 assert!(tree.first().is_none());
1738 }
1739}