Skip to main content

embed_collections/
avl.rs

1//! An intrusive AVL tree implementation.
2//!
3//! The algorithm origin from open-zfs
4//!
5//! # Examples
6//!
7//! ## Using `Box` to search and remove by key
8//!
9//! ```rust
10//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
11//! use core::cell::UnsafeCell;
12//! use core::cmp::Ordering;
13//!
14//! struct MyNode {
15//!     value: i32,
16//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
17//! }
18//!
19//! unsafe impl AvlItem<()> for MyNode {
20//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
21//!         unsafe { &mut *self.avl_node.get() }
22//!     }
23//! }
24//!
25//! let mut tree = AvlTree::<Box<MyNode>, ()>::new();
26//! tree.add(Box::new(MyNode { value: 10, avl_node: UnsafeCell::new(Default::default()) }), |a, b| a.value.cmp(&b.value));
27//!
28//! // Search and remove
29//! if let Some(node) = tree.remove_by_key(&10, |key, node| key.cmp(&node.value)) {
30//!     assert_eq!(node.value, 10);
31//! }
32//! ```
33//!
34//! ## Using `Arc` for multiple ownership
35//!
36//! remove_ref only available to `Arc` and `Rc`
37//!
38//! ```rust
39//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
40//! use core::cell::UnsafeCell;
41//! use std::sync::Arc;
42//!
43//! struct MyNode {
44//!     value: i32,
45//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
46//! }
47//!
48//! unsafe impl AvlItem<()> for MyNode {
49//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
50//!         unsafe { &mut *self.avl_node.get() }
51//!     }
52//! }
53//!
54//! let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
55//! let node = Arc::new(MyNode { value: 42, avl_node: UnsafeCell::new(Default::default()) });
56//!
57//! tree.add(node.clone(), |a, b| a.value.cmp(&b.value));
58//! assert_eq!(tree.get_count(), 1);
59//!
60//! // Remove by reference (detach from avl tree)
61//! tree.remove_ref(&node);
62//! assert_eq!(tree.get_count(), 0);
63//! ```
64//!
65
66use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72    cmp::{Ordering, PartialEq},
73    fmt, mem,
74    ptr::{NonNull, null},
75};
76
77/// A trait to return internal mutable AvlNode for specified list.
78///
79/// The tag is used to distinguish different AvlNodes within the same item,
80/// allowing an item to belong to multiple lists simultaneously.
81/// For only one ownership, you can use `()`.
82///
83/// # Safety
84///
85/// Implementors must ensure `get_node` returns a valid reference to the `AvlNode`
86/// embedded within `Self`. Users must use `UnsafeCell` to hold `AvlNode` to support
87/// interior mutability required by list operations.
88pub unsafe trait AvlItem<Tag>: Sized {
89    fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94    Left = 0,
95    Right = 1,
96}
97
98impl AvlDirection {
99    #[inline(always)]
100    fn reverse(self) -> AvlDirection {
101        match self {
102            AvlDirection::Left => AvlDirection::Right,
103            AvlDirection::Right => AvlDirection::Left,
104        }
105    }
106}
107
108macro_rules! avlchild_to_balance {
109    ( $dir: expr ) => {
110        match $dir {
111            AvlDirection::Left => -1,
112            AvlDirection::Right => 1,
113        }
114    };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118    pub left: *const T,
119    pub right: *const T,
120    pub parent: *const T,
121    pub balance: i8,
122    _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128    #[inline(always)]
129    pub fn detach(&mut self) {
130        self.left = null();
131        self.right = null();
132        self.parent = null();
133        self.balance = 0;
134    }
135
136    #[inline(always)]
137    fn get_child(&self, dir: AvlDirection) -> *const T {
138        match dir {
139            AvlDirection::Left => self.left,
140            AvlDirection::Right => self.right,
141        }
142    }
143
144    #[inline(always)]
145    fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146        match dir {
147            AvlDirection::Left => self.left = child,
148            AvlDirection::Right => self.right = child,
149        }
150    }
151
152    #[inline(always)]
153    fn get_parent(&self) -> *const T {
154        self.parent
155    }
156
157    // Swap two node but not there value
158    #[inline(always)]
159    pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160        mem::swap(&mut self.left, &mut other.left);
161        mem::swap(&mut self.right, &mut other.right);
162        mem::swap(&mut self.parent, &mut other.parent);
163        mem::swap(&mut self.balance, &mut other.balance);
164    }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168    fn default() -> Self {
169        Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170    }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176        write!(f, "(")?;
177
178        if !self.left.is_null() {
179            write!(f, "left: {:p}", self.left)?;
180        } else {
181            write!(f, "left: none ")?;
182        }
183
184        if !self.right.is_null() {
185            write!(f, "right: {:p}", self.right)?;
186        } else {
187            write!(f, "right: none ")?;
188        }
189        write!(f, ")")
190    }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195/// An intrusive AVL tree (balanced binary search tree).
196///
197/// Elements in the tree must implement the [`AvlItem`] trait.
198/// The tree supports various pointer types (`Box`, `Arc`, `Rc`, etc.) through the [`Pointer`] trait.
199pub struct AvlTree<P, Tag>
200where
201    P: Pointer,
202    P::Target: AvlItem<Tag>,
203{
204    pub root: *const P::Target,
205    count: i64,
206    _phan: PhantomData<fn(P, &Tag)>,
207}
208
209/// Result of a search operation in an [`AvlTree`].
210///
211/// An `AvlSearchResult` identifies either:
212/// 1. An exact match: `direction` is `None` and `node` points to the matching item.
213/// 2. An insertion point: `direction` is `Some(dir)` and `node` points to the parent
214///    where a new node should be attached as the `dir` child.
215///
216/// The lifetime `'a` ties the search result to the tree's borrow, ensuring safety.
217/// However, this lifetime often prevents further mutable operations on the tree
218/// (e.g., adding a node while holding the search result). Use [`detach`](Self::detach)
219/// to de-couple the result from the tree's lifetime when necessary.
220pub struct AvlSearchResult<'a, P: Pointer> {
221    /// The matching node or the parent for insertion.
222    pub node: *const P::Target,
223    /// `None` if exact match found, or `Some(direction)` indicating insertion point.
224    pub direction: Option<AvlDirection>,
225    _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229    fn default() -> Self {
230        AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231    }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235    /// Returns a reference to the matching node if the search was an exact match.
236    #[inline(always)]
237    pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238        if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239    }
240
241    /// Returns `true` if the search result is an exact match.
242    #[inline(always)]
243    pub fn is_exact(&self) -> bool {
244        self.direction.is_none() && !self.node.is_null()
245    }
246
247    /// De-couple the lifetime of the search result from the tree.
248    ///
249    /// This method is essential for performing mutable operations on the tree
250    /// using search results. In Rust, a search result typically borrows the tree
251    /// immutably. If you need to modify the tree (e.g., call `insert` or `remove`)
252    /// based on that result, the borrow checker would normally prevent it.
253    ///
254    /// `detach` effectively "erases" the lifetime `'a`, returning a result with
255    /// an unbounded lifetime `'b`.
256    ///
257    /// # Examples
258    ///
259    /// Used in `RangeTree::add`:
260    /// ```ignore
261    /// let result = self.root.find(&rs_key, range_tree_segment_cmp);
262    /// // result is AvlSearchResult<'a, ...> and borrows self.root
263    ///
264    /// let detached = unsafe { result.detach() };
265    /// // detached has no lifetime bound to self.root
266    ///
267    /// self.space += size; // Mutable operation on self permitted
268    /// self.merge_seg(start, end, detached); // Mutation on tree permitted
269    /// ```
270    ///
271    /// # Safety
272    /// This is an unsafe operation. The compiler no longer protects the validity
273    /// of the internal pointer via lifetimes. You must ensure that the tree
274    /// structure is not modified in a way that invalidates `node` (e.g., the
275    /// parent node being removed) before using the detached result.
276    #[inline(always)]
277    pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278        AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279    }
280
281    /// Return the nearest node in the search result
282    #[inline(always)]
283    pub fn get_nearest(&self) -> Option<&P::Target> {
284        if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285    }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289    /// Returns the matching Arc node if this is an exact match.
290    pub fn get_exact(&self) -> Option<Arc<T>> {
291        if self.is_exact() {
292            unsafe {
293                Arc::increment_strong_count(self.node);
294                Some(Arc::from_raw(self.node))
295            }
296        } else {
297            None
298        }
299    }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303    /// Returns the matching Rc node if this is an exact match.
304    pub fn get_exact(&self) -> Option<Rc<T>> {
305        if self.is_exact() {
306            unsafe {
307                Rc::increment_strong_count(self.node);
308                Some(Rc::from_raw(self.node))
309            }
310        } else {
311            None
312        }
313    }
314}
315
316macro_rules! return_end {
317    ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321    ($balance: expr) => {
322        match $balance {
323            0 | 1 => AvlDirection::Left,
324            _ => AvlDirection::Right,
325        }
326    };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331    P: Pointer,
332    P::Target: AvlItem<Tag>,
333{
334    /// Creates a new, empty `AvlTree`.
335    pub fn new() -> Self {
336        AvlTree { count: 0, root: null(), _phan: Default::default() }
337    }
338
339    /// Returns an iterator that removes all elements from the tree in post-order.
340    ///
341    /// This is an optimized, non-recursive, and stack-less traversal that preserves
342    /// tree invariants during destruction.
343    #[inline]
344    pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345        AvlDrain { tree: self, parent: null(), dir: None }
346    }
347
348    pub fn get_count(&self) -> i64 {
349        self.count
350    }
351
352    pub fn first(&self) -> Option<&P::Target> {
353        unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354    }
355
356    #[inline]
357    pub fn last(&self) -> Option<&P::Target> {
358        unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359    }
360
361    /// Inserts a new node into the tree at the location specified by a search result.
362    ///
363    /// This is typically used after a [`find`](Self::find) operation didn't find an exact match.
364    ///
365    /// # Safety
366    ///
367    /// Once the tree structure changed, previous search result is not safe to use anymore.
368    ///
369    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
370    /// to avoid the borrowing issue.
371    ///
372    /// # Panics
373    /// Panics if the search result is an exact match (i.e. node already exists).
374    ///
375    /// # Examples
376    ///
377    /// ```rust
378    /// use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
379    /// use core::cell::UnsafeCell;
380    /// use std::sync::Arc;
381    ///
382    /// struct MyNode {
383    ///     value: i32,
384    ///     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
385    /// }
386    ///
387    /// unsafe impl AvlItem<()> for MyNode {
388    ///     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
389    ///         unsafe { &mut *self.avl_node.get() }
390    ///     }
391    /// }
392    ///
393    /// let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
394    /// let key = 42;
395    /// let result = tree.find(&key, |k, n| k.cmp(&n.value));
396    ///
397    /// if !result.is_exact() {
398    ///     let new_node = Arc::new(MyNode {
399    ///         value: key,
400    ///         avl_node: UnsafeCell::new(Default::default()),
401    ///     });
402    ///     tree.insert(new_node, unsafe{result.detach()});
403    /// }
404    /// ```
405    #[inline]
406    pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407        debug_assert!(w.direction.is_some());
408        self._insert(new_data, w.node, w.direction.unwrap());
409    }
410
411    #[allow(clippy::not_unsafe_ptr_arg_deref)]
412    pub fn _insert(
413        &mut self,
414        new_data: P,
415        here: *const P::Target, // parent
416        mut which_child: AvlDirection,
417    ) {
418        let mut new_balance: i8;
419        let new_ptr = new_data.into_raw();
420
421        if here.is_null() {
422            if self.count > 0 {
423                panic!("insert into a tree size {} with empty where.node", self.count);
424            }
425            self.root = new_ptr;
426            self.count += 1;
427            return;
428        }
429
430        let parent = unsafe { &*here };
431        let node = unsafe { (*new_ptr).get_node() };
432        let parent_node = parent.get_node();
433        node.parent = here;
434        parent_node.set_child(which_child, new_ptr);
435        self.count += 1;
436
437        /*
438         * Now, back up the tree modifying the balance of all nodes above the
439         * insertion point. If we get to a highly unbalanced ancestor, we
440         * need to do a rotation.  If we back out of the tree we are done.
441         * If we brought any subtree into perfect balance (0), we are also done.
442         */
443        let mut data: *const P::Target = here;
444        loop {
445            let node = unsafe { (*data).get_node() };
446            let old_balance = node.balance;
447            new_balance = old_balance + avlchild_to_balance!(which_child);
448            if new_balance == 0 {
449                node.balance = 0;
450                return;
451            }
452            if old_balance != 0 {
453                self.rotate(data, new_balance);
454                return;
455            }
456            node.balance = new_balance;
457            let parent_ptr = node.get_parent();
458            if parent_ptr.is_null() {
459                return;
460            }
461            which_child = self.parent_direction(data, parent_ptr);
462            data = parent_ptr;
463        }
464    }
465
466    /// Insert "new_data" in "tree" in the given "direction" either after or
467    /// before AvlDirection::After, AvlDirection::Before) the data "here".
468    ///
469    /// Insertions can only be done at empty leaf points in the tree, therefore
470    /// if the given child of the node is already present we move to either
471    /// the AVL_PREV or AVL_NEXT and reverse the insertion direction. Since
472    /// every other node in the tree is a leaf, this always works.
473    ///
474    /// # Safety
475    ///
476    /// Once the tree structure changed, previous search result is not safe to use anymore.
477    ///
478    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
479    /// to avoid the borrowing issue.
480    pub unsafe fn insert_here(
481        &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
482    ) {
483        let mut dir_child = direction;
484        assert!(!here.node.is_null());
485        let here_node = here.node;
486        let child = unsafe { (*here_node).get_node().get_child(dir_child) };
487        if !child.is_null() {
488            dir_child = dir_child.reverse();
489            let node = self.bottom_child_ref(child, dir_child);
490            self._insert(new_data, node, dir_child);
491        } else {
492            self._insert(new_data, here_node, dir_child);
493        }
494    }
495
496    // set child and both child's parent
497    #[inline(always)]
498    fn set_child2(
499        &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
500        parent: *const P::Target,
501    ) {
502        if !child.is_null() {
503            unsafe { (*child).get_node().parent = parent };
504        }
505        node.set_child(dir, child);
506    }
507
508    #[inline(always)]
509    fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
510        if !parent.is_null() {
511            let parent_node = unsafe { (*parent).get_node() };
512            if parent_node.left == data {
513                return AvlDirection::Left;
514            }
515            if parent_node.right == data {
516                return AvlDirection::Right;
517            }
518            panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
519        }
520        // this just follow zfs
521        AvlDirection::Left
522    }
523
524    #[inline(always)]
525    fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
526        let node = unsafe { (*data).get_node() };
527        let parent = node.get_parent();
528        if !parent.is_null() {
529            return self.parent_direction(data, parent);
530        }
531        // this just follow zfs
532        AvlDirection::Left
533    }
534
535    #[inline]
536    fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
537        let dir = if balance < 0 { AvlDirection::Left } else { AvlDirection::Right };
538        let node = unsafe { (*data).get_node() };
539
540        let parent = node.get_parent();
541        let dir_inverse = dir.reverse();
542        let left_heavy = balance >> 1;
543        let right_heavy = -left_heavy;
544
545        let child = node.get_child(dir);
546        let child_node = unsafe { (*child).get_node() };
547        let mut child_balance = child_node.balance;
548
549        let which_child = self.parent_direction(data, parent);
550
551        // node is overly left heavy, the left child is balanced or also left heavy.
552        if child_balance != right_heavy {
553            child_balance += right_heavy;
554
555            let c_right = child_node.get_child(dir_inverse);
556            self.set_child2(node, dir, c_right, data);
557            // move node to be child's right child
558            node.balance = -child_balance;
559
560            node.parent = child;
561            child_node.set_child(dir_inverse, data);
562            // update the pointer into this subtree
563
564            child_node.balance = child_balance;
565            if !parent.is_null() {
566                child_node.parent = parent;
567                unsafe { (*parent).get_node() }.set_child(which_child, child);
568            } else {
569                child_node.parent = null();
570                self.root = child;
571            }
572            return child_balance == 0;
573        }
574        // When node is left heavy, but child is right heavy we use
575        // a different rotation.
576
577        let g_child = child_node.get_child(dir_inverse);
578        let g_child_node = unsafe { (*g_child).get_node() };
579        let g_left = g_child_node.get_child(dir);
580        let g_right = g_child_node.get_child(dir_inverse);
581
582        self.set_child2(node, dir, g_right, data);
583        self.set_child2(child_node, dir_inverse, g_left, child);
584
585        /*
586         * move child to left child of gchild and
587         * move node to right child of gchild and
588         * fixup parent of all this to point to gchild
589         */
590
591        let g_child_balance = g_child_node.balance;
592        if g_child_balance == right_heavy {
593            child_node.balance = left_heavy;
594        } else {
595            child_node.balance = 0;
596        }
597        child_node.parent = g_child;
598        g_child_node.set_child(dir, child);
599
600        if g_child_balance == left_heavy {
601            node.balance = right_heavy;
602        } else {
603            node.balance = 0;
604        }
605        g_child_node.balance = 0;
606
607        node.parent = g_child;
608        g_child_node.set_child(dir_inverse, data);
609
610        if !parent.is_null() {
611            g_child_node.parent = parent;
612            unsafe { (*parent).get_node() }.set_child(which_child, g_child);
613        } else {
614            g_child_node.parent = null();
615            self.root = g_child;
616        }
617        true
618    }
619
620    /*
621    fn replace(&mut self, old: *const P::Target, node: P) {
622        let old_node = unsafe { (*old).get_node() };
623        let new_ptr = node.into_raw();
624        let new_node = unsafe { (*new_ptr).get_node() };
625
626        let left = old_node.get_child(AvlDirection::Left);
627        if !left.is_null() {
628            self.set_child2(new_node, AvlDirection::Left, left, new_ptr);
629        }
630        let right = old_node.get_child(AvlDirection::Right);
631        if !right.is_null() {
632            self.set_child2(new_node, AvlDirection::Right, right, new_ptr);
633        }
634
635        new_node.balance = old_node.balance;
636        old_node.balance = 0;
637        let parent = old_node.get_parent();
638        if !parent.is_null() {
639            let dir = self.parent_direction(old, parent);
640            self.set_child2(unsafe { (*parent).get_node() }, dir, new_ptr, parent);
641            old_node.parent = null();
642        } else {
643            debug_assert_eq!(self.root, old);
644            self.root = new_ptr;
645        }
646    }
647    */
648
649    /// Requires `del` to be a valid pointer to a node in this tree.
650    ///
651    /// # Safety
652    ///
653    /// It does not drop the node data, only unlinks it.
654    /// Caller is responsible for re-taking ownership (e.g. via from_raw) and dropping if needed.
655    ///
656    /// For Arc/Rc, use [Self::remove_ref()] instead.
657    ///
658    pub unsafe fn remove(&mut self, del: *const P::Target) {
659        /*
660         * Deletion is easiest with a node that has at most 1 child.
661         * We swap a node with 2 children with a sequentially valued
662         * neighbor node. That node will have at most 1 child. Note this
663         * has no effect on the ordering of the remaining nodes.
664         *
665         * As an optimization, we choose the greater neighbor if the tree
666         * is right heavy, otherwise the left neighbor. This reduces the
667         * number of rotations needed.
668         */
669        if self.count == 0 {
670            return;
671        }
672        if self.count == 1 && self.root == del {
673            self.root = null();
674            self.count = 0;
675            unsafe { (*del).get_node().detach() };
676            return;
677        }
678        let mut which_child: AvlDirection;
679
680        // Use reference directly to get node, avoiding unsafe dereference of raw pointer
681        let del_node = unsafe { (*del).get_node() };
682
683        let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
684
685        if node_swap_flag {
686            let dir: AvlDirection = balance_to_child!(del_node.balance + 1);
687            let child_temp = del_node.get_child(dir);
688
689            let dir_inverse: AvlDirection = dir.reverse();
690            let child = self.bottom_child_ref(child_temp, dir_inverse);
691
692            // Fix Miri UB: Avoid calling parent_direction2(child) if child's parent is del,
693            // because that would create a aliasing &mut ref to del while we hold del_node.
694            let dir_child_temp =
695                if child == child_temp { dir } else { self.parent_direction2(child) };
696
697            // Fix Miri UB: Do not call parent_direction2(del) as it creates a new &mut AvlNode
698            // alias while we hold del_node. Use del_node to find parent direction.
699            let parent = del_node.get_parent();
700            let dir_child_del = if !parent.is_null() {
701                self.parent_direction(del, parent)
702            } else {
703                AvlDirection::Left
704            };
705
706            let child_node = unsafe { (*child).get_node() };
707            child_node.swap(del_node);
708
709            // move 'node' to delete's spot in the tree
710            if child_node.get_child(dir) == child {
711                // if node(d) left child is node(c)
712                child_node.set_child(dir, del);
713            }
714
715            let c_dir = child_node.get_child(dir);
716            if c_dir == del {
717                del_node.parent = child;
718            } else if !c_dir.is_null() {
719                unsafe { (*c_dir).get_node() }.parent = child;
720            }
721
722            let c_inv = child_node.get_child(dir_inverse);
723            if c_inv == del {
724                del_node.parent = child;
725            } else if !c_inv.is_null() {
726                unsafe { (*c_inv).get_node() }.parent = child;
727            }
728
729            let parent = child_node.get_parent();
730            if !parent.is_null() {
731                unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
732            } else {
733                self.root = child;
734            }
735
736            // Put tmp where node used to be (just temporary).
737            // It always has a parent and at most 1 child.
738            let parent = del_node.get_parent();
739            unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
740            if !del_node.right.is_null() {
741                which_child = AvlDirection::Right;
742            } else {
743                which_child = AvlDirection::Left;
744            }
745            let child = del_node.get_child(which_child);
746            if !child.is_null() {
747                unsafe { (*child).get_node() }.parent = del;
748            }
749            which_child = dir_child_temp;
750        } else {
751            // Fix Miri UB here as well
752            let parent = del_node.get_parent();
753            if !parent.is_null() {
754                which_child = self.parent_direction(del, parent);
755            } else {
756                which_child = AvlDirection::Left;
757            }
758        }
759
760        // Here we know "delete" is at least partially a leaf node. It can
761        // be easily removed from the tree.
762        let parent: *const P::Target = del_node.get_parent();
763
764        let imm_data: *const P::Target =
765            if !del_node.left.is_null() { del_node.left } else { del_node.right };
766
767        // Connect parent directly to node (leaving out delete).
768        if !imm_data.is_null() {
769            let imm_node = unsafe { (*imm_data).get_node() };
770            imm_node.parent = parent;
771        }
772
773        if !parent.is_null() {
774            assert!(self.count > 0);
775            self.count -= 1;
776
777            let parent_node = unsafe { (*parent).get_node() };
778            parent_node.set_child(which_child, imm_data);
779
780            //Since the subtree is now shorter, begin adjusting parent balances
781            //and performing any needed rotations.
782            let mut node_data: *const P::Target = parent;
783            let mut old_balance: i8;
784            let mut new_balance: i8;
785            loop {
786                // Move up the tree and adjust the balance.
787                // Capture the parent and which_child values for the next
788                // iteration before any rotations occur.
789                let node = unsafe { (*node_data).get_node() };
790                old_balance = node.balance;
791                new_balance = old_balance - avlchild_to_balance!(which_child);
792
793                //If a node was in perfect balance but isn't anymore then
794                //we can stop, since the height didn't change above this point
795                //due to a deletion.
796                if old_balance == 0 {
797                    node.balance = new_balance;
798                    break;
799                }
800
801                let parent = node.get_parent();
802                which_child = self.parent_direction(node_data, parent);
803
804                //If the new balance is zero, we don't need to rotate
805                //else
806                //need a rotation to fix the balance.
807                //If the rotation doesn't change the height
808                //of the sub-tree we have finished adjusting.
809                if new_balance == 0 {
810                    node.balance = new_balance;
811                } else if !self.rotate(node_data, new_balance) {
812                    break;
813                }
814
815                if !parent.is_null() {
816                    node_data = parent;
817                    continue;
818                }
819                break;
820            }
821        } else if !imm_data.is_null() {
822            assert!(self.count > 0);
823            self.count -= 1;
824            self.root = imm_data;
825        }
826        if self.root.is_null() && self.count > 0 {
827            panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
828        }
829        del_node.detach();
830    }
831
832    /// Removes a node from the tree by key.
833    ///
834    /// The `cmp_func` should compare the key `K` with the elements in the tree.
835    /// Returns `Some(P)` if an exact match was found and removed, `None` otherwise.
836    #[inline]
837    pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
838        let result = self.find(val, cmp_func);
839        self.remove_with(unsafe { result.detach() })
840    }
841
842    /// remove with a previous search result
843    ///
844    /// - If the result is exact match, return the removed element ownership
845    /// - If the result is not exact match, return None
846    ///
847    /// # Safety
848    ///
849    /// Once the tree structure changed, previous search result is not safe to use anymore.
850    ///
851    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
852    /// to avoid the borrowing issue.
853    #[inline]
854    pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
855        if result.is_exact() {
856            unsafe {
857                let p = result.node;
858                self.remove(p);
859                Some(P::from_raw(p))
860            }
861        } else {
862            None
863        }
864    }
865
866    /// Searches for an element in the tree.
867    ///
868    /// The `cmp_func` should compare the key `K` with the elements in the tree.
869    /// Returns an [`AvlSearchResult`] which indicates if an exact match was found,
870    /// or where a new element should be inserted.
871    #[inline]
872    pub fn find<'a, K>(
873        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
874    ) -> AvlSearchResult<'a, P> {
875        if self.root.is_null() {
876            return AvlSearchResult::default();
877        }
878        let mut node_data = self.root;
879        loop {
880            let diff = cmp_func(val, unsafe { &*node_data });
881            match diff {
882                Ordering::Equal => {
883                    return AvlSearchResult {
884                        node: node_data,
885                        direction: None,
886                        _phan: PhantomData,
887                    };
888                }
889                Ordering::Less => {
890                    let node = unsafe { (*node_data).get_node() };
891                    let left = node.get_child(AvlDirection::Left);
892                    if left.is_null() {
893                        return AvlSearchResult {
894                            node: node_data,
895                            direction: Some(AvlDirection::Left),
896                            _phan: PhantomData,
897                        };
898                    }
899                    node_data = left;
900                }
901                Ordering::Greater => {
902                    let node = unsafe { (*node_data).get_node() };
903                    let right = node.get_child(AvlDirection::Right);
904                    if right.is_null() {
905                        return AvlSearchResult {
906                            node: node_data,
907                            direction: Some(AvlDirection::Right),
908                            _phan: PhantomData,
909                        };
910                    }
911                    node_data = right;
912                }
913            }
914        }
915    }
916
917    // for range tree, val may overlap multiple range(node), ensure return the smallest
918    #[inline]
919    pub fn find_contained<'a, K>(
920        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
921    ) -> Option<&'a P::Target> {
922        if self.root.is_null() {
923            return None;
924        }
925        let mut node_data = self.root;
926        let mut result_node: *const P::Target = null();
927        loop {
928            let diff = cmp_func(val, unsafe { &*node_data });
929            match diff {
930                Ordering::Equal => {
931                    let node = unsafe { (*node_data).get_node() };
932                    let left = node.get_child(AvlDirection::Left);
933                    result_node = node_data;
934                    if left.is_null() {
935                        break;
936                    } else {
937                        node_data = left;
938                    }
939                }
940                Ordering::Less => {
941                    let node = unsafe { (*node_data).get_node() };
942                    let left = node.get_child(AvlDirection::Left);
943                    if left.is_null() {
944                        break;
945                    }
946                    node_data = left;
947                }
948                Ordering::Greater => {
949                    let node = unsafe { (*node_data).get_node() };
950                    let right = node.get_child(AvlDirection::Right);
951                    if right.is_null() {
952                        break;
953                    }
954                    node_data = right;
955                }
956            }
957        }
958        if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
959    }
960
961    // for slab, return any block larger or equal than search param
962    #[inline]
963    pub fn find_larger_eq<'a, K>(
964        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
965    ) -> AvlSearchResult<'a, P> {
966        if self.root.is_null() {
967            return AvlSearchResult::default();
968        }
969        let mut node_data = self.root;
970        loop {
971            let diff = cmp_func(val, unsafe { &*node_data });
972            match diff {
973                Ordering::Equal => {
974                    return AvlSearchResult {
975                        node: node_data,
976                        direction: None,
977                        _phan: PhantomData,
978                    };
979                }
980                Ordering::Less => {
981                    return AvlSearchResult {
982                        node: node_data,
983                        direction: None,
984                        _phan: PhantomData,
985                    };
986                }
987                Ordering::Greater => {
988                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
989                    if right.is_null() {
990                        return AvlSearchResult {
991                            node: null(),
992                            direction: None,
993                            _phan: PhantomData,
994                        };
995                    }
996                    node_data = right;
997                }
998            }
999        }
1000    }
1001
1002    /// For range tree
1003    #[inline]
1004    pub fn find_nearest<'a, K>(
1005        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1006    ) -> AvlSearchResult<'a, P> {
1007        if self.root.is_null() {
1008            return AvlSearchResult::default();
1009        }
1010
1011        let mut node_data = self.root;
1012        let mut nearest_node = null();
1013        loop {
1014            let diff = cmp_func(val, unsafe { &*node_data });
1015            match diff {
1016                Ordering::Equal => {
1017                    return AvlSearchResult {
1018                        node: node_data,
1019                        direction: None,
1020                        _phan: PhantomData,
1021                    };
1022                }
1023                Ordering::Less => {
1024                    nearest_node = node_data;
1025                    let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1026                    if left.is_null() {
1027                        break;
1028                    }
1029                    node_data = left;
1030                }
1031                Ordering::Greater => {
1032                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1033                    if right.is_null() {
1034                        break;
1035                    }
1036                    node_data = right;
1037                }
1038            }
1039        }
1040        AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData }
1041    }
1042
1043    #[inline(always)]
1044    fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1045        loop {
1046            let child = unsafe { (*data).get_node() }.get_child(dir);
1047            if !child.is_null() {
1048                data = child;
1049            } else {
1050                return data;
1051            }
1052        }
1053    }
1054
1055    pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1056        let mut node = self.first();
1057        while let Some(n) = node {
1058            cb(n);
1059            node = self.next(n);
1060        }
1061    }
1062
1063    #[inline]
1064    pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1065        if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1066            Some(unsafe { p.as_ref() })
1067        } else {
1068            None
1069        }
1070    }
1071
1072    #[inline]
1073    pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1074        if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1075            Some(unsafe { p.as_ref() })
1076        } else {
1077            None
1078        }
1079    }
1080
1081    #[inline]
1082    fn walk_dir(
1083        &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1084    ) -> Option<NonNull<P::Target>> {
1085        let dir_inverse = dir.reverse();
1086        let node = unsafe { (*data_ptr).get_node() };
1087        let temp = node.get_child(dir);
1088        if !temp.is_null() {
1089            unsafe {
1090                Some(NonNull::new_unchecked(
1091                    self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1092                ))
1093            }
1094        } else {
1095            let mut parent = node.parent;
1096            if parent.is_null() {
1097                return None;
1098            }
1099            loop {
1100                let pdir = self.parent_direction(data_ptr, parent);
1101                if pdir == dir_inverse {
1102                    return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1103                }
1104                data_ptr = parent;
1105                parent = unsafe { (*parent).get_node() }.parent;
1106                if parent.is_null() {
1107                    return None;
1108                }
1109            }
1110        }
1111    }
1112
1113    #[inline]
1114    fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1115        let node = unsafe { (*data).get_node() };
1116        let left = node.left;
1117        if !left.is_null() {
1118            assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1119            assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1120        }
1121        let right = node.right;
1122        if !right.is_null() {
1123            assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1124            assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1125        }
1126    }
1127
1128    #[inline]
1129    pub fn nearest<'a>(
1130        &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1131    ) -> AvlSearchResult<'a, P> {
1132        if !current.node.is_null() {
1133            if current.direction.is_some() && current.direction != Some(direction) {
1134                return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1135            }
1136            if let Some(node) = self.walk_dir(current.node, direction) {
1137                return AvlSearchResult {
1138                    node: node.as_ptr(),
1139                    direction: None,
1140                    _phan: PhantomData,
1141                };
1142            }
1143        }
1144        AvlSearchResult::default()
1145    }
1146
1147    pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1148        let c = {
1149            #[cfg(feature = "std")]
1150            {
1151                ((self.get_count() + 10) as f32).log2() as usize
1152            }
1153            #[cfg(not(feature = "std"))]
1154            {
1155                100
1156            }
1157        };
1158        let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1159        if self.root.is_null() {
1160            assert_eq!(self.count, 0);
1161            return;
1162        }
1163        let mut data = self.root;
1164        let mut visited = 0;
1165        loop {
1166            if !data.is_null() {
1167                let left = {
1168                    let node = unsafe { (*data).get_node() };
1169                    node.get_child(AvlDirection::Left)
1170                };
1171                if !left.is_null() {
1172                    stack.push(data);
1173                    data = left;
1174                    continue;
1175                }
1176                visited += 1;
1177                self.validate_node(data, cmp_func);
1178                data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1179            } else if !stack.is_empty() {
1180                let _data = stack.pop().unwrap();
1181                self.validate_node(_data, cmp_func);
1182                visited += 1;
1183                let node = unsafe { (*_data).get_node() };
1184                data = node.get_child(AvlDirection::Right);
1185            } else {
1186                break;
1187            }
1188        }
1189        assert_eq!(visited, self.count);
1190    }
1191
1192    /// Adds a new element to the tree, takes the ownership of P.
1193    ///
1194    /// The `cmp_func` should compare two elements to determine their relative order.
1195    /// Returns `true` if the element was added, `false` if an equivalent element
1196    /// already exists (in which case the provided `node` is dropped).
1197    #[inline]
1198    pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1199        if self.count == 0 && self.root.is_null() {
1200            self.root = node.into_raw();
1201            self.count = 1;
1202            return true;
1203        }
1204
1205        let w = self.find(node.as_ref(), cmp_func);
1206        if w.direction.is_none() {
1207            // To prevent memory leak, we must drop the node.
1208            // But since we took ownership, we have to convert it back to P and drop it.
1209            drop(node);
1210            return false;
1211        }
1212
1213        // Safety: We need to decouple the lifetime of 'w' from 'self' to call 'insert'.
1214        // We extract the pointers and reconstruct the result.
1215        let w_node = w.node;
1216        let w_dir = w.direction;
1217
1218        let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1219
1220        self.insert(node, w_detached);
1221        true
1222    }
1223}
1224
1225impl<P, Tag> Drop for AvlTree<P, Tag>
1226where
1227    P: Pointer,
1228    P::Target: AvlItem<Tag>,
1229{
1230    fn drop(&mut self) {
1231        if mem::needs_drop::<P>() {
1232            for _ in self.drain() {}
1233        }
1234    }
1235}
1236
1237pub struct AvlDrain<'a, P: Pointer, Tag>
1238where
1239    P::Target: AvlItem<Tag>,
1240{
1241    tree: &'a mut AvlTree<P, Tag>,
1242    parent: *const P::Target,
1243    dir: Option<AvlDirection>,
1244}
1245
1246impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1247where
1248    P::Target: AvlItem<Tag>,
1249{
1250    type Item = P;
1251
1252    fn next(&mut self) -> Option<Self::Item> {
1253        if self.tree.root.is_null() {
1254            return None;
1255        }
1256
1257        let mut node: *const P::Target;
1258        let parent: *const P::Target;
1259
1260        if self.dir.is_none() && self.parent.is_null() {
1261            // Initial call: find the leftmost node
1262            let mut curr = self.tree.root;
1263            while unsafe { !(*curr).get_node().left.is_null() } {
1264                curr = unsafe { (*curr).get_node().left };
1265            }
1266            node = curr;
1267        } else {
1268            parent = self.parent;
1269            if parent.is_null() {
1270                // Should not happen if root was nulled
1271                return None;
1272            }
1273
1274            let child_dir = self.dir.unwrap();
1275            // child_dir child of parent was just nulled in previous call?
1276            // NO, we null it in THIS call.
1277
1278            if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1279                node = parent;
1280            } else {
1281                // Finished left, go to right sibling
1282                node = unsafe { (*parent).get_node().right };
1283                while unsafe { !(*node).get_node().left.is_null() } {
1284                    node = unsafe { (*node).get_node().left };
1285                }
1286            }
1287        }
1288
1289        // Goto check_right_side logic
1290        if unsafe { !(*node).get_node().right.is_null() } {
1291            // It has a right child, so we must yield that first (in post-order)
1292            // Note: in AVL, if left is null, right must be a leaf.
1293            node = unsafe { (*node).get_node().right };
1294        }
1295
1296        // Determine next state
1297        let next_parent = unsafe { (*node).get_node().parent };
1298        if next_parent.is_null() {
1299            self.tree.root = null();
1300            self.parent = null();
1301            self.dir = Some(AvlDirection::Left);
1302        } else {
1303            self.parent = next_parent;
1304            self.dir = Some(self.tree.parent_direction(node, next_parent));
1305            // Unlink from parent NOW
1306            unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1307        }
1308
1309        self.tree.count -= 1;
1310        unsafe {
1311            (*node).get_node().detach();
1312            Some(P::from_raw(node))
1313        }
1314    }
1315}
1316
1317impl<T, Tag> AvlTree<Arc<T>, Tag>
1318where
1319    T: AvlItem<Tag>,
1320{
1321    pub fn remove_ref(&mut self, node: &Arc<T>) {
1322        let p = Arc::as_ptr(node);
1323        unsafe { self.remove(p) };
1324        unsafe { drop(Arc::from_raw(p)) };
1325    }
1326}
1327
1328impl<T, Tag> AvlTree<Rc<T>, Tag>
1329where
1330    T: AvlItem<Tag>,
1331{
1332    pub fn remove_ref(&mut self, node: &Rc<T>) {
1333        let p = Rc::as_ptr(node);
1334        unsafe { self.remove(p) };
1335        unsafe { drop(Rc::from_raw(p)) };
1336    }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341    use super::*;
1342    use core::cell::UnsafeCell;
1343    use rand::Rng;
1344    use std::time::Instant;
1345
1346    struct IntAvlNode {
1347        pub value: i64,
1348        pub node: UnsafeCell<AvlNode<Self, ()>>,
1349    }
1350
1351    impl fmt::Debug for IntAvlNode {
1352        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1353            write!(f, "{} {:#?}", self.value, self.node)
1354        }
1355    }
1356
1357    impl fmt::Display for IntAvlNode {
1358        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1359            write!(f, "{}", self.value)
1360        }
1361    }
1362
1363    unsafe impl AvlItem<()> for IntAvlNode {
1364        fn get_node(&self) -> &mut AvlNode<Self, ()> {
1365            unsafe { &mut *self.node.get() }
1366        }
1367    }
1368
1369    type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1370
1371    fn new_intnode(i: i64) -> Box<IntAvlNode> {
1372        Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1373    }
1374
1375    fn new_inttree() -> IntAvlTree {
1376        AvlTree::<Box<IntAvlNode>, ()>::new()
1377    }
1378
1379    fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1380        a.value.cmp(&b.value)
1381    }
1382
1383    fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1384        a.cmp(&b.value)
1385    }
1386
1387    impl AvlTree<Box<IntAvlNode>, ()> {
1388        fn remove_int(&mut self, i: i64) -> bool {
1389            if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1390                // node is Box<IntAvlNode>, dropped automatically
1391                return true;
1392            }
1393            // else
1394            println!("not found {}", i);
1395            false
1396        }
1397
1398        fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1399            self.add(node, cmp_int_node)
1400        }
1401
1402        fn validate_tree(&self) {
1403            self.validate(cmp_int_node);
1404        }
1405
1406        fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1407            self.find(&i, cmp_int)
1408        }
1409
1410        fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1411            self.find(node, cmp_int_node)
1412        }
1413    }
1414
1415    #[test]
1416    fn int_avl_node() {
1417        let mut tree = new_inttree();
1418
1419        assert_eq!(tree.get_count(), 0);
1420        assert!(tree.first().is_none());
1421        assert!(tree.last().is_none());
1422
1423        let node1 = new_intnode(1);
1424        let node2 = new_intnode(2);
1425        let node3 = new_intnode(3);
1426
1427        let p1 = &*node1 as *const IntAvlNode;
1428        let p2 = &*node2 as *const IntAvlNode;
1429        let p3 = &*node3 as *const IntAvlNode;
1430
1431        tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1432        tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1433
1434        assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1435        // This is tricky as node1 is not in a tree, its parent is not set.
1436        // assert_eq!(tree.parent_direction2(p1), AvlDirection::Left);
1437        assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1438    }
1439
1440    #[test]
1441    fn int_avl_tree_basic() {
1442        let mut tree = new_inttree();
1443
1444        let temp_node = new_intnode(0);
1445        let temp_node_val = Pointer::as_ref(&temp_node);
1446        assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1447        assert_eq!(
1448            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1449            false
1450        );
1451        assert_eq!(
1452            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1453            false
1454        );
1455        drop(temp_node);
1456
1457        tree.add_int_node(new_intnode(0));
1458        let result = tree.find_int(0);
1459        assert!(result.get_node_ref().is_some());
1460        assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1461        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1462
1463        let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1464        assert!(rs.is_some());
1465        let found_value = rs.unwrap().value;
1466        assert_eq!(found_value, 0);
1467
1468        let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1469        assert!(rs.is_none());
1470
1471        let result = tree.find_int(1);
1472        let left = tree.nearest(&result, AvlDirection::Left);
1473        assert_eq!(left.is_exact(), true);
1474        assert_eq!(left.get_nearest().unwrap().value, 0);
1475        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1476
1477        tree.add_int_node(new_intnode(2));
1478        let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1479        assert!(rs.is_some());
1480        let found_value = rs.unwrap().value;
1481        assert_eq!(found_value, 2);
1482    }
1483
1484    #[test]
1485    fn int_avl_tree_order() {
1486        let max;
1487        #[cfg(miri)]
1488        {
1489            max = 2000;
1490        }
1491        #[cfg(not(miri))]
1492        {
1493            max = 200000;
1494        }
1495        let mut tree = new_inttree();
1496        assert!(tree.first().is_none());
1497        let start_ts = Instant::now();
1498        for i in 0..max {
1499            tree.add_int_node(new_intnode(i));
1500        }
1501        tree.validate_tree();
1502        assert_eq!(tree.get_count(), max as i64);
1503
1504        let mut count = 0;
1505        let mut current = tree.first();
1506        let last = tree.last();
1507        while let Some(c) = current {
1508            assert_eq!(c.value, count);
1509            count += 1;
1510            if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1511                current = None;
1512            } else {
1513                current = tree.next(c);
1514            }
1515        }
1516        assert_eq!(count, max);
1517
1518        {
1519            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1520            assert!(rs.is_some());
1521            let found_value = rs.unwrap().value;
1522            println!("found larger_eq {}", found_value);
1523            assert!(found_value >= 5);
1524            tree.remove_int(5);
1525            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1526            assert!(rs.is_some());
1527            assert!(rs.unwrap().value >= 6);
1528            tree.add_int_node(new_intnode(5));
1529        }
1530
1531        for i in 0..max {
1532            assert!(tree.remove_int(i));
1533        }
1534        assert_eq!(tree.get_count(), 0);
1535
1536        let end_ts = Instant::now();
1537        println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1538    }
1539
1540    #[test]
1541    fn int_avl_tree_fixed1() {
1542        let mut tree = new_inttree();
1543        let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1544        for i in arr.iter() {
1545            let node = new_intnode(*i);
1546            tree.add_int_node(node);
1547            let rs = tree.find_int(*i);
1548            assert!(rs.get_node_ref().is_some(), "add error {}", i);
1549        }
1550        assert_eq!(tree.get_count(), arr.len() as i64);
1551        for i in arr.iter() {
1552            assert!(tree.remove_int(*i));
1553        }
1554        assert_eq!(tree.get_count(), 0);
1555    }
1556
1557    #[test]
1558    fn int_avl_tree_fixed2() {
1559        let mut tree = new_inttree();
1560        tree.validate_tree();
1561        let node1 = new_intnode(536872960);
1562        {
1563            tree.add_int_node(node1);
1564            tree.validate_tree();
1565            tree.remove_int(536872960);
1566            tree.validate_tree();
1567            tree.add_int_node(new_intnode(536872960));
1568            tree.validate_tree();
1569        }
1570
1571        assert!(tree.find_int(536872960).get_node_ref().is_some());
1572        let node2 = new_intnode(12288);
1573        tree.add_int_node(node2);
1574        tree.validate_tree();
1575        tree.remove_int(536872960);
1576        tree.validate_tree();
1577        tree.add_int_node(new_intnode(536872960));
1578        tree.validate_tree();
1579        let node3 = new_intnode(22528);
1580        tree.add_int_node(node3);
1581        tree.validate_tree();
1582        tree.remove_int(12288);
1583        assert!(tree.find_int(12288).get_node_ref().is_none());
1584        tree.validate_tree();
1585        tree.remove_int(22528);
1586        assert!(tree.find_int(22528).get_node_ref().is_none());
1587        tree.validate_tree();
1588        tree.add_int_node(new_intnode(22528));
1589        tree.validate_tree();
1590    }
1591
1592    #[test]
1593    fn int_avl_tree_random() {
1594        let count = 1000;
1595        let mut test_list: Vec<i64> = Vec::with_capacity(count);
1596        let mut rng = rand::thread_rng();
1597        let mut tree = new_inttree();
1598        tree.validate_tree();
1599        for _ in 0..count {
1600            let node_value: i64 = rng.r#gen();
1601            if !test_list.contains(&node_value) {
1602                test_list.push(node_value);
1603                assert!(tree.add_int_node(new_intnode(node_value)))
1604            }
1605        }
1606        tree.validate_tree();
1607        test_list.sort();
1608        for index in 0..test_list.len() {
1609            let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1610            let prev = tree.prev(node_ptr);
1611            let next = tree.next(node_ptr);
1612            if index == 0 {
1613                // first node
1614                assert!(prev.is_none());
1615                assert!(next.is_some());
1616                assert_eq!(next.unwrap().value, test_list[index + 1]);
1617            } else if index == test_list.len() - 1 {
1618                // last node
1619                assert!(prev.is_some());
1620                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1621                assert!(next.is_none());
1622            } else {
1623                // middle node
1624                assert!(prev.is_some());
1625                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1626                assert!(next.is_some());
1627                assert_eq!(next.unwrap().value, test_list[index + 1]);
1628            }
1629        }
1630        for index in 0..test_list.len() {
1631            assert!(tree.remove_int(test_list[index]));
1632        }
1633        tree.validate_tree();
1634        assert_eq!(0, tree.get_count());
1635    }
1636
1637    #[test]
1638    fn int_avl_tree_insert_here() {
1639        let mut tree = new_inttree();
1640        let node1 = new_intnode(10);
1641        tree.add_int_node(node1);
1642        // Insert 5 before 10
1643        let rs = tree.find_int(10);
1644        let here = unsafe { rs.detach() };
1645        unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1646        tree.validate_tree();
1647        assert_eq!(tree.get_count(), 2);
1648        assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1649
1650        // Insert 15 after 10
1651        let rs = tree.find_int(10);
1652        let here = unsafe { rs.detach() };
1653        unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1654        tree.validate_tree();
1655        assert_eq!(tree.get_count(), 3);
1656        assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1657
1658        // Insert 3 before 5 (which is left child of 10)
1659        let rs = tree.find_int(5);
1660        let here = unsafe { rs.detach() };
1661        unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1662        tree.validate_tree();
1663        assert_eq!(tree.get_count(), 4);
1664
1665        // Insert 7 after 5
1666        let rs = tree.find_int(5);
1667        let here = unsafe { rs.detach() };
1668        unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1669        tree.validate_tree();
1670        assert_eq!(tree.get_count(), 5);
1671    }
1672
1673    #[test]
1674    fn test_arc_avl_tree_get_exact() {
1675        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1676        // Manually constructing Arc node
1677        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1678        tree.add(node.clone(), cmp_int_node);
1679
1680        // find returns AvlSearchResult<'a, Arc<IntAvlNode>>
1681        let result_search = tree.find(&100, cmp_int);
1682
1683        // This should invoke the specialized get_exact for Arc<T>
1684        let exact = result_search.get_exact();
1685        assert!(exact.is_some());
1686        let exact_arc = exact.unwrap();
1687        assert_eq!(exact_arc.value, 100);
1688        assert!(Arc::ptr_eq(&node, &exact_arc));
1689        // Check ref count: 1 (original) + 1 (in tree) + 1 (exact_arc) = 3
1690        assert_eq!(Arc::strong_count(&node), 3);
1691    }
1692
1693    #[test]
1694    fn test_arc_avl_tree_remove_ref() {
1695        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1696        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1697        tree.add(node.clone(), cmp_int_node);
1698        assert_eq!(tree.get_count(), 1);
1699        assert_eq!(Arc::strong_count(&node), 2);
1700
1701        tree.remove_ref(&node);
1702        assert_eq!(tree.get_count(), 0);
1703        assert_eq!(Arc::strong_count(&node), 1);
1704    }
1705
1706    #[test]
1707    fn test_arc_avl_tree_remove_with() {
1708        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1709        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1710        tree.add(node.clone(), cmp_int_node);
1711
1712        let removed = tree.remove_by_key(&300, cmp_int);
1713        assert!(removed.is_some());
1714        let removed_arc = removed.unwrap();
1715        assert_eq!(removed_arc.value, 300);
1716        assert_eq!(tree.get_count(), 0);
1717        // count: 1 (node) + 1 (removed_arc) = 2. Tree dropped its count.
1718        assert_eq!(Arc::strong_count(&node), 2);
1719
1720        drop(removed_arc);
1721        assert_eq!(Arc::strong_count(&node), 1);
1722    }
1723
1724    #[test]
1725    fn test_avl_drain() {
1726        let mut tree = new_inttree();
1727        for i in 0..100 {
1728            tree.add_int_node(new_intnode(i));
1729        }
1730        assert_eq!(tree.get_count(), 100);
1731
1732        let mut count = 0;
1733        for node in tree.drain() {
1734            assert!(node.value >= 0 && node.value < 100);
1735            count += 1;
1736        }
1737        assert_eq!(count, 100);
1738        assert_eq!(tree.get_count(), 0);
1739        assert!(tree.first().is_none());
1740    }
1741}