langchain_rust/document_loaders/
dir_loader.rs

1use async_recursion::async_recursion;
2use std::sync::Arc;
3use std::{fmt, path::Path, pin::Pin};
4use tokio::fs;
5
6use super::LoaderError;
7
8pub struct PathFilter(Arc<dyn Fn(&Path) -> bool + Send + Sync>);
9
10impl PathFilter {
11    pub fn new<F>(f: F) -> Self
12    where
13        F: Fn(&Path) -> bool + Send + Sync + 'static,
14    {
15        PathFilter(Arc::new(f))
16    }
17}
18
19impl fmt::Debug for PathFilter {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        write!(f, "Filter")
22    }
23}
24
25impl Clone for PathFilter {
26    fn clone(&self) -> Self {
27        PathFilter(Arc::clone(&self.0))
28    }
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DirLoaderOptions {
33    pub glob: Option<String>,
34    pub suffixes: Option<Vec<String>>,
35    pub path_filter: Option<PathFilter>,
36}
37
38/// Recursively list all files in a directory
39#[async_recursion]
40pub async fn list_files_in_path(
41    dir_path: &Path,
42    files: &mut Vec<String>,
43    opts: &DirLoaderOptions,
44) -> Result<Pin<Box<()>>, LoaderError> {
45    if dir_path.is_file() {
46        files.push(dir_path.to_string_lossy().to_string());
47        return Ok(Box::pin(()));
48    }
49    if !dir_path.is_dir() {
50        return Err(LoaderError::OtherError(format!(
51            "Path is not a directory: {:?}",
52            dir_path
53        )));
54    }
55    let mut reader = fs::read_dir(dir_path).await.unwrap();
56    while let Some(entry) = reader.next_entry().await.unwrap() {
57        let path = entry.path();
58        if path.is_file() {
59            files.push(path.to_string_lossy().to_string());
60        } else if path.is_dir() {
61            if opts
62                .path_filter
63                .as_ref()
64                .map_or(false, |f| f.0(path.as_path()))
65            {
66                continue;
67            }
68
69            list_files_in_path(&path, files, opts).await.unwrap();
70        }
71    }
72    Ok(Box::pin(()))
73}
74
75/// Find files in a directory that match the given options
76pub async fn find_files_with_extension(folder_path: &str, opts: &DirLoaderOptions) -> Vec<String> {
77    let mut matching_files = Vec::new();
78    let folder_path = Path::new(folder_path);
79    let mut all_files: Vec<String> = Vec::new();
80
81    list_files_in_path(folder_path, &mut all_files, &opts.clone())
82        .await
83        .unwrap();
84
85    for file_name in all_files {
86        let path_str = file_name.clone();
87
88        // check if the file has the required extension
89        if let Some(suffixes) = &opts.suffixes {
90            let mut has_suffix = false;
91            for suffix in suffixes {
92                if path_str.ends_with(suffix) {
93                    has_suffix = true;
94                    break;
95                }
96            }
97            if !has_suffix {
98                continue;
99            }
100        }
101
102        if opts
103            .path_filter
104            .as_ref()
105            .map_or(false, |f| f.0(&Path::new(&file_name)))
106        {
107            continue; // Skip this path if the filter returns true
108        }
109
110        // check if the file matches the glob pattern
111        if let Some(glob_pattern) = &opts.glob {
112            let glob = glob::Pattern::new(glob_pattern).unwrap();
113            if !glob.matches(&path_str) {
114                continue;
115            }
116        }
117
118        matching_files.push(path_str);
119    }
120
121    matching_files
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::env;
128
129    #[tokio::test]
130    async fn test_find_files_with_extension() {
131        // Create a temporary directory for testing
132        let temp_dir = env::temp_dir().join("dir_loader_test_dir");
133
134        if temp_dir.exists() {
135            fs::remove_dir_all(&temp_dir)
136                .await
137                .expect("Failed to remove existing directory");
138        }
139
140        fs::create_dir(&temp_dir)
141            .await
142            .expect("Failed to create temporary directory");
143        // Create some files with different extensions
144        let file_paths = [
145            temp_dir.as_path().join("file1.txt"),
146            temp_dir.as_path().join("file2.txt"),
147            temp_dir.as_path().join("file3.md"),
148            temp_dir.as_path().join("file4.txt"),
149        ];
150
151        // Write some content to the files
152        for path in &file_paths {
153            let content = "Hello, world!";
154            std::fs::write(path, content).expect("Failed to write file");
155        }
156
157        // Call the function to find files with the ".txt" extension
158        let found_files = find_files_with_extension(
159            temp_dir.as_path().to_str().unwrap(),
160            &DirLoaderOptions {
161                glob: None,
162                suffixes: Some(vec![".txt".to_string()]),
163                path_filter: None,
164            },
165        )
166        .await
167        .into_iter()
168        .collect::<Vec<_>>();
169
170        // Expecting to find 3 files with ".txt" extension
171        assert_eq!(found_files.len(), 3);
172        // Expecting each file name to contain ".txt" extension
173        for file in &found_files {
174            assert!(file.ends_with(".txt"));
175        }
176        assert!(found_files.contains(&temp_dir.join("file1.txt").to_string_lossy().to_string()),);
177        assert!(found_files.contains(&temp_dir.join("file2.txt").to_string_lossy().to_string()),);
178        assert!(found_files.contains(&temp_dir.join("file4.txt").to_string_lossy().to_string()),);
179
180        // Clean up: remove the temporary directory and its contents
181        fs::remove_dir_all(&temp_dir)
182            .await
183            .expect("Failed to remove temporary directory");
184    }
185}