Skip to main content

matrixcode_core/prompt/
context.rs

1//! Runtime Context Injection
2//!
3//! Provides dynamic context injection for:
4//! - User context (CLAUDE.md files, preferences)
5//! - System context (git status, date, workspace info)
6//! - Environment context (platform, tools available)
7
8use crate::lsp::LspServerInfo;
9use chrono::Local;
10use serde::{Deserialize, Serialize};
11use std::path::{Path, PathBuf};
12
13/// User context from CLAUDE.md and preferences
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UserContext {
16    /// Content from CLAUDE.md file
17    pub claude_md_content: Option<String>,
18    /// User preferences (from accumulated memory)
19    pub preferences: Vec<String>,
20    /// Current date/time
21    pub current_date: String,
22    /// User's preferred language
23    pub language: Option<String>,
24}
25
26impl Default for UserContext {
27    fn default() -> Self {
28        Self {
29            claude_md_content: None,
30            preferences: Vec::new(),
31            current_date: Local::now().format("%Y-%m-%d").to_string(),
32            language: None,
33        }
34    }
35}
36
37/// System context (git, workspace, tools)
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SystemContext {
40    /// Git branch name
41    pub git_branch: Option<String>,
42    /// Git status (clean/dirty)
43    pub git_status: Option<String>,
44    /// Current working directory
45    pub working_directory: Option<String>,
46    /// Project root directory
47    pub project_root: Option<String>,
48    /// Detected project type
49    pub project_type: Option<ProjectType>,
50    /// Available tools
51    pub available_tools: Vec<String>,
52    /// Platform info
53    pub platform: String,
54    /// LSP servers status (dynamic injection)
55    #[serde(default)]
56    pub lsp_servers: Vec<LspServerInfo>,
57}
58
59impl Default for SystemContext {
60    fn default() -> Self {
61        Self {
62            git_branch: None,
63            git_status: None,
64            working_directory: None,
65            project_root: None,
66            project_type: None,
67            available_tools: Vec::new(),
68            platform: std::env::consts::OS.to_string(),
69            lsp_servers: Vec::new(),
70        }
71    }
72}
73
74/// Detected project type
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum ProjectType {
77    Rust,
78    NodeJs,
79    Python,
80    Go,
81    Java,
82    Cpp,
83    Mixed,
84    Unknown,
85}
86
87impl ProjectType {
88    /// Detect project type from directory contents
89    pub fn detect<P: AsRef<Path>>(dir: P) -> Self {
90        let dir = dir.as_ref();
91
92        let has_cargo = dir.join("Cargo.toml").exists();
93        let has_package_json = dir.join("package.json").exists();
94        let has_pyproject = dir.join("pyproject.toml").exists() || dir.join("setup.py").exists();
95        let has_go_mod = dir.join("go.mod").exists();
96        let has_pom = dir.join("pom.xml").exists() || dir.join("build.gradle").exists();
97        let has_cmake = dir.join("CMakeLists.txt").exists() || dir.join("Makefile").exists();
98
99        let types = [
100            has_cargo,
101            has_package_json,
102            has_pyproject,
103            has_go_mod,
104            has_pom,
105            has_cmake,
106        ];
107        let count = types.iter().filter(|&&x| x).count();
108
109        if count > 1 {
110            return ProjectType::Mixed;
111        }
112
113        if has_cargo {
114            ProjectType::Rust
115        } else if has_package_json {
116            ProjectType::NodeJs
117        } else if has_pyproject {
118            ProjectType::Python
119        } else if has_go_mod {
120            ProjectType::Go
121        } else if has_pom {
122            ProjectType::Java
123        } else if has_cmake {
124            ProjectType::Cpp
125        } else {
126            ProjectType::Unknown
127        }
128    }
129
130    pub fn as_str(&self) -> &'static str {
131        match self {
132            ProjectType::Rust => "rust",
133            ProjectType::NodeJs => "nodejs",
134            ProjectType::Python => "python",
135            ProjectType::Go => "go",
136            ProjectType::Java => "java",
137            ProjectType::Cpp => "cpp",
138            ProjectType::Mixed => "mixed",
139            ProjectType::Unknown => "unknown",
140        }
141    }
142}
143
144/// Context injector for runtime context
145pub struct ContextInjector {
146    /// Working directory
147    working_dir: PathBuf,
148    /// Cache of user context
149    user_context_cache: Option<UserContext>,
150    /// Cache of system context
151    system_context_cache: Option<SystemContext>,
152    /// Whether to refresh cache
153    dirty: bool,
154}
155
156impl ContextInjector {
157    /// Create a new context injector
158    pub fn new<P: Into<PathBuf>>(working_dir: P) -> Self {
159        Self {
160            working_dir: working_dir.into(),
161            user_context_cache: None,
162            system_context_cache: None,
163            dirty: true,
164        }
165    }
166
167    /// Mark cache as dirty (need refresh)
168    pub fn invalidate(&mut self) {
169        self.dirty = true;
170    }
171
172    /// Set LSP servers info (for dynamic injection)
173    /// This updates the cached system context with LSP server information
174    pub fn set_lsp_servers(&mut self, servers: Vec<LspServerInfo>) {
175        if let Some(ref mut ctx) = self.system_context_cache {
176            ctx.lsp_servers = servers;
177        } else {
178            let mut ctx = SystemContext::default();
179            ctx.lsp_servers = servers;
180            self.system_context_cache = Some(ctx);
181        }
182    }
183
184    /// Get user context
185    pub fn get_user_context(&mut self) -> &UserContext {
186        if self.dirty || self.user_context_cache.is_none() {
187            self.user_context_cache = Some(self.collect_user_context());
188            self.dirty = false;
189        }
190        self.user_context_cache.as_ref().unwrap()
191    }
192
193    /// Get system context
194    pub fn get_system_context(&mut self) -> &SystemContext {
195        if self.dirty || self.system_context_cache.is_none() {
196            self.system_context_cache = Some(self.collect_system_context());
197            self.dirty = false; // Reset dirty flag after refresh
198        }
199        self.system_context_cache.as_ref().unwrap()
200    }
201
202    /// Collect user context
203    fn collect_user_context(&self) -> UserContext {
204        let mut ctx = UserContext::default();
205
206        // Read CLAUDE.md
207        let claude_md_path = self.working_dir.join("CLAUDE.md");
208        if claude_md_path.exists() {
209            if let Ok(content) = std::fs::read_to_string(&claude_md_path) {
210                ctx.claude_md_content = Some(content);
211            }
212        }
213
214        // Also check parent directories
215        if ctx.claude_md_content.is_none() {
216            if let Some(parent) = self.working_dir.parent() {
217                let parent_claude_md = parent.join("CLAUDE.md");
218                if parent_claude_md.exists() {
219                    if let Ok(content) = std::fs::read_to_string(&parent_claude_md) {
220                        ctx.claude_md_content = Some(content);
221                    }
222                }
223            }
224        }
225
226        ctx
227    }
228
229    /// Collect system context
230    fn collect_system_context(&self) -> SystemContext {
231        let mut ctx = SystemContext::default();
232
233        // Git info
234        if let Ok(output) = std::process::Command::new("git")
235            .args(["branch", "--show-current"])
236            .current_dir(&self.working_dir)
237            .output()
238        {
239            if output.status.success() {
240                ctx.git_branch = Some(String::from_utf8_lossy(&output.stdout).trim().to_string());
241            }
242        }
243
244        if let Ok(output) = std::process::Command::new("git")
245            .args(["status", "--porcelain"])
246            .current_dir(&self.working_dir)
247            .output()
248        {
249            if output.status.success() {
250                let status = String::from_utf8_lossy(&output.stdout);
251                ctx.git_status = if status.trim().is_empty() {
252                    Some("clean".to_string())
253                } else {
254                    Some(format!("dirty ({} changes)", status.lines().count()))
255                };
256            }
257        }
258
259        // Working directory
260        ctx.working_directory = self.working_dir.to_str().map(|s| s.to_string());
261        ctx.project_root = ctx.working_directory.clone();
262
263        // Project type
264        ctx.project_type = Some(ProjectType::detect(&self.working_dir));
265
266        // Available tools (check common tools)
267        let tools = ["git", "cargo", "npm", "python", "go", "docker"];
268        for tool in tools {
269            if Self::tool_available(tool) {
270                ctx.available_tools.push(tool.to_string());
271            }
272        }
273
274        ctx
275    }
276
277    /// Check if a tool is available
278    fn tool_available(tool: &str) -> bool {
279        #[cfg(unix)]
280        {
281            std::process::Command::new("which")
282                .arg(tool)
283                .output()
284                .map(|o| o.status.success())
285                .unwrap_or(false)
286        }
287
288        #[cfg(windows)]
289        {
290            std::process::Command::new("where")
291                .arg(tool)
292                .output()
293                .map(|o| o.status.success())
294                .unwrap_or(false)
295        }
296    }
297
298    /// Render user context as prompt section
299    pub fn render_user_context(&mut self) -> String {
300        let ctx = self.get_user_context();
301        let mut parts = Vec::new();
302
303        // Date
304        parts.push(format!(
305            "<currentDate>\n{}\n</currentDate>",
306            ctx.current_date
307        ));
308
309        // CLAUDE.md content
310        if let Some(ref claude_md) = ctx.claude_md_content {
311            parts.push(format!(
312                "<userPreferences>\n{}\n</userPreferences>",
313                claude_md
314            ));
315        }
316
317        parts.join("\n\n")
318    }
319
320    /// Render system context as prompt section
321    pub fn render_system_context(&mut self) -> String {
322        let ctx = self.get_system_context();
323        let mut parts = Vec::new();
324
325        // Working directory
326        if let Some(ref dir) = ctx.working_directory {
327            parts.push(format!("<workingDirectory>\n{}\n</workingDirectory>", dir));
328        }
329
330        // Git info
331        if let Some(ref branch) = ctx.git_branch {
332            let git_info = if let Some(ref status) = ctx.git_status {
333                format!("Branch: {}\nStatus: {}", branch, status)
334            } else {
335                format!("Branch: {}", branch)
336            };
337            parts.push(format!("<gitContext>\n{}\n</gitContext>", git_info));
338        }
339
340        // Project type
341        if let Some(ref pt) = ctx.project_type {
342            parts.push(format!("<projectType>\n{}\n</projectType>", pt.as_str()));
343        }
344
345        // Available tools
346        if !ctx.available_tools.is_empty() {
347            parts.push(format!(
348                "<availableTools>\n{}\n</availableTools>",
349                ctx.available_tools.join(", ")
350            ));
351        }
352
353        // LSP servers (dynamic injection)
354        if !ctx.lsp_servers.is_empty() {
355            let servers_info = ctx
356                .lsp_servers
357                .iter()
358                .map(|s| {
359                    let status = s.status.label();
360                    format!("{}: {} [{}]", s.language, s.name, status)
361                })
362                .collect::<Vec<_>>()
363                .join("\n");
364            parts.push(format!("<lspServers>\n{}\n</lspServers>", servers_info));
365        }
366
367        parts.join("\n\n")
368    }
369
370    /// Render full context (for injection into prompt)
371    pub fn render_full_context(&mut self) -> String {
372        let user = self.render_user_context();
373        let system = self.render_system_context();
374
375        format!("<context>\n{}\n\n{}\n</context>", user, system)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use std::path::PathBuf;
383
384    #[test]
385    fn test_user_context_default() {
386        let ctx = UserContext::default();
387        assert!(ctx.claude_md_content.is_none());
388        assert!(!ctx.current_date.is_empty());
389    }
390
391    #[test]
392    fn test_system_context_default() {
393        let ctx = SystemContext::default();
394        assert!(ctx.git_branch.is_none());
395        assert!(!ctx.platform.is_empty());
396    }
397
398    #[test]
399    fn test_project_type_detect_rust() {
400        let temp_dir = tempfile::tempdir().unwrap();
401        std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
402
403        assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Rust);
404    }
405
406    #[test]
407    fn test_project_type_detect_mixed() {
408        let temp_dir = tempfile::tempdir().unwrap();
409        std::fs::write(temp_dir.path().join("Cargo.toml"), "").unwrap();
410        std::fs::write(temp_dir.path().join("package.json"), "").unwrap();
411
412        assert_eq!(ProjectType::detect(temp_dir.path()), ProjectType::Mixed);
413    }
414
415    #[test]
416    fn test_context_invalidator() {
417        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
418
419        // Get once
420        let _ = injector.get_user_context();
421
422        // Invalidate
423        injector.invalidate();
424
425        // Should refresh
426        let _ = injector.get_user_context();
427    }
428
429    #[test]
430    fn test_render_user_context() {
431        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
432        let rendered = injector.render_user_context();
433
434        assert!(rendered.contains("<currentDate>"));
435    }
436
437    #[test]
438    fn test_render_system_context() {
439        let mut injector = ContextInjector::new(std::env::current_dir().unwrap());
440        let rendered = injector.render_system_context();
441
442        assert!(rendered.contains("<workingDirectory>") || rendered.contains("<projectType>"));
443    }
444}