Skip to main content

matrixcode_core/prompt/
preprocess.rs

1//! Pre-processing Hook for Skills/Workflows Trigger Detection
2//!
3//! This module implements the **backend-side** trigger detection that was
4//! previously described in the prompt. By moving this logic to code:
5//! - Eliminates ambiguity in pattern matching
6//! - Provides deterministic behavior
7//! - Reduces prompt token cost (~100 lines removed from prompt)
8//! - Enables easier testing and debugging
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use regex::Regex;
13
14/// Trigger type detection result
15#[derive(Debug, Clone, PartialEq)]
16pub enum ProcessResult {
17    /// A skill was triggered
18    SkillTriggered {
19        skill_id: String,
20        confidence: f32,
21    },
22    /// A workflow was triggered
23    WorkflowTriggered {
24        workflow_id: String,
25        inputs: HashMap<String, String>,
26    },
27    /// Continue normal processing
28    Continue,
29}
30
31/// Type of trigger detected
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TriggerType {
34    Skill,
35    Workflow,
36    SkillKeyword,
37    WorkflowKeyword,
38}
39
40/// Skill trigger pattern
41#[derive(Debug, Clone)]
42pub struct SkillPattern {
43    /// Skill identifier (e.g., "code-review", "refactor")
44    pub skill_id: String,
45    /// Primary trigger patterns (regex or keyword)
46    pub patterns: Vec<String>,
47    /// Compiled regex patterns
48    pub compiled: Vec<Regex>,
49    /// Confidence weight (0.0 - 1.0)
50    pub weight: f32,
51}
52
53impl SkillPattern {
54    pub fn new(skill_id: impl Into<String>, patterns: Vec<&str>, weight: f32) -> Self {
55        let patterns: Vec<String> = patterns.into_iter().map(|s| s.to_string()).collect();
56        let compiled = patterns.iter()
57            .filter_map(|p| Regex::new(&format!("(?i){}", p)).ok())
58            .collect();
59        
60        Self {
61            skill_id: skill_id.into(),
62            patterns,
63            compiled,
64            weight,
65        }
66    }
67
68    /// Check if user message matches this skill
69    pub fn matches(&self, message: &str) -> Option<f32> {
70        for regex in &self.compiled {
71            if regex.is_match(message) {
72                return Some(self.weight);
73            }
74        }
75        None
76    }
77}
78
79/// Workflow trigger configuration
80#[derive(Debug, Clone)]
81pub struct WorkflowTrigger {
82    /// Workflow identifier
83    pub workflow_id: String,
84    /// Trigger keywords
85    pub keywords: Vec<String>,
86    /// Required inputs that can be extracted from message
87    pub extractable_inputs: Vec<String>,
88}
89
90impl WorkflowTrigger {
91    pub fn new(workflow_id: impl Into<String>, keywords: Vec<&str>, inputs: Vec<&str>) -> Self {
92        Self {
93            workflow_id: workflow_id.into(),
94            keywords: keywords.into_iter().map(|s| s.to_string()).collect(),
95            extractable_inputs: inputs.into_iter().map(|s| s.to_string()).collect(),
96        }
97    }
98
99    /// Check if message triggers this workflow
100    pub fn matches(&self, message: &str) -> bool {
101        let msg_lower = message.to_lowercase();
102        self.keywords.iter().any(|k| msg_lower.contains(&k.to_lowercase()))
103    }
104
105    /// Extract inputs from message (simple extraction)
106    pub fn extract_inputs(&self, message: &str) -> HashMap<String, String> {
107        let mut inputs = HashMap::new();
108        
109        // Simple topic extraction for common patterns
110        if self.extractable_inputs.contains(&"topic".to_string()) {
111            // Pattern: "generate article about X" or "X article"
112            let patterns = [
113                r"(?i)(?:generate|create|write).*(?:article|post|content).*?about\s+(.+?)(?:\.|$)",
114                r"(?i)(?:article|post|content)\s+about\s+(.+?)(?:\.|$)",
115            ];
116            
117            for pattern in patterns {
118                if let Ok(re) = Regex::new(pattern) {
119                    if let Some(caps) = re.captures(message) {
120                        if let Some(topic) = caps.get(1) {
121                            inputs.insert("topic".to_string(), topic.as_str().trim().to_string());
122                            break;
123                        }
124                    }
125                }
126            }
127        }
128        
129        inputs
130    }
131}
132
133/// Pre-processing hook for trigger detection
134pub struct PreProcessHook {
135    /// Skill patterns
136    skills: Vec<SkillPattern>,
137    /// Workflow triggers
138    workflows: Vec<WorkflowTrigger>,
139    /// Minimum confidence threshold
140    confidence_threshold: f32,
141}
142
143impl Default for PreProcessHook {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl PreProcessHook {
150    /// Create with default patterns (from claude-code-analysis)
151    pub fn new() -> Self {
152        Self {
153            skills: Self::default_skill_patterns(),
154            workflows: Self::default_workflow_triggers(),
155            confidence_threshold: 0.7,
156        }
157    }
158
159    /// Default skill patterns based on analysis
160    fn default_skill_patterns() -> Vec<SkillPattern> {
161        vec![
162            // Code review skill
163            SkillPattern::new("code-review", vec![
164                r"/review",
165                r"审查.*代码",
166                r"检查.*代码",
167                r"code\s*review",
168                r"review.*code",
169            ], 0.9),
170
171            // Refactor skill
172            SkillPattern::new("refactor", vec![
173                r"/refactor",
174                r"重构.*代码",
175                r"优化.*结构",
176                r"refactor",
177            ], 0.9),
178
179            // Debug skill
180            SkillPattern::new("debug", vec![
181                r"/debug",
182                r"调试.*问题",
183                r"排查.*问题",
184                r"debug",
185                r"调试",
186            ], 0.9),
187
188            // Planning skill
189            SkillPattern::new("planning", vec![
190                r"/plan",
191                r"规划.*方案",
192                r"设计.*方案",
193                r"plan",
194            ], 0.9),
195
196            // Security review skill
197            SkillPattern::new("security-review", vec![
198                r"/security",
199                r"安全.*审查",
200                r"安全.*检查",
201                r"security\s*review",
202            ], 0.9),
203
204            // Demo skill
205            SkillPattern::new("demo", vec![
206                r"/demo",
207                r"演示",
208                r"demo",
209            ], 0.8),
210
211            // Git commit skill
212            SkillPattern::new("git-commit", vec![
213                r"/commit",
214                r"提交.*代码",
215                r"commit",
216            ], 0.8),
217        ]
218    }
219
220    /// Default workflow triggers based on analysis
221    fn default_workflow_triggers() -> Vec<WorkflowTrigger> {
222        vec![
223            // Image article workflow
224            WorkflowTrigger::new(
225                "image-article",
226                vec!["generate article", "生成文章", "create article", "图片文章"],
227                vec!["topic"],
228            ),
229
230            // Analysis workflow
231            WorkflowTrigger::new(
232                "code-analysis",
233                vec!["analyze code", "分析代码", "代码分析", "code analysis"],
234                vec!["target"],
235            ),
236
237            // Test workflow
238            WorkflowTrigger::new(
239                "test-runner",
240                vec!["run tests", "运行测试", "执行测试", "test suite"],
241                vec!["test_path"],
242            ),
243        ]
244    }
245
246    /// Process user message and detect triggers
247    pub fn process(&self, message: &str) -> ProcessResult {
248        // Step 1: Check for skill triggers
249        for skill in &self.skills {
250            if let Some(confidence) = skill.matches(message) {
251                if confidence >= self.confidence_threshold {
252                    return ProcessResult::SkillTriggered {
253                        skill_id: skill.skill_id.clone(),
254                        confidence,
255                    };
256                }
257            }
258        }
259
260        // Step 2: Check for workflow triggers
261        for workflow in &self.workflows {
262            if workflow.matches(message) {
263                let inputs = workflow.extract_inputs(message);
264                return ProcessResult::WorkflowTriggered {
265                    workflow_id: workflow.workflow_id.clone(),
266                    inputs,
267                };
268            }
269        }
270
271        // Step 3: Continue normal processing
272        ProcessResult::Continue
273    }
274
275    /// Add a custom skill pattern
276    pub fn add_skill(&mut self, skill: SkillPattern) {
277        self.skills.push(skill);
278    }
279
280    /// Add a custom workflow trigger
281    pub fn add_workflow(&mut self, workflow: WorkflowTrigger) {
282        self.workflows.push(workflow);
283    }
284
285    /// Set confidence threshold
286    pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
287        self.confidence_threshold = threshold;
288        self
289    }
290
291    /// Check if message contains skill-like patterns (for heuristics)
292    pub fn has_skill_intent(&self, message: &str) -> bool {
293        let msg_lower = message.to_lowercase();
294        
295        // Check for common skill indicators
296        let skill_indicators = [
297            "review", "refactor", "debug", "plan", "security",
298            "审查", "重构", "调试", "规划", "安全",
299        ];
300        
301        skill_indicators.iter().any(|ind| msg_lower.contains(ind))
302    }
303
304    /// Check if message contains workflow-like patterns (multiple steps)
305    pub fn has_workflow_intent(&self, message: &str) -> bool {
306        let msg_lower = message.to_lowercase();
307        
308        // Check for multi-step indicators
309        let workflow_indicators = [
310            "generate", "create", "analyze", "process", "batch",
311            "生成", "创建", "分析", "处理", "批量",
312            "and then", "then", "after", "然后", "接着",
313        ];
314        
315        // Count how many indicators are present
316        let count = workflow_indicators.iter()
317            .filter(|ind| msg_lower.contains(*ind))
318            .count();
319        
320        count >= 2
321    }
322
323    /// Get all registered skills
324    pub fn list_skills(&self) -> Vec<&str> {
325        self.skills.iter().map(|s| s.skill_id.as_str()).collect()
326    }
327
328    /// Get all registered workflows
329    pub fn list_workflows(&self) -> Vec<&str> {
330        self.workflows.iter().map(|w| w.workflow_id.as_str()).collect()
331    }
332}
333
334/// Global preprocessor instance
335static GLOBAL_PREPROCESSOR: std::sync::OnceLock<Arc<PreProcessHook>> = std::sync::OnceLock::new();
336
337/// Get the global preprocessor
338pub fn global_preprocessor() -> Arc<PreProcessHook> {
339    GLOBAL_PREPROCESSOR.get_or_init(|| Arc::new(PreProcessHook::new())).clone()
340}
341
342/// Process message with global preprocessor
343pub fn preprocess(message: &str) -> ProcessResult {
344    global_preprocessor().process(message)
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_skill_trigger_slash_command() {
353        let hook = PreProcessHook::new();
354        
355        let result = hook.process("/review this code");
356        assert!(matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "code-review"));
357        
358        let result = hook.process("/refactor the module");
359        assert!(matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "refactor"));
360    }
361
362    #[test]
363    fn test_skill_trigger_chinese() {
364        let hook = PreProcessHook::new();
365        
366        let result = hook.process("审查这段代码");
367        assert!(matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "code-review"));
368        
369        let result = hook.process("调试这个bug");
370        assert!(matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "debug"));
371    }
372
373    #[test]
374    fn test_workflow_trigger() {
375        let hook = PreProcessHook::new();
376        
377        let result = hook.process("generate article about Rust performance");
378        assert!(matches!(result, ProcessResult::WorkflowTriggered { workflow_id, .. } if workflow_id == "image-article"));
379    }
380
381    #[test]
382    fn test_continue_normal() {
383        let hook = PreProcessHook::new();
384        
385        let result = hook.process("What is the weather today?");
386        assert!(matches!(result, ProcessResult::Continue));
387        
388        let result = hook.process("Help me write a function");
389        assert!(matches!(result, ProcessResult::Continue));
390    }
391
392    #[test]
393    fn test_confidence_threshold() {
394        let hook = PreProcessHook::new().with_confidence_threshold(0.85);
395        
396        // Should still work for high-confidence matches (0.9 > 0.85)
397        let result = hook.process("/review");
398        assert!(matches!(result, ProcessResult::SkillTriggered { .. }));
399    }
400
401    #[test]
402    fn test_custom_skill() {
403        let mut hook = PreProcessHook::new();
404        hook.add_skill(SkillPattern::new("custom", vec!["/custom", "custom skill"], 0.9));
405        
406        let result = hook.process("/custom task");
407        assert!(matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "custom"));
408    }
409
410    #[test]
411    fn test_extract_inputs() {
412        let hook = PreProcessHook::new();
413        
414        let result = hook.process("generate article about Rust async programming");
415        if let ProcessResult::WorkflowTriggered { inputs, .. } = result {
416            assert!(inputs.contains_key("topic"));
417            assert!(inputs["topic"].to_lowercase().contains("rust"));
418        } else {
419            panic!("Expected WorkflowTriggered");
420        }
421    }
422
423    #[test]
424    fn test_has_skill_intent() {
425        let hook = PreProcessHook::new();
426        
427        assert!(hook.has_skill_intent("Please review my code"));
428        assert!(hook.has_skill_intent("审查代码"));
429        assert!(!hook.has_skill_intent("What's the time?"));
430    }
431
432    #[test]
433    fn test_has_workflow_intent() {
434        let hook = PreProcessHook::new();
435        
436        assert!(hook.has_workflow_intent("Analyze the code and then generate a report"));
437        assert!(hook.has_workflow_intent("分析代码,然后生成报告"));
438        assert!(!hook.has_workflow_intent("Just a simple question"));
439    }
440
441    #[test]
442    fn test_list_skills() {
443        let hook = PreProcessHook::new();
444        let skills = hook.list_skills();
445        
446        assert!(skills.contains(&"code-review"));
447        assert!(skills.contains(&"refactor"));
448        assert!(skills.contains(&"debug"));
449    }
450
451    #[test]
452    fn test_list_workflows() {
453        let hook = PreProcessHook::new();
454        let workflows = hook.list_workflows();
455        
456        assert!(workflows.contains(&"image-article"));
457        assert!(workflows.contains(&"code-analysis"));
458    }
459}