ntree_rs/traversal/traverse_mut/
async.rs1use crate::{
4 traversal::{macros_async, TraverseMut},
5 Asynchronous, Node, Synchronous,
6};
7use async_recursion::async_recursion;
8use futures::future::join_all;
9use std::marker::PhantomData;
10
11impl<'a, T> TraverseMut<'a, T, Asynchronous> {
12 pub fn into_sync(self) -> TraverseMut<'a, T, Synchronous> {
14 self.into()
15 }
16}
17
18impl<'a, T: Sync + Send + 'a> TraverseMut<'a, T, Asynchronous> {
19 pub(crate) fn new_async(node: &'a mut Node<T>) -> Self {
20 Self {
21 node,
22 strategy: PhantomData,
23 }
24 }
25
26 macros_async::for_each!(&mut Node<T>, iter_mut);
27 macros_async::map!(&mut Node<T>, iter_mut);
28 macros_async::reduce!(&mut Node<T>, iter_mut);
29 macros_async::cascade!(&mut Node<T>, iter_mut);
30}
31
32#[cfg(test)]
33mod tests {
34 use super::*;
35 use crate::node;
36 use std::sync::{Arc, Mutex};
37
38 #[tokio::test]
39 async fn test_for_each() {
40 let mut root = node!(10_i32, node!(20, node!(40)), node!(30, node!(50)));
41
42 let result = Arc::new(Mutex::new(Vec::new()));
43 root.traverse_mut()
44 .into_async()
45 .for_each(|n| {
46 n.value = n.value.saturating_add(1);
47 result.clone().lock().unwrap().push(n.value);
48 })
49 .await;
50
51 let got = result.lock().unwrap();
52 assert!(got.contains(&41));
53 assert!(got.contains(&51));
54 assert!(got.contains(&21));
55 assert!(got.contains(&31));
56 assert_eq!(got[got.len() - 1], 11);
57 }
58
59 #[tokio::test]
60 async fn test_map() {
61 let mut original = node!(1, node!(2, node!(4)), node!(3, node!(5)));
62 let new_root = original
63 .traverse_mut()
64 .into_async()
65 .map(|n| {
66 n.value += 1;
67 n.value % 2 == 0
68 })
69 .await;
70
71 let want = node!(2, node!(3, node!(5)), node!(4, node!(6)));
72 assert_eq!(original, want);
73
74 let want = node!(true, node!(false, node!(false)), node!(true, node!(true)));
75 assert_eq!(new_root.take(), want);
76 }
77
78 #[tokio::test]
79 async fn test_reduce() {
80 let mut root = node!(10_i32, node!(20, node!(40)), node!(30, node!(50)));
81
82 let result = Arc::new(Mutex::new(Vec::new()));
83 let sum = root
84 .traverse_mut()
85 .into_async()
86 .reduce(|n, results| {
87 n.value = n.value.saturating_add(1);
88 result.clone().lock().unwrap().push(n.value);
89 n.value + results.iter().sum::<i32>()
90 })
91 .await;
92
93 assert_eq!(sum, 155);
94
95 let got = result.lock().unwrap();
96 assert!(got.contains(&41));
97 assert!(got.contains(&21));
98 assert!(got.contains(&51));
99 assert!(got.contains(&31));
100 assert_eq!(got[got.len() - 1], 11);
101 }
102
103 #[tokio::test]
104 async fn test_cascade() {
105 let mut root = node!(10, node!(20, node!(40)), node!(30, node!(50)));
106
107 let result = Arc::new(Mutex::new(Vec::new()));
108 root.traverse_mut()
109 .into_async()
110 .cascade(0, |n, parent_value| {
111 let next = n.value + parent_value;
112 result.clone().lock().unwrap().push(next);
113 n.value = *parent_value;
114 next
115 })
116 .await;
117
118 assert_eq!(root.value, 0);
119 assert_eq!(root.children[0].value, 10);
120 assert_eq!(root.children[1].value, 10);
121 assert_eq!(root.children[0].children[0].value, 30);
122 assert_eq!(root.children[1].children[0].value, 40);
123
124 let got = result.lock().unwrap();
125 assert_eq!(got[0], 10);
126 assert!(got.contains(&30));
127 assert!(got.contains(&40));
128 assert!(got.contains(&70));
129 assert!(got.contains(&90));
130 }
131}