Skip to main content

a3s_code_core/context/
fs_provider.rs

1//! File System Context Provider
2//!
3//! Provides simple file-based RAG (Retrieval-Augmented Generation):
4//! - Index files in a directory with glob pattern filtering
5//! - Simple keyword-based search with relevance scoring
6//! - Support for file size limits and exclusion patterns
7
8use crate::context::{ContextItem, ContextProvider, ContextQuery, ContextResult, ContextType};
9use async_trait::async_trait;
10use ignore::WalkBuilder;
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// File system context provider configuration
18#[derive(Debug, Clone)]
19pub struct FileSystemContextConfig {
20    /// Root directory to index
21    pub root_path: PathBuf,
22    /// Include patterns (glob syntax: ["**/*.rs", "**/*.md"])
23    pub include_patterns: Vec<String>,
24    /// Exclude patterns (glob syntax: ["**/target/**", "**/node_modules/**"])
25    pub exclude_patterns: Vec<String>,
26    /// Maximum file size in bytes (default: 1MB)
27    pub max_file_size: usize,
28    /// Whether to enable cache (default: true)
29    pub enable_cache: bool,
30}
31
32impl FileSystemContextConfig {
33    /// Create a new config with default settings
34    pub fn new(root_path: impl Into<PathBuf>) -> Self {
35        Self {
36            root_path: root_path.into(),
37            include_patterns: vec!["**/*.rs".to_string(), "**/*.md".to_string()],
38            exclude_patterns: vec![
39                "**/target/**".to_string(),
40                "**/node_modules/**".to_string(),
41                "**/.git/**".to_string(),
42            ],
43            max_file_size: 1024 * 1024, // 1MB
44            enable_cache: true,
45        }
46    }
47
48    /// Set include patterns
49    pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
50        self.include_patterns = patterns;
51        self
52    }
53
54    /// Set exclude patterns
55    pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
56        self.exclude_patterns = patterns;
57        self
58    }
59
60    /// Set max file size
61    pub fn with_max_file_size(mut self, size: usize) -> Self {
62        self.max_file_size = size;
63        self
64    }
65
66    /// Enable/disable cache
67    pub fn with_cache(mut self, enable: bool) -> Self {
68        self.enable_cache = enable;
69        self
70    }
71}
72
73/// Indexed file entry
74#[derive(Debug, Clone)]
75struct IndexedFile {
76    path: PathBuf,
77    content: String,
78    size: usize,
79}
80
81/// File system context provider
82pub struct FileSystemContextProvider {
83    config: FileSystemContextConfig,
84    /// Cached indexed files
85    cache: Arc<RwLock<HashMap<PathBuf, IndexedFile>>>,
86}
87
88impl FileSystemContextProvider {
89    /// Create a new file system context provider
90    pub fn new(config: FileSystemContextConfig) -> Self {
91        Self {
92            config,
93            cache: Arc::new(RwLock::new(HashMap::new())),
94        }
95    }
96
97    /// Index files in the root directory
98    async fn index_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
99        let mut files = Vec::new();
100
101        let walker = WalkBuilder::new(&self.config.root_path)
102            .hidden(false)
103            .git_ignore(true)
104            .build();
105
106        for entry in walker {
107            let entry = entry.map_err(|e| anyhow::anyhow!("Walk error: {}", e))?;
108            let path = entry.path();
109
110            if !path.is_file() {
111                continue;
112            }
113
114            let metadata =
115                fs::metadata(path).map_err(|e| anyhow::anyhow!("Metadata error: {}", e))?;
116            if metadata.len() > self.config.max_file_size as u64 {
117                continue;
118            }
119
120            if !self.matches_include_patterns(path) {
121                continue;
122            }
123
124            if self.matches_exclude_patterns(path) {
125                continue;
126            }
127
128            let content =
129                fs::read_to_string(path).map_err(|e| anyhow::anyhow!("Read error: {}", e))?;
130
131            files.push(IndexedFile {
132                path: path.to_path_buf(),
133                content,
134                size: metadata.len() as usize,
135            });
136        }
137
138        Ok(files)
139    }
140
141    fn matches_include_patterns(&self, path: &Path) -> bool {
142        if self.config.include_patterns.is_empty() {
143            return true;
144        }
145
146        // Normalize to forward slashes for consistent cross-platform glob matching
147        let path_str = path.to_string_lossy().replace('\\', "/");
148        self.config.include_patterns.iter().any(|pattern| {
149            glob::Pattern::new(pattern)
150                .map(|p| p.matches(&path_str))
151                .unwrap_or(false)
152        })
153    }
154
155    fn matches_exclude_patterns(&self, path: &Path) -> bool {
156        // Normalize to forward slashes for consistent cross-platform glob matching
157        let path_str = path.to_string_lossy().replace('\\', "/");
158        self.config.exclude_patterns.iter().any(|pattern| {
159            glob::Pattern::new(pattern)
160                .map(|p| p.matches(&path_str))
161                .unwrap_or(false)
162        })
163    }
164
165    async fn search_simple(
166        &self,
167        query: &str,
168        files: &[IndexedFile],
169        max_results: usize,
170    ) -> Vec<(IndexedFile, f32)> {
171        let query_lower = query.to_lowercase();
172        let keywords: Vec<&str> = query_lower.split_whitespace().collect();
173
174        let mut results: Vec<(IndexedFile, f32)> = files
175            .iter()
176            .filter_map(|file| {
177                let content_lower = file.content.to_lowercase();
178                let mut score = 0.0;
179                for keyword in &keywords {
180                    let count = content_lower.matches(keyword).count();
181                    score += count as f32;
182                }
183
184                if score > 0.0 {
185                    let normalized_score = score / (file.content.len() as f32).sqrt();
186                    Some((file.clone(), normalized_score))
187                } else {
188                    None
189                }
190            })
191            .collect();
192
193        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
194        results.truncate(max_results);
195        results
196    }
197
198    async fn update_cache(&self, files: Vec<IndexedFile>) {
199        if !self.config.enable_cache {
200            return;
201        }
202
203        let mut cache = self.cache.write().await;
204        cache.clear();
205        for file in files {
206            cache.insert(file.path.clone(), file);
207        }
208    }
209
210    async fn get_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
211        if self.config.enable_cache {
212            let cache = self.cache.read().await;
213            if !cache.is_empty() {
214                return Ok(cache.values().cloned().collect());
215            }
216        }
217
218        let files = self.index_files().await?;
219        self.update_cache(files.clone()).await;
220        Ok(files)
221    }
222}
223
224#[async_trait]
225impl ContextProvider for FileSystemContextProvider {
226    fn name(&self) -> &str {
227        "filesystem"
228    }
229
230    async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
231        let files = self.get_files().await?;
232        let results = self
233            .search_simple(&query.query, &files, query.max_results)
234            .await;
235
236        let items: Vec<ContextItem> = results
237            .into_iter()
238            .map(|(file, score)| {
239                let content = match query.depth {
240                    crate::context::ContextDepth::Abstract => {
241                        file.content.chars().take(500).collect::<String>()
242                    }
243                    crate::context::ContextDepth::Overview => {
244                        file.content.chars().take(2000).collect::<String>()
245                    }
246                    crate::context::ContextDepth::Full => file.content.clone(),
247                };
248
249                let token_count = content.split_whitespace().count();
250
251                ContextItem {
252                    id: file.path.to_string_lossy().to_string(),
253                    context_type: ContextType::Resource,
254                    content,
255                    token_count,
256                    relevance: score,
257                    source: Some(format!("file:{}", file.path.display())),
258                    metadata: {
259                        let mut meta = HashMap::new();
260                        meta.insert(
261                            "path".to_string(),
262                            serde_json::Value::String(file.path.to_string_lossy().to_string()),
263                        );
264                        meta.insert(
265                            "size".to_string(),
266                            serde_json::Value::Number(file.size.into()),
267                        );
268                        meta
269                    },
270                }
271            })
272            .collect();
273
274        let total_tokens: usize = items.iter().map(|item| item.token_count).sum();
275        let truncated = items.len() < files.len();
276
277        Ok(ContextResult {
278            items,
279            total_tokens,
280            provider: self.name().to_string(),
281            truncated,
282        })
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use std::fs::File;
290    use std::io::Write;
291    use tempfile::TempDir;
292
293    fn create_test_files(dir: &Path) -> anyhow::Result<()> {
294        let mut file1 = File::create(dir.join("test1.rs"))?;
295        writeln!(file1, "fn main() {{\n    println!(\"Hello, world!\");\n}}")?;
296
297        let mut file2 = File::create(dir.join("test2.md"))?;
298        writeln!(
299            file2,
300            "# Test Document\n\nThis is a test document about Rust programming."
301        )?;
302
303        fs::create_dir(dir.join("subdir"))?;
304        let mut file4 = File::create(dir.join("subdir/test4.rs"))?;
305        writeln!(file4, "fn test() {{\n    // Test function\n}}")?;
306
307        Ok(())
308    }
309
310    #[tokio::test]
311    async fn test_index_files() {
312        let temp_dir = TempDir::new().unwrap();
313        create_test_files(temp_dir.path()).unwrap();
314
315        let config = FileSystemContextConfig::new(temp_dir.path());
316        let provider = FileSystemContextProvider::new(config);
317
318        let files = provider.index_files().await.unwrap();
319        assert!(files.len() >= 2);
320    }
321
322    #[tokio::test]
323    async fn test_search_simple() {
324        let temp_dir = TempDir::new().unwrap();
325        create_test_files(temp_dir.path()).unwrap();
326
327        let config = FileSystemContextConfig::new(temp_dir.path());
328        let provider = FileSystemContextProvider::new(config);
329
330        let query = ContextQuery::new("Rust programming");
331        let result = provider.query(&query).await.unwrap();
332
333        assert!(!result.items.is_empty());
334        assert!(result
335            .items
336            .iter()
337            .any(|item| item.content.contains("Rust")));
338    }
339}