congee/
tree.rs

1use std::{marker::PhantomData, ptr::NonNull, sync::Arc};
2
3use crossbeam_epoch::Guard;
4
5use crate::{
6    Allocator, DefaultAllocator,
7    base_node::{BaseNode, Node, Prefix},
8    error::{ArtError, OOMError},
9    lock::ReadGuard,
10    node_4::Node4,
11    node_256::Node256,
12    node_ptr::{ChildIsPayload, ChildIsSubNode, NodePtr, PtrType},
13    range_scan::RangeScan,
14    utils::Backoff,
15};
16
17/// Raw interface to the ART tree.
18/// The `Art` is a wrapper around the `RawArt` that provides a safe interface.
19pub(crate) struct RawCongee<
20    const K_LEN: usize,
21    A: Allocator + Clone + Send + 'static = DefaultAllocator,
22> {
23    pub(crate) root: NonNull<Node256>,
24    drain_callback: Arc<dyn Fn([u8; K_LEN], usize)>,
25    allocator: A,
26    _pt_key: PhantomData<[u8; K_LEN]>,
27}
28
29unsafe impl<const K_LEN: usize, A: Allocator + Clone + Send> Send for RawCongee<K_LEN, A> {}
30unsafe impl<const K_LEN: usize, A: Allocator + Clone + Send> Sync for RawCongee<K_LEN, A> {}
31
32impl<const K_LEN: usize> Default for RawCongee<K_LEN> {
33    fn default() -> Self {
34        Self::new(DefaultAllocator {}, Arc::new(|_: [u8; K_LEN], _: usize| {}))
35    }
36}
37
38pub(crate) trait CongeeVisitor<const K_LEN: usize> {
39    fn visit_payload(&mut self, _key: [u8; K_LEN], _payload: usize) {}
40    fn pre_visit_sub_node(&mut self, _node: NonNull<BaseNode>, _tree_level: usize) {}
41    fn post_visit_sub_node(&mut self, _node: NonNull<BaseNode>, _tree_level: usize) {}
42}
43
44struct DropVisitor<const K_LEN: usize, A: Allocator + Clone + Send> {
45    allocator: A,
46    drain_callback: Arc<dyn Fn([u8; K_LEN], usize)>,
47}
48
49impl<const K_LEN: usize, A: Allocator + Clone + Send> CongeeVisitor<K_LEN>
50    for DropVisitor<K_LEN, A>
51{
52    fn visit_payload(&mut self, key: [u8; K_LEN], payload: usize) {
53        (self.drain_callback)(key, payload);
54    }
55
56    fn post_visit_sub_node(&mut self, node: NonNull<BaseNode>, _tree_level: usize) {
57        unsafe {
58            BaseNode::drop_node(node, self.allocator.clone());
59        }
60    }
61}
62
63struct LeafNodeKeyVisitor<const K_LEN: usize> {
64    keys: Vec<[u8; K_LEN]>,
65}
66
67impl<const K_LEN: usize> CongeeVisitor<K_LEN> for LeafNodeKeyVisitor<K_LEN> {
68    fn visit_payload(&mut self, key: [u8; K_LEN], _payload: usize) {
69        self.keys.push(key);
70    }
71}
72
73struct ValueCountVisitor<const K_LEN: usize> {
74    value_count: usize,
75}
76
77impl<const K_LEN: usize> CongeeVisitor<K_LEN> for ValueCountVisitor<K_LEN> {
78    fn visit_payload(&mut self, _key: [u8; K_LEN], _payload: usize) {
79        self.value_count += 1;
80    }
81}
82
83impl<const K_LEN: usize, A: Allocator + Clone + Send> Drop for RawCongee<K_LEN, A> {
84    fn drop(&mut self) {
85        let mut visitor = DropVisitor::<K_LEN, A> {
86            allocator: self.allocator.clone(),
87            drain_callback: self.drain_callback.clone(),
88        };
89        self.dfs_visitor_slow(&mut visitor).unwrap();
90
91        // see this: https://github.com/XiangpengHao/congee/issues/20
92        for _ in 0..128 {
93            crossbeam_epoch::pin().flush();
94        }
95    }
96}
97
98impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
99    pub fn new(allocator: A, drain_callback: Arc<dyn Fn([u8; K_LEN], usize)>) -> Self {
100        let root = BaseNode::make_node::<Node256>(&[], &allocator)
101            .expect("Can't allocate memory for root node!");
102        RawCongee {
103            root: root.into_non_null(),
104            drain_callback,
105            allocator,
106            _pt_key: PhantomData,
107        }
108    }
109}
110
111impl<const K_LEN: usize, A: Allocator + Clone + Send> RawCongee<K_LEN, A> {
112    pub(crate) fn is_empty(&self, _guard: &Guard) -> bool {
113        loop {
114            if let Ok(node) = BaseNode::read_lock_root(self.root) {
115                let is_empty = node.as_ref().meta.count == 0;
116                if node.check_version().is_ok() {
117                    return is_empty;
118                }
119            }
120        }
121    }
122
123    #[inline]
124    pub(crate) fn get(&self, key: &[u8; K_LEN], _guard: &Guard) -> Option<usize> {
125        'outer: loop {
126            let mut level = 0;
127
128            let mut node = if let Ok(v) = BaseNode::read_lock_root(self.root) {
129                v
130            } else {
131                continue;
132            };
133
134            loop {
135                level = Self::check_prefix(node.as_ref(), key, level)?;
136
137                let child_node = node
138                    .as_ref()
139                    .get_child(unsafe { *key.get_unchecked(level) });
140                if node.check_version().is_err() {
141                    continue 'outer;
142                }
143
144                let child_node = child_node?;
145
146                match child_node.downcast::<K_LEN>(level) {
147                    PtrType::Payload(tid) => {
148                        return Some(tid);
149                    }
150                    PtrType::SubNode(sub_node) => {
151                        level += 1;
152
153                        node = if let Ok(n) = BaseNode::read_lock(sub_node) {
154                            n
155                        } else {
156                            continue 'outer;
157                        };
158                    }
159                }
160            }
161        }
162    }
163
164    pub(crate) fn keys(&self) -> Vec<[u8; K_LEN]> {
165        loop {
166            let mut visitor = LeafNodeKeyVisitor::<K_LEN> { keys: Vec::new() };
167            if self.dfs_visitor_slow(&mut visitor).is_ok() {
168                return visitor.keys;
169            }
170        }
171    }
172
173    fn is_last_level<'a>(current_level: usize) -> Result<ChildIsPayload<'a>, ChildIsSubNode<'a>> {
174        if current_level == (K_LEN - 1) {
175            Ok(ChildIsPayload::new())
176        } else {
177            Err(ChildIsSubNode::new())
178        }
179    }
180
181    /// Depth-First Search visitor implemented recursively, use with caution
182    pub(crate) fn dfs_visitor_slow<V: CongeeVisitor<K_LEN>>(
183        &self,
184        visitor: &mut V,
185    ) -> Result<(), ArtError> {
186        let first = VisitingEntry::SubNode(unsafe {
187            std::mem::transmute::<NonNull<Node256>, NonNull<BaseNode>>(self.root)
188        });
189        Self::recursive_dfs(first, 0, visitor)?;
190        Ok(())
191    }
192
193    fn recursive_dfs<V: CongeeVisitor<K_LEN>>(
194        node: VisitingEntry<K_LEN>,
195        tree_level: usize,
196        visitor: &mut V,
197    ) -> Result<(), ArtError> {
198        match node {
199            VisitingEntry::Payload((k, v)) => {
200                visitor.visit_payload(k, v);
201            }
202            VisitingEntry::SubNode(node_ptr) => {
203                visitor.pre_visit_sub_node(node_ptr, tree_level);
204                let node_lock = BaseNode::read_lock(node_ptr)?;
205                let children = node_lock.as_ref().get_children(0, 255);
206                for (k, child_ptr) in children {
207                    let next = match child_ptr.downcast::<K_LEN>(node_lock.as_ref().prefix().len())
208                    {
209                        PtrType::Payload(tid) => {
210                            let mut key: [u8; K_LEN] = [0; K_LEN];
211                            let prefix = node_lock.as_ref().prefix();
212                            key[0..prefix.len()].copy_from_slice(prefix);
213                            key[prefix.len()] = k;
214                            VisitingEntry::Payload((key, tid))
215                        }
216                        PtrType::SubNode(sub_node) => VisitingEntry::SubNode(sub_node),
217                    };
218
219                    Self::recursive_dfs(next, tree_level + 1, visitor)?;
220                }
221                node_lock.check_version()?;
222                visitor.post_visit_sub_node(node_ptr, tree_level);
223            }
224        }
225        Ok(())
226    }
227
228    /// Returns the number of values in the tree.
229    pub(crate) fn value_count(&self, _guard: &Guard) -> usize {
230        loop {
231            let mut visitor = ValueCountVisitor::<K_LEN> { value_count: 0 };
232            if self.dfs_visitor_slow(&mut visitor).is_ok() {
233                return visitor.value_count;
234            }
235        }
236    }
237
238    #[inline]
239    fn insert_inner<F>(
240        &self,
241        k: &[u8; K_LEN],
242        tid_func: &mut F,
243        guard: &Guard,
244    ) -> Result<Option<usize>, ArtError>
245    where
246        F: FnMut(Option<usize>) -> usize,
247    {
248        let mut parent_node = None;
249        let mut node = BaseNode::read_lock_root(self.root)?;
250        let mut parent_key: u8;
251        let mut node_key: u8 = 0;
252        let mut level = 0usize;
253
254        loop {
255            parent_key = node_key;
256
257            let mut next_level = level;
258            let res = self.check_prefix_not_match(node.as_ref(), k, &mut next_level);
259            match res {
260                None => {
261                    level = next_level;
262                    node_key = k[level];
263
264                    let next_node = node.as_ref().get_child(node_key);
265
266                    node.check_version()?;
267
268                    let next_node = if let Some(n) = next_node {
269                        n
270                    } else {
271                        let new_leaf = {
272                            match Self::is_last_level(level) {
273                                Ok(_is_last_level) => NodePtr::from_payload(tid_func(None)),
274                                Err(_is_sub_node) => {
275                                    let new_prefix = k;
276                                    let mut n4 = BaseNode::make_node::<Node4>(
277                                        &new_prefix[..k.len() - 1],
278                                        &self.allocator,
279                                    )?;
280                                    n4.as_mut().insert(
281                                        k[k.len() - 1],
282                                        NodePtr::from_payload(tid_func(None)),
283                                    );
284                                    n4.into_note_ptr()
285                                }
286                            }
287                        };
288
289                        if let Err(e) = BaseNode::insert_and_unlock(
290                            node,
291                            (parent_key, parent_node),
292                            (node_key, new_leaf),
293                            &self.allocator,
294                            guard,
295                        ) {
296                            match new_leaf.downcast::<K_LEN>(level) {
297                                PtrType::Payload(_) => {}
298                                PtrType::SubNode(sub_node) => unsafe {
299                                    BaseNode::drop_node(sub_node, self.allocator.clone());
300                                },
301                            }
302                            return Err(e);
303                        }
304
305                        return Ok(None);
306                    };
307
308                    if let Some(p) = parent_node {
309                        p.unlock()?;
310                    }
311
312                    match next_node.downcast::<K_LEN>(level) {
313                        PtrType::Payload(old) => {
314                            // At this point, the level must point to the last u8 of the key,
315                            // meaning that we are updating an existing value.
316                            let new = tid_func(Some(old));
317                            if old == new {
318                                node.check_version()?;
319                                return Ok(Some(old));
320                            }
321
322                            let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;
323
324                            write_n
325                                .as_mut()
326                                .change(node_key, NodePtr::from_payload(new));
327                            return Ok(Some(old));
328                        }
329                        PtrType::SubNode(sub_node) => {
330                            parent_node = Some(node);
331                            node = BaseNode::read_lock(sub_node)?;
332                            level += 1;
333                        }
334                    }
335                }
336
337                Some(no_match_key) => {
338                    let mut write_p = parent_node.unwrap().upgrade().map_err(|(_n, v)| v)?;
339                    let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;
340
341                    // 1) Create new node which will be parent of node, Set common prefix, level to this node
342                    let mut new_middle_node = BaseNode::make_node::<Node4>(
343                        write_n.as_ref().prefix()[0..next_level].as_ref(),
344                        &self.allocator,
345                    )?;
346
347                    // 2)  add node and (tid, *k) as children
348                    if next_level == (K_LEN - 1) {
349                        // this is the last key, just insert to node
350                        new_middle_node
351                            .as_mut()
352                            .insert(k[next_level], NodePtr::from_payload(tid_func(None)));
353                    } else {
354                        // otherwise create a new node
355                        let mut single_new_node =
356                            BaseNode::make_node::<Node4>(&k[..k.len() - 1], &self.allocator)?;
357
358                        single_new_node
359                            .as_mut()
360                            .insert(k[k.len() - 1], NodePtr::from_payload(tid_func(None)));
361                        new_middle_node
362                            .as_mut()
363                            .insert(k[next_level], single_new_node.into_note_ptr());
364                    }
365
366                    new_middle_node
367                        .as_mut()
368                        .insert(no_match_key, NodePtr::from_node(write_n.as_mut()));
369
370                    // 3) update parentNode to point to the new node, unlock
371                    write_p
372                        .as_mut()
373                        .change(parent_key, new_middle_node.into_note_ptr());
374
375                    return Ok(None);
376                }
377            }
378        }
379    }
380
381    #[inline]
382    pub(crate) fn insert(
383        &self,
384        k: &[u8; K_LEN],
385        tid: usize,
386        guard: &Guard,
387    ) -> Result<Option<usize>, OOMError> {
388        let backoff = Backoff::new();
389        loop {
390            match self.insert_inner(k, &mut |_| tid, guard) {
391                Ok(v) => return Ok(v),
392                Err(e) => match e {
393                    ArtError::Locked | ArtError::VersionNotMatch => {
394                        backoff.spin();
395                        continue;
396                    }
397                    ArtError::Oom => return Err(OOMError::new()),
398                },
399            }
400        }
401    }
402
403    #[inline]
404    pub(crate) fn compute_or_insert<F>(
405        &self,
406        k: &[u8; K_LEN],
407        insert_func: &mut F,
408        guard: &Guard,
409    ) -> Result<Option<usize>, OOMError>
410    where
411        F: FnMut(Option<usize>) -> usize,
412    {
413        let backoff = Backoff::new();
414        loop {
415            match self.insert_inner(k, insert_func, guard) {
416                Ok(v) => return Ok(v),
417                Err(e) => match e {
418                    ArtError::Locked | ArtError::VersionNotMatch => {
419                        backoff.spin();
420                        continue;
421                    }
422                    ArtError::Oom => return Err(OOMError::new()),
423                },
424            }
425        }
426    }
427
428    fn check_prefix(node: &BaseNode, key: &[u8; K_LEN], mut level: usize) -> Option<usize> {
429        let node_prefix = node.prefix();
430        let key_prefix = key;
431
432        for (n, k) in node_prefix.iter().zip(key_prefix).skip(level) {
433            if n != k {
434                return None;
435            }
436            level += 1;
437        }
438        debug_assert!(level == node_prefix.len());
439        Some(level)
440    }
441
442    #[inline]
443    fn check_prefix_not_match(
444        &self,
445        n: &BaseNode,
446        key: &[u8; K_LEN],
447        level: &mut usize,
448    ) -> Option<u8> {
449        let n_prefix = n.prefix();
450        if !n_prefix.is_empty() {
451            let p_iter = n_prefix.iter().skip(*level);
452            for (i, v) in p_iter.enumerate() {
453                if *v != key[*level] {
454                    let no_matching_key = *v;
455
456                    let mut prefix = Prefix::default();
457                    for (j, v) in prefix.iter_mut().enumerate().take(n_prefix.len() - i - 1) {
458                        *v = n_prefix[j + 1 + i];
459                    }
460
461                    return Some(no_matching_key);
462                }
463                *level += 1;
464            }
465        }
466
467        None
468    }
469
470    #[inline]
471    pub(crate) fn range(
472        &self,
473        start: &[u8; K_LEN],
474        end: &[u8; K_LEN],
475        result: &mut [([u8; K_LEN], usize)],
476        _guard: &Guard,
477    ) -> usize {
478        let mut range_scan = RangeScan::new(start, end, result, self.root);
479
480        if !range_scan.is_valid_key_pair() {
481            return 0;
482        }
483
484        let backoff = Backoff::new();
485        loop {
486            let scanned = range_scan.scan();
487            match scanned {
488                Ok(n) => {
489                    return n;
490                }
491                Err(_) => {
492                    backoff.spin();
493                }
494            }
495        }
496    }
497
498    #[inline]
499    fn compute_if_present_inner<F>(
500        &self,
501        k: &[u8; K_LEN],
502        remapping_function: &mut F,
503        guard: &Guard,
504    ) -> Result<Option<(usize, Option<usize>)>, ArtError>
505    where
506        F: FnMut(usize) -> Option<usize>,
507    {
508        let mut parent: Option<(ReadGuard, u8)> = None;
509        let mut node_key: u8;
510        let mut level = 0;
511        let mut node = BaseNode::read_lock_root(self.root)?;
512
513        loop {
514            level = if let Some(v) = Self::check_prefix(node.as_ref(), k, level) {
515                v
516            } else {
517                return Ok(None);
518            };
519
520            node_key = k[level];
521
522            let child_node = node.as_ref().get_child(node_key);
523            node.check_version()?;
524
525            let child_node = match child_node {
526                Some(n) => n,
527                None => return Ok(None),
528            };
529
530            match child_node.downcast::<K_LEN>(level) {
531                PtrType::Payload(tid) => {
532                    let new_v = remapping_function(tid);
533
534                    match new_v {
535                        Some(new_v) => {
536                            if new_v == tid {
537                                // the value is not change, early return;
538                                return Ok(Some((tid, Some(tid))));
539                            }
540                            let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;
541                            write_n
542                                .as_mut()
543                                .change(k[level], NodePtr::from_payload(new_v));
544
545                            return Ok(Some((tid, Some(new_v))));
546                        }
547                        None => {
548                            // new value is none, we need to delete this entry
549                            debug_assert!(parent.is_some()); // reaching leaf means we must have parent, bcs root can't be leaf
550                            if node.as_ref().value_count() == 1 {
551                                let (parent_node, parent_key) = parent.unwrap();
552                                let mut write_p = parent_node.upgrade().map_err(|(_n, v)| v)?;
553
554                                let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;
555
556                                write_p.as_mut().remove(parent_key);
557
558                                write_n.mark_obsolete();
559                                let allocator = self.allocator.clone();
560                                guard.defer(move || unsafe {
561                                    let ptr = NonNull::from(write_n.as_mut());
562                                    std::mem::forget(write_n);
563                                    BaseNode::drop_node(ptr, allocator);
564                                });
565                            } else {
566                                let mut write_n = node.upgrade().map_err(|(_n, v)| v)?;
567
568                                write_n.as_mut().remove(node_key);
569                            }
570                            return Ok(Some((tid, None)));
571                        }
572                    }
573                }
574                PtrType::SubNode(sub_node) => {
575                    level += 1;
576                    parent = Some((node, node_key));
577                    node = BaseNode::read_lock(sub_node)?;
578                }
579            }
580        }
581    }
582
583    #[inline]
584    pub(crate) fn compute_if_present<F>(
585        &self,
586        k: &[u8; K_LEN],
587        remapping_function: &mut F,
588        guard: &Guard,
589    ) -> Option<(usize, Option<usize>)>
590    where
591        F: FnMut(usize) -> Option<usize>,
592    {
593        let backoff = Backoff::new();
594        loop {
595            match self.compute_if_present_inner(k, &mut *remapping_function, guard) {
596                Ok(n) => return n,
597                Err(_) => backoff.spin(),
598            }
599        }
600    }
601}
602
603enum VisitingEntry<const K_LEN: usize> {
604    SubNode(NonNull<BaseNode>),
605    Payload(([u8; K_LEN], usize)),
606}