Skip to main content

embed_collections/
avl.rs

1//! An intrusive AVL tree implementation.
2//!
3//! The algorithm origin from open-zfs
4//!
5//! # Examples
6//!
7//! ## Using `Box` to search and remove by key
8//!
9//! ```rust
10//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
11//! use core::cell::UnsafeCell;
12//! use core::cmp::Ordering;
13//!
14//! struct MyNode {
15//!     value: i32,
16//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
17//! }
18//!
19//! unsafe impl AvlItem<()> for MyNode {
20//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
21//!         unsafe { &mut *self.avl_node.get() }
22//!     }
23//! }
24//!
25//! let mut tree = AvlTree::<Box<MyNode>, ()>::new();
26//! tree.add(Box::new(MyNode { value: 10, avl_node: UnsafeCell::new(Default::default()) }), |a, b| a.value.cmp(&b.value));
27//!
28//! // Search and remove
29//! if let Some(node) = tree.remove_by_key(&10, |key, node| key.cmp(&node.value)) {
30//!     assert_eq!(node.value, 10);
31//! }
32//! ```
33//!
34//! ## Using `Arc` for multiple ownership
35//!
36//! remove_ref only available to `Arc` and `Rc`
37//!
38//! ```rust
39//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
40//! use core::cell::UnsafeCell;
41//! use std::sync::Arc;
42//!
43//! struct MyNode {
44//!     value: i32,
45//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
46//! }
47//!
48//! unsafe impl AvlItem<()> for MyNode {
49//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
50//!         unsafe { &mut *self.avl_node.get() }
51//!     }
52//! }
53//!
54//! let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
55//! let node = Arc::new(MyNode { value: 42, avl_node: UnsafeCell::new(Default::default()) });
56//!
57//! tree.add(node.clone(), |a, b| a.value.cmp(&b.value));
58//! assert_eq!(tree.get_count(), 1);
59//!
60//! // Remove by reference (detach from avl tree)
61//! tree.remove_ref(&node);
62//! assert_eq!(tree.get_count(), 0);
63//! ```
64//!
65
66use crate::Pointer;
67use alloc::rc::Rc;
68use alloc::sync::Arc;
69use alloc::vec::Vec;
70use core::marker::PhantomData;
71use core::{
72    cmp::{Ordering, PartialEq},
73    fmt, mem,
74    ptr::{NonNull, null},
75};
76
77/// A trait to return internal mutable AvlNode for specified list.
78///
79/// The tag is used to distinguish different AvlNodes within the same item,
80/// allowing an item to belong to multiple lists simultaneously.
81/// For only one ownership, you can use `()`.
82///
83/// # Safety
84///
85/// Implementors must ensure `get_node` returns a valid reference to the `AvlNode`
86/// embedded within `Self`. Users must use `UnsafeCell` to hold `AvlNode` to support
87/// interior mutability required by list operations.
88pub unsafe trait AvlItem<Tag>: Sized {
89    fn get_node(&self) -> &mut AvlNode<Self, Tag>;
90}
91
92#[derive(PartialEq, Debug, Copy, Clone)]
93pub enum AvlDirection {
94    Left = 0,
95    Right = 1,
96}
97
98impl AvlDirection {
99    #[inline(always)]
100    fn reverse(self) -> AvlDirection {
101        match self {
102            AvlDirection::Left => AvlDirection::Right,
103            AvlDirection::Right => AvlDirection::Left,
104        }
105    }
106}
107
108macro_rules! avlchild_to_balance {
109    ( $dir: expr ) => {
110        match $dir {
111            AvlDirection::Left => -1,
112            AvlDirection::Right => 1,
113        }
114    };
115}
116
117pub struct AvlNode<T: Sized, Tag> {
118    pub left: *const T,
119    pub right: *const T,
120    pub parent: *const T,
121    pub balance: i8,
122    _phan: PhantomData<fn(&Tag)>,
123}
124
125unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
126
127impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
128    #[inline(always)]
129    pub fn detach(&mut self) {
130        self.left = null();
131        self.right = null();
132        self.parent = null();
133        self.balance = 0;
134    }
135
136    #[inline(always)]
137    fn get_child(&self, dir: AvlDirection) -> *const T {
138        match dir {
139            AvlDirection::Left => self.left,
140            AvlDirection::Right => self.right,
141        }
142    }
143
144    #[inline(always)]
145    fn set_child(&mut self, dir: AvlDirection, child: *const T) {
146        match dir {
147            AvlDirection::Left => self.left = child,
148            AvlDirection::Right => self.right = child,
149        }
150    }
151
152    #[inline(always)]
153    fn get_parent(&self) -> *const T {
154        return self.parent;
155    }
156
157    // Swap two node but not there value
158    #[inline(always)]
159    pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
160        mem::swap(&mut self.left, &mut other.left);
161        mem::swap(&mut self.right, &mut other.right);
162        mem::swap(&mut self.parent, &mut other.parent);
163        mem::swap(&mut self.balance, &mut other.balance);
164    }
165}
166
167impl<T, Tag> Default for AvlNode<T, Tag> {
168    fn default() -> Self {
169        Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
170    }
171}
172
173#[allow(unused_must_use)]
174impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
175    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176        write!(f, "(")?;
177
178        if !self.left.is_null() {
179            write!(f, "left: {:p}", self.left)?;
180        } else {
181            write!(f, "left: none ")?;
182        }
183
184        if !self.right.is_null() {
185            write!(f, "right: {:p}", self.right)?;
186        } else {
187            write!(f, "right: none ")?;
188        }
189        write!(f, ")")
190    }
191}
192
193pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
194
195/// An intrusive AVL tree (balanced binary search tree).
196///
197/// Elements in the tree must implement the [`AvlItem`] trait.
198/// The tree supports various pointer types (`Box`, `Arc`, `Rc`, etc.) through the [`Pointer`] trait.
199pub struct AvlTree<P, Tag>
200where
201    P: Pointer,
202    P::Target: AvlItem<Tag>,
203{
204    pub root: *const P::Target,
205    count: i64,
206    _phan: PhantomData<fn(P, &Tag)>,
207}
208
209/// Result of a search operation in an [`AvlTree`].
210///
211/// An `AvlSearchResult` identifies either:
212/// 1. An exact match: `direction` is `None` and `node` points to the matching item.
213/// 2. An insertion point: `direction` is `Some(dir)` and `node` points to the parent
214///    where a new node should be attached as the `dir` child.
215///
216/// The lifetime `'a` ties the search result to the tree's borrow, ensuring safety.
217/// However, this lifetime often prevents further mutable operations on the tree
218/// (e.g., adding a node while holding the search result). Use [`detach`](Self::detach)
219/// to de-couple the result from the tree's lifetime when necessary.
220pub struct AvlSearchResult<'a, P: Pointer> {
221    /// The matching node or the parent for insertion.
222    pub node: *const P::Target,
223    /// `None` if exact match found, or `Some(direction)` indicating insertion point.
224    pub direction: Option<AvlDirection>,
225    _phan: PhantomData<&'a P::Target>,
226}
227
228impl<P: Pointer> Default for AvlSearchResult<'_, P> {
229    fn default() -> Self {
230        AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
231    }
232}
233
234impl<'a, P: Pointer> AvlSearchResult<'a, P> {
235    /// Returns a reference to the matching node if the search was an exact match.
236    #[inline(always)]
237    pub fn get_node_ref(&self) -> Option<&'a P::Target> {
238        if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
239    }
240
241    /// Returns `true` if the search result is an exact match.
242    #[inline(always)]
243    pub fn is_exact(&self) -> bool {
244        self.direction.is_none() && !self.node.is_null()
245    }
246
247    /// De-couple the lifetime of the search result from the tree.
248    ///
249    /// This method is essential for performing mutable operations on the tree
250    /// using search results. In Rust, a search result typically borrows the tree
251    /// immutably. If you need to modify the tree (e.g., call `insert` or `remove`)
252    /// based on that result, the borrow checker would normally prevent it.
253    ///
254    /// `detach` effectively "erases" the lifetime `'a`, returning a result with
255    /// an unbounded lifetime `'b`.
256    ///
257    /// # Examples
258    ///
259    /// Used in `RangeTree::add`:
260    /// ```ignore
261    /// let result = self.root.find(&rs_key, range_tree_segment_cmp);
262    /// // result is AvlSearchResult<'a, ...> and borrows self.root
263    ///
264    /// let detached = unsafe { result.detach() };
265    /// // detached has no lifetime bound to self.root
266    ///
267    /// self.space += size; // Mutable operation on self permitted
268    /// self.merge_seg(start, end, detached); // Mutation on tree permitted
269    /// ```
270    ///
271    /// # Safety
272    /// This is an unsafe operation. The compiler no longer protects the validity
273    /// of the internal pointer via lifetimes. You must ensure that the tree
274    /// structure is not modified in a way that invalidates `node` (e.g., the
275    /// parent node being removed) before using the detached result.
276    #[inline(always)]
277    pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
278        AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
279    }
280
281    /// Return the nearest node in the search result
282    #[inline(always)]
283    pub fn get_nearest(&self) -> Option<&P::Target> {
284        if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
285    }
286}
287
288impl<'a, T> AvlSearchResult<'a, Arc<T>> {
289    /// Returns the matching Arc node if this is an exact match.
290    pub fn get_exact(&self) -> Option<Arc<T>> {
291        if self.is_exact() {
292            unsafe {
293                Arc::increment_strong_count(self.node);
294                Some(Arc::from_raw(self.node))
295            }
296        } else {
297            None
298        }
299    }
300}
301
302impl<'a, T> AvlSearchResult<'a, Rc<T>> {
303    /// Returns the matching Rc node if this is an exact match.
304    pub fn get_exact(&self) -> Option<Rc<T>> {
305        if self.is_exact() {
306            unsafe {
307                Rc::increment_strong_count(self.node);
308                Some(Rc::from_raw(self.node))
309            }
310        } else {
311            None
312        }
313    }
314}
315
316macro_rules! return_end {
317    ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
318}
319
320macro_rules! balance_to_child {
321    ($balance: expr) => {
322        match $balance {
323            0 | 1 => AvlDirection::Left,
324            _ => AvlDirection::Right,
325        }
326    };
327}
328
329impl<P, Tag> AvlTree<P, Tag>
330where
331    P: Pointer,
332    P::Target: AvlItem<Tag>,
333{
334    /// Creates a new, empty `AvlTree`.
335    pub fn new() -> Self {
336        AvlTree { count: 0, root: null(), _phan: Default::default() }
337    }
338
339    /// Returns an iterator that removes all elements from the tree in post-order.
340    ///
341    /// This is an optimized, non-recursive, and stack-less traversal that preserves
342    /// tree invariants during destruction.
343    #[inline]
344    pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
345        AvlDrain { tree: self, parent: null(), dir: None }
346    }
347
348    pub fn get_count(&self) -> i64 {
349        self.count
350    }
351
352    pub fn first(&self) -> Option<&P::Target> {
353        unsafe { return_end!(self, AvlDirection::Left).as_ref() }
354    }
355
356    #[inline]
357    pub fn last(&self) -> Option<&P::Target> {
358        unsafe { return_end!(self, AvlDirection::Right).as_ref() }
359    }
360
361    /// Inserts a new node into the tree at the location specified by a search result.
362    ///
363    /// This is typically used after a [`find`](Self::find) operation didn't find an exact match.
364    ///
365    /// # Safety
366    ///
367    /// Once the tree structure changed, previous search result is not safe to use anymore.
368    ///
369    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
370    /// to avoid the borrowing issue.
371    ///
372    /// # Panics
373    /// Panics if the search result is an exact match (i.e. node already exists).
374    ///
375    /// # Examples
376    ///
377    /// ```rust
378    /// use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
379    /// use core::cell::UnsafeCell;
380    /// use std::sync::Arc;
381    ///
382    /// struct MyNode {
383    ///     value: i32,
384    ///     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
385    /// }
386    ///
387    /// unsafe impl AvlItem<()> for MyNode {
388    ///     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
389    ///         unsafe { &mut *self.avl_node.get() }
390    ///     }
391    /// }
392    ///
393    /// let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
394    /// let key = 42;
395    /// let result = tree.find(&key, |k, n| k.cmp(&n.value));
396    ///
397    /// if !result.is_exact() {
398    ///     let new_node = Arc::new(MyNode {
399    ///         value: key,
400    ///         avl_node: UnsafeCell::new(Default::default()),
401    ///     });
402    ///     tree.insert(new_node, unsafe{result.detach()});
403    /// }
404    /// ```
405    #[inline]
406    pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
407        debug_assert!(w.direction.is_some());
408        self._insert(new_data, w.node, w.direction.unwrap());
409    }
410
411    pub fn _insert(
412        &mut self,
413        new_data: P,
414        here: *const P::Target, // parent
415        mut which_child: AvlDirection,
416    ) {
417        let mut new_balance: i8;
418        let new_ptr = new_data.into_raw();
419
420        if here.is_null() {
421            if self.count > 0 {
422                panic!("insert into a tree size {} with empty where.node", self.count);
423            }
424            self.root = new_ptr;
425            self.count += 1;
426            return;
427        }
428
429        let parent = unsafe { &*here };
430        let node = unsafe { (*new_ptr).get_node() };
431        let parent_node = parent.get_node();
432        node.parent = here;
433        parent_node.set_child(which_child, new_ptr);
434        self.count += 1;
435
436        /*
437         * Now, back up the tree modifying the balance of all nodes above the
438         * insertion point. If we get to a highly unbalanced ancestor, we
439         * need to do a rotation.  If we back out of the tree we are done.
440         * If we brought any subtree into perfect balance (0), we are also done.
441         */
442        let mut data: *const P::Target = here;
443        loop {
444            let node = unsafe { (*data).get_node() };
445            let old_balance = node.balance;
446            new_balance = old_balance + avlchild_to_balance!(which_child);
447            if new_balance == 0 {
448                node.balance = 0;
449                return;
450            }
451            if old_balance != 0 {
452                self.rotate(data, new_balance);
453                return;
454            }
455            node.balance = new_balance;
456            let parent_ptr = node.get_parent();
457            if parent_ptr.is_null() {
458                return;
459            }
460            which_child = self.parent_direction(data, parent_ptr);
461            data = parent_ptr;
462        }
463    }
464
465    /// Insert "new_data" in "tree" in the given "direction" either after or
466    /// before AvlDirection::After, AvlDirection::Before) the data "here".
467    ///
468    /// Insertions can only be done at empty leaf points in the tree, therefore
469    /// if the given child of the node is already present we move to either
470    /// the AVL_PREV or AVL_NEXT and reverse the insertion direction. Since
471    /// every other node in the tree is a leaf, this always works.
472    ///
473    /// # Safety
474    ///
475    /// Once the tree structure changed, previous search result is not safe to use anymore.
476    ///
477    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
478    /// to avoid the borrowing issue.
479    pub unsafe fn insert_here(
480        &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
481    ) {
482        let mut dir_child = direction;
483        assert_eq!(here.node.is_null(), false);
484        let here_node = here.node;
485        let child = unsafe { (*here_node).get_node().get_child(dir_child) };
486        if !child.is_null() {
487            dir_child = dir_child.reverse();
488            let node = self.bottom_child_ref(child, dir_child);
489            self._insert(new_data, node, dir_child);
490        } else {
491            self._insert(new_data, here_node, dir_child);
492        }
493    }
494
495    // set child and both child's parent
496    #[inline(always)]
497    fn set_child2(
498        &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
499        parent: *const P::Target,
500    ) {
501        if !child.is_null() {
502            unsafe { (*child).get_node().parent = parent };
503        }
504        node.set_child(dir, child);
505    }
506
507    #[inline(always)]
508    fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
509        if !parent.is_null() {
510            let parent_node = unsafe { (*parent).get_node() };
511            if parent_node.left == data {
512                return AvlDirection::Left;
513            }
514            if parent_node.right == data {
515                return AvlDirection::Right;
516            }
517            panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
518        }
519        // this just follow zfs
520        AvlDirection::Left
521    }
522
523    #[inline(always)]
524    fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
525        let node = unsafe { (*data).get_node() };
526        let parent = node.get_parent();
527        if !parent.is_null() {
528            return self.parent_direction(data, parent);
529        }
530        // this just follow zfs
531        AvlDirection::Left
532    }
533
534    #[inline]
535    fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
536        let dir: AvlDirection;
537        if balance < 0 {
538            dir = AvlDirection::Left;
539        } else {
540            dir = AvlDirection::Right;
541        }
542        let node = unsafe { (*data).get_node() };
543
544        let parent = node.get_parent();
545        let dir_inverse = dir.reverse();
546        let left_heavy = balance >> 1;
547        let right_heavy = -left_heavy;
548
549        let child = node.get_child(dir);
550        let child_node = unsafe { (*child).get_node() };
551        let mut child_balance = child_node.balance;
552
553        let which_child = self.parent_direction(data, parent);
554
555        // node is overly left heavy, the left child is balanced or also left heavy.
556        if child_balance != right_heavy {
557            child_balance += right_heavy;
558
559            let c_right = child_node.get_child(dir_inverse);
560            self.set_child2(node, dir, c_right, data);
561            // move node to be child's right child
562            node.balance = -child_balance;
563
564            node.parent = child;
565            child_node.set_child(dir_inverse, data);
566            // update the pointer into this subtree
567
568            child_node.balance = child_balance;
569            if !parent.is_null() {
570                child_node.parent = parent;
571                unsafe { (*parent).get_node() }.set_child(which_child, child);
572            } else {
573                child_node.parent = null();
574                self.root = child;
575            }
576            return child_balance == 0;
577        }
578        // When node is left heavy, but child is right heavy we use
579        // a different rotation.
580
581        let g_child = child_node.get_child(dir_inverse);
582        let g_child_node = unsafe { (*g_child).get_node() };
583        let g_left = g_child_node.get_child(dir);
584        let g_right = g_child_node.get_child(dir_inverse);
585
586        self.set_child2(node, dir, g_right, data);
587        self.set_child2(child_node, dir_inverse, g_left, child);
588
589        /*
590         * move child to left child of gchild and
591         * move node to right child of gchild and
592         * fixup parent of all this to point to gchild
593         */
594
595        let g_child_balance = g_child_node.balance;
596        if g_child_balance == right_heavy {
597            child_node.balance = left_heavy;
598        } else {
599            child_node.balance = 0;
600        }
601        child_node.parent = g_child;
602        g_child_node.set_child(dir, child);
603
604        if g_child_balance == left_heavy {
605            node.balance = right_heavy;
606        } else {
607            node.balance = 0;
608        }
609        g_child_node.balance = 0;
610
611        node.parent = g_child;
612        g_child_node.set_child(dir_inverse, data);
613
614        if !parent.is_null() {
615            g_child_node.parent = parent;
616            unsafe { (*parent).get_node() }.set_child(which_child, g_child);
617        } else {
618            g_child_node.parent = null();
619            self.root = g_child;
620        }
621        return true;
622    }
623
624    /*
625    fn replace(&mut self, old: *const P::Target, node: P) {
626        let old_node = unsafe { (*old).get_node() };
627        let new_ptr = node.into_raw();
628        let new_node = unsafe { (*new_ptr).get_node() };
629
630        let left = old_node.get_child(AvlDirection::Left);
631        if !left.is_null() {
632            self.set_child2(new_node, AvlDirection::Left, left, new_ptr);
633        }
634        let right = old_node.get_child(AvlDirection::Right);
635        if !right.is_null() {
636            self.set_child2(new_node, AvlDirection::Right, right, new_ptr);
637        }
638
639        new_node.balance = old_node.balance;
640        old_node.balance = 0;
641        let parent = old_node.get_parent();
642        if !parent.is_null() {
643            let dir = self.parent_direction(old, parent);
644            self.set_child2(unsafe { (*parent).get_node() }, dir, new_ptr, parent);
645            old_node.parent = null();
646        } else {
647            debug_assert_eq!(self.root, old);
648            self.root = new_ptr;
649        }
650    }
651    */
652
653    /// Requires `del` to be a valid pointer to a node in this tree.
654    ///
655    /// # Safety
656    ///
657    /// It does not drop the node data, only unlinks it.
658    /// Caller is responsible for re-taking ownership (e.g. via from_raw) and dropping if needed.
659    ///
660    pub unsafe fn remove(&mut self, del: *const P::Target) {
661        /*
662         * Deletion is easiest with a node that has at most 1 child.
663         * We swap a node with 2 children with a sequentially valued
664         * neighbor node. That node will have at most 1 child. Note this
665         * has no effect on the ordering of the remaining nodes.
666         *
667         * As an optimization, we choose the greater neighbor if the tree
668         * is right heavy, otherwise the left neighbor. This reduces the
669         * number of rotations needed.
670         */
671        if self.count == 0 {
672            return;
673        }
674        if self.count == 1 && self.root == del {
675            self.root = null();
676            self.count = 0;
677            unsafe { (*del).get_node().detach() };
678            return;
679        }
680        let mut which_child: AvlDirection;
681        let imm_data: *const P::Target;
682        let parent: *const P::Target;
683        // Use reference directly to get node, avoiding unsafe dereference of raw pointer
684        let del_node = unsafe { (*del).get_node() };
685
686        let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
687
688        if node_swap_flag {
689            let dir: AvlDirection;
690            let dir_child_temp: AvlDirection;
691            let dir_child_del: AvlDirection;
692            let dir_inverse: AvlDirection;
693            dir = balance_to_child!(del_node.balance + 1);
694
695            let child_temp = del_node.get_child(dir);
696
697            dir_inverse = dir.reverse();
698            let child = self.bottom_child_ref(child_temp, dir_inverse);
699
700            // Fix Miri UB: Avoid calling parent_direction2(child) if child's parent is del,
701            // because that would create a aliasing &mut ref to del while we hold del_node.
702            if child == child_temp {
703                dir_child_temp = dir;
704            } else {
705                dir_child_temp = self.parent_direction2(child);
706            }
707
708            // Fix Miri UB: Do not call parent_direction2(del) as it creates a new &mut AvlNode
709            // alias while we hold del_node. Use del_node to find parent direction.
710            let parent = del_node.get_parent();
711            if !parent.is_null() {
712                dir_child_del = self.parent_direction(del, parent);
713            } else {
714                dir_child_del = AvlDirection::Left;
715            }
716
717            let child_node = unsafe { (*child).get_node() };
718            child_node.swap(del_node);
719
720            // move 'node' to delete's spot in the tree
721            if child_node.get_child(dir) == child {
722                // if node(d) left child is node(c)
723                child_node.set_child(dir, del);
724            }
725
726            let c_dir = child_node.get_child(dir);
727            if c_dir == del {
728                del_node.parent = child;
729            } else if !c_dir.is_null() {
730                unsafe { (*c_dir).get_node() }.parent = child;
731            }
732
733            let c_inv = child_node.get_child(dir_inverse);
734            if c_inv == del {
735                del_node.parent = child;
736            } else if !c_inv.is_null() {
737                unsafe { (*c_inv).get_node() }.parent = child;
738            }
739
740            let parent = child_node.get_parent();
741            if !parent.is_null() {
742                unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
743            } else {
744                self.root = child;
745            }
746
747            // Put tmp where node used to be (just temporary).
748            // It always has a parent and at most 1 child.
749            let parent = del_node.get_parent();
750            unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
751            if !del_node.right.is_null() {
752                which_child = AvlDirection::Right;
753            } else {
754                which_child = AvlDirection::Left;
755            }
756            let child = del_node.get_child(which_child);
757            if !child.is_null() {
758                unsafe { (*child).get_node() }.parent = del;
759            }
760            which_child = dir_child_temp;
761        } else {
762            // Fix Miri UB here as well
763            let parent = del_node.get_parent();
764            if !parent.is_null() {
765                which_child = self.parent_direction(del, parent);
766            } else {
767                which_child = AvlDirection::Left;
768            }
769        }
770
771        // Here we know "delete" is at least partially a leaf node. It can
772        // be easily removed from the tree.
773        parent = del_node.get_parent();
774
775        if !del_node.left.is_null() {
776            imm_data = del_node.left;
777        } else {
778            imm_data = del_node.right;
779        }
780
781        // Connect parent directly to node (leaving out delete).
782        if !imm_data.is_null() {
783            let imm_node = unsafe { (*imm_data).get_node() };
784            imm_node.parent = parent;
785        }
786
787        if !parent.is_null() {
788            assert!(self.count > 0);
789            self.count -= 1;
790
791            let parent_node = unsafe { (*parent).get_node() };
792            parent_node.set_child(which_child, imm_data);
793
794            //Since the subtree is now shorter, begin adjusting parent balances
795            //and performing any needed rotations.
796            let mut node_data: *const P::Target = parent;
797            let mut old_balance: i8;
798            let mut new_balance: i8;
799            loop {
800                // Move up the tree and adjust the balance.
801                // Capture the parent and which_child values for the next
802                // iteration before any rotations occur.
803                let node = unsafe { (*node_data).get_node() };
804                old_balance = node.balance;
805                new_balance = old_balance - avlchild_to_balance!(which_child);
806
807                //If a node was in perfect balance but isn't anymore then
808                //we can stop, since the height didn't change above this point
809                //due to a deletion.
810                if old_balance == 0 {
811                    node.balance = new_balance;
812                    break;
813                }
814
815                let parent = node.get_parent();
816                which_child = self.parent_direction(node_data, parent);
817
818                //If the new balance is zero, we don't need to rotate
819                //else
820                //need a rotation to fix the balance.
821                //If the rotation doesn't change the height
822                //of the sub-tree we have finished adjusting.
823                if new_balance == 0 {
824                    node.balance = new_balance;
825                } else if !self.rotate(node_data, new_balance) {
826                    break;
827                }
828
829                if !parent.is_null() {
830                    node_data = parent;
831                    continue;
832                }
833                break;
834            }
835        } else {
836            if !imm_data.is_null() {
837                assert!(self.count > 0);
838                self.count -= 1;
839                self.root = imm_data;
840            }
841        }
842        if self.root.is_null() {
843            if self.count > 0 {
844                panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
845            }
846        }
847        del_node.detach();
848    }
849
850    /// Removes a node from the tree by key.
851    ///
852    /// The `cmp_func` should compare the key `K` with the elements in the tree.
853    /// Returns `Some(P)` if an exact match was found and removed, `None` otherwise.
854    #[inline]
855    pub fn remove_by_key<K>(&mut self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>) -> Option<P> {
856        let result = self.find(val, cmp_func);
857        self.remove_with(unsafe { result.detach() })
858    }
859
860    /// remove with a previous search result
861    ///
862    /// - If the result is exact match, return the removed element ownership
863    /// - If the result is not exact match, return None
864    ///
865    /// # Safety
866    ///
867    /// Once the tree structure changed, previous search result is not safe to use anymore.
868    ///
869    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
870    /// to avoid the borrowing issue.
871    #[inline]
872    pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
873        if result.is_exact() {
874            unsafe {
875                let p = result.node;
876                self.remove(p);
877                Some(P::from_raw(p))
878            }
879        } else {
880            None
881        }
882    }
883
884    /// Searches for an element in the tree.
885    ///
886    /// The `cmp_func` should compare the key `K` with the elements in the tree.
887    /// Returns an [`AvlSearchResult`] which indicates if an exact match was found,
888    /// or where a new element should be inserted.
889    #[inline]
890    pub fn find<'a, K>(
891        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
892    ) -> AvlSearchResult<'a, P> {
893        if self.root.is_null() {
894            return AvlSearchResult::default();
895        }
896        let mut node_data = self.root;
897        loop {
898            let diff = cmp_func(val, unsafe { &*node_data });
899            match diff {
900                Ordering::Equal => {
901                    return AvlSearchResult {
902                        node: node_data,
903                        direction: None,
904                        _phan: PhantomData,
905                    };
906                }
907                Ordering::Less => {
908                    let node = unsafe { (*node_data).get_node() };
909                    let left = node.get_child(AvlDirection::Left);
910                    if left.is_null() {
911                        return AvlSearchResult {
912                            node: node_data,
913                            direction: Some(AvlDirection::Left),
914                            _phan: PhantomData,
915                        };
916                    }
917                    node_data = left;
918                }
919                Ordering::Greater => {
920                    let node = unsafe { (*node_data).get_node() };
921                    let right = node.get_child(AvlDirection::Right);
922                    if right.is_null() {
923                        return AvlSearchResult {
924                            node: node_data,
925                            direction: Some(AvlDirection::Right),
926                            _phan: PhantomData,
927                        };
928                    }
929                    node_data = right;
930                }
931            }
932        }
933    }
934
935    // for range tree, val may overlap multiple range(node), ensure return the smallest
936    #[inline]
937    pub fn find_contained<'a, K>(
938        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
939    ) -> Option<&'a P::Target> {
940        if self.root.is_null() {
941            return None;
942        }
943        let mut node_data = self.root;
944        let mut result_node: *const P::Target = null();
945        loop {
946            let diff = cmp_func(val, unsafe { &*node_data });
947            match diff {
948                Ordering::Equal => {
949                    let node = unsafe { (*node_data).get_node() };
950                    let left = node.get_child(AvlDirection::Left);
951                    result_node = node_data;
952                    if left.is_null() {
953                        break;
954                    } else {
955                        node_data = left;
956                    }
957                }
958                Ordering::Less => {
959                    let node = unsafe { (*node_data).get_node() };
960                    let left = node.get_child(AvlDirection::Left);
961                    if left.is_null() {
962                        break;
963                    }
964                    node_data = left;
965                }
966                Ordering::Greater => {
967                    let node = unsafe { (*node_data).get_node() };
968                    let right = node.get_child(AvlDirection::Right);
969                    if right.is_null() {
970                        break;
971                    }
972                    node_data = right;
973                }
974            }
975        }
976        if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
977    }
978
979    // for slab, return any block larger or equal than search param
980    #[inline]
981    pub fn find_larger_eq<'a, K>(
982        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
983    ) -> AvlSearchResult<'a, P> {
984        if self.root.is_null() {
985            return AvlSearchResult::default();
986        }
987        let mut node_data = self.root;
988        loop {
989            let diff = cmp_func(val, unsafe { &*node_data });
990            match diff {
991                Ordering::Equal => {
992                    return AvlSearchResult {
993                        node: node_data,
994                        direction: None,
995                        _phan: PhantomData,
996                    };
997                }
998                Ordering::Less => {
999                    return AvlSearchResult {
1000                        node: node_data,
1001                        direction: None,
1002                        _phan: PhantomData,
1003                    };
1004                }
1005                Ordering::Greater => {
1006                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1007                    if right.is_null() {
1008                        return AvlSearchResult {
1009                            node: null(),
1010                            direction: None,
1011                            _phan: PhantomData,
1012                        };
1013                    }
1014                    node_data = right;
1015                }
1016            }
1017        }
1018    }
1019
1020    /// For range tree
1021    #[inline]
1022    pub fn find_nearest<'a, K>(
1023        &'a self, val: &K, cmp_func: AvlCmpFunc<K, P::Target>,
1024    ) -> AvlSearchResult<'a, P> {
1025        if self.root.is_null() {
1026            return AvlSearchResult::default();
1027        }
1028
1029        let mut node_data = self.root;
1030        let mut nearest_node = null();
1031        loop {
1032            let diff = cmp_func(val, unsafe { &*node_data });
1033            match diff {
1034                Ordering::Equal => {
1035                    return AvlSearchResult {
1036                        node: node_data,
1037                        direction: None,
1038                        _phan: PhantomData,
1039                    };
1040                }
1041                Ordering::Less => {
1042                    nearest_node = node_data;
1043                    let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1044                    if left.is_null() {
1045                        break;
1046                    }
1047                    node_data = left;
1048                }
1049                Ordering::Greater => {
1050                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1051                    if right.is_null() {
1052                        break;
1053                    }
1054                    node_data = right;
1055                }
1056            }
1057        }
1058        return AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData };
1059    }
1060
1061    #[inline(always)]
1062    fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1063        loop {
1064            let child = unsafe { (*data).get_node() }.get_child(dir);
1065            if !child.is_null() {
1066                data = child;
1067            } else {
1068                return data;
1069            }
1070        }
1071    }
1072
1073    pub fn walk<F: Fn(&P::Target)>(&self, cb: F) {
1074        let mut node = self.first();
1075        while let Some(n) = node {
1076            cb(n);
1077            node = self.next(n);
1078        }
1079    }
1080
1081    #[inline]
1082    pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1083        if let Some(p) = self.walk_dir(data, AvlDirection::Right) {
1084            Some(unsafe { p.as_ref() })
1085        } else {
1086            None
1087        }
1088    }
1089
1090    #[inline]
1091    pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1092        if let Some(p) = self.walk_dir(data, AvlDirection::Left) {
1093            Some(unsafe { p.as_ref() })
1094        } else {
1095            None
1096        }
1097    }
1098
1099    #[inline]
1100    fn walk_dir(
1101        &self, mut data_ptr: *const P::Target, dir: AvlDirection,
1102    ) -> Option<NonNull<P::Target>> {
1103        let dir_inverse = dir.reverse();
1104        let node = unsafe { (*data_ptr).get_node() };
1105        let temp = node.get_child(dir);
1106        if !temp.is_null() {
1107            unsafe {
1108                Some(NonNull::new_unchecked(
1109                    self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1110                ))
1111            }
1112        } else {
1113            let mut parent = node.parent;
1114            if parent.is_null() {
1115                return None;
1116            }
1117            loop {
1118                let pdir = self.parent_direction(data_ptr, parent);
1119                if pdir == dir_inverse {
1120                    return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1121                }
1122                data_ptr = parent;
1123                parent = unsafe { (*parent).get_node() }.parent;
1124                if parent.is_null() {
1125                    return None;
1126                }
1127            }
1128        }
1129    }
1130
1131    #[inline]
1132    fn validate_node(&self, data: *const P::Target, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1133        let node = unsafe { (*data).get_node() };
1134        let left = node.left;
1135        if !left.is_null() {
1136            assert!(cmp_func(unsafe { &*left }, unsafe { &*data }) != Ordering::Greater);
1137            assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1138        }
1139        let right = node.right;
1140        if !right.is_null() {
1141            assert!(cmp_func(unsafe { &*right }, unsafe { &*data }) != Ordering::Less);
1142            assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1143        }
1144    }
1145
1146    #[inline]
1147    pub fn nearest<'a>(
1148        &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1149    ) -> AvlSearchResult<'a, P> {
1150        if !current.node.is_null() {
1151            if current.direction.is_some() && current.direction != Some(direction) {
1152                return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1153            }
1154            if let Some(node) = self.walk_dir(current.node, direction) {
1155                return AvlSearchResult {
1156                    node: node.as_ptr(),
1157                    direction: None,
1158                    _phan: PhantomData,
1159                };
1160            }
1161        }
1162        return AvlSearchResult::default();
1163    }
1164
1165    pub fn validate(&self, cmp_func: AvlCmpFunc<P::Target, P::Target>) {
1166        let c = {
1167            #[cfg(feature = "std")]
1168            {
1169                ((self.get_count() + 10) as f32).log2() as usize
1170            }
1171            #[cfg(not(feature = "std"))]
1172            {
1173                100
1174            }
1175        };
1176        let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1177        if self.root.is_null() {
1178            assert_eq!(self.count, 0);
1179            return;
1180        }
1181        let mut data = self.root;
1182        let mut visited = 0;
1183        loop {
1184            if !data.is_null() {
1185                let left = {
1186                    let node = unsafe { (*data).get_node() };
1187                    node.get_child(AvlDirection::Left)
1188                };
1189                if !left.is_null() {
1190                    stack.push(data);
1191                    data = left;
1192                    continue;
1193                }
1194                visited += 1;
1195                self.validate_node(data, cmp_func);
1196                data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1197            } else if stack.len() > 0 {
1198                let _data = stack.pop().unwrap();
1199                self.validate_node(_data, cmp_func);
1200                visited += 1;
1201                let node = unsafe { (*_data).get_node() };
1202                data = node.get_child(AvlDirection::Right);
1203            } else {
1204                break;
1205            }
1206        }
1207        assert_eq!(visited, self.count);
1208    }
1209
1210    /// Adds a new element to the tree, takes the ownership of P.
1211    ///
1212    /// The `cmp_func` should compare two elements to determine their relative order.
1213    /// Returns `true` if the element was added, `false` if an equivalent element
1214    /// already exists (in which case the provided `node` is dropped).
1215    #[inline]
1216    pub fn add(&mut self, node: P, cmp_func: AvlCmpFunc<P::Target, P::Target>) -> bool {
1217        if self.count == 0 && self.root.is_null() {
1218            self.root = node.into_raw();
1219            self.count = 1;
1220            return true;
1221        }
1222
1223        let w = self.find(node.as_ref(), cmp_func);
1224        if w.direction.is_none() {
1225            // To prevent memory leak, we must drop the node.
1226            // But since we took ownership, we have to convert it back to P and drop it.
1227            drop(node);
1228            return false;
1229        }
1230
1231        // Safety: We need to decouple the lifetime of 'w' from 'self' to call 'insert'.
1232        // We extract the pointers and reconstruct the result.
1233        let w_node = w.node;
1234        let w_dir = w.direction;
1235        // Drop w explicitly to release the borrow on self (though not strictly needed with NLL, being explicit helps)
1236        drop(w);
1237
1238        let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
1239
1240        self.insert(node, w_detached);
1241        return true;
1242    }
1243}
1244
1245impl<P, Tag> Drop for AvlTree<P, Tag>
1246where
1247    P: Pointer,
1248    P::Target: AvlItem<Tag>,
1249{
1250    fn drop(&mut self) {
1251        if mem::needs_drop::<P>() {
1252            for _ in self.drain() {}
1253        }
1254    }
1255}
1256
1257pub struct AvlDrain<'a, P: Pointer, Tag>
1258where
1259    P::Target: AvlItem<Tag>,
1260{
1261    tree: &'a mut AvlTree<P, Tag>,
1262    parent: *const P::Target,
1263    dir: Option<AvlDirection>,
1264}
1265
1266impl<'a, P: Pointer, Tag> Iterator for AvlDrain<'a, P, Tag>
1267where
1268    P::Target: AvlItem<Tag>,
1269{
1270    type Item = P;
1271
1272    fn next(&mut self) -> Option<Self::Item> {
1273        if self.tree.root.is_null() {
1274            return None;
1275        }
1276
1277        let mut node: *const P::Target;
1278        let parent: *const P::Target;
1279
1280        if self.dir.is_none() && self.parent.is_null() {
1281            // Initial call: find the leftmost node
1282            let mut curr = self.tree.root;
1283            while unsafe { !(*curr).get_node().left.is_null() } {
1284                curr = unsafe { (*curr).get_node().left };
1285            }
1286            node = curr;
1287        } else {
1288            parent = self.parent;
1289            if parent.is_null() {
1290                // Should not happen if root was nulled
1291                return None;
1292            }
1293
1294            let child_dir = self.dir.unwrap();
1295            // child_dir child of parent was just nulled in previous call?
1296            // NO, we null it in THIS call.
1297
1298            if child_dir == AvlDirection::Right || unsafe { (*parent).get_node().right.is_null() } {
1299                node = parent;
1300            } else {
1301                // Finished left, go to right sibling
1302                node = unsafe { (*parent).get_node().right };
1303                while unsafe { !(*node).get_node().left.is_null() } {
1304                    node = unsafe { (*node).get_node().left };
1305                }
1306            }
1307        }
1308
1309        // Goto check_right_side logic
1310        if unsafe { !(*node).get_node().right.is_null() } {
1311            // It has a right child, so we must yield that first (in post-order)
1312            // Note: in AVL, if left is null, right must be a leaf.
1313            node = unsafe { (*node).get_node().right };
1314        }
1315
1316        // Determine next state
1317        let next_parent = unsafe { (*node).get_node().parent };
1318        if next_parent.is_null() {
1319            self.tree.root = null();
1320            self.parent = null();
1321            self.dir = Some(AvlDirection::Left);
1322        } else {
1323            self.parent = next_parent;
1324            self.dir = Some(self.tree.parent_direction(node, next_parent));
1325            // Unlink from parent NOW
1326            unsafe { (*next_parent).get_node().set_child(self.dir.unwrap(), null()) };
1327        }
1328
1329        self.tree.count -= 1;
1330        unsafe {
1331            (*node).get_node().detach();
1332            Some(P::from_raw(node))
1333        }
1334    }
1335}
1336
1337impl<T, Tag> AvlTree<Arc<T>, Tag>
1338where
1339    T: AvlItem<Tag>,
1340{
1341    pub fn remove_ref(&mut self, node: &Arc<T>) {
1342        let p = Arc::as_ptr(node);
1343        unsafe { self.remove(p) };
1344        unsafe { drop(Arc::from_raw(p)) };
1345    }
1346}
1347
1348impl<T, Tag> AvlTree<Rc<T>, Tag>
1349where
1350    T: AvlItem<Tag>,
1351{
1352    pub fn remove_ref(&mut self, node: &Rc<T>) {
1353        let p = Rc::as_ptr(node);
1354        unsafe { self.remove(p) };
1355        unsafe { drop(Rc::from_raw(p)) };
1356    }
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361    use super::*;
1362    use core::cell::UnsafeCell;
1363    use rand::Rng;
1364    use std::time::Instant;
1365
1366    struct IntAvlNode {
1367        pub value: i64,
1368        pub node: UnsafeCell<AvlNode<Self, ()>>,
1369    }
1370
1371    impl fmt::Debug for IntAvlNode {
1372        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1373            write!(f, "{} {:#?}", self.value, self.node)
1374        }
1375    }
1376
1377    impl fmt::Display for IntAvlNode {
1378        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1379            write!(f, "{}", self.value)
1380        }
1381    }
1382
1383    unsafe impl AvlItem<()> for IntAvlNode {
1384        fn get_node(&self) -> &mut AvlNode<Self, ()> {
1385            unsafe { &mut *self.node.get() }
1386        }
1387    }
1388
1389    type IntAvlTree = AvlTree<Box<IntAvlNode>, ()>;
1390
1391    fn new_intnode(i: i64) -> Box<IntAvlNode> {
1392        Box::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: i })
1393    }
1394
1395    fn new_inttree() -> IntAvlTree {
1396        AvlTree::<Box<IntAvlNode>, ()>::new()
1397    }
1398
1399    fn cmp_int_node(a: &IntAvlNode, b: &IntAvlNode) -> Ordering {
1400        a.value.cmp(&b.value)
1401    }
1402
1403    fn cmp_int(a: &i64, b: &IntAvlNode) -> Ordering {
1404        a.cmp(&b.value)
1405    }
1406
1407    impl AvlTree<Box<IntAvlNode>, ()> {
1408        fn remove_int(&mut self, i: i64) -> bool {
1409            if let Some(_node) = self.remove_by_key(&i, cmp_int) {
1410                // node is Box<IntAvlNode>, dropped automatically
1411                return true;
1412            }
1413            // else
1414            println!("not found {}", i);
1415            false
1416        }
1417
1418        fn add_int_node(&mut self, node: Box<IntAvlNode>) -> bool {
1419            self.add(node, cmp_int_node)
1420        }
1421
1422        fn validate_tree(&self) {
1423            self.validate(cmp_int_node);
1424        }
1425
1426        fn find_int<'a>(&'a self, i: i64) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1427            self.find(&i, cmp_int)
1428        }
1429
1430        fn find_node<'a>(&'a self, node: &'a IntAvlNode) -> AvlSearchResult<'a, Box<IntAvlNode>> {
1431            self.find(node, cmp_int_node)
1432        }
1433    }
1434
1435    #[test]
1436    fn int_avl_node() {
1437        let mut tree = new_inttree();
1438
1439        assert_eq!(tree.get_count(), 0);
1440        assert!(tree.first().is_none());
1441        assert!(tree.last().is_none());
1442
1443        let node1 = new_intnode(1);
1444        let node2 = new_intnode(2);
1445        let node3 = new_intnode(3);
1446
1447        let p1 = &*node1 as *const IntAvlNode;
1448        let p2 = &*node2 as *const IntAvlNode;
1449        let p3 = &*node3 as *const IntAvlNode;
1450
1451        tree.set_child2(node1.get_node(), AvlDirection::Left, p2, p1);
1452        tree.set_child2(node2.get_node(), AvlDirection::Right, p3, p2);
1453
1454        assert_eq!(tree.parent_direction2(p2), AvlDirection::Left);
1455        // This is tricky as node1 is not in a tree, its parent is not set.
1456        // assert_eq!(tree.parent_direction2(p1), AvlDirection::Left);
1457        assert_eq!(tree.parent_direction2(p3), AvlDirection::Right);
1458    }
1459
1460    #[test]
1461    fn int_avl_tree_basic() {
1462        let mut tree = new_inttree();
1463
1464        let temp_node = new_intnode(0);
1465        let temp_node_val = Pointer::as_ref(&temp_node);
1466        assert!(tree.find_node(temp_node_val).get_node_ref().is_none());
1467        assert_eq!(
1468            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Left).is_exact(),
1469            false
1470        );
1471        assert_eq!(
1472            tree.nearest(&tree.find_node(temp_node_val), AvlDirection::Right).is_exact(),
1473            false
1474        );
1475        drop(temp_node);
1476
1477        tree.add_int_node(new_intnode(0));
1478        let result = tree.find_int(0);
1479        assert!(result.get_node_ref().is_some());
1480        assert_eq!(tree.nearest(&result, AvlDirection::Left).is_exact(), false);
1481        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1482
1483        let rs = tree.find_larger_eq(&0, cmp_int).get_node_ref();
1484        assert!(rs.is_some());
1485        let found_value = rs.unwrap().value;
1486        assert_eq!(found_value, 0);
1487
1488        let rs = tree.find_larger_eq(&2, cmp_int).get_node_ref();
1489        assert!(rs.is_none());
1490
1491        let result = tree.find_int(1);
1492        let left = tree.nearest(&result, AvlDirection::Left);
1493        assert_eq!(left.is_exact(), true);
1494        assert_eq!(left.get_nearest().unwrap().value, 0);
1495        assert_eq!(tree.nearest(&result, AvlDirection::Right).is_exact(), false);
1496
1497        tree.add_int_node(new_intnode(2));
1498        let rs = tree.find_larger_eq(&1, cmp_int).get_node_ref();
1499        assert!(rs.is_some());
1500        let found_value = rs.unwrap().value;
1501        assert_eq!(found_value, 2);
1502    }
1503
1504    #[test]
1505    fn int_avl_tree_order() {
1506        let max;
1507        #[cfg(miri)]
1508        {
1509            max = 2000;
1510        }
1511        #[cfg(not(miri))]
1512        {
1513            max = 200000;
1514        }
1515        let mut tree = new_inttree();
1516        assert!(tree.first().is_none());
1517        let start_ts = Instant::now();
1518        for i in 0..max {
1519            tree.add_int_node(new_intnode(i));
1520        }
1521        tree.validate_tree();
1522        assert_eq!(tree.get_count(), max as i64);
1523
1524        let mut count = 0;
1525        let mut current = tree.first();
1526        let last = tree.last();
1527        while let Some(c) = current {
1528            assert_eq!(c.value, count);
1529            count += 1;
1530            if c as *const _ == last.map(|n| n as *const _).unwrap_or(null()) {
1531                current = None;
1532            } else {
1533                current = tree.next(c);
1534            }
1535        }
1536        assert_eq!(count, max);
1537
1538        {
1539            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1540            assert!(rs.is_some());
1541            let found_value = rs.unwrap().value;
1542            println!("found larger_eq {}", found_value);
1543            assert!(found_value >= 5);
1544            tree.remove_int(5);
1545            let rs = tree.find_larger_eq(&5, cmp_int).get_node_ref();
1546            assert!(rs.is_some());
1547            assert!(rs.unwrap().value >= 6);
1548            tree.add_int_node(new_intnode(5));
1549        }
1550
1551        for i in 0..max {
1552            assert!(tree.remove_int(i));
1553        }
1554        assert_eq!(tree.get_count(), 0);
1555
1556        let end_ts = Instant::now();
1557        println!("duration {}", end_ts.duration_since(start_ts).as_secs_f64());
1558    }
1559
1560    #[test]
1561    fn int_avl_tree_fixed1() {
1562        let mut tree = new_inttree();
1563        let arr = [4719789032060327248, 7936680652950253153, 5197008094511783121];
1564        for i in arr.iter() {
1565            let node = new_intnode(*i);
1566            tree.add_int_node(node);
1567            let rs = tree.find_int(*i);
1568            assert!(rs.get_node_ref().is_some(), "add error {}", i);
1569        }
1570        assert_eq!(tree.get_count(), arr.len() as i64);
1571        for i in arr.iter() {
1572            assert!(tree.remove_int(*i));
1573        }
1574        assert_eq!(tree.get_count(), 0);
1575    }
1576
1577    #[test]
1578    fn int_avl_tree_fixed2() {
1579        let mut tree = new_inttree();
1580        tree.validate_tree();
1581        let node1 = new_intnode(536872960);
1582        {
1583            tree.add_int_node(node1);
1584            tree.validate_tree();
1585            tree.remove_int(536872960);
1586            tree.validate_tree();
1587            tree.add_int_node(new_intnode(536872960));
1588            tree.validate_tree();
1589        }
1590
1591        assert!(tree.find_int(536872960).get_node_ref().is_some());
1592        let node2 = new_intnode(12288);
1593        tree.add_int_node(node2);
1594        tree.validate_tree();
1595        tree.remove_int(536872960);
1596        tree.validate_tree();
1597        tree.add_int_node(new_intnode(536872960));
1598        tree.validate_tree();
1599        let node3 = new_intnode(22528);
1600        tree.add_int_node(node3);
1601        tree.validate_tree();
1602        tree.remove_int(12288);
1603        assert!(tree.find_int(12288).get_node_ref().is_none());
1604        tree.validate_tree();
1605        tree.remove_int(22528);
1606        assert!(tree.find_int(22528).get_node_ref().is_none());
1607        tree.validate_tree();
1608        tree.add_int_node(new_intnode(22528));
1609        tree.validate_tree();
1610    }
1611
1612    #[test]
1613    fn int_avl_tree_random() {
1614        let count = 1000;
1615        let mut test_list: Vec<i64> = Vec::with_capacity(count);
1616        let mut rng = rand::thread_rng();
1617        let mut tree = new_inttree();
1618        tree.validate_tree();
1619        for _ in 0..count {
1620            let node_value: i64 = rng.r#gen();
1621            if !test_list.contains(&node_value) {
1622                test_list.push(node_value);
1623                assert!(tree.add_int_node(new_intnode(node_value)))
1624            }
1625        }
1626        tree.validate_tree();
1627        test_list.sort();
1628        for index in 0..test_list.len() {
1629            let node_ptr = tree.find_int(test_list[index]).get_node_ref().unwrap();
1630            let prev = tree.prev(node_ptr);
1631            let next = tree.next(node_ptr);
1632            if index == 0 {
1633                // first node
1634                assert!(prev.is_none());
1635                assert!(next.is_some());
1636                assert_eq!(next.unwrap().value, test_list[index + 1]);
1637            } else if index == test_list.len() - 1 {
1638                // last node
1639                assert!(prev.is_some());
1640                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1641                assert!(next.is_none());
1642            } else {
1643                // middle node
1644                assert!(prev.is_some());
1645                assert_eq!(prev.unwrap().value, test_list[index - 1]);
1646                assert!(next.is_some());
1647                assert_eq!(next.unwrap().value, test_list[index + 1]);
1648            }
1649        }
1650        for index in 0..test_list.len() {
1651            assert!(tree.remove_int(test_list[index]));
1652        }
1653        tree.validate_tree();
1654        assert_eq!(0, tree.get_count());
1655    }
1656
1657    #[test]
1658    fn int_avl_tree_insert_here() {
1659        let mut tree = new_inttree();
1660        let node1 = new_intnode(10);
1661        tree.add_int_node(node1);
1662        // Insert 5 before 10
1663        let rs = tree.find_int(10);
1664        let here = unsafe { rs.detach() };
1665        unsafe { tree.insert_here(new_intnode(5), here, AvlDirection::Left) };
1666        tree.validate_tree();
1667        assert_eq!(tree.get_count(), 2);
1668        assert_eq!(tree.find_int(5).get_node_ref().unwrap().value, 5);
1669
1670        // Insert 15 after 10
1671        let rs = tree.find_int(10);
1672        let here = unsafe { rs.detach() };
1673        unsafe { tree.insert_here(new_intnode(15), here, AvlDirection::Right) };
1674        tree.validate_tree();
1675        assert_eq!(tree.get_count(), 3);
1676        assert_eq!(tree.find_int(15).get_node_ref().unwrap().value, 15);
1677
1678        // Insert 3 before 5 (which is left child of 10)
1679        let rs = tree.find_int(5);
1680        let here = unsafe { rs.detach() };
1681        unsafe { tree.insert_here(new_intnode(3), here, AvlDirection::Left) };
1682        tree.validate_tree();
1683        assert_eq!(tree.get_count(), 4);
1684
1685        // Insert 7 after 5
1686        let rs = tree.find_int(5);
1687        let here = unsafe { rs.detach() };
1688        unsafe { tree.insert_here(new_intnode(7), here, AvlDirection::Right) };
1689        tree.validate_tree();
1690        assert_eq!(tree.get_count(), 5);
1691    }
1692
1693    #[test]
1694    fn test_arc_avl_tree_get_exact() {
1695        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1696        // Manually constructing Arc node
1697        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 100 });
1698        tree.add(node.clone(), cmp_int_node);
1699
1700        // find returns AvlSearchResult<'a, Arc<IntAvlNode>>
1701        let result_search = tree.find(&100, cmp_int);
1702
1703        // This should invoke the specialized get_exact for Arc<T>
1704        let exact = result_search.get_exact();
1705        assert!(exact.is_some());
1706        let exact_arc = exact.unwrap();
1707        assert_eq!(exact_arc.value, 100);
1708        assert!(Arc::ptr_eq(&node, &exact_arc));
1709        // Check ref count: 1 (original) + 1 (in tree) + 1 (exact_arc) = 3
1710        assert_eq!(Arc::strong_count(&node), 3);
1711    }
1712
1713    #[test]
1714    fn test_arc_avl_tree_remove_ref() {
1715        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1716        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 200 });
1717        tree.add(node.clone(), cmp_int_node);
1718        assert_eq!(tree.get_count(), 1);
1719        assert_eq!(Arc::strong_count(&node), 2);
1720
1721        tree.remove_ref(&node);
1722        assert_eq!(tree.get_count(), 0);
1723        assert_eq!(Arc::strong_count(&node), 1);
1724    }
1725
1726    #[test]
1727    fn test_arc_avl_tree_remove_with() {
1728        let mut tree = AvlTree::<Arc<IntAvlNode>, ()>::new();
1729        let node = Arc::new(IntAvlNode { node: UnsafeCell::new(AvlNode::default()), value: 300 });
1730        tree.add(node.clone(), cmp_int_node);
1731
1732        let removed = tree.remove_by_key(&300, cmp_int);
1733        assert!(removed.is_some());
1734        let removed_arc = removed.unwrap();
1735        assert_eq!(removed_arc.value, 300);
1736        assert_eq!(tree.get_count(), 0);
1737        // count: 1 (node) + 1 (removed_arc) = 2. Tree dropped its count.
1738        assert_eq!(Arc::strong_count(&node), 2);
1739
1740        drop(removed_arc);
1741        assert_eq!(Arc::strong_count(&node), 1);
1742    }
1743
1744    #[test]
1745    fn test_avl_drain() {
1746        let mut tree = new_inttree();
1747        for i in 0..100 {
1748            tree.add_int_node(new_intnode(i));
1749        }
1750        assert_eq!(tree.get_count(), 100);
1751
1752        let mut count = 0;
1753        for node in tree.drain() {
1754            assert!(node.value >= 0 && node.value < 100);
1755            count += 1;
1756        }
1757        assert_eq!(count, 100);
1758        assert_eq!(tree.get_count(), 0);
1759        assert!(tree.first().is_none());
1760    }
1761}