agentroot_core/providers/
csv.rs1use crate::db::hash_content;
4use crate::error::{AgentRootError, Result};
5use crate::providers::{ProviderConfig, SourceItem, SourceProvider};
6use async_trait::async_trait;
7use csv::ReaderBuilder;
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11use walkdir::WalkDir;
12
13pub struct CSVProvider;
15
16impl Default for CSVProvider {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl CSVProvider {
23 pub fn new() -> Self {
25 Self
26 }
27
28 fn parse_csv_file(&self, path: &Path, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
30 let file_content = fs::read_to_string(path).map_err(|e| {
31 AgentRootError::Io(std::io::Error::new(
32 e.kind(),
33 format!("Failed to read CSV file {:?}: {}", path, e),
34 ))
35 })?;
36
37 let delimiter = config
38 .options
39 .get("delimiter")
40 .and_then(|s| s.chars().next())
41 .unwrap_or(',');
42
43 let has_headers = config
44 .options
45 .get("has_headers")
46 .map(|s| s == "true")
47 .unwrap_or(true);
48
49 let mut reader = ReaderBuilder::new()
50 .delimiter(delimiter as u8)
51 .has_headers(has_headers)
52 .from_reader(file_content.as_bytes());
53
54 let headers = if has_headers {
55 reader.headers()?.clone()
56 } else {
57 csv::StringRecord::new()
58 };
59
60 let filename = path
61 .file_name()
62 .and_then(|s| s.to_str())
63 .unwrap_or("unknown.csv");
64
65 let mut items = Vec::new();
66 for (row_num, result) in reader.records().enumerate() {
67 let record = result.map_err(|e| {
68 AgentRootError::Parse(format!("Failed to parse CSV row {}: {}", row_num + 1, e))
69 })?;
70
71 let row_content = if has_headers && !headers.is_empty() {
72 let mut parts = Vec::new();
73 for (idx, field) in record.iter().enumerate() {
74 let header = headers.get(idx).unwrap_or("unknown");
75 parts.push(format!("{}: {}", header, field));
76 }
77 parts.join("\n")
78 } else {
79 record
80 .iter()
81 .enumerate()
82 .map(|(idx, field)| format!("column_{}: {}", idx, field))
83 .collect::<Vec<_>>()
84 .join("\n")
85 };
86
87 let title = format!("{} - Row {}", filename, row_num + 1);
88
89 let uri = format!("csv://{}/row_{}", path.display(), row_num + 1);
90 let hash = hash_content(&row_content);
91
92 let mut metadata = HashMap::new();
93 metadata.insert("file".to_string(), filename.to_string());
94 metadata.insert("row_number".to_string(), (row_num + 1).to_string());
95 metadata.insert("column_count".to_string(), record.len().to_string());
96
97 if has_headers {
98 for (idx, field) in record.iter().enumerate() {
99 if let Some(header) = headers.get(idx) {
100 metadata.insert(header.to_string(), field.to_string());
101 }
102 }
103 }
104
105 items.push(SourceItem {
106 uri,
107 title,
108 content: row_content,
109 hash,
110 source_type: "csv".to_string(),
111 metadata,
112 });
113 }
114
115 Ok(items)
116 }
117
118 fn scan_directory(&self, base_path: &Path, pattern: &str) -> Result<Vec<PathBuf>> {
120 let glob_pattern = glob::Pattern::new(pattern)?;
121 let mut csv_files = Vec::new();
122
123 for entry in WalkDir::new(base_path)
124 .follow_links(true)
125 .into_iter()
126 .filter_entry(|e| {
127 let name = e.file_name().to_string_lossy();
128 !name.starts_with('.')
129 && !matches!(
130 name.as_ref(),
131 "node_modules" | ".git" | ".cache" | "target" | "dist" | "build"
132 )
133 })
134 {
135 let entry = entry?;
136 if !entry.file_type().is_file() {
137 continue;
138 }
139
140 let path = entry.path();
141 if let Some(ext) = path.extension() {
142 if ext.eq_ignore_ascii_case("csv") {
143 if let Ok(relative) = path.strip_prefix(base_path) {
144 let relative_str = relative.to_string_lossy();
145 if glob_pattern.matches(&relative_str) {
146 csv_files.push(path.to_path_buf());
147 }
148 }
149 }
150 }
151 }
152
153 Ok(csv_files)
154 }
155}
156
157#[async_trait]
158impl SourceProvider for CSVProvider {
159 fn provider_type(&self) -> &'static str {
160 "csv"
161 }
162
163 async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
164 let base_path = Path::new(&config.base_path);
165
166 if base_path.is_file() {
167 if base_path
168 .extension()
169 .map(|e| e.eq_ignore_ascii_case("csv"))
170 .unwrap_or(false)
171 {
172 return self.parse_csv_file(base_path, config);
173 } else {
174 return Err(AgentRootError::Parse(format!(
175 "File {:?} is not a CSV file",
176 base_path
177 )));
178 }
179 }
180
181 if !base_path.exists() {
182 return Err(AgentRootError::Io(std::io::Error::new(
183 std::io::ErrorKind::NotFound,
184 format!("Path not found: {:?}", base_path),
185 )));
186 }
187
188 let csv_files = self.scan_directory(base_path, &config.pattern)?;
189 let mut all_items = Vec::new();
190
191 for csv_file in csv_files {
192 match self.parse_csv_file(&csv_file, config) {
193 Ok(items) => all_items.extend(items),
194 Err(e) => {
195 tracing::warn!("Failed to parse CSV file {:?}: {}", csv_file, e);
196 }
197 }
198 }
199
200 Ok(all_items)
201 }
202
203 async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
204 if !uri.starts_with("csv://") {
205 return Err(AgentRootError::Parse(format!(
206 "Invalid CSV URI: {}. Expected format: csv://path/to/file.csv/row_N",
207 uri
208 )));
209 }
210
211 let uri_path = &uri[6..];
212 let parts: Vec<&str> = uri_path.rsplitn(2, '/').collect();
213 if parts.len() != 2 || !parts[0].starts_with("row_") {
214 return Err(AgentRootError::Parse(format!(
215 "Invalid CSV URI format: {}. Expected: csv://path/to/file.csv/row_N",
216 uri
217 )));
218 }
219
220 let row_str = &parts[0][4..];
221 let row_num: usize = row_str
222 .parse()
223 .map_err(|_| AgentRootError::Parse(format!("Invalid row number in URI: {}", uri)))?;
224
225 let file_path = Path::new(parts[1]);
226 let config =
227 ProviderConfig::new(file_path.to_string_lossy().to_string(), "**/*".to_string());
228
229 let all_items = self.parse_csv_file(file_path, &config)?;
230
231 all_items
232 .into_iter()
233 .find(|item| item.uri == uri)
234 .ok_or_else(|| {
235 AgentRootError::Parse(format!(
236 "Row {} not found in CSV file {:?}",
237 row_num, file_path
238 ))
239 })
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_provider_type() {
249 let provider = CSVProvider::new();
250 assert_eq!(provider.provider_type(), "csv");
251 }
252
253 #[tokio::test]
254 async fn test_parse_csv_with_headers() {
255 let provider = CSVProvider::new();
256 let csv_content = "name,age,city\nAlice,30,NYC\nBob,25,LA\n";
257
258 let temp_dir = tempfile::tempdir().unwrap();
259 let csv_path = temp_dir.path().join("test.csv");
260 fs::write(&csv_path, csv_content).unwrap();
261
262 let config = ProviderConfig::new(
263 csv_path.to_string_lossy().to_string(),
264 "**/*.csv".to_string(),
265 );
266 let items = provider.parse_csv_file(&csv_path, &config).unwrap();
267
268 assert_eq!(items.len(), 2);
269 assert!(items[0].content.contains("name: Alice"));
270 assert!(items[0].content.contains("age: 30"));
271 assert!(items[0].metadata.get("name").unwrap() == "Alice");
272 }
273
274 #[tokio::test]
275 async fn test_parse_csv_custom_delimiter() {
276 let provider = CSVProvider::new();
277 let csv_content = "name;age;city\nAlice;30;NYC\n";
278
279 let temp_dir = tempfile::tempdir().unwrap();
280 let csv_path = temp_dir.path().join("test.csv");
281 fs::write(&csv_path, csv_content).unwrap();
282
283 let mut config = ProviderConfig::new(
284 csv_path.to_string_lossy().to_string(),
285 "**/*.csv".to_string(),
286 );
287 config
288 .options
289 .insert("delimiter".to_string(), ";".to_string());
290
291 let items = provider.parse_csv_file(&csv_path, &config).unwrap();
292
293 assert_eq!(items.len(), 1);
294 assert!(items[0].content.contains("name: Alice"));
295 }
296
297 #[tokio::test]
298 async fn test_fetch_item_by_uri() {
299 let provider = CSVProvider::new();
300 let csv_content = "name,age\nAlice,30\nBob,25\n";
301
302 let temp_dir = tempfile::tempdir().unwrap();
303 let csv_path = temp_dir.path().join("test.csv");
304 fs::write(&csv_path, csv_content).unwrap();
305
306 let uri = format!("csv://{}/row_1", csv_path.display());
307 let item = provider.fetch_item(&uri).await.unwrap();
308
309 assert!(item.content.contains("Alice"));
310 assert_eq!(item.metadata.get("row_number").unwrap(), "1");
311 }
312}