arena_voxel_tree/
parallel_tree.rs

1//! [`MTArena`] is defined here.
2//! * This library is inspired by [r3bl-org work](https://github.com/r3bl-org/)
3
4use std::{
5    fmt::Debug,
6    marker::{Send, Sync},
7    sync::{Arc, RwLock},
8    thread::{spawn, JoinHandle},
9};
10
11use crate::arena_node::{Arena, Node};
12use crate::arena_types::{ResultUidList, ShareableArena, WalkerFn};
13use crate::utils::ReadGuarded;
14
15/// [`MTArena`] is built on top of [`Arena`] but with support for sharing the arena between threads.
16/// Also supports tree walking on a separate thread w/ a lambda that's supplied.
17///
18/// 1. [Wikipedia definition of memory
19///    arena](https://en.wikipedia.org/wiki/Region-based_memory_management)
20/// 2. You can learn more about how this library was built from this [developerlife.com
21///    article](https://developerlife.com/2022/02/24/rust-non-binary-tree/).
22#[derive(Debug)]
23pub struct MTArena<T>
24where
25    T: Debug + Send + Sync + Clone + 'static,
26{
27    arena_arc: ShareableArena<T>,
28}
29
30impl<T> MTArena<T>
31where
32    T: Debug + Send + Sync + Clone + 'static,
33{
34    pub fn new() -> Self {
35        MTArena {
36            arena_arc: Arc::new(RwLock::new(Arena::new())),
37        }
38    }
39
40    pub fn get_arena_arc(&self) -> ShareableArena<T> {
41        self.arena_arc.clone()
42    }
43
44    /// `walker_fn` is a closure that captures variables. It is wrapped in an `Arc` to be able to
45    /// clone that and share it across threads. More info:
46    /// 1. SO thread: <https://stackoverflow.com/a/36213377/2085356>
47    /// 2. Scoped threads: <https://docs.rs/crossbeam/0.3.0/crossbeam/struct.Scope.html>
48    pub fn tree_walk_parallel(
49        &self,
50        node_id: usize,
51        walker_fn: Arc<WalkerFn<T>>,
52        traversal_kind: TraversalKind,
53    ) -> JoinHandle<ResultUidList> {
54        let arena_arc = self.get_arena_arc();
55        let walker_fn_arc = walker_fn;
56
57        spawn(move || {
58            let read_guard: ReadGuarded<'_, Arena<T>> = arena_arc.read().unwrap();
59            let return_value = match traversal_kind {
60                TraversalKind::DepthFirst => read_guard.tree_walk_dfs(node_id),
61                TraversalKind::BreadthFirst => read_guard.tree_walk_bfs(node_id),
62            };
63
64            // While walking the tree, in a separate thread, call the `walker_fn` for each
65            // node.
66            if let Some(result_list) = return_value.clone() {
67                result_list.into_iter().for_each(|uid| {
68                    let node_arc_opt = read_guard.get_node_arc(uid);
69                    if let Some(node_arc) = node_arc_opt {
70                        let node_ref: ReadGuarded<'_, Node<T>> = node_arc.read().unwrap();
71                        walker_fn_arc(uid, node_ref.payload.clone());
72                    }
73                });
74            }
75
76            return_value
77        })
78    }
79}
80
81#[derive(Debug)]
82pub enum TraversalKind {
83    DepthFirst,
84    BreadthFirst,
85}
86
87impl<T> Default for MTArena<T>
88where
89    T: Debug + Send + Sync + Clone + 'static,
90{
91    fn default() -> Self {
92        Self::new()
93    }
94}