par_dfs/async/
bfs.rs

1use super::{Node, NodeStream, StreamQueue};
2
3use futures::stream::{FuturesOrdered, Stream, StreamExt};
4use futures::FutureExt;
5use pin_project::pin_project;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11#[allow(clippy::module_name_repetitions)]
12#[derive(Default)]
13#[pin_project]
14/// Asynchronous breadth-first stream for types implementing the [`Node`] trait.
15///
16/// ### Example
17/// ```
18/// use futures::StreamExt;
19/// use par_dfs::r#async::{Node, Bfs, NodeStream};
20///
21/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
22/// struct WordNode(String);
23///
24/// #[async_trait::async_trait]
25/// impl Node for WordNode {
26///     type Error = std::convert::Infallible;
27///
28///     async fn children(
29///         self: std::sync::Arc<Self>,
30///         _depth: usize
31///     ) -> Result<NodeStream<Self, Self::Error>, Self::Error> {
32///         let len = self.0.len();
33///         let nodes: Vec<String> = if len > 1 {
34///             let mid = len/2;
35///             vec![self.0[..mid].into(), self.0[mid..].into()]
36///         } else {
37///             assert!(len == 1);
38///             vec![self.0.clone()]
39///         };
40///         let nodes = nodes.into_iter()
41///             .map(Self)
42///             .map(Result::Ok);
43///         let stream = futures::stream::iter(nodes);
44///         Ok(Box::pin(stream.boxed()))
45///     }
46/// }
47///
48/// let result = tokio_test::block_on(async {
49///     let word = "Hello World";
50///     let root = WordNode(word.into());
51///     let limit = (word.len() as f32).log2().ceil() as usize;
52///     let bfs = Bfs::<WordNode>::new(root, limit, true);
53///     let output = bfs
54///         .collect::<Vec<_>>()
55///         .await
56///         .into_iter()
57///         .collect::<Result<Vec<_>, _>>()
58///         .unwrap();
59///     output[output.len()-word.len()..]
60///         .into_iter().map(|s| s.0.as_str()).collect::<String>()
61/// });
62/// assert_eq!(result, "Hello World");
63/// ```
64///
65/// [`Node`]: trait@crate::async::Node
66pub struct Bfs<N>
67where
68    N: Node,
69{
70    #[pin]
71    current_stream: Option<(usize, NodeStream<N, N::Error>)>,
72    child_streams_futs: StreamQueue<N, N::Error>,
73    max_depth: Option<usize>,
74    allow_circles: bool,
75    visited: HashSet<N>,
76}
77
78impl<N> Bfs<N>
79where
80    N: Node + Send + Unpin + Clone + 'static,
81    N::Error: Send + 'static,
82{
83    #[inline]
84    /// Creates a new [`Bfs`] stream.
85    ///
86    /// The BFS will be performed from the `root` node up to depth `max_depth`.
87    ///
88    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
89    ///
90    /// [`Bfs`]: struct@crate::async::Bfs
91    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
92    where
93        R: Into<N>,
94        D: Into<Option<usize>>,
95    {
96        let root = root.into();
97        let max_depth = max_depth.into();
98        let mut child_streams_futs: StreamQueue<N, N::Error> = FuturesOrdered::new();
99        let depth = 1;
100        let child_stream_fut = Arc::new(root.clone())
101            .children(depth)
102            .map(move |stream| (depth, stream));
103        child_streams_futs.push_back(Box::pin(child_stream_fut));
104
105        Self {
106            current_stream: None,
107            child_streams_futs,
108            max_depth,
109            visited: HashSet::from_iter([root]),
110            allow_circles,
111        }
112    }
113}
114
115impl<N> Stream for Bfs<N>
116where
117    N: Node + Send + Clone + Unpin + 'static,
118    N::Error: Send + 'static,
119{
120    type Item = Result<N, N::Error>;
121
122    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123        let mut this = self.project();
124
125        // println!("------- poll");
126        // println!("has current stream: {:?}", this.current_stream.is_some());
127
128        loop {
129            let mut current_stream = this.current_stream.as_mut().as_pin_mut();
130            let next_item = match current_stream.as_deref_mut() {
131                Some((depth, stream)) => {
132                    let next_item = stream.as_mut().poll_next(cx);
133                    Some(next_item.map(|node| (depth, node)))
134                }
135                None => None,
136            };
137
138            // println!("next item: {:?}", next_item);
139            match next_item {
140                // stream item is ready but failure success
141                Some(Poll::Ready((_, Some(Err(err))))) => {
142                    return Poll::Ready(Some(Err(err)));
143                }
144                // stream item is ready and success
145                Some(Poll::Ready((depth, Some(Ok(node))))) => {
146                    if *this.allow_circles || !this.visited.contains(&node) {
147                        if !*this.allow_circles {
148                            this.visited.insert(node.clone());
149                        }
150
151                        if let Some(max_depth) = this.max_depth {
152                            if depth >= max_depth {
153                                return Poll::Ready(Some(Ok(node)));
154                            }
155                        }
156
157                        // add child stream future to be polled
158                        let arc_node = Arc::new(node.clone());
159                        let next_depth = *depth + 1;
160                        let child_stream_fut = arc_node
161                            .children(next_depth)
162                            .map(move |stream| (next_depth, stream));
163                        this.child_streams_futs
164                            .push_back(Box::pin(child_stream_fut));
165
166                        return Poll::Ready(Some(Ok(node)));
167                    }
168                }
169                // stream item is pending
170                Some(Poll::Pending) => {
171                    return Poll::Pending;
172                }
173                // no current stream or completed
174                Some(Poll::Ready((_, None))) | None => {
175                    // proceed to poll the next stream
176                }
177            }
178
179            // poll the next stream
180            // println!("child stream futs: {:?}", this.child_streams_futs.len());
181            match this.child_streams_futs.poll_next_unpin(cx) {
182                Poll::Ready(Some((depth, stream))) => {
183                    // println!(
184                    //     "child stream fut depth {} completed: {:?}",
185                    //     depth,
186                    //     stream.is_ok()
187                    // );
188                    let stream = match stream {
189                        Ok(stream) => stream.boxed(),
190                        Err(err) => futures::stream::iter([Err(err)]).boxed(),
191                    };
192                    this.current_stream.set(Some((depth, Box::pin(stream))));
193                }
194                // when there are no more child stream futures,
195                // we are done
196                Poll::Ready(None) => {
197                    // println!("no more child streams");
198                    return Poll::Ready(None);
199                }
200                // still waiting for the next stream
201                Poll::Pending => {
202                    // println!("child stream is still pending");
203                    return Poll::Pending;
204                }
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::utils::test;
214    use anyhow::Result;
215    use futures::StreamExt;
216    use pretty_assertions::assert_eq;
217    use std::cmp::Ordering;
218    use tokio::time::{sleep, Duration};
219
220    macro_rules! depths {
221        ($stream:ident) => {{
222            $stream
223                // collect the entire stream
224                .collect::<Vec<_>>()
225                .await
226                .into_iter()
227                // fail on first error
228                .collect::<Result<Vec<_>, _>>()?
229                .into_iter()
230                // get depth
231                .map(|item| item.0)
232                .collect::<Vec<_>>()
233        }};
234    }
235
236    macro_rules! test_depths_unordered {
237        ($name:ident: $values:expr) => {
238            paste::item! {
239                #[tokio::test(flavor = "multi_thread")]
240                async fn [< test_ $name _ unordered >] () -> Result<()> {
241                    let (iter, expected_depths) = $values;
242                    let iter = iter
243                        .map(|node| async move {
244                            sleep(Duration::from_millis(100)).await;
245                            node
246                        })
247                        .buffer_unordered(8);
248                    let depths = depths!(iter);
249                    assert!(test::is_monotonic(&depths, Ordering::Greater));
250                    test::assert_eq_vec!(depths, expected_depths);
251                    Ok(())
252                }
253            }
254        };
255    }
256
257    macro_rules! test_depths_ordered {
258        ($name:ident: $values:expr) => {
259            paste::item! {
260                #[tokio::test(flavor = "multi_thread")]
261                async fn [< test_ $name _ ordered >] () -> Result<()> {
262                    let (iter, expected_depths) = $values;
263                    let iter = iter
264                        .map(|node| async move {
265                            sleep(Duration::from_millis(100)).await;
266                            node
267                        })
268                        .buffered(8);
269                    let depths = depths!(iter);
270                    assert!(test::is_monotonic(&depths, Ordering::Greater));
271                    assert_eq!(depths, expected_depths);
272                    Ok(())
273                }
274            }
275        };
276    }
277
278    macro_rules! test_depths {
279        ($name:ident: $values:expr, $($macro:ident,)*) => {
280            $(
281                $macro!($name: $values);
282            )*
283        }
284    }
285
286    test_depths!(
287        bfs:
288        (
289            Bfs::<test::Node>::new(0, 3, true),
290            [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]
291        ),
292        test_depths_ordered,
293        test_depths_unordered,
294    );
295
296    test_depths!(
297        bfs_no_circles:
298        (
299            Bfs::<test::Node>::new(0, 3, false),
300            [1, 2, 3]
301        ),
302        test_depths_ordered,
303        test_depths_unordered,
304    );
305}