use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use crate::chat::{ChatResponse, ToolCall, ToolResult};
use crate::usage::Usage;
pub type ToolApprovalFn = Arc<dyn Fn(&ToolCall) -> ToolApproval + Send + Sync>;
pub type ToolLoopEventFn = Arc<dyn Fn(ToolLoopEvent) + Send + Sync>;
pub type StopConditionFn = Arc<dyn Fn(&StopContext) -> StopDecision + Send + Sync>;
#[derive(Debug)]
pub struct StopContext<'a> {
pub iteration: u32,
pub response: &'a ChatResponse,
pub total_usage: &'a Usage,
pub tool_calls_executed: usize,
pub last_tool_results: &'a [ToolResult],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StopDecision {
Continue,
Stop,
StopWithReason(String),
}
#[derive(Debug, Clone)]
pub struct LoopDetectionConfig {
pub threshold: u32,
pub action: LoopAction,
}
impl Default for LoopDetectionConfig {
fn default() -> Self {
Self {
threshold: 3,
action: LoopAction::Warn,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoopAction {
Warn,
Stop,
InjectWarning,
}
#[derive(Debug, Clone)]
pub enum ToolLoopEvent {
IterationStart {
iteration: u32,
message_count: usize,
},
ToolExecutionStart {
call_id: String,
tool_name: String,
arguments: Value,
},
ToolExecutionEnd {
call_id: String,
tool_name: String,
result: ToolResult,
duration: Duration,
},
LlmResponseReceived {
iteration: u32,
has_tool_calls: bool,
text_length: usize,
},
LoopDetected {
tool_name: String,
consecutive_count: u32,
action: LoopAction,
},
}
pub struct ToolLoopConfig {
pub max_iterations: u32,
pub parallel_tool_execution: bool,
pub on_tool_call: Option<ToolApprovalFn>,
pub on_event: Option<ToolLoopEventFn>,
pub stop_when: Option<StopConditionFn>,
pub loop_detection: Option<LoopDetectionConfig>,
pub timeout: Option<Duration>,
pub max_depth: Option<u32>,
}
impl Default for ToolLoopConfig {
fn default() -> Self {
Self {
max_iterations: 10,
parallel_tool_execution: true,
on_tool_call: None,
on_event: None,
stop_when: None,
loop_detection: None,
timeout: None,
max_depth: Some(3),
}
}
}
impl std::fmt::Debug for ToolLoopConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolLoopConfig")
.field("max_iterations", &self.max_iterations)
.field("parallel_tool_execution", &self.parallel_tool_execution)
.field("has_on_tool_call", &self.on_tool_call.is_some())
.field("has_on_event", &self.on_event.is_some())
.field("has_stop_when", &self.stop_when.is_some())
.field("loop_detection", &self.loop_detection)
.field("timeout", &self.timeout)
.field("max_depth", &self.max_depth)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum ToolApproval {
Approve,
Deny(String),
Modify(Value),
}
#[derive(Debug)]
pub struct ToolLoopResult {
pub response: ChatResponse,
pub iterations: u32,
pub total_usage: Usage,
pub termination_reason: TerminationReason,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TerminationReason {
Complete,
StopCondition {
reason: Option<String>,
},
MaxIterations {
limit: u32,
},
LoopDetected {
tool_name: String,
count: u32,
},
Timeout {
limit: Duration,
},
}