1use crate::conversation::message::{Message, ToolRequest};
2use crate::tool_inspection::{InspectionAction, InspectionResult, ToolInspector};
3use anyhow::Result;
4use async_trait::async_trait;
5use rmcp::model::CallToolRequestParam;
6use serde_json::Value;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11struct InternalToolCall {
12 name: String,
13 parameters: Value,
14}
15
16impl InternalToolCall {
17 fn matches(&self, other: &InternalToolCall) -> bool {
18 self.name == other.name && self.parameters == other.parameters
19 }
20
21 fn from_tool_call(tool_call: &CallToolRequestParam) -> Self {
22 let name = tool_call.name.to_string();
23 let parameters = tool_call
24 .arguments
25 .as_ref()
26 .map(|obj| Value::Object(obj.clone()))
27 .unwrap_or(Value::Null);
28 Self { name, parameters }
29 }
30}
31
32#[derive(Debug)]
33pub struct RepetitionInspector {
34 max_repetitions: Option<u32>,
35 last_call: Option<InternalToolCall>,
36 repeat_count: u32,
37 call_counts: HashMap<String, u32>,
38}
39
40impl RepetitionInspector {
41 pub fn new(max_repetitions: Option<u32>) -> Self {
42 Self {
43 max_repetitions,
44 last_call: None,
45 repeat_count: 0,
46 call_counts: HashMap::new(),
47 }
48 }
49
50 pub fn check_tool_call(&mut self, tool_call: CallToolRequestParam) -> bool {
51 let internal_call = InternalToolCall::from_tool_call(&tool_call);
52 let total_calls = self
53 .call_counts
54 .entry(internal_call.name.clone())
55 .or_insert(0);
56 *total_calls += 1;
57
58 if self.max_repetitions.is_none() {
59 self.last_call = Some(internal_call);
60 self.repeat_count = 1;
61 return true;
62 }
63
64 if let Some(last) = &self.last_call {
65 if last.matches(&internal_call) {
66 self.repeat_count += 1;
67 if self.repeat_count > self.max_repetitions.unwrap() {
68 return false;
69 }
70 } else {
71 self.repeat_count = 1;
72 }
73 } else {
74 self.repeat_count = 1;
75 }
76
77 self.last_call = Some(internal_call);
78 true
79 }
80
81 pub fn reset(&mut self) {
82 self.last_call = None;
83 self.repeat_count = 0;
84 self.call_counts.clear();
85 }
86}
87
88#[async_trait]
89impl ToolInspector for RepetitionInspector {
90 fn name(&self) -> &'static str {
91 "repetition"
92 }
93
94 fn as_any(&self) -> &dyn std::any::Any {
95 self
96 }
97
98 async fn inspect(
99 &self,
100 tool_requests: &[ToolRequest],
101 _messages: &[Message],
102 ) -> Result<Vec<InspectionResult>> {
103 let mut results = Vec::new();
104
105 for tool_request in tool_requests {
107 if let Ok(tool_call) = &tool_request.tool_call {
108 let mut temp_inspector = RepetitionInspector::new(self.max_repetitions);
110 temp_inspector.last_call = self.last_call.clone();
111 temp_inspector.repeat_count = self.repeat_count;
112 temp_inspector.call_counts = self.call_counts.clone();
113
114 if !temp_inspector.check_tool_call(tool_call.clone()) {
115 results.push(InspectionResult {
116 tool_request_id: tool_request.id.clone(),
117 action: InspectionAction::Deny,
118 reason: format!(
119 "Tool '{}' has exceeded maximum repetitions",
120 tool_call.name
121 ),
122 confidence: 1.0,
123 inspector_name: "repetition".to_string(),
124 finding_id: Some("REP-001".to_string()),
125 });
126 }
127 }
128 }
129
130 Ok(results)
131 }
132}