1use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72 cmp::{Ordering, PartialEq},
73 fmt, mem,
74 ptr::{NonNull, null},
75};
76
77pub unsafe trait AvlItem<Tag>: Sized {
89 fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94 Left = 0,
95 Right = 1,
96}
97
98impl AvlDirection {
99 #[inline(always)]
100 fn reverse(self) -> AvlDirection {
101 match self {
102 AvlDirection::Left => AvlDirection::Right,
103 AvlDirection::Right => AvlDirection::Left,
104 }
105 }
106}
107
108macro_rules! avlchild_to_balance {
109 ( $dir: expr ) => {
110 match $dir {
111 AvlDirection::Left => -1,
112 AvlDirection::Right => 1,
113 }
114 };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118 pub left: *const T,
119 pub right: *const T,
120 pub parent: *const T,
121 pub balance: i8,
122 _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128 #[inline(always)]
129 pub fn detach(&mut self) {
130 self.left = null();
131 self.right = null();
132 self.parent = null();
133 self.balance = 0;
134 }
135
136 #[inline(always)]
137 fn get_child(&self, dir: AvlDirection) -> *const T {
138 match dir {
139 AvlDirection::Left => self.left,
140 AvlDirection::Right => self.right,
141 }
142 }
143
144 #[inline(always)]
145 fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146 match dir {
147 AvlDirection::Left => self.left = child,
148 AvlDirection::Right => self.right = child,
149 }
150 }
151
152 #[inline(always)]
153 fn get_parent(&self) -> *const T {
154 return self.parent;
155 }
156
157 #[inline(always)]
159 pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160 mem::swap(&mut self.left, &mut other.left);
161 mem::swap(&mut self.right, &mut other.right);
162 mem::swap(&mut self.parent, &mut other.parent);
163 mem::swap(&mut self.balance, &mut other.balance);
164 }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168 fn default() -> Self {
169 Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170 }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176 write!(f, "(")?;
177
178 if !self.left.is_null() {
179 write!(f, "left: {:p}", self.left)?;
180 } else {
181 write!(f, "left: none ")?;
182 }
183
184 if !self.right.is_null() {
185 write!(f, "right: {:p}", self.right)?;
186 } else {
187 write!(f, "right: none ")?;
188 }
189 write!(f, ")")
190 }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195pub struct AvlTree<P, Tag>
200where
201 P: Pointer,
202 P::Target: AvlItem<Tag>,
203{
204 pub root: *const P::Target,
205 count: i64,
206 _phan: PhantomData<fn(P, &Tag)>,
207}
208
209pub struct AvlSearchResult<'a, P: Pointer> {
221 pub node: *const P::Target,
223 pub direction: Option<AvlDirection>,
225 _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229 fn default() -> Self {
230 AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231 }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235 #[inline(always)]
237 pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238 if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239 }
240
241 #[inline(always)]
243 pub fn is_exact(&self) -> bool {
244 self.direction.is_none() && !self.node.is_null()
245 }
246
247 #[inline(always)]
277 pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278 AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279 }
280
281 #[inline(always)]
283 pub fn get_nearest(&self) -> Option<&P::Target> {
284 if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285 }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289 pub fn get_exact(&self) -> Option<Arc<T>> {
291 if self.is_exact() {
292 unsafe {
293 Arc::increment_strong_count(self.node);
294 Some(Arc::from_raw(self.node))
295 }
296 } else {
297 None
298 }
299 }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303 pub fn get_exact(&self) -> Option<Rc<T>> {
305 if self.is_exact() {
306 unsafe {
307 Rc::increment_strong_count(self.node);
308 Some(Rc::from_raw(self.node))
309 }
310 } else {
311 None
312 }
313 }
314}
315
316macro_rules! return_end {
317 ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321 ($balance: expr) => {
322 match $balance {
323 0 | 1 => AvlDirection::Left,
324 _ => AvlDirection::Right,
325 }
326 };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331 P: Pointer,
332 P::Target: AvlItem<Tag>,
333{
334 pub fn new() -> Self {
336 AvlTree { count: 0, root: null(), _phan: Default::default() }
337 }
338
339 #[inline]
344 pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345 AvlDrain { tree: self, parent: null(), dir: None }
346 }
347
348 pub fn get_count(&self) -> i64 {
349 self.count
350 }
351
352 pub fn first(&self) -> Option<&P::Target> {
353 unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354 }
355
356 #[inline]
357 pub fn last(&self) -> Option<&P::Target> {
358 unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359 }
360
361 #[inline]
406 pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407 debug_assert!(w.direction.is_some());
408 self._insert(new_data, w.node, w.direction.unwrap());
409 }
410
411 pub fn _insert(
412 &mut self,
413 new_data: P,
414 here: *const P::Target, mut which_child: AvlDirection,
416 ) {
417 let mut new_balance: i8;
418 let new_ptr = new_data.into_raw();
419
420 if here.is_null() {
421 if self.count > 0 {
422 panic!("insert into a tree size {} with empty where.node", self.count);
423 }
424 self.root = new_ptr;
425 self.count += 1;
426 return;
427 }
428
429 let parent = unsafe { &*here };
430 let node = unsafe { (*new_ptr).get_node() };
431 let parent_node = parent.get_node();
432 node.parent = here;
433 parent_node.set_child(which_child, new_ptr);
434 self.count += 1;
435
436 let mut data: *const P::Target = here;
443 loop {
444 let node = unsafe { (*data).get_node() };
445 let old_balance = node.balance;
446 new_balance = old_balance + avlchild_to_balance!(which_child);
447 if new_balance == 0 {
448 node.balance = 0;
449 return;
450 }
451 if old_balance != 0 {
452 self.rotate(data, new_balance);
453 return;
454 }
455 node.balance = new_balance;
456 let parent_ptr = node.get_parent();
457 if parent_ptr.is_null() {
458 return;
459 }
460 which_child = self.parent_direction(data, parent_ptr);
461 data = parent_ptr;
462 }
463 }
464
465 pub unsafe fn insert_here(
480 &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
481 ) {
482 let mut dir_child = direction;
483 assert_eq!(here.node.is_null(), false);
484 let here_node = here.node;
485 let child = unsafe { (*here_node).get_node().get_child(dir_child) };
486 if !child.is_null() {
487 dir_child = dir_child.reverse();
488 let node = self.bottom_child_ref(child, dir_child);
489 self._insert(new_data, node, dir_child);
490 } else {
491 self._insert(new_data, here_node, dir_child);
492 }
493 }
494
495 #[inline(always)]
497 fn set_child2(
498 &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
499 parent: *const P::Target,
500 ) {
501 if !child.is_null() {
502 unsafe { (*child).get_node().parent = parent };
503 }
504 node.set_child(dir, child);
505 }
506
507 #[inline(always)]
508 fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
509 if !parent.is_null() {
510 let parent_node = unsafe { (*parent).get_node() };
511 if parent_node.left == data {
512 return AvlDirection::Left;
513 }
514 if parent_node.right == data {
515 return AvlDirection::Right;
516 }
517 panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
518 }
519 AvlDirection::Left
521 }
522
523 #[inline(always)]
524 fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
525 let node = unsafe { (*data).get_node() };
526 let parent = node.get_parent();
527 if !parent.is_null() {
528 return self.parent_direction(data, parent);
529 }
530 AvlDirection::Left
532 }
533
534 #[inline]
535 fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
536 let dir: AvlDirection;
537 if balance < 0 {
538 dir = AvlDirection::Left;
539 } else {
540 dir = AvlDirection::Right;
541 }
542 let node = unsafe { (*data).get_node() };
543
544 let parent = node.get_parent();
545 let dir_inverse = dir.reverse();
546 let left_heavy = balance >> 1;
547 let right_heavy = -left_heavy;
548
549 let child = node.get_child(dir);
550 let child_node = unsafe { (*child).get_node() };
551 let mut child_balance = child_node.balance;
552
553 let which_child = self.parent_direction(data, parent);
554
555 if child_balance != right_heavy {
557 child_balance += right_heavy;
558
559 let c_right = child_node.get_child(dir_inverse);
560 self.set_child2(node, dir, c_right, data);
561 node.balance = -child_balance;
563
564 node.parent = child;
565 child_node.set_child(dir_inverse, data);
566 child_node.balance = child_balance;
569 if !parent.is_null() {
570 child_node.parent = parent;
571 unsafe { (*parent).get_node() }.set_child(which_child, child);
572 } else {
573 child_node.parent = null();
574 self.root = child;
575 }
576 return child_balance == 0;
577 }
578 let g_child = child_node.get_child(dir_inverse);
582 let g_child_node = unsafe { (*g_child).get_node() };
583 let g_left = g_child_node.get_child(dir);
584 let g_right = g_child_node.get_child(dir_inverse);
585
586 self.set_child2(node, dir, g_right, data);
587 self.set_child2(child_node, dir_inverse, g_left, child);
588
589 let g_child_balance = g_child_node.balance;
596 if g_child_balance == right_heavy {
597 child_node.balance = left_heavy;
598 } else {
599 child_node.balance = 0;
600 }
601 child_node.parent = g_child;
602 g_child_node.set_child(dir, child);
603
604 if g_child_balance == left_heavy {
605 node.balance = right_heavy;
606 } else {
607 node.balance = 0;
608 }
609 g_child_node.balance = 0;
610
611 node.parent = g_child;
612 g_child_node.set_child(dir_inverse, data);
613
614 if !parent.is_null() {
615 g_child_node.parent = parent;
616 unsafe { (*parent).get_node() }.set_child(which_child, g_child);
617 } else {
618 g_child_node.parent = null();
619 self.root = g_child;
620 }
621 return true;
622 }
623
624 pub unsafe fn remove(&mut self, del: *const P::Target) {
661 if self.count == 0 {
672 return;
673 }
674 if self.count == 1 && self.root == del {
675 self.root = null();
676 self.count = 0;
677 unsafe { (*del).get_node().detach() };
678 return;
679 }
680 let mut which_child: AvlDirection;
681 let imm_data: *const P::Target;
682 let parent: *const P::Target;
683 let del_node = unsafe { (*del).get_node() };
685
686 let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
687
688 if node_swap_flag {
689 let dir: AvlDirection;
690 let dir_child_temp: AvlDirection;
691 let dir_child_del: AvlDirection;
692 let dir_inverse: AvlDirection;
693 dir = balance_to_child!(del_node.balance + 1);
694
695 let child_temp = del_node.get_child(dir);
696
697 dir_inverse = dir.reverse();
698 let child = self.bottom_child_ref(child_temp, dir_inverse);
699
700 if child == child_temp {
703 dir_child_temp = dir;
704 } else {
705 dir_child_temp = self.parent_direction2(child);
706 }
707
708 let parent = del_node.get_parent();
711 if !parent.is_null() {
712 dir_child_del = self.parent_direction(del, parent);
713 } else {
714 dir_child_del = AvlDirection::Left;
715 }
716
717 let child_node = unsafe { (*child).get_node() };
718 child_node.swap(del_node);
719
720 if child_node.get_child(dir) == child {
722 child_node.set_child(dir, del);
724 }
725
726 let c_dir = child_node.get_child(dir);
727 if c_dir == del {
728 del_node.parent = child;
729 } else if !c_dir.is_null() {
730 unsafe { (*c_dir).get_node() }.parent = child;
731 }
732
733 let c_inv = child_node.get_child(dir_inverse);
734 if c_inv == del {
735 del_node.parent = child;
736 } else if !c_inv.is_null() {
737 unsafe { (*c_inv).get_node() }.parent = child;
738 }
739
740 let parent = child_node.get_parent();
741 if !parent.is_null() {
742 unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
743 } else {
744 self.root = child;
745 }
746
747 let parent = del_node.get_parent();
750 unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
751 if !del_node.right.is_null() {
752 which_child = AvlDirection::Right;
753 } else {
754 which_child = AvlDirection::Left;
755 }
756 let child = del_node.get_child(which_child);
757 if !child.is_null() {
758 unsafe { (*child).get_node() }.parent = del;
759 }
760 which_child = dir_child_temp;
761 } else {
762 let parent = del_node.get_parent();
764 if !parent.is_null() {
765 which_child = self.parent_direction(del, parent);
766 } else {
767 which_child = AvlDirection::Left;
768 }
769 }
770
771 parent = del_node.get_parent();
774
775 if !del_node.left.is_null() {
776 imm_data = del_node.left;
777 } else {
778 imm_data = del_node.right;
779 }
780
781 if !imm_data.is_null() {
783 let imm_node = unsafe { (*imm_data).get_node() };
784 imm_node.parent = parent;
785 }
786
787 if !parent.is_null() {
788 assert!(self.count > 0);
789 self.count -= 1;
790
791 let parent_node = unsafe { (*parent).get_node() };
792 parent_node.set_child(which_child, imm_data);
793
794 let mut node_data: *const P::Target = parent;
797 let mut old_balance: i8;
798 let mut new_balance: i8;
799 loop {
800 let node = unsafe { (*node_data).get_node() };
804 old_balance = node.balance;
805 new_balance = old_balance - avlchild_to_balance!(which_child);
806
807 if old_balance == 0 {
811 node.balance = new_balance;
812 break;
813 }
814
815 let parent = node.get_parent();
816 which_child = self.parent_direction(node_data, parent);
817
818 if new_balance == 0 {
824 node.balance = new_balance;
825 } else if !self.rotate(node_data, new_balance) {
826 break;
827 }
828
829 if !parent.is_null() {
830 node_data = parent;
831 continue;
832 }
833 break;
834 }
835 } else {
836 if !imm_data.is_null() {
837 assert!(self.count > 0);
838 self.count -= 1;
839 self.root = imm_data;
840 }
841 }
842 if self.root.is_null() {
843 if self.count > 0 {
844 panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
845 }
846 }
847 del_node.detach();
848 }
849
850 #[inline]
855 pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
856 let result = self.find(val, cmp_func);
857 self.remove_with(unsafe { result.detach() })
858 }
859
860 #[inline]
872 pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
873 if result.is_exact() {
874 unsafe {
875 let p = result.node;
876 self.remove(p);
877 Some(P::from_raw(p))
878 }
879 } else {
880 None
881 }
882 }
883
884 #[inline]
890 pub fn find<'a, K>(
891 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
892 ) -> AvlSearchResult<'a, P> {
893 if self.root.is_null() {
894 return AvlSearchResult::default();
895 }
896 let mut node_data = self.root;
897 loop {
898 let diff = cmp_func(val, unsafe { &*node_data });
899 match diff {
900 Ordering::Equal => {
901 return AvlSearchResult {
902 node: node_data,
903 direction: None,
904 _phan: PhantomData,
905 };
906 }
907 Ordering::Less => {
908 let node = unsafe { (*node_data).get_node() };
909 let left = node.get_child(AvlDirection::Left);
910 if left.is_null() {
911 return AvlSearchResult {
912 node: node_data,
913 direction: Some(AvlDirection::Left),
914 _phan: PhantomData,
915 };
916 }
917 node_data = left;
918 }
919 Ordering::Greater => {
920 let node = unsafe { (*node_data).get_node() };
921 let right = node.get_child(AvlDirection::Right);
922 if right.is_null() {
923 return AvlSearchResult {
924 node: node_data,
925 direction: Some(AvlDirection::Right),
926 _phan: PhantomData,
927 };
928 }
929 node_data = right;
930 }
931 }
932 }
933 }
934
935 #[inline]
937 pub fn find_contained<'a, K>(
938 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
939 ) -> Option<&'a P::Target> {
940 if self.root.is_null() {
941 return None;
942 }
943 let mut node_data = self.root;
944 let mut result_node: *const P::Target = null();
945 loop {
946 let diff = cmp_func(val, unsafe { &*node_data });
947 match diff {
948 Ordering::Equal => {
949 let node = unsafe { (*node_data).get_node() };
950 let left = node.get_child(AvlDirection::Left);
951 result_node = node_data;
952 if left.is_null() {
953 break;
954 } else {
955 node_data = left;
956 }
957 }
958 Ordering::Less => {
959 let node = unsafe { (*node_data).get_node() };
960 let left = node.get_child(AvlDirection::Left);
961 if left.is_null() {
962 break;
963 }
964 node_data = left;
965 }
966 Ordering::Greater => {
967 let node = unsafe { (*node_data).get_node() };
968 let right = node.get_child(AvlDirection::Right);
969 if right.is_null() {
970 break;
971 }
972 node_data = right;
973 }
974 }
975 }
976 if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
977 }
978
979 #[inline]
981 pub fn find_larger_eq<'a, K>(
982 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
983 ) -> AvlSearchResult<'a, P> {
984 if self.root.is_null() {
985 return AvlSearchResult::default();
986 }
987 let mut node_data = self.root;
988 loop {
989 let diff = cmp_func(val, unsafe { &*node_data });
990 match diff {
991 Ordering::Equal => {
992 return AvlSearchResult {
993 node: node_data,
994 direction: None,
995 _phan: PhantomData,
996 };
997 }
998 Ordering::Less => {
999 return AvlSearchResult {
1000 node: node_data,
1001 direction: None,
1002 _phan: PhantomData,
1003 };
1004 }
1005 Ordering::Greater => {
1006 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1007 if right.is_null() {
1008 return AvlSearchResult {
1009 node: null(),
1010 direction: None,
1011 _phan: PhantomData,
1012 };
1013 }
1014 node_data = right;
1015 }
1016 }
1017 }
1018 }
1019
1020 #[inline]
1022 pub fn find_nearest<'a, K>(
1023 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1024 ) -> AvlSearchResult<'a, P> {
1025 if self.root.is_null() {
1026 return AvlSearchResult::default();
1027 }
1028
1029 let mut node_data = self.root;
1030 let mut nearest_node = null();
1031 loop {
1032 let diff = cmp_func(val, unsafe { &*node_data });
1033 match diff {
1034 Ordering::Equal => {
1035 return AvlSearchResult {
1036 node: node_data,
1037 direction: None,
1038 _phan: PhantomData,
1039 };
1040 }
1041 Ordering::Less => {
1042 nearest_node = node_data;
1043 let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1044 if left.is_null() {
1045 break;
1046 }
1047 node_data = left;
1048 }
1049 Ordering::Greater => {
1050 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1051 if right.is_null() {
1052 break;
1053 }
1054 node_data = right;
1055 }
1056 }
1057 }
1058 return AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData };
1059 }
1060
1061 #[inline(always)]
1062 fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1063 loop {
1064 let child = unsafe { (*data).get_node() }.get_child(dir);
1065 if !child.is_null() {
1066 data = child;
1067 } else {
1068 return data;
1069 }
1070 }
1071 }
1072
1073 pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1074 let mut node = self.first();
1075 while let Some(n) = node {
1076 cb(n);
1077 node = self.next(n);
1078 }
1079 }
1080
1081 #[inline]
1082 pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1083 if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1084 Some(unsafe { p.as_ref() })
1085 } else {
1086 None
1087 }
1088 }
1089
1090 #[inline]
1091 pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1092 if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1093 Some(unsafe { p.as_ref() })
1094 } else {
1095 None
1096 }
1097 }
1098
1099 #[inline]
1100 fn walk_dir(
1101 &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1102 ) -> Option<NonNull<P::Target>> {
1103 let dir_inverse = dir.reverse();
1104 let node = unsafe { (*data_ptr).get_node() };
1105 let temp = node.get_child(dir);
1106 if !temp.is_null() {
1107 unsafe {
1108 Some(NonNull::new_unchecked(
1109 self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1110 ))
1111 }
1112 } else {
1113 let mut parent = node.parent;
1114 if parent.is_null() {
1115 return None;
1116 }
1117 loop {
1118 let pdir = self.parent_direction(data_ptr, parent);
1119 if pdir == dir_inverse {
1120 return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1121 }
1122 data_ptr = parent;
1123 parent = unsafe { (*parent).get_node() }.parent;
1124 if parent.is_null() {
1125 return None;
1126 }
1127 }
1128 }
1129 }
1130
1131 #[inline]
1132 fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1133 let node = unsafe { (*data).get_node() };
1134 let left = node.left;
1135 if !left.is_null() {
1136 assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1137 assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1138 }
1139 let right = node.right;
1140 if !right.is_null() {
1141 assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1142 assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1143 }
1144 }
1145
1146 #[inline]
1147 pub fn nearest<'a>(
1148 &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1149 ) -> AvlSearchResult<'a, P> {
1150 if !current.node.is_null() {
1151 if current.direction.is_some() && current.direction != Some(direction) {
1152 return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1153 }
1154 if let Some(node) = self.walk_dir(current.node, direction) {
1155 return AvlSearchResult {
1156 node: node.as_ptr(),
1157 direction: None,
1158 _phan: PhantomData,
1159 };
1160 }
1161 }
1162 return AvlSearchResult::default();
1163 }
1164
1165 pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1166 let c = {
1167 #[cfg(feature = "std")]
1168 {
1169 ((self.get_count() + 10) as f32).log2() as usize
1170 }
1171 #[cfg(not(feature = "std"))]
1172 {
1173 100
1174 }
1175 };
1176 let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1177 if self.root.is_null() {
1178 assert_eq!(self.count, 0);
1179 return;
1180 }
1181 let mut data = self.root;
1182 let mut visited = 0;
1183 loop {
1184 if !data.is_null() {
1185 let left = {
1186 let node = unsafe { (*data).get_node() };
1187 node.get_child(AvlDirection::Left)
1188 };
1189 if !left.is_null() {
1190 stack.push(data);
1191 data = left;
1192 continue;
1193 }
1194 visited += 1;
1195 self.validate_node(data, cmp_func);
1196 data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1197 } else if stack.len() > 0 {
1198 let _data = stack.pop().unwrap();
1199 self.validate_node(_data, cmp_func);
1200 visited += 1;
1201 let node = unsafe { (*_data).get_node() };
1202 data = node.get_child(AvlDirection::Right);
1203 } else {
1204 break;
1205 }
1206 }
1207 assert_eq!(visited, self.count);
1208 }
1209
1210 #[inline]
1216 pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1217 if self.count == 0 && self.root.is_null() {
1218 self.root = node.into_raw();
1219 self.count = 1;
1220 return true;
1221 }
1222
1223 let w = self.find(node.as_ref(), cmp_func);
1224 if w.direction.is_none() {
1225 drop(node);
1228 return false;
1229 }
1230
1231 let w_node = w.node;
1234 let w_dir = w.direction;
1235 drop(w);
1237
1238 let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1239
1240 self.insert(node, w_detached);
1241 return true;
1242 }
1243}
1244
1245impl<P, Tag> Drop for AvlTree<P, Tag>
1246where
1247 P: Pointer,
1248 P::Target: AvlItem<Tag>,
1249{
1250 fn drop(&mut self) {
1251 if mem::needs_drop::<P>() {
1252 for _ in self.drain() {}
1253 }
1254 }
1255}
1256
1257pub struct AvlDrain<'a, P: Pointer, Tag>
1258where
1259 P::Target: AvlItem<Tag>,
1260{
1261 tree: &'a mut AvlTree<P, Tag>,
1262 parent: *const P::Target,
1263 dir: Option<AvlDirection>,
1264}
1265
1266impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1267where
1268 P::Target: AvlItem<Tag>,
1269{
1270 type Item = P;
1271
1272 fn next(&mut self) -> Option<Self::Item> {
1273 if self.tree.root.is_null() {
1274 return None;
1275 }
1276
1277 let mut node: *const P::Target;
1278 let parent: *const P::Target;
1279
1280 if self.dir.is_none() && self.parent.is_null() {
1281 let mut curr = self.tree.root;
1283 while unsafe { !(*curr).get_node().left.is_null() } {
1284 curr = unsafe { (*curr).get_node().left };
1285 }
1286 node = curr;
1287 } else {
1288 parent = self.parent;
1289 if parent.is_null() {
1290 return None;
1292 }
1293
1294 let child_dir = self.dir.unwrap();
1295 if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1299 node = parent;
1300 } else {
1301 node = unsafe { (*parent).get_node().right };
1303 while unsafe { !(*node).get_node().left.is_null() } {
1304 node = unsafe { (*node).get_node().left };
1305 }
1306 }
1307 }
1308
1309 if unsafe { !(*node).get_node().right.is_null() } {
1311 node = unsafe { (*node).get_node().right };
1314 }
1315
1316 let next_parent = unsafe { (*node).get_node().parent };
1318 if next_parent.is_null() {
1319 self.tree.root = null();
1320 self.parent = null();
1321 self.dir = Some(AvlDirection::Left);
1322 } else {
1323 self.parent = next_parent;
1324 self.dir = Some(self.tree.parent_direction(node, next_parent));
1325 unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1327 }
1328
1329 self.tree.count -= 1;
1330 unsafe {
1331 (*node).get_node().detach();
1332 Some(P::from_raw(node))
1333 }
1334 }
1335}
1336
1337impl<T, Tag> AvlTree<Arc<T>, Tag>
1338where
1339 T: AvlItem<Tag>,
1340{
1341 pub fn remove_ref(&mut self, node: &Arc<T>) {
1342 let p = Arc::as_ptr(node);
1343 unsafe { self.remove(p) };
1344 unsafe { drop(Arc::from_raw(p)) };
1345 }
1346}
1347
1348impl<T, Tag> AvlTree<Rc<T>, Tag>
1349where
1350 T: AvlItem<Tag>,
1351{
1352 pub fn remove_ref(&mut self, node: &Rc<T>) {
1353 let p = Rc::as_ptr(node);
1354 unsafe { self.remove(p) };
1355 unsafe { drop(Rc::from_raw(p)) };
1356 }
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361 use super::*;
1362 use core::cell::UnsafeCell;
1363 use rand::Rng;
1364 use std::time::Instant;
1365
1366 struct IntAvlNode {
1367 pub value: i64,
1368 pub node: UnsafeCell<AvlNode<Self, ()>>,
1369 }
1370
1371 impl fmt::Debug for IntAvlNode {
1372 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1373 write!(f, "{} {:#?}", self.value, self.node)
1374 }
1375 }
1376
1377 impl fmt::Display for IntAvlNode {
1378 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1379 write!(f, "{}", self.value)
1380 }
1381 }
1382
1383 unsafe impl AvlItem<()> for IntAvlNode {
1384 fn get_node(&self) -> &mut AvlNode<Self, ()> {
1385 unsafe { &mut *self.node.get() }
1386 }
1387 }
1388
1389 type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1390
1391 fn new_intnode(i: i64) -> Box<IntAvlNode> {
1392 Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1393 }
1394
1395 fn new_inttree() -> IntAvlTree {
1396 AvlTree::<Box<IntAvlNode>, ()>::new()
1397 }
1398
1399 fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1400 a.value.cmp(&b.value)
1401 }
1402
1403 fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1404 a.cmp(&b.value)
1405 }
1406
1407 impl AvlTree<Box<IntAvlNode>, ()> {
1408 fn remove_int(&mut self, i: i64) -> bool {
1409 if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1410 return true;
1412 }
1413 println!("not found {}", i);
1415 false
1416 }
1417
1418 fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1419 self.add(node, cmp_int_node)
1420 }
1421
1422 fn validate_tree(&self) {
1423 self.validate(cmp_int_node);
1424 }
1425
1426 fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1427 self.find(&i, cmp_int)
1428 }
1429
1430 fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1431 self.find(node, cmp_int_node)
1432 }
1433 }
1434
1435 #[test]
1436 fn int_avl_node() {
1437 let mut tree = new_inttree();
1438
1439 assert_eq!(tree.get_count(), 0);
1440 assert!(tree.first().is_none());
1441 assert!(tree.last().is_none());
1442
1443 let node1 = new_intnode(1);
1444 let node2 = new_intnode(2);
1445 let node3 = new_intnode(3);
1446
1447 let p1 = &*node1 as *const IntAvlNode;
1448 let p2 = &*node2 as *const IntAvlNode;
1449 let p3 = &*node3 as *const IntAvlNode;
1450
1451 tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1452 tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1453
1454 assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1455 assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1458 }
1459
1460 #[test]
1461 fn int_avl_tree_basic() {
1462 let mut tree = new_inttree();
1463
1464 let temp_node = new_intnode(0);
1465 let temp_node_val = Pointer::as_ref(&temp_node);
1466 assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1467 assert_eq!(
1468 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1469 false
1470 );
1471 assert_eq!(
1472 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1473 false
1474 );
1475 drop(temp_node);
1476
1477 tree.add_int_node(new_intnode(0));
1478 let result = tree.find_int(0);
1479 assert!(result.get_node_ref().is_some());
1480 assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1481 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1482
1483 let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1484 assert!(rs.is_some());
1485 let found_value = rs.unwrap().value;
1486 assert_eq!(found_value, 0);
1487
1488 let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1489 assert!(rs.is_none());
1490
1491 let result = tree.find_int(1);
1492 let left = tree.nearest(&result, AvlDirection::Left);
1493 assert_eq!(left.is_exact(), true);
1494 assert_eq!(left.get_nearest().unwrap().value, 0);
1495 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1496
1497 tree.add_int_node(new_intnode(2));
1498 let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1499 assert!(rs.is_some());
1500 let found_value = rs.unwrap().value;
1501 assert_eq!(found_value, 2);
1502 }
1503
1504 #[test]
1505 fn int_avl_tree_order() {
1506 let max;
1507 #[cfg(miri)]
1508 {
1509 max = 2000;
1510 }
1511 #[cfg(not(miri))]
1512 {
1513 max = 200000;
1514 }
1515 let mut tree = new_inttree();
1516 assert!(tree.first().is_none());
1517 let start_ts = Instant::now();
1518 for i in 0..max {
1519 tree.add_int_node(new_intnode(i));
1520 }
1521 tree.validate_tree();
1522 assert_eq!(tree.get_count(), max as i64);
1523
1524 let mut count = 0;
1525 let mut current = tree.first();
1526 let last = tree.last();
1527 while let Some(c) = current {
1528 assert_eq!(c.value, count);
1529 count += 1;
1530 if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1531 current = None;
1532 } else {
1533 current = tree.next(c);
1534 }
1535 }
1536 assert_eq!(count, max);
1537
1538 {
1539 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1540 assert!(rs.is_some());
1541 let found_value = rs.unwrap().value;
1542 println!("found larger_eq {}", found_value);
1543 assert!(found_value >= 5);
1544 tree.remove_int(5);
1545 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1546 assert!(rs.is_some());
1547 assert!(rs.unwrap().value >= 6);
1548 tree.add_int_node(new_intnode(5));
1549 }
1550
1551 for i in 0..max {
1552 assert!(tree.remove_int(i));
1553 }
1554 assert_eq!(tree.get_count(), 0);
1555
1556 let end_ts = Instant::now();
1557 println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1558 }
1559
1560 #[test]
1561 fn int_avl_tree_fixed1() {
1562 let mut tree = new_inttree();
1563 let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1564 for i in arr.iter() {
1565 let node = new_intnode(*i);
1566 tree.add_int_node(node);
1567 let rs = tree.find_int(*i);
1568 assert!(rs.get_node_ref().is_some(), "add error {}", i);
1569 }
1570 assert_eq!(tree.get_count(), arr.len() as i64);
1571 for i in arr.iter() {
1572 assert!(tree.remove_int(*i));
1573 }
1574 assert_eq!(tree.get_count(), 0);
1575 }
1576
1577 #[test]
1578 fn int_avl_tree_fixed2() {
1579 let mut tree = new_inttree();
1580 tree.validate_tree();
1581 let node1 = new_intnode(536872960);
1582 {
1583 tree.add_int_node(node1);
1584 tree.validate_tree();
1585 tree.remove_int(536872960);
1586 tree.validate_tree();
1587 tree.add_int_node(new_intnode(536872960));
1588 tree.validate_tree();
1589 }
1590
1591 assert!(tree.find_int(536872960).get_node_ref().is_some());
1592 let node2 = new_intnode(12288);
1593 tree.add_int_node(node2);
1594 tree.validate_tree();
1595 tree.remove_int(536872960);
1596 tree.validate_tree();
1597 tree.add_int_node(new_intnode(536872960));
1598 tree.validate_tree();
1599 let node3 = new_intnode(22528);
1600 tree.add_int_node(node3);
1601 tree.validate_tree();
1602 tree.remove_int(12288);
1603 assert!(tree.find_int(12288).get_node_ref().is_none());
1604 tree.validate_tree();
1605 tree.remove_int(22528);
1606 assert!(tree.find_int(22528).get_node_ref().is_none());
1607 tree.validate_tree();
1608 tree.add_int_node(new_intnode(22528));
1609 tree.validate_tree();
1610 }
1611
1612 #[test]
1613 fn int_avl_tree_random() {
1614 let count = 1000;
1615 let mut test_list: Vec<i64> = Vec::with_capacity(count);
1616 let mut rng = rand::thread_rng();
1617 let mut tree = new_inttree();
1618 tree.validate_tree();
1619 for _ in 0..count {
1620 let node_value: i64 = rng.r#gen();
1621 if !test_list.contains(&node_value) {
1622 test_list.push(node_value);
1623 assert!(tree.add_int_node(new_intnode(node_value)))
1624 }
1625 }
1626 tree.validate_tree();
1627 test_list.sort();
1628 for index in 0..test_list.len() {
1629 let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1630 let prev = tree.prev(node_ptr);
1631 let next = tree.next(node_ptr);
1632 if index == 0 {
1633 assert!(prev.is_none());
1635 assert!(next.is_some());
1636 assert_eq!(next.unwrap().value, test_list[index + 1]);
1637 } else if index == test_list.len() - 1 {
1638 assert!(prev.is_some());
1640 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1641 assert!(next.is_none());
1642 } else {
1643 assert!(prev.is_some());
1645 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1646 assert!(next.is_some());
1647 assert_eq!(next.unwrap().value, test_list[index + 1]);
1648 }
1649 }
1650 for index in 0..test_list.len() {
1651 assert!(tree.remove_int(test_list[index]));
1652 }
1653 tree.validate_tree();
1654 assert_eq!(0, tree.get_count());
1655 }
1656
1657 #[test]
1658 fn int_avl_tree_insert_here() {
1659 let mut tree = new_inttree();
1660 let node1 = new_intnode(10);
1661 tree.add_int_node(node1);
1662 let rs = tree.find_int(10);
1664 let here = unsafe { rs.detach() };
1665 unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1666 tree.validate_tree();
1667 assert_eq!(tree.get_count(), 2);
1668 assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1669
1670 let rs = tree.find_int(10);
1672 let here = unsafe { rs.detach() };
1673 unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1674 tree.validate_tree();
1675 assert_eq!(tree.get_count(), 3);
1676 assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1677
1678 let rs = tree.find_int(5);
1680 let here = unsafe { rs.detach() };
1681 unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1682 tree.validate_tree();
1683 assert_eq!(tree.get_count(), 4);
1684
1685 let rs = tree.find_int(5);
1687 let here = unsafe { rs.detach() };
1688 unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1689 tree.validate_tree();
1690 assert_eq!(tree.get_count(), 5);
1691 }
1692
1693 #[test]
1694 fn test_arc_avl_tree_get_exact() {
1695 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1696 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1698 tree.add(node.clone(), cmp_int_node);
1699
1700 let result_search = tree.find(&100, cmp_int);
1702
1703 let exact = result_search.get_exact();
1705 assert!(exact.is_some());
1706 let exact_arc = exact.unwrap();
1707 assert_eq!(exact_arc.value, 100);
1708 assert!(Arc::ptr_eq(&node, &exact_arc));
1709 assert_eq!(Arc::strong_count(&node), 3);
1711 }
1712
1713 #[test]
1714 fn test_arc_avl_tree_remove_ref() {
1715 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1716 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1717 tree.add(node.clone(), cmp_int_node);
1718 assert_eq!(tree.get_count(), 1);
1719 assert_eq!(Arc::strong_count(&node), 2);
1720
1721 tree.remove_ref(&node);
1722 assert_eq!(tree.get_count(), 0);
1723 assert_eq!(Arc::strong_count(&node), 1);
1724 }
1725
1726 #[test]
1727 fn test_arc_avl_tree_remove_with() {
1728 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1729 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1730 tree.add(node.clone(), cmp_int_node);
1731
1732 let removed = tree.remove_by_key(&300, cmp_int);
1733 assert!(removed.is_some());
1734 let removed_arc = removed.unwrap();
1735 assert_eq!(removed_arc.value, 300);
1736 assert_eq!(tree.get_count(), 0);
1737 assert_eq!(Arc::strong_count(&node), 2);
1739
1740 drop(removed_arc);
1741 assert_eq!(Arc::strong_count(&node), 1);
1742 }
1743
1744 #[test]
1745 fn test_avl_drain() {
1746 let mut tree = new_inttree();
1747 for i in 0..100 {
1748 tree.add_int_node(new_intnode(i));
1749 }
1750 assert_eq!(tree.get_count(), 100);
1751
1752 let mut count = 0;
1753 for node in tree.drain() {
1754 assert!(node.value >= 0 && node.value < 100);
1755 count += 1;
1756 }
1757 assert_eq!(count, 100);
1758 assert_eq!(tree.get_count(), 0);
1759 assert!(tree.first().is_none());
1760 }
1761}