1use super::queue;
2use super::{ExtendQueue, FastNode, Node, Queue};
3use std::iter::Iterator;
4
5#[allow(clippy::module_name_repetitions)]
6#[derive(Debug, Clone)]
7pub struct Dfs<N>
45where
46 N: Node,
47{
48 queue: queue::Queue<N, N::Error>,
49 max_depth: Option<usize>,
50}
51
52impl<N> Dfs<N>
53where
54 N: Node,
55{
56 #[inline]
57 pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
65 where
66 R: Into<N>,
67 D: Into<Option<usize>>,
68 {
69 let mut queue = queue::Queue::new(allow_circles);
70 let root = root.into();
71 let max_depth = max_depth.into();
72 let depth = 1;
73 match root.children(depth) {
74 Ok(children) => queue.add_all(depth, children),
75 Err(err) => queue.add(depth, Err(err)),
76 }
77 Self { queue, max_depth }
78 }
79}
80
81impl<N> Iterator for Dfs<N>
82where
83 N: Node,
84{
85 type Item = Result<N, N::Error>;
86
87 #[inline]
88 fn next(&mut self) -> Option<Self::Item> {
89 match self.queue.pop_back() {
90 Some((_, Err(err))) => Some(Err(err)),
92 Some((depth, Ok(node))) => {
94 if let Some(max_depth) = self.max_depth {
95 if depth >= max_depth {
96 return Some(Ok(node));
97 }
98 }
99
100 match node.children(depth + 1) {
101 Ok(children) => {
102 self.queue.add_all(depth + 1, children);
103 }
104 Err(err) => self.queue.add(depth + 1, Err(err)),
105 };
106 Some(Ok(node))
107 }
108 None => None,
110 }
111 }
112}
113
114#[allow(clippy::module_name_repetitions)]
115#[derive(Debug, Clone)]
116pub struct FastDfs<N>
155where
156 N: FastNode,
157{
158 queue: queue::Queue<N, N::Error>,
159 max_depth: Option<usize>,
160}
161
162impl<N> FastDfs<N>
163where
164 N: FastNode,
165{
166 #[inline]
167 pub fn new<R, D>(root: R, max_depth: D, allow_circles: bool) -> Self
175 where
176 R: Into<N>,
177 D: Into<Option<usize>>,
178 {
179 let mut queue = queue::Queue::new(allow_circles);
180 let root: N = root.into();
181 let max_depth = max_depth.into();
182 let depth = 1;
183 let mut depth_queue = queue::QueueWrapper::new(depth, &mut queue);
184 if let Err(err) = root.add_children(depth, &mut depth_queue) {
185 depth_queue.add(Err(err));
186 }
187 Self { queue, max_depth }
188 }
189}
190
191impl<N> Iterator for FastDfs<N>
192where
193 N: FastNode,
194{
195 type Item = Result<N, N::Error>;
196
197 #[inline]
198 fn next(&mut self) -> Option<Self::Item> {
199 match self.queue.pop_back() {
200 Some((_, Err(err))) => Some(Err(err)),
202 Some((depth, Ok(node))) => {
204 if let Some(max_depth) = self.max_depth {
205 if depth >= max_depth {
206 return Some(Ok(node));
207 }
208 }
209 let next_depth = depth + 1;
210 let mut depth_queue = queue::QueueWrapper::new(next_depth, &mut self.queue);
211 if let Err(err) = node.add_children(next_depth, &mut depth_queue) {
212 depth_queue.add(Err(err));
213 }
214 Some(Ok(node))
215 }
216 None => None,
218 }
219 }
220}
221
222#[cfg(feature = "rayon")]
223#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
224mod par {
225 use crate::sync::par::parallel_iterator;
226 use crate::sync::{Dfs, FastDfs, FastNode, Node};
227
228 parallel_iterator!(Dfs<Node>);
229 parallel_iterator!(FastDfs<FastNode>);
230}
231
232#[cfg(feature = "rayon")]
233pub use par::*;
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::utils::test;
239 use anyhow::Result;
240 use pretty_assertions::assert_eq;
241
242 #[cfg(feature = "rayon")]
243 use rayon::iter::{IntoParallelIterator, ParallelIterator};
244
245 macro_rules! depths {
246 ($iter:ident) => {{
247 $iter
248 .collect::<Result<Vec<_>, _>>()?
250 .into_iter()
251 .map(|item| item.0)
253 .collect::<Vec<_>>()
254 }};
255 }
256
257 macro_rules! test_depths_serial {
258 ($name:ident: $values:expr) => {
259 paste::item! {
260 #[test]
261 fn [< test_ $name _ serial >] () -> Result<()> {
262 let (iter, expected_depths) = $values;
263 let depths = depths!(iter);
264 assert_eq!(depths, expected_depths);
265 Ok(())
266 }
267 }
268 };
269 }
270
271 macro_rules! test_depths_parallel {
272 ($name:ident: $values:expr) => {
273 paste::item! {
274 #[cfg(feature = "rayon")]
275 #[test]
276 fn [< test_ $name _ parallel >] () -> Result<()> {
277 let (iter, expected_depths) = $values;
278 let iter = iter.into_par_iter();
279 let depths = depths!(iter);
280 test::assert_eq_vec!(depths, expected_depths);
281 Ok(())
282 }
283 }
284 };
285 }
286
287 macro_rules! test_depths {
288 ($name:ident: $values:expr, $($macro:ident,)*) => {
289 $(
290 $macro!($name: $values);
291 )*
292 }
293 }
294
295 test_depths!(
296 dfs:
297 (
298 Dfs::<test::Node>::new(0, 3, true),
299 [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
300 ),
301 test_depths_serial,
302 test_depths_parallel,
303 );
304
305 test_depths!(
306 fast_dfs:
307 (
308 FastDfs::<test::Node>::new(0, 3, true),
309 [1, 2, 3, 3, 2, 3, 3, 1, 2, 3, 3, 2, 3, 3]
310 ),
311 test_depths_serial,
312 test_depths_parallel,
313 );
314
315 test_depths!(
316 fast_dfs_no_circles:
317 (
318 FastDfs::<test::Node>::new(0, 3, false),
319 [1, 2, 3]
320 ),
321 test_depths_serial,
322 );
323
324 test_depths!(
325 dfs_no_circles:
326 (
327 Dfs::<test::Node>::new(0, 3, false),
328 [1, 2, 3]
329 ),
330 test_depths_serial,
331 test_depths_parallel,
332 );
333}