par_dfs/sync/
dfs.rs

1use super::queue;
2use super::{ExtendQueue, FastNode, Node, Queue};
3use std::iter::Iterator;
4
5#[allow(clippy::module_name_repetitions)]
6#[derive(Debug, Clone)]
7/// Synchronous depth-first iterator for types implementing the [`Node`] trait.
8///
9/// ### Example
10/// ```
11/// use par_dfs::sync::{Node, Dfs, NodeIter};
12///
13/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
14/// struct WordNode(String);
15///
16/// impl Node for WordNode {
17///     type Error = std::convert::Infallible;
18///
19///     fn children(&self, _depth: usize) -> NodeIter<Self, Self::Error> {
20///         let len = self.0.len();
21///         let nodes: Vec<String> = if len < 2 {
22///             vec![]
23///         } else {
24///             let mid = len/2;
25///             vec![self.0[mid..].into(), self.0[..mid].into()]
26///         };
27///         let nodes = nodes.into_iter()
28///             .map(Self)
29///             .map(Result::Ok);
30///         Ok(Box::new(nodes))
31///     }
32/// }
33///
34/// let root = WordNode("Hello World".into());
35/// let dfs = Dfs::<WordNode>::new(root, None, true);
36/// let output = dfs.collect::<Result<Vec<_>, _>>().unwrap();
37/// let result = output.into_iter()
38///     .filter_map(|s| if s.0.len() == 1 { Some(s.0) } else { None })
39///     .collect::<String>();
40/// assert_eq!(result, "Hello World");
41/// ```
42///
43/// [`Node`]: trait@crate::sync::Node
44pub struct Dfs<N>
45where
46    N: Node,
47{
48    queue: queue::Queue<N, N::Error>,
49    max_depth: Option<usize>,
50}
51
52impl<N> Dfs<N>
53where
54    N: Node,
55{
56    #[inline]
57    /// Creates a new [`Dfs`] iterator.
58    ///
59    /// The DFS will be performed from the `root` node up to depth `max_depth`.
60    ///
61    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
62    ///
63    /// [`Dfs`]: struct@crate::sync::Dfs
64    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
65    where
66        R: Into<N>,
67        D: Into<Option<usize>>,
68    {
69        let mut queue = queue::Queue::new(allow_circles);
70        let root = root.into();
71        let max_depth = max_depth.into();
72        let depth = 1;
73        match root.children(depth) {
74            Ok(children) => queue.add_all(depth, children),
75            Err(err) => queue.add(depth, Err(err)),
76        }
77        Self { queue, max_depth }
78    }
79}
80
81impl<N> Iterator for Dfs<N>
82where
83    N: Node,
84{
85    type Item = Result<N, N::Error>;
86
87    #[inline]
88    fn next(&mut self) -> Option<Self::Item> {
89        match self.queue.pop_back() {
90            // next node failed
91            Some((_, Err(err))) => Some(Err(err)),
92            // next node succeeded
93            Some((depth, Ok(node))) => {
94                if let Some(max_depth) = self.max_depth {
95                    if depth >= max_depth {
96                        return Some(Ok(node));
97                    }
98                }
99
100                match node.children(depth + 1) {
101                    Ok(children) => {
102                        self.queue.add_all(depth + 1, children);
103                    }
104                    Err(err) => self.queue.add(depth + 1, Err(err)),
105                };
106                Some(Ok(node))
107            }
108            // no next node
109            None => None,
110        }
111    }
112}
113
114#[allow(clippy::module_name_repetitions)]
115#[derive(Debug, Clone)]
116/// Synchronous, fast depth-first iterator for types implementing the [`FastNode`] trait.
117///
118/// ### Example
119/// ```
120/// use par_dfs::sync::{FastNode, FastDfs, ExtendQueue, NodeIter};
121///
122/// #[derive(PartialEq, Eq, Hash, Clone, Debug)]
123/// struct WordNode(String);
124///
125/// impl FastNode for WordNode {
126///     type Error = std::convert::Infallible;
127///
128///     fn add_children<E>(
129///         &self, _depth: usize, queue: &mut E
130///     ) -> Result<(), Self::Error>
131///     where
132///         E: ExtendQueue<Self, Self::Error>,
133///     {
134///         let len = self.0.len();
135///         if len > 1 {
136///             let mid = len/2;
137///             queue.add(Ok(Self(self.0[mid..].into())));
138///             queue.add(Ok(Self(self.0[..mid].into())));
139///         }
140///         Ok(())
141///     }
142/// }
143///
144/// let root = WordNode("Hello World".into());
145/// let dfs = FastDfs::<WordNode>::new(root, None, true);
146/// let output = dfs.collect::<Result<Vec<_>, _>>().unwrap();
147/// let result = output.into_iter()
148///     .filter_map(|s| if s.0.len() == 1 { Some(s.0) } else { None })
149///     .collect::<String>();
150/// assert_eq!(result, "Hello World");
151/// ```
152///
153/// [`FastNode`]: trait@crate::sync::FastNode
154pub struct FastDfs<N>
155where
156    N: FastNode,
157{
158    queue: queue::Queue<N, N::Error>,
159    max_depth: Option<usize>,
160}
161
162impl<N> FastDfs<N>
163where
164    N: FastNode,
165{
166    #[inline]
167    /// Creates a new [`FastDfs`] iterator.
168    ///
169    /// The DFS will be performed from the `root` node up to depth `max_depth`.
170    ///
171    /// When `allow_circles`, visited nodes will not be tracked, which can lead to cycles.
172    ///
173    /// [`FastDfs`]: struct@crate::sync::FastDfs
174    pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
175    where
176        R: Into<N>,
177        D: Into<Option<usize>>,
178    {
179        let mut queue = queue::Queue::new(allow_circles);
180        let root: N = root.into();
181        let max_depth = max_depth.into();
182        let depth = 1;
183        let mut depth_queue = queue::QueueWrapper::new(depth, &mut queue);
184        if let Err(err) = root.add_children(depth, &mut depth_queue) {
185            depth_queue.add(Err(err));
186        }
187        Self { queue, max_depth }
188    }
189}
190
191impl<N> Iterator for FastDfs<N>
192where
193    N: FastNode,
194{
195    type Item = Result<N, N::Error>;
196
197    #[inline]
198    fn next(&mut self) -> Option<Self::Item> {
199        match self.queue.pop_back() {
200            // next node failed
201            Some((_, Err(err))) => Some(Err(err)),
202            // next node succeeded
203            Some((depth, Ok(node))) => {
204                if let Some(max_depth) = self.max_depth {
205                    if depth >= max_depth {
206                        return Some(Ok(node));
207                    }
208                }
209                let next_depth = depth + 1;
210                let mut depth_queue = queue::QueueWrapper::new(next_depth, &mut self.queue);
211                if let Err(err) = node.add_children(next_depth, &mut depth_queue) {
212                    depth_queue.add(Err(err));
213                }
214                Some(Ok(node))
215            }
216            // no next node
217            None => None,
218        }
219    }
220}
221
222#[cfg(feature = "rayon")]
223#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
224mod par {
225    use crate::sync::par::parallel_iterator;
226    use crate::sync::{Dfs, FastDfs, FastNode, Node};
227
228    parallel_iterator!(Dfs<Node>);
229    parallel_iterator!(FastDfs<FastNode>);
230}
231
232#[cfg(feature = "rayon")]
233pub use par::*;
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::utils::test;
239    use anyhow::Result;
240    use pretty_assertions::assert_eq;
241
242    #[cfg(feature = "rayon")]
243    use rayon::iter::{IntoParallelIterator, ParallelIterator};
244
245    macro_rules! depths {
246        ($iter:ident) => {{
247            $iter
248                // fail on first error
249                .collect::<Result<Vec<_>, _>>()?
250                .into_iter()
251                // get depth
252                .map(|item| item.0)
253                .collect::<Vec<_>>()
254        }};
255    }
256
257    macro_rules! test_depths_serial {
258        ($name:ident: $values:expr) => {
259            paste::item! {
260                #[test]
261                fn [< test_ $name _ serial >] () -> Result<()> {
262                    let (iter, expected_depths) = $values;
263                    let depths = depths!(iter);
264                    assert_eq!(depths, expected_depths);
265                    Ok(())
266                }
267            }
268        };
269    }
270
271    macro_rules! test_depths_parallel {
272        ($name:ident: $values:expr) => {
273            paste::item! {
274                #[cfg(feature = "rayon")]
275                #[test]
276                fn [< test_ $name _ parallel >] () -> Result<()> {
277                    let (iter, expected_depths) = $values;
278                    let iter = iter.into_par_iter();
279                    let depths = depths!(iter);
280                    test::assert_eq_vec!(depths, expected_depths);
281                    Ok(())
282                }
283            }
284        };
285    }
286
287    macro_rules! test_depths {
288        ($name:ident: $values:expr, $($macro:ident,)*) => {
289            $(
290                $macro!($name: $values);
291            )*
292        }
293    }
294
295    test_depths!(
296        dfs:
297        (
298            Dfs::<test::Node>::new(0, 3, true),
299            [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
300        ),
301        test_depths_serial,
302        test_depths_parallel,
303    );
304
305    test_depths!(
306        fast_dfs:
307        (
308            FastDfs::<test::Node>::new(0, 3, true),
309            [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
310        ),
311        test_depths_serial,
312        test_depths_parallel,
313    );
314
315    test_depths!(
316        fast_dfs_no_circles:
317        (
318            FastDfs::<test::Node>::new(0, 3, false),
319            [1, 2, 3]
320        ),
321        test_depths_serial,
322    );
323
324    test_depths!(
325        dfs_no_circles:
326        (
327            Dfs::<test::Node>::new(0, 3, false),
328            [1, 2, 3]
329        ),
330        test_depths_serial,
331        test_depths_parallel,
332    );
333}