1use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72 cmp::{Ordering, PartialEq},
73 fmt, mem,
74 ptr::{NonNull, null},
75};
76
77pub unsafe trait AvlItem<Tag>: Sized {
89 fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94 Left = 0,
95 Right = 1,
96}
97
98impl AvlDirection {
99 #[inline(always)]
100 fn reverse(self) -> AvlDirection {
101 match self {
102 AvlDirection::Left => AvlDirection::Right,
103 AvlDirection::Right => AvlDirection::Left,
104 }
105 }
106}
107
108macro_rules! avlchild_to_balance {
109 ( $dir: expr ) => {
110 match $dir {
111 AvlDirection::Left => -1,
112 AvlDirection::Right => 1,
113 }
114 };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118 pub left: *const T,
119 pub right: *const T,
120 pub parent: *const T,
121 pub balance: i8,
122 _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128 #[inline(always)]
129 pub fn detach(&mut self) {
130 self.left = null();
131 self.right = null();
132 self.parent = null();
133 self.balance = 0;
134 }
135
136 #[inline(always)]
137 fn get_child(&self, dir: AvlDirection) -> *const T {
138 match dir {
139 AvlDirection::Left => self.left,
140 AvlDirection::Right => self.right,
141 }
142 }
143
144 #[inline(always)]
145 fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146 match dir {
147 AvlDirection::Left => self.left = child,
148 AvlDirection::Right => self.right = child,
149 }
150 }
151
152 #[inline(always)]
153 fn get_parent(&self) -> *const T {
154 self.parent
155 }
156
157 #[inline(always)]
159 pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160 mem::swap(&mut self.left, &mut other.left);
161 mem::swap(&mut self.right, &mut other.right);
162 mem::swap(&mut self.parent, &mut other.parent);
163 mem::swap(&mut self.balance, &mut other.balance);
164 }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168 fn default() -> Self {
169 Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170 }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176 write!(f, "(")?;
177
178 if !self.left.is_null() {
179 write!(f, "left: {:p}", self.left)?;
180 } else {
181 write!(f, "left: none ")?;
182 }
183
184 if !self.right.is_null() {
185 write!(f, "right: {:p}", self.right)?;
186 } else {
187 write!(f, "right: none ")?;
188 }
189 write!(f, ")")
190 }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195pub struct AvlTree<P, Tag>
200where
201 P: Pointer,
202 P::Target: AvlItem<Tag>,
203{
204 pub root: *const P::Target,
205 count: i64,
206 _phan: PhantomData<fn(P, &Tag)>,
207}
208
209pub struct AvlSearchResult<'a, P: Pointer> {
221 pub node: *const P::Target,
223 pub direction: Option<AvlDirection>,
225 _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229 fn default() -> Self {
230 AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231 }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235 #[inline(always)]
237 pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238 if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239 }
240
241 #[inline(always)]
243 pub fn is_exact(&self) -> bool {
244 self.direction.is_none() && !self.node.is_null()
245 }
246
247 #[inline(always)]
277 pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278 AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279 }
280
281 #[inline(always)]
283 pub fn get_nearest(&self) -> Option<&P::Target> {
284 if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285 }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289 pub fn get_exact(&self) -> Option<Arc<T>> {
291 if self.is_exact() {
292 unsafe {
293 Arc::increment_strong_count(self.node);
294 Some(Arc::from_raw(self.node))
295 }
296 } else {
297 None
298 }
299 }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303 pub fn get_exact(&self) -> Option<Rc<T>> {
305 if self.is_exact() {
306 unsafe {
307 Rc::increment_strong_count(self.node);
308 Some(Rc::from_raw(self.node))
309 }
310 } else {
311 None
312 }
313 }
314}
315
316macro_rules! return_end {
317 ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321 ($balance: expr) => {
322 match $balance {
323 0 | 1 => AvlDirection::Left,
324 _ => AvlDirection::Right,
325 }
326 };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331 P: Pointer,
332 P::Target: AvlItem<Tag>,
333{
334 pub fn new() -> Self {
336 AvlTree { count: 0, root: null(), _phan: Default::default() }
337 }
338
339 #[inline]
344 pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345 AvlDrain { tree: self, parent: null(), dir: None }
346 }
347
348 pub fn get_count(&self) -> i64 {
349 self.count
350 }
351
352 pub fn first(&self) -> Option<&P::Target> {
353 unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354 }
355
356 #[inline]
357 pub fn last(&self) -> Option<&P::Target> {
358 unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359 }
360
361 #[inline]
406 pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407 debug_assert!(w.direction.is_some());
408 self._insert(new_data, w.node, w.direction.unwrap());
409 }
410
411 #[allow(clippy::not_unsafe_ptr_arg_deref)]
412 pub fn _insert(
413 &mut self,
414 new_data: P,
415 here: *const P::Target, mut which_child: AvlDirection,
417 ) {
418 let mut new_balance: i8;
419 let new_ptr = new_data.into_raw();
420
421 if here.is_null() {
422 if self.count > 0 {
423 panic!("insert into a tree size {} with empty where.node", self.count);
424 }
425 self.root = new_ptr;
426 self.count += 1;
427 return;
428 }
429
430 let parent = unsafe { &*here };
431 let node = unsafe { (*new_ptr).get_node() };
432 let parent_node = parent.get_node();
433 node.parent = here;
434 parent_node.set_child(which_child, new_ptr);
435 self.count += 1;
436
437 let mut data: *const P::Target = here;
444 loop {
445 let node = unsafe { (*data).get_node() };
446 let old_balance = node.balance;
447 new_balance = old_balance + avlchild_to_balance!(which_child);
448 if new_balance == 0 {
449 node.balance = 0;
450 return;
451 }
452 if old_balance != 0 {
453 self.rotate(data, new_balance);
454 return;
455 }
456 node.balance = new_balance;
457 let parent_ptr = node.get_parent();
458 if parent_ptr.is_null() {
459 return;
460 }
461 which_child = self.parent_direction(data, parent_ptr);
462 data = parent_ptr;
463 }
464 }
465
466 pub unsafe fn insert_here(
481 &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
482 ) {
483 let mut dir_child = direction;
484 assert!(!here.node.is_null());
485 let here_node = here.node;
486 let child = unsafe { (*here_node).get_node().get_child(dir_child) };
487 if !child.is_null() {
488 dir_child = dir_child.reverse();
489 let node = self.bottom_child_ref(child, dir_child);
490 self._insert(new_data, node, dir_child);
491 } else {
492 self._insert(new_data, here_node, dir_child);
493 }
494 }
495
496 #[inline(always)]
498 fn set_child2(
499 &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
500 parent: *const P::Target,
501 ) {
502 if !child.is_null() {
503 unsafe { (*child).get_node().parent = parent };
504 }
505 node.set_child(dir, child);
506 }
507
508 #[inline(always)]
509 fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
510 if !parent.is_null() {
511 let parent_node = unsafe { (*parent).get_node() };
512 if parent_node.left == data {
513 return AvlDirection::Left;
514 }
515 if parent_node.right == data {
516 return AvlDirection::Right;
517 }
518 panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
519 }
520 AvlDirection::Left
522 }
523
524 #[inline(always)]
525 fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
526 let node = unsafe { (*data).get_node() };
527 let parent = node.get_parent();
528 if !parent.is_null() {
529 return self.parent_direction(data, parent);
530 }
531 AvlDirection::Left
533 }
534
535 #[inline]
536 fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
537 let dir = if balance < 0 { AvlDirection::Left } else { AvlDirection::Right };
538 let node = unsafe { (*data).get_node() };
539
540 let parent = node.get_parent();
541 let dir_inverse = dir.reverse();
542 let left_heavy = balance >> 1;
543 let right_heavy = -left_heavy;
544
545 let child = node.get_child(dir);
546 let child_node = unsafe { (*child).get_node() };
547 let mut child_balance = child_node.balance;
548
549 let which_child = self.parent_direction(data, parent);
550
551 if child_balance != right_heavy {
553 child_balance += right_heavy;
554
555 let c_right = child_node.get_child(dir_inverse);
556 self.set_child2(node, dir, c_right, data);
557 node.balance = -child_balance;
559
560 node.parent = child;
561 child_node.set_child(dir_inverse, data);
562 child_node.balance = child_balance;
565 if !parent.is_null() {
566 child_node.parent = parent;
567 unsafe { (*parent).get_node() }.set_child(which_child, child);
568 } else {
569 child_node.parent = null();
570 self.root = child;
571 }
572 return child_balance == 0;
573 }
574 let g_child = child_node.get_child(dir_inverse);
578 let g_child_node = unsafe { (*g_child).get_node() };
579 let g_left = g_child_node.get_child(dir);
580 let g_right = g_child_node.get_child(dir_inverse);
581
582 self.set_child2(node, dir, g_right, data);
583 self.set_child2(child_node, dir_inverse, g_left, child);
584
585 let g_child_balance = g_child_node.balance;
592 if g_child_balance == right_heavy {
593 child_node.balance = left_heavy;
594 } else {
595 child_node.balance = 0;
596 }
597 child_node.parent = g_child;
598 g_child_node.set_child(dir, child);
599
600 if g_child_balance == left_heavy {
601 node.balance = right_heavy;
602 } else {
603 node.balance = 0;
604 }
605 g_child_node.balance = 0;
606
607 node.parent = g_child;
608 g_child_node.set_child(dir_inverse, data);
609
610 if !parent.is_null() {
611 g_child_node.parent = parent;
612 unsafe { (*parent).get_node() }.set_child(which_child, g_child);
613 } else {
614 g_child_node.parent = null();
615 self.root = g_child;
616 }
617 true
618 }
619
620 pub unsafe fn remove(&mut self, del: *const P::Target) {
659 if self.count == 0 {
670 return;
671 }
672 if self.count == 1 && self.root == del {
673 self.root = null();
674 self.count = 0;
675 unsafe { (*del).get_node().detach() };
676 return;
677 }
678 let mut which_child: AvlDirection;
679
680 let del_node = unsafe { (*del).get_node() };
682
683 let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
684
685 if node_swap_flag {
686 let dir: AvlDirection = balance_to_child!(del_node.balance + 1);
687 let child_temp = del_node.get_child(dir);
688
689 let dir_inverse: AvlDirection = dir.reverse();
690 let child = self.bottom_child_ref(child_temp, dir_inverse);
691
692 let dir_child_temp =
695 if child == child_temp { dir } else { self.parent_direction2(child) };
696
697 let parent = del_node.get_parent();
700 let dir_child_del = if !parent.is_null() {
701 self.parent_direction(del, parent)
702 } else {
703 AvlDirection::Left
704 };
705
706 let child_node = unsafe { (*child).get_node() };
707 child_node.swap(del_node);
708
709 if child_node.get_child(dir) == child {
711 child_node.set_child(dir, del);
713 }
714
715 let c_dir = child_node.get_child(dir);
716 if c_dir == del {
717 del_node.parent = child;
718 } else if !c_dir.is_null() {
719 unsafe { (*c_dir).get_node() }.parent = child;
720 }
721
722 let c_inv = child_node.get_child(dir_inverse);
723 if c_inv == del {
724 del_node.parent = child;
725 } else if !c_inv.is_null() {
726 unsafe { (*c_inv).get_node() }.parent = child;
727 }
728
729 let parent = child_node.get_parent();
730 if !parent.is_null() {
731 unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
732 } else {
733 self.root = child;
734 }
735
736 let parent = del_node.get_parent();
739 unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
740 if !del_node.right.is_null() {
741 which_child = AvlDirection::Right;
742 } else {
743 which_child = AvlDirection::Left;
744 }
745 let child = del_node.get_child(which_child);
746 if !child.is_null() {
747 unsafe { (*child).get_node() }.parent = del;
748 }
749 which_child = dir_child_temp;
750 } else {
751 let parent = del_node.get_parent();
753 if !parent.is_null() {
754 which_child = self.parent_direction(del, parent);
755 } else {
756 which_child = AvlDirection::Left;
757 }
758 }
759
760 let parent: *const P::Target = del_node.get_parent();
763
764 let imm_data: *const P::Target =
765 if !del_node.left.is_null() { del_node.left } else { del_node.right };
766
767 if !imm_data.is_null() {
769 let imm_node = unsafe { (*imm_data).get_node() };
770 imm_node.parent = parent;
771 }
772
773 if !parent.is_null() {
774 assert!(self.count > 0);
775 self.count -= 1;
776
777 let parent_node = unsafe { (*parent).get_node() };
778 parent_node.set_child(which_child, imm_data);
779
780 let mut node_data: *const P::Target = parent;
783 let mut old_balance: i8;
784 let mut new_balance: i8;
785 loop {
786 let node = unsafe { (*node_data).get_node() };
790 old_balance = node.balance;
791 new_balance = old_balance - avlchild_to_balance!(which_child);
792
793 if old_balance == 0 {
797 node.balance = new_balance;
798 break;
799 }
800
801 let parent = node.get_parent();
802 which_child = self.parent_direction(node_data, parent);
803
804 if new_balance == 0 {
810 node.balance = new_balance;
811 } else if !self.rotate(node_data, new_balance) {
812 break;
813 }
814
815 if !parent.is_null() {
816 node_data = parent;
817 continue;
818 }
819 break;
820 }
821 } else if !imm_data.is_null() {
822 assert!(self.count > 0);
823 self.count -= 1;
824 self.root = imm_data;
825 }
826 if self.root.is_null() && self.count > 0 {
827 panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
828 }
829 del_node.detach();
830 }
831
832 #[inline]
837 pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
838 let result = self.find(val, cmp_func);
839 self.remove_with(unsafe { result.detach() })
840 }
841
842 #[inline]
854 pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
855 if result.is_exact() {
856 unsafe {
857 let p = result.node;
858 self.remove(p);
859 Some(P::from_raw(p))
860 }
861 } else {
862 None
863 }
864 }
865
866 #[inline]
872 pub fn find<'a, K>(
873 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
874 ) -> AvlSearchResult<'a, P> {
875 if self.root.is_null() {
876 return AvlSearchResult::default();
877 }
878 let mut node_data = self.root;
879 loop {
880 let diff = cmp_func(val, unsafe { &*node_data });
881 match diff {
882 Ordering::Equal => {
883 return AvlSearchResult {
884 node: node_data,
885 direction: None,
886 _phan: PhantomData,
887 };
888 }
889 Ordering::Less => {
890 let node = unsafe { (*node_data).get_node() };
891 let left = node.get_child(AvlDirection::Left);
892 if left.is_null() {
893 return AvlSearchResult {
894 node: node_data,
895 direction: Some(AvlDirection::Left),
896 _phan: PhantomData,
897 };
898 }
899 node_data = left;
900 }
901 Ordering::Greater => {
902 let node = unsafe { (*node_data).get_node() };
903 let right = node.get_child(AvlDirection::Right);
904 if right.is_null() {
905 return AvlSearchResult {
906 node: node_data,
907 direction: Some(AvlDirection::Right),
908 _phan: PhantomData,
909 };
910 }
911 node_data = right;
912 }
913 }
914 }
915 }
916
917 #[inline]
919 pub fn find_contained<'a, K>(
920 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
921 ) -> Option<&'a P::Target> {
922 if self.root.is_null() {
923 return None;
924 }
925 let mut node_data = self.root;
926 let mut result_node: *const P::Target = null();
927 loop {
928 let diff = cmp_func(val, unsafe { &*node_data });
929 match diff {
930 Ordering::Equal => {
931 let node = unsafe { (*node_data).get_node() };
932 let left = node.get_child(AvlDirection::Left);
933 result_node = node_data;
934 if left.is_null() {
935 break;
936 } else {
937 node_data = left;
938 }
939 }
940 Ordering::Less => {
941 let node = unsafe { (*node_data).get_node() };
942 let left = node.get_child(AvlDirection::Left);
943 if left.is_null() {
944 break;
945 }
946 node_data = left;
947 }
948 Ordering::Greater => {
949 let node = unsafe { (*node_data).get_node() };
950 let right = node.get_child(AvlDirection::Right);
951 if right.is_null() {
952 break;
953 }
954 node_data = right;
955 }
956 }
957 }
958 if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
959 }
960
961 #[inline]
963 pub fn find_larger_eq<'a, K>(
964 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
965 ) -> AvlSearchResult<'a, P> {
966 if self.root.is_null() {
967 return AvlSearchResult::default();
968 }
969 let mut node_data = self.root;
970 loop {
971 let diff = cmp_func(val, unsafe { &*node_data });
972 match diff {
973 Ordering::Equal => {
974 return AvlSearchResult {
975 node: node_data,
976 direction: None,
977 _phan: PhantomData,
978 };
979 }
980 Ordering::Less => {
981 return AvlSearchResult {
982 node: node_data,
983 direction: None,
984 _phan: PhantomData,
985 };
986 }
987 Ordering::Greater => {
988 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
989 if right.is_null() {
990 return AvlSearchResult {
991 node: null(),
992 direction: None,
993 _phan: PhantomData,
994 };
995 }
996 node_data = right;
997 }
998 }
999 }
1000 }
1001
1002 #[inline]
1004 pub fn find_nearest<'a, K>(
1005 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1006 ) -> AvlSearchResult<'a, P> {
1007 if self.root.is_null() {
1008 return AvlSearchResult::default();
1009 }
1010
1011 let mut node_data = self.root;
1012 let mut nearest_node = null();
1013 loop {
1014 let diff = cmp_func(val, unsafe { &*node_data });
1015 match diff {
1016 Ordering::Equal => {
1017 return AvlSearchResult {
1018 node: node_data,
1019 direction: None,
1020 _phan: PhantomData,
1021 };
1022 }
1023 Ordering::Less => {
1024 nearest_node = node_data;
1025 let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1026 if left.is_null() {
1027 break;
1028 }
1029 node_data = left;
1030 }
1031 Ordering::Greater => {
1032 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1033 if right.is_null() {
1034 break;
1035 }
1036 node_data = right;
1037 }
1038 }
1039 }
1040 AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData }
1041 }
1042
1043 #[inline(always)]
1044 fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1045 loop {
1046 let child = unsafe { (*data).get_node() }.get_child(dir);
1047 if !child.is_null() {
1048 data = child;
1049 } else {
1050 return data;
1051 }
1052 }
1053 }
1054
1055 pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1056 let mut node = self.first();
1057 while let Some(n) = node {
1058 cb(n);
1059 node = self.next(n);
1060 }
1061 }
1062
1063 #[inline]
1064 pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1065 if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1066 Some(unsafe { p.as_ref() })
1067 } else {
1068 None
1069 }
1070 }
1071
1072 #[inline]
1073 pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1074 if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1075 Some(unsafe { p.as_ref() })
1076 } else {
1077 None
1078 }
1079 }
1080
1081 #[inline]
1082 fn walk_dir(
1083 &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1084 ) -> Option<NonNull<P::Target>> {
1085 let dir_inverse = dir.reverse();
1086 let node = unsafe { (*data_ptr).get_node() };
1087 let temp = node.get_child(dir);
1088 if !temp.is_null() {
1089 unsafe {
1090 Some(NonNull::new_unchecked(
1091 self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1092 ))
1093 }
1094 } else {
1095 let mut parent = node.parent;
1096 if parent.is_null() {
1097 return None;
1098 }
1099 loop {
1100 let pdir = self.parent_direction(data_ptr, parent);
1101 if pdir == dir_inverse {
1102 return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1103 }
1104 data_ptr = parent;
1105 parent = unsafe { (*parent).get_node() }.parent;
1106 if parent.is_null() {
1107 return None;
1108 }
1109 }
1110 }
1111 }
1112
1113 #[inline]
1114 fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1115 let node = unsafe { (*data).get_node() };
1116 let left = node.left;
1117 if !left.is_null() {
1118 assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1119 assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1120 }
1121 let right = node.right;
1122 if !right.is_null() {
1123 assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1124 assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1125 }
1126 }
1127
1128 #[inline]
1129 pub fn nearest<'a>(
1130 &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1131 ) -> AvlSearchResult<'a, P> {
1132 if !current.node.is_null() {
1133 if current.direction.is_some() && current.direction != Some(direction) {
1134 return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1135 }
1136 if let Some(node) = self.walk_dir(current.node, direction) {
1137 return AvlSearchResult {
1138 node: node.as_ptr(),
1139 direction: None,
1140 _phan: PhantomData,
1141 };
1142 }
1143 }
1144 AvlSearchResult::default()
1145 }
1146
1147 pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1148 let c = {
1149 #[cfg(feature = "std")]
1150 {
1151 ((self.get_count() + 10) as f32).log2() as usize
1152 }
1153 #[cfg(not(feature = "std"))]
1154 {
1155 100
1156 }
1157 };
1158 let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1159 if self.root.is_null() {
1160 assert_eq!(self.count, 0);
1161 return;
1162 }
1163 let mut data = self.root;
1164 let mut visited = 0;
1165 loop {
1166 if !data.is_null() {
1167 let left = {
1168 let node = unsafe { (*data).get_node() };
1169 node.get_child(AvlDirection::Left)
1170 };
1171 if !left.is_null() {
1172 stack.push(data);
1173 data = left;
1174 continue;
1175 }
1176 visited += 1;
1177 self.validate_node(data, cmp_func);
1178 data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1179 } else if !stack.is_empty() {
1180 let _data = stack.pop().unwrap();
1181 self.validate_node(_data, cmp_func);
1182 visited += 1;
1183 let node = unsafe { (*_data).get_node() };
1184 data = node.get_child(AvlDirection::Right);
1185 } else {
1186 break;
1187 }
1188 }
1189 assert_eq!(visited, self.count);
1190 }
1191
1192 #[inline]
1198 pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1199 if self.count == 0 && self.root.is_null() {
1200 self.root = node.into_raw();
1201 self.count = 1;
1202 return true;
1203 }
1204
1205 let w = self.find(node.as_ref(), cmp_func);
1206 if w.direction.is_none() {
1207 drop(node);
1210 return false;
1211 }
1212
1213 let w_node = w.node;
1216 let w_dir = w.direction;
1217
1218 let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1219
1220 self.insert(node, w_detached);
1221 true
1222 }
1223}
1224
1225impl<P, Tag> Drop for AvlTree<P, Tag>
1226where
1227 P: Pointer,
1228 P::Target: AvlItem<Tag>,
1229{
1230 fn drop(&mut self) {
1231 if mem::needs_drop::<P>() {
1232 for _ in self.drain() {}
1233 }
1234 }
1235}
1236
1237pub struct AvlDrain<'a, P: Pointer, Tag>
1238where
1239 P::Target: AvlItem<Tag>,
1240{
1241 tree: &'a mut AvlTree<P, Tag>,
1242 parent: *const P::Target,
1243 dir: Option<AvlDirection>,
1244}
1245
1246impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1247where
1248 P::Target: AvlItem<Tag>,
1249{
1250 type Item = P;
1251
1252 fn next(&mut self) -> Option<Self::Item> {
1253 if self.tree.root.is_null() {
1254 return None;
1255 }
1256
1257 let mut node: *const P::Target;
1258 let parent: *const P::Target;
1259
1260 if self.dir.is_none() && self.parent.is_null() {
1261 let mut curr = self.tree.root;
1263 while unsafe { !(*curr).get_node().left.is_null() } {
1264 curr = unsafe { (*curr).get_node().left };
1265 }
1266 node = curr;
1267 } else {
1268 parent = self.parent;
1269 if parent.is_null() {
1270 return None;
1272 }
1273
1274 let child_dir = self.dir.unwrap();
1275 if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1279 node = parent;
1280 } else {
1281 node = unsafe { (*parent).get_node().right };
1283 while unsafe { !(*node).get_node().left.is_null() } {
1284 node = unsafe { (*node).get_node().left };
1285 }
1286 }
1287 }
1288
1289 if unsafe { !(*node).get_node().right.is_null() } {
1291 node = unsafe { (*node).get_node().right };
1294 }
1295
1296 let next_parent = unsafe { (*node).get_node().parent };
1298 if next_parent.is_null() {
1299 self.tree.root = null();
1300 self.parent = null();
1301 self.dir = Some(AvlDirection::Left);
1302 } else {
1303 self.parent = next_parent;
1304 self.dir = Some(self.tree.parent_direction(node, next_parent));
1305 unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1307 }
1308
1309 self.tree.count -= 1;
1310 unsafe {
1311 (*node).get_node().detach();
1312 Some(P::from_raw(node))
1313 }
1314 }
1315}
1316
1317impl<T, Tag> AvlTree<Arc<T>, Tag>
1318where
1319 T: AvlItem<Tag>,
1320{
1321 pub fn remove_ref(&mut self, node: &Arc<T>) {
1322 let p = Arc::as_ptr(node);
1323 unsafe { self.remove(p) };
1324 unsafe { drop(Arc::from_raw(p)) };
1325 }
1326}
1327
1328impl<T, Tag> AvlTree<Rc<T>, Tag>
1329where
1330 T: AvlItem<Tag>,
1331{
1332 pub fn remove_ref(&mut self, node: &Rc<T>) {
1333 let p = Rc::as_ptr(node);
1334 unsafe { self.remove(p) };
1335 unsafe { drop(Rc::from_raw(p)) };
1336 }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341 use super::*;
1342 use core::cell::UnsafeCell;
1343 use rand::Rng;
1344 use std::time::Instant;
1345
1346 struct IntAvlNode {
1347 pub value: i64,
1348 pub node: UnsafeCell<AvlNode<Self, ()>>,
1349 }
1350
1351 impl fmt::Debug for IntAvlNode {
1352 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1353 write!(f, "{} {:#?}", self.value, self.node)
1354 }
1355 }
1356
1357 impl fmt::Display for IntAvlNode {
1358 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1359 write!(f, "{}", self.value)
1360 }
1361 }
1362
1363 unsafe impl AvlItem<()> for IntAvlNode {
1364 fn get_node(&self) -> &mut AvlNode<Self, ()> {
1365 unsafe { &mut *self.node.get() }
1366 }
1367 }
1368
1369 type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1370
1371 fn new_intnode(i: i64) -> Box<IntAvlNode> {
1372 Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1373 }
1374
1375 fn new_inttree() -> IntAvlTree {
1376 AvlTree::<Box<IntAvlNode>, ()>::new()
1377 }
1378
1379 fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1380 a.value.cmp(&b.value)
1381 }
1382
1383 fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1384 a.cmp(&b.value)
1385 }
1386
1387 impl AvlTree<Box<IntAvlNode>, ()> {
1388 fn remove_int(&mut self, i: i64) -> bool {
1389 if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1390 return true;
1392 }
1393 println!("not found {}", i);
1395 false
1396 }
1397
1398 fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1399 self.add(node, cmp_int_node)
1400 }
1401
1402 fn validate_tree(&self) {
1403 self.validate(cmp_int_node);
1404 }
1405
1406 fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1407 self.find(&i, cmp_int)
1408 }
1409
1410 fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1411 self.find(node, cmp_int_node)
1412 }
1413 }
1414
1415 #[test]
1416 fn int_avl_node() {
1417 let mut tree = new_inttree();
1418
1419 assert_eq!(tree.get_count(), 0);
1420 assert!(tree.first().is_none());
1421 assert!(tree.last().is_none());
1422
1423 let node1 = new_intnode(1);
1424 let node2 = new_intnode(2);
1425 let node3 = new_intnode(3);
1426
1427 let p1 = &*node1 as *const IntAvlNode;
1428 let p2 = &*node2 as *const IntAvlNode;
1429 let p3 = &*node3 as *const IntAvlNode;
1430
1431 tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1432 tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1433
1434 assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1435 assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1438 }
1439
1440 #[test]
1441 fn int_avl_tree_basic() {
1442 let mut tree = new_inttree();
1443
1444 let temp_node = new_intnode(0);
1445 let temp_node_val = Pointer::as_ref(&temp_node);
1446 assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1447 assert_eq!(
1448 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1449 false
1450 );
1451 assert_eq!(
1452 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1453 false
1454 );
1455 drop(temp_node);
1456
1457 tree.add_int_node(new_intnode(0));
1458 let result = tree.find_int(0);
1459 assert!(result.get_node_ref().is_some());
1460 assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1461 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1462
1463 let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1464 assert!(rs.is_some());
1465 let found_value = rs.unwrap().value;
1466 assert_eq!(found_value, 0);
1467
1468 let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1469 assert!(rs.is_none());
1470
1471 let result = tree.find_int(1);
1472 let left = tree.nearest(&result, AvlDirection::Left);
1473 assert_eq!(left.is_exact(), true);
1474 assert_eq!(left.get_nearest().unwrap().value, 0);
1475 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1476
1477 tree.add_int_node(new_intnode(2));
1478 let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1479 assert!(rs.is_some());
1480 let found_value = rs.unwrap().value;
1481 assert_eq!(found_value, 2);
1482 }
1483
1484 #[test]
1485 fn int_avl_tree_order() {
1486 let max;
1487 #[cfg(miri)]
1488 {
1489 max = 2000;
1490 }
1491 #[cfg(not(miri))]
1492 {
1493 max = 200000;
1494 }
1495 let mut tree = new_inttree();
1496 assert!(tree.first().is_none());
1497 let start_ts = Instant::now();
1498 for i in 0..max {
1499 tree.add_int_node(new_intnode(i));
1500 }
1501 tree.validate_tree();
1502 assert_eq!(tree.get_count(), max as i64);
1503
1504 let mut count = 0;
1505 let mut current = tree.first();
1506 let last = tree.last();
1507 while let Some(c) = current {
1508 assert_eq!(c.value, count);
1509 count += 1;
1510 if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1511 current = None;
1512 } else {
1513 current = tree.next(c);
1514 }
1515 }
1516 assert_eq!(count, max);
1517
1518 {
1519 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1520 assert!(rs.is_some());
1521 let found_value = rs.unwrap().value;
1522 println!("found larger_eq {}", found_value);
1523 assert!(found_value >= 5);
1524 tree.remove_int(5);
1525 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1526 assert!(rs.is_some());
1527 assert!(rs.unwrap().value >= 6);
1528 tree.add_int_node(new_intnode(5));
1529 }
1530
1531 for i in 0..max {
1532 assert!(tree.remove_int(i));
1533 }
1534 assert_eq!(tree.get_count(), 0);
1535
1536 let end_ts = Instant::now();
1537 println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1538 }
1539
1540 #[test]
1541 fn int_avl_tree_fixed1() {
1542 let mut tree = new_inttree();
1543 let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1544 for i in arr.iter() {
1545 let node = new_intnode(*i);
1546 tree.add_int_node(node);
1547 let rs = tree.find_int(*i);
1548 assert!(rs.get_node_ref().is_some(), "add error {}", i);
1549 }
1550 assert_eq!(tree.get_count(), arr.len() as i64);
1551 for i in arr.iter() {
1552 assert!(tree.remove_int(*i));
1553 }
1554 assert_eq!(tree.get_count(), 0);
1555 }
1556
1557 #[test]
1558 fn int_avl_tree_fixed2() {
1559 let mut tree = new_inttree();
1560 tree.validate_tree();
1561 let node1 = new_intnode(536872960);
1562 {
1563 tree.add_int_node(node1);
1564 tree.validate_tree();
1565 tree.remove_int(536872960);
1566 tree.validate_tree();
1567 tree.add_int_node(new_intnode(536872960));
1568 tree.validate_tree();
1569 }
1570
1571 assert!(tree.find_int(536872960).get_node_ref().is_some());
1572 let node2 = new_intnode(12288);
1573 tree.add_int_node(node2);
1574 tree.validate_tree();
1575 tree.remove_int(536872960);
1576 tree.validate_tree();
1577 tree.add_int_node(new_intnode(536872960));
1578 tree.validate_tree();
1579 let node3 = new_intnode(22528);
1580 tree.add_int_node(node3);
1581 tree.validate_tree();
1582 tree.remove_int(12288);
1583 assert!(tree.find_int(12288).get_node_ref().is_none());
1584 tree.validate_tree();
1585 tree.remove_int(22528);
1586 assert!(tree.find_int(22528).get_node_ref().is_none());
1587 tree.validate_tree();
1588 tree.add_int_node(new_intnode(22528));
1589 tree.validate_tree();
1590 }
1591
1592 #[test]
1593 fn int_avl_tree_random() {
1594 let count = 1000;
1595 let mut test_list: Vec<i64> = Vec::with_capacity(count);
1596 let mut rng = rand::thread_rng();
1597 let mut tree = new_inttree();
1598 tree.validate_tree();
1599 for _ in 0..count {
1600 let node_value: i64 = rng.r#gen();
1601 if !test_list.contains(&node_value) {
1602 test_list.push(node_value);
1603 assert!(tree.add_int_node(new_intnode(node_value)))
1604 }
1605 }
1606 tree.validate_tree();
1607 test_list.sort();
1608 for index in 0..test_list.len() {
1609 let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1610 let prev = tree.prev(node_ptr);
1611 let next = tree.next(node_ptr);
1612 if index == 0 {
1613 assert!(prev.is_none());
1615 assert!(next.is_some());
1616 assert_eq!(next.unwrap().value, test_list[index + 1]);
1617 } else if index == test_list.len() - 1 {
1618 assert!(prev.is_some());
1620 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1621 assert!(next.is_none());
1622 } else {
1623 assert!(prev.is_some());
1625 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1626 assert!(next.is_some());
1627 assert_eq!(next.unwrap().value, test_list[index + 1]);
1628 }
1629 }
1630 for index in 0..test_list.len() {
1631 assert!(tree.remove_int(test_list[index]));
1632 }
1633 tree.validate_tree();
1634 assert_eq!(0, tree.get_count());
1635 }
1636
1637 #[test]
1638 fn int_avl_tree_insert_here() {
1639 let mut tree = new_inttree();
1640 let node1 = new_intnode(10);
1641 tree.add_int_node(node1);
1642 let rs = tree.find_int(10);
1644 let here = unsafe { rs.detach() };
1645 unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1646 tree.validate_tree();
1647 assert_eq!(tree.get_count(), 2);
1648 assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1649
1650 let rs = tree.find_int(10);
1652 let here = unsafe { rs.detach() };
1653 unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1654 tree.validate_tree();
1655 assert_eq!(tree.get_count(), 3);
1656 assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1657
1658 let rs = tree.find_int(5);
1660 let here = unsafe { rs.detach() };
1661 unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1662 tree.validate_tree();
1663 assert_eq!(tree.get_count(), 4);
1664
1665 let rs = tree.find_int(5);
1667 let here = unsafe { rs.detach() };
1668 unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1669 tree.validate_tree();
1670 assert_eq!(tree.get_count(), 5);
1671 }
1672
1673 #[test]
1674 fn test_arc_avl_tree_get_exact() {
1675 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1676 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1678 tree.add(node.clone(), cmp_int_node);
1679
1680 let result_search = tree.find(&100, cmp_int);
1682
1683 let exact = result_search.get_exact();
1685 assert!(exact.is_some());
1686 let exact_arc = exact.unwrap();
1687 assert_eq!(exact_arc.value, 100);
1688 assert!(Arc::ptr_eq(&node, &exact_arc));
1689 assert_eq!(Arc::strong_count(&node), 3);
1691 }
1692
1693 #[test]
1694 fn test_arc_avl_tree_remove_ref() {
1695 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1696 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1697 tree.add(node.clone(), cmp_int_node);
1698 assert_eq!(tree.get_count(), 1);
1699 assert_eq!(Arc::strong_count(&node), 2);
1700
1701 tree.remove_ref(&node);
1702 assert_eq!(tree.get_count(), 0);
1703 assert_eq!(Arc::strong_count(&node), 1);
1704 }
1705
1706 #[test]
1707 fn test_arc_avl_tree_remove_with() {
1708 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1709 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1710 tree.add(node.clone(), cmp_int_node);
1711
1712 let removed = tree.remove_by_key(&300, cmp_int);
1713 assert!(removed.is_some());
1714 let removed_arc = removed.unwrap();
1715 assert_eq!(removed_arc.value, 300);
1716 assert_eq!(tree.get_count(), 0);
1717 assert_eq!(Arc::strong_count(&node), 2);
1719
1720 drop(removed_arc);
1721 assert_eq!(Arc::strong_count(&node), 1);
1722 }
1723
1724 #[test]
1725 fn test_avl_drain() {
1726 let mut tree = new_inttree();
1727 for i in 0..100 {
1728 tree.add_int_node(new_intnode(i));
1729 }
1730 assert_eq!(tree.get_count(), 100);
1731
1732 let mut count = 0;
1733 for node in tree.drain() {
1734 assert!(node.value >= 0 && node.value < 100);
1735 count += 1;
1736 }
1737 assert_eq!(count, 100);
1738 assert_eq!(tree.get_count(), 0);
1739 assert!(tree.first().is_none());
1740 }
1741}