use serde_json::Value;
use crate::chat::{ChatMessage, ChatResponse, ToolCall};
use crate::usage::Usage;
use super::config::{
LoopAction, LoopDetectionConfig, LoopEvent, TerminationReason, ToolLoopConfig, ToolLoopResult,
};
pub(crate) struct IterationSnapshot<'a> {
pub response: &'a ChatResponse,
pub call_refs: &'a [&'a ToolCall],
pub iterations: u32,
pub total_usage: &'a Usage,
pub config: &'a ToolLoopConfig,
}
#[derive(Debug, Default)]
pub(crate) struct LoopDetectionState {
last_hash: Option<u64>,
last_tool_name: String,
consecutive_count: u32,
}
impl LoopDetectionState {
#[cfg(test)]
pub(crate) fn update(&mut self, calls: &[ToolCall], threshold: u32) -> Option<(String, u32)> {
let refs: Vec<&ToolCall> = calls.iter().collect();
self.update_refs(&refs, threshold)
}
pub(crate) fn update_refs(
&mut self,
calls: &[&ToolCall],
threshold: u32,
) -> Option<(String, u32)> {
let (hash, tool_name) = compute_tool_calls_hash(calls);
if self.last_hash == Some(hash) {
self.consecutive_count += 1;
if self.consecutive_count >= threshold && self.consecutive_count % threshold == 0 {
return Some((self.last_tool_name.clone(), self.consecutive_count));
}
} else {
self.last_hash = Some(hash);
self.last_tool_name = tool_name;
self.consecutive_count = 1;
}
None
}
#[cfg(test)]
pub(crate) fn reset(&mut self) {
self.last_hash = None;
self.last_tool_name.clear();
self.consecutive_count = 0;
}
}
fn compute_tool_calls_hash(calls: &[&ToolCall]) -> (u64, String) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
if calls.is_empty() {
return (0, String::new());
}
let mut hasher = DefaultHasher::new();
for call in calls {
call.name.hash(&mut hasher);
hash_json_value(&call.arguments, &mut hasher);
}
let tool_name = if calls.len() == 1 {
calls[0].name.clone()
} else {
calls
.iter()
.map(|c| c.name.as_str())
.collect::<Vec<_>>()
.join("+")
};
(hasher.finish(), tool_name)
}
fn hash_json_value<H: std::hash::Hasher>(value: &Value, hasher: &mut H) {
use std::hash::Hash;
match value {
Value::Null => 0u8.hash(hasher),
Value::Bool(b) => {
1u8.hash(hasher);
b.hash(hasher);
}
Value::Number(n) => {
2u8.hash(hasher);
n.to_string().hash(hasher);
}
Value::String(s) => {
3u8.hash(hasher);
s.hash(hasher);
}
Value::Array(arr) => {
4u8.hash(hasher);
arr.len().hash(hasher);
for item in arr {
hash_json_value(item, hasher);
}
}
Value::Object(obj) => {
5u8.hash(hasher);
obj.len().hash(hasher);
let mut keys: Vec<_> = obj.keys().collect();
keys.sort();
for key in keys {
key.hash(hasher);
hash_json_value(&obj[key], hasher);
}
}
}
}
#[cfg(test)]
pub(crate) fn compute_tool_calls_signature(calls: &[ToolCall]) -> (String, String) {
if calls.is_empty() {
return (String::new(), String::new());
}
if calls.len() == 1 {
let call = &calls[0];
let args = serde_json::to_string(&call.arguments).unwrap_or_default();
return (call.name.clone(), args);
}
let mut names = Vec::with_capacity(calls.len());
let mut args_parts = Vec::with_capacity(calls.len());
for call in calls {
names.push(call.name.as_str());
args_parts.push(serde_json::to_string(&call.arguments).unwrap_or_default());
}
(names.join("+"), args_parts.join("|"))
}
pub(crate) enum LoopCheckResult {
Continue,
Stop { tool_name: String, count: u32 },
InjectWarning { tool_name: String, count: u32 },
}
pub(crate) fn check_loop_detection_refs(
state: &mut LoopDetectionState,
calls: &[&ToolCall],
config: Option<&LoopDetectionConfig>,
events: &mut Vec<LoopEvent>,
) -> LoopCheckResult {
let Some(detection) = config else {
return LoopCheckResult::Continue;
};
if let Some((tool_name, count)) = state.update_refs(calls, detection.threshold) {
let action = detection.action;
events.push(LoopEvent::LoopDetected {
tool_name: tool_name.clone(),
consecutive_count: count,
action,
});
match detection.action {
LoopAction::Warn => LoopCheckResult::Continue,
LoopAction::Stop => LoopCheckResult::Stop { tool_name, count },
LoopAction::InjectWarning => LoopCheckResult::InjectWarning { tool_name, count },
}
} else {
LoopCheckResult::Continue
}
}
pub(crate) fn create_loop_warning_message(tool_name: &str, count: u32) -> ChatMessage {
ChatMessage::system(format!(
"Warning: You have called the tool '{tool_name}' with identical arguments {count} times in a row. \
This appears to be a loop. Please try a different approach or tool."
))
}
pub(crate) fn handle_loop_detection(
state: &mut LoopDetectionState,
snap: &IterationSnapshot<'_>,
messages: &mut Vec<ChatMessage>,
events: &mut Vec<LoopEvent>,
) -> Option<ToolLoopResult> {
match check_loop_detection_refs(
state,
snap.call_refs,
snap.config.loop_detection.as_ref(),
events,
) {
LoopCheckResult::Continue => None,
LoopCheckResult::Stop { tool_name, count } => Some(ToolLoopResult {
response: snap.response.clone(),
iterations: snap.iterations,
total_usage: snap.total_usage.clone(),
termination_reason: TerminationReason::LoopDetected { tool_name, count },
}),
LoopCheckResult::InjectWarning { tool_name, count } => {
messages.push(create_loop_warning_message(&tool_name, count));
None
}
}
}