Skip to main content

autoagents_core/readers/
simple_directory_reader.rs

1use std::collections::HashSet;
2use std::ffi::OsStr;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde_json::json;
7use walkdir::WalkDir;
8
9use crate::document::Document;
10
11#[derive(Debug, thiserror::Error)]
12pub enum ReaderError {
13    #[error("Root path does not exist: {0}")]
14    MissingPath(PathBuf),
15
16    #[error("Failed to read file {path:?}: {source}")]
17    Io {
18        path: PathBuf,
19        source: std::io::Error,
20    },
21
22    #[error("File {0:?} is not valid UTF-8")]
23    Utf8(PathBuf),
24}
25
26#[derive(Clone, Debug)]
27pub struct SimpleDirectoryReader {
28    root: PathBuf,
29    recursive: bool,
30    extensions: Option<HashSet<String>>,
31}
32
33impl SimpleDirectoryReader {
34    pub fn new(root: impl Into<PathBuf>) -> Self {
35        Self {
36            root: root.into(),
37            recursive: true,
38            extensions: None,
39        }
40    }
41
42    /// Limit the reader to a specific set of extensions (without dots).
43    pub fn with_extensions<I, S>(mut self, extensions: I) -> Self
44    where
45        I: IntoIterator<Item = S>,
46        S: Into<String>,
47    {
48        self.extensions = Some(extensions.into_iter().map(|ext| ext.into()).collect());
49        self
50    }
51
52    pub fn recursive(mut self, recursive: bool) -> Self {
53        self.recursive = recursive;
54        self
55    }
56
57    pub fn load_data(&self) -> Result<Vec<Document>, ReaderError> {
58        if !self.root.exists() {
59            return Err(ReaderError::MissingPath(self.root.clone()));
60        }
61
62        let mut docs = Vec::new();
63        let walker = if self.recursive {
64            WalkDir::new(&self.root)
65        } else {
66            WalkDir::new(&self.root).max_depth(1)
67        };
68
69        for entry in walker {
70            let entry = match entry {
71                Ok(e) => e,
72                Err(err) => {
73                    return Err(ReaderError::Io {
74                        path: self.root.clone(),
75                        source: std::io::Error::other(err),
76                    });
77                }
78            };
79
80            if entry.file_type().is_dir() {
81                continue;
82            }
83
84            if let Some(exts) = &self.extensions {
85                if let Some(ext) = entry.path().extension().and_then(OsStr::to_str) {
86                    if !exts.contains(ext) {
87                        continue;
88                    }
89                } else {
90                    continue;
91                }
92            }
93
94            let content = match fs::read_to_string(entry.path()) {
95                Ok(content) => content,
96                Err(err) if err.kind() == std::io::ErrorKind::InvalidData => {
97                    return Err(ReaderError::Utf8(entry.path().to_path_buf()));
98                }
99                Err(source) => {
100                    return Err(ReaderError::Io {
101                        path: entry.path().to_path_buf(),
102                        source,
103                    });
104                }
105            };
106
107            let relative = path_relative_to(entry.path(), &self.root)
108                .unwrap_or_else(|| entry.file_name().to_string_lossy().to_string());
109
110            let metadata = json!({
111                "source": relative,
112                "absolute_path": entry.path().to_string_lossy(),
113                "extension": entry.path().extension().and_then(OsStr::to_str).unwrap_or_default(),
114            });
115
116            docs.push(Document::with_metadata(content, metadata));
117        }
118
119        Ok(docs)
120    }
121}
122
123fn path_relative_to(path: &Path, base: &Path) -> Option<String> {
124    path.strip_prefix(base)
125        .ok()
126        .map(|p| p.to_string_lossy().to_string())
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use std::fs;
133
134    #[test]
135    fn test_missing_path_error() {
136        let reader = SimpleDirectoryReader::new("/nonexistent/path/xyz123");
137        let result = reader.load_data();
138        assert!(result.is_err());
139        assert!(matches!(result.unwrap_err(), ReaderError::MissingPath(_)));
140    }
141
142    #[test]
143    fn test_empty_directory() {
144        let dir = std::env::temp_dir().join("autoagents_test_empty_dir");
145        fs::create_dir_all(&dir).unwrap();
146        let reader = SimpleDirectoryReader::new(&dir);
147        let docs = reader.load_data().unwrap();
148        assert!(docs.is_empty());
149        fs::remove_dir_all(&dir).ok();
150    }
151
152    #[test]
153    fn test_single_file_load() {
154        let dir = std::env::temp_dir().join("autoagents_test_single_file");
155        fs::create_dir_all(&dir).unwrap();
156        fs::write(dir.join("test.txt"), "hello world").unwrap();
157
158        let reader = SimpleDirectoryReader::new(&dir);
159        let docs = reader.load_data().unwrap();
160        assert_eq!(docs.len(), 1);
161        assert_eq!(docs[0].page_content, "hello world");
162        assert_eq!(docs[0].metadata["extension"], "txt");
163
164        fs::remove_dir_all(&dir).ok();
165    }
166
167    #[test]
168    fn test_extension_filter() {
169        let dir = std::env::temp_dir().join("autoagents_test_ext_filter");
170        fs::create_dir_all(&dir).unwrap();
171        fs::write(dir.join("file.txt"), "text").unwrap();
172        fs::write(dir.join("file.md"), "markdown").unwrap();
173
174        let reader = SimpleDirectoryReader::new(&dir).with_extensions(["txt"]);
175        let docs = reader.load_data().unwrap();
176        assert_eq!(docs.len(), 1);
177        assert_eq!(docs[0].page_content, "text");
178
179        fs::remove_dir_all(&dir).ok();
180    }
181
182    #[test]
183    fn test_non_recursive_mode() {
184        let dir = std::env::temp_dir().join("autoagents_test_nonrecursive");
185        let sub = dir.join("sub");
186        fs::create_dir_all(&sub).unwrap();
187        fs::write(dir.join("top.txt"), "top").unwrap();
188        fs::write(sub.join("nested.txt"), "nested").unwrap();
189
190        let reader = SimpleDirectoryReader::new(&dir).recursive(false);
191        let docs = reader.load_data().unwrap();
192        assert_eq!(docs.len(), 1);
193        assert_eq!(docs[0].page_content, "top");
194
195        fs::remove_dir_all(&dir).ok();
196    }
197
198    #[test]
199    fn test_relative_path_metadata() {
200        let dir = std::env::temp_dir().join("autoagents_test_relpath");
201        fs::create_dir_all(&dir).unwrap();
202        fs::write(dir.join("file.txt"), "content").unwrap();
203
204        let reader = SimpleDirectoryReader::new(&dir);
205        let docs = reader.load_data().unwrap();
206        assert_eq!(docs.len(), 1);
207        assert_eq!(docs[0].metadata["source"], "file.txt");
208
209        fs::remove_dir_all(&dir).ok();
210    }
211
212    #[test]
213    fn test_recursive_mode() {
214        let dir = std::env::temp_dir().join("autoagents_test_recursive");
215        let sub = dir.join("sub");
216        fs::create_dir_all(&sub).unwrap();
217        fs::write(dir.join("top.txt"), "top").unwrap();
218        fs::write(sub.join("nested.txt"), "nested").unwrap();
219
220        let reader = SimpleDirectoryReader::new(&dir);
221        let docs = reader.load_data().unwrap();
222        assert_eq!(docs.len(), 2);
223
224        fs::remove_dir_all(&dir).ok();
225    }
226
227    #[test]
228    fn test_path_relative_to_fn() {
229        let result = path_relative_to(Path::new("/a/b/c.txt"), Path::new("/a/b"));
230        assert_eq!(result, Some("c.txt".to_string()));
231    }
232
233    #[test]
234    fn test_path_relative_to_fn_no_prefix() {
235        let result = path_relative_to(Path::new("/x/y.txt"), Path::new("/a/b"));
236        assert_eq!(result, None);
237    }
238}