1use crate::lsp::LspServerInfo;
9use chrono::Local;
10use serde::{Deserialize, Serialize};
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserContext {
16 pub claude_md_content: Option<String>,
18 pub preferences: Vec<String>,
20 pub current_date: String,
22 pub language: Option<String>,
24}
25
26impl Default for UserContext {
27 fn default() -> Self {
28 Self {
29 claude_md_content: None,
30 preferences: Vec::new(),
31 current_date: Local::now().format("%Y-%m-%d").to_string(),
32 language: None,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SystemContext {
40 pub git_branch: Option<String>,
42 pub git_status: Option<String>,
44 pub working_directory: Option<String>,
46 pub project_root: Option<String>,
48 pub project_type: Option<ProjectType>,
50 pub available_tools: Vec<String>,
52 pub platform: String,
54 #[serde(default)]
56 pub lsp_servers: Vec<LspServerInfo>,
57}
58
59impl Default for SystemContext {
60 fn default() -> Self {
61 Self {
62 git_branch: None,
63 git_status: None,
64 working_directory: None,
65 project_root: None,
66 project_type: None,
67 available_tools: Vec::new(),
68 platform: std::env::consts::OS.to_string(),
69 lsp_servers: Vec::new(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum ProjectType {
77 Rust,
78 NodeJs,
79 Python,
80 Go,
81 Java,
82 Cpp,
83 Mixed,
84 Unknown,
85}
86
87impl ProjectType {
88 pub fn detect<P: AsRef<Path>>(dir: P) -> Self {
90 let dir = dir.as_ref();
91
92 let has_cargo = dir.join("Cargo.toml").exists();
93 let has_package_json = dir.join("package.json").exists();
94 let has_pyproject = dir.join("pyproject.toml").exists() || dir.join("setup.py").exists();
95 let has_go_mod = dir.join("go.mod").exists();
96 let has_pom = dir.join("pom.xml").exists() || dir.join("build.gradle").exists();
97 let has_cmake = dir.join("CMakeLists.txt").exists() || dir.join("Makefile").exists();
98
99 let types = [
100 has_cargo,
101 has_package_json,
102 has_pyproject,
103 has_go_mod,
104 has_pom,
105 has_cmake,
106 ];
107 let count = types.iter().filter(|&&x| x).count();
108
109 if count > 1 {
110 return ProjectType::Mixed;
111 }
112
113 if has_cargo {
114 ProjectType::Rust
115 } else if has_package_json {
116 ProjectType::NodeJs
117 } else if has_pyproject {
118 ProjectType::Python
119 } else if has_go_mod {
120 ProjectType::Go
121 } else if has_pom {
122 ProjectType::Java
123 } else if has_cmake {
124 ProjectType::Cpp
125 } else {
126 ProjectType::Unknown
127 }
128 }
129
130 pub fn as_str(&self) -> &'static str {
131 match self {
132 ProjectType::Rust => "rust",
133 ProjectType::NodeJs => "nodejs",
134 ProjectType::Python => "python",
135 ProjectType::Go => "go",
136 ProjectType::Java => "java",
137 ProjectType::Cpp => "cpp",
138 ProjectType::Mixed => "mixed",
139 ProjectType::Unknown => "unknown",
140 }
141 }
142}
143
144pub struct ContextInjector {
146 working_dir: PathBuf,
148 user_context_cache: Option<UserContext>,
150 system_context_cache: Option<SystemContext>,
152 dirty: bool,
154}
155
156impl ContextInjector {
157 pub fn new<P: Into<PathBuf>>(working_dir: P) -> Self {
159 Self {
160 working_dir: working_dir.into(),
161 user_context_cache: None,
162 system_context_cache: None,
163 dirty: true,
164 }
165 }
166
167 pub fn invalidate(&mut self) {
169 self.dirty = true;
170 }
171
172 pub fn set_lsp_servers(&mut self, servers: Vec<LspServerInfo>) {
175 if let Some(ref mut ctx) = self.system_context_cache {
176 ctx.lsp_servers = servers;
177 } else {
178 let mut ctx = SystemContext::default();
179 ctx.lsp_servers = servers;
180 self.system_context_cache = Some(ctx);
181 }
182 }
183
184 pub fn get_user_context(&mut self) -> &UserContext {
186 if self.dirty || self.user_context_cache.is_none() {
187 self.user_context_cache = Some(self.collect_user_context());
188 self.dirty = false;
189 }
190 self.user_context_cache.as_ref().unwrap()
191 }
192
193 pub fn get_system_context(&mut self) -> &SystemContext {
195 if self.dirty || self.system_context_cache.is_none() {
196 self.system_context_cache = Some(self.collect_system_context());
197 self.dirty = false; }
199 self.system_context_cache.as_ref().unwrap()
200 }
201
202 fn collect_user_context(&self) -> UserContext {
204 let mut ctx = UserContext::default();
205
206 let claude_md_path = self.working_dir.join("CLAUDE.md");
208 if claude_md_path.exists() {
209 if let Ok(content) = std::fs::read_to_string(&claude_md_path) {
210 ctx.claude_md_content = Some(content);
211 }
212 }
213
214 if ctx.claude_md_content.is_none() {
216 if let Some(parent) = self.working_dir.parent() {
217 let parent_claude_md = parent.join("CLAUDE.md");
218 if parent_claude_md.exists() {
219 if let Ok(content) = std::fs::read_to_string(&parent_claude_md) {
220 ctx.claude_md_content = Some(content);
221 }
222 }
223 }
224 }
225
226 ctx
227 }
228
229 fn collect_system_context(&self) -> SystemContext {
231 let mut ctx = SystemContext::default();
232
233 if let Ok(output) = std::process::Command::new("git")
235 .args(["branch", "--show-current"])
236 .current_dir(&self.working_dir)
237 .output()
238 {
239 if output.status.success() {
240 ctx.git_branch = Some(String::from_utf8_lossy(&output.stdout).trim().to_string());
241 }
242 }
243
244 if let Ok(output) = std::process::Command::new("git")
245 .args(["status", "--porcelain"])
246 .current_dir(&self.working_dir)
247 .output()
248 {
249 if output.status.success() {
250 let status = String::from_utf8_lossy(&output.stdout);
251 ctx.git_status = if status.trim().is_empty() {
252 Some("clean".to_string())
253 } else {
254 Some(format!("dirty ({} changes)", status.lines().count()))
255 };
256 }
257 }
258
259 ctx.working_directory = self.working_dir.to_str().map(|s| s.to_string());
261 ctx.project_root = ctx.working_directory.clone();
262
263 ctx.project_type = Some(ProjectType::detect(&self.working_dir));
265
266 let tools = ["git", "cargo", "npm", "python", "go", "docker"];
268 for tool in tools {
269 if Self::tool_available(tool) {
270 ctx.available_tools.push(tool.to_string());
271 }
272 }
273
274 ctx
275 }
276
277 fn tool_available(tool: &str) -> bool {
279 #[cfg(unix)]
280 {
281 std::process::Command::new("which")
282 .arg(tool)
283 .output()
284 .map(|o| o.status.success())
285 .unwrap_or(false)
286 }
287
288 #[cfg(windows)]
289 {
290 std::process::Command::new("where")
291 .arg(tool)
292 .output()
293 .map(|o| o.status.success())
294 .unwrap_or(false)
295 }
296 }
297
298 pub fn render_user_context(&mut self) -> String {
300 let ctx = self.get_user_context();
301 let mut parts = Vec::new();
302
303 parts.push(format!(
305 "<currentDate>\n{}\n</currentDate>",
306 ctx.current_date
307 ));
308
309 if let Some(ref claude_md) = ctx.claude_md_content {
311 parts.push(format!(
312 "<userPreferences>\n{}\n</userPreferences>",
313 claude_md
314 ));
315 }
316
317 parts.join("\n\n")
318 }
319
320 pub fn render_system_context(&mut self) -> String {
322 let ctx = self.get_system_context();
323 let mut parts = Vec::new();
324
325 if let Some(ref dir) = ctx.working_directory {
327 parts.push(format!("<workingDirectory>\n{}\n</workingDirectory>", dir));
328 }
329
330 if let Some(ref branch) = ctx.git_branch {
332 let git_info = if let Some(ref status) = ctx.git_status {
333 format!("Branch: {}\nStatus: {}", branch, status)
334 } else {
335 format!("Branch: {}", branch)
336 };
337 parts.push(format!("<gitContext>\n{}\n</gitContext>", git_info));
338 }
339
340 if let Some(ref pt) = ctx.project_type {
342 parts.push(format!("<projectType>\n{}\n</projectType>", pt.as_str()));
343 }
344
345 if !ctx.available_tools.is_empty() {
347 parts.push(format!(
348 "<availableTools>\n{}\n</availableTools>",
349 ctx.available_tools.join(", ")
350 ));
351 }
352
353 if !ctx.lsp_servers.is_empty() {
355 let servers_info = ctx
356 .lsp_servers
357 .iter()
358 .map(|s| {
359 let status = s.status.label();
360 format!("{}: {} [{}]", s.language, s.name, status)
361 })
362 .collect::<Vec<_>>()
363 .join("\n");
364 parts.push(format!("<lspServers>\n{}\n</lspServers>", servers_info));
365 }
366
367 parts.join("\n\n")
368 }
369
370 pub fn render_full_context(&mut self) -> String {
372 let user = self.render_user_context();
373 let system = self.render_system_context();
374
375 format!("<context>\n{}\n\n{}\n</context>", user, system)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use std::path::PathBuf;
383
384 #[test]
385 fn test_user_context_default() {
386 let ctx = UserContext::default();
387 assert!(ctx.claude_md_content.is_none());
388 assert!(!ctx.current_date.is_empty());
389 }
390
391 #[test]
392 fn test_system_context_default() {
393 let ctx = SystemContext::default();
394 assert!(ctx.git_branch.is_none());
395 assert!(!ctx.platform.is_empty());
396 }
397
398 #[test]
399 fn test_project_type_detect_rust() {
400 let temp_dir = tempfile::tempdir().unwrap();
401 std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
402
403 assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Rust);
404 }
405
406 #[test]
407 fn test_project_type_detect_mixed() {
408 let temp_dir = tempfile::tempdir().unwrap();
409 std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
410 std::fs::write(temp_dir.path().join("package.json"), "").unwrap();
411
412 assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Mixed);
413 }
414
415 #[test]
416 fn test_context_invalidator() {
417 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
418
419 let _ = injector.get_user_context();
421
422 injector.invalidate();
424
425 let _ = injector.get_user_context();
427 }
428
429 #[test]
430 fn test_render_user_context() {
431 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
432 let rendered = injector.render_user_context();
433
434 assert!(rendered.contains("<currentDate>"));
435 }
436
437 #[test]
438 fn test_render_system_context() {
439 let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
440 let rendered = injector.render_system_context();
441
442 assert!(rendered.contains("<workingDirectory>") || rendered.contains("<projectType>"));
443 }
444}