par_dfs/async/bfs.rs
1use super::{Node, NodeStream, StreamQueue};
2
3use futures::stream::{FuturesOrdered, Stream, StreamExt};
4use futures::FutureExt;
5use pin_project::pin_project;
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11#[allow(clippy::module_name_repetitions)]
12#[derive(Default)]
13#[pin_project]
14/// Asynchronous breadth-first stream for types implementing the [`Node`] trait.
15///
16/// ### Example
17/// ```
18/// use futures::StreamExt;
19/// use par_dfs::r#async::{Node, Bfs, NodeStream};
20///
21/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
22/// struct WordNode(String);
23///
24/// #[async_trait::async_trait]
25/// impl Node for WordNode {
26/// type Error = std::convert::Infallible;
27///
28/// async fn children(
29/// self: std::sync::Arc<Self>,
30/// _depth: usize
31/// ) -> Result<NodeStream<Self, Self::Error>, Self::Error> {
32/// let len = self.0.len();
33/// let nodes: Vec<String> = if len > 1 {
34/// let mid = len/2;
35/// vec![self.0[..mid].into(), self.0[mid..].into()]
36/// } else {
37/// assert!(len == 1);
38/// vec![self.0.clone()]
39/// };
40/// let nodes = nodes.into_iter()
41/// .map(Self)
42/// .map(Result::Ok);
43/// let stream = futures::stream::iter(nodes);
44/// Ok(Box::pin(stream.boxed()))
45/// }
46/// }
47///
48/// let result = tokio_test::block_on(async {
49/// let word = "Hello World";
50/// let root = WordNode(word.into());
51/// let limit = (word.len() as f32).log2().ceil() as usize;
52/// let bfs = Bfs::<WordNode>::new(root, limit, true);
53/// let output = bfs
54/// .collect::<Vec<_>>()
55/// .await
56/// .into_iter()
57/// .collect::<Result<Vec<_>, _>>()
58/// .unwrap();
59/// output[output.len()-word.len()..]
60/// .into_iter().map(|s| s.0.as_str()).collect::<String>()
61/// });
62/// assert_eq!(result, "Hello World");
63/// ```
64///
65/// [`Node`]: trait@crate::async::Node
66pub struct Bfs<N>
67where
68 N: Node,
69{
70 #[pin]
71 current_stream: Option<(usize, NodeStream<N, N::Error>)>,
72 child_streams_futs: StreamQueue<N, N::Error>,
73 max_depth: Option<usize>,
74 allow_circles: bool,
75 visited: HashSet<N>,
76}
77
78impl<N> Bfs<N>
79where
80 N: Node + Send + Unpin + Clone + 'static,
81 N::Error: Send + 'static,
82{
83 #[inline]
84 /// Creates a new [`Bfs`] stream.
85 ///
86 /// The BFS will be performed from the `root` node up to depth `max_depth`.
87 ///
88 /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
89 ///
90 /// [`Bfs`]: struct@crate::async::Bfs
91 pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
92 where
93 R: Into<N>,
94 D: Into<Option<usize>>,
95 {
96 let root = root.into();
97 let max_depth = max_depth.into();
98 let mut child_streams_futs: StreamQueue<N, N::Error> = FuturesOrdered::new();
99 let depth = 1;
100 let child_stream_fut = Arc::new(root.clone())
101 .children(depth)
102 .map(move |stream| (depth, stream));
103 child_streams_futs.push_back(Box::pin(child_stream_fut));
104
105 Self {
106 current_stream: None,
107 child_streams_futs,
108 max_depth,
109 visited: HashSet::from_iter([root]),
110 allow_circles,
111 }
112 }
113}
114
115impl<N> Stream for Bfs<N>
116where
117 N: Node + Send + Clone + Unpin + 'static,
118 N::Error: Send + 'static,
119{
120 type Item = Result<N, N::Error>;
121
122 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 let mut this = self.project();
124
125 // println!("------- poll");
126 // println!("has current stream: {:?}", this.current_stream.is_some());
127
128 loop {
129 let mut current_stream = this.current_stream.as_mut().as_pin_mut();
130 let next_item = match current_stream.as_deref_mut() {
131 Some((depth, stream)) => {
132 let next_item = stream.as_mut().poll_next(cx);
133 Some(next_item.map(|node| (depth, node)))
134 }
135 None => None,
136 };
137
138 // println!("next item: {:?}", next_item);
139 match next_item {
140 // stream item is ready but failure success
141 Some(Poll::Ready((_, Some(Err(err))))) => {
142 return Poll::Ready(Some(Err(err)));
143 }
144 // stream item is ready and success
145 Some(Poll::Ready((depth, Some(Ok(node))))) => {
146 if *this.allow_circles || !this.visited.contains(&node) {
147 if !*this.allow_circles {
148 this.visited.insert(node.clone());
149 }
150
151 if let Some(max_depth) = this.max_depth {
152 if depth >= max_depth {
153 return Poll::Ready(Some(Ok(node)));
154 }
155 }
156
157 // add child stream future to be polled
158 let arc_node = Arc::new(node.clone());
159 let next_depth = *depth + 1;
160 let child_stream_fut = arc_node
161 .children(next_depth)
162 .map(move |stream| (next_depth, stream));
163 this.child_streams_futs
164 .push_back(Box::pin(child_stream_fut));
165
166 return Poll::Ready(Some(Ok(node)));
167 }
168 }
169 // stream item is pending
170 Some(Poll::Pending) => {
171 return Poll::Pending;
172 }
173 // no current stream or completed
174 Some(Poll::Ready((_, None))) | None => {
175 // proceed to poll the next stream
176 }
177 }
178
179 // poll the next stream
180 // println!("child stream futs: {:?}", this.child_streams_futs.len());
181 match this.child_streams_futs.poll_next_unpin(cx) {
182 Poll::Ready(Some((depth, stream))) => {
183 // println!(
184 // "child stream fut depth {} completed: {:?}",
185 // depth,
186 // stream.is_ok()
187 // );
188 let stream = match stream {
189 Ok(stream) => stream.boxed(),
190 Err(err) => futures::stream::iter([Err(err)]).boxed(),
191 };
192 this.current_stream.set(Some((depth, Box::pin(stream))));
193 }
194 // when there are no more child stream futures,
195 // we are done
196 Poll::Ready(None) => {
197 // println!("no more child streams");
198 return Poll::Ready(None);
199 }
200 // still waiting for the next stream
201 Poll::Pending => {
202 // println!("child stream is still pending");
203 return Poll::Pending;
204 }
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::utils::test;
214 use anyhow::Result;
215 use futures::StreamExt;
216 use pretty_assertions::assert_eq;
217 use std::cmp::Ordering;
218 use tokio::time::{sleep, Duration};
219
220 macro_rules! depths {
221 ($stream:ident) => {{
222 $stream
223 // collect the entire stream
224 .collect::<Vec<_>>()
225 .await
226 .into_iter()
227 // fail on first error
228 .collect::<Result<Vec<_>, _>>()?
229 .into_iter()
230 // get depth
231 .map(|item| item.0)
232 .collect::<Vec<_>>()
233 }};
234 }
235
236 macro_rules! test_depths_unordered {
237 ($name:ident: $values:expr) => {
238 paste::item! {
239 #[tokio::test(flavor = "multi_thread")]
240 async fn [< test_ $name _ unordered >] () -> Result<()> {
241 let (iter, expected_depths) = $values;
242 let iter = iter
243 .map(|node| async move {
244 sleep(Duration::from_millis(100)).await;
245 node
246 })
247 .buffer_unordered(8);
248 let depths = depths!(iter);
249 assert!(test::is_monotonic(&depths, Ordering::Greater));
250 test::assert_eq_vec!(depths, expected_depths);
251 Ok(())
252 }
253 }
254 };
255 }
256
257 macro_rules! test_depths_ordered {
258 ($name:ident: $values:expr) => {
259 paste::item! {
260 #[tokio::test(flavor = "multi_thread")]
261 async fn [< test_ $name _ ordered >] () -> Result<()> {
262 let (iter, expected_depths) = $values;
263 let iter = iter
264 .map(|node| async move {
265 sleep(Duration::from_millis(100)).await;
266 node
267 })
268 .buffered(8);
269 let depths = depths!(iter);
270 assert!(test::is_monotonic(&depths, Ordering::Greater));
271 assert_eq!(depths, expected_depths);
272 Ok(())
273 }
274 }
275 };
276 }
277
278 macro_rules! test_depths {
279 ($name:ident: $values:expr, $($macro:ident,)*) => {
280 $(
281 $macro!($name: $values);
282 )*
283 }
284 }
285
286 test_depths!(
287 bfs:
288 (
289 Bfs::<test::Node>::new(0, 3, true),
290 [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]
291 ),
292 test_depths_ordered,
293 test_depths_unordered,
294 );
295
296 test_depths!(
297 bfs_no_circles:
298 (
299 Bfs::<test::Node>::new(0, 3, false),
300 [1, 2, 3]
301 ),
302 test_depths_ordered,
303 test_depths_unordered,
304 );
305}