par_dfs/async/
dfs.rs

1use super::{Node, Stack, StreamQueue};
2
3use futures::stream::{FuturesOrdered, Stream, StreamExt};
4use futures::FutureExt;
5use pin_project::pin_project;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11#[derive(Default)]
12#[pin_project]
13/// Asynchronous depth-first stream for types implementing the [`Node`] trait.
14///
15/// ### Example
16/// ```
17/// use futures::StreamExt;
18/// use par_dfs::r#async::{Node, Dfs, NodeStream};
19///
20/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
21/// struct WordNode(String);
22///
23/// #[async_trait::async_trait]
24/// impl Node for WordNode {
25///     type Error = std::convert::Infallible;
26///
27///     async fn children(
28///         self: std::sync::Arc<Self>,
29///         _depth: usize
30///     ) -> Result<NodeStream<Self, Self::Error>, Self::Error> {
31///         let len = self.0.len();
32///         let nodes: Vec<String> = if len < 2 {
33///             vec![]
34///         } else {
35///             let mid = len/2;
36///             vec![self.0[..mid].into(), self.0[mid..].into()]
37///         };
38///         let nodes = nodes.into_iter()
39///             .map(Self)
40///             .map(Result::Ok);
41///         let stream = futures::stream::iter(nodes);
42///         Ok(Box::pin(stream.boxed()))
43///     }
44/// }
45///
46/// let result = tokio_test::block_on(async {
47///     let root = WordNode("Hello World".into());
48///     let dfs = Dfs::<WordNode>::new(root, None, true);
49///     let output = dfs
50///         .collect::<Vec<_>>()
51///         .await
52///         .into_iter()
53///         .collect::<Result<Vec<_>, _>>()
54///         .unwrap();
55///     output.into_iter()
56///         .filter_map(|s| if s.0.len() == 1 { Some(s.0) } else { None })
57///         .collect::<String>()
58/// });
59/// assert_eq!(result, "Hello World");
60/// ```
61///
62/// [`Node`]: trait@crate::async::Node
63pub struct Dfs<N>
64where
65    N: Node,
66{
67    stack: Stack<N, N::Error>,
68    child_streams_futs: StreamQueue<N, N::Error>,
69    max_depth: Option<usize>,
70    allow_circles: bool,
71    visited: HashSet<N>,
72}
73
74impl<N> Dfs<N>
75where
76    N: Node + Send + Unpin + Clone + 'static,
77    N::Error: Send + 'static,
78{
79    #[inline]
80    /// Creates a new [`Dfs`] stream.
81    ///
82    /// The DFS will be performed from the `root` node up to depth `max_depth`.
83    ///
84    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
85    ///
86    /// [`Dfs`]: struct@crate::async::Dfs
87    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
88    where
89        R: Into<N>,
90        D: Into<Option<usize>>,
91    {
92        let root = root.into();
93        let max_depth = max_depth.into();
94        let mut child_streams_futs: StreamQueue<N, N::Error> = FuturesOrdered::new();
95        let depth = 1;
96        let child_stream_fut = Arc::new(root.clone())
97            .children(depth)
98            .map(move |stream| (depth, stream));
99        child_streams_futs.push_front(Box::pin(child_stream_fut));
100
101        Self {
102            stack: vec![],
103            child_streams_futs,
104            max_depth,
105            visited: HashSet::from_iter([root]),
106            allow_circles,
107        }
108    }
109}
110
111impl<N> Stream for Dfs<N>
112where
113    N: Node + Send + Clone + Unpin + 'static,
114    N::Error: Send + 'static,
115{
116    type Item = Result<N, N::Error>;
117
118    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        let this = self.project();
120
121        // println!("------- poll");
122        // println!("stack size: {:?}", this.stack.len());
123
124        // we first poll for the newest child stream in dfs
125        // println!("child stream futs: {:?}", this.child_streams_futs.len());
126        match this.child_streams_futs.poll_next_unpin(cx) {
127            Poll::Ready(Some((depth, stream))) => {
128                // println!(
129                //     "child stream fut depth {} completed: {:?}",
130                //     depth,
131                //     stream.is_ok()
132                // );
133                let stream = match stream {
134                    Ok(stream) => stream.boxed(),
135                    Err(err) => futures::stream::iter([Err(err)]).boxed(),
136                };
137                this.stack.push((depth, Box::pin(stream)));
138                // println!("stack size: {}", this.stack.len());
139            }
140            // when there is no child stream future,
141            // continue to poll the current stream
142            Poll::Ready(None) => {
143                // println!("no child stream to wait for");
144            }
145            // still waiting for the new child stream
146            Poll::Pending => {
147                // println!("child stream is still pending");
148                return Poll::Pending;
149            }
150        }
151
152        // at this point, the last element in the stack is the current level
153        loop {
154            let next_item = match this.stack.last_mut() {
155                Some((depth, current_stream)) => {
156                    let next_item = current_stream.as_mut().poll_next(cx);
157                    Some(next_item.map(|node| (depth, node)))
158                }
159                None => None,
160            };
161
162            // println!("next item: {:?}", next_item);
163            match next_item {
164                // stream item is ready but failure success
165                Some(Poll::Ready((_, Some(Err(err))))) => {
166                    return Poll::Ready(Some(Err(err)));
167                }
168                // stream item is ready and success
169                Some(Poll::Ready((depth, Some(Ok(node))))) => {
170                    if *this.allow_circles || !this.visited.contains(&node) {
171                        if !*this.allow_circles {
172                            this.visited.insert(node.clone());
173                        }
174
175                        if let Some(max_depth) = this.max_depth {
176                            if depth >= max_depth {
177                                return Poll::Ready(Some(Ok(node)));
178                            }
179                        }
180
181                        // add child stream future to be polled
182                        let arc_node = Arc::new(node.clone());
183                        let next_depth = *depth + 1;
184                        let child_stream_fut = arc_node
185                            .children(next_depth)
186                            .map(move |stream| (next_depth, stream));
187                        this.child_streams_futs
188                            .push_front(Box::pin(child_stream_fut));
189
190                        return Poll::Ready(Some(Ok(node)));
191                    }
192                }
193                // stream completed for this level completed
194                Some(Poll::Ready((_, None))) => {
195                    this.stack.pop();
196                    // println!("pop stack to size: {}", this.stack.len());
197                    // try again in the next round
198                    // returning Poll::Pending here is bad because the runtime can not know when to poll
199                    // us again to make progress since we never passed the cx to poll of the next
200                    // level stream
201                }
202                // stream item is pending
203                Some(Poll::Pending) => {
204                    return Poll::Pending;
205                }
206                // stack is empty and we are done
207                None => {
208                    return Poll::Ready(None);
209                }
210            }
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::utils::test;
219    use anyhow::Result;
220    use futures::StreamExt;
221    use pretty_assertions::assert_eq;
222    use tokio::time::{sleep, Duration};
223
224    macro_rules! depths {
225        ($stream:ident) => {{
226            $stream
227                // collect the entire stream
228                .collect::<Vec<_>>()
229                .await
230                .into_iter()
231                // fail on first error
232                .collect::<Result<Vec<_>, _>>()?
233                .into_iter()
234                // get depth
235                .map(|item| item.0)
236                .collect::<Vec<_>>()
237        }};
238    }
239
240    macro_rules! test_depths_unordered {
241        ($name:ident: $values:expr) => {
242            paste::item! {
243                #[tokio::test(flavor = "multi_thread")]
244                async fn [< test_ $name _ unordered >] () -> Result<()> {
245                    let (iter, expected_depths) = $values;
246                    let iter = iter
247                        .map(|node| async move {
248                            sleep(Duration::from_millis(100)).await;
249                            node
250                        })
251                        .buffer_unordered(8);
252                    let depths = depths!(iter);
253                    assert_eq!(depths, expected_depths);
254                    Ok(())
255                }
256            }
257        };
258    }
259
260    macro_rules! test_depths_ordered {
261        ($name:ident: $values:expr) => {
262            paste::item! {
263                #[tokio::test(flavor = "multi_thread")]
264                async fn [< test_ $name _ ordered >] () -> Result<()> {
265                    let (iter, expected_depths) = $values;
266                    let iter = iter
267                        .map(|node| async move {
268                            sleep(Duration::from_millis(100)).await;
269                            node
270                        })
271                        .buffered(8);
272                    let depths = depths!(iter);
273                    assert_eq!(depths, expected_depths);
274                    Ok(())
275                }
276            }
277        };
278    }
279
280    macro_rules! test_depths {
281        ($name:ident: $values:expr, $($macro:ident,)*) => {
282            $(
283                $macro!($name: $values);
284            )*
285        }
286    }
287
288    test_depths!(
289        dfs:
290        (
291            Dfs::<test::Node>::new(0, 3, true),
292            [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
293        ),
294        test_depths_ordered,
295        test_depths_unordered,
296    );
297
298    test_depths!(
299        dfs_no_circles:
300        (
301            Dfs::<test::Node>::new(0, 3, false),
302            [1, 2, 3]
303        ),
304        test_depths_ordered,
305        test_depths_unordered,
306    );
307}