langchain_rust/document_loaders/
dir_loader.rs1use async_recursion::async_recursion;
2use std::sync::Arc;
3use std::{fmt, path::Path, pin::Pin};
4use tokio::fs;
5
6use super::LoaderError;
7
8pub struct PathFilter(Arc<dyn Fn(&Path) -> bool + Send + Sync>);
9
10impl PathFilter {
11 pub fn new<F>(f: F) -> Self
12 where
13 F: Fn(&Path) -> bool + Send + Sync + 'static,
14 {
15 PathFilter(Arc::new(f))
16 }
17}
18
19impl fmt::Debug for PathFilter {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 write!(f, "Filter")
22 }
23}
24
25impl Clone for PathFilter {
26 fn clone(&self) -> Self {
27 PathFilter(Arc::clone(&self.0))
28 }
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DirLoaderOptions {
33 pub glob: Option<String>,
34 pub suffixes: Option<Vec<String>>,
35 pub path_filter: Option<PathFilter>,
36}
37
38#[async_recursion]
40pub async fn list_files_in_path(
41 dir_path: &Path,
42 files: &mut Vec<String>,
43 opts: &DirLoaderOptions,
44) -> Result<Pin<Box<()>>, LoaderError> {
45 if dir_path.is_file() {
46 files.push(dir_path.to_string_lossy().to_string());
47 return Ok(Box::pin(()));
48 }
49 if !dir_path.is_dir() {
50 return Err(LoaderError::OtherError(format!(
51 "Path is not a directory: {:?}",
52 dir_path
53 )));
54 }
55 let mut reader = fs::read_dir(dir_path).await.unwrap();
56 while let Some(entry) = reader.next_entry().await.unwrap() {
57 let path = entry.path();
58 if path.is_file() {
59 files.push(path.to_string_lossy().to_string());
60 } else if path.is_dir() {
61 if opts
62 .path_filter
63 .as_ref()
64 .map_or(false, |f| f.0(path.as_path()))
65 {
66 continue;
67 }
68
69 list_files_in_path(&path, files, opts).await.unwrap();
70 }
71 }
72 Ok(Box::pin(()))
73}
74
75pub async fn find_files_with_extension(folder_path: &str, opts: &DirLoaderOptions) -> Vec<String> {
77 let mut matching_files = Vec::new();
78 let folder_path = Path::new(folder_path);
79 let mut all_files: Vec<String> = Vec::new();
80
81 list_files_in_path(folder_path, &mut all_files, &opts.clone())
82 .await
83 .unwrap();
84
85 for file_name in all_files {
86 let path_str = file_name.clone();
87
88 if let Some(suffixes) = &opts.suffixes {
90 let mut has_suffix = false;
91 for suffix in suffixes {
92 if path_str.ends_with(suffix) {
93 has_suffix = true;
94 break;
95 }
96 }
97 if !has_suffix {
98 continue;
99 }
100 }
101
102 if opts
103 .path_filter
104 .as_ref()
105 .map_or(false, |f| f.0(&Path::new(&file_name)))
106 {
107 continue; }
109
110 if let Some(glob_pattern) = &opts.glob {
112 let glob = glob::Pattern::new(glob_pattern).unwrap();
113 if !glob.matches(&path_str) {
114 continue;
115 }
116 }
117
118 matching_files.push(path_str);
119 }
120
121 matching_files
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::env;
128
129 #[tokio::test]
130 async fn test_find_files_with_extension() {
131 let temp_dir = env::temp_dir().join("dir_loader_test_dir");
133
134 if temp_dir.exists() {
135 fs::remove_dir_all(&temp_dir)
136 .await
137 .expect("Failed to remove existing directory");
138 }
139
140 fs::create_dir(&temp_dir)
141 .await
142 .expect("Failed to create temporary directory");
143 let file_paths = [
145 temp_dir.as_path().join("file1.txt"),
146 temp_dir.as_path().join("file2.txt"),
147 temp_dir.as_path().join("file3.md"),
148 temp_dir.as_path().join("file4.txt"),
149 ];
150
151 for path in &file_paths {
153 let content = "Hello, world!";
154 std::fs::write(path, content).expect("Failed to write file");
155 }
156
157 let found_files = find_files_with_extension(
159 temp_dir.as_path().to_str().unwrap(),
160 &DirLoaderOptions {
161 glob: None,
162 suffixes: Some(vec![".txt".to_string()]),
163 path_filter: None,
164 },
165 )
166 .await
167 .into_iter()
168 .collect::<Vec<_>>();
169
170 assert_eq!(found_files.len(), 3);
172 for file in &found_files {
174 assert!(file.ends_with(".txt"));
175 }
176 assert!(found_files.contains(&temp_dir.join("file1.txt").to_string_lossy().to_string()),);
177 assert!(found_files.contains(&temp_dir.join("file2.txt").to_string_lossy().to_string()),);
178 assert!(found_files.contains(&temp_dir.join("file4.txt").to_string_lossy().to_string()),);
179
180 fs::remove_dir_all(&temp_dir)
182 .await
183 .expect("Failed to remove temporary directory");
184 }
185}