autoagents_core/readers/
simple_directory_reader.rs1use std::collections::HashSet;
2use std::ffi::OsStr;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde_json::json;
7use walkdir::WalkDir;
8
9use crate::document::Document;
10
11#[derive(Debug, thiserror::Error)]
12pub enum ReaderError {
13 #[error("Root path does not exist: {0}")]
14 MissingPath(PathBuf),
15
16 #[error("Failed to read file {path:?}: {source}")]
17 Io {
18 path: PathBuf,
19 source: std::io::Error,
20 },
21
22 #[error("File {0:?} is not valid UTF-8")]
23 Utf8(PathBuf),
24}
25
26#[derive(Clone, Debug)]
27pub struct SimpleDirectoryReader {
28 root: PathBuf,
29 recursive: bool,
30 extensions: Option<HashSet<String>>,
31}
32
33impl SimpleDirectoryReader {
34 pub fn new(root: impl Into<PathBuf>) -> Self {
35 Self {
36 root: root.into(),
37 recursive: true,
38 extensions: None,
39 }
40 }
41
42 pub fn with_extensions<I, S>(mut self, extensions: I) -> Self
44 where
45 I: IntoIterator<Item = S>,
46 S: Into<String>,
47 {
48 self.extensions = Some(extensions.into_iter().map(|ext| ext.into()).collect());
49 self
50 }
51
52 pub fn recursive(mut self, recursive: bool) -> Self {
53 self.recursive = recursive;
54 self
55 }
56
57 pub fn load_data(&self) -> Result<Vec<Document>, ReaderError> {
58 if !self.root.exists() {
59 return Err(ReaderError::MissingPath(self.root.clone()));
60 }
61
62 let mut docs = Vec::new();
63 let walker = if self.recursive {
64 WalkDir::new(&self.root)
65 } else {
66 WalkDir::new(&self.root).max_depth(1)
67 };
68
69 for entry in walker {
70 let entry = match entry {
71 Ok(e) => e,
72 Err(err) => {
73 return Err(ReaderError::Io {
74 path: self.root.clone(),
75 source: std::io::Error::other(err),
76 });
77 }
78 };
79
80 if entry.file_type().is_dir() {
81 continue;
82 }
83
84 if let Some(exts) = &self.extensions {
85 if let Some(ext) = entry.path().extension().and_then(OsStr::to_str) {
86 if !exts.contains(ext) {
87 continue;
88 }
89 } else {
90 continue;
91 }
92 }
93
94 let content = match fs::read_to_string(entry.path()) {
95 Ok(content) => content,
96 Err(err) if err.kind() == std::io::ErrorKind::InvalidData => {
97 return Err(ReaderError::Utf8(entry.path().to_path_buf()));
98 }
99 Err(source) => {
100 return Err(ReaderError::Io {
101 path: entry.path().to_path_buf(),
102 source,
103 });
104 }
105 };
106
107 let relative = path_relative_to(entry.path(), &self.root)
108 .unwrap_or_else(|| entry.file_name().to_string_lossy().to_string());
109
110 let metadata = json!({
111 "source": relative,
112 "absolute_path": entry.path().to_string_lossy(),
113 "extension": entry.path().extension().and_then(OsStr::to_str).unwrap_or_default(),
114 });
115
116 docs.push(Document::with_metadata(content, metadata));
117 }
118
119 Ok(docs)
120 }
121}
122
123fn path_relative_to(path: &Path, base: &Path) -> Option<String> {
124 path.strip_prefix(base)
125 .ok()
126 .map(|p| p.to_string_lossy().to_string())
127}