agentroot_core/providers/
csv.rs

1//! CSV Provider for indexing CSV files row-by-row
2
3use crate::db::hash_content;
4use crate::error::{AgentRootError, Result};
5use crate::providers::{ProviderConfig, SourceItem, SourceProvider};
6use async_trait::async_trait;
7use csv::ReaderBuilder;
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11use walkdir::WalkDir;
12
13/// Provider for indexing CSV files
14pub struct CSVProvider;
15
16impl Default for CSVProvider {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl CSVProvider {
23    /// Create a new CSVProvider
24    pub fn new() -> Self {
25        Self
26    }
27
28    /// Parse CSV file and return rows as items
29    fn parse_csv_file(&self, path: &Path, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
30        let file_content = fs::read_to_string(path).map_err(|e| {
31            AgentRootError::Io(std::io::Error::new(
32                e.kind(),
33                format!("Failed to read CSV file {:?}: {}", path, e),
34            ))
35        })?;
36
37        let delimiter = config
38            .options
39            .get("delimiter")
40            .and_then(|s| s.chars().next())
41            .unwrap_or(',');
42
43        let has_headers = config
44            .options
45            .get("has_headers")
46            .map(|s| s == "true")
47            .unwrap_or(true);
48
49        let mut reader = ReaderBuilder::new()
50            .delimiter(delimiter as u8)
51            .has_headers(has_headers)
52            .from_reader(file_content.as_bytes());
53
54        let headers = if has_headers {
55            reader.headers()?.clone()
56        } else {
57            csv::StringRecord::new()
58        };
59
60        let filename = path
61            .file_name()
62            .and_then(|s| s.to_str())
63            .unwrap_or("unknown.csv");
64
65        let mut items = Vec::new();
66        for (row_num, result) in reader.records().enumerate() {
67            let record = result.map_err(|e| {
68                AgentRootError::Parse(format!("Failed to parse CSV row {}: {}", row_num + 1, e))
69            })?;
70
71            let row_content = if has_headers && !headers.is_empty() {
72                let mut parts = Vec::new();
73                for (idx, field) in record.iter().enumerate() {
74                    let header = headers.get(idx).unwrap_or("unknown");
75                    parts.push(format!("{}: {}", header, field));
76                }
77                parts.join("\n")
78            } else {
79                record
80                    .iter()
81                    .enumerate()
82                    .map(|(idx, field)| format!("column_{}: {}", idx, field))
83                    .collect::<Vec<_>>()
84                    .join("\n")
85            };
86
87            let title = format!("{} - Row {}", filename, row_num + 1);
88
89            let uri = format!("csv://{}/row_{}", path.display(), row_num + 1);
90            let hash = hash_content(&row_content);
91
92            let mut metadata = HashMap::new();
93            metadata.insert("file".to_string(), filename.to_string());
94            metadata.insert("row_number".to_string(), (row_num + 1).to_string());
95            metadata.insert("column_count".to_string(), record.len().to_string());
96
97            if has_headers {
98                for (idx, field) in record.iter().enumerate() {
99                    if let Some(header) = headers.get(idx) {
100                        metadata.insert(header.to_string(), field.to_string());
101                    }
102                }
103            }
104
105            items.push(SourceItem {
106                uri,
107                title,
108                content: row_content,
109                hash,
110                source_type: "csv".to_string(),
111                metadata,
112            });
113        }
114
115        Ok(items)
116    }
117
118    /// Scan directory for CSV files matching pattern
119    fn scan_directory(&self, base_path: &Path, pattern: &str) -> Result<Vec<PathBuf>> {
120        let glob_pattern = glob::Pattern::new(pattern)?;
121        let mut csv_files = Vec::new();
122
123        for entry in WalkDir::new(base_path)
124            .follow_links(true)
125            .into_iter()
126            .filter_entry(|e| {
127                let name = e.file_name().to_string_lossy();
128                !name.starts_with('.')
129                    && !matches!(
130                        name.as_ref(),
131                        "node_modules" | ".git" | ".cache" | "target" | "dist" | "build"
132                    )
133            })
134        {
135            let entry = entry?;
136            if !entry.file_type().is_file() {
137                continue;
138            }
139
140            let path = entry.path();
141            if let Some(ext) = path.extension() {
142                if ext.eq_ignore_ascii_case("csv") {
143                    if let Ok(relative) = path.strip_prefix(base_path) {
144                        let relative_str = relative.to_string_lossy();
145                        if glob_pattern.matches(&relative_str) {
146                            csv_files.push(path.to_path_buf());
147                        }
148                    }
149                }
150            }
151        }
152
153        Ok(csv_files)
154    }
155}
156
157#[async_trait]
158impl SourceProvider for CSVProvider {
159    fn provider_type(&self) -> &'static str {
160        "csv"
161    }
162
163    async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
164        let base_path = Path::new(&config.base_path);
165
166        if base_path.is_file() {
167            if base_path
168                .extension()
169                .map(|e| e.eq_ignore_ascii_case("csv"))
170                .unwrap_or(false)
171            {
172                return self.parse_csv_file(base_path, config);
173            } else {
174                return Err(AgentRootError::Parse(format!(
175                    "File {:?} is not a CSV file",
176                    base_path
177                )));
178            }
179        }
180
181        if !base_path.exists() {
182            return Err(AgentRootError::Io(std::io::Error::new(
183                std::io::ErrorKind::NotFound,
184                format!("Path not found: {:?}", base_path),
185            )));
186        }
187
188        let csv_files = self.scan_directory(base_path, &config.pattern)?;
189        let mut all_items = Vec::new();
190
191        for csv_file in csv_files {
192            match self.parse_csv_file(&csv_file, config) {
193                Ok(items) => all_items.extend(items),
194                Err(e) => {
195                    tracing::warn!("Failed to parse CSV file {:?}: {}", csv_file, e);
196                }
197            }
198        }
199
200        Ok(all_items)
201    }
202
203    async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
204        if !uri.starts_with("csv://") {
205            return Err(AgentRootError::Parse(format!(
206                "Invalid CSV URI: {}. Expected format: csv://path/to/file.csv/row_N",
207                uri
208            )));
209        }
210
211        let uri_path = &uri[6..];
212        let parts: Vec<&str> = uri_path.rsplitn(2, '/').collect();
213        if parts.len() != 2 || !parts[0].starts_with("row_") {
214            return Err(AgentRootError::Parse(format!(
215                "Invalid CSV URI format: {}. Expected: csv://path/to/file.csv/row_N",
216                uri
217            )));
218        }
219
220        let row_str = &parts[0][4..];
221        let row_num: usize = row_str
222            .parse()
223            .map_err(|_| AgentRootError::Parse(format!("Invalid row number in URI: {}", uri)))?;
224
225        let file_path = Path::new(parts[1]);
226        let config =
227            ProviderConfig::new(file_path.to_string_lossy().to_string(), "**/*".to_string());
228
229        let all_items = self.parse_csv_file(file_path, &config)?;
230
231        all_items
232            .into_iter()
233            .find(|item| item.uri == uri)
234            .ok_or_else(|| {
235                AgentRootError::Parse(format!(
236                    "Row {} not found in CSV file {:?}",
237                    row_num, file_path
238                ))
239            })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_provider_type() {
249        let provider = CSVProvider::new();
250        assert_eq!(provider.provider_type(), "csv");
251    }
252
253    #[tokio::test]
254    async fn test_parse_csv_with_headers() {
255        let provider = CSVProvider::new();
256        let csv_content = "name,age,city\nAlice,30,NYC\nBob,25,LA\n";
257
258        let temp_dir = tempfile::tempdir().unwrap();
259        let csv_path = temp_dir.path().join("test.csv");
260        fs::write(&csv_path, csv_content).unwrap();
261
262        let config = ProviderConfig::new(
263            csv_path.to_string_lossy().to_string(),
264            "**/*.csv".to_string(),
265        );
266        let items = provider.parse_csv_file(&csv_path, &config).unwrap();
267
268        assert_eq!(items.len(), 2);
269        assert!(items[0].content.contains("name: Alice"));
270        assert!(items[0].content.contains("age: 30"));
271        assert!(items[0].metadata.get("name").unwrap() == "Alice");
272    }
273
274    #[tokio::test]
275    async fn test_parse_csv_custom_delimiter() {
276        let provider = CSVProvider::new();
277        let csv_content = "name;age;city\nAlice;30;NYC\n";
278
279        let temp_dir = tempfile::tempdir().unwrap();
280        let csv_path = temp_dir.path().join("test.csv");
281        fs::write(&csv_path, csv_content).unwrap();
282
283        let mut config = ProviderConfig::new(
284            csv_path.to_string_lossy().to_string(),
285            "**/*.csv".to_string(),
286        );
287        config
288            .options
289            .insert("delimiter".to_string(), ";".to_string());
290
291        let items = provider.parse_csv_file(&csv_path, &config).unwrap();
292
293        assert_eq!(items.len(), 1);
294        assert!(items[0].content.contains("name: Alice"));
295    }
296
297    #[tokio::test]
298    async fn test_fetch_item_by_uri() {
299        let provider = CSVProvider::new();
300        let csv_content = "name,age\nAlice,30\nBob,25\n";
301
302        let temp_dir = tempfile::tempdir().unwrap();
303        let csv_path = temp_dir.path().join("test.csv");
304        fs::write(&csv_path, csv_content).unwrap();
305
306        let uri = format!("csv://{}/row_1", csv_path.display());
307        let item = provider.fetch_item(&uri).await.unwrap();
308
309        assert!(item.content.contains("Alice"));
310        assert_eq!(item.metadata.get("row_number").unwrap(), "1");
311    }
312}