Skip to main content

a3s_code_core/context/
fs_provider.rs

1//! File System Context Provider
2//!
3//! Provides simple file-based RAG (Retrieval-Augmented Generation):
4//! - Index files in a directory with glob pattern filtering
5//! - Simple keyword-based search with relevance scoring
6//! - Support for file size limits and exclusion patterns
7
8use crate::context::{ContextItem, ContextProvider, ContextQuery, ContextResult, ContextType};
9use async_trait::async_trait;
10use ignore::WalkBuilder;
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// File system context provider configuration
18#[derive(Debug, Clone)]
19pub struct FileSystemContextConfig {
20    /// Root directory to index
21    pub root_path: PathBuf,
22    /// Include patterns (glob syntax: ["**/*.rs", "**/*.md"])
23    pub include_patterns: Vec<String>,
24    /// Exclude patterns (glob syntax: ["**/target/**", "**/node_modules/**"])
25    pub exclude_patterns: Vec<String>,
26    /// Maximum file size in bytes (default: 1MB)
27    pub max_file_size: usize,
28    /// Whether to enable cache (default: true)
29    pub enable_cache: bool,
30}
31
32impl FileSystemContextConfig {
33    /// Create a new config with default settings
34    pub fn new(root_path: impl Into<PathBuf>) -> Self {
35        Self {
36            root_path: root_path.into(),
37            include_patterns: vec!["**/*.rs".to_string(), "**/*.md".to_string()],
38            exclude_patterns: vec![
39                "**/target/**".to_string(),
40                "**/node_modules/**".to_string(),
41                "**/.git/**".to_string(),
42            ],
43            max_file_size: 1024 * 1024, // 1MB
44            enable_cache: true,
45        }
46    }
47
48    /// Set include patterns
49    pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
50        self.include_patterns = patterns;
51        self
52    }
53
54    /// Set exclude patterns
55    pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
56        self.exclude_patterns = patterns;
57        self
58    }
59
60    /// Set max file size
61    pub fn with_max_file_size(mut self, size: usize) -> Self {
62        self.max_file_size = size;
63        self
64    }
65
66    /// Enable/disable cache
67    pub fn with_cache(mut self, enable: bool) -> Self {
68        self.enable_cache = enable;
69        self
70    }
71}
72
73/// Indexed file entry
74#[derive(Debug, Clone)]
75struct IndexedFile {
76    path: PathBuf,
77    content: String,
78    size: usize,
79}
80
81/// File system context provider
82pub struct FileSystemContextProvider {
83    config: FileSystemContextConfig,
84    /// Cached indexed files
85    cache: Arc<RwLock<HashMap<PathBuf, IndexedFile>>>,
86}
87
88impl FileSystemContextProvider {
89    /// Create a new file system context provider
90    pub fn new(config: FileSystemContextConfig) -> Self {
91        Self {
92            config,
93            cache: Arc::new(RwLock::new(HashMap::new())),
94        }
95    }
96
97    /// Index files in the root directory
98    async fn index_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
99        let mut files = Vec::new();
100
101        let walker = WalkBuilder::new(&self.config.root_path)
102            .hidden(false)
103            .git_ignore(true)
104            .build();
105
106        for entry in walker {
107            let entry = entry.map_err(|e| anyhow::anyhow!("Walk error: {}", e))?;
108            let path = entry.path();
109
110            if !path.is_file() {
111                continue;
112            }
113
114            let metadata =
115                fs::metadata(path).map_err(|e| anyhow::anyhow!("Metadata error: {}", e))?;
116            if metadata.len() > self.config.max_file_size as u64 {
117                continue;
118            }
119
120            if !self.matches_include_patterns(path) {
121                continue;
122            }
123
124            if self.matches_exclude_patterns(path) {
125                continue;
126            }
127
128            let content =
129                fs::read_to_string(path).map_err(|e| anyhow::anyhow!("Read error: {}", e))?;
130
131            files.push(IndexedFile {
132                path: path.to_path_buf(),
133                content,
134                size: metadata.len() as usize,
135            });
136        }
137
138        Ok(files)
139    }
140
141    fn matches_include_patterns(&self, path: &Path) -> bool {
142        if self.config.include_patterns.is_empty() {
143            return true;
144        }
145
146        let path_str = path.to_string_lossy();
147        self.config.include_patterns.iter().any(|pattern| {
148            glob::Pattern::new(pattern)
149                .map(|p| p.matches(&path_str))
150                .unwrap_or(false)
151        })
152    }
153
154    fn matches_exclude_patterns(&self, path: &Path) -> bool {
155        let path_str = path.to_string_lossy();
156        self.config.exclude_patterns.iter().any(|pattern| {
157            glob::Pattern::new(pattern)
158                .map(|p| p.matches(&path_str))
159                .unwrap_or(false)
160        })
161    }
162
163    async fn search_simple(
164        &self,
165        query: &str,
166        files: &[IndexedFile],
167        max_results: usize,
168    ) -> Vec<(IndexedFile, f32)> {
169        let query_lower = query.to_lowercase();
170        let keywords: Vec<&str> = query_lower.split_whitespace().collect();
171
172        let mut results: Vec<(IndexedFile, f32)> = files
173            .iter()
174            .filter_map(|file| {
175                let content_lower = file.content.to_lowercase();
176                let mut score = 0.0;
177                for keyword in &keywords {
178                    let count = content_lower.matches(keyword).count();
179                    score += count as f32;
180                }
181
182                if score > 0.0 {
183                    let normalized_score = score / (file.content.len() as f32).sqrt();
184                    Some((file.clone(), normalized_score))
185                } else {
186                    None
187                }
188            })
189            .collect();
190
191        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
192        results.truncate(max_results);
193        results
194    }
195
196    async fn update_cache(&self, files: Vec<IndexedFile>) {
197        if !self.config.enable_cache {
198            return;
199        }
200
201        let mut cache = self.cache.write().await;
202        cache.clear();
203        for file in files {
204            cache.insert(file.path.clone(), file);
205        }
206    }
207
208    async fn get_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
209        if self.config.enable_cache {
210            let cache = self.cache.read().await;
211            if !cache.is_empty() {
212                return Ok(cache.values().cloned().collect());
213            }
214        }
215
216        let files = self.index_files().await?;
217        self.update_cache(files.clone()).await;
218        Ok(files)
219    }
220}
221
222#[async_trait]
223impl ContextProvider for FileSystemContextProvider {
224    fn name(&self) -> &str {
225        "filesystem"
226    }
227
228    async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
229        let files = self.get_files().await?;
230        let results = self
231            .search_simple(&query.query, &files, query.max_results)
232            .await;
233
234        let items: Vec<ContextItem> = results
235            .into_iter()
236            .map(|(file, score)| {
237                let content = match query.depth {
238                    crate::context::ContextDepth::Abstract => {
239                        file.content.chars().take(500).collect::<String>()
240                    }
241                    crate::context::ContextDepth::Overview => {
242                        file.content.chars().take(2000).collect::<String>()
243                    }
244                    crate::context::ContextDepth::Full => file.content.clone(),
245                };
246
247                let token_count = content.split_whitespace().count();
248
249                ContextItem {
250                    id: file.path.to_string_lossy().to_string(),
251                    context_type: ContextType::Resource,
252                    content,
253                    token_count,
254                    relevance: score,
255                    source: Some(format!("file:{}", file.path.display())),
256                    metadata: {
257                        let mut meta = HashMap::new();
258                        meta.insert(
259                            "path".to_string(),
260                            serde_json::Value::String(file.path.to_string_lossy().to_string()),
261                        );
262                        meta.insert(
263                            "size".to_string(),
264                            serde_json::Value::Number(file.size.into()),
265                        );
266                        meta
267                    },
268                }
269            })
270            .collect();
271
272        let total_tokens: usize = items.iter().map(|item| item.token_count).sum();
273        let truncated = items.len() < files.len();
274
275        Ok(ContextResult {
276            items,
277            total_tokens,
278            provider: self.name().to_string(),
279            truncated,
280        })
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use std::fs::File;
288    use std::io::Write;
289    use tempfile::TempDir;
290
291    fn create_test_files(dir: &Path) -> anyhow::Result<()> {
292        let mut file1 = File::create(dir.join("test1.rs"))?;
293        writeln!(file1, "fn main() {{\n    println!(\"Hello, world!\");\n}}")?;
294
295        let mut file2 = File::create(dir.join("test2.md"))?;
296        writeln!(
297            file2,
298            "# Test Document\n\nThis is a test document about Rust programming."
299        )?;
300
301        fs::create_dir(dir.join("subdir"))?;
302        let mut file4 = File::create(dir.join("subdir/test4.rs"))?;
303        writeln!(file4, "fn test() {{\n    // Test function\n}}")?;
304
305        Ok(())
306    }
307
308    #[tokio::test]
309    async fn test_index_files() {
310        let temp_dir = TempDir::new().unwrap();
311        create_test_files(temp_dir.path()).unwrap();
312
313        let config = FileSystemContextConfig::new(temp_dir.path());
314        let provider = FileSystemContextProvider::new(config);
315
316        let files = provider.index_files().await.unwrap();
317        assert!(files.len() >= 2);
318    }
319
320    #[tokio::test]
321    async fn test_search_simple() {
322        let temp_dir = TempDir::new().unwrap();
323        create_test_files(temp_dir.path()).unwrap();
324
325        let config = FileSystemContextConfig::new(temp_dir.path());
326        let provider = FileSystemContextProvider::new(config);
327
328        let query = ContextQuery::new("Rust programming");
329        let result = provider.query(&query).await.unwrap();
330
331        assert!(!result.items.is_empty());
332        assert!(result
333            .items
334            .iter()
335            .any(|item| item.content.contains("Rust")));
336    }
337}