async_walk/
walk.rs

1use crate::error::{Error, Result};
2use crossbeam_channel::{unbounded, Receiver, Sender};
3use futures::future::BoxFuture;
4use futures::stream::{Stream, StreamExt};
5use std::collections::HashSet;
6use std::fs::Metadata;
7
8#[cfg(not(target_os = "windows"))]
9use std::os::unix::fs::MetadataExt;
10#[cfg(target_os = "windows")]
11use std::os::windows::fs::MetadataExt;
12use std::path::PathBuf;
13use std::pin::Pin;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::sync::Arc;
16
17use std::task::{Context, Poll};
18use tokio::fs::{metadata, read_dir, DirEntry};
19use tokio::spawn;
20
21pub type FilterResult = ::std::result::Result<bool, Box<dyn ::std::error::Error + Send + Sync>>;
22
23// TODO: Explore making this type cleaner
24pub type Filter = Box<dyn Fn(&DirEntry) -> BoxFuture<FilterResult> + Sync + Send>;
25
26// use type alias for readability, represents an ino on linux or file index on windows
27type PathId = u64;
28
29enum Entry {
30    File(DirEntry),
31    Dir(DirEntry, PathId, u64),
32    Symlink(DirEntry, PathId, u64),
33    Root(Metadata, PathId),
34}
35
36pub struct Walk {
37    queue: Vec<(PathBuf, u64)>,
38    ready_entries: Vec<Result<DirEntry>>,
39    receiver: Receiver<Result<Entry>>,
40    sender: Sender<Result<Entry>>,
41    follow_symlinks: bool,
42    counter: Arc<AtomicUsize>,
43    concurrency_limit: Option<usize>,
44    visited: Option<HashSet<u64>>,
45    max_depth: Option<u64>,
46    filter: Option<Arc<Filter>>,
47}
48
49impl Walk {
50    pub fn new(
51        root: PathBuf,
52        follow_symlinks: bool,
53        concurrency_limit: Option<usize>,
54        max_level: Option<u64>,
55        filter: Option<Filter>,
56    ) -> Self {
57        let (tx, rx) = unbounded();
58        let visited = match follow_symlinks {
59            true => Some(HashSet::new()),
60            false => None,
61        };
62        Walk {
63            queue: vec![(root, 0)],
64            ready_entries: vec![],
65            receiver: rx,
66            sender: tx,
67            follow_symlinks: follow_symlinks,
68            counter: Arc::new(AtomicUsize::new(0)),
69            concurrency_limit: concurrency_limit,
70            visited: visited,
71            max_depth: max_level,
72            filter: filter.map(|f| Arc::new(f)),
73        }
74    }
75}
76
77fn unique_id(info: &Metadata) -> u64 {
78    // Called with fs::metadata so should never be None
79    #[cfg(target_os = "windows")]
80    let id = info.file_index().unwrap();
81
82    #[cfg(not(target_os = "windows"))]
83    let id = info.ino();
84    id
85}
86
87async fn handle_entry(
88    entry: Result<DirEntry>,
89    follow_symlinks: bool,
90    depth: u64,
91    filter: Option<Arc<Filter>>,
92) -> Result<Option<Entry>> {
93    let entry = entry?;
94    if let Some(filter) = filter {
95        let include = filter(&entry)
96            .await
97            .map_err(|e| Error::Filter(entry.path(), e))?;
98        if !include {
99            return Ok(None);
100        }
101    }
102    let file_type = entry
103        .file_type()
104        .await
105        .map_err(|e| Error::Io(entry.path(), e))?;
106    if file_type.is_dir() {
107        let unique_id = if follow_symlinks {
108            #[cfg(not(target_os = "windows"))]
109            let info = entry
110                .metadata()
111                .await
112                .map_err(|e| Error::Io(entry.path(), e))?;
113
114            // we can't use entry.metadata() on windows since it doesn't include the file index when called from DirEntry
115            #[cfg(target_os = "windows")]
116            let info = metadata(entry.path())
117                .await
118                .map_err(|e| Error::Io(entry.path(), e))?;
119
120            unique_id(&info)
121        } else {
122            0 // pass 0 since this will never be used
123        };
124        Ok(Some(Entry::Dir(entry, unique_id, depth)))
125    } else if file_type.is_symlink() && follow_symlinks {
126        // follow the symlink to get its type
127        let info = metadata(entry.path())
128            .await
129            .map_err(|e| Error::Io(entry.path(), e))?;
130        if info.is_dir() {
131            Ok(Some(Entry::Symlink(entry, unique_id(&info), depth)))
132        } else {
133            Ok(Some(Entry::File(entry)))
134        }
135    } else {
136        Ok(Some(Entry::File(entry)))
137    }
138}
139
140impl Stream for Walk {
141    type Item = Result<DirEntry>;
142
143    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
144        let walk = self.get_mut();
145        while !walk.queue.is_empty() {
146            // Check how many tasks are currently executing break if we have are at the limit
147            let counter = walk.counter.clone();
148            if let Some(limit) = walk.concurrency_limit {
149                if counter.load(Ordering::Relaxed) == limit {
150                    break;
151                }
152            }
153            // Guaranteed to not be none since we check if the queue is empty in the loop
154            let (p, depth) = walk.queue.pop().unwrap();
155            let sender = walk.sender.clone();
156            let filter = walk.filter.clone();
157            let follow_symlinks = walk.follow_symlinks;
158            // TODO: Double check ordering is what we want
159            counter.fetch_add(1, Ordering::Relaxed);
160            let waker = cx.waker().clone();
161            spawn(async move {
162                // TODO: If we add a include_root_dir option, remove follow_symlinks condition
163                // We have a special case for the root directory so we store it for symlink cycles
164                if depth == 0 && follow_symlinks {
165                    // use fs::metadata so we follow the symlink
166                    match metadata(&p).await {
167                        Ok(info) => {
168                            let id = unique_id(&info);
169                            let _ = sender.send(Ok(Entry::Root(info, id)));
170                        }
171                        Err(e) => {
172                            let _ = sender.send(Err(Error::Io(p.clone(), e)));
173                        }
174                    }
175                }
176                match read_dir(&p).await {
177                    Ok(entries) => {
178                        entries
179                            .map(|res| res.map_err(|e| Error::Io(p.clone(), e)))
180                            .for_each(|entry| async {
181                                let sender = sender.clone();
182                                let waker = waker.clone();
183                                let filter = filter.clone();
184                                match handle_entry(entry, follow_symlinks, depth + 1, filter).await
185                                {
186                                    Ok(entry) => {
187                                        if let Some(entry) = entry {
188                                            let _ = sender.send(Ok(entry));
189                                        }
190                                    }
191                                    Err(e) => {
192                                        let _ = sender.send(Err(e));
193                                    }
194                                };
195                                // Wake each time we send, since a result will be ready
196                                waker.wake();
197                            })
198                            .await;
199                    }
200                    Err(e) => {
201                        let _ = sender.send(Err(Error::Io(p, e)));
202                    }
203                };
204                // decrement counter since this task has finished
205                counter.fetch_sub(1, Ordering::Relaxed);
206                // Wake after decrementing counter since we might have been at the concurrency limit
207                waker.wake();
208            });
209        }
210        while let Ok(entry) = walk.receiver.try_recv() {
211            match entry {
212                Ok(entry) => match entry {
213                    Entry::Root(_, id) => {
214                        if walk.follow_symlinks {
215                            walk.visited.as_mut().unwrap().insert(id);
216                        }
217                    }
218                    Entry::File(entry) => {
219                        walk.ready_entries.push(Ok(entry));
220                    }
221                    Entry::Dir(entry, unique_id, depth) => {
222                        if walk
223                            .max_depth
224                            .map(|max_depth| depth < max_depth)
225                            .unwrap_or(true)
226                        {
227                            walk.queue.push((entry.path(), depth));
228                        }
229
230                        if walk.follow_symlinks {
231                            walk.visited
232                                .as_mut()
233                                .expect("BUG: This should always be Some")
234                                .insert(unique_id);
235                        }
236                        walk.ready_entries.push(Ok(entry));
237                    }
238                    Entry::Symlink(entry, link, depth) => {
239                        // Guaranteed to be Some this this is a symlink entry
240                        if walk
241                            .visited
242                            .as_ref()
243                            .expect("BUG: This should always be Some")
244                            .contains(&link)
245                        {
246                            walk.ready_entries
247                                .push(Err(Error::SymlinkCycle(entry.path())));
248                        } else {
249                            walk.queue.push((entry.path(), depth));
250                            walk.ready_entries.push(Ok(entry));
251                        }
252                    }
253                },
254                Err(e) => {
255                    walk.ready_entries.push(Err(e));
256                }
257            }
258        }
259
260        if let Some(entry) = walk.ready_entries.pop() {
261            Poll::Ready(Some(entry))
262        } else if walk.queue.is_empty() && walk.counter.load(Ordering::Relaxed) == 0 {
263            // We are done when ready entries is empty, there is nothing in the queue and no ongoing async tasks
264            Poll::Ready(None)
265        } else {
266            Poll::Pending
267        }
268    }
269}
270
271//pub struct WalkBuilder {}
272
273pub struct WalkBuilder {
274    root: PathBuf,
275    follow_symlinks: bool,
276    concurrency_limit: Option<usize>,
277    max_depth: Option<u64>,
278    filter: Option<Filter>,
279}
280
281impl WalkBuilder {
282    pub fn new(root: impl Into<PathBuf>) -> Self {
283        Self {
284            root: root.into(),
285            follow_symlinks: false,
286            concurrency_limit: None,
287            max_depth: None,
288            filter: None,
289        }
290    }
291
292    pub fn follow_symlinks(mut self, follow_symlinks: bool) -> Self {
293        self.follow_symlinks = follow_symlinks;
294        self
295    }
296
297    pub fn concurrency_limit(mut self, concurrency_limit: usize) -> Self {
298        self.concurrency_limit = Some(concurrency_limit);
299        self
300    }
301
302    pub fn max_depth<'a>(mut self, max_depth: u64) -> Self {
303        self.max_depth = Some(max_depth);
304        self
305    }
306
307    // TODO: Support just passing in a closure
308    pub fn filter(mut self, filter: Filter) -> Self {
309        self.filter = Some(filter);
310        self
311    }
312
313    pub fn build(self) -> Walk {
314        Walk::new(
315            self.root,
316            self.follow_symlinks,
317            self.concurrency_limit,
318            self.max_depth,
319            self.filter,
320        )
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use futures::FutureExt;
328    use tempfile::{tempdir, tempdir_in, NamedTempFile};
329    use tokio::fs::os::unix::symlink;
330
331    #[tokio::test(core_threads = 4)]
332    async fn test_single_level() {
333        let root = tempdir().unwrap();
334        let file = NamedTempFile::new_in(root.path()).unwrap();
335        let file2 = NamedTempFile::new_in(root.path()).unwrap();
336
337        let dir = tempdir_in(root.path()).unwrap();
338
339        let walk = WalkBuilder::new(root.path()).build();
340        let entries = walk
341            .map(|entry| entry.ok().map(|entry| entry.path()))
342            .collect::<Vec<Option<PathBuf>>>()
343            .await;
344        drop(root);
345        assert_eq!(entries.len(), 3);
346        assert_eq!(entries.contains(&Some(file.path().to_path_buf())), true);
347        assert_eq!(entries.contains(&Some(file2.path().to_path_buf())), true);
348        assert_eq!(entries.contains(&Some(dir.path().to_path_buf())), true);
349    }
350
351    #[tokio::test(core_threads = 4)]
352    async fn test_multi_level() {
353        let root = tempdir().unwrap();
354        let file = NamedTempFile::new_in(root.path()).unwrap();
355        let dir = tempdir_in(root.path()).unwrap();
356        let file2 = NamedTempFile::new_in(dir.path()).unwrap();
357
358        let walk = WalkBuilder::new(root.path()).build();
359        let entries = walk
360            .map(|entry| entry.ok().map(|entry| entry.path()))
361            .collect::<Vec<Option<PathBuf>>>()
362            .await;
363        drop(root);
364        assert_eq!(entries.len(), 3);
365        assert_eq!(entries.contains(&Some(file.path().to_path_buf())), true);
366        assert_eq!(entries.contains(&Some(file2.path().to_path_buf())), true);
367        assert_eq!(entries.contains(&Some(dir.path().to_path_buf())), true);
368    }
369
370    #[tokio::test(core_threads = 4)]
371    async fn test_max_depth() {
372        let root = tempdir().unwrap();
373        let file = NamedTempFile::new_in(root.path()).unwrap();
374        let dir1 = tempdir_in(root.path()).unwrap();
375        let dir2 = tempdir_in(dir1.path()).unwrap();
376        let file2 = NamedTempFile::new_in(dir2.path()).unwrap();
377
378        let walk = WalkBuilder::new(root.path().to_path_buf())
379            .max_depth(2)
380            .build();
381        let entries = walk
382            .map(|entry| entry.ok().map(|entry| entry.path()))
383            .collect::<Vec<Option<PathBuf>>>()
384            .await;
385        drop(file);
386        drop(dir1);
387        drop(dir2);
388        assert_eq!(entries.len(), 3);
389        assert_eq!(entries.contains(&Some(file2.path().to_path_buf())), false);
390    }
391
392    #[tokio::test(core_threads = 4)]
393    async fn test_follow_symlinks() {
394        let root = tempdir().unwrap();
395        let linked_dir = tempdir().unwrap();
396        let link = root.path().join("link");
397        symlink(linked_dir.path(), &link).await.unwrap();
398        let file = NamedTempFile::new_in(&link).unwrap();
399        let walk = WalkBuilder::new(root.path().to_path_buf())
400            .follow_symlinks(true)
401            .build();
402        let entries = walk
403            .map(|entry| {
404                entry.as_ref().unwrap();
405                entry.ok().map(|entry| entry.path())
406            })
407            .collect::<Vec<Option<PathBuf>>>()
408            .await;
409        drop(root);
410        drop(linked_dir);
411        assert_eq!(entries.len(), 2);
412        assert_eq!(entries.contains(&Some(file.path().to_path_buf())), true);
413    }
414
415    #[tokio::test(core_threads = 4)]
416    async fn test_does_not_follow_symlinks() {
417        let root = tempdir().unwrap();
418        let linked_dir = tempdir().unwrap();
419        let file = NamedTempFile::new_in(linked_dir.path()).unwrap();
420        symlink(&linked_dir, root.path().join("link"))
421            .await
422            .unwrap();
423        let walk = WalkBuilder::new(root.path()).max_depth(2).build();
424        let entries = walk
425            .map(|entry| {
426                entry.as_ref().unwrap();
427                entry.ok().map(|entry| entry.path())
428            })
429            .collect::<Vec<Option<PathBuf>>>()
430            .await;
431        drop(root);
432        drop(linked_dir);
433        assert_eq!(entries.len(), 1);
434        assert_eq!(entries.contains(&Some(file.path().to_path_buf())), false);
435    }
436
437    #[tokio::test(core_threads = 4)]
438    async fn test_symlink_cycle() {
439        let root = tempdir().unwrap();
440        let link = root.path().join("link");
441        symlink(root.path(), &link).await.unwrap();
442        let file = NamedTempFile::new_in(&link).unwrap();
443        let walk = WalkBuilder::new(root.path()).follow_symlinks(true).build();
444        let entries = walk.collect::<Vec<Result<DirEntry>>>().await;
445        // Search for the error in the entries
446        let find = entries.iter().find(|res| match res {
447            Err(e) => match e {
448                Error::SymlinkCycle(p) => p == &link,
449                _ => false,
450            },
451            _ => false,
452        });
453        drop(file);
454        drop(root);
455        assert_eq!(entries.len(), 2);
456        assert_eq!(find.is_some(), true);
457    }
458
459    #[tokio::test(core_threads = 4)]
460    async fn test_filter() {
461        let root = tempdir().unwrap();
462        let file = NamedTempFile::new_in(root.path()).unwrap();
463        let dir = tempdir_in(root.path()).unwrap();
464        let file2 = NamedTempFile::new_in(dir.path()).unwrap();
465        let filter_dir = dir.path().to_path_buf();
466        let filter: Filter = Box::new(move |entry| {
467            let filter_dir = filter_dir.clone();
468            async move { FilterResult::Ok(entry.path() != filter_dir) }.boxed()
469        });
470        let walk = WalkBuilder::new(root.path()).filter(filter).build();
471        let entries = walk
472            .map(|entry| {
473                entry.as_ref().unwrap();
474                entry.ok().map(|entry| entry.path())
475            })
476            .collect::<Vec<Option<PathBuf>>>()
477            .await;
478        assert_eq!(entries.len(), 1);
479        assert_eq!(entries.contains(&Some(file2.path().to_path_buf())), false);
480        assert_eq!(entries.contains(&Some(dir.path().to_path_buf())), false);
481        assert_eq!(entries.contains(&Some(file.path().to_path_buf())), true);
482        drop(root);
483    }
484
485    #[tokio::test(core_threads = 4)]
486    async fn test_filter_error() {
487        let root = tempdir().unwrap();
488        let file = NamedTempFile::new_in(root.path()).unwrap();
489
490        let filter: Filter =
491            Box::new(move |_entry| async move { FilterResult::Err("Error!!".into()) }.boxed());
492        let walk = WalkBuilder::new(root.path()).filter(filter).build();
493        let entries = walk.collect::<Vec<Result<DirEntry>>>().await;
494        // Search for the error in the entries
495        let find = entries.iter().find(|res| match res {
496            Err(e) => match e {
497                Error::Filter(_, _) => true,
498                _ => false,
499            },
500            _ => false,
501        });
502        assert_eq!(entries.len(), 1);
503        assert_eq!(find.is_some(), true);
504        drop(root);
505        drop(file);
506    }
507}