ntree_rs/traversal/traverse/
async.rs

1//! Asynchronous traversal implementation.
2
3use crate::{
4    traversal::{macros_async, Traverse},
5    Asynchronous, Node, Synchronous,
6};
7use async_recursion::async_recursion;
8use futures::future::join_all;
9use std::marker::PhantomData;
10
11impl<'a, T> From<Traverse<'a, T, Synchronous>> for Traverse<'a, T, Asynchronous>
12where
13    T: Sync + Send,
14{
15    fn from(value: Traverse<'a, T, Synchronous>) -> Self {
16        Traverse::new_async(value.node)
17    }
18}
19
20impl<'a, T> Traverse<'a, T, Asynchronous> {
21    /// Converts the asynchronous traverse into a synchronous one.
22    pub fn into_sync(self) -> Traverse<'a, T, Synchronous> {
23        self.into()
24    }
25}
26
27impl<'a, T: Sync + Send + 'a> Traverse<'a, T, Asynchronous> {
28    pub(crate) fn new_async(node: &'a Node<T>) -> Self {
29        Self {
30            node,
31            strategy: PhantomData,
32        }
33    }
34
35    macros_async::for_each!(&Node<T>, iter);
36    macros_async::map!(&Node<T>, iter);
37    macros_async::reduce!(&Node<T>, iter);
38    macros_async::cascade!(&Node<T>, iter);
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use crate::node;
45    use std::sync::{Arc, Mutex};
46
47    #[tokio::test]
48    async fn test_for_each() {
49        let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
50
51        let result = Arc::new(Mutex::new(Vec::new()));
52        root.traverse()
53            .into_async()
54            .for_each(|n| result.clone().lock().unwrap().push(n.value))
55            .await;
56
57        let got = result.lock().unwrap();
58        assert!(got.contains(&40));
59        assert!(got.contains(&20));
60        assert!(got.contains(&50));
61        assert!(got.contains(&30));
62        assert_eq!(got[got.len() - 1], 10);
63    }
64
65    #[tokio::test]
66    async fn test_map() {
67        let original = node!(1, node!(2, node!(4)), node!(3, node!(5)));
68
69        let copy = original.clone();
70        let new_root = copy.traverse().into_async().map(|n| n.value % 2 == 0).await;
71        assert_eq!(original, copy);
72
73        let want = node!(false, node!(true, node!(true)), node!(false, node!(false)));
74        assert_eq!(new_root.take(), want);
75    }
76
77    #[tokio::test]
78    async fn test_reduce() {
79        let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
80
81        let result = Arc::new(Mutex::new(Vec::new()));
82        let sum = root
83            .traverse()
84            .into_async()
85            .reduce(|n, results| {
86                result.clone().lock().unwrap().push(n.value);
87                n.value + results.iter().sum::<i32>()
88            })
89            .await;
90
91        assert_eq!(sum, 150);
92
93        let got = result.lock().unwrap();
94        assert!(got.contains(&40));
95        assert!(got.contains(&20));
96        assert!(got.contains(&50));
97        assert!(got.contains(&30));
98        assert_eq!(got[got.len() - 1], 10);
99    }
100
101    #[tokio::test]
102    async fn test_cascade() {
103        let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
104
105        let result = Arc::new(Mutex::new(Vec::new()));
106        root.traverse()
107            .into_async()
108            .cascade(0, |n, parent_value| {
109                let next = n.value + parent_value;
110                result.clone().lock().unwrap().push(next);
111                next
112            })
113            .await;
114
115        let got = result.lock().unwrap();
116        assert_eq!(got[0], 10);
117        assert!(got.contains(&30));
118        assert!(got.contains(&40));
119        assert!(got.contains(&70));
120        assert!(got.contains(&90));
121    }
122}