par_dfs/async/dfs.rs
1use super::{Node, Stack, StreamQueue};
2
3use futures::stream::{FuturesOrdered, Stream, StreamExt};
4use futures::FutureExt;
5use pin_project::pin_project;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11#[derive(Default)]
12#[pin_project]
13/// Asynchronous depth-first stream for types implementing the [`Node`] trait.
14///
15/// ### Example
16/// ```
17/// use futures::StreamExt;
18/// use par_dfs::r#async::{Node, Dfs, NodeStream};
19///
20/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
21/// struct WordNode(String);
22///
23/// #[async_trait::async_trait]
24/// impl Node for WordNode {
25/// type Error = std::convert::Infallible;
26///
27/// async fn children(
28/// self: std::sync::Arc<Self>,
29/// _depth: usize
30/// ) -> Result<NodeStream<Self, Self::Error>, Self::Error> {
31/// let len = self.0.len();
32/// let nodes: Vec<String> = if len < 2 {
33/// vec![]
34/// } else {
35/// let mid = len/2;
36/// vec![self.0[..mid].into(), self.0[mid..].into()]
37/// };
38/// let nodes = nodes.into_iter()
39/// .map(Self)
40/// .map(Result::Ok);
41/// let stream = futures::stream::iter(nodes);
42/// Ok(Box::pin(stream.boxed()))
43/// }
44/// }
45///
46/// let result = tokio_test::block_on(async {
47/// let root = WordNode("Hello World".into());
48/// let dfs = Dfs::<WordNode>::new(root, None, true);
49/// let output = dfs
50/// .collect::<Vec<_>>()
51/// .await
52/// .into_iter()
53/// .collect::<Result<Vec<_>, _>>()
54/// .unwrap();
55/// output.into_iter()
56/// .filter_map(|s| if s.0.len() == 1 { Some(s.0) } else { None })
57/// .collect::<String>()
58/// });
59/// assert_eq!(result, "Hello World");
60/// ```
61///
62/// [`Node`]: trait@crate::async::Node
63pub struct Dfs<N>
64where
65 N: Node,
66{
67 stack: Stack<N, N::Error>,
68 child_streams_futs: StreamQueue<N, N::Error>,
69 max_depth: Option<usize>,
70 allow_circles: bool,
71 visited: HashSet<N>,
72}
73
74impl<N> Dfs<N>
75where
76 N: Node + Send + Unpin + Clone + 'static,
77 N::Error: Send + 'static,
78{
79 #[inline]
80 /// Creates a new [`Dfs`] stream.
81 ///
82 /// The DFS will be performed from the `root` node up to depth `max_depth`.
83 ///
84 /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
85 ///
86 /// [`Dfs`]: struct@crate::async::Dfs
87 pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
88 where
89 R: Into<N>,
90 D: Into<Option<usize>>,
91 {
92 let root = root.into();
93 let max_depth = max_depth.into();
94 let mut child_streams_futs: StreamQueue<N, N::Error> = FuturesOrdered::new();
95 let depth = 1;
96 let child_stream_fut = Arc::new(root.clone())
97 .children(depth)
98 .map(move |stream| (depth, stream));
99 child_streams_futs.push_front(Box::pin(child_stream_fut));
100
101 Self {
102 stack: vec![],
103 child_streams_futs,
104 max_depth,
105 visited: HashSet::from_iter([root]),
106 allow_circles,
107 }
108 }
109}
110
111impl<N> Stream for Dfs<N>
112where
113 N: Node + Send + Clone + Unpin + 'static,
114 N::Error: Send + 'static,
115{
116 type Item = Result<N, N::Error>;
117
118 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 let this = self.project();
120
121 // println!("------- poll");
122 // println!("stack size: {:?}", this.stack.len());
123
124 // we first poll for the newest child stream in dfs
125 // println!("child stream futs: {:?}", this.child_streams_futs.len());
126 match this.child_streams_futs.poll_next_unpin(cx) {
127 Poll::Ready(Some((depth, stream))) => {
128 // println!(
129 // "child stream fut depth {} completed: {:?}",
130 // depth,
131 // stream.is_ok()
132 // );
133 let stream = match stream {
134 Ok(stream) => stream.boxed(),
135 Err(err) => futures::stream::iter([Err(err)]).boxed(),
136 };
137 this.stack.push((depth, Box::pin(stream)));
138 // println!("stack size: {}", this.stack.len());
139 }
140 // when there is no child stream future,
141 // continue to poll the current stream
142 Poll::Ready(None) => {
143 // println!("no child stream to wait for");
144 }
145 // still waiting for the new child stream
146 Poll::Pending => {
147 // println!("child stream is still pending");
148 return Poll::Pending;
149 }
150 }
151
152 // at this point, the last element in the stack is the current level
153 loop {
154 let next_item = match this.stack.last_mut() {
155 Some((depth, current_stream)) => {
156 let next_item = current_stream.as_mut().poll_next(cx);
157 Some(next_item.map(|node| (depth, node)))
158 }
159 None => None,
160 };
161
162 // println!("next item: {:?}", next_item);
163 match next_item {
164 // stream item is ready but failure success
165 Some(Poll::Ready((_, Some(Err(err))))) => {
166 return Poll::Ready(Some(Err(err)));
167 }
168 // stream item is ready and success
169 Some(Poll::Ready((depth, Some(Ok(node))))) => {
170 if *this.allow_circles || !this.visited.contains(&node) {
171 if !*this.allow_circles {
172 this.visited.insert(node.clone());
173 }
174
175 if let Some(max_depth) = this.max_depth {
176 if depth >= max_depth {
177 return Poll::Ready(Some(Ok(node)));
178 }
179 }
180
181 // add child stream future to be polled
182 let arc_node = Arc::new(node.clone());
183 let next_depth = *depth + 1;
184 let child_stream_fut = arc_node
185 .children(next_depth)
186 .map(move |stream| (next_depth, stream));
187 this.child_streams_futs
188 .push_front(Box::pin(child_stream_fut));
189
190 return Poll::Ready(Some(Ok(node)));
191 }
192 }
193 // stream completed for this level completed
194 Some(Poll::Ready((_, None))) => {
195 this.stack.pop();
196 // println!("pop stack to size: {}", this.stack.len());
197 // try again in the next round
198 // returning Poll::Pending here is bad because the runtime can not know when to poll
199 // us again to make progress since we never passed the cx to poll of the next
200 // level stream
201 }
202 // stream item is pending
203 Some(Poll::Pending) => {
204 return Poll::Pending;
205 }
206 // stack is empty and we are done
207 None => {
208 return Poll::Ready(None);
209 }
210 }
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::utils::test;
219 use anyhow::Result;
220 use futures::StreamExt;
221 use pretty_assertions::assert_eq;
222 use tokio::time::{sleep, Duration};
223
224 macro_rules! depths {
225 ($stream:ident) => {{
226 $stream
227 // collect the entire stream
228 .collect::<Vec<_>>()
229 .await
230 .into_iter()
231 // fail on first error
232 .collect::<Result<Vec<_>, _>>()?
233 .into_iter()
234 // get depth
235 .map(|item| item.0)
236 .collect::<Vec<_>>()
237 }};
238 }
239
240 macro_rules! test_depths_unordered {
241 ($name:ident: $values:expr) => {
242 paste::item! {
243 #[tokio::test(flavor = "multi_thread")]
244 async fn [< test_ $name _ unordered >] () -> Result<()> {
245 let (iter, expected_depths) = $values;
246 let iter = iter
247 .map(|node| async move {
248 sleep(Duration::from_millis(100)).await;
249 node
250 })
251 .buffer_unordered(8);
252 let depths = depths!(iter);
253 assert_eq!(depths, expected_depths);
254 Ok(())
255 }
256 }
257 };
258 }
259
260 macro_rules! test_depths_ordered {
261 ($name:ident: $values:expr) => {
262 paste::item! {
263 #[tokio::test(flavor = "multi_thread")]
264 async fn [< test_ $name _ ordered >] () -> Result<()> {
265 let (iter, expected_depths) = $values;
266 let iter = iter
267 .map(|node| async move {
268 sleep(Duration::from_millis(100)).await;
269 node
270 })
271 .buffered(8);
272 let depths = depths!(iter);
273 assert_eq!(depths, expected_depths);
274 Ok(())
275 }
276 }
277 };
278 }
279
280 macro_rules! test_depths {
281 ($name:ident: $values:expr, $($macro:ident,)*) => {
282 $(
283 $macro!($name: $values);
284 )*
285 }
286 }
287
288 test_depths!(
289 dfs:
290 (
291 Dfs::<test::Node>::new(0, 3, true),
292 [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
293 ),
294 test_depths_ordered,
295 test_depths_unordered,
296 );
297
298 test_depths!(
299 dfs_no_circles:
300 (
301 Dfs::<test::Node>::new(0, 3, false),
302 [1, 2, 3]
303 ),
304 test_depths_ordered,
305 test_depths_unordered,
306 );
307}