1use std::path::{Path, PathBuf};
9use chrono::Local;
10use serde::{Deserialize, Serialize};
11use crate::lsp::LspServerInfo;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserContext {
16 pub claude_md_content: Option<String>,
18 pub preferences: Vec<String>,
20 pub current_date: String,
22 pub language: Option<String>,
24}
25
26impl Default for UserContext {
27 fn default() -> Self {
28 Self {
29 claude_md_content: None,
30 preferences: Vec::new(),
31 current_date: Local::now().format("%Y-%m-%d").to_string(),
32 language: None,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SystemContext {
40 pub git_branch: Option<String>,
42 pub git_status: Option<String>,
44 pub working_directory: Option<String>,
46 pub project_root: Option<String>,
48 pub project_type: Option<ProjectType>,
50 pub available_tools: Vec<String>,
52 pub platform: String,
54 #[serde(default)]
56 pub lsp_servers: Vec<LspServerInfo>,
57}
58
59impl Default for SystemContext {
60 fn default() -> Self {
61 Self {
62 git_branch: None,
63 git_status: None,
64 working_directory: None,
65 project_root: None,
66 project_type: None,
67 available_tools: Vec::new(),
68 platform: std::env::consts::OS.to_string(),
69 lsp_servers: Vec::new(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum ProjectType {
77 Rust,
78 NodeJs,
79 Python,
80 Go,
81 Java,
82 Cpp,
83 Mixed,
84 Unknown,
85}
86
87impl ProjectType {
88 pub fn detect<P: AsRef<Path>>(dir: P) -> Self {
90 let dir = dir.as_ref();
91
92 let has_cargo = dir.join("Cargo.toml").exists();
93 let has_package_json = dir.join("package.json").exists();
94 let has_pyproject = dir.join("pyproject.toml").exists() || dir.join("setup.py").exists();
95 let has_go_mod = dir.join("go.mod").exists();
96 let has_pom = dir.join("pom.xml").exists() || dir.join("build.gradle").exists();
97 let has_cmake = dir.join("CMakeLists.txt").exists() || dir.join("Makefile").exists();
98
99 let types = [has_cargo, has_package_json, has_pyproject, has_go_mod, has_pom, has_cmake];
100 let count = types.iter().filter(|&&x| x).count();
101
102 if count > 1 {
103 return ProjectType::Mixed;
104 }
105
106 if has_cargo {
107 ProjectType::Rust
108 } else if has_package_json {
109 ProjectType::NodeJs
110 } else if has_pyproject {
111 ProjectType::Python
112 } else if has_go_mod {
113 ProjectType::Go
114 } else if has_pom {
115 ProjectType::Java
116 } else if has_cmake {
117 ProjectType::Cpp
118 } else {
119 ProjectType::Unknown
120 }
121 }
122
123 pub fn as_str(&self) -> &'static str {
124 match self {
125 ProjectType::Rust => "rust",
126 ProjectType::NodeJs => "nodejs",
127 ProjectType::Python => "python",
128 ProjectType::Go => "go",
129 ProjectType::Java => "java",
130 ProjectType::Cpp => "cpp",
131 ProjectType::Mixed => "mixed",
132 ProjectType::Unknown => "unknown",
133 }
134 }
135}
136
137pub struct ContextInjector {
139 working_dir: PathBuf,
141 user_context_cache: Option<UserContext>,
143 system_context_cache: Option<SystemContext>,
145 dirty: bool,
147}
148
149impl ContextInjector {
150 pub fn new<P: Into<PathBuf>>(working_dir: P) -> Self {
152 Self {
153 working_dir: working_dir.into(),
154 user_context_cache: None,
155 system_context_cache: None,
156 dirty: true,
157 }
158 }
159
160 pub fn invalidate(&mut self) {
162 self.dirty = true;
163 }
164
165 pub fn set_lsp_servers(&mut self, servers: Vec<LspServerInfo>) {
168 if let Some(ref mut ctx) = self.system_context_cache {
169 ctx.lsp_servers = servers;
170 } else {
171 let mut ctx = SystemContext::default();
172 ctx.lsp_servers = servers;
173 self.system_context_cache = Some(ctx);
174 }
175 }
176
177 pub fn get_user_context(&mut self) -> &UserContext {
179 if self.dirty || self.user_context_cache.is_none() {
180 self.user_context_cache = Some(self.collect_user_context());
181 self.dirty = false;
182 }
183 self.user_context_cache.as_ref().unwrap()
184 }
185
186 pub fn get_system_context(&mut self) -> &SystemContext {
188 if self.dirty || self.system_context_cache.is_none() {
189 self.system_context_cache = Some(self.collect_system_context());
190 self.dirty = false; }
192 self.system_context_cache.as_ref().unwrap()
193 }
194
195 fn collect_user_context(&self) -> UserContext {
197 let mut ctx = UserContext::default();
198
199 let claude_md_path = self.working_dir.join("CLAUDE.md");
201 if claude_md_path.exists() {
202 if let Ok(content) = std::fs::read_to_string(&claude_md_path) {
203 ctx.claude_md_content = Some(content);
204 }
205 }
206
207 if ctx.claude_md_content.is_none() {
209 if let Some(parent) = self.working_dir.parent() {
210 let parent_claude_md = parent.join("CLAUDE.md");
211 if parent_claude_md.exists() {
212 if let Ok(content) = std::fs::read_to_string(&parent_claude_md) {
213 ctx.claude_md_content = Some(content);
214 }
215 }
216 }
217 }
218
219 ctx
220 }
221
222 fn collect_system_context(&self) -> SystemContext {
224 let mut ctx = SystemContext::default();
225
226 if let Ok(output) = std::process::Command::new("git")
228 .args(["branch", "--show-current"])
229 .current_dir(&self.working_dir)
230 .output()
231 {
232 if output.status.success() {
233 ctx.git_branch = Some(String::from_utf8_lossy(&output.stdout).trim().to_string());
234 }
235 }
236
237 if let Ok(output) = std::process::Command::new("git")
238 .args(["status", "--porcelain"])
239 .current_dir(&self.working_dir)
240 .output()
241 {
242 if output.status.success() {
243 let status = String::from_utf8_lossy(&output.stdout);
244 ctx.git_status = if status.trim().is_empty() {
245 Some("clean".to_string())
246 } else {
247 Some(format!("dirty ({} changes)", status.lines().count()))
248 };
249 }
250 }
251
252 ctx.working_directory = self.working_dir.to_str().map(|s| s.to_string());
254 ctx.project_root = ctx.working_directory.clone();
255
256 ctx.project_type = Some(ProjectType::detect(&self.working_dir));
258
259 let tools = ["git", "cargo", "npm", "python", "go", "docker"];
261 for tool in tools {
262 if Self::tool_available(tool) {
263 ctx.available_tools.push(tool.to_string());
264 }
265 }
266
267 ctx
268 }
269
270 fn tool_available(tool: &str) -> bool {
272 #[cfg(unix)]
273 {
274 std::process::Command::new("which")
275 .arg(tool)
276 .output()
277 .map(|o| o.status.success())
278 .unwrap_or(false)
279 }
280
281 #[cfg(windows)]
282 {
283 std::process::Command::new("where")
284 .arg(tool)
285 .output()
286 .map(|o| o.status.success())
287 .unwrap_or(false)
288 }
289 }
290
291 pub fn render_user_context(&mut self) -> String {
293 let ctx = self.get_user_context();
294 let mut parts = Vec::new();
295
296 parts.push(format!("<currentDate>\n{}\n</currentDate>", ctx.current_date));
298
299 if let Some(ref claude_md) = ctx.claude_md_content {
301 parts.push(format!("<userPreferences>\n{}\n</userPreferences>", claude_md));
302 }
303
304 parts.join("\n\n")
305 }
306
307 pub fn render_system_context(&mut self) -> String {
309 let ctx = self.get_system_context();
310 let mut parts = Vec::new();
311
312 if let Some(ref dir) = ctx.working_directory {
314 parts.push(format!("<workingDirectory>\n{}\n</workingDirectory>", dir));
315 }
316
317 if let Some(ref branch) = ctx.git_branch {
319 let git_info = if let Some(ref status) = ctx.git_status {
320 format!("Branch: {}\nStatus: {}", branch, status)
321 } else {
322 format!("Branch: {}", branch)
323 };
324 parts.push(format!("<gitContext>\n{}\n</gitContext>", git_info));
325 }
326
327 if let Some(ref pt) = ctx.project_type {
329 parts.push(format!("<projectType>\n{}\n</projectType>", pt.as_str()));
330 }
331
332 if !ctx.available_tools.is_empty() {
334 parts.push(format!("<availableTools>\n{}\n</availableTools>", ctx.available_tools.join(", ")));
335 }
336
337 if !ctx.lsp_servers.is_empty() {
339 let servers_info = ctx.lsp_servers.iter()
340 .map(|s| {
341 let status = s.status.label();
342 format!("{}: {} [{}]", s.language, s.name, status)
343 })
344 .collect::<Vec<_>>()
345 .join("\n");
346 parts.push(format!("<lspServers>\n{}\n</lspServers>", servers_info));
347 }
348
349 parts.join("\n\n")
350 }
351
352 pub fn render_full_context(&mut self) -> String {
354 let user = self.render_user_context();
355 let system = self.render_system_context();
356
357 format!(
358 "<context>\n{}\n\n{}\n</context>",
359 user, system
360 )
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use std::path::PathBuf;
368
369 #[test]
370 fn test_user_context_default() {
371 let ctx = UserContext::default();
372 assert!(ctx.claude_md_content.is_none());
373 assert!(!ctx.current_date.is_empty());
374 }
375
376 #[test]
377 fn test_system_context_default() {
378 let ctx = SystemContext::default();
379 assert!(ctx.git_branch.is_none());
380 assert!(!ctx.platform.is_empty());
381 }
382
383 #[test]
384 fn test_project_type_detect_rust() {
385 let temp_dir = tempfile::tempdir().unwrap();
386 std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
387
388 assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Rust);
389 }
390
391 #[test]
392 fn test_project_type_detect_mixed() {
393 let temp_dir = tempfile::tempdir().unwrap();
394 std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
395 std::fs::write(temp_dir.path().join("package.json"), "").unwrap();
396
397 assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Mixed);
398 }
399
400 #[test]
401 fn test_context_invalidator() {
402 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
403
404 let _ = injector.get_user_context();
406
407 injector.invalidate();
409
410 let _ = injector.get_user_context();
412 }
413
414 #[test]
415 fn test_render_user_context() {
416 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
417 let rendered = injector.render_user_context();
418
419 assert!(rendered.contains("<currentDate>"));
420 }
421
422 #[test]
423 fn test_render_system_context() {
424 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
425 let rendered = injector.render_system_context();
426
427 assert!(rendered.contains("<workingDirectory>") || rendered.contains("<projectType>"));
428 }
429}