ntree_rs/traversal/traverse_mut/
async.rs

1//! Asynchronous traversal implementation.
2
3use crate::{
4    traversal::{macros_async, TraverseMut},
5    Asynchronous, Node, Synchronous,
6};
7use async_recursion::async_recursion;
8use futures::future::join_all;
9use std::marker::PhantomData;
10
11impl<'a, T> TraverseMut<'a, T, Asynchronous> {
12    /// Converts the asynchronous traverse into a synchronous one.
13    pub fn into_sync(self) -> TraverseMut<'a, T, Synchronous> {
14        self.into()
15    }
16}
17
18impl<'a, T: Sync + Send + 'a> TraverseMut<'a, T, Asynchronous> {
19    pub(crate) fn new_async(node: &'a mut Node<T>) -> Self {
20        Self {
21            node,
22            strategy: PhantomData,
23        }
24    }
25
26    macros_async::for_each!(&mut Node<T>, iter_mut);
27    macros_async::map!(&mut Node<T>, iter_mut);
28    macros_async::reduce!(&mut Node<T>, iter_mut);
29    macros_async::cascade!(&mut Node<T>, iter_mut);
30}
31
32#[cfg(test)]
33mod tests {
34    use super::*;
35    use crate::node;
36    use std::sync::{Arc, Mutex};
37
38    #[tokio::test]
39    async fn test_for_each() {
40        let mut root = node!(10_i32, node!(20, node!(40)), node!(30, node!(50)));
41
42        let result = Arc::new(Mutex::new(Vec::new()));
43        root.traverse_mut()
44            .into_async()
45            .for_each(|n| {
46                n.value = n.value.saturating_add(1);
47                result.clone().lock().unwrap().push(n.value);
48            })
49            .await;
50
51        let got = result.lock().unwrap();
52        assert!(got.contains(&41));
53        assert!(got.contains(&51));
54        assert!(got.contains(&21));
55        assert!(got.contains(&31));
56        assert_eq!(got[got.len() - 1], 11);
57    }
58
59    #[tokio::test]
60    async fn test_map() {
61        let mut original = node!(1, node!(2, node!(4)), node!(3, node!(5)));
62        let new_root = original
63            .traverse_mut()
64            .into_async()
65            .map(|n| {
66                n.value += 1;
67                n.value % 2 == 0
68            })
69            .await;
70
71        let want = node!(2, node!(3, node!(5)), node!(4, node!(6)));
72        assert_eq!(original, want);
73
74        let want = node!(true, node!(false, node!(false)), node!(true, node!(true)));
75        assert_eq!(new_root.take(), want);
76    }
77
78    #[tokio::test]
79    async fn test_reduce() {
80        let mut root = node!(10_i32, node!(20, node!(40)), node!(30, node!(50)));
81
82        let result = Arc::new(Mutex::new(Vec::new()));
83        let sum = root
84            .traverse_mut()
85            .into_async()
86            .reduce(|n, results| {
87                n.value = n.value.saturating_add(1);
88                result.clone().lock().unwrap().push(n.value);
89                n.value + results.iter().sum::<i32>()
90            })
91            .await;
92
93        assert_eq!(sum, 155);
94
95        let got = result.lock().unwrap();
96        assert!(got.contains(&41));
97        assert!(got.contains(&21));
98        assert!(got.contains(&51));
99        assert!(got.contains(&31));
100        assert_eq!(got[got.len() - 1], 11);
101    }
102
103    #[tokio::test]
104    async fn test_cascade() {
105        let mut root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
106
107        let result = Arc::new(Mutex::new(Vec::new()));
108        root.traverse_mut()
109            .into_async()
110            .cascade(0, |n, parent_value| {
111                let next = n.value + parent_value;
112                result.clone().lock().unwrap().push(next);
113                n.value = *parent_value;
114                next
115            })
116            .await;
117
118        assert_eq!(root.value, 0);
119        assert_eq!(root.children[0].value, 10);
120        assert_eq!(root.children[1].value, 10);
121        assert_eq!(root.children[0].children[0].value, 30);
122        assert_eq!(root.children[1].children[0].value, 40);
123
124        let got = result.lock().unwrap();
125        assert_eq!(got[0], 10);
126        assert!(got.contains(&30));
127        assert!(got.contains(&40));
128        assert!(got.contains(&70));
129        assert!(got.contains(&90));
130    }
131}