Skip to main content

matrixcode_core/prompt/
preprocess.rs

1//! Pre-processing Hook for Skills/Workflows Trigger Detection
2//!
3//! This module implements the **backend-side** trigger detection that was
4//! previously described in the prompt. By moving this logic to code:
5//! - Eliminates ambiguity in pattern matching
6//! - Provides deterministic behavior
7//! - Reduces prompt token cost (~100 lines removed from prompt)
8//! - Enables easier testing and debugging
9
10use regex::Regex;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Trigger type detection result
15#[derive(Debug, Clone, PartialEq)]
16pub enum ProcessResult {
17    /// A skill was triggered
18    SkillTriggered { skill_id: String, confidence: f32 },
19    /// A workflow was triggered
20    WorkflowTriggered {
21        workflow_id: String,
22        inputs: HashMap<String, String>,
23    },
24    /// Continue normal processing
25    Continue,
26}
27
28/// Type of trigger detected
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum TriggerType {
31    Skill,
32    Workflow,
33    SkillKeyword,
34    WorkflowKeyword,
35}
36
37/// Skill trigger pattern
38#[derive(Debug, Clone)]
39pub struct SkillPattern {
40    /// Skill identifier (e.g., "code-review", "refactor")
41    pub skill_id: String,
42    /// Primary trigger patterns (regex or keyword)
43    pub patterns: Vec<String>,
44    /// Compiled regex patterns
45    pub compiled: Vec<Regex>,
46    /// Confidence weight (0.0 - 1.0)
47    pub weight: f32,
48}
49
50impl SkillPattern {
51    pub fn new(skill_id: impl Into<String>, patterns: Vec<&str>, weight: f32) -> Self {
52        let patterns: Vec<String> = patterns.into_iter().map(|s| s.to_string()).collect();
53        let compiled = patterns
54            .iter()
55            .filter_map(|p| Regex::new(&format!("(?i){}", p)).ok())
56            .collect();
57
58        Self {
59            skill_id: skill_id.into(),
60            patterns,
61            compiled,
62            weight,
63        }
64    }
65
66    /// Check if user message matches this skill
67    pub fn matches(&self, message: &str) -> Option<f32> {
68        for regex in &self.compiled {
69            if regex.is_match(message) {
70                return Some(self.weight);
71            }
72        }
73        None
74    }
75}
76
77/// Workflow trigger configuration
78#[derive(Debug, Clone)]
79pub struct WorkflowTrigger {
80    /// Workflow identifier
81    pub workflow_id: String,
82    /// Trigger keywords
83    pub keywords: Vec<String>,
84    /// Required inputs that can be extracted from message
85    pub extractable_inputs: Vec<String>,
86}
87
88impl WorkflowTrigger {
89    pub fn new(workflow_id: impl Into<String>, keywords: Vec<&str>, inputs: Vec<&str>) -> Self {
90        Self {
91            workflow_id: workflow_id.into(),
92            keywords: keywords.into_iter().map(|s| s.to_string()).collect(),
93            extractable_inputs: inputs.into_iter().map(|s| s.to_string()).collect(),
94        }
95    }
96
97    /// Check if message triggers this workflow
98    pub fn matches(&self, message: &str) -> bool {
99        let msg_lower = message.to_lowercase();
100        self.keywords
101            .iter()
102            .any(|k| msg_lower.contains(&k.to_lowercase()))
103    }
104
105    /// Extract inputs from message (simple extraction)
106    pub fn extract_inputs(&self, message: &str) -> HashMap<String, String> {
107        let mut inputs = HashMap::new();
108
109        // Simple topic extraction for common patterns
110        if self.extractable_inputs.contains(&"topic".to_string()) {
111            // Pattern: "generate article about X" or "X article"
112            let patterns = [
113                r"(?i)(?:generate|create|write).*(?:article|post|content).*?about\s+(.+?)(?:\.|$)",
114                r"(?i)(?:article|post|content)\s+about\s+(.+?)(?:\.|$)",
115            ];
116
117            for pattern in patterns {
118                if let Ok(re) = Regex::new(pattern) {
119                    if let Some(caps) = re.captures(message) {
120                        if let Some(topic) = caps.get(1) {
121                            inputs.insert("topic".to_string(), topic.as_str().trim().to_string());
122                            break;
123                        }
124                    }
125                }
126            }
127        }
128
129        inputs
130    }
131}
132
133/// Pre-processing hook for trigger detection
134pub struct PreProcessHook {
135    /// Skill patterns
136    skills: Vec<SkillPattern>,
137    /// Workflow triggers
138    workflows: Vec<WorkflowTrigger>,
139    /// Minimum confidence threshold
140    confidence_threshold: f32,
141}
142
143impl Default for PreProcessHook {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149impl PreProcessHook {
150    /// Create with default patterns (from claude-code-analysis)
151    pub fn new() -> Self {
152        Self {
153            skills: Self::default_skill_patterns(),
154            workflows: Self::default_workflow_triggers(),
155            confidence_threshold: 0.7,
156        }
157    }
158
159    /// Default skill patterns based on analysis
160    fn default_skill_patterns() -> Vec<SkillPattern> {
161        vec![
162            // Code review skill
163            SkillPattern::new(
164                "code-review",
165                vec![
166                    r"/review",
167                    r"审查.*代码",
168                    r"检查.*代码",
169                    r"code\s*review",
170                    r"review.*code",
171                ],
172                0.9,
173            ),
174            // Refactor skill
175            SkillPattern::new(
176                "refactor",
177                vec![r"/refactor", r"重构.*代码", r"优化.*结构", r"refactor"],
178                0.9,
179            ),
180            // Debug skill
181            SkillPattern::new(
182                "debug",
183                vec![r"/debug", r"调试.*问题", r"排查.*问题", r"debug", r"调试"],
184                0.9,
185            ),
186            // Planning skill
187            SkillPattern::new(
188                "planning",
189                vec![r"/plan", r"规划.*方案", r"设计.*方案", r"plan"],
190                0.9,
191            ),
192            // Security review skill
193            SkillPattern::new(
194                "security-review",
195                vec![
196                    r"/security",
197                    r"安全.*审查",
198                    r"安全.*检查",
199                    r"security\s*review",
200                ],
201                0.9,
202            ),
203            // Demo skill
204            SkillPattern::new("demo", vec![r"/demo", r"演示", r"demo"], 0.8),
205            // Git commit skill
206            SkillPattern::new(
207                "git-commit",
208                vec![r"/commit", r"提交.*代码", r"commit"],
209                0.8,
210            ),
211        ]
212    }
213
214    /// Default workflow triggers based on analysis
215    fn default_workflow_triggers() -> Vec<WorkflowTrigger> {
216        vec![
217            // Image article workflow
218            WorkflowTrigger::new(
219                "image-article",
220                vec!["generate article", "生成文章", "create article", "图片文章"],
221                vec!["topic"],
222            ),
223            // Analysis workflow
224            WorkflowTrigger::new(
225                "code-analysis",
226                vec!["analyze code", "分析代码", "代码分析", "code analysis"],
227                vec!["target"],
228            ),
229            // Test workflow
230            WorkflowTrigger::new(
231                "test-runner",
232                vec!["run tests", "运行测试", "执行测试", "test suite"],
233                vec!["test_path"],
234            ),
235        ]
236    }
237
238    /// Process user message and detect triggers
239    pub fn process(&self, message: &str) -> ProcessResult {
240        // Step 1: Check for skill triggers
241        for skill in &self.skills {
242            if let Some(confidence) = skill.matches(message) {
243                if confidence >= self.confidence_threshold {
244                    return ProcessResult::SkillTriggered {
245                        skill_id: skill.skill_id.clone(),
246                        confidence,
247                    };
248                }
249            }
250        }
251
252        // Step 2: Check for workflow triggers
253        for workflow in &self.workflows {
254            if workflow.matches(message) {
255                let inputs = workflow.extract_inputs(message);
256                return ProcessResult::WorkflowTriggered {
257                    workflow_id: workflow.workflow_id.clone(),
258                    inputs,
259                };
260            }
261        }
262
263        // Step 3: Continue normal processing
264        ProcessResult::Continue
265    }
266
267    /// Add a custom skill pattern
268    pub fn add_skill(&mut self, skill: SkillPattern) {
269        self.skills.push(skill);
270    }
271
272    /// Add a custom workflow trigger
273    pub fn add_workflow(&mut self, workflow: WorkflowTrigger) {
274        self.workflows.push(workflow);
275    }
276
277    /// Set confidence threshold
278    pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
279        self.confidence_threshold = threshold;
280        self
281    }
282
283    /// Check if message contains skill-like patterns (for heuristics)
284    pub fn has_skill_intent(&self, message: &str) -> bool {
285        let msg_lower = message.to_lowercase();
286
287        // Check for common skill indicators
288        let skill_indicators = [
289            "review", "refactor", "debug", "plan", "security", "审查", "重构", "调试", "规划",
290            "安全",
291        ];
292
293        skill_indicators.iter().any(|ind| msg_lower.contains(ind))
294    }
295
296    /// Check if message contains workflow-like patterns (multiple steps)
297    pub fn has_workflow_intent(&self, message: &str) -> bool {
298        let msg_lower = message.to_lowercase();
299
300        // Check for multi-step indicators
301        let workflow_indicators = [
302            "generate", "create", "analyze", "process", "batch", "生成", "创建", "分析", "处理",
303            "批量", "and then", "then", "after", "然后", "接着",
304        ];
305
306        // Count how many indicators are present
307        let count = workflow_indicators
308            .iter()
309            .filter(|ind| msg_lower.contains(*ind))
310            .count();
311
312        count >= 2
313    }
314
315    /// Get all registered skills
316    pub fn list_skills(&self) -> Vec<&str> {
317        self.skills.iter().map(|s| s.skill_id.as_str()).collect()
318    }
319
320    /// Get all registered workflows
321    pub fn list_workflows(&self) -> Vec<&str> {
322        self.workflows
323            .iter()
324            .map(|w| w.workflow_id.as_str())
325            .collect()
326    }
327}
328
329/// Global preprocessor instance
330static GLOBAL_PREPROCESSOR: std::sync::OnceLock<Arc<PreProcessHook>> = std::sync::OnceLock::new();
331
332/// Get the global preprocessor
333pub fn global_preprocessor() -> Arc<PreProcessHook> {
334    GLOBAL_PREPROCESSOR
335        .get_or_init(|| Arc::new(PreProcessHook::new()))
336        .clone()
337}
338
339/// Process message with global preprocessor
340pub fn preprocess(message: &str) -> ProcessResult {
341    global_preprocessor().process(message)
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_skill_trigger_slash_command() {
350        let hook = PreProcessHook::new();
351
352        let result = hook.process("/review this code");
353        assert!(
354            matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "code-review")
355        );
356
357        let result = hook.process("/refactor the module");
358        assert!(
359            matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "refactor")
360        );
361    }
362
363    #[test]
364    fn test_skill_trigger_chinese() {
365        let hook = PreProcessHook::new();
366
367        let result = hook.process("审查这段代码");
368        assert!(
369            matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "code-review")
370        );
371
372        let result = hook.process("调试这个bug");
373        assert!(
374            matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "debug")
375        );
376    }
377
378    #[test]
379    fn test_workflow_trigger() {
380        let hook = PreProcessHook::new();
381
382        let result = hook.process("generate article about Rust performance");
383        assert!(
384            matches!(result, ProcessResult::WorkflowTriggered { workflow_id, .. } if workflow_id == "image-article")
385        );
386    }
387
388    #[test]
389    fn test_continue_normal() {
390        let hook = PreProcessHook::new();
391
392        let result = hook.process("What is the weather today?");
393        assert!(matches!(result, ProcessResult::Continue));
394
395        let result = hook.process("Help me write a function");
396        assert!(matches!(result, ProcessResult::Continue));
397    }
398
399    #[test]
400    fn test_confidence_threshold() {
401        let hook = PreProcessHook::new().with_confidence_threshold(0.85);
402
403        // Should still work for high-confidence matches (0.9 > 0.85)
404        let result = hook.process("/review");
405        assert!(matches!(result, ProcessResult::SkillTriggered { .. }));
406    }
407
408    #[test]
409    fn test_custom_skill() {
410        let mut hook = PreProcessHook::new();
411        hook.add_skill(SkillPattern::new(
412            "custom",
413            vec!["/custom", "custom skill"],
414            0.9,
415        ));
416
417        let result = hook.process("/custom task");
418        assert!(
419            matches!(result, ProcessResult::SkillTriggered { skill_id, .. } if skill_id == "custom")
420        );
421    }
422
423    #[test]
424    fn test_extract_inputs() {
425        let hook = PreProcessHook::new();
426
427        let result = hook.process("generate article about Rust async programming");
428        if let ProcessResult::WorkflowTriggered { inputs, .. } = result {
429            assert!(inputs.contains_key("topic"));
430            assert!(inputs["topic"].to_lowercase().contains("rust"));
431        } else {
432            panic!("Expected WorkflowTriggered");
433        }
434    }
435
436    #[test]
437    fn test_has_skill_intent() {
438        let hook = PreProcessHook::new();
439
440        assert!(hook.has_skill_intent("Please review my code"));
441        assert!(hook.has_skill_intent("审查代码"));
442        assert!(!hook.has_skill_intent("What's the time?"));
443    }
444
445    #[test]
446    fn test_has_workflow_intent() {
447        let hook = PreProcessHook::new();
448
449        assert!(hook.has_workflow_intent("Analyze the code and then generate a report"));
450        assert!(hook.has_workflow_intent("分析代码,然后生成报告"));
451        assert!(!hook.has_workflow_intent("Just a simple question"));
452    }
453
454    #[test]
455    fn test_list_skills() {
456        let hook = PreProcessHook::new();
457        let skills = hook.list_skills();
458
459        assert!(skills.contains(&"code-review"));
460        assert!(skills.contains(&"refactor"));
461        assert!(skills.contains(&"debug"));
462    }
463
464    #[test]
465    fn test_list_workflows() {
466        let hook = PreProcessHook::new();
467        let workflows = hook.list_workflows();
468
469        assert!(workflows.contains(&"image-article"));
470        assert!(workflows.contains(&"code-analysis"));
471    }
472}