ntree_rs/traversal/traverse/
async.rs1use crate::{
4 traversal::{macros_async, Traverse},
5 Asynchronous, Node, Synchronous,
6};
7use async_recursion::async_recursion;
8use futures::future::join_all;
9use std::marker::PhantomData;
10
11impl<'a, T> From<Traverse<'a, T, Synchronous>> for Traverse<'a, T, Asynchronous>
12where
13 T: Sync + Send,
14{
15 fn from(value: Traverse<'a, T, Synchronous>) -> Self {
16 Traverse::new_async(value.node)
17 }
18}
19
20impl<'a, T> Traverse<'a, T, Asynchronous> {
21 pub fn into_sync(self) -> Traverse<'a, T, Synchronous> {
23 self.into()
24 }
25}
26
27impl<'a, T: Sync + Send + 'a> Traverse<'a, T, Asynchronous> {
28 pub(crate) fn new_async(node: &'a Node<T>) -> Self {
29 Self {
30 node,
31 strategy: PhantomData,
32 }
33 }
34
35 macros_async::for_each!(&Node<T>, iter);
36 macros_async::map!(&Node<T>, iter);
37 macros_async::reduce!(&Node<T>, iter);
38 macros_async::cascade!(&Node<T>, iter);
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use crate::node;
45 use std::sync::{Arc, Mutex};
46
47 #[tokio::test]
48 async fn test_for_each() {
49 let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
50
51 let result = Arc::new(Mutex::new(Vec::new()));
52 root.traverse()
53 .into_async()
54 .for_each(|n| result.clone().lock().unwrap().push(n.value))
55 .await;
56
57 let got = result.lock().unwrap();
58 assert!(got.contains(&40));
59 assert!(got.contains(&20));
60 assert!(got.contains(&50));
61 assert!(got.contains(&30));
62 assert_eq!(got[got.len() - 1], 10);
63 }
64
65 #[tokio::test]
66 async fn test_map() {
67 let original = node!(1, node!(2, node!(4)), node!(3, node!(5)));
68
69 let copy = original.clone();
70 let new_root = copy.traverse().into_async().map(|n| n.value % 2 == 0).await;
71 assert_eq!(original, copy);
72
73 let want = node!(false, node!(true, node!(true)), node!(false, node!(false)));
74 assert_eq!(new_root.take(), want);
75 }
76
77 #[tokio::test]
78 async fn test_reduce() {
79 let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
80
81 let result = Arc::new(Mutex::new(Vec::new()));
82 let sum = root
83 .traverse()
84 .into_async()
85 .reduce(|n, results| {
86 result.clone().lock().unwrap().push(n.value);
87 n.value + results.iter().sum::<i32>()
88 })
89 .await;
90
91 assert_eq!(sum, 150);
92
93 let got = result.lock().unwrap();
94 assert!(got.contains(&40));
95 assert!(got.contains(&20));
96 assert!(got.contains(&50));
97 assert!(got.contains(&30));
98 assert_eq!(got[got.len() - 1], 10);
99 }
100
101 #[tokio::test]
102 async fn test_cascade() {
103 let root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
104
105 let result = Arc::new(Mutex::new(Vec::new()));
106 root.traverse()
107 .into_async()
108 .cascade(0, |n, parent_value| {
109 let next = n.value + parent_value;
110 result.clone().lock().unwrap().push(next);
111 next
112 })
113 .await;
114
115 let got = result.lock().unwrap();
116 assert_eq!(got[0], 10);
117 assert!(got.contains(&30));
118 assert!(got.contains(&40));
119 assert!(got.contains(&70));
120 assert!(got.contains(&90));
121 }
122}