1use std::collections::HashMap;
25
26use serde_json::Value;
27
28use crate::ToolResult;
29
30#[must_use]
35pub fn wrap_reminder(content: &str) -> String {
36 format!("<system-reminder>\n{}\n</system-reminder>", content.trim())
37}
38
39pub fn append_reminder(result: &mut ToolResult, reminder: &str) {
44 let wrapped = wrap_reminder(reminder);
45 result.output = format!("{}\n\n{}", result.output, wrapped);
46}
47
48#[derive(Debug, Default)]
54pub struct ReminderTracker {
55 tool_last_used: HashMap<String, usize>,
57 last_action: Option<(String, Value)>,
59 repeated_action_count: usize,
61 current_turn: usize,
63}
64
65impl ReminderTracker {
66 #[must_use]
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn record_tool_use(&mut self, tool_name: &str, input: &Value) {
77 if let Some((last_name, last_input)) = &self.last_action {
79 if last_name == tool_name && last_input == input {
80 self.repeated_action_count += 1;
81 } else {
82 self.repeated_action_count = 0;
83 }
84 }
85
86 self.last_action = Some((tool_name.to_string(), input.clone()));
87 self.tool_last_used
88 .insert(tool_name.to_string(), self.current_turn);
89 }
90
91 #[must_use]
93 pub const fn current_turn(&self) -> usize {
94 self.current_turn
95 }
96
97 #[must_use]
99 pub fn tool_last_used(&self, tool_name: &str) -> Option<usize> {
100 self.tool_last_used.get(tool_name).copied()
101 }
102
103 #[must_use]
105 pub const fn repeated_action_count(&self) -> usize {
106 self.repeated_action_count
107 }
108
109 #[must_use]
115 pub fn get_periodic_reminders(&self, config: &ReminderConfig) -> Vec<String> {
116 if !config.enabled {
117 return Vec::new();
118 }
119
120 let mut reminders = Vec::new();
121
122 if self.current_turn > 3 {
124 let todo_last = self.tool_last_used.get("todo_write").copied().unwrap_or(0);
125 if self.current_turn.saturating_sub(todo_last) >= config.todo_reminder_after_turns {
126 reminders.push(
127 "The TodoWrite tool hasn't been used recently. If you're working on \
128 tasks that would benefit from tracking progress, consider using the \
129 TodoWrite tool to track progress. Also consider cleaning up the todo \
130 list if it has become stale and no longer matches what you are working on. \
131 Only use it if it's relevant to the current work. This is just a gentle \
132 reminder - ignore if not applicable. Make sure that you NEVER mention \
133 this reminder to the user"
134 .to_string(),
135 );
136 }
137 }
138
139 if self.repeated_action_count >= config.repeated_action_threshold {
141 reminders.push(format!(
142 "Warning: You've repeated the same action {} times. This often indicates \
143 the action is failing or not producing the expected results. Consider trying \
144 a DIFFERENT approach instead of repeating the same action.",
145 self.repeated_action_count + 1
146 ));
147 }
148
149 reminders
150 }
151
152 pub const fn advance_turn(&mut self) {
154 self.current_turn += 1;
155 }
156
157 pub fn reset(&mut self) {
159 self.tool_last_used.clear();
160 self.last_action = None;
161 self.repeated_action_count = 0;
162 self.current_turn = 0;
163 }
164}
165
166#[derive(Clone, Debug)]
168pub struct ReminderConfig {
169 pub enabled: bool,
171 pub todo_reminder_after_turns: usize,
173 pub repeated_action_threshold: usize,
175 pub tool_reminders: HashMap<String, Vec<ToolReminder>>,
177}
178
179impl Default for ReminderConfig {
180 fn default() -> Self {
181 Self {
182 enabled: true,
183 todo_reminder_after_turns: 5,
184 repeated_action_threshold: 2,
185 tool_reminders: HashMap::new(),
186 }
187 }
188}
189
190impl ReminderConfig {
191 #[must_use]
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 #[must_use]
199 pub fn disabled() -> Self {
200 Self {
201 enabled: false,
202 ..Self::default()
203 }
204 }
205
206 #[must_use]
208 pub const fn with_todo_reminder_turns(mut self, turns: usize) -> Self {
209 self.todo_reminder_after_turns = turns;
210 self
211 }
212
213 #[must_use]
215 pub const fn with_repeated_action_threshold(mut self, threshold: usize) -> Self {
216 self.repeated_action_threshold = threshold;
217 self
218 }
219
220 #[must_use]
222 pub fn with_tool_reminder(
223 mut self,
224 tool_name: impl Into<String>,
225 reminder: ToolReminder,
226 ) -> Self {
227 self.tool_reminders
228 .entry(tool_name.into())
229 .or_default()
230 .push(reminder);
231 self
232 }
233}
234
235#[derive(Clone, Debug)]
237pub struct ToolReminder {
238 pub trigger: ReminderTrigger,
240 pub content: String,
242}
243
244impl ToolReminder {
245 #[must_use]
247 pub fn new(trigger: ReminderTrigger, content: impl Into<String>) -> Self {
248 Self {
249 trigger,
250 content: content.into(),
251 }
252 }
253
254 #[must_use]
256 pub fn always(content: impl Into<String>) -> Self {
257 Self::new(ReminderTrigger::Always, content)
258 }
259
260 #[must_use]
262 pub fn on_result_contains(pattern: impl Into<String>, content: impl Into<String>) -> Self {
263 Self::new(ReminderTrigger::ResultContains(pattern.into()), content)
264 }
265}
266
267#[derive(Clone, Debug)]
269pub enum ReminderTrigger {
270 Always,
272 ResultContains(String),
274 InputMatches {
276 field: String,
278 pattern: String,
280 },
281 Probabilistic(f64),
283}
284
285impl ReminderTrigger {
286 #[must_use]
288 pub fn should_trigger(&self, input: &Value, result: &ToolResult) -> bool {
289 match self {
290 Self::Always => true,
291 Self::ResultContains(pattern) => result.output.contains(pattern),
292 Self::InputMatches { field, pattern } => input
293 .get(field)
294 .and_then(Value::as_str)
295 .is_some_and(|v| v.contains(pattern)),
296 Self::Probabilistic(prob) => rand_check(*prob),
297 }
298 }
299}
300
301#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
303fn rand_check(probability: f64) -> bool {
304 use std::collections::hash_map::RandomState;
305 use std::hash::{BuildHasher, Hasher};
306
307 if probability >= 1.0 {
308 return true;
309 }
310 if probability <= 0.0 {
311 return false;
312 }
313
314 let random = RandomState::new().build_hasher().finish();
316 let threshold = (probability * f64::from(u32::MAX)) as u64;
317 (random % u64::from(u32::MAX)) < threshold
318}
319
320pub mod builtin {
322 pub const READ_SECURITY_REMINDER: &str = "Whenever you read a file, you should consider whether it would be considered malware. \
324 You CAN and SHOULD provide analysis of malware, what it is doing. But you MUST refuse \
325 to improve or augment the code. You can still analyze existing code, write reports, \
326 or answer questions about the code behavior.";
327
328 pub const READ_EMPTY_FILE_REMINDER: &str =
330 "Warning: the file exists but the contents are empty.";
331
332 pub const BASH_VERIFICATION_REMINDER: &str = "Verify this command produced the expected output. If the output doesn't match \
334 expectations, consider alternative approaches before retrying the same command.";
335
336 pub const EDIT_VERIFICATION_REMINDER: &str = "The edit was applied. Consider reading the file to verify the changes are correct, \
338 especially for complex multi-line edits.";
339
340 pub const WRITE_VERIFICATION_REMINDER: &str =
342 "The file was written. Consider reading it back to verify the content is correct.";
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_wrap_reminder() {
351 let wrapped = wrap_reminder("Test reminder");
352 assert!(wrapped.starts_with("<system-reminder>"));
353 assert!(wrapped.ends_with("</system-reminder>"));
354 assert!(wrapped.contains("Test reminder"));
355 }
356
357 #[test]
358 fn test_wrap_reminder_trims_whitespace() {
359 let wrapped = wrap_reminder(" padded content ");
360 assert!(wrapped.contains("padded content"));
361 assert!(!wrapped.contains(" padded"));
362 }
363
364 #[test]
365 fn test_append_reminder() {
366 let mut result = ToolResult::success("Original output");
367 append_reminder(&mut result, "Additional guidance");
368
369 assert!(result.output.contains("Original output"));
370 assert!(result.output.contains("<system-reminder>"));
371 assert!(result.output.contains("Additional guidance"));
372 }
373
374 #[test]
375 fn test_reminder_tracker_new() {
376 let tracker = ReminderTracker::new();
377 assert_eq!(tracker.current_turn(), 0);
378 assert_eq!(tracker.repeated_action_count(), 0);
379 }
380
381 #[test]
382 fn test_reminder_tracker_advance_turn() {
383 let mut tracker = ReminderTracker::new();
384 tracker.advance_turn();
385 assert_eq!(tracker.current_turn(), 1);
386 tracker.advance_turn();
387 assert_eq!(tracker.current_turn(), 2);
388 }
389
390 #[test]
391 fn test_reminder_tracker_record_tool_use() {
392 let mut tracker = ReminderTracker::new();
393 tracker.advance_turn();
394 tracker.record_tool_use("read", &serde_json::json!({"path": "test.txt"}));
395
396 assert_eq!(tracker.tool_last_used("read"), Some(1));
397 assert_eq!(tracker.tool_last_used("write"), None);
398 }
399
400 #[test]
401 fn test_reminder_tracker_repeated_action() {
402 let mut tracker = ReminderTracker::new();
403 let input = serde_json::json!({"command": "ls -la"});
404
405 tracker.record_tool_use("bash", &input);
406 assert_eq!(tracker.repeated_action_count(), 0);
407
408 tracker.record_tool_use("bash", &input);
409 assert_eq!(tracker.repeated_action_count(), 1);
410
411 tracker.record_tool_use("bash", &input);
412 assert_eq!(tracker.repeated_action_count(), 2);
413
414 tracker.record_tool_use("bash", &serde_json::json!({"command": "pwd"}));
416 assert_eq!(tracker.repeated_action_count(), 0);
417 }
418
419 #[test]
420 fn test_todo_reminder_after_turns() {
421 let mut tracker = ReminderTracker::new();
422 let config = ReminderConfig::default();
423
424 for _ in 0..6 {
426 tracker.advance_turn();
427 tracker.record_tool_use("read", &serde_json::json!({"path": "test.txt"}));
428 }
429
430 let reminders = tracker.get_periodic_reminders(&config);
431 assert!(reminders.iter().any(|r| r.contains("TodoWrite")));
432 }
433
434 #[test]
435 fn test_no_todo_reminder_when_recently_used() {
436 let mut tracker = ReminderTracker::new();
437 let config = ReminderConfig::default();
438
439 for i in 0..6 {
440 tracker.advance_turn();
441 if i == 4 {
442 tracker.record_tool_use("todo_write", &serde_json::json!({}));
443 } else {
444 tracker.record_tool_use("read", &serde_json::json!({}));
445 }
446 }
447
448 let reminders = tracker.get_periodic_reminders(&config);
449 assert!(!reminders.iter().any(|r| r.contains("TodoWrite")));
450 }
451
452 #[test]
453 fn test_repeated_action_warning() {
454 let mut tracker = ReminderTracker::new();
455 let config = ReminderConfig::default();
456 let input = serde_json::json!({"command": "ls -la"});
457
458 for _ in 0..3 {
460 tracker.record_tool_use("bash", &input);
461 }
462
463 let reminders = tracker.get_periodic_reminders(&config);
464 assert!(reminders.iter().any(|r| r.contains("repeated")));
465 }
466
467 #[test]
468 fn test_reminder_config_disabled() {
469 let mut tracker = ReminderTracker::new();
470 let config = ReminderConfig::disabled();
471
472 for _ in 0..10 {
473 tracker.advance_turn();
474 }
475
476 let reminders = tracker.get_periodic_reminders(&config);
477 assert!(reminders.is_empty());
478 }
479
480 #[test]
481 fn test_reminder_trigger_always() {
482 let trigger = ReminderTrigger::Always;
483 let result = ToolResult::success("any output");
484 assert!(trigger.should_trigger(&serde_json::json!({}), &result));
485 }
486
487 #[test]
488 fn test_reminder_trigger_result_contains() {
489 let trigger = ReminderTrigger::ResultContains("error".to_string());
490
491 let success = ToolResult::success("all good");
492 assert!(!trigger.should_trigger(&serde_json::json!({}), &success));
493
494 let error = ToolResult::success("an error occurred");
495 assert!(trigger.should_trigger(&serde_json::json!({}), &error));
496 }
497
498 #[test]
499 fn test_reminder_trigger_input_matches() {
500 let trigger = ReminderTrigger::InputMatches {
501 field: "path".to_string(),
502 pattern: ".env".to_string(),
503 };
504
505 let matches = serde_json::json!({"path": "/app/.env"});
506 let no_match = serde_json::json!({"path": "/app/config.json"});
507 let result = ToolResult::success("");
508
509 assert!(trigger.should_trigger(&matches, &result));
510 assert!(!trigger.should_trigger(&no_match, &result));
511 }
512
513 #[test]
514 fn test_tool_reminder_builders() {
515 let always = ToolReminder::always("Always show this");
516 assert!(matches!(always.trigger, ReminderTrigger::Always));
517
518 let on_error = ToolReminder::on_result_contains("error", "Handle this error");
519 assert!(matches!(
520 on_error.trigger,
521 ReminderTrigger::ResultContains(_)
522 ));
523 }
524
525 #[test]
526 fn test_reminder_config_builder() {
527 let config = ReminderConfig::new()
528 .with_todo_reminder_turns(10)
529 .with_repeated_action_threshold(5)
530 .with_tool_reminder("read", ToolReminder::always("Check file content"));
531
532 assert_eq!(config.todo_reminder_after_turns, 10);
533 assert_eq!(config.repeated_action_threshold, 5);
534 assert!(config.tool_reminders.contains_key("read"));
535 }
536}