Skip to main content

matrixcode_core/prompt/
hooks.rs

1//! Session Start Hooks - Dynamic prompt injection
2//!
3//! This module provides hooks that inject content at session start:
4//! - SessionStart hook: Inject skill rules, red flags before user message
5//! - TodoReminder: Remind pending tasks from TodoWrite
6//! - DiagnosticsInjection: Real-time LSP diagnostics
7//!
8//! ## Injection Order
9//!
10//! ```
11//! Session Start:
12//! ├── Core system prompt (static)
13//! ├── Environment info (startup)
14//! ├── SessionStart hook content ← This module
15//! ├── Deferred tools (MCP)
16//! ├── Todo reminder (if pending)
17//! └── Diagnostics (real-time)
18//! ```
19
20use std::collections::HashMap;
21
22/// SessionStart hook content builder
23pub struct SessionStartHook {
24    /// Skills with mandatory invocation
25    mandatory_skills: Vec<String>,
26    /// Whether to include red flags
27    include_red_flags: bool,
28    /// Whether to include skill priority rules
29    include_skill_priority: bool,
30    /// Custom hook content
31    custom_content: Option<String>,
32}
33
34impl Default for SessionStartHook {
35    fn default() -> Self {
36        Self {
37            mandatory_skills: Vec::new(),
38            include_red_flags: true,
39            include_skill_priority: true,
40            custom_content: None,
41        }
42    }
43}
44
45impl SessionStartHook {
46    /// Create a new SessionStart hook
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Add a mandatory skill
52    pub fn add_mandatory_skill(mut self, skill_name: impl Into<String>) -> Self {
53        self.mandatory_skills.push(skill_name.into());
54        self
55    }
56
57    /// Set whether to include red flags
58    pub fn with_red_flags(mut self, include: bool) -> Self {
59        self.include_red_flags = include;
60        self
61    }
62
63    /// Set custom hook content
64    pub fn with_custom_content(mut self, content: impl Into<String>) -> Self {
65        self.custom_content = Some(content.into());
66        self
67    }
68
69    /// Build the hook content
70    pub fn build(&self) -> String {
71        let mut parts = Vec::new();
72
73        // 1. Mandatory skills warning
74        if !self.mandatory_skills.is_empty() {
75            parts.push(self.build_mandatory_skills_warning());
76        }
77
78        // 2. Red flags table
79        if self.include_red_flags {
80            parts.push(RED_FLAGS_SECTION.to_string());
81        }
82
83        // 3. Skill priority rules
84        if self.include_skill_priority {
85            parts.push(SKILL_PRIORITY_SECTION.to_string());
86        }
87
88        // 4. Custom content
89        if let Some(ref content) = self.custom_content {
90            parts.push(content.clone());
91        }
92
93        parts.join("\n\n")
94    }
95
96    fn build_mandatory_skills_warning(&self) -> String {
97        let skills_list = self.mandatory_skills.join(", ");
98        format!(
99            "<EXTREMELY-IMPORTANT>
100The following skills are **MANDATORY** and must be invoked before any response:
101{}
102
103If you think there's even a 1% chance these skills apply, you **MUST** invoke them first.
104This is not negotiable. This is not optional.
105</EXTREMELY-IMPORTANT>",
106            skills_list
107        )
108    }
109}
110
111/// Red flags section content
112const RED_FLAGS_SECTION: &str = "
113## Red Flags - STOP and reconsider
114
115These thoughts mean STOP — you're rationalizing:
116
117| Thought | Reality |
118|---------|---------|
119| \"This is just a simple question\" | Questions are tasks. Check for skills. |
120| \"I need more context first\" | Skill check comes BEFORE clarifying questions. |
121| \"Let me explore the codebase first\" | Skills tell you HOW to explore. Check first. |
122| \"I can check git/files quickly\" | Files lack conversation context. Check for skills. |
123| \"Let me gather information first\" | Skills tell you HOW to gather information. |
124| \"This doesn't need a formal skill\" | If a skill exists, use it. |
125| \"I remember this skill\" | Skills evolve. Read current version. |
126| \"This doesn't count as a task\" | Action = task. Check for skills. |
127| \"The skill is overkill\" | Simple things become complex. Use it. |
128| \"I'll just do this one thing first\" | Check BEFORE doing anything. |
129| \"This feels productive\" | Undisciplined action wastes time. Skills prevent this. |
130| \"I know what that means\" | Knowing the concept ≠ using the skill. Invoke it. |
131
132**When you catch yourself thinking these, STOP and invoke the relevant skill.**
133";
134
135/// Skill priority section content
136const SKILL_PRIORITY_SECTION: &str = "
137## Skill Priority
138
139When multiple skills could apply, use this order:
140
1411. **Process skills first** (brainstorming, debugging, planning)
142   - These determine HOW to approach the task
143   - Must be invoked first
144
1452. **Implementation skills second** (code-review, frontend-design)
146   - These guide execution
147   - After process skills
148
149Examples:
150- \"Let's build X\" → brainstorming → frontend-design
151- \"Fix this bug\" → debugging → domain-specific skills
152";
153
154// ============================================================================
155// Todo Reminder System
156// ============================================================================
157
158/// Todo reminder content builder
159pub struct TodoReminder {
160    /// Pending tasks
161    pending_tasks: Vec<String>,
162    /// In-progress task (at most one)
163    in_progress: Option<String>,
164    /// Reminder count (to track repeated reminders)
165    reminder_count: HashMap<String, usize>,
166    /// Max reminders per task
167    max_reminders: usize,
168}
169
170impl Default for TodoReminder {
171    fn default() -> Self {
172        Self {
173            pending_tasks: Vec::new(),
174            in_progress: None,
175            reminder_count: HashMap::new(),
176            max_reminders: 2,
177        }
178    }
179}
180
181impl TodoReminder {
182    /// Create a new todo reminder
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Set pending tasks
188    pub fn set_pending_tasks(mut self, tasks: Vec<String>) -> Self {
189        self.pending_tasks = tasks;
190        self
191    }
192
193    /// Set in-progress task
194    pub fn set_in_progress(mut self, task: impl Into<String>) -> Self {
195        self.in_progress = Some(task.into());
196        self
197    }
198
199    /// Set max reminders per task
200    pub fn with_max_reminders(mut self, max: usize) -> Self {
201        self.max_reminders = max;
202        self
203    }
204
205    /// Check if should remind (not exceeded max)
206    pub fn should_remind(&self, task: &str) -> bool {
207        let count = self.reminder_count.get(task).copied().unwrap_or(0);
208        count < self.max_reminders
209    }
210
211    /// Increment reminder count
212    pub fn increment_reminder(&mut self, task: &str) {
213        *self.reminder_count.entry(task.to_string()).or_insert(0) += 1;
214    }
215
216    /// Build reminder content
217    pub fn build(&self) -> Option<String> {
218        // Only remind if there's pending work and not exceeded max
219        if self.pending_tasks.is_empty() && self.in_progress.is_none() {
220            return None;
221        }
222
223        let mut lines = Vec::new();
224
225        // In-progress task
226        if let Some(ref task) = self.in_progress {
227            lines.push(format!("⏳ **In Progress**: {}", task));
228        }
229
230        // Pending tasks (filter by reminder count)
231        let remindable_pending: Vec<_> = self.pending_tasks
232            .iter()
233            .filter(|t| self.should_remind(t))
234            .collect();
235
236        if !remindable_pending.is_empty() {
237            lines.push("\n📋 **Pending Tasks**:".to_string());
238            for task in remindable_pending {
239                lines.push(format!("  - {}", task));
240            }
241        }
242
243        if lines.is_empty() {
244            return None;
245        }
246
247        Some(format!(
248            "<todo-reminder>\n{}\n</todo-reminder>",
249            lines.join("\n")
250        ))
251    }
252}
253
254// ============================================================================
255// Diagnostics Injection
256// ============================================================================
257
258/// Diagnostic entry for injection
259#[derive(Debug, Clone)]
260pub struct DiagnosticEntry {
261    /// File path
262    pub file: String,
263    /// Line number
264    pub line: usize,
265    /// Severity (error, warning, info)
266    pub severity: String,
267    /// Message
268    pub message: String,
269    /// Source (rustc, rust-analyzer, etc.)
270    pub source: String,
271}
272
273/// Diagnostics injection builder
274pub struct DiagnosticsInjection {
275    /// Diagnostic entries
276    diagnostics: Vec<DiagnosticEntry>,
277    /// Max entries to show
278    max_entries: usize,
279}
280
281impl Default for DiagnosticsInjection {
282    fn default() -> Self {
283        Self {
284            diagnostics: Vec::new(),
285            max_entries: 20,
286        }
287    }
288}
289
290impl DiagnosticsInjection {
291    /// Create a new diagnostics injection
292    pub fn new() -> Self {
293        Self::default()
294    }
295
296    /// Add a diagnostic
297    pub fn add_diagnostic(mut self, entry: DiagnosticEntry) -> Self {
298        self.diagnostics.push(entry);
299        self
300    }
301
302    /// Set diagnostics
303    pub fn set_diagnostics(mut self, entries: Vec<DiagnosticEntry>) -> Self {
304        self.diagnostics = entries;
305        self
306    }
307
308    /// Set max entries
309    pub fn with_max_entries(mut self, max: usize) -> Self {
310        self.max_entries = max;
311        self
312    }
313
314    /// Build diagnostics content
315    pub fn build(&self) -> Option<String> {
316        if self.diagnostics.is_empty() {
317            return None;
318        }
319
320        // Limit entries
321        let entries: Vec<_> = self.diagnostics.iter()
322            .take(self.max_entries)
323            .collect();
324
325        let mut lines = Vec::new();
326        lines.push("<new-diagnostics>".to_string());
327        lines.push("The following new diagnostic issues were detected:".to_string());
328        lines.push("\n".to_string());
329
330        for diag in entries {
331            let severity_marker = match diag.severity.as_str() {
332                "error" => "✘",
333                "warning" => "⚠",
334                "info" => "ℹ",
335                _ => "•",
336            };
337
338            lines.push(format!(
339                "{} {}:{} {} [{}]",
340                severity_marker,
341                diag.file,
342                diag.line,
343                diag.message,
344                diag.source
345            ));
346        }
347
348        lines.push("\n</new-diagnostics>".to_string());
349
350        Some(lines.join("\n"))
351    }
352
353    /// Check if has errors
354    pub fn has_errors(&self) -> bool {
355        self.diagnostics.iter().any(|d| d.severity == "error")
356    }
357
358    /// Check if has warnings
359    pub fn has_warnings(&self) -> bool {
360        self.diagnostics.iter().any(|d| d.severity == "warning")
361    }
362}
363
364// ============================================================================
365// Combined Session Context
366// ============================================================================
367
368/// Combined session start context
369pub struct SessionStartContext {
370    /// Session start hook
371    pub hook: SessionStartHook,
372    /// Todo reminder
373    pub todo: TodoReminder,
374    /// Diagnostics
375    pub diagnostics: DiagnosticsInjection,
376}
377
378impl Default for SessionStartContext {
379    fn default() -> Self {
380        Self {
381            hook: SessionStartHook::new(),
382            todo: TodoReminder::new(),
383            diagnostics: DiagnosticsInjection::new(),
384        }
385    }
386}
387
388impl SessionStartContext {
389    /// Create a new session start context
390    pub fn new() -> Self {
391        Self::default()
392    }
393
394    /// Build all session start content
395    pub fn build(&self) -> String {
396        let mut parts = Vec::new();
397
398        // 1. Session start hook content
399        let hook_content = self.hook.build();
400        if !hook_content.is_empty() {
401            parts.push(format!(
402                "SessionStart hook additional context:\n{}",
403                hook_content
404            ));
405        }
406
407        // 2. Todo reminder
408        if let Some(todo_content) = self.todo.build() {
409            parts.push(todo_content);
410        }
411
412        // 3. Diagnostics
413        if let Some(diag_content) = self.diagnostics.build() {
414            parts.push(diag_content);
415        }
416
417        parts.join("\n\n")
418    }
419
420    /// Check if has any content to inject
421    pub fn has_content(&self) -> bool {
422        !self.hook.build().is_empty()
423            || self.todo.build().is_some()
424            || self.diagnostics.build().is_some()
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_session_start_hook_builds_content() {
434        let hook = SessionStartHook::new()
435            .add_mandatory_skill("code-review")
436            .with_red_flags(true);
437
438        let content = hook.build();
439        assert!(content.contains("EXTREMELY-IMPORTANT"));
440        assert!(content.contains("code-review"));
441        assert!(content.contains("Red Flags"));
442    }
443
444    #[test]
445    fn test_session_start_hook_without_red_flags() {
446        let hook = SessionStartHook::new()
447            .with_red_flags(false);
448
449        let content = hook.build();
450        assert!(!content.contains("Red Flags"));
451    }
452
453    #[test]
454    fn test_todo_reminder_with_pending_tasks() {
455        let reminder = TodoReminder::new()
456            .set_pending_tasks(vec!["Task A".to_string(), "Task B".to_string()]);
457
458        let content = reminder.build();
459        assert!(content.is_some());
460        let content = content.unwrap();
461        assert!(content.contains("Pending Tasks"));
462        assert!(content.contains("Task A"));
463    }
464
465    #[test]
466    fn test_todo_reminder_empty() {
467        let reminder = TodoReminder::new();
468
469        let content = reminder.build();
470        assert!(content.is_none());
471    }
472
473    #[test]
474    fn test_todo_reminder_max_limit() {
475        let mut reminder = TodoReminder::new()
476            .set_pending_tasks(vec!["Task A".to_string()])
477            .with_max_reminders(2);
478
479        // First two reminders should work
480        assert!(reminder.should_remind("Task A"));
481        reminder.increment_reminder("Task A");
482        assert!(reminder.should_remind("Task A"));
483        reminder.increment_reminder("Task A");
484
485        // Third should not
486        assert!(!reminder.should_remind("Task A"));
487    }
488
489    #[test]
490    fn test_diagnostics_injection_with_errors() {
491        let injection = DiagnosticsInjection::new()
492            .add_diagnostic(DiagnosticEntry {
493                file: "src/main.rs".to_string(),
494                line: 42,
495                severity: "error".to_string(),
496                message: "missing semicolon".to_string(),
497                source: "rustc".to_string(),
498            });
499
500        let content = injection.build();
501        assert!(content.is_some());
502        let content = content.unwrap();
503        assert!(content.contains("new-diagnostics"));
504        assert!(content.contains("✘"));
505        assert!(content.contains("missing semicolon"));
506    }
507
508    #[test]
509    fn test_diagnostics_has_errors() {
510        let injection = DiagnosticsInjection::new()
511            .add_diagnostic(DiagnosticEntry {
512                file: "src/main.rs".to_string(),
513                line: 42,
514                severity: "error".to_string(),
515                message: "error".to_string(),
516                source: "rustc".to_string(),
517            });
518
519        assert!(injection.has_errors());
520        assert!(!injection.has_warnings());
521    }
522
523    #[test]
524    fn test_session_start_context_combined() {
525        let hook = SessionStartHook::new().add_mandatory_skill("test");
526        let todo = TodoReminder::new().set_pending_tasks(vec!["Task".to_string()]);
527        let diagnostics = DiagnosticsInjection::new().add_diagnostic(DiagnosticEntry {
528            file: "test.rs".to_string(),
529            line: 1,
530            severity: "warning".to_string(),
531            message: "test".to_string(),
532            source: "rustc".to_string(),
533        });
534
535        let context = SessionStartContext {
536            hook,
537            todo,
538            diagnostics,
539        };
540
541        let content = context.build();
542        assert!(content.contains("SessionStart hook"));
543        assert!(content.contains("todo-reminder"));
544        assert!(content.contains("new-diagnostics"));
545    }
546}