Skip to main content

a3s_code_core/context/
fs_provider.rs

1//! File System Context Provider
2//!
3//! Provides simple file-based RAG (Retrieval-Augmented Generation):
4//! - Index files in a directory with glob pattern filtering
5//! - Simple keyword-based search with relevance scoring
6//! - Support for file size limits and exclusion patterns
7
8use crate::context::{ContextItem, ContextProvider, ContextQuery, ContextResult, ContextType};
9use async_trait::async_trait;
10use ignore::WalkBuilder;
11use std::collections::HashMap;
12use std::fs;
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17/// File system context provider configuration
18#[derive(Debug, Clone)]
19pub struct FileSystemContextConfig {
20    /// Root directory to index
21    pub root_path: PathBuf,
22    /// Include patterns (glob syntax: ["**/*.rs", "**/*.md"])
23    pub include_patterns: Vec<String>,
24    /// Exclude patterns (glob syntax: ["**/target/**", "**/node_modules/**"])
25    pub exclude_patterns: Vec<String>,
26    /// Maximum file size in bytes (default: 1MB)
27    pub max_file_size: usize,
28    /// Whether to enable cache (default: true)
29    pub enable_cache: bool,
30}
31
32impl FileSystemContextConfig {
33    /// Create a new config with default settings
34    pub fn new(root_path: impl Into<PathBuf>) -> Self {
35        Self {
36            root_path: root_path.into(),
37            include_patterns: vec!["**/*.rs".to_string(), "**/*.md".to_string()],
38            exclude_patterns: vec![
39                "**/target/**".to_string(),
40                "**/node_modules/**".to_string(),
41                "**/.git/**".to_string(),
42            ],
43            max_file_size: 1024 * 1024, // 1MB
44            enable_cache: true,
45        }
46    }
47
48    /// Set include patterns
49    pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
50        self.include_patterns = patterns;
51        self
52    }
53
54    /// Set exclude patterns
55    pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
56        self.exclude_patterns = patterns;
57        self
58    }
59
60    /// Set max file size
61    pub fn with_max_file_size(mut self, size: usize) -> Self {
62        self.max_file_size = size;
63        self
64    }
65
66    /// Enable/disable cache
67    pub fn with_cache(mut self, enable: bool) -> Self {
68        self.enable_cache = enable;
69        self
70    }
71}
72
73/// Indexed file entry
74#[derive(Debug, Clone)]
75struct IndexedFile {
76    path: PathBuf,
77    content: String,
78    size: usize,
79}
80
81/// File system context provider
82pub struct FileSystemContextProvider {
83    config: FileSystemContextConfig,
84    /// Cached indexed files
85    cache: Arc<RwLock<HashMap<PathBuf, IndexedFile>>>,
86}
87
88impl FileSystemContextProvider {
89    /// Create a new file system context provider
90    pub fn new(config: FileSystemContextConfig) -> Self {
91        Self {
92            config,
93            cache: Arc::new(RwLock::new(HashMap::new())),
94        }
95    }
96
97    /// Index files in the root directory
98    async fn index_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
99        let mut files = Vec::new();
100
101        let walker = WalkBuilder::new(&self.config.root_path)
102            .hidden(false)
103            .git_ignore(true)
104            .build();
105
106        for entry in walker {
107            let entry = entry.map_err(|e| anyhow::anyhow!("Walk error: {}", e))?;
108            let path = entry.path();
109
110            if !path.is_file() {
111                continue;
112            }
113
114            let metadata = fs::metadata(path).map_err(|e| anyhow::anyhow!("Metadata error: {}", e))?;
115            if metadata.len() > self.config.max_file_size as u64 {
116                continue;
117            }
118
119            if !self.matches_include_patterns(path) {
120                continue;
121            }
122
123            if self.matches_exclude_patterns(path) {
124                continue;
125            }
126
127            let content = fs::read_to_string(path).map_err(|e| anyhow::anyhow!("Read error: {}", e))?;
128
129            files.push(IndexedFile {
130                path: path.to_path_buf(),
131                content,
132                size: metadata.len() as usize,
133            });
134        }
135
136        Ok(files)
137    }
138
139    fn matches_include_patterns(&self, path: &Path) -> bool {
140        if self.config.include_patterns.is_empty() {
141            return true;
142        }
143
144        let path_str = path.to_string_lossy();
145        self.config.include_patterns.iter().any(|pattern| {
146            glob::Pattern::new(pattern)
147                .map(|p| p.matches(&path_str))
148                .unwrap_or(false)
149        })
150    }
151
152    fn matches_exclude_patterns(&self, path: &Path) -> bool {
153        let path_str = path.to_string_lossy();
154        self.config.exclude_patterns.iter().any(|pattern| {
155            glob::Pattern::new(pattern)
156                .map(|p| p.matches(&path_str))
157                .unwrap_or(false)
158        })
159    }
160
161    async fn search_simple(&self, query: &str, files: &[IndexedFile], max_results: usize) -> Vec<(IndexedFile, f32)> {
162        let query_lower = query.to_lowercase();
163        let keywords: Vec<&str> = query_lower.split_whitespace().collect();
164
165        let mut results: Vec<(IndexedFile, f32)> = files
166            .iter()
167            .filter_map(|file| {
168                let content_lower = file.content.to_lowercase();
169                let mut score = 0.0;
170                for keyword in &keywords {
171                    let count = content_lower.matches(keyword).count();
172                    score += count as f32;
173                }
174
175                if score > 0.0 {
176                    let normalized_score = score / (file.content.len() as f32).sqrt();
177                    Some((file.clone(), normalized_score))
178                } else {
179                    None
180                }
181            })
182            .collect();
183
184        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
185        results.truncate(max_results);
186        results
187    }
188
189    async fn update_cache(&self, files: Vec<IndexedFile>) {
190        if !self.config.enable_cache {
191            return;
192        }
193
194        let mut cache = self.cache.write().await;
195        cache.clear();
196        for file in files {
197            cache.insert(file.path.clone(), file);
198        }
199    }
200
201    async fn get_files(&self) -> anyhow::Result<Vec<IndexedFile>> {
202        if self.config.enable_cache {
203            let cache = self.cache.read().await;
204            if !cache.is_empty() {
205                return Ok(cache.values().cloned().collect());
206            }
207        }
208
209        let files = self.index_files().await?;
210        self.update_cache(files.clone()).await;
211        Ok(files)
212    }
213}
214
215#[async_trait]
216impl ContextProvider for FileSystemContextProvider {
217    fn name(&self) -> &str {
218        "filesystem"
219    }
220
221    async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
222        let files = self.get_files().await?;
223        let results = self.search_simple(&query.query, &files, query.max_results).await;
224
225        let items: Vec<ContextItem> = results
226            .into_iter()
227            .map(|(file, score)| {
228                let content = match query.depth {
229                    crate::context::ContextDepth::Abstract => {
230                        file.content.chars().take(500).collect::<String>()
231                    }
232                    crate::context::ContextDepth::Overview => {
233                        file.content.chars().take(2000).collect::<String>()
234                    }
235                    crate::context::ContextDepth::Full => file.content.clone(),
236                };
237
238                let token_count = content.split_whitespace().count();
239
240                ContextItem {
241                    id: file.path.to_string_lossy().to_string(),
242                    context_type: ContextType::Resource,
243                    content,
244                    token_count,
245                    relevance: score,
246                    source: Some(format!("file:{}", file.path.display())),
247                    metadata: {
248                        let mut meta = HashMap::new();
249                        meta.insert("path".to_string(), serde_json::Value::String(file.path.to_string_lossy().to_string()));
250                        meta.insert("size".to_string(), serde_json::Value::Number(file.size.into()));
251                        meta
252                    },
253                }
254            })
255            .collect();
256
257        let total_tokens: usize = items.iter().map(|item| item.token_count).sum();
258        let truncated = items.len() < files.len();
259
260        Ok(ContextResult {
261            items,
262            total_tokens,
263            provider: self.name().to_string(),
264            truncated,
265        })
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use tempfile::TempDir;
273    use std::fs::File;
274    use std::io::Write;
275
276    fn create_test_files(dir: &Path) -> anyhow::Result<()> {
277        let mut file1 = File::create(dir.join("test1.rs"))?;
278        writeln!(file1, "fn main() {{\n    println!(\"Hello, world!\");\n}}")?;
279
280        let mut file2 = File::create(dir.join("test2.md"))?;
281        writeln!(file2, "# Test Document\n\nThis is a test document about Rust programming.")?;
282
283        fs::create_dir(dir.join("subdir"))?;
284        let mut file4 = File::create(dir.join("subdir/test4.rs"))?;
285        writeln!(file4, "fn test() {{\n    // Test function\n}}")?;
286
287        Ok(())
288    }
289
290    #[tokio::test]
291    async fn test_index_files() {
292        let temp_dir = TempDir::new().unwrap();
293        create_test_files(temp_dir.path()).unwrap();
294
295        let config = FileSystemContextConfig::new(temp_dir.path());
296        let provider = FileSystemContextProvider::new(config);
297
298        let files = provider.index_files().await.unwrap();
299        assert!(files.len() >= 2);
300    }
301
302    #[tokio::test]
303    async fn test_search_simple() {
304        let temp_dir = TempDir::new().unwrap();
305        create_test_files(temp_dir.path()).unwrap();
306
307        let config = FileSystemContextConfig::new(temp_dir.path());
308        let provider = FileSystemContextProvider::new(config);
309
310        let query = ContextQuery::new("Rust programming");
311        let result = provider.query(&query).await.unwrap();
312
313        assert!(!result.items.is_empty());
314        assert!(result.items.iter().any(|item| item.content.contains("Rust")));
315    }
316}