Skip to main content

embed_collections/
avl.rs

1//! An intrusive AVL tree implementation.
2//!
3//! The algorithm origin from open-zfs
4//!
5//! # Examples
6//!
7//! ## Using `Box` to search and remove by key
8//!
9//! ```rust
10//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
11//! use core::cell::UnsafeCell;
12//! use core::cmp::Ordering;
13//!
14//! struct MyNode {
15//!     value: i32,
16//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
17//! }
18//!
19//! unsafe impl AvlItem<()> for MyNode {
20//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
21//!         unsafe { &mut *self.avl_node.get() }
22//!     }
23//! }
24//!
25//! let mut tree = AvlTree::<Box<MyNode>, ()>::new();
26//! tree.add(Box::new(MyNode { value: 10, avl_node: UnsafeCell::new(Default::default()) }), |a, b| a.value.cmp(&b.value));
27//!
28//! // Search and remove
29//! if let Some(node) = tree.remove_by_key(&10, |key, node| key.cmp(&node.value)) {
30//!     assert_eq!(node.value, 10);
31//! }
32//! ```
33//!
34//! ## Using `Arc` for multiple ownership
35//!
36//! remove_ref only available to `Arc` and `Rc`
37//!
38//! ```rust
39//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
40//! use core::cell::UnsafeCell;
41//! use std::sync::Arc;
42//!
43//! struct MyNode {
44//!     value: i32,
45//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
46//! }
47//!
48//! unsafe impl AvlItem<()> for MyNode {
49//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
50//!         unsafe { &mut *self.avl_node.get() }
51//!     }
52//! }
53//!
54//! let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
55//! let node = Arc::new(MyNode { value: 42, avl_node: UnsafeCell::new(Default::default()) });
56//!
57//! tree.add(node.clone(), |a, b| a.value.cmp(&b.value));
58//! assert_eq!(tree.get_count(), 1);
59//!
60//! // Remove by reference (detach from avl tree)
61//! tree.remove_ref(&node);
62//! assert_eq!(tree.get_count(), 0);
63//! ```
64//!
65
66use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72    cmp::{Ordering, PartialEq},
73    fmt, mem,
74    ptr::{NonNull, null},
75};
76
77/// A trait to return internal mutable AvlNode for specified list.
78///
79/// The tag is used to distinguish different AvlNodes within the same item,
80/// allowing an item to belong to multiple lists simultaneously.
81/// For only one ownership, you can use `()`.
82///
83/// # Safety
84///
85/// Implementors must ensure `get_node` returns a valid reference to the `AvlNode`
86/// embedded within `Self`. Users must use `UnsafeCell` to hold `AvlNode` to support
87/// interior mutability required by list operations.
88pub unsafe trait AvlItem<Tag>: Sized {
89    fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94    Left = 0,
95    Right = 1,
96}
97
98impl AvlDirection {
99    #[inline(always)]
100    fn reverse(self) -> AvlDirection {
101        match self {
102            AvlDirection::Left => AvlDirection::Right,
103            AvlDirection::Right => AvlDirection::Left,
104        }
105    }
106}
107
108macro_rules! avlchild_to_balance {
109    ( $dir: expr ) => {
110        match $dir {
111            AvlDirection::Left => -1,
112            AvlDirection::Right => 1,
113        }
114    };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118    pub left: *const T,
119    pub right: *const T,
120    pub parent: *const T,
121    pub balance: i8,
122    _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128    #[inline(always)]
129    pub fn detach(&mut self) {
130        self.left = null();
131        self.right = null();
132        self.parent = null();
133        self.balance = 0;
134    }
135
136    #[inline(always)]
137    fn get_child(&self, dir: AvlDirection) -> *const T {
138        match dir {
139            AvlDirection::Left => self.left,
140            AvlDirection::Right => self.right,
141        }
142    }
143
144    #[inline(always)]
145    fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146        match dir {
147            AvlDirection::Left => self.left = child,
148            AvlDirection::Right => self.right = child,
149        }
150    }
151
152    #[inline(always)]
153    fn get_parent(&self) -> *const T {
154        self.parent
155    }
156
157    // Swap two node but not there value
158    #[inline(always)]
159    pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160        mem::swap(&mut self.left, &mut other.left);
161        mem::swap(&mut self.right, &mut other.right);
162        mem::swap(&mut self.parent, &mut other.parent);
163        mem::swap(&mut self.balance, &mut other.balance);
164    }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168    fn default() -> Self {
169        Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170    }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176        write!(f, "(")?;
177
178        if !self.left.is_null() {
179            write!(f, "left: {:p}", self.left)?;
180        } else {
181            write!(f, "left: none ")?;
182        }
183
184        if !self.right.is_null() {
185            write!(f, "right: {:p}", self.right)?;
186        } else {
187            write!(f, "right: none ")?;
188        }
189        write!(f, ")")
190    }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195/// An intrusive AVL tree (balanced binary search tree).
196///
197/// Elements in the tree must implement the [`AvlItem`] trait.
198/// The tree supports various pointer types (`Box`, `Arc`, `Rc`, etc.) through the [`Pointer`] trait.
199pub struct AvlTree<P, Tag>
200where
201    P: Pointer,
202    P::Target: AvlItem<Tag>,
203{
204    pub root: *const P::Target,
205    count: i64,
206    _phan: PhantomData<fn(P, &Tag)>,
207}
208
209/// Result of a search operation in an [`AvlTree`].
210///
211/// An `AvlSearchResult` identifies either:
212/// 1. An exact match: `direction` is `None` and `node` points to the matching item.
213/// 2. An insertion point: `direction` is `Some(dir)` and `node` points to the parent
214///    where a new node should be attached as the `dir` child.
215///
216/// The lifetime `'a` ties the search result to the tree's borrow, ensuring safety.
217/// However, this lifetime often prevents further mutable operations on the tree
218/// (e.g., adding a node while holding the search result). Use [`detach`](Self::detach)
219/// to de-couple the result from the tree's lifetime when necessary.
220pub struct AvlSearchResult<'a, P: Pointer> {
221    /// The matching node or the parent for insertion.
222    pub node: *const P::Target,
223    /// `None` if exact match found, or `Some(direction)` indicating insertion point.
224    pub direction: Option<AvlDirection>,
225    _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229    fn default() -> Self {
230        AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231    }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235    /// Returns a reference to the matching node if the search was an exact match.
236    #[inline(always)]
237    pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238        if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239    }
240
241    /// Returns `true` if the search result is an exact match.
242    #[inline(always)]
243    pub fn is_exact(&self) -> bool {
244        self.direction.is_none() && !self.node.is_null()
245    }
246
247    /// De-couple the lifetime of the search result from the tree.
248    ///
249    /// This method is essential for performing mutable operations on the tree
250    /// using search results. In Rust, a search result typically borrows the tree
251    /// immutably. If you need to modify the tree (e.g., call `insert` or `remove`)
252    /// based on that result, the borrow checker would normally prevent it.
253    ///
254    /// `detach` effectively "erases" the lifetime `'a`, returning a result with
255    /// an unbounded lifetime `'b`.
256    ///
257    /// # Examples
258    ///
259    /// Used in `RangeTree::add`:
260    /// ```ignore
261    /// let result = self.root.find(&rs_key, range_tree_segment_cmp);
262    /// // result is AvlSearchResult<'a, ...> and borrows self.root
263    ///
264    /// let detached = unsafe { result.detach() };
265    /// // detached has no lifetime bound to self.root
266    ///
267    /// self.space += size; // Mutable operation on self permitted
268    /// self.merge_seg(start, end, detached); // Mutation on tree permitted
269    /// ```
270    ///
271    /// # Safety
272    /// This is an unsafe operation. The compiler no longer protects the validity
273    /// of the internal pointer via lifetimes. You must ensure that the tree
274    /// structure is not modified in a way that invalidates `node` (e.g., the
275    /// parent node being removed) before using the detached result.
276    #[inline(always)]
277    pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278        AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279    }
280
281    /// Return the nearest node in the search result
282    #[inline(always)]
283    pub fn get_nearest(&self) -> Option<&P::Target> {
284        if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285    }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289    /// Returns the matching Arc node if this is an exact match.
290    pub fn get_exact(&self) -> Option<Arc<T>> {
291        if self.is_exact() {
292            unsafe {
293                Arc::increment_strong_count(self.node);
294                Some(Arc::from_raw(self.node))
295            }
296        } else {
297            None
298        }
299    }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303    /// Returns the matching Rc node if this is an exact match.
304    pub fn get_exact(&self) -> Option<Rc<T>> {
305        if self.is_exact() {
306            unsafe {
307                Rc::increment_strong_count(self.node);
308                Some(Rc::from_raw(self.node))
309            }
310        } else {
311            None
312        }
313    }
314}
315
316macro_rules! return_end {
317    ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321    ($balance: expr) => {
322        match $balance {
323            0 | 1 => AvlDirection::Left,
324            _ => AvlDirection::Right,
325        }
326    };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331    P: Pointer,
332    P::Target: AvlItem<Tag>,
333{
334    /// Creates a new, empty `AvlTree`.
335    pub fn new() -> Self {
336        AvlTree { count: 0, root: null(), _phan: Default::default() }
337    }
338
339    /// Returns an iterator that removes all elements from the tree in post-order.
340    ///
341    /// This is an optimized, non-recursive, and stack-less traversal that preserves
342    /// tree invariants during destruction.
343    #[inline]
344    pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345        AvlDrain { tree: self, parent: null(), dir: None }
346    }
347
348    pub fn get_count(&self) -> i64 {
349        self.count
350    }
351
352    pub fn first(&self) -> Option<&P::Target> {
353        unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354    }
355
356    #[inline]
357    pub fn last(&self) -> Option<&P::Target> {
358        unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359    }
360
361    /// Inserts a new node into the tree at the location specified by a search result.
362    ///
363    /// This is typically used after a [`find`](Self::find) operation didn't find an exact match.
364    ///
365    /// # Safety
366    ///
367    /// Once the tree structure changed, previous search result is not safe to use anymore.
368    ///
369    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
370    /// to avoid the borrowing issue.
371    ///
372    /// # Panics
373    /// Panics if the search result is an exact match (i.e. node already exists).
374    ///
375    /// # Examples
376    ///
377    /// ```rust
378    /// use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
379    /// use core::cell::UnsafeCell;
380    /// use std::sync::Arc;
381    ///
382    /// struct MyNode {
383    ///     value: i32,
384    ///     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
385    /// }
386    ///
387    /// unsafe impl AvlItem<()> for MyNode {
388    ///     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
389    ///         unsafe { &mut *self.avl_node.get() }
390    ///     }
391    /// }
392    ///
393    /// let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
394    /// let key = 42;
395    /// let result = tree.find(&key, |k, n| k.cmp(&n.value));
396    ///
397    /// if !result.is_exact() {
398    ///     let new_node = Arc::new(MyNode {
399    ///         value: key,
400    ///         avl_node: UnsafeCell::new(Default::default()),
401    ///     });
402    ///     tree.insert(new_node, unsafe{result.detach()});
403    /// }
404    /// ```
405    #[inline]
406    pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407        debug_assert!(w.direction.is_some());
408        self._insert(new_data, w.node, w.direction.unwrap());
409    }
410
411    #[allow(clippy::not_unsafe_ptr_arg_deref)]
412    pub fn _insert(
413        &mut self,
414        new_data: P,
415        here: *const P::Target, // parent
416        mut which_child: AvlDirection,
417    ) {
418        let mut new_balance: i8;
419        let new_ptr = new_data.into_raw();
420
421        if here.is_null() {
422            if self.count > 0 {
423                panic!("insert into a tree size {} with empty where.node", self.count);
424            }
425            self.root = new_ptr;
426            self.count += 1;
427            return;
428        }
429
430        let parent = unsafe { &*here };
431        let node = unsafe { (*new_ptr).get_node() };
432        let parent_node = parent.get_node();
433        node.parent = here;
434        parent_node.set_child(which_child, new_ptr);
435        self.count += 1;
436
437        /*
438         * Now, back up the tree modifying the balance of all nodes above the
439         * insertion point. If we get to a highly unbalanced ancestor, we
440         * need to do a rotation.  If we back out of the tree we are done.
441         * If we brought any subtree into perfect balance (0), we are also done.
442         */
443        let mut data: *const P::Target = here;
444        loop {
445            let node = unsafe { (*data).get_node() };
446            let old_balance = node.balance;
447            new_balance = old_balance + avlchild_to_balance!(which_child);
448            if new_balance == 0 {
449                node.balance = 0;
450                return;
451            }
452            if old_balance != 0 {
453                self.rotate(data, new_balance);
454                return;
455            }
456            node.balance = new_balance;
457            let parent_ptr = node.get_parent();
458            if parent_ptr.is_null() {
459                return;
460            }
461            which_child = self.parent_direction(data, parent_ptr);
462            data = parent_ptr;
463        }
464    }
465
466    /// Insert "new_data" in "tree" in the given "direction" either after or
467    /// before AvlDirection::After, AvlDirection::Before) the data "here".
468    ///
469    /// Insertions can only be done at empty leaf points in the tree, therefore
470    /// if the given child of the node is already present we move to either
471    /// the AVL_PREV or AVL_NEXT and reverse the insertion direction. Since
472    /// every other node in the tree is a leaf, this always works.
473    ///
474    /// # Safety
475    ///
476    /// Once the tree structure changed, previous search result is not safe to use anymore.
477    ///
478    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
479    /// to avoid the borrowing issue.
480    pub unsafe fn insert_here(
481        &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
482    ) {
483        let mut dir_child = direction;
484        assert!(!here.node.is_null());
485        let here_node = here.node;
486        let child = unsafe { (*here_node).get_node().get_child(dir_child) };
487        if !child.is_null() {
488            dir_child = dir_child.reverse();
489            let node = self.bottom_child_ref(child, dir_child);
490            self._insert(new_data, node, dir_child);
491        } else {
492            self._insert(new_data, here_node, dir_child);
493        }
494    }
495
496    // set child and both child's parent
497    #[inline(always)]
498    fn set_child2(
499        &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
500        parent: *const P::Target,
501    ) {
502        if !child.is_null() {
503            unsafe { (*child).get_node().parent = parent };
504        }
505        node.set_child(dir, child);
506    }
507
508    #[inline(always)]
509    fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
510        if !parent.is_null() {
511            let parent_node = unsafe { (*parent).get_node() };
512            if parent_node.left == data {
513                return AvlDirection::Left;
514            }
515            if parent_node.right == data {
516                return AvlDirection::Right;
517            }
518            panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
519        }
520        // this just follow zfs
521        AvlDirection::Left
522    }
523
524    #[inline(always)]
525    fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
526        let node = unsafe { (*data).get_node() };
527        let parent = node.get_parent();
528        if !parent.is_null() {
529            return self.parent_direction(data, parent);
530        }
531        // this just follow zfs
532        AvlDirection::Left
533    }
534
535    #[inline]
536    fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
537        let dir = if balance < 0 { AvlDirection::Left } else { AvlDirection::Right };
538        let node = unsafe { (*data).get_node() };
539
540        let parent = node.get_parent();
541        let dir_inverse = dir.reverse();
542        let left_heavy = balance >> 1;
543        let right_heavy = -left_heavy;
544
545        let child = node.get_child(dir);
546        let child_node = unsafe { (*child).get_node() };
547        let mut child_balance = child_node.balance;
548
549        let which_child = self.parent_direction(data, parent);
550
551        // node is overly left heavy, the left child is balanced or also left heavy.
552        if child_balance != right_heavy {
553            child_balance += right_heavy;
554
555            let c_right = child_node.get_child(dir_inverse);
556            self.set_child2(node, dir, c_right, data);
557            // move node to be child's right child
558            node.balance = -child_balance;
559
560            node.parent = child;
561            child_node.set_child(dir_inverse, data);
562            // update the pointer into this subtree
563
564            child_node.balance = child_balance;
565            if !parent.is_null() {
566                child_node.parent = parent;
567                unsafe { (*parent).get_node() }.set_child(which_child, child);
568            } else {
569                child_node.parent = null();
570                self.root = child;
571            }
572            return child_balance == 0;
573        }
574        // When node is left heavy, but child is right heavy we use
575        // a different rotation.
576
577        let g_child = child_node.get_child(dir_inverse);
578        let g_child_node = unsafe { (*g_child).get_node() };
579        let g_left = g_child_node.get_child(dir);
580        let g_right = g_child_node.get_child(dir_inverse);
581
582        self.set_child2(node, dir, g_right, data);
583        self.set_child2(child_node, dir_inverse, g_left, child);
584
585        /*
586         * move child to left child of gchild and
587         * move node to right child of gchild and
588         * fixup parent of all this to point to gchild
589         */
590
591        let g_child_balance = g_child_node.balance;
592        if g_child_balance == right_heavy {
593            child_node.balance = left_heavy;
594        } else {
595            child_node.balance = 0;
596        }
597        child_node.parent = g_child;
598        g_child_node.set_child(dir, child);
599
600        if g_child_balance == left_heavy {
601            node.balance = right_heavy;
602        } else {
603            node.balance = 0;
604        }
605        g_child_node.balance = 0;
606
607        node.parent = g_child;
608        g_child_node.set_child(dir_inverse, data);
609
610        if !parent.is_null() {
611            g_child_node.parent = parent;
612            unsafe { (*parent).get_node() }.set_child(which_child, g_child);
613        } else {
614            g_child_node.parent = null();
615            self.root = g_child;
616        }
617        true
618    }
619
620    /*
621    fn replace(&mut self, old: *const P::Target, node: P) {
622        let old_node = unsafe { (*old).get_node() };
623        let new_ptr = node.into_raw();
624        let new_node = unsafe { (*new_ptr).get_node() };
625
626        let left = old_node.get_child(AvlDirection::Left);
627        if !left.is_null() {
628            self.set_child2(new_node, AvlDirection::Left, left, new_ptr);
629        }
630        let right = old_node.get_child(AvlDirection::Right);
631        if !right.is_null() {
632            self.set_child2(new_node, AvlDirection::Right, right, new_ptr);
633        }
634
635        new_node.balance = old_node.balance;
636        old_node.balance = 0;
637        let parent = old_node.get_parent();
638        if !parent.is_null() {
639            let dir = self.parent_direction(old, parent);
640            self.set_child2(unsafe { (*parent).get_node() }, dir, new_ptr, parent);
641            old_node.parent = null();
642        } else {
643            debug_assert_eq!(self.root, old);
644            self.root = new_ptr;
645        }
646    }
647    */
648
649    /// Requires `del` to be a valid pointer to a node in this tree.
650    ///
651    /// # Safety
652    ///
653    /// It does not drop the node data, only unlinks it.
654    /// Caller is responsible for re-taking ownership (e.g. via from_raw) and dropping if needed.
655    ///
656    pub unsafe fn remove(&mut self, del: *const P::Target) {
657        /*
658         * Deletion is easiest with a node that has at most 1 child.
659         * We swap a node with 2 children with a sequentially valued
660         * neighbor node. That node will have at most 1 child. Note this
661         * has no effect on the ordering of the remaining nodes.
662         *
663         * As an optimization, we choose the greater neighbor if the tree
664         * is right heavy, otherwise the left neighbor. This reduces the
665         * number of rotations needed.
666         */
667        if self.count == 0 {
668            return;
669        }
670        if self.count == 1 && self.root == del {
671            self.root = null();
672            self.count = 0;
673            unsafe { (*del).get_node().detach() };
674            return;
675        }
676        let mut which_child: AvlDirection;
677
678        // Use reference directly to get node, avoiding unsafe dereference of raw pointer
679        let del_node = unsafe { (*del).get_node() };
680
681        let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
682
683        if node_swap_flag {
684            let dir: AvlDirection = balance_to_child!(del_node.balance + 1);
685            let child_temp = del_node.get_child(dir);
686
687            let dir_inverse: AvlDirection = dir.reverse();
688            let child = self.bottom_child_ref(child_temp, dir_inverse);
689
690            // Fix Miri UB: Avoid calling parent_direction2(child) if child's parent is del,
691            // because that would create a aliasing &mut ref to del while we hold del_node.
692            let dir_child_temp =
693                if child == child_temp { dir } else { self.parent_direction2(child) };
694
695            // Fix Miri UB: Do not call parent_direction2(del) as it creates a new &mut AvlNode
696            // alias while we hold del_node. Use del_node to find parent direction.
697            let parent = del_node.get_parent();
698            let dir_child_del = if !parent.is_null() {
699                self.parent_direction(del, parent)
700            } else {
701                AvlDirection::Left
702            };
703
704            let child_node = unsafe { (*child).get_node() };
705            child_node.swap(del_node);
706
707            // move 'node' to delete's spot in the tree
708            if child_node.get_child(dir) == child {
709                // if node(d) left child is node(c)
710                child_node.set_child(dir, del);
711            }
712
713            let c_dir = child_node.get_child(dir);
714            if c_dir == del {
715                del_node.parent = child;
716            } else if !c_dir.is_null() {
717                unsafe { (*c_dir).get_node() }.parent = child;
718            }
719
720            let c_inv = child_node.get_child(dir_inverse);
721            if c_inv == del {
722                del_node.parent = child;
723            } else if !c_inv.is_null() {
724                unsafe { (*c_inv).get_node() }.parent = child;
725            }
726
727            let parent = child_node.get_parent();
728            if !parent.is_null() {
729                unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
730            } else {
731                self.root = child;
732            }
733
734            // Put tmp where node used to be (just temporary).
735            // It always has a parent and at most 1 child.
736            let parent = del_node.get_parent();
737            unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
738            if !del_node.right.is_null() {
739                which_child = AvlDirection::Right;
740            } else {
741                which_child = AvlDirection::Left;
742            }
743            let child = del_node.get_child(which_child);
744            if !child.is_null() {
745                unsafe { (*child).get_node() }.parent = del;
746            }
747            which_child = dir_child_temp;
748        } else {
749            // Fix Miri UB here as well
750            let parent = del_node.get_parent();
751            if !parent.is_null() {
752                which_child = self.parent_direction(del, parent);
753            } else {
754                which_child = AvlDirection::Left;
755            }
756        }
757
758        // Here we know "delete" is at least partially a leaf node. It can
759        // be easily removed from the tree.
760        let parent: *const P::Target = del_node.get_parent();
761
762        let imm_data: *const P::Target =
763            if !del_node.left.is_null() { del_node.left } else { del_node.right };
764
765        // Connect parent directly to node (leaving out delete).
766        if !imm_data.is_null() {
767            let imm_node = unsafe { (*imm_data).get_node() };
768            imm_node.parent = parent;
769        }
770
771        if !parent.is_null() {
772            assert!(self.count > 0);
773            self.count -= 1;
774
775            let parent_node = unsafe { (*parent).get_node() };
776            parent_node.set_child(which_child, imm_data);
777
778            //Since the subtree is now shorter, begin adjusting parent balances
779            //and performing any needed rotations.
780            let mut node_data: *const P::Target = parent;
781            let mut old_balance: i8;
782            let mut new_balance: i8;
783            loop {
784                // Move up the tree and adjust the balance.
785                // Capture the parent and which_child values for the next
786                // iteration before any rotations occur.
787                let node = unsafe { (*node_data).get_node() };
788                old_balance = node.balance;
789                new_balance = old_balance - avlchild_to_balance!(which_child);
790
791                //If a node was in perfect balance but isn't anymore then
792                //we can stop, since the height didn't change above this point
793                //due to a deletion.
794                if old_balance == 0 {
795                    node.balance = new_balance;
796                    break;
797                }
798
799                let parent = node.get_parent();
800                which_child = self.parent_direction(node_data, parent);
801
802                //If the new balance is zero, we don't need to rotate
803                //else
804                //need a rotation to fix the balance.
805                //If the rotation doesn't change the height
806                //of the sub-tree we have finished adjusting.
807                if new_balance == 0 {
808                    node.balance = new_balance;
809                } else if !self.rotate(node_data, new_balance) {
810                    break;
811                }
812
813                if !parent.is_null() {
814                    node_data = parent;
815                    continue;
816                }
817                break;
818            }
819        } else if !imm_data.is_null() {
820            assert!(self.count > 0);
821            self.count -= 1;
822            self.root = imm_data;
823        }
824        if self.root.is_null() && self.count > 0 {
825            panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
826        }
827        del_node.detach();
828    }
829
830    /// Removes a node from the tree by key.
831    ///
832    /// The `cmp_func` should compare the key `K` with the elements in the tree.
833    /// Returns `Some(P)` if an exact match was found and removed, `None` otherwise.
834    #[inline]
835    pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
836        let result = self.find(val, cmp_func);
837        self.remove_with(unsafe { result.detach() })
838    }
839
840    /// remove with a previous search result
841    ///
842    /// - If the result is exact match, return the removed element ownership
843    /// - If the result is not exact match, return None
844    ///
845    /// # Safety
846    ///
847    /// Once the tree structure changed, previous search result is not safe to use anymore.
848    ///
849    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
850    /// to avoid the borrowing issue.
851    #[inline]
852    pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
853        if result.is_exact() {
854            unsafe {
855                let p = result.node;
856                self.remove(p);
857                Some(P::from_raw(p))
858            }
859        } else {
860            None
861        }
862    }
863
864    /// Searches for an element in the tree.
865    ///
866    /// The `cmp_func` should compare the key `K` with the elements in the tree.
867    /// Returns an [`AvlSearchResult`] which indicates if an exact match was found,
868    /// or where a new element should be inserted.
869    #[inline]
870    pub fn find<'a, K>(
871        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
872    ) -> AvlSearchResult<'a, P> {
873        if self.root.is_null() {
874            return AvlSearchResult::default();
875        }
876        let mut node_data = self.root;
877        loop {
878            let diff = cmp_func(val, unsafe { &*node_data });
879            match diff {
880                Ordering::Equal => {
881                    return AvlSearchResult {
882                        node: node_data,
883                        direction: None,
884                        _phan: PhantomData,
885                    };
886                }
887                Ordering::Less => {
888                    let node = unsafe { (*node_data).get_node() };
889                    let left = node.get_child(AvlDirection::Left);
890                    if left.is_null() {
891                        return AvlSearchResult {
892                            node: node_data,
893                            direction: Some(AvlDirection::Left),
894                            _phan: PhantomData,
895                        };
896                    }
897                    node_data = left;
898                }
899                Ordering::Greater => {
900                    let node = unsafe { (*node_data).get_node() };
901                    let right = node.get_child(AvlDirection::Right);
902                    if right.is_null() {
903                        return AvlSearchResult {
904                            node: node_data,
905                            direction: Some(AvlDirection::Right),
906                            _phan: PhantomData,
907                        };
908                    }
909                    node_data = right;
910                }
911            }
912        }
913    }
914
915    // for range tree, val may overlap multiple range(node), ensure return the smallest
916    #[inline]
917    pub fn find_contained<'a, K>(
918        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
919    ) -> Option<&'a P::Target> {
920        if self.root.is_null() {
921            return None;
922        }
923        let mut node_data = self.root;
924        let mut result_node: *const P::Target = null();
925        loop {
926            let diff = cmp_func(val, unsafe { &*node_data });
927            match diff {
928                Ordering::Equal => {
929                    let node = unsafe { (*node_data).get_node() };
930                    let left = node.get_child(AvlDirection::Left);
931                    result_node = node_data;
932                    if left.is_null() {
933                        break;
934                    } else {
935                        node_data = left;
936                    }
937                }
938                Ordering::Less => {
939                    let node = unsafe { (*node_data).get_node() };
940                    let left = node.get_child(AvlDirection::Left);
941                    if left.is_null() {
942                        break;
943                    }
944                    node_data = left;
945                }
946                Ordering::Greater => {
947                    let node = unsafe { (*node_data).get_node() };
948                    let right = node.get_child(AvlDirection::Right);
949                    if right.is_null() {
950                        break;
951                    }
952                    node_data = right;
953                }
954            }
955        }
956        if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
957    }
958
959    // for slab, return any block larger or equal than search param
960    #[inline]
961    pub fn find_larger_eq<'a, K>(
962        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
963    ) -> AvlSearchResult<'a, P> {
964        if self.root.is_null() {
965            return AvlSearchResult::default();
966        }
967        let mut node_data = self.root;
968        loop {
969            let diff = cmp_func(val, unsafe { &*node_data });
970            match diff {
971                Ordering::Equal => {
972                    return AvlSearchResult {
973                        node: node_data,
974                        direction: None,
975                        _phan: PhantomData,
976                    };
977                }
978                Ordering::Less => {
979                    return AvlSearchResult {
980                        node: node_data,
981                        direction: None,
982                        _phan: PhantomData,
983                    };
984                }
985                Ordering::Greater => {
986                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
987                    if right.is_null() {
988                        return AvlSearchResult {
989                            node: null(),
990                            direction: None,
991                            _phan: PhantomData,
992                        };
993                    }
994                    node_data = right;
995                }
996            }
997        }
998    }
999
1000    /// For range tree
1001    #[inline]
1002    pub fn find_nearest<'a, K>(
1003        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1004    ) -> AvlSearchResult<'a, P> {
1005        if self.root.is_null() {
1006            return AvlSearchResult::default();
1007        }
1008
1009        let mut node_data = self.root;
1010        let mut nearest_node = null();
1011        loop {
1012            let diff = cmp_func(val, unsafe { &*node_data });
1013            match diff {
1014                Ordering::Equal => {
1015                    return AvlSearchResult {
1016                        node: node_data,
1017                        direction: None,
1018                        _phan: PhantomData,
1019                    };
1020                }
1021                Ordering::Less => {
1022                    nearest_node = node_data;
1023                    let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1024                    if left.is_null() {
1025                        break;
1026                    }
1027                    node_data = left;
1028                }
1029                Ordering::Greater => {
1030                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1031                    if right.is_null() {
1032                        break;
1033                    }
1034                    node_data = right;
1035                }
1036            }
1037        }
1038        AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData }
1039    }
1040
1041    #[inline(always)]
1042    fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1043        loop {
1044            let child = unsafe { (*data).get_node() }.get_child(dir);
1045            if !child.is_null() {
1046                data = child;
1047            } else {
1048                return data;
1049            }
1050        }
1051    }
1052
1053    pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1054        let mut node = self.first();
1055        while let Some(n) = node {
1056            cb(n);
1057            node = self.next(n);
1058        }
1059    }
1060
1061    #[inline]
1062    pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1063        if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1064            Some(unsafe { p.as_ref() })
1065        } else {
1066            None
1067        }
1068    }
1069
1070    #[inline]
1071    pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1072        if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1073            Some(unsafe { p.as_ref() })
1074        } else {
1075            None
1076        }
1077    }
1078
1079    #[inline]
1080    fn walk_dir(
1081        &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1082    ) -> Option<NonNull<P::Target>> {
1083        let dir_inverse = dir.reverse();
1084        let node = unsafe { (*data_ptr).get_node() };
1085        let temp = node.get_child(dir);
1086        if !temp.is_null() {
1087            unsafe {
1088                Some(NonNull::new_unchecked(
1089                    self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1090                ))
1091            }
1092        } else {
1093            let mut parent = node.parent;
1094            if parent.is_null() {
1095                return None;
1096            }
1097            loop {
1098                let pdir = self.parent_direction(data_ptr, parent);
1099                if pdir == dir_inverse {
1100                    return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1101                }
1102                data_ptr = parent;
1103                parent = unsafe { (*parent).get_node() }.parent;
1104                if parent.is_null() {
1105                    return None;
1106                }
1107            }
1108        }
1109    }
1110
1111    #[inline]
1112    fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1113        let node = unsafe { (*data).get_node() };
1114        let left = node.left;
1115        if !left.is_null() {
1116            assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1117            assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1118        }
1119        let right = node.right;
1120        if !right.is_null() {
1121            assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1122            assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1123        }
1124    }
1125
1126    #[inline]
1127    pub fn nearest<'a>(
1128        &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1129    ) -> AvlSearchResult<'a, P> {
1130        if !current.node.is_null() {
1131            if current.direction.is_some() && current.direction != Some(direction) {
1132                return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1133            }
1134            if let Some(node) = self.walk_dir(current.node, direction) {
1135                return AvlSearchResult {
1136                    node: node.as_ptr(),
1137                    direction: None,
1138                    _phan: PhantomData,
1139                };
1140            }
1141        }
1142        AvlSearchResult::default()
1143    }
1144
1145    pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1146        let c = {
1147            #[cfg(feature = "std")]
1148            {
1149                ((self.get_count() + 10) as f32).log2() as usize
1150            }
1151            #[cfg(not(feature = "std"))]
1152            {
1153                100
1154            }
1155        };
1156        let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1157        if self.root.is_null() {
1158            assert_eq!(self.count, 0);
1159            return;
1160        }
1161        let mut data = self.root;
1162        let mut visited = 0;
1163        loop {
1164            if !data.is_null() {
1165                let left = {
1166                    let node = unsafe { (*data).get_node() };
1167                    node.get_child(AvlDirection::Left)
1168                };
1169                if !left.is_null() {
1170                    stack.push(data);
1171                    data = left;
1172                    continue;
1173                }
1174                visited += 1;
1175                self.validate_node(data, cmp_func);
1176                data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1177            } else if !stack.is_empty() {
1178                let _data = stack.pop().unwrap();
1179                self.validate_node(_data, cmp_func);
1180                visited += 1;
1181                let node = unsafe { (*_data).get_node() };
1182                data = node.get_child(AvlDirection::Right);
1183            } else {
1184                break;
1185            }
1186        }
1187        assert_eq!(visited, self.count);
1188    }
1189
1190    /// Adds a new element to the tree, takes the ownership of P.
1191    ///
1192    /// The `cmp_func` should compare two elements to determine their relative order.
1193    /// Returns `true` if the element was added, `false` if an equivalent element
1194    /// already exists (in which case the provided `node` is dropped).
1195    #[inline]
1196    pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1197        if self.count == 0 && self.root.is_null() {
1198            self.root = node.into_raw();
1199            self.count = 1;
1200            return true;
1201        }
1202
1203        let w = self.find(node.as_ref(), cmp_func);
1204        if w.direction.is_none() {
1205            // To prevent memory leak, we must drop the node.
1206            // But since we took ownership, we have to convert it back to P and drop it.
1207            drop(node);
1208            return false;
1209        }
1210
1211        // Safety: We need to decouple the lifetime of 'w' from 'self' to call 'insert'.
1212        // We extract the pointers and reconstruct the result.
1213        let w_node = w.node;
1214        let w_dir = w.direction;
1215
1216        let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1217
1218        self.insert(node, w_detached);
1219        true
1220    }
1221}
1222
1223impl<P, Tag> Drop for AvlTree<P, Tag>
1224where
1225    P: Pointer,
1226    P::Target: AvlItem<Tag>,
1227{
1228    fn drop(&mut self) {
1229        if mem::needs_drop::<P>() {
1230            for _ in self.drain() {}
1231        }
1232    }
1233}
1234
1235pub struct AvlDrain<'a, P: Pointer, Tag>
1236where
1237    P::Target: AvlItem<Tag>,
1238{
1239    tree: &'a mut AvlTree<P, Tag>,
1240    parent: *const P::Target,
1241    dir: Option<AvlDirection>,
1242}
1243
1244impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1245where
1246    P::Target: AvlItem<Tag>,
1247{
1248    type Item = P;
1249
1250    fn next(&mut self) -> Option<Self::Item> {
1251        if self.tree.root.is_null() {
1252            return None;
1253        }
1254
1255        let mut node: *const P::Target;
1256        let parent: *const P::Target;
1257
1258        if self.dir.is_none() && self.parent.is_null() {
1259            // Initial call: find the leftmost node
1260            let mut curr = self.tree.root;
1261            while unsafe { !(*curr).get_node().left.is_null() } {
1262                curr = unsafe { (*curr).get_node().left };
1263            }
1264            node = curr;
1265        } else {
1266            parent = self.parent;
1267            if parent.is_null() {
1268                // Should not happen if root was nulled
1269                return None;
1270            }
1271
1272            let child_dir = self.dir.unwrap();
1273            // child_dir child of parent was just nulled in previous call?
1274            // NO, we null it in THIS call.
1275
1276            if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1277                node = parent;
1278            } else {
1279                // Finished left, go to right sibling
1280                node = unsafe { (*parent).get_node().right };
1281                while unsafe { !(*node).get_node().left.is_null() } {
1282                    node = unsafe { (*node).get_node().left };
1283                }
1284            }
1285        }
1286
1287        // Goto check_right_side logic
1288        if unsafe { !(*node).get_node().right.is_null() } {
1289            // It has a right child, so we must yield that first (in post-order)
1290            // Note: in AVL, if left is null, right must be a leaf.
1291            node = unsafe { (*node).get_node().right };
1292        }
1293
1294        // Determine next state
1295        let next_parent = unsafe { (*node).get_node().parent };
1296        if next_parent.is_null() {
1297            self.tree.root = null();
1298            self.parent = null();
1299            self.dir = Some(AvlDirection::Left);
1300        } else {
1301            self.parent = next_parent;
1302            self.dir = Some(self.tree.parent_direction(node, next_parent));
1303            // Unlink from parent NOW
1304            unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1305        }
1306
1307        self.tree.count -= 1;
1308        unsafe {
1309            (*node).get_node().detach();
1310            Some(P::from_raw(node))
1311        }
1312    }
1313}
1314
1315impl<T, Tag> AvlTree<Arc<T>, Tag>
1316where
1317    T: AvlItem<Tag>,
1318{
1319    pub fn remove_ref(&mut self, node: &Arc<T>) {
1320        let p = Arc::as_ptr(node);
1321        unsafe { self.remove(p) };
1322        unsafe { drop(Arc::from_raw(p)) };
1323    }
1324}
1325
1326impl<T, Tag> AvlTree<Rc<T>, Tag>
1327where
1328    T: AvlItem<Tag>,
1329{
1330    pub fn remove_ref(&mut self, node: &Rc<T>) {
1331        let p = Rc::as_ptr(node);
1332        unsafe { self.remove(p) };
1333        unsafe { drop(Rc::from_raw(p)) };
1334    }
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339    use super::*;
1340    use core::cell::UnsafeCell;
1341    use rand::Rng;
1342    use std::time::Instant;
1343
1344    struct IntAvlNode {
1345        pub value: i64,
1346        pub node: UnsafeCell<AvlNode<Self, ()>>,
1347    }
1348
1349    impl fmt::Debug for IntAvlNode {
1350        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1351            write!(f, "{} {:#?}", self.value, self.node)
1352        }
1353    }
1354
1355    impl fmt::Display for IntAvlNode {
1356        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1357            write!(f, "{}", self.value)
1358        }
1359    }
1360
1361    unsafe impl AvlItem<()> for IntAvlNode {
1362        fn get_node(&self) -> &mut AvlNode<Self, ()> {
1363            unsafe { &mut *self.node.get() }
1364        }
1365    }
1366
1367    type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1368
1369    fn new_intnode(i: i64) -> Box<IntAvlNode> {
1370        Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1371    }
1372
1373    fn new_inttree() -> IntAvlTree {
1374        AvlTree::<Box<IntAvlNode>, ()>::new()
1375    }
1376
1377    fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1378        a.value.cmp(&b.value)
1379    }
1380
1381    fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1382        a.cmp(&b.value)
1383    }
1384
1385    impl AvlTree<Box<IntAvlNode>, ()> {
1386        fn remove_int(&mut self, i: i64) -> bool {
1387            if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1388                // node is Box<IntAvlNode>, dropped automatically
1389                return true;
1390            }
1391            // else
1392            println!("not found {}", i);
1393            false
1394        }
1395
1396        fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1397            self.add(node, cmp_int_node)
1398        }
1399
1400        fn validate_tree(&self) {
1401            self.validate(cmp_int_node);
1402        }
1403
1404        fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1405            self.find(&i, cmp_int)
1406        }
1407
1408        fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1409            self.find(node, cmp_int_node)
1410        }
1411    }
1412
1413    #[test]
1414    fn int_avl_node() {
1415        let mut tree = new_inttree();
1416
1417        assert_eq!(tree.get_count(), 0);
1418        assert!(tree.first().is_none());
1419        assert!(tree.last().is_none());
1420
1421        let node1 = new_intnode(1);
1422        let node2 = new_intnode(2);
1423        let node3 = new_intnode(3);
1424
1425        let p1 = &*node1 as *const IntAvlNode;
1426        let p2 = &*node2 as *const IntAvlNode;
1427        let p3 = &*node3 as *const IntAvlNode;
1428
1429        tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1430        tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1431
1432        assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1433        // This is tricky as node1 is not in a tree, its parent is not set.
1434        // assert_eq!(tree.parent_direction2(p1), AvlDirection::Left);
1435        assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1436    }
1437
1438    #[test]
1439    fn int_avl_tree_basic() {
1440        let mut tree = new_inttree();
1441
1442        let temp_node = new_intnode(0);
1443        let temp_node_val = Pointer::as_ref(&temp_node);
1444        assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1445        assert_eq!(
1446            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1447            false
1448        );
1449        assert_eq!(
1450            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1451            false
1452        );
1453        drop(temp_node);
1454
1455        tree.add_int_node(new_intnode(0));
1456        let result = tree.find_int(0);
1457        assert!(result.get_node_ref().is_some());
1458        assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1459        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1460
1461        let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1462        assert!(rs.is_some());
1463        let found_value = rs.unwrap().value;
1464        assert_eq!(found_value, 0);
1465
1466        let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1467        assert!(rs.is_none());
1468
1469        let result = tree.find_int(1);
1470        let left = tree.nearest(&result, AvlDirection::Left);
1471        assert_eq!(left.is_exact(), true);
1472        assert_eq!(left.get_nearest().unwrap().value, 0);
1473        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1474
1475        tree.add_int_node(new_intnode(2));
1476        let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1477        assert!(rs.is_some());
1478        let found_value = rs.unwrap().value;
1479        assert_eq!(found_value, 2);
1480    }
1481
1482    #[test]
1483    fn int_avl_tree_order() {
1484        let max;
1485        #[cfg(miri)]
1486        {
1487            max = 2000;
1488        }
1489        #[cfg(not(miri))]
1490        {
1491            max = 200000;
1492        }
1493        let mut tree = new_inttree();
1494        assert!(tree.first().is_none());
1495        let start_ts = Instant::now();
1496        for i in 0..max {
1497            tree.add_int_node(new_intnode(i));
1498        }
1499        tree.validate_tree();
1500        assert_eq!(tree.get_count(), max as i64);
1501
1502        let mut count = 0;
1503        let mut current = tree.first();
1504        let last = tree.last();
1505        while let Some(c) = current {
1506            assert_eq!(c.value, count);
1507            count += 1;
1508            if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1509                current = None;
1510            } else {
1511                current = tree.next(c);
1512            }
1513        }
1514        assert_eq!(count, max);
1515
1516        {
1517            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1518            assert!(rs.is_some());
1519            let found_value = rs.unwrap().value;
1520            println!("found larger_eq {}", found_value);
1521            assert!(found_value >= 5);
1522            tree.remove_int(5);
1523            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1524            assert!(rs.is_some());
1525            assert!(rs.unwrap().value >= 6);
1526            tree.add_int_node(new_intnode(5));
1527        }
1528
1529        for i in 0..max {
1530            assert!(tree.remove_int(i));
1531        }
1532        assert_eq!(tree.get_count(), 0);
1533
1534        let end_ts = Instant::now();
1535        println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1536    }
1537
1538    #[test]
1539    fn int_avl_tree_fixed1() {
1540        let mut tree = new_inttree();
1541        let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1542        for i in arr.iter() {
1543            let node = new_intnode(*i);
1544            tree.add_int_node(node);
1545            let rs = tree.find_int(*i);
1546            assert!(rs.get_node_ref().is_some(), "add error {}", i);
1547        }
1548        assert_eq!(tree.get_count(), arr.len() as i64);
1549        for i in arr.iter() {
1550            assert!(tree.remove_int(*i));
1551        }
1552        assert_eq!(tree.get_count(), 0);
1553    }
1554
1555    #[test]
1556    fn int_avl_tree_fixed2() {
1557        let mut tree = new_inttree();
1558        tree.validate_tree();
1559        let node1 = new_intnode(536872960);
1560        {
1561            tree.add_int_node(node1);
1562            tree.validate_tree();
1563            tree.remove_int(536872960);
1564            tree.validate_tree();
1565            tree.add_int_node(new_intnode(536872960));
1566            tree.validate_tree();
1567        }
1568
1569        assert!(tree.find_int(536872960).get_node_ref().is_some());
1570        let node2 = new_intnode(12288);
1571        tree.add_int_node(node2);
1572        tree.validate_tree();
1573        tree.remove_int(536872960);
1574        tree.validate_tree();
1575        tree.add_int_node(new_intnode(536872960));
1576        tree.validate_tree();
1577        let node3 = new_intnode(22528);
1578        tree.add_int_node(node3);
1579        tree.validate_tree();
1580        tree.remove_int(12288);
1581        assert!(tree.find_int(12288).get_node_ref().is_none());
1582        tree.validate_tree();
1583        tree.remove_int(22528);
1584        assert!(tree.find_int(22528).get_node_ref().is_none());
1585        tree.validate_tree();
1586        tree.add_int_node(new_intnode(22528));
1587        tree.validate_tree();
1588    }
1589
1590    #[test]
1591    fn int_avl_tree_random() {
1592        let count = 1000;
1593        let mut test_list: Vec<i64> = Vec::with_capacity(count);
1594        let mut rng = rand::thread_rng();
1595        let mut tree = new_inttree();
1596        tree.validate_tree();
1597        for _ in 0..count {
1598            let node_value: i64 = rng.r#gen();
1599            if !test_list.contains(&node_value) {
1600                test_list.push(node_value);
1601                assert!(tree.add_int_node(new_intnode(node_value)))
1602            }
1603        }
1604        tree.validate_tree();
1605        test_list.sort();
1606        for index in 0..test_list.len() {
1607            let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1608            let prev = tree.prev(node_ptr);
1609            let next = tree.next(node_ptr);
1610            if index == 0 {
1611                // first node
1612                assert!(prev.is_none());
1613                assert!(next.is_some());
1614                assert_eq!(next.unwrap().value, test_list[index + 1]);
1615            } else if index == test_list.len() - 1 {
1616                // last node
1617                assert!(prev.is_some());
1618                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1619                assert!(next.is_none());
1620            } else {
1621                // middle node
1622                assert!(prev.is_some());
1623                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1624                assert!(next.is_some());
1625                assert_eq!(next.unwrap().value, test_list[index + 1]);
1626            }
1627        }
1628        for index in 0..test_list.len() {
1629            assert!(tree.remove_int(test_list[index]));
1630        }
1631        tree.validate_tree();
1632        assert_eq!(0, tree.get_count());
1633    }
1634
1635    #[test]
1636    fn int_avl_tree_insert_here() {
1637        let mut tree = new_inttree();
1638        let node1 = new_intnode(10);
1639        tree.add_int_node(node1);
1640        // Insert 5 before 10
1641        let rs = tree.find_int(10);
1642        let here = unsafe { rs.detach() };
1643        unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1644        tree.validate_tree();
1645        assert_eq!(tree.get_count(), 2);
1646        assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1647
1648        // Insert 15 after 10
1649        let rs = tree.find_int(10);
1650        let here = unsafe { rs.detach() };
1651        unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1652        tree.validate_tree();
1653        assert_eq!(tree.get_count(), 3);
1654        assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1655
1656        // Insert 3 before 5 (which is left child of 10)
1657        let rs = tree.find_int(5);
1658        let here = unsafe { rs.detach() };
1659        unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1660        tree.validate_tree();
1661        assert_eq!(tree.get_count(), 4);
1662
1663        // Insert 7 after 5
1664        let rs = tree.find_int(5);
1665        let here = unsafe { rs.detach() };
1666        unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1667        tree.validate_tree();
1668        assert_eq!(tree.get_count(), 5);
1669    }
1670
1671    #[test]
1672    fn test_arc_avl_tree_get_exact() {
1673        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1674        // Manually constructing Arc node
1675        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1676        tree.add(node.clone(), cmp_int_node);
1677
1678        // find returns AvlSearchResult<'a, Arc<IntAvlNode>>
1679        let result_search = tree.find(&100, cmp_int);
1680
1681        // This should invoke the specialized get_exact for Arc<T>
1682        let exact = result_search.get_exact();
1683        assert!(exact.is_some());
1684        let exact_arc = exact.unwrap();
1685        assert_eq!(exact_arc.value, 100);
1686        assert!(Arc::ptr_eq(&node, &exact_arc));
1687        // Check ref count: 1 (original) + 1 (in tree) + 1 (exact_arc) = 3
1688        assert_eq!(Arc::strong_count(&node), 3);
1689    }
1690
1691    #[test]
1692    fn test_arc_avl_tree_remove_ref() {
1693        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1694        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1695        tree.add(node.clone(), cmp_int_node);
1696        assert_eq!(tree.get_count(), 1);
1697        assert_eq!(Arc::strong_count(&node), 2);
1698
1699        tree.remove_ref(&node);
1700        assert_eq!(tree.get_count(), 0);
1701        assert_eq!(Arc::strong_count(&node), 1);
1702    }
1703
1704    #[test]
1705    fn test_arc_avl_tree_remove_with() {
1706        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1707        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1708        tree.add(node.clone(), cmp_int_node);
1709
1710        let removed = tree.remove_by_key(&300, cmp_int);
1711        assert!(removed.is_some());
1712        let removed_arc = removed.unwrap();
1713        assert_eq!(removed_arc.value, 300);
1714        assert_eq!(tree.get_count(), 0);
1715        // count: 1 (node) + 1 (removed_arc) = 2. Tree dropped its count.
1716        assert_eq!(Arc::strong_count(&node), 2);
1717
1718        drop(removed_arc);
1719        assert_eq!(Arc::strong_count(&node), 1);
1720    }
1721
1722    #[test]
1723    fn test_avl_drain() {
1724        let mut tree = new_inttree();
1725        for i in 0..100 {
1726            tree.add_int_node(new_intnode(i));
1727        }
1728        assert_eq!(tree.get_count(), 100);
1729
1730        let mut count = 0;
1731        for node in tree.drain() {
1732            assert!(node.value >= 0 && node.value < 100);
1733            count += 1;
1734        }
1735        assert_eq!(count, 100);
1736        assert_eq!(tree.get_count(), 0);
1737        assert!(tree.first().is_none());
1738    }
1739}