1use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72 cmp::{Ordering, PartialEq},
73 fmt, mem,
74 ptr::{NonNull, null},
75};
76
77pub unsafe trait AvlItem<Tag>: Sized {
89 fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94 Left = 0,
95 Right = 1,
96}
97
98impl AvlDirection {
99 #[inline(always)]
100 fn reverse(self) -> AvlDirection {
101 match self {
102 AvlDirection::Left => AvlDirection::Right,
103 AvlDirection::Right => AvlDirection::Left,
104 }
105 }
106}
107
108macro_rules! avlchild_to_balance {
109 ( $dir: expr ) => {
110 match $dir {
111 AvlDirection::Left => -1,
112 AvlDirection::Right => 1,
113 }
114 };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118 pub left: *const T,
119 pub right: *const T,
120 pub parent: *const T,
121 pub balance: i8,
122 _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128 #[inline(always)]
129 pub fn detach(&mut self) {
130 self.left = null();
131 self.right = null();
132 self.parent = null();
133 self.balance = 0;
134 }
135
136 #[inline(always)]
137 fn get_child(&self, dir: AvlDirection) -> *const T {
138 match dir {
139 AvlDirection::Left => self.left,
140 AvlDirection::Right => self.right,
141 }
142 }
143
144 #[inline(always)]
145 fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146 match dir {
147 AvlDirection::Left => self.left = child,
148 AvlDirection::Right => self.right = child,
149 }
150 }
151
152 #[inline(always)]
153 fn get_parent(&self) -> *const T {
154 return self.parent;
155 }
156
157 #[inline(always)]
159 pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160 mem::swap(&mut self.left, &mut other.left);
161 mem::swap(&mut self.right, &mut other.right);
162 mem::swap(&mut self.parent, &mut other.parent);
163 mem::swap(&mut self.balance, &mut other.balance);
164 }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168 fn default() -> Self {
169 Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170 }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176 write!(f, "(")?;
177
178 if !self.left.is_null() {
179 write!(f, "left: {:p}", self.left)?;
180 } else {
181 write!(f, "left: none ")?;
182 }
183
184 if !self.right.is_null() {
185 write!(f, "right: {:p}", self.right)?;
186 } else {
187 write!(f, "right: none ")?;
188 }
189 write!(f, ")")
190 }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195pub struct AvlTree<P, Tag>
200where
201 P: Pointer,
202 P::Target: AvlItem<Tag>,
203{
204 pub root: *const P::Target,
205 count: i64,
206 _phan: PhantomData<fn(P, &Tag)>,
207}
208
209pub struct AvlSearchResult<'a, P: Pointer> {
221 pub node: *const P::Target,
223 pub direction: Option<AvlDirection>,
225 _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229 fn default() -> Self {
230 AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231 }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235 #[inline(always)]
237 pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238 if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239 }
240
241 #[inline(always)]
243 pub fn is_exact(&self) -> bool {
244 self.direction.is_none() && !self.node.is_null()
245 }
246
247 #[inline(always)]
277 pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278 AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279 }
280
281 #[inline(always)]
283 pub fn get_nearest(&self) -> Option<&P::Target> {
284 if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285 }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289 pub fn get_exact(&self) -> Option<Arc<T>> {
291 if self.is_exact() {
292 unsafe {
293 Arc::increment_strong_count(self.node);
294 Some(Arc::from_raw(self.node))
295 }
296 } else {
297 None
298 }
299 }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303 pub fn get_exact(&self) -> Option<Rc<T>> {
305 if self.is_exact() {
306 unsafe {
307 Rc::increment_strong_count(self.node);
308 Some(Rc::from_raw(self.node))
309 }
310 } else {
311 None
312 }
313 }
314}
315
316macro_rules! return_end {
317 ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321 ($balance: expr) => {
322 match $balance {
323 0 | 1 => AvlDirection::Left,
324 _ => AvlDirection::Right,
325 }
326 };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331 P: Pointer,
332 P::Target: AvlItem<Tag>,
333{
334 pub fn new() -> Self {
336 AvlTree { count: 0, root: null(), _phan: Default::default() }
337 }
338
339 #[inline]
344 pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345 AvlDrain { tree: self, parent: null(), dir: None }
346 }
347
348 pub fn get_count(&self) -> i64 {
349 self.count
350 }
351
352 pub fn first(&self) -> Option<&P::Target> {
353 unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354 }
355
356 #[inline]
357 pub fn last(&self) -> Option<&P::Target> {
358 unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359 }
360
361 #[inline]
401 pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
402 debug_assert!(w.direction.is_some());
403 self._insert(new_data, w.node, w.direction.unwrap());
404 }
405
406 pub fn _insert(
407 &mut self,
408 new_data: P,
409 here: *const P::Target, mut which_child: AvlDirection,
411 ) {
412 let mut new_balance: i8;
413 let new_ptr = new_data.into_raw();
414
415 if here.is_null() {
416 if self.count > 0 {
417 panic!("insert into a tree size {} with empty where.node", self.count);
418 }
419 self.root = new_ptr;
420 self.count += 1;
421 return;
422 }
423
424 let parent = unsafe { &*here };
425 let node = unsafe { (*new_ptr).get_node() };
426 let parent_node = parent.get_node();
427 node.parent = here;
428 parent_node.set_child(which_child, new_ptr);
429 self.count += 1;
430
431 let mut data: *const P::Target = here;
438 loop {
439 let node = unsafe { (*data).get_node() };
440 let old_balance = node.balance;
441 new_balance = old_balance + avlchild_to_balance!(which_child);
442 if new_balance == 0 {
443 node.balance = 0;
444 return;
445 }
446 if old_balance != 0 {
447 self.rotate(data, new_balance);
448 return;
449 }
450 node.balance = new_balance;
451 let parent_ptr = node.get_parent();
452 if parent_ptr.is_null() {
453 return;
454 }
455 which_child = self.parent_direction(data, parent_ptr);
456 data = parent_ptr;
457 }
458 }
459
460 pub unsafe fn insert_here(
468 &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
469 ) {
470 let mut dir_child = direction;
471 assert_eq!(here.node.is_null(), false);
472 let here_node = here.node;
473 let child = unsafe { (*here_node).get_node().get_child(dir_child) };
474 if !child.is_null() {
475 dir_child = dir_child.reverse();
476 let node = self.bottom_child_ref(child, dir_child);
477 self._insert(new_data, node, dir_child);
478 } else {
479 self._insert(new_data, here_node, dir_child);
480 }
481 }
482
483 #[inline(always)]
485 fn set_child2(
486 &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
487 parent: *const P::Target,
488 ) {
489 if !child.is_null() {
490 unsafe { (*child).get_node().parent = parent };
491 }
492 node.set_child(dir, child);
493 }
494
495 #[inline(always)]
496 fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
497 if !parent.is_null() {
498 let parent_node = unsafe { (*parent).get_node() };
499 if parent_node.left == data {
500 return AvlDirection::Left;
501 }
502 if parent_node.right == data {
503 return AvlDirection::Right;
504 }
505 panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
506 }
507 AvlDirection::Left
509 }
510
511 #[inline(always)]
512 fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
513 let node = unsafe { (*data).get_node() };
514 let parent = node.get_parent();
515 if !parent.is_null() {
516 return self.parent_direction(data, parent);
517 }
518 AvlDirection::Left
520 }
521
522 #[inline]
523 fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
524 let dir: AvlDirection;
525 if balance < 0 {
526 dir = AvlDirection::Left;
527 } else {
528 dir = AvlDirection::Right;
529 }
530 let node = unsafe { (*data).get_node() };
531
532 let parent = node.get_parent();
533 let dir_inverse = dir.reverse();
534 let left_heavy = balance >> 1;
535 let right_heavy = -left_heavy;
536
537 let child = node.get_child(dir);
538 let child_node = unsafe { (*child).get_node() };
539 let mut child_balance = child_node.balance;
540
541 let which_child = self.parent_direction(data, parent);
542
543 if child_balance != right_heavy {
545 child_balance += right_heavy;
546
547 let c_right = child_node.get_child(dir_inverse);
548 self.set_child2(node, dir, c_right, data);
549 node.balance = -child_balance;
551
552 node.parent = child;
553 child_node.set_child(dir_inverse, data);
554 child_node.balance = child_balance;
557 if !parent.is_null() {
558 child_node.parent = parent;
559 unsafe { (*parent).get_node() }.set_child(which_child, child);
560 } else {
561 child_node.parent = null();
562 self.root = child;
563 }
564 return child_balance == 0;
565 }
566 let g_child = child_node.get_child(dir_inverse);
570 let g_child_node = unsafe { (*g_child).get_node() };
571 let g_left = g_child_node.get_child(dir);
572 let g_right = g_child_node.get_child(dir_inverse);
573
574 self.set_child2(node, dir, g_right, data);
575 self.set_child2(child_node, dir_inverse, g_left, child);
576
577 let g_child_balance = g_child_node.balance;
584 if g_child_balance == right_heavy {
585 child_node.balance = left_heavy;
586 } else {
587 child_node.balance = 0;
588 }
589 child_node.parent = g_child;
590 g_child_node.set_child(dir, child);
591
592 if g_child_balance == left_heavy {
593 node.balance = right_heavy;
594 } else {
595 node.balance = 0;
596 }
597 g_child_node.balance = 0;
598
599 node.parent = g_child;
600 g_child_node.set_child(dir_inverse, data);
601
602 if !parent.is_null() {
603 g_child_node.parent = parent;
604 unsafe { (*parent).get_node() }.set_child(which_child, g_child);
605 } else {
606 g_child_node.parent = null();
607 self.root = g_child;
608 }
609 return true;
610 }
611
612 pub unsafe fn remove(&mut self, del: *const P::Target) {
648 if self.count == 0 {
659 return;
660 }
661 if self.count == 1 && self.root == del {
662 self.root = null();
663 self.count = 0;
664 unsafe { (*del).get_node().detach() };
665 return;
666 }
667 let mut which_child: AvlDirection;
668 let imm_data: *const P::Target;
669 let parent: *const P::Target;
670 let del_node = unsafe { (*del).get_node() };
672
673 let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
674
675 if node_swap_flag {
676 let dir: AvlDirection;
677 let dir_child_temp: AvlDirection;
678 let dir_child_del: AvlDirection;
679 let dir_inverse: AvlDirection;
680 dir = balance_to_child!(del_node.balance + 1);
681
682 let child_temp = del_node.get_child(dir);
683
684 dir_inverse = dir.reverse();
685 let child = self.bottom_child_ref(child_temp, dir_inverse);
686
687 if child == child_temp {
690 dir_child_temp = dir;
691 } else {
692 dir_child_temp = self.parent_direction2(child);
693 }
694
695 let parent = del_node.get_parent();
698 if !parent.is_null() {
699 dir_child_del = self.parent_direction(del, parent);
700 } else {
701 dir_child_del = AvlDirection::Left;
702 }
703
704 let child_node = unsafe { (*child).get_node() };
705 child_node.swap(del_node);
706
707 if child_node.get_child(dir) == child {
709 child_node.set_child(dir, del);
711 }
712
713 let c_dir = child_node.get_child(dir);
714 if c_dir == del {
715 del_node.parent = child;
716 } else if !c_dir.is_null() {
717 unsafe { (*c_dir).get_node() }.parent = child;
718 }
719
720 let c_inv = child_node.get_child(dir_inverse);
721 if c_inv == del {
722 del_node.parent = child;
723 } else if !c_inv.is_null() {
724 unsafe { (*c_inv).get_node() }.parent = child;
725 }
726
727 let parent = child_node.get_parent();
728 if !parent.is_null() {
729 unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
730 } else {
731 self.root = child;
732 }
733
734 let parent = del_node.get_parent();
737 unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
738 if !del_node.right.is_null() {
739 which_child = AvlDirection::Right;
740 } else {
741 which_child = AvlDirection::Left;
742 }
743 let child = del_node.get_child(which_child);
744 if !child.is_null() {
745 unsafe { (*child).get_node() }.parent = del;
746 }
747 which_child = dir_child_temp;
748 } else {
749 let parent = del_node.get_parent();
751 if !parent.is_null() {
752 which_child = self.parent_direction(del, parent);
753 } else {
754 which_child = AvlDirection::Left;
755 }
756 }
757
758 parent = del_node.get_parent();
761
762 if !del_node.left.is_null() {
763 imm_data = del_node.left;
764 } else {
765 imm_data = del_node.right;
766 }
767
768 if !imm_data.is_null() {
770 let imm_node = unsafe { (*imm_data).get_node() };
771 imm_node.parent = parent;
772 }
773
774 if !parent.is_null() {
775 assert!(self.count > 0);
776 self.count -= 1;
777
778 let parent_node = unsafe { (*parent).get_node() };
779 parent_node.set_child(which_child, imm_data);
780
781 let mut node_data: *const P::Target = parent;
784 let mut old_balance: i8;
785 let mut new_balance: i8;
786 loop {
787 let node = unsafe { (*node_data).get_node() };
791 old_balance = node.balance;
792 new_balance = old_balance - avlchild_to_balance!(which_child);
793
794 if old_balance == 0 {
798 node.balance = new_balance;
799 break;
800 }
801
802 let parent = node.get_parent();
803 which_child = self.parent_direction(node_data, parent);
804
805 if new_balance == 0 {
811 node.balance = new_balance;
812 } else if !self.rotate(node_data, new_balance) {
813 break;
814 }
815
816 if !parent.is_null() {
817 node_data = parent;
818 continue;
819 }
820 break;
821 }
822 } else {
823 if !imm_data.is_null() {
824 assert!(self.count > 0);
825 self.count -= 1;
826 self.root = imm_data;
827 }
828 }
829 if self.root.is_null() {
830 if self.count > 0 {
831 panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
832 }
833 }
834 del_node.detach();
835 }
836
837 #[inline]
842 pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
843 let result = self.find(val, cmp_func);
844 self.remove_with(unsafe { result.detach() })
845 }
846
847 #[inline]
857 pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
858 if result.is_exact() {
859 unsafe {
860 let p = result.node;
861 self.remove(p);
862 Some(P::from_raw(p))
863 }
864 } else {
865 None
866 }
867 }
868
869 #[inline]
875 pub fn find<'a, K>(
876 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
877 ) -> AvlSearchResult<'a, P> {
878 if self.root.is_null() {
879 return AvlSearchResult::default();
880 }
881 let mut node_data = self.root;
882 loop {
883 let diff = cmp_func(val, unsafe { &*node_data });
884 match diff {
885 Ordering::Equal => {
886 return AvlSearchResult {
887 node: node_data,
888 direction: None,
889 _phan: PhantomData,
890 };
891 }
892 Ordering::Less => {
893 let node = unsafe { (*node_data).get_node() };
894 let left = node.get_child(AvlDirection::Left);
895 if left.is_null() {
896 return AvlSearchResult {
897 node: node_data,
898 direction: Some(AvlDirection::Left),
899 _phan: PhantomData,
900 };
901 }
902 node_data = left;
903 }
904 Ordering::Greater => {
905 let node = unsafe { (*node_data).get_node() };
906 let right = node.get_child(AvlDirection::Right);
907 if right.is_null() {
908 return AvlSearchResult {
909 node: node_data,
910 direction: Some(AvlDirection::Right),
911 _phan: PhantomData,
912 };
913 }
914 node_data = right;
915 }
916 }
917 }
918 }
919
920 #[inline]
922 pub fn find_contained<'a, K>(
923 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
924 ) -> Option<&'a P::Target> {
925 if self.root.is_null() {
926 return None;
927 }
928 let mut node_data = self.root;
929 let mut result_node: *const P::Target = null();
930 loop {
931 let diff = cmp_func(val, unsafe { &*node_data });
932 match diff {
933 Ordering::Equal => {
934 let node = unsafe { (*node_data).get_node() };
935 let left = node.get_child(AvlDirection::Left);
936 result_node = node_data;
937 if left.is_null() {
938 break;
939 } else {
940 node_data = left;
941 }
942 }
943 Ordering::Less => {
944 let node = unsafe { (*node_data).get_node() };
945 let left = node.get_child(AvlDirection::Left);
946 if left.is_null() {
947 break;
948 }
949 node_data = left;
950 }
951 Ordering::Greater => {
952 let node = unsafe { (*node_data).get_node() };
953 let right = node.get_child(AvlDirection::Right);
954 if right.is_null() {
955 break;
956 }
957 node_data = right;
958 }
959 }
960 }
961 if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
962 }
963
964 #[inline]
966 pub fn find_larger_eq<'a, K>(
967 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
968 ) -> AvlSearchResult<'a, P> {
969 if self.root.is_null() {
970 return AvlSearchResult::default();
971 }
972 let mut node_data = self.root;
973 loop {
974 let diff = cmp_func(val, unsafe { &*node_data });
975 match diff {
976 Ordering::Equal => {
977 return AvlSearchResult {
978 node: node_data,
979 direction: None,
980 _phan: PhantomData,
981 };
982 }
983 Ordering::Less => {
984 return AvlSearchResult {
985 node: node_data,
986 direction: None,
987 _phan: PhantomData,
988 };
989 }
990 Ordering::Greater => {
991 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
992 if right.is_null() {
993 return AvlSearchResult {
994 node: null(),
995 direction: None,
996 _phan: PhantomData,
997 };
998 }
999 node_data = right;
1000 }
1001 }
1002 }
1003 }
1004
1005 #[inline]
1007 pub fn find_nearest<'a, K>(
1008 &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1009 ) -> AvlSearchResult<'a, P> {
1010 if self.root.is_null() {
1011 return AvlSearchResult::default();
1012 }
1013
1014 let mut node_data = self.root;
1015 let mut nearest_node = null();
1016 loop {
1017 let diff = cmp_func(val, unsafe { &*node_data });
1018 match diff {
1019 Ordering::Equal => {
1020 return AvlSearchResult {
1021 node: node_data,
1022 direction: None,
1023 _phan: PhantomData,
1024 };
1025 }
1026 Ordering::Less => {
1027 nearest_node = node_data;
1028 let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1029 if left.is_null() {
1030 break;
1031 }
1032 node_data = left;
1033 }
1034 Ordering::Greater => {
1035 let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1036 if right.is_null() {
1037 break;
1038 }
1039 node_data = right;
1040 }
1041 }
1042 }
1043 return AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData };
1044 }
1045
1046 #[inline(always)]
1047 fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1048 loop {
1049 let child = unsafe { (*data).get_node() }.get_child(dir);
1050 if !child.is_null() {
1051 data = child;
1052 } else {
1053 return data;
1054 }
1055 }
1056 }
1057
1058 pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1059 let mut node = self.first();
1060 while let Some(n) = node {
1061 cb(n);
1062 node = self.next(n);
1063 }
1064 }
1065
1066 #[inline]
1067 pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1068 if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1069 Some(unsafe { p.as_ref() })
1070 } else {
1071 None
1072 }
1073 }
1074
1075 #[inline]
1076 pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1077 if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1078 Some(unsafe { p.as_ref() })
1079 } else {
1080 None
1081 }
1082 }
1083
1084 #[inline]
1085 fn walk_dir(
1086 &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1087 ) -> Option<NonNull<P::Target>> {
1088 let dir_inverse = dir.reverse();
1089 let node = unsafe { (*data_ptr).get_node() };
1090 let temp = node.get_child(dir);
1091 if !temp.is_null() {
1092 unsafe {
1093 Some(NonNull::new_unchecked(
1094 self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1095 ))
1096 }
1097 } else {
1098 let mut parent = node.parent;
1099 if parent.is_null() {
1100 return None;
1101 }
1102 loop {
1103 let pdir = self.parent_direction(data_ptr, parent);
1104 if pdir == dir_inverse {
1105 return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1106 }
1107 data_ptr = parent;
1108 parent = unsafe { (*parent).get_node() }.parent;
1109 if parent.is_null() {
1110 return None;
1111 }
1112 }
1113 }
1114 }
1115
1116 #[inline]
1117 fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1118 let node = unsafe { (*data).get_node() };
1119 let left = node.left;
1120 if !left.is_null() {
1121 assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1122 assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1123 }
1124 let right = node.right;
1125 if !right.is_null() {
1126 assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1127 assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1128 }
1129 }
1130
1131 #[inline]
1132 pub fn nearest<'a>(
1133 &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1134 ) -> AvlSearchResult<'a, P> {
1135 if !current.node.is_null() {
1136 if current.direction.is_some() && current.direction != Some(direction) {
1137 return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1138 }
1139 if let Some(node) = self.walk_dir(current.node, direction) {
1140 return AvlSearchResult {
1141 node: node.as_ptr(),
1142 direction: None,
1143 _phan: PhantomData,
1144 };
1145 }
1146 }
1147 return AvlSearchResult::default();
1148 }
1149
1150 pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1151 let c = {
1152 #[cfg(feature = "std")]
1153 {
1154 ((self.get_count() + 10) as f32).log2() as usize
1155 }
1156 #[cfg(not(feature = "std"))]
1157 {
1158 100
1159 }
1160 };
1161 let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1162 if self.root.is_null() {
1163 assert_eq!(self.count, 0);
1164 return;
1165 }
1166 let mut data = self.root;
1167 let mut visited = 0;
1168 loop {
1169 if !data.is_null() {
1170 let left = {
1171 let node = unsafe { (*data).get_node() };
1172 node.get_child(AvlDirection::Left)
1173 };
1174 if !left.is_null() {
1175 stack.push(data);
1176 data = left;
1177 continue;
1178 }
1179 visited += 1;
1180 self.validate_node(data, cmp_func);
1181 data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1182 } else if stack.len() > 0 {
1183 let _data = stack.pop().unwrap();
1184 self.validate_node(_data, cmp_func);
1185 visited += 1;
1186 let node = unsafe { (*_data).get_node() };
1187 data = node.get_child(AvlDirection::Right);
1188 } else {
1189 break;
1190 }
1191 }
1192 assert_eq!(visited, self.count);
1193 }
1194
1195 #[inline]
1201 pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1202 if self.count == 0 && self.root.is_null() {
1203 self.root = node.into_raw();
1204 self.count = 1;
1205 return true;
1206 }
1207
1208 let w = self.find(node.as_ref(), cmp_func);
1209 if w.direction.is_none() {
1210 drop(node);
1213 return false;
1214 }
1215
1216 let w_node = w.node;
1219 let w_dir = w.direction;
1220 drop(w);
1222
1223 let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1224
1225 self.insert(node, w_detached);
1226 return true;
1227 }
1228}
1229
1230impl<P, Tag> Drop for AvlTree<P, Tag>
1231where
1232 P: Pointer,
1233 P::Target: AvlItem<Tag>,
1234{
1235 fn drop(&mut self) {
1236 if mem::needs_drop::<P>() {
1237 for _ in self.drain() {}
1238 }
1239 }
1240}
1241
1242pub struct AvlDrain<'a, P: Pointer, Tag>
1243where
1244 P::Target: AvlItem<Tag>,
1245{
1246 tree: &'a mut AvlTree<P, Tag>,
1247 parent: *const P::Target,
1248 dir: Option<AvlDirection>,
1249}
1250
1251impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1252where
1253 P::Target: AvlItem<Tag>,
1254{
1255 type Item = P;
1256
1257 fn next(&mut self) -> Option<Self::Item> {
1258 if self.tree.root.is_null() {
1259 return None;
1260 }
1261
1262 let mut node: *const P::Target;
1263 let parent: *const P::Target;
1264
1265 if self.dir.is_none() && self.parent.is_null() {
1266 let mut curr = self.tree.root;
1268 while unsafe { !(*curr).get_node().left.is_null() } {
1269 curr = unsafe { (*curr).get_node().left };
1270 }
1271 node = curr;
1272 } else {
1273 parent = self.parent;
1274 if parent.is_null() {
1275 return None;
1277 }
1278
1279 let child_dir = self.dir.unwrap();
1280 if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1284 node = parent;
1285 } else {
1286 node = unsafe { (*parent).get_node().right };
1288 while unsafe { !(*node).get_node().left.is_null() } {
1289 node = unsafe { (*node).get_node().left };
1290 }
1291 }
1292 }
1293
1294 if unsafe { !(*node).get_node().right.is_null() } {
1296 node = unsafe { (*node).get_node().right };
1299 }
1300
1301 let next_parent = unsafe { (*node).get_node().parent };
1303 if next_parent.is_null() {
1304 self.tree.root = null();
1305 self.parent = null();
1306 self.dir = Some(AvlDirection::Left);
1307 } else {
1308 self.parent = next_parent;
1309 self.dir = Some(self.tree.parent_direction(node, next_parent));
1310 unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1312 }
1313
1314 self.tree.count -= 1;
1315 unsafe {
1316 (*node).get_node().detach();
1317 Some(P::from_raw(node))
1318 }
1319 }
1320}
1321
1322impl<T, Tag> AvlTree<Arc<T>, Tag>
1323where
1324 T: AvlItem<Tag>,
1325{
1326 pub fn remove_ref(&mut self, node: &Arc<T>) {
1327 let p = Arc::as_ptr(node);
1328 unsafe { self.remove(p) };
1329 unsafe { drop(Arc::from_raw(p)) };
1330 }
1331}
1332
1333impl<T, Tag> AvlTree<Rc<T>, Tag>
1334where
1335 T: AvlItem<Tag>,
1336{
1337 pub fn remove_ref(&mut self, node: &Rc<T>) {
1338 let p = Rc::as_ptr(node);
1339 unsafe { self.remove(p) };
1340 unsafe { drop(Rc::from_raw(p)) };
1341 }
1342}
1343
1344#[cfg(test)]
1345mod tests {
1346 use super::*;
1347 use core::cell::UnsafeCell;
1348 use rand::Rng;
1349 use std::time::Instant;
1350
1351 struct IntAvlNode {
1352 pub value: i64,
1353 pub node: UnsafeCell<AvlNode<Self, ()>>,
1354 }
1355
1356 impl fmt::Debug for IntAvlNode {
1357 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1358 write!(f, "{} {:#?}", self.value, self.node)
1359 }
1360 }
1361
1362 impl fmt::Display for IntAvlNode {
1363 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1364 write!(f, "{}", self.value)
1365 }
1366 }
1367
1368 unsafe impl AvlItem<()> for IntAvlNode {
1369 fn get_node(&self) -> &mut AvlNode<Self, ()> {
1370 unsafe { &mut *self.node.get() }
1371 }
1372 }
1373
1374 type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1375
1376 fn new_intnode(i: i64) -> Box<IntAvlNode> {
1377 Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1378 }
1379
1380 fn new_inttree() -> IntAvlTree {
1381 AvlTree::<Box<IntAvlNode>, ()>::new()
1382 }
1383
1384 fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1385 a.value.cmp(&b.value)
1386 }
1387
1388 fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1389 a.cmp(&b.value)
1390 }
1391
1392 impl AvlTree<Box<IntAvlNode>, ()> {
1393 fn remove_int(&mut self, i: i64) -> bool {
1394 if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1395 return true;
1397 }
1398 println!("not found {}", i);
1400 false
1401 }
1402
1403 fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1404 self.add(node, cmp_int_node)
1405 }
1406
1407 fn validate_tree(&self) {
1408 self.validate(cmp_int_node);
1409 }
1410
1411 fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1412 self.find(&i, cmp_int)
1413 }
1414
1415 fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1416 self.find(node, cmp_int_node)
1417 }
1418 }
1419
1420 #[test]
1421 fn int_avl_node() {
1422 let mut tree = new_inttree();
1423
1424 assert_eq!(tree.get_count(), 0);
1425 assert!(tree.first().is_none());
1426 assert!(tree.last().is_none());
1427
1428 let node1 = new_intnode(1);
1429 let node2 = new_intnode(2);
1430 let node3 = new_intnode(3);
1431
1432 let p1 = &*node1 as *const IntAvlNode;
1433 let p2 = &*node2 as *const IntAvlNode;
1434 let p3 = &*node3 as *const IntAvlNode;
1435
1436 tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1437 tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1438
1439 assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1440 assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1443 }
1444
1445 #[test]
1446 fn int_avl_tree_basic() {
1447 let mut tree = new_inttree();
1448
1449 let temp_node = new_intnode(0);
1450 let temp_node_val = Pointer::as_ref(&temp_node);
1451 assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1452 assert_eq!(
1453 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1454 false
1455 );
1456 assert_eq!(
1457 tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1458 false
1459 );
1460 drop(temp_node);
1461
1462 tree.add_int_node(new_intnode(0));
1463 let result = tree.find_int(0);
1464 assert!(result.get_node_ref().is_some());
1465 assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1466 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1467
1468 let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1469 assert!(rs.is_some());
1470 let found_value = rs.unwrap().value;
1471 assert_eq!(found_value, 0);
1472
1473 let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1474 assert!(rs.is_none());
1475
1476 let result = tree.find_int(1);
1477 let left = tree.nearest(&result, AvlDirection::Left);
1478 assert_eq!(left.is_exact(), true);
1479 assert_eq!(left.get_nearest().unwrap().value, 0);
1480 assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1481
1482 tree.add_int_node(new_intnode(2));
1483 let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1484 assert!(rs.is_some());
1485 let found_value = rs.unwrap().value;
1486 assert_eq!(found_value, 2);
1487 }
1488
1489 #[test]
1490 fn int_avl_tree_order() {
1491 let max;
1492 #[cfg(miri)]
1493 {
1494 max = 2000;
1495 }
1496 #[cfg(not(miri))]
1497 {
1498 max = 200000;
1499 }
1500 let mut tree = new_inttree();
1501 assert!(tree.first().is_none());
1502 let start_ts = Instant::now();
1503 for i in 0..max {
1504 tree.add_int_node(new_intnode(i));
1505 }
1506 tree.validate_tree();
1507 assert_eq!(tree.get_count(), max as i64);
1508
1509 let mut count = 0;
1510 let mut current = tree.first();
1511 let last = tree.last();
1512 while let Some(c) = current {
1513 assert_eq!(c.value, count);
1514 count += 1;
1515 if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1516 current = None;
1517 } else {
1518 current = tree.next(c);
1519 }
1520 }
1521 assert_eq!(count, max);
1522
1523 {
1524 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1525 assert!(rs.is_some());
1526 let found_value = rs.unwrap().value;
1527 println!("found larger_eq {}", found_value);
1528 assert!(found_value >= 5);
1529 tree.remove_int(5);
1530 let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1531 assert!(rs.is_some());
1532 assert!(rs.unwrap().value >= 6);
1533 tree.add_int_node(new_intnode(5));
1534 }
1535
1536 for i in 0..max {
1537 assert!(tree.remove_int(i));
1538 }
1539 assert_eq!(tree.get_count(), 0);
1540
1541 let end_ts = Instant::now();
1542 println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1543 }
1544
1545 #[test]
1546 fn int_avl_tree_fixed1() {
1547 let mut tree = new_inttree();
1548 let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1549 for i in arr.iter() {
1550 let node = new_intnode(*i);
1551 tree.add_int_node(node);
1552 let rs = tree.find_int(*i);
1553 assert!(rs.get_node_ref().is_some(), "add error {}", i);
1554 }
1555 assert_eq!(tree.get_count(), arr.len() as i64);
1556 for i in arr.iter() {
1557 assert!(tree.remove_int(*i));
1558 }
1559 assert_eq!(tree.get_count(), 0);
1560 }
1561
1562 #[test]
1563 fn int_avl_tree_fixed2() {
1564 let mut tree = new_inttree();
1565 tree.validate_tree();
1566 let node1 = new_intnode(536872960);
1567 {
1568 tree.add_int_node(node1);
1569 tree.validate_tree();
1570 tree.remove_int(536872960);
1571 tree.validate_tree();
1572 tree.add_int_node(new_intnode(536872960));
1573 tree.validate_tree();
1574 }
1575
1576 assert!(tree.find_int(536872960).get_node_ref().is_some());
1577 let node2 = new_intnode(12288);
1578 tree.add_int_node(node2);
1579 tree.validate_tree();
1580 tree.remove_int(536872960);
1581 tree.validate_tree();
1582 tree.add_int_node(new_intnode(536872960));
1583 tree.validate_tree();
1584 let node3 = new_intnode(22528);
1585 tree.add_int_node(node3);
1586 tree.validate_tree();
1587 tree.remove_int(12288);
1588 assert!(tree.find_int(12288).get_node_ref().is_none());
1589 tree.validate_tree();
1590 tree.remove_int(22528);
1591 assert!(tree.find_int(22528).get_node_ref().is_none());
1592 tree.validate_tree();
1593 tree.add_int_node(new_intnode(22528));
1594 tree.validate_tree();
1595 }
1596
1597 #[test]
1598 fn int_avl_tree_random() {
1599 let count = 1000;
1600 let mut test_list: Vec<i64> = Vec::with_capacity(count);
1601 let mut rng = rand::thread_rng();
1602 let mut tree = new_inttree();
1603 tree.validate_tree();
1604 for _ in 0..count {
1605 let node_value: i64 = rng.r#gen();
1606 if !test_list.contains(&node_value) {
1607 test_list.push(node_value);
1608 assert!(tree.add_int_node(new_intnode(node_value)))
1609 }
1610 }
1611 tree.validate_tree();
1612 test_list.sort();
1613 for index in 0..test_list.len() {
1614 let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1615 let prev = tree.prev(node_ptr);
1616 let next = tree.next(node_ptr);
1617 if index == 0 {
1618 assert!(prev.is_none());
1620 assert!(next.is_some());
1621 assert_eq!(next.unwrap().value, test_list[index + 1]);
1622 } else if index == test_list.len() - 1 {
1623 assert!(prev.is_some());
1625 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1626 assert!(next.is_none());
1627 } else {
1628 assert!(prev.is_some());
1630 assert_eq!(prev.unwrap().value, test_list[index - 1]);
1631 assert!(next.is_some());
1632 assert_eq!(next.unwrap().value, test_list[index + 1]);
1633 }
1634 }
1635 for index in 0..test_list.len() {
1636 assert!(tree.remove_int(test_list[index]));
1637 }
1638 tree.validate_tree();
1639 assert_eq!(0, tree.get_count());
1640 }
1641
1642 #[test]
1643 fn int_avl_tree_insert_here() {
1644 let mut tree = new_inttree();
1645 let node1 = new_intnode(10);
1646 tree.add_int_node(node1);
1647 let rs = tree.find_int(10);
1649 let here = unsafe { rs.detach() };
1650 unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1651 tree.validate_tree();
1652 assert_eq!(tree.get_count(), 2);
1653 assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1654
1655 let rs = tree.find_int(10);
1657 let here = unsafe { rs.detach() };
1658 unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1659 tree.validate_tree();
1660 assert_eq!(tree.get_count(), 3);
1661 assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1662
1663 let rs = tree.find_int(5);
1665 let here = unsafe { rs.detach() };
1666 unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1667 tree.validate_tree();
1668 assert_eq!(tree.get_count(), 4);
1669
1670 let rs = tree.find_int(5);
1672 let here = unsafe { rs.detach() };
1673 unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1674 tree.validate_tree();
1675 assert_eq!(tree.get_count(), 5);
1676 }
1677
1678 #[test]
1679 fn test_arc_avl_tree_get_exact() {
1680 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1681 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1683 tree.add(node.clone(), cmp_int_node);
1684
1685 let result_search = tree.find(&100, cmp_int);
1687
1688 let exact = result_search.get_exact();
1690 assert!(exact.is_some());
1691 let exact_arc = exact.unwrap();
1692 assert_eq!(exact_arc.value, 100);
1693 assert!(Arc::ptr_eq(&node, &exact_arc));
1694 assert_eq!(Arc::strong_count(&node), 3);
1696 }
1697
1698 #[test]
1699 fn test_arc_avl_tree_remove_ref() {
1700 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1701 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1702 tree.add(node.clone(), cmp_int_node);
1703 assert_eq!(tree.get_count(), 1);
1704 assert_eq!(Arc::strong_count(&node), 2);
1705
1706 tree.remove_ref(&node);
1707 assert_eq!(tree.get_count(), 0);
1708 assert_eq!(Arc::strong_count(&node), 1);
1709 }
1710
1711 #[test]
1712 fn test_arc_avl_tree_remove_with() {
1713 let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1714 let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1715 tree.add(node.clone(), cmp_int_node);
1716
1717 let removed = tree.remove_by_key(&300, cmp_int);
1718 assert!(removed.is_some());
1719 let removed_arc = removed.unwrap();
1720 assert_eq!(removed_arc.value, 300);
1721 assert_eq!(tree.get_count(), 0);
1722 assert_eq!(Arc::strong_count(&node), 2);
1724
1725 drop(removed_arc);
1726 assert_eq!(Arc::strong_count(&node), 1);
1727 }
1728
1729 #[test]
1730 fn test_avl_drain() {
1731 let mut tree = new_inttree();
1732 for i in 0..100 {
1733 tree.add_int_node(new_intnode(i));
1734 }
1735 assert_eq!(tree.get_count(), 100);
1736
1737 let mut count = 0;
1738 for node in tree.drain() {
1739 assert!(node.value >= 0 && node.value < 100);
1740 count += 1;
1741 }
1742 assert_eq!(count, 100);
1743 assert_eq!(tree.get_count(), 0);
1744 assert!(tree.first().is_none());
1745 }
1746}