Skip to main content

autoagents_core/readers/
simple_directory_reader.rs

1use std::collections::HashSet;
2use std::ffi::OsStr;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde_json::json;
7use walkdir::WalkDir;
8
9use crate::document::Document;
10
11#[derive(Debug, thiserror::Error)]
12pub enum ReaderError {
13    #[error("Root path does not exist: {0}")]
14    MissingPath(PathBuf),
15
16    #[error("Failed to read file {path:?}: {source}")]
17    Io {
18        path: PathBuf,
19        source: std::io::Error,
20    },
21
22    #[error("File {0:?} is not valid UTF-8")]
23    Utf8(PathBuf),
24}
25
26#[derive(Clone, Debug)]
27pub struct SimpleDirectoryReader {
28    root: PathBuf,
29    recursive: bool,
30    extensions: Option<HashSet<String>>,
31}
32
33impl SimpleDirectoryReader {
34    pub fn new(root: impl Into<PathBuf>) -> Self {
35        Self {
36            root: root.into(),
37            recursive: true,
38            extensions: None,
39        }
40    }
41
42    /// Limit the reader to a specific set of extensions (without dots).
43    pub fn with_extensions<I, S>(mut self, extensions: I) -> Self
44    where
45        I: IntoIterator<Item = S>,
46        S: Into<String>,
47    {
48        self.extensions = Some(extensions.into_iter().map(|ext| ext.into()).collect());
49        self
50    }
51
52    pub fn recursive(mut self, recursive: bool) -> Self {
53        self.recursive = recursive;
54        self
55    }
56
57    pub fn load_data(&self) -> Result<Vec<Document>, ReaderError> {
58        if !self.root.exists() {
59            return Err(ReaderError::MissingPath(self.root.clone()));
60        }
61
62        let mut docs = Vec::new();
63        let walker = if self.recursive {
64            WalkDir::new(&self.root)
65        } else {
66            WalkDir::new(&self.root).max_depth(1)
67        };
68
69        for entry in walker {
70            let entry = match entry {
71                Ok(e) => e,
72                Err(err) => {
73                    return Err(ReaderError::Io {
74                        path: self.root.clone(),
75                        source: std::io::Error::other(err),
76                    });
77                }
78            };
79
80            if entry.file_type().is_dir() {
81                continue;
82            }
83
84            if let Some(exts) = &self.extensions {
85                if let Some(ext) = entry.path().extension().and_then(OsStr::to_str) {
86                    if !exts.contains(ext) {
87                        continue;
88                    }
89                } else {
90                    continue;
91                }
92            }
93
94            let content = match fs::read_to_string(entry.path()) {
95                Ok(content) => content,
96                Err(err) if err.kind() == std::io::ErrorKind::InvalidData => {
97                    return Err(ReaderError::Utf8(entry.path().to_path_buf()));
98                }
99                Err(source) => {
100                    return Err(ReaderError::Io {
101                        path: entry.path().to_path_buf(),
102                        source,
103                    });
104                }
105            };
106
107            let relative = path_relative_to(entry.path(), &self.root)
108                .unwrap_or_else(|| entry.file_name().to_string_lossy().to_string());
109
110            let metadata = json!({
111                "source": relative,
112                "absolute_path": entry.path().to_string_lossy(),
113                "extension": entry.path().extension().and_then(OsStr::to_str).unwrap_or_default(),
114            });
115
116            docs.push(Document::with_metadata(content, metadata));
117        }
118
119        Ok(docs)
120    }
121}
122
123fn path_relative_to(path: &Path, base: &Path) -> Option<String> {
124    path.strip_prefix(base)
125        .ok()
126        .map(|p| p.to_string_lossy().to_string())
127}