par_dfs/sync/
bfs.rs

1use super::queue;
2use super::{ExtendQueue, FastNode, Node, Queue};
3use std::iter::Iterator;
4
5#[allow(clippy::module_name_repetitions)]
6#[derive(Debug, Clone)]
7/// Synchronous breadth-first iterator for types implementing the [`Node`] trait.
8///
9/// ### Example
10/// ```
11/// use par_dfs::sync::{Node, Bfs, NodeIter};
12///
13/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
14/// struct WordNode(String);
15///
16/// impl Node for WordNode {
17///     type Error = std::convert::Infallible;
18///
19///     fn children(&self, _depth: usize) -> NodeIter<Self, Self::Error> {
20///         let len = self.0.len();
21///         let nodes: Vec<String> = if len > 1 {
22///             let mid = len/2;
23///             vec![self.0[..mid].into(), self.0[mid..].into()]
24///         } else {
25///             assert!(len == 1);
26///             vec![self.0.clone()]
27///         };
28///         let nodes = nodes.into_iter()
29///             .map(Self)
30///             .map(Result::Ok);
31///         Ok(Box::new(nodes))
32///     }
33/// }
34///
35/// let word = "Hello World";
36/// let root = WordNode(word.into());
37/// let limit = (word.len() as f32).log2().ceil() as usize;
38///
39/// let bfs = Bfs::<WordNode>::new(root, limit, true);
40/// let output = bfs.collect::<Result<Vec<_>, _>>().unwrap();
41/// let result = output[output.len()-word.len()..]
42///     .into_iter().map(|s| s.0.as_str()).collect::<String>();
43/// assert_eq!(result, "Hello World");
44/// ```
45///
46/// [`Node`]: trait@crate::sync::Node
47pub struct Bfs<N>
48where
49    N: Node,
50{
51    queue: queue::Queue<N, N::Error>,
52    max_depth: Option<usize>,
53}
54
55impl<N> Bfs<N>
56where
57    N: Node,
58{
59    #[inline]
60    /// Creates a new [`Bfs`] iterator.
61    ///
62    /// The BFS will be performed from the `root` node up to depth `max_depth`.
63    ///
64    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
65    ///
66    /// [`Bfs`]: struct@crate::sync::Bfs
67    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
68    where
69        R: Into<N>,
70        D: Into<Option<usize>>,
71    {
72        let mut queue = queue::Queue::new(allow_circles);
73        let root = root.into();
74        let max_depth = max_depth.into();
75
76        let depth = 1;
77        match root.children(depth) {
78            Ok(children) => queue.add_all(depth, children),
79            Err(err) => queue.add(0, Err(err)),
80        }
81
82        Self { queue, max_depth }
83    }
84}
85
86impl<N> Iterator for Bfs<N>
87where
88    N: Node,
89{
90    type Item = Result<N, N::Error>;
91
92    #[inline]
93    fn next(&mut self) -> Option<Self::Item> {
94        match self.queue.pop_front() {
95            // next node failed
96            Some((_, Err(err))) => Some(Err(err)),
97            // next node succeeded
98            Some((depth, Ok(node))) => {
99                if let Some(max_depth) = self.max_depth {
100                    if depth >= max_depth {
101                        return Some(Ok(node));
102                    }
103                }
104                match node.children(depth + 1) {
105                    Ok(children) => {
106                        self.queue.add_all(depth + 1, children);
107                    }
108                    Err(err) => self.queue.add(depth + 1, Err(err)),
109                };
110                Some(Ok(node))
111            }
112            // no next node
113            None => None,
114        }
115    }
116}
117
118#[allow(clippy::module_name_repetitions)]
119#[derive(Debug, Clone)]
120/// Synchronous, fast breadth-first iterator for types implementing the [`FastNode`] trait.
121///
122/// ### Example
123/// ```
124/// use par_dfs::sync::{FastNode, FastBfs, ExtendQueue, NodeIter};
125///
126/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
127/// struct WordNode(String);
128///
129/// impl FastNode for WordNode {
130///     type Error = std::convert::Infallible;
131///
132///     fn add_children<E>(
133///         &self, _depth: usize, queue: &mut E
134///     ) -> Result<(), Self::Error>
135///     where
136///         E: ExtendQueue<Self, Self::Error>,
137///     {
138///         let len = self.0.len();
139///         if len > 1 {
140///             let mid = len/2;
141///             queue.add(Ok(Self(self.0[..mid].into())));
142///             queue.add(Ok(Self(self.0[mid..].into())));
143///         } else {
144///             assert!(len == 1);
145///             queue.add(Ok(Self(self.0.clone())));
146///         }
147///         Ok(())
148///     }
149/// }
150///
151/// let word = "Hello World";
152/// let root = WordNode(word.into());
153/// let limit = (word.len() as f32).log2().ceil() as usize;
154///
155/// let bfs = FastBfs::<WordNode>::new(root, limit, true);
156/// let output = bfs.collect::<Result<Vec<_>, _>>().unwrap();
157/// let result = output[output.len()-word.len()..]
158///     .into_iter().map(|s| s.0.as_str()).collect::<String>();
159/// assert_eq!(result, "Hello World");
160/// ```
161///
162/// [`FastNode`]: trait@crate::sync::FastNode
163pub struct FastBfs<N>
164where
165    N: FastNode,
166{
167    queue: queue::Queue<N, N::Error>,
168    max_depth: Option<usize>,
169}
170
171impl<N> FastBfs<N>
172where
173    N: FastNode,
174{
175    #[inline]
176    /// Creates a new [`FastBfs`] iterator.
177    ///
178    /// The BFS will be performed from the `root` node up to depth `max_depth`.
179    ///
180    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
181    ///
182    /// [`FastBfs`]: struct@crate::sync::FastBfs
183    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
184    where
185        R: Into<N>,
186        D: Into<Option<usize>>,
187    {
188        let mut queue = queue::Queue::new(allow_circles);
189        let root: N = root.into();
190        let max_depth = max_depth.into();
191        let depth = 1;
192        let mut depth_queue = queue::QueueWrapper::new(depth, &mut queue);
193        if let Err(err) = root.add_children(depth, &mut depth_queue) {
194            depth_queue.add(Err(err));
195        }
196        Self { queue, max_depth }
197    }
198}
199
200impl<N> Iterator for FastBfs<N>
201where
202    N: FastNode,
203{
204    type Item = Result<N, N::Error>;
205
206    #[inline]
207    fn next(&mut self) -> Option<Self::Item> {
208        match self.queue.pop_front() {
209            // next node failed
210            Some((_, Err(err))) => Some(Err(err)),
211            // next node succeeded
212            Some((depth, Ok(node))) => {
213                if let Some(max_depth) = self.max_depth {
214                    if depth >= max_depth {
215                        return Some(Ok(node));
216                    }
217                }
218                let next_depth = depth + 1;
219                let mut depth_queue = queue::QueueWrapper::new(next_depth, &mut self.queue);
220                if let Err(err) = node.add_children(next_depth, &mut depth_queue) {
221                    depth_queue.add(Err(err));
222                }
223                Some(Ok(node))
224            }
225            // no next node
226            None => None,
227        }
228    }
229}
230
231#[cfg(feature = "rayon")]
232#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
233mod par {
234    use crate::sync::par::parallel_iterator;
235    use crate::sync::{Bfs, FastBfs, FastNode, Node};
236
237    parallel_iterator!(Bfs<Node>);
238    parallel_iterator!(FastBfs<FastNode>);
239}
240
241#[cfg(feature = "rayon")]
242pub use par::*;
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::utils::test;
248    use anyhow::Result;
249    use pretty_assertions::assert_eq;
250    use std::cmp::Ordering;
251
252    #[cfg(feature = "rayon")]
253    use rayon::iter::{IntoParallelIterator, ParallelIterator};
254
255    macro_rules! depths {
256        ($iter:ident) => {{
257            $iter
258                // fail on first error
259                .collect::<Result<Vec<_>, _>>()?
260                .into_iter()
261                // get depth
262                .map(|item| item.0)
263                .collect::<Vec<_>>()
264        }};
265    }
266
267    macro_rules! test_depths_serial {
268        ($name:ident: $values:expr) => {
269            paste::item! {
270                #[test]
271                fn [< test_ $name _ serial >] () -> Result<()> {
272                    let (iter, expected_depths) = $values;
273                    let depths = depths!(iter);
274                    assert!(test::is_monotonic(&depths, Ordering::Greater));
275                    assert_eq!(depths, expected_depths);
276                    Ok(())
277                }
278            }
279        };
280    }
281
282    macro_rules! test_depths_parallel {
283        ($name:ident: $values:expr) => {
284            paste::item! {
285                #[cfg(feature = "rayon")]
286                #[test]
287                fn [< test_ $name _ parallel >] () -> Result<()> {
288                    let (iter, expected_depths) = $values;
289                    let iter = iter.into_par_iter();
290                    let depths = depths!(iter);
291                    test::assert_eq_vec!(depths, expected_depths);
292                    Ok(())
293                }
294            }
295        };
296    }
297
298    macro_rules! test_depths {
299        ($name:ident: $values:expr, $($macro:ident,)*) => {
300            $(
301                $macro!($name: $values);
302            )*
303        }
304    }
305
306    test_depths!(
307        bfs:
308        (
309            Bfs::<test::Node>::new(0, 3, true),
310            [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]
311        ),
312        test_depths_serial,
313        test_depths_parallel,
314    );
315
316    test_depths!(
317        fast_bfs:
318        (
319            FastBfs::<test::Node>::new(0, 3, true),
320            [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]
321        ),
322        test_depths_serial,
323        test_depths_parallel,
324    );
325
326    test_depths!(
327        fast_bfs_no_circles:
328        (
329            FastBfs::<test::Node>::new(0, 3, false),
330            [1, 2, 3]
331        ),
332        test_depths_serial,
333    );
334
335    test_depths!(
336        bfs_no_circles:
337        (
338            Bfs::<test::Node>::new(0, 3, false),
339            [1, 2, 3]
340        ),
341        test_depths_serial,
342        test_depths_parallel,
343    );
344}