Skip to main content

aster/
tool_monitor.rs

1use crate::conversation::message::{Message, ToolRequest};
2use crate::tool_inspection::{InspectionAction, InspectionResult, ToolInspector};
3use anyhow::Result;
4use async_trait::async_trait;
5use rmcp::model::CallToolRequestParam;
6use serde_json::Value;
7use std::collections::HashMap;
8
9// Helper struct for internal tracking
10#[derive(Debug, Clone)]
11struct InternalToolCall {
12    name: String,
13    parameters: Value,
14}
15
16impl InternalToolCall {
17    fn matches(&self, other: &InternalToolCall) -> bool {
18        self.name == other.name && self.parameters == other.parameters
19    }
20
21    fn from_tool_call(tool_call: &CallToolRequestParam) -> Self {
22        let name = tool_call.name.to_string();
23        let parameters = tool_call
24            .arguments
25            .as_ref()
26            .map(|obj| Value::Object(obj.clone()))
27            .unwrap_or(Value::Null);
28        Self { name, parameters }
29    }
30}
31
32#[derive(Debug)]
33pub struct RepetitionInspector {
34    max_repetitions: Option<u32>,
35    last_call: Option<InternalToolCall>,
36    repeat_count: u32,
37    call_counts: HashMap<String, u32>,
38}
39
40impl RepetitionInspector {
41    pub fn new(max_repetitions: Option<u32>) -> Self {
42        Self {
43            max_repetitions,
44            last_call: None,
45            repeat_count: 0,
46            call_counts: HashMap::new(),
47        }
48    }
49
50    pub fn check_tool_call(&mut self, tool_call: CallToolRequestParam) -> bool {
51        let internal_call = InternalToolCall::from_tool_call(&tool_call);
52        let total_calls = self
53            .call_counts
54            .entry(internal_call.name.clone())
55            .or_insert(0);
56        *total_calls += 1;
57
58        if self.max_repetitions.is_none() {
59            self.last_call = Some(internal_call);
60            self.repeat_count = 1;
61            return true;
62        }
63
64        if let Some(last) = &self.last_call {
65            if last.matches(&internal_call) {
66                self.repeat_count += 1;
67                if self.repeat_count > self.max_repetitions.unwrap() {
68                    return false;
69                }
70            } else {
71                self.repeat_count = 1;
72            }
73        } else {
74            self.repeat_count = 1;
75        }
76
77        self.last_call = Some(internal_call);
78        true
79    }
80
81    pub fn reset(&mut self) {
82        self.last_call = None;
83        self.repeat_count = 0;
84        self.call_counts.clear();
85    }
86}
87
88#[async_trait]
89impl ToolInspector for RepetitionInspector {
90    fn name(&self) -> &'static str {
91        "repetition"
92    }
93
94    fn as_any(&self) -> &dyn std::any::Any {
95        self
96    }
97
98    async fn inspect(
99        &self,
100        tool_requests: &[ToolRequest],
101        _messages: &[Message],
102    ) -> Result<Vec<InspectionResult>> {
103        let mut results = Vec::new();
104
105        // Check repetition limits for each tool request
106        for tool_request in tool_requests {
107            if let Ok(tool_call) = &tool_request.tool_call {
108                // Create a temporary clone to check without modifying state
109                let mut temp_inspector = RepetitionInspector::new(self.max_repetitions);
110                temp_inspector.last_call = self.last_call.clone();
111                temp_inspector.repeat_count = self.repeat_count;
112                temp_inspector.call_counts = self.call_counts.clone();
113
114                if !temp_inspector.check_tool_call(tool_call.clone()) {
115                    results.push(InspectionResult {
116                        tool_request_id: tool_request.id.clone(),
117                        action: InspectionAction::Deny,
118                        reason: format!(
119                            "Tool '{}' has exceeded maximum repetitions",
120                            tool_call.name
121                        ),
122                        confidence: 1.0,
123                        inspector_name: "repetition".to_string(),
124                        finding_id: Some("REP-001".to_string()),
125                    });
126                }
127            }
128        }
129
130        Ok(results)
131    }
132}