Skip to main content

matrixcode_core/prompt/
context.rs

1//! Runtime Context Injection
2//!
3//! Provides dynamic context injection for:
4//! - User context (CLAUDE.md files, preferences)
5//! - System context (git status, date, workspace info)
6//! - Environment context (platform, tools available)
7
8use std::path::{Path, PathBuf};
9use chrono::Local;
10use serde::{Deserialize, Serialize};
11use crate::lsp::LspServerInfo;
12
13/// User context from CLAUDE.md and preferences
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserContext {
16    /// Content from CLAUDE.md file
17    pub claude_md_content: Option<String>,
18    /// User preferences (from accumulated memory)
19    pub preferences: Vec<String>,
20    /// Current date/time
21    pub current_date: String,
22    /// User's preferred language
23    pub language: Option<String>,
24}
25
26impl Default for UserContext {
27    fn default() -> Self {
28        Self {
29            claude_md_content: None,
30            preferences: Vec::new(),
31            current_date: Local::now().format("%Y-%m-%d").to_string(),
32            language: None,
33        }
34    }
35}
36
37/// System context (git, workspace, tools)
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SystemContext {
40    /// Git branch name
41    pub git_branch: Option<String>,
42    /// Git status (clean/dirty)
43    pub git_status: Option<String>,
44    /// Current working directory
45    pub working_directory: Option<String>,
46    /// Project root directory
47    pub project_root: Option<String>,
48    /// Detected project type
49    pub project_type: Option<ProjectType>,
50    /// Available tools
51    pub available_tools: Vec<String>,
52    /// Platform info
53    pub platform: String,
54    /// LSP servers status (dynamic injection)
55    #[serde(default)]
56    pub lsp_servers: Vec<LspServerInfo>,
57}
58
59impl Default for SystemContext {
60    fn default() -> Self {
61        Self {
62            git_branch: None,
63            git_status: None,
64            working_directory: None,
65            project_root: None,
66            project_type: None,
67            available_tools: Vec::new(),
68            platform: std::env::consts::OS.to_string(),
69            lsp_servers: Vec::new(),
70        }
71    }
72}
73
74/// Detected project type
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum ProjectType {
77    Rust,
78    NodeJs,
79    Python,
80    Go,
81    Java,
82    Cpp,
83    Mixed,
84    Unknown,
85}
86
87impl ProjectType {
88    /// Detect project type from directory contents
89    pub fn detect<P: AsRef<Path>>(dir: P) -> Self {
90        let dir = dir.as_ref();
91        
92        let has_cargo = dir.join("Cargo.toml").exists();
93        let has_package_json = dir.join("package.json").exists();
94        let has_pyproject = dir.join("pyproject.toml").exists() || dir.join("setup.py").exists();
95        let has_go_mod = dir.join("go.mod").exists();
96        let has_pom = dir.join("pom.xml").exists() || dir.join("build.gradle").exists();
97        let has_cmake = dir.join("CMakeLists.txt").exists() || dir.join("Makefile").exists();
98        
99        let types = [has_cargo, has_package_json, has_pyproject, has_go_mod, has_pom, has_cmake];
100        let count = types.iter().filter(|&&x| x).count();
101        
102        if count > 1 {
103            return ProjectType::Mixed;
104        }
105        
106        if has_cargo {
107            ProjectType::Rust
108        } else if has_package_json {
109            ProjectType::NodeJs
110        } else if has_pyproject {
111            ProjectType::Python
112        } else if has_go_mod {
113            ProjectType::Go
114        } else if has_pom {
115            ProjectType::Java
116        } else if has_cmake {
117            ProjectType::Cpp
118        } else {
119            ProjectType::Unknown
120        }
121    }
122    
123    pub fn as_str(&self) -> &'static str {
124        match self {
125            ProjectType::Rust => "rust",
126            ProjectType::NodeJs => "nodejs",
127            ProjectType::Python => "python",
128            ProjectType::Go => "go",
129            ProjectType::Java => "java",
130            ProjectType::Cpp => "cpp",
131            ProjectType::Mixed => "mixed",
132            ProjectType::Unknown => "unknown",
133        }
134    }
135}
136
137/// Context injector for runtime context
138pub struct ContextInjector {
139    /// Working directory
140    working_dir: PathBuf,
141    /// Cache of user context
142    user_context_cache: Option<UserContext>,
143    /// Cache of system context
144    system_context_cache: Option<SystemContext>,
145    /// Whether to refresh cache
146    dirty: bool,
147}
148
149impl ContextInjector {
150    /// Create a new context injector
151    pub fn new<P: Into<PathBuf>>(working_dir: P) -> Self {
152        Self {
153            working_dir: working_dir.into(),
154            user_context_cache: None,
155            system_context_cache: None,
156            dirty: true,
157        }
158    }
159
160    /// Mark cache as dirty (need refresh)
161    pub fn invalidate(&mut self) {
162        self.dirty = true;
163    }
164
165    /// Set LSP servers info (for dynamic injection)
166    /// This updates the cached system context with LSP server information
167    pub fn set_lsp_servers(&mut self, servers: Vec<LspServerInfo>) {
168        if let Some(ref mut ctx) = self.system_context_cache {
169            ctx.lsp_servers = servers;
170        } else {
171            let mut ctx = SystemContext::default();
172            ctx.lsp_servers = servers;
173            self.system_context_cache = Some(ctx);
174        }
175    }
176
177    /// Get user context
178    pub fn get_user_context(&mut self) -> &UserContext {
179        if self.dirty || self.user_context_cache.is_none() {
180            self.user_context_cache = Some(self.collect_user_context());
181            self.dirty = false;
182        }
183        self.user_context_cache.as_ref().unwrap()
184    }
185
186    /// Get system context
187    pub fn get_system_context(&mut self) -> &SystemContext {
188        if self.dirty || self.system_context_cache.is_none() {
189            self.system_context_cache = Some(self.collect_system_context());
190            self.dirty = false;  // Reset dirty flag after refresh
191        }
192        self.system_context_cache.as_ref().unwrap()
193    }
194
195    /// Collect user context
196    fn collect_user_context(&self) -> UserContext {
197        let mut ctx = UserContext::default();
198        
199        // Read CLAUDE.md
200        let claude_md_path = self.working_dir.join("CLAUDE.md");
201        if claude_md_path.exists() {
202            if let Ok(content) = std::fs::read_to_string(&claude_md_path) {
203                ctx.claude_md_content = Some(content);
204            }
205        }
206        
207        // Also check parent directories
208        if ctx.claude_md_content.is_none() {
209            if let Some(parent) = self.working_dir.parent() {
210                let parent_claude_md = parent.join("CLAUDE.md");
211                if parent_claude_md.exists() {
212                    if let Ok(content) = std::fs::read_to_string(&parent_claude_md) {
213                        ctx.claude_md_content = Some(content);
214                    }
215                }
216            }
217        }
218        
219        ctx
220    }
221
222    /// Collect system context
223    fn collect_system_context(&self) -> SystemContext {
224        let mut ctx = SystemContext::default();
225        
226        // Git info
227        if let Ok(output) = std::process::Command::new("git")
228            .args(["branch", "--show-current"])
229            .current_dir(&self.working_dir)
230            .output()
231        {
232            if output.status.success() {
233                ctx.git_branch = Some(String::from_utf8_lossy(&output.stdout).trim().to_string());
234            }
235        }
236        
237        if let Ok(output) = std::process::Command::new("git")
238            .args(["status", "--porcelain"])
239            .current_dir(&self.working_dir)
240            .output()
241        {
242            if output.status.success() {
243                let status = String::from_utf8_lossy(&output.stdout);
244                ctx.git_status = if status.trim().is_empty() {
245                    Some("clean".to_string())
246                } else {
247                    Some(format!("dirty ({} changes)", status.lines().count()))
248                };
249            }
250        }
251        
252        // Working directory
253        ctx.working_directory = self.working_dir.to_str().map(|s| s.to_string());
254        ctx.project_root = ctx.working_directory.clone();
255        
256        // Project type
257        ctx.project_type = Some(ProjectType::detect(&self.working_dir));
258        
259        // Available tools (check common tools)
260        let tools = ["git", "cargo", "npm", "python", "go", "docker"];
261        for tool in tools {
262            if Self::tool_available(tool) {
263                ctx.available_tools.push(tool.to_string());
264            }
265        }
266        
267        ctx
268    }
269
270    /// Check if a tool is available
271    fn tool_available(tool: &str) -> bool {
272        #[cfg(unix)]
273        {
274            std::process::Command::new("which")
275                .arg(tool)
276                .output()
277                .map(|o| o.status.success())
278                .unwrap_or(false)
279        }
280        
281        #[cfg(windows)]
282        {
283            std::process::Command::new("where")
284                .arg(tool)
285                .output()
286                .map(|o| o.status.success())
287                .unwrap_or(false)
288        }
289    }
290
291    /// Render user context as prompt section
292    pub fn render_user_context(&mut self) -> String {
293        let ctx = self.get_user_context();
294        let mut parts = Vec::new();
295        
296        // Date
297        parts.push(format!("<currentDate>\n{}\n</currentDate>", ctx.current_date));
298        
299        // CLAUDE.md content
300        if let Some(ref claude_md) = ctx.claude_md_content {
301            parts.push(format!("<userPreferences>\n{}\n</userPreferences>", claude_md));
302        }
303        
304        parts.join("\n\n")
305    }
306
307    /// Render system context as prompt section
308    pub fn render_system_context(&mut self) -> String {
309        let ctx = self.get_system_context();
310        let mut parts = Vec::new();
311        
312        // Working directory
313        if let Some(ref dir) = ctx.working_directory {
314            parts.push(format!("<workingDirectory>\n{}\n</workingDirectory>", dir));
315        }
316        
317        // Git info
318        if let Some(ref branch) = ctx.git_branch {
319            let git_info = if let Some(ref status) = ctx.git_status {
320                format!("Branch: {}\nStatus: {}", branch, status)
321            } else {
322                format!("Branch: {}", branch)
323            };
324            parts.push(format!("<gitContext>\n{}\n</gitContext>", git_info));
325        }
326        
327        // Project type
328        if let Some(ref pt) = ctx.project_type {
329            parts.push(format!("<projectType>\n{}\n</projectType>", pt.as_str()));
330        }
331        
332        // Available tools
333        if !ctx.available_tools.is_empty() {
334            parts.push(format!("<availableTools>\n{}\n</availableTools>", ctx.available_tools.join(", ")));
335        }
336        
337        // LSP servers (dynamic injection)
338        if !ctx.lsp_servers.is_empty() {
339            let servers_info = ctx.lsp_servers.iter()
340                .map(|s| {
341                    let status = s.status.label();
342                    format!("{}: {} [{}]", s.language, s.name, status)
343                })
344                .collect::<Vec<_>>()
345                .join("\n");
346            parts.push(format!("<lspServers>\n{}\n</lspServers>", servers_info));
347        }
348        
349        parts.join("\n\n")
350    }
351
352    /// Render full context (for injection into prompt)
353    pub fn render_full_context(&mut self) -> String {
354        let user = self.render_user_context();
355        let system = self.render_system_context();
356        
357        format!(
358            "<context>\n{}\n\n{}\n</context>",
359            user, system
360        )
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use std::path::PathBuf;
368
369    #[test]
370    fn test_user_context_default() {
371        let ctx = UserContext::default();
372        assert!(ctx.claude_md_content.is_none());
373        assert!(!ctx.current_date.is_empty());
374    }
375
376    #[test]
377    fn test_system_context_default() {
378        let ctx = SystemContext::default();
379        assert!(ctx.git_branch.is_none());
380        assert!(!ctx.platform.is_empty());
381    }
382
383    #[test]
384    fn test_project_type_detect_rust() {
385        let temp_dir = tempfile::tempdir().unwrap();
386        std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
387        
388        assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Rust);
389    }
390
391    #[test]
392    fn test_project_type_detect_mixed() {
393        let temp_dir = tempfile::tempdir().unwrap();
394        std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
395        std::fs::write(temp_dir.path().join("package.json"), "").unwrap();
396        
397        assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Mixed);
398    }
399
400    #[test]
401    fn test_context_invalidator() {
402        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
403        
404        // Get once
405        let _ = injector.get_user_context();
406        
407        // Invalidate
408        injector.invalidate();
409        
410        // Should refresh
411        let _ = injector.get_user_context();
412    }
413
414    #[test]
415    fn test_render_user_context() {
416        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
417        let rendered = injector.render_user_context();
418        
419        assert!(rendered.contains("<currentDate>"));
420    }
421
422    #[test]
423    fn test_render_system_context() {
424        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
425        let rendered = injector.render_system_context();
426        
427        assert!(rendered.contains("<workingDirectory>") || rendered.contains("<projectType>"));
428    }
429}