1use anyhow::Result;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::path::{Path, PathBuf};
13use std::time::{SystemTime, UNIX_EPOCH};
14
15use super::file_collector::{CollectorConfig, FileCollector};
16use crate::models::ProjectContext;
17
18#[derive(Debug, Clone)]
20pub struct ContextManager {
21 root_path: PathBuf,
23 last_file_hash: Option<u64>,
25 last_load_time: Option<u64>,
27 cached_files: Vec<PathBuf>,
29 collector_config: CollectorConfig,
31}
32
33impl ContextManager {
34 pub fn new(root_path: impl AsRef<Path>) -> Self {
36 Self {
37 root_path: root_path.as_ref().to_path_buf(),
38 last_file_hash: None,
39 last_load_time: None,
40 cached_files: Vec::new(),
41 collector_config: CollectorConfig {
42 max_file_size: 1024 * 1024, max_files: 100,
44 priority_extensions: vec![
45 "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "cpp", "c", "h", "hpp",
46 "cs", "rb", "php", "swift", "kt", "scala", "r", "sql", "sh", "yaml", "yml",
47 "toml", "json", "xml", "html", "css", "scss", "md", "txt",
48 ],
49 ignore_patterns: vec![
50 "*.log", "*.tmp", "*.cache", "*.pyc", "*.pyo", "*.pyd", "*.so", "*.dylib",
51 "*.dll", "*.exe", "*.o", "*.a", "*.lib", "*.png", "*.jpg", "*.jpeg", "*.gif",
52 "*.bmp", "*.ico", "*.svg", "*.pdf", "*.zip", "*.tar", "*.gz", "*.rar", "*.7z",
53 ],
54 },
55 }
56 }
57
58 pub async fn needs_reload(&self) -> bool {
60 match self.compute_file_hash().await {
61 Ok(current_hash) => {
62 if let Some(last_hash) = self.last_file_hash {
63 current_hash != last_hash
64 } else {
65 true
67 }
68 }
69 Err(_) => false, }
71 }
72
73 pub async fn reload_if_needed(&mut self) -> Result<bool> {
75 if self.needs_reload().await {
76 self.reload().await?;
77 Ok(true)
78 } else {
79 Ok(false)
80 }
81 }
82
83 pub async fn reload(&mut self) -> Result<()> {
85 let collector = FileCollector::new(self.collector_config.clone());
87 let files = collector.collect_files(&self.root_path).await?;
88
89 let hash = self.compute_hash_from_files(&files)?;
91
92 self.cached_files = files;
94 self.last_file_hash = Some(hash);
95 self.last_load_time = Some(
96 SystemTime::now()
97 .duration_since(UNIX_EPOCH)
98 .unwrap_or_default()
99 .as_secs(),
100 );
101
102 Ok(())
103 }
104
105 pub fn build_context(&self) -> ProjectContext {
112 let mut context = ProjectContext::new(self.root_path.to_string_lossy().to_string());
113 context.project_type = detect_project_type(&self.root_path);
114
115 for file_path in &self.cached_files {
117 if let Ok(rel_path) = file_path.strip_prefix(&self.root_path) {
118 if let Some(path_str) = rel_path.to_str() {
119 context.add_file(path_str.to_string(), String::new());
121 }
122 }
123 }
124
125 context
126 }
127
128 pub fn get_file_list(&self) -> Vec<String> {
130 self.cached_files
131 .iter()
132 .filter_map(|p| {
133 p.strip_prefix(&self.root_path)
134 .ok()
135 .and_then(|p| p.to_str())
136 .map(|s| s.to_string())
137 })
138 .collect()
139 }
140
141 pub fn total_files(&self) -> usize {
143 self.cached_files.len()
144 }
145
146 async fn compute_file_hash(&self) -> Result<u64> {
149 let collector = FileCollector::new(self.collector_config.clone());
151 let current_files = collector.collect_files(&self.root_path).await?;
152
153 self.compute_hash_from_files(¤t_files)
154 }
155
156 fn compute_hash_from_files(&self, files: &[PathBuf]) -> Result<u64> {
158 let mut hasher = DefaultHasher::new();
159
160 let mut file_paths: Vec<_> = files
162 .iter()
163 .filter_map(|p| {
164 p.strip_prefix(&self.root_path)
165 .ok()
166 .and_then(|p| p.to_str())
167 })
168 .collect();
169 file_paths.sort();
170
171 for path in file_paths {
172 path.hash(&mut hasher);
173 }
174
175 Ok(hasher.finish())
176 }
177}
178
179fn detect_project_type(root_path: &Path) -> Option<String> {
181 if root_path.join("Cargo.toml").exists() {
182 Some("Rust".to_string())
183 } else if root_path.join("package.json").exists() {
184 Some("JavaScript/TypeScript".to_string())
185 } else if root_path.join("requirements.txt").exists() || root_path.join("setup.py").exists() {
186 Some("Python".to_string())
187 } else if root_path.join("go.mod").exists() {
188 Some("Go".to_string())
189 } else if root_path.join("pom.xml").exists() || root_path.join("build.gradle").exists() {
190 Some("Java".to_string())
191 } else if root_path.join("Gemfile").exists() {
192 Some("Ruby".to_string())
193 } else if root_path.join("composer.json").exists() {
194 Some("PHP".to_string())
195 } else {
196 None
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use std::fs;
204 use tempfile::TempDir;
205
206 #[tokio::test]
207 async fn test_context_manager_creation() {
208 let temp_dir = TempDir::new().unwrap();
209 let manager = ContextManager::new(temp_dir.path());
210
211 assert_eq!(manager.root_path, temp_dir.path());
212 assert_eq!(manager.total_files(), 0);
213 assert!(manager.needs_reload().await);
214 }
215
216 #[tokio::test]
217 async fn test_file_tree_change_detection() {
218 let temp_dir = TempDir::new().unwrap();
219 let mut manager = ContextManager::new(temp_dir.path());
220
221 manager.reload().await.unwrap();
223 let initial_hash = manager.last_file_hash;
224
225 assert!(!manager.needs_reload().await);
227
228 let test_file = temp_dir.path().join("test.py");
230 fs::write(&test_file, "print('test')").unwrap();
231
232 assert!(manager.needs_reload().await);
233
234 manager.reload().await.unwrap();
236 assert_ne!(manager.last_file_hash, initial_hash);
237 }
238
239 #[tokio::test]
240 async fn test_project_context_building() {
241 let temp_dir = TempDir::new().unwrap();
242
243 fs::write(temp_dir.path().join("main.py"), "print('hello')").unwrap();
245 fs::write(temp_dir.path().join("lib.py"), "def helper(): pass").unwrap();
246 fs::write(temp_dir.path().join("requirements.txt"), "requests\n").unwrap();
247
248 let mut manager = ContextManager::new(temp_dir.path());
249 manager.reload().await.unwrap();
250
251 let context = manager.build_context();
252 assert_eq!(context.root_path, temp_dir.path().to_string_lossy().to_string());
253 assert_eq!(context.project_type, Some("Python".to_string()));
254 assert_eq!(context.files.len(), 3); }
256}