Skip to main content

embed_collections/avl/
mod.rs

1//! An intrusive AVL tree implementation.
2//!
3//! The algorithm origin from open-zfs
4//!
5//! The ordering of avl tree is defined by [AvlItem] trait. There're two scenario:
6//! - borrow field as key from the node struct
7//! - impl Ord for struct itself
8//!
9//! Refer to the example of [AvlItem] document.
10//!
11//! # Examples
12//!
13//! ## Using `Box` to search and remove by key
14//!
15//! ```rust
16//! use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
17//! use core::cell::UnsafeCell;
18//! use core::cmp::Ordering;
19//! extern crate alloc;
20//! use alloc::sync::Arc;
21//!
22//! struct MyNode {
23//!     value: i32,
24//!     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
25//! }
26//!
27//! unsafe impl AvlItem<()> for MyNode {
28//!
29//!     type Key = i32;
30//!
31//!     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
32//!         unsafe { &mut *self.avl_node.get() }
33//!     }
34//!
35//!     fn borrow_key(&self) -> &Self::Key {
36//!         &self.value
37//!     }
38//! }
39//!
40//! // A boxed example
41//!
42//! let mut tree = AvlTree::<Box<MyNode>, ()>::new();
43//! tree.add(Box::new(MyNode { value: 10, avl_node: UnsafeCell::new(Default::default()) }));
44//!
45//! // Search and remove
46//! if let Some(node) = tree.remove_by_key(&10) {
47//!     assert_eq!(node.value, 10);
48//! }
49//! assert!(tree.is_empty());
50//!
51//! // Using `Arc` for multiple ownership
52//! // remove_ref only available to `Arc` and `Rc`
53//!
54//! let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
55//! let node = Arc::new(MyNode { value: 42, avl_node: UnsafeCell::new(Default::default()) });
56//!
57//! tree.add(node.clone());
58//! assert_eq!(tree.len(), 1);
59//!
60//! // Remove by reference (detach from avl tree)
61//! tree.remove_ref(&node);
62//! assert_eq!(tree.len(), 0);
63//! ```
64
65mod iter;
66pub use iter::*;
67#[cfg(test)]
68mod tests;
69
70use crate::Pointer;
71use alloc::rc::Rc;
72use alloc::sync::Arc;
73use alloc::vec::Vec;
74use core::marker::PhantomData;
75use core::{
76    cmp::{Ordering, PartialEq},
77    fmt, mem,
78    ptr::{NonNull, null},
79};
80
81/// A trait to return internal mutable AvlNode for specified list.
82///
83/// The tag is used to distinguish different AvlNodes within the same item,
84/// allowing an item to belong to multiple lists simultaneously.
85/// For only one ownership, you can use `()`.
86///
87/// # Safety
88///
89/// Implementors must ensure `get_node` returns a valid reference to the `AvlNode`
90/// embedded within `Self`. Users must use `UnsafeCell` to hold `AvlNode` to support
91/// interior mutability required by list operations.
92///
93/// # Example (struct field as key)
94///
95/// ```
96/// use embed_collections::avl::{AvlItem, AvlNode};
97/// use core::cell::UnsafeCell;
98///
99/// pub struct IntAvlNode {
100///     pub value: i64,
101///     pub node: UnsafeCell<AvlNode<Self, ()>>,
102/// }
103/// unsafe impl AvlItem<()> for IntAvlNode {
104///     type Key = i64;
105///
106///     fn get_node(&self) -> &mut AvlNode<Self, ()> {
107///         unsafe { &mut *self.node.get() }
108///     }
109///
110///     fn borrow_key(&self) -> &Self::Key {
111///         &self.value
112///     }
113/// }
114/// ```
115///
116/// # Example (impl Ord for struct)
117///
118/// ```
119/// use embed_collections::avl::{AvlItem, AvlNode};
120/// use core::cell::{Cell, UnsafeCell};
121/// use core::cmp::Ordering;
122///
123/// pub struct RangeSeg {
124///     node: UnsafeCell<AvlNode<Self, ()>>,
125///     pub start: Cell<u64>,
126///     pub end: Cell<u64>,
127/// }
128///
129/// impl PartialOrd for RangeSeg {
130///     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
131///         Some(Ord::cmp(self, other))
132///     }
133/// }
134/// impl PartialEq for RangeSeg {
135///     fn eq(&self, other: &Self) -> bool {
136///         Ord::cmp(self, other) == Ordering::Equal
137///     }
138/// }
139/// impl Eq for RangeSeg {}
140///
141/// impl Ord for RangeSeg {
142///     fn cmp(&self, other: &Self) -> Ordering {
143///         if self.end.get() <= other.start.get() {
144///             Ordering::Less
145///         } else if self.start.get() >= other.end.get() {
146///             Ordering::Greater
147///         } else {
148///             // intersection
149///             Ordering::Equal
150///         }
151///     }
152/// }
153/// unsafe impl Send for RangeSeg {}
154/// unsafe impl AvlItem<()> for RangeSeg {
155///     type Key = Self;
156///
157///     fn get_node(&self) -> &mut AvlNode<Self, ()> {
158///         unsafe { &mut *self.node.get() }
159///     }
160///
161///     fn borrow_key(&self) -> &Self::Key {
162///         self
163///     }
164/// }
165/// ```
166pub unsafe trait AvlItem<Tag>: Sized {
167    type Key: Ord;
168
169    fn get_node(&self) -> &mut AvlNode<Self, Tag>;
170
171    #[inline(always)]
172    fn cmp(&self, other: &Self) -> Ordering {
173        self.borrow_key().cmp(other.borrow_key())
174    }
175
176    #[inline(always)]
177    fn cmp_key(&self, key: &Self::Key) -> Ordering {
178        key.cmp(self.borrow_key())
179    }
180
181    fn borrow_key(&self) -> &Self::Key;
182}
183
184#[derive(PartialEq, Debug, Copy, Clone)]
185pub enum AvlDirection {
186    Left = 0,
187    Right = 1,
188}
189
190impl AvlDirection {
191    #[inline(always)]
192    fn reverse(self) -> AvlDirection {
193        match self {
194            AvlDirection::Left => AvlDirection::Right,
195            AvlDirection::Right => AvlDirection::Left,
196        }
197    }
198}
199
200macro_rules! avlchild_to_balance {
201    ( $dir: expr ) => {
202        match $dir {
203            AvlDirection::Left => -1,
204            AvlDirection::Right => 1,
205        }
206    };
207}
208
209macro_rules! as_avlitem {
210    ($P: tt, $Tag: tt, $name: ident) => {
211        <$P::Target as AvlItem<$Tag>>::$name
212    };
213}
214
215pub struct AvlNode<T: Sized, Tag> {
216    pub left: *const T,
217    pub right: *const T,
218    pub parent: *const T,
219    pub balance: i8,
220    _phan: PhantomData<fn(&Tag)>,
221}
222
223unsafe impl<T, Tag> Send for AvlNode<T, Tag> {}
224
225impl<T: AvlItem<Tag>, Tag> AvlNode<T, Tag> {
226    #[inline(always)]
227    pub fn detach(&mut self) {
228        self.left = null();
229        self.right = null();
230        self.parent = null();
231        self.balance = 0;
232    }
233
234    #[inline(always)]
235    fn get_child(&self, dir: AvlDirection) -> *const T {
236        match dir {
237            AvlDirection::Left => self.left,
238            AvlDirection::Right => self.right,
239        }
240    }
241
242    #[inline(always)]
243    fn set_child(&mut self, dir: AvlDirection, child: *const T) {
244        match dir {
245            AvlDirection::Left => self.left = child,
246            AvlDirection::Right => self.right = child,
247        }
248    }
249
250    #[inline(always)]
251    fn get_parent(&self) -> *const T {
252        self.parent
253    }
254
255    // Swap two node but not there value
256    #[inline(always)]
257    pub fn swap(&mut self, other: &mut AvlNode<T, Tag>) {
258        mem::swap(&mut self.left, &mut other.left);
259        mem::swap(&mut self.right, &mut other.right);
260        mem::swap(&mut self.parent, &mut other.parent);
261        mem::swap(&mut self.balance, &mut other.balance);
262    }
263}
264
265impl<T, Tag> Default for AvlNode<T, Tag> {
266    fn default() -> Self {
267        Self { left: null(), right: null(), parent: null(), balance: 0, _phan: Default::default() }
268    }
269}
270
271#[allow(unused_must_use)]
272impl<T: AvlItem<Tag>, Tag> fmt::Debug for AvlNode<T, Tag> {
273    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
274        write!(f, "(")?;
275
276        if !self.left.is_null() {
277            write!(f, "left: {:p}", self.left)?;
278        } else {
279            write!(f, "left: none ")?;
280        }
281
282        if !self.right.is_null() {
283            write!(f, "right: {:p}", self.right)?;
284        } else {
285            write!(f, "right: none ")?;
286        }
287        write!(f, ")")
288    }
289}
290
291pub type AvlCmpFunc<K, T> = fn(&K, &T) -> Ordering;
292
293/// An intrusive AVL tree (balanced binary search tree).
294///
295/// Elements in the tree must implement the [`AvlItem`] trait.
296/// The tree supports various pointer types (`Box`, `Arc`, `Rc`, etc.) through the [`Pointer`] trait.
297pub struct AvlTree<P, Tag>
298where
299    P: Pointer,
300    P::Target: AvlItem<Tag>,
301{
302    root: *const P::Target,
303    count: usize,
304    _phan: PhantomData<fn(P, &Tag)>,
305}
306
307unsafe impl<P, Tag> Send for AvlTree<P, Tag>
308where
309    P: Pointer + Send,
310    P::Target: AvlItem<Tag>,
311{
312}
313
314/// Result of a search operation in an [`AvlTree`].
315///
316/// An `AvlSearchResult` identifies either:
317/// 1. An exact match: `direction` is `None` and `node` points to the matching item.
318/// 2. An insertion point: `direction` is `Some(dir)` and `node` points to the parent
319///    where a new node should be attached as the `dir` child.
320///
321/// The lifetime `'a` ties the search result to the tree's borrow, ensuring safety.
322/// However, this lifetime often prevents further mutable operations on the tree
323/// (e.g., adding a node while holding the search result). Use [`detach`](Self::detach)
324/// to de-couple the result from the tree's lifetime when necessary.
325pub struct AvlSearchResult<'a, P: Pointer> {
326    /// The matching node or the parent for insertion.
327    pub node: *const P::Target,
328    /// `None` if exact match found, or `Some(direction)` indicating insertion point.
329    pub direction: Option<AvlDirection>,
330    _phan: PhantomData<&'a P::Target>,
331}
332
333impl<P: Pointer> Default for AvlSearchResult<'_, P> {
334    fn default() -> Self {
335        AvlSearchResult { node: null(), direction: Some(AvlDirection::Left), _phan: PhantomData }
336    }
337}
338
339impl<'a, P: Pointer> AvlSearchResult<'a, P> {
340    /// Returns a reference to the matching node if the search was an exact match.
341    #[inline(always)]
342    pub fn get_node_ref(&self) -> Option<&'a P::Target> {
343        if self.is_exact() { unsafe { self.node.as_ref() } } else { None }
344    }
345
346    /// Returns `true` if the search result is an exact match.
347    #[inline(always)]
348    pub fn is_exact(&self) -> bool {
349        self.direction.is_none() && !self.node.is_null()
350    }
351
352    /// De-couple the lifetime of the search result from the tree.
353    ///
354    /// This method is essential for performing mutable operations on the tree
355    /// using search results. In Rust, a search result typically borrows the tree
356    /// immutably. If you need to modify the tree (e.g., call `insert` or `remove`)
357    /// based on that result, the borrow checker would normally prevent it.
358    ///
359    /// `detach` effectively "erases" the lifetime `'a`, returning a result with
360    /// an unbounded lifetime `'b`.
361    ///
362    /// # Examples
363    ///
364    /// Used in `RangeTree::add`:
365    /// ```ignore
366    /// let result = self.root.find(&rs_key);
367    /// // result is AvlSearchResult<'a, ...> and borrows self.root
368    ///
369    /// let detached = unsafe { result.detach() };
370    /// // detached has no lifetime bound to self.root
371    ///
372    /// self.space += size; // Mutable operation on self permitted
373    /// self.merge_seg(start, end, detached); // Mutation on tree permitted
374    /// ```
375    ///
376    /// # Safety
377    /// This is an unsafe operation. The compiler no longer protects the validity
378    /// of the internal pointer via lifetimes. You must ensure that the tree
379    /// structure is not modified in a way that invalidates `node` (e.g., the
380    /// parent node being removed) before using the detached result.
381    #[inline(always)]
382    pub unsafe fn detach<'b>(&'a self) -> AvlSearchResult<'b, P> {
383        AvlSearchResult { node: self.node, direction: self.direction, _phan: PhantomData }
384    }
385
386    /// Return the nearest node in the search result
387    #[inline(always)]
388    pub fn get_nearest(&self) -> Option<&P::Target> {
389        if self.node.is_null() { None } else { unsafe { self.node.as_ref() } }
390    }
391}
392
393impl<'a, T> AvlSearchResult<'a, Arc<T>> {
394    /// Returns the matching Arc node if this is an exact match.
395    pub fn get_exact(&self) -> Option<Arc<T>> {
396        if self.is_exact() {
397            unsafe {
398                Arc::increment_strong_count(self.node);
399                Some(Arc::from_raw(self.node))
400            }
401        } else {
402            None
403        }
404    }
405}
406
407impl<'a, T> AvlSearchResult<'a, Rc<T>> {
408    /// Returns the matching Rc node if this is an exact match.
409    pub fn get_exact(&self) -> Option<Rc<T>> {
410        if self.is_exact() {
411            unsafe {
412                Rc::increment_strong_count(self.node);
413                Some(Rc::from_raw(self.node))
414            }
415        } else {
416            None
417        }
418    }
419}
420
421macro_rules! return_end {
422    ($tree: expr, $dir: expr) => {{ if $tree.root.is_null() { null() } else { $tree.bottom_child_ref($tree.root, $dir) } }};
423}
424
425macro_rules! balance_to_child {
426    ($balance: expr) => {
427        match $balance {
428            0 | 1 => AvlDirection::Left,
429            _ => AvlDirection::Right,
430        }
431    };
432}
433
434impl<P, Tag> AvlTree<P, Tag>
435where
436    P: Pointer,
437    P::Target: AvlItem<Tag>,
438{
439    /// Creates a new, empty `AvlTree`.
440    #[inline]
441    pub fn new() -> Self {
442        AvlTree { count: 0, root: null(), _phan: Default::default() }
443    }
444
445    /// Returns an iterator that removes all elements from the tree in post-order.
446    ///
447    /// This is an optimized, non-recursive, and stack-less traversal that preserves
448    /// tree invariants during destruction.
449    #[inline]
450    pub fn drain(&mut self) -> AvlDrain<'_, P, Tag> {
451        AvlDrain::new(self)
452    }
453
454    #[inline]
455    pub fn is_empty(&self) -> bool {
456        self.count == 0
457    }
458
459    #[inline]
460    pub fn len(&self) -> usize {
461        self.count
462    }
463
464    /// Adds a new element to the tree, takes the ownership of P.
465    ///
466    /// The `cmp_func` should compare two elements to determine their relative order.
467    /// Returns `true` if the element was added, `false` if an equivalent element
468    /// already exists (in which case the provided `node` is dropped).
469    #[inline]
470    pub fn add(&mut self, node: P) -> bool {
471        if self.count == 0 && self.root.is_null() {
472            self.root = node.into_raw();
473            self.count = 1;
474            return true;
475        }
476
477        let w = self._find(node.as_ref(), as_avlitem!(P, Tag, cmp));
478        if w.direction.is_none() {
479            // To prevent memory leak, we must drop the node.
480            // But since we took ownership, we have to convert it back to P and drop it.
481            drop(node);
482            return false;
483        }
484
485        // Safety: We need to decouple the lifetime of 'w' from 'self' to call 'insert'.
486        // We extract the pointers and reconstruct the result.
487        let w_node = w.node;
488        let w_dir = w.direction;
489
490        let w_detached = AvlSearchResult { node: w_node, direction: w_dir, _phan: PhantomData };
491
492        self.insert(node, w_detached);
493        true
494    }
495
496    /// Inserts a new node into the tree at the location specified by a search result.
497    ///
498    /// This is typically used after a [`find`](Self::find) operation didn't find an exact match.
499    ///
500    /// # Safety
501    ///
502    /// Once the tree structure changed, previous search result is not safe to use anymore.
503    ///
504    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
505    /// to avoid the borrowing issue.
506    ///
507    /// # Panics
508    /// Panics if the search result is an exact match (i.e. node already exists).
509    ///
510    /// # Examples
511    ///
512    /// ```rust
513    /// use embed_collections::avl::{AvlTree, AvlItem, AvlNode};
514    /// use core::cell::UnsafeCell;
515    /// extern crate alloc;
516    /// use alloc::sync::Arc;
517    ///
518    /// struct MyNode {
519    ///     value: i32,
520    ///     avl_node: UnsafeCell<AvlNode<MyNode, ()>>,
521    /// }
522    ///
523    /// unsafe impl AvlItem<()> for MyNode {
524    ///     type Key = i32;
525    ///
526    ///     fn get_node(&self) -> &mut AvlNode<MyNode, ()> {
527    ///         unsafe { &mut *self.avl_node.get() }
528    ///     }
529    ///
530    ///     fn borrow_key(&self) -> &Self::Key {
531    ///         &self.value
532    ///     }
533    /// }
534    ///
535    /// let mut tree = AvlTree::<Arc<MyNode>, ()>::new();
536    /// let key = 42;
537    /// let result = tree.find(&key);
538    ///
539    /// if !result.is_exact() {
540    ///     let new_node = Arc::new(MyNode {
541    ///         value: key,
542    ///         avl_node: UnsafeCell::new(Default::default()),
543    ///     });
544    ///     tree.insert(new_node, unsafe{result.detach()});
545    /// }
546    /// ```
547    #[inline]
548    pub fn insert(&mut self, new_data: P, w: AvlSearchResult<'_, P>) {
549        debug_assert!(w.direction.is_some());
550        self._insert(new_data, w.node, w.direction.unwrap());
551    }
552
553    #[allow(clippy::not_unsafe_ptr_arg_deref)]
554    pub fn _insert(
555        &mut self,
556        new_data: P,
557        here: *const P::Target, // parent
558        mut which_child: AvlDirection,
559    ) {
560        let mut new_balance: i8;
561        let new_ptr = new_data.into_raw();
562
563        if here.is_null() {
564            if self.count > 0 {
565                panic!("insert into a tree size {} with empty where.node", self.count);
566            }
567            self.root = new_ptr;
568            self.count += 1;
569            return;
570        }
571
572        let parent = unsafe { &*here };
573        let node = unsafe { (*new_ptr).get_node() };
574        let parent_node = parent.get_node();
575        node.parent = here;
576        parent_node.set_child(which_child, new_ptr);
577        self.count += 1;
578
579        /*
580         * Now, back up the tree modifying the balance of all nodes above the
581         * insertion point. If we get to a highly unbalanced ancestor, we
582         * need to do a rotation.  If we back out of the tree we are done.
583         * If we brought any subtree into perfect balance (0), we are also done.
584         */
585        let mut data: *const P::Target = here;
586        loop {
587            let node = unsafe { (*data).get_node() };
588            let old_balance = node.balance;
589            new_balance = old_balance + avlchild_to_balance!(which_child);
590            if new_balance == 0 {
591                node.balance = 0;
592                return;
593            }
594            if old_balance != 0 {
595                self.rotate(data, new_balance);
596                return;
597            }
598            node.balance = new_balance;
599            let parent_ptr = node.get_parent();
600            if parent_ptr.is_null() {
601                return;
602            }
603            which_child = self.parent_direction(data, parent_ptr);
604            data = parent_ptr;
605        }
606    }
607
608    /// Insert "new_data" in "tree" in the given "direction" either after or
609    /// before AvlDirection::After, AvlDirection::Before) the data "here".
610    ///
611    /// Insertions can only be done at empty leaf points in the tree, therefore
612    /// if the given child of the node is already present we move to either
613    /// the AVL_PREV or AVL_NEXT and reverse the insertion direction. Since
614    /// every other node in the tree is a leaf, this always works.
615    ///
616    /// # Safety
617    ///
618    /// Once the tree structure changed, previous search result is not safe to use anymore.
619    ///
620    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
621    /// to avoid the borrowing issue.
622    pub unsafe fn insert_here(
623        &mut self, new_data: P, here: AvlSearchResult<P>, direction: AvlDirection,
624    ) {
625        let mut dir_child = direction;
626        assert!(!here.node.is_null());
627        let here_node = here.node;
628        let child = unsafe { (*here_node).get_node().get_child(dir_child) };
629        if !child.is_null() {
630            dir_child = dir_child.reverse();
631            let node = self.bottom_child_ref(child, dir_child);
632            self._insert(new_data, node, dir_child);
633        } else {
634            self._insert(new_data, here_node, dir_child);
635        }
636    }
637
638    // set child and both child's parent
639    #[inline(always)]
640    fn set_child2(
641        &mut self, node: &mut AvlNode<P::Target, Tag>, dir: AvlDirection, child: *const P::Target,
642        parent: *const P::Target,
643    ) {
644        if !child.is_null() {
645            unsafe { (*child).get_node().parent = parent };
646        }
647        node.set_child(dir, child);
648    }
649
650    #[inline(always)]
651    fn parent_direction(&self, data: *const P::Target, parent: *const P::Target) -> AvlDirection {
652        if !parent.is_null() {
653            let parent_node = unsafe { (*parent).get_node() };
654            if parent_node.left == data {
655                return AvlDirection::Left;
656            }
657            if parent_node.right == data {
658                return AvlDirection::Right;
659            }
660            panic!("invalid avl tree, node {:p}, parent {:p}", data, parent);
661        }
662        // this just follow zfs
663        AvlDirection::Left
664    }
665
666    #[inline(always)]
667    fn parent_direction2(&self, data: *const P::Target) -> AvlDirection {
668        let node = unsafe { (*data).get_node() };
669        let parent = node.get_parent();
670        if !parent.is_null() {
671            return self.parent_direction(data, parent);
672        }
673        // this just follow zfs
674        AvlDirection::Left
675    }
676
677    #[inline]
678    fn rotate(&mut self, data: *const P::Target, balance: i8) -> bool {
679        let dir = if balance < 0 { AvlDirection::Left } else { AvlDirection::Right };
680        let node = unsafe { (*data).get_node() };
681
682        let parent = node.get_parent();
683        let dir_inverse = dir.reverse();
684        let left_heavy = balance >> 1;
685        let right_heavy = -left_heavy;
686
687        let child = node.get_child(dir);
688        let child_node = unsafe { (*child).get_node() };
689        let mut child_balance = child_node.balance;
690
691        let which_child = self.parent_direction(data, parent);
692
693        // node is overly left heavy, the left child is balanced or also left heavy.
694        if child_balance != right_heavy {
695            child_balance += right_heavy;
696
697            let c_right = child_node.get_child(dir_inverse);
698            self.set_child2(node, dir, c_right, data);
699            // move node to be child's right child
700            node.balance = -child_balance;
701
702            node.parent = child;
703            child_node.set_child(dir_inverse, data);
704            // update the pointer into this subtree
705
706            child_node.balance = child_balance;
707            if !parent.is_null() {
708                child_node.parent = parent;
709                unsafe { (*parent).get_node() }.set_child(which_child, child);
710            } else {
711                child_node.parent = null();
712                self.root = child;
713            }
714            return child_balance == 0;
715        }
716        // When node is left heavy, but child is right heavy we use
717        // a different rotation.
718
719        let g_child = child_node.get_child(dir_inverse);
720        let g_child_node = unsafe { (*g_child).get_node() };
721        let g_left = g_child_node.get_child(dir);
722        let g_right = g_child_node.get_child(dir_inverse);
723
724        self.set_child2(node, dir, g_right, data);
725        self.set_child2(child_node, dir_inverse, g_left, child);
726
727        /*
728         * move child to left child of gchild and
729         * move node to right child of gchild and
730         * fixup parent of all this to point to gchild
731         */
732
733        let g_child_balance = g_child_node.balance;
734        if g_child_balance == right_heavy {
735            child_node.balance = left_heavy;
736        } else {
737            child_node.balance = 0;
738        }
739        child_node.parent = g_child;
740        g_child_node.set_child(dir, child);
741
742        if g_child_balance == left_heavy {
743            node.balance = right_heavy;
744        } else {
745            node.balance = 0;
746        }
747        g_child_node.balance = 0;
748
749        node.parent = g_child;
750        g_child_node.set_child(dir_inverse, data);
751
752        if !parent.is_null() {
753            g_child_node.parent = parent;
754            unsafe { (*parent).get_node() }.set_child(which_child, g_child);
755        } else {
756            g_child_node.parent = null();
757            self.root = g_child;
758        }
759        true
760    }
761
762    /*
763    fn replace(&mut self, old: *const P::Target, node: P) {
764        let old_node = unsafe { (*old).get_node() };
765        let new_ptr = node.into_raw();
766        let new_node = unsafe { (*new_ptr).get_node() };
767
768        let left = old_node.get_child(AvlDirection::Left);
769        if !left.is_null() {
770            self.set_child2(new_node, AvlDirection::Left, left, new_ptr);
771        }
772        let right = old_node.get_child(AvlDirection::Right);
773        if !right.is_null() {
774            self.set_child2(new_node, AvlDirection::Right, right, new_ptr);
775        }
776
777        new_node.balance = old_node.balance;
778        old_node.balance = 0;
779        let parent = old_node.get_parent();
780        if !parent.is_null() {
781            let dir = self.parent_direction(old, parent);
782            self.set_child2(unsafe { (*parent).get_node() }, dir, new_ptr, parent);
783            old_node.parent = null();
784        } else {
785            debug_assert_eq!(self.root, old);
786            self.root = new_ptr;
787        }
788    }
789    */
790
791    /// Requires `del` to be a valid pointer to a node in this tree.
792    ///
793    /// # Safety
794    ///
795    /// It does not drop the node data, only unlinks it.
796    /// Caller is responsible for re-taking ownership (e.g. via from_raw) and dropping if needed.
797    ///
798    /// For Arc/Rc, use [Self::remove_ref()] instead.
799    ///
800    pub unsafe fn remove(&mut self, del: *const P::Target) {
801        /*
802         * Deletion is easiest with a node that has at most 1 child.
803         * We swap a node with 2 children with a sequentially valued
804         * neighbor node. That node will have at most 1 child. Note this
805         * has no effect on the ordering of the remaining nodes.
806         *
807         * As an optimization, we choose the greater neighbor if the tree
808         * is right heavy, otherwise the left neighbor. This reduces the
809         * number of rotations needed.
810         */
811        if self.count == 0 {
812            return;
813        }
814        if self.count == 1 && self.root == del {
815            self.root = null();
816            self.count = 0;
817            unsafe { (*del).get_node().detach() };
818            return;
819        }
820        let mut which_child: AvlDirection;
821
822        // Use reference directly to get node, avoiding unsafe dereference of raw pointer
823        let del_node = unsafe { (*del).get_node() };
824
825        let node_swap_flag = !del_node.left.is_null() && !del_node.right.is_null();
826
827        if node_swap_flag {
828            let dir: AvlDirection = balance_to_child!(del_node.balance + 1);
829            let child_temp = del_node.get_child(dir);
830
831            let dir_inverse: AvlDirection = dir.reverse();
832            let child = self.bottom_child_ref(child_temp, dir_inverse);
833
834            // Fix Miri UB: Avoid calling parent_direction2(child) if child's parent is del,
835            // because that would create a aliasing &mut ref to del while we hold del_node.
836            let dir_child_temp =
837                if child == child_temp { dir } else { self.parent_direction2(child) };
838
839            // Fix Miri UB: Do not call parent_direction2(del) as it creates a new &mut AvlNode
840            // alias while we hold del_node. Use del_node to find parent direction.
841            let parent = del_node.get_parent();
842            let dir_child_del = if !parent.is_null() {
843                self.parent_direction(del, parent)
844            } else {
845                AvlDirection::Left
846            };
847
848            let child_node = unsafe { (*child).get_node() };
849            child_node.swap(del_node);
850
851            // move 'node' to delete's spot in the tree
852            if child_node.get_child(dir) == child {
853                // if node(d) left child is node(c)
854                child_node.set_child(dir, del);
855            }
856
857            let c_dir = child_node.get_child(dir);
858            if c_dir == del {
859                del_node.parent = child;
860            } else if !c_dir.is_null() {
861                unsafe { (*c_dir).get_node() }.parent = child;
862            }
863
864            let c_inv = child_node.get_child(dir_inverse);
865            if c_inv == del {
866                del_node.parent = child;
867            } else if !c_inv.is_null() {
868                unsafe { (*c_inv).get_node() }.parent = child;
869            }
870
871            let parent = child_node.get_parent();
872            if !parent.is_null() {
873                unsafe { (*parent).get_node() }.set_child(dir_child_del, child);
874            } else {
875                self.root = child;
876            }
877
878            // Put tmp where node used to be (just temporary).
879            // It always has a parent and at most 1 child.
880            let parent = del_node.get_parent();
881            unsafe { (*parent).get_node() }.set_child(dir_child_temp, del);
882            if !del_node.right.is_null() {
883                which_child = AvlDirection::Right;
884            } else {
885                which_child = AvlDirection::Left;
886            }
887            let child = del_node.get_child(which_child);
888            if !child.is_null() {
889                unsafe { (*child).get_node() }.parent = del;
890            }
891            which_child = dir_child_temp;
892        } else {
893            // Fix Miri UB here as well
894            let parent = del_node.get_parent();
895            if !parent.is_null() {
896                which_child = self.parent_direction(del, parent);
897            } else {
898                which_child = AvlDirection::Left;
899            }
900        }
901
902        // Here we know "delete" is at least partially a leaf node. It can
903        // be easily removed from the tree.
904        let parent: *const P::Target = del_node.get_parent();
905
906        let imm_data: *const P::Target =
907            if !del_node.left.is_null() { del_node.left } else { del_node.right };
908
909        // Connect parent directly to node (leaving out delete).
910        if !imm_data.is_null() {
911            let imm_node = unsafe { (*imm_data).get_node() };
912            imm_node.parent = parent;
913        }
914
915        if !parent.is_null() {
916            assert!(self.count > 0);
917            self.count -= 1;
918
919            let parent_node = unsafe { (*parent).get_node() };
920            parent_node.set_child(which_child, imm_data);
921
922            //Since the subtree is now shorter, begin adjusting parent balances
923            //and performing any needed rotations.
924            let mut node_data: *const P::Target = parent;
925            let mut old_balance: i8;
926            let mut new_balance: i8;
927            loop {
928                // Move up the tree and adjust the balance.
929                // Capture the parent and which_child values for the next
930                // iteration before any rotations occur.
931                let node = unsafe { (*node_data).get_node() };
932                old_balance = node.balance;
933                new_balance = old_balance - avlchild_to_balance!(which_child);
934
935                //If a node was in perfect balance but isn't anymore then
936                //we can stop, since the height didn't change above this point
937                //due to a deletion.
938                if old_balance == 0 {
939                    node.balance = new_balance;
940                    break;
941                }
942
943                let parent = node.get_parent();
944                which_child = self.parent_direction(node_data, parent);
945
946                //If the new balance is zero, we don't need to rotate
947                //else
948                //need a rotation to fix the balance.
949                //If the rotation doesn't change the height
950                //of the sub-tree we have finished adjusting.
951                if new_balance == 0 {
952                    node.balance = new_balance;
953                } else if !self.rotate(node_data, new_balance) {
954                    break;
955                }
956
957                if !parent.is_null() {
958                    node_data = parent;
959                    continue;
960                }
961                break;
962            }
963        } else if !imm_data.is_null() {
964            debug_assert!(self.count > 0);
965            self.count -= 1;
966            self.root = imm_data;
967        }
968        if self.root.is_null() && self.count > 0 {
969            panic!("AvlTree {} nodes left after remove but tree.root == nil", self.count);
970        }
971        del_node.detach();
972    }
973
974    /// Removes a node from the tree by key.
975    ///
976    /// The `cmp_func` should compare the key `K` with the elements in the tree.
977    /// Returns `Some(P)` if an exact match was found and removed, `None` otherwise.
978    #[inline]
979    pub fn remove_by_key(&mut self, val: &'_ as_avlitem!(P, Tag, Key)) -> Option<P> {
980        let result = self.find(val);
981        self.remove_with(unsafe { result.detach() })
982    }
983
984    /// remove with a previous search result
985    ///
986    /// - If the result is exact match, return the removed element ownership
987    /// - If the result is not exact match, return None
988    ///
989    /// # Safety
990    ///
991    /// Once the tree structure changed, previous search result is not safe to use anymore.
992    ///
993    /// You should [detach()](AvlSearchResult::detach) the result before calling insert,
994    /// to avoid the borrowing issue.
995    #[inline]
996    pub fn remove_with(&mut self, result: AvlSearchResult<'_, P>) -> Option<P> {
997        if result.is_exact() {
998            unsafe {
999                let p = result.node;
1000                self.remove(p);
1001                Some(P::from_raw(p))
1002            }
1003        } else {
1004            None
1005        }
1006    }
1007
1008    /// Searches for an element in the tree.
1009    ///
1010    /// Returns an [`AvlSearchResult`] which indicates if an exact match was found,
1011    /// or where a new element should be inserted.
1012    #[inline]
1013    pub fn find<'a>(&'a self, val: &'_ as_avlitem!(P, Tag, Key)) -> AvlSearchResult<'a, P> {
1014        self._find::<as_avlitem!(P, Tag, Key)>(val, |_val, other| _val.cmp(other.borrow_key()))
1015    }
1016
1017    #[inline]
1018    fn _find<'a, K>(
1019        &'a self, val: &'_ K, cmp_func: AvlCmpFunc<K, P::Target>,
1020    ) -> AvlSearchResult<'a, P> {
1021        if self.root.is_null() {
1022            return AvlSearchResult::default();
1023        }
1024        let mut node_data = self.root;
1025        loop {
1026            let diff = cmp_func(val, unsafe { &*node_data });
1027            match diff {
1028                Ordering::Equal => {
1029                    return AvlSearchResult {
1030                        node: node_data,
1031                        direction: None,
1032                        _phan: PhantomData,
1033                    };
1034                }
1035                Ordering::Less => {
1036                    let node = unsafe { (*node_data).get_node() };
1037                    let left = node.get_child(AvlDirection::Left);
1038                    if left.is_null() {
1039                        return AvlSearchResult {
1040                            node: node_data,
1041                            direction: Some(AvlDirection::Left),
1042                            _phan: PhantomData,
1043                        };
1044                    }
1045                    node_data = left;
1046                }
1047                Ordering::Greater => {
1048                    let node = unsafe { (*node_data).get_node() };
1049                    let right = node.get_child(AvlDirection::Right);
1050                    if right.is_null() {
1051                        return AvlSearchResult {
1052                            node: node_data,
1053                            direction: Some(AvlDirection::Right),
1054                            _phan: PhantomData,
1055                        };
1056                    }
1057                    node_data = right;
1058                }
1059            }
1060        }
1061    }
1062
1063    // for range tree, val may overlap multiple range(node), ensure return the smallest
1064    #[inline]
1065    pub fn find_contained<'a>(
1066        &'a self, val: &'_ <P::Target as AvlItem<Tag>>::Key,
1067    ) -> Option<&'a P::Target> {
1068        if self.root.is_null() {
1069            return None;
1070        }
1071        let mut node_data = self.root;
1072        let mut result_node: *const P::Target = null();
1073        loop {
1074            let diff = unsafe { &*node_data }.cmp_key(val);
1075            match diff {
1076                Ordering::Equal => {
1077                    let node = unsafe { (*node_data).get_node() };
1078                    let left = node.get_child(AvlDirection::Left);
1079                    result_node = node_data;
1080                    if left.is_null() {
1081                        break;
1082                    } else {
1083                        node_data = left;
1084                    }
1085                }
1086                Ordering::Less => {
1087                    let node = unsafe { (*node_data).get_node() };
1088                    let left = node.get_child(AvlDirection::Left);
1089                    if left.is_null() {
1090                        break;
1091                    }
1092                    node_data = left;
1093                }
1094                Ordering::Greater => {
1095                    let node = unsafe { (*node_data).get_node() };
1096                    let right = node.get_child(AvlDirection::Right);
1097                    if right.is_null() {
1098                        break;
1099                    }
1100                    node_data = right;
1101                }
1102            }
1103        }
1104        if result_node.is_null() { None } else { unsafe { result_node.as_ref() } }
1105    }
1106
1107    // for slab, return any block larger or equal than search param
1108    #[inline]
1109    pub fn find_larger_eq<'a>(
1110        &'a self, val: &'_ <P::Target as AvlItem<Tag>>::Key,
1111    ) -> AvlSearchResult<'a, P> {
1112        if self.root.is_null() {
1113            return AvlSearchResult::default();
1114        }
1115        let mut node_data = self.root;
1116        loop {
1117            let diff = unsafe { &*node_data }.cmp_key(val);
1118            match diff {
1119                Ordering::Equal => {
1120                    return AvlSearchResult {
1121                        node: node_data,
1122                        direction: None,
1123                        _phan: PhantomData,
1124                    };
1125                }
1126                Ordering::Less => {
1127                    return AvlSearchResult {
1128                        node: node_data,
1129                        direction: None,
1130                        _phan: PhantomData,
1131                    };
1132                }
1133                Ordering::Greater => {
1134                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1135                    if right.is_null() {
1136                        return AvlSearchResult {
1137                            node: null(),
1138                            direction: None,
1139                            _phan: PhantomData,
1140                        };
1141                    }
1142                    node_data = right;
1143                }
1144            }
1145        }
1146    }
1147
1148    /// For range tree
1149    #[inline]
1150    pub fn find_nearest<'a, K>(
1151        &'a self, val: &'_ <P::Target as AvlItem<Tag>>::Key,
1152    ) -> AvlSearchResult<'a, P> {
1153        if self.root.is_null() {
1154            return AvlSearchResult::default();
1155        }
1156
1157        let mut node_data = self.root;
1158        let mut nearest_node = null();
1159        loop {
1160            let diff = unsafe { &*node_data }.cmp_key(val);
1161            match diff {
1162                Ordering::Equal => {
1163                    return AvlSearchResult {
1164                        node: node_data,
1165                        direction: None,
1166                        _phan: PhantomData,
1167                    };
1168                }
1169                Ordering::Less => {
1170                    nearest_node = node_data;
1171                    let left = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Left);
1172                    if left.is_null() {
1173                        break;
1174                    }
1175                    node_data = left;
1176                }
1177                Ordering::Greater => {
1178                    let right = unsafe { (*node_data).get_node() }.get_child(AvlDirection::Right);
1179                    if right.is_null() {
1180                        break;
1181                    }
1182                    node_data = right;
1183                }
1184            }
1185        }
1186        AvlSearchResult { node: nearest_node, direction: None, _phan: PhantomData }
1187    }
1188
1189    #[inline(always)]
1190    fn bottom_child_ref(&self, mut data: *const P::Target, dir: AvlDirection) -> *const P::Target {
1191        loop {
1192            let child = unsafe { (*data).get_node() }.get_child(dir);
1193            if !child.is_null() {
1194                data = child;
1195            } else {
1196                return data;
1197            }
1198        }
1199    }
1200
1201    /// return a iterator to get the reference
1202    ///
1203    /// NOTE: If you use the [Iterator] interface (for iteration), you only get &P::Target.
1204    ///
1205    /// you can use [AvlIter::next_ref()] to get &P.
1206    #[inline]
1207    pub fn iter(&self) -> AvlIter<'_, P, Tag> {
1208        let first_p = if !self.root.is_null() {
1209            Some(unsafe {
1210                NonNull::new_unchecked(
1211                    self.bottom_child_ref(self.root, AvlDirection::Left) as *mut P::Target
1212                )
1213            })
1214        } else {
1215            None
1216        };
1217        AvlIter::new(self, first_p, AvlDirection::Right)
1218    }
1219
1220    /// return a reversed iterator to get the reference.
1221    ///
1222    ///
1223    /// NOTE: If you use the [Iterator] interface (for iteration), you only get &P::Target.
1224    ///
1225    /// you can use [AvlIter::next_ref()] to get &P.
1226    #[inline]
1227    pub fn iter_rev(&self) -> AvlIter<'_, P, Tag> {
1228        let last_p = if !self.root.is_null() {
1229            Some(unsafe {
1230                NonNull::new_unchecked(
1231                    self.bottom_child_ref(self.root, AvlDirection::Right) as *mut P::Target
1232                )
1233            })
1234        } else {
1235            None
1236        };
1237        AvlIter::new(self, last_p, AvlDirection::Left)
1238    }
1239
1240    #[inline]
1241    pub fn next<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1242        if let Some(p) = self.walk_dir(data.into(), AvlDirection::Right) {
1243            Some(unsafe { p.as_ref() })
1244        } else {
1245            None
1246        }
1247    }
1248
1249    #[inline]
1250    pub fn prev<'a>(&'a self, data: &'a P::Target) -> Option<&'a P::Target> {
1251        if let Some(p) = self.walk_dir(data.into(), AvlDirection::Left) {
1252            Some(unsafe { p.as_ref() })
1253        } else {
1254            None
1255        }
1256    }
1257
1258    #[inline]
1259    fn walk_dir(
1260        &self, mut data_ptr: NonNull<P::Target>, dir: AvlDirection,
1261    ) -> Option<NonNull<P::Target>> {
1262        let dir_inverse = dir.reverse();
1263        let node = unsafe { data_ptr.as_ref().get_node() };
1264        let temp = node.get_child(dir);
1265        if !temp.is_null() {
1266            unsafe {
1267                Some(NonNull::new_unchecked(
1268                    self.bottom_child_ref(temp, dir_inverse) as *mut P::Target
1269                ))
1270            }
1271        } else {
1272            let mut parent = node.parent;
1273            if parent.is_null() {
1274                return None;
1275            }
1276            loop {
1277                let pdir = self.parent_direction(data_ptr.as_ptr(), parent);
1278                if pdir == dir_inverse {
1279                    return Some(unsafe { NonNull::new_unchecked(parent as *mut P::Target) });
1280                }
1281                let data_ptr_raw = parent as *mut P::Target;
1282                parent = unsafe { (*parent).get_node() }.parent;
1283                if parent.is_null() {
1284                    return None;
1285                }
1286                unsafe {
1287                    data_ptr = NonNull::new_unchecked(data_ptr_raw);
1288                }
1289            }
1290        }
1291    }
1292
1293    #[inline]
1294    fn validate_node(&self, data: *const P::Target) {
1295        let node = unsafe { (*data).get_node() };
1296        let left = node.left;
1297        if !left.is_null() {
1298            assert!(unsafe { &*left }.cmp(unsafe { &*data }) != Ordering::Greater);
1299            assert_eq!(unsafe { (*left).get_node() }.get_parent(), data);
1300        }
1301        let right = node.right;
1302        if !right.is_null() {
1303            assert!(unsafe { &*right }.cmp(unsafe { &*data }) != Ordering::Less);
1304            assert_eq!(unsafe { (*right).get_node() }.get_parent(), data);
1305        }
1306    }
1307
1308    #[inline]
1309    pub fn first(&self) -> Option<&P::Target> {
1310        unsafe { return_end!(self, AvlDirection::Left).as_ref() }
1311    }
1312
1313    #[inline]
1314    pub fn last(&self) -> Option<&P::Target> {
1315        unsafe { return_end!(self, AvlDirection::Right).as_ref() }
1316    }
1317
1318    #[inline]
1319    pub fn nearest<'a>(
1320        &'a self, current: &AvlSearchResult<'a, P>, direction: AvlDirection,
1321    ) -> AvlSearchResult<'a, P> {
1322        if !current.node.is_null() {
1323            if current.direction.is_some() && current.direction != Some(direction) {
1324                return AvlSearchResult { node: current.node, direction: None, _phan: PhantomData };
1325            }
1326            if let Some(node) = self.walk_dir(
1327                unsafe { NonNull::new_unchecked(current.node as *mut P::Target) },
1328                direction,
1329            ) {
1330                return AvlSearchResult {
1331                    node: node.as_ptr(),
1332                    direction: None,
1333                    _phan: PhantomData,
1334                };
1335            }
1336        }
1337        AvlSearchResult::default()
1338    }
1339
1340    pub fn validate(&self) {
1341        let c = {
1342            #[cfg(feature = "std")]
1343            {
1344                ((self.len() + 10) as f32).log2() as usize
1345            }
1346            #[cfg(not(feature = "std"))]
1347            {
1348                100
1349            }
1350        };
1351        let mut stack: Vec<*const P::Target> = Vec::with_capacity(c);
1352        if self.root.is_null() {
1353            assert_eq!(self.count, 0);
1354            return;
1355        }
1356        let mut data = self.root;
1357        let mut visited = 0;
1358        loop {
1359            if !data.is_null() {
1360                let left = {
1361                    let node = unsafe { (*data).get_node() };
1362                    node.get_child(AvlDirection::Left)
1363                };
1364                if !left.is_null() {
1365                    stack.push(data);
1366                    data = left;
1367                    continue;
1368                }
1369                visited += 1;
1370                self.validate_node(data);
1371                data = unsafe { (*data).get_node() }.get_child(AvlDirection::Right);
1372            } else if !stack.is_empty() {
1373                let _data = stack.pop().unwrap();
1374                self.validate_node(_data);
1375                visited += 1;
1376                let node = unsafe { (*_data).get_node() };
1377                data = node.get_child(AvlDirection::Right);
1378            } else {
1379                break;
1380            }
1381        }
1382        assert_eq!(visited, self.count);
1383    }
1384}
1385
1386impl<P, Tag> Drop for AvlTree<P, Tag>
1387where
1388    P: Pointer,
1389    P::Target: AvlItem<Tag>,
1390{
1391    fn drop(&mut self) {
1392        if mem::needs_drop::<P>() {
1393            for _ in self.drain() {}
1394        }
1395    }
1396}
1397
1398impl<T, Tag> AvlTree<Arc<T>, Tag>
1399where
1400    T: AvlItem<Tag>,
1401{
1402    pub fn remove_ref(&mut self, node: &Arc<T>) {
1403        let p = Arc::as_ptr(node);
1404        unsafe { self.remove(p) };
1405        unsafe { drop(Arc::from_raw(p)) };
1406    }
1407}
1408
1409impl<T, Tag> AvlTree<Rc<T>, Tag>
1410where
1411    T: AvlItem<Tag>,
1412{
1413    pub fn remove_ref(&mut self, node: &Rc<T>) {
1414        let p = Rc::as_ptr(node);
1415        unsafe { self.remove(p) };
1416        unsafe { drop(Rc::from_raw(p)) };
1417    }
1418}