use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use futures::Stream;
use serde_json::Value;
use crate::chat::{ChatResponse, ToolCall, ToolResult};
use crate::error::LlmError;
use crate::usage::Usage;
use super::cacher::ToolResultCacher;
use super::extractor::ToolResultExtractor;
use super::processor::ToolResultProcessor;
pub type ToolApprovalFn = Arc<dyn Fn(&ToolCall) -> ToolApproval + Send + Sync>;
pub type StopConditionFn = Arc<dyn Fn(&StopContext) -> StopDecision + Send + Sync>;
pub type LoopStream = Pin<Box<dyn Stream<Item = Result<LoopEvent, LlmError>> + Send>>;
#[derive(Debug)]
pub struct StopContext<'a> {
pub iteration: u32,
pub response: &'a ChatResponse,
pub total_usage: &'a Usage,
pub tool_calls_executed: usize,
pub last_tool_results: &'a [ToolResult],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StopDecision {
Continue,
Stop,
StopWithReason(String),
}
#[derive(Debug, Clone, Copy)]
pub struct LoopDetectionConfig {
pub threshold: u32,
pub action: LoopAction,
}
impl Default for LoopDetectionConfig {
fn default() -> Self {
Self {
threshold: 3,
action: LoopAction::Warn,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoopAction {
Warn,
Stop,
InjectWarning,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum LoopEvent {
TextDelta(String),
ReasoningDelta(String),
ToolCallStart {
index: u32,
id: String,
name: String,
},
ToolCallDelta {
index: u32,
json_chunk: String,
},
ToolCallComplete {
index: u32,
call: ToolCall,
},
Usage(Usage),
IterationStart {
iteration: u32,
message_count: usize,
},
ToolExecutionStart {
call_id: String,
tool_name: String,
arguments: Value,
},
ToolExecutionEnd {
call_id: String,
tool_name: String,
result: ToolResult,
duration: Duration,
},
ToolResultProcessed {
tool_name: String,
original_tokens: u32,
processed_tokens: u32,
},
ToolResultExtracted {
tool_name: String,
original_tokens: u32,
extracted_tokens: u32,
},
ToolResultCached {
tool_name: String,
original_tokens: u32,
summary_tokens: u32,
},
ObservationsMasked {
masked_count: usize,
tokens_saved: u32,
},
LoopDetected {
tool_name: String,
consecutive_count: u32,
action: LoopAction,
},
Done(ToolLoopResult),
}
pub struct ToolLoopConfig {
pub max_iterations: u32,
pub parallel_tool_execution: bool,
pub on_tool_call: Option<ToolApprovalFn>,
pub stop_when: Option<StopConditionFn>,
pub loop_detection: Option<LoopDetectionConfig>,
pub timeout: Option<Duration>,
pub result_processor: Option<Arc<dyn ToolResultProcessor>>,
pub result_extractor: Option<Arc<dyn ToolResultExtractor>>,
pub result_cacher: Option<Arc<dyn ToolResultCacher>>,
pub masking: Option<ObservationMaskingConfig>,
pub force_mask_iterations: Option<Arc<std::sync::Mutex<std::collections::HashSet<u32>>>>,
pub max_depth: Option<u32>,
}
#[derive(Debug, Clone, Copy)]
pub struct ObservationMaskingConfig {
pub max_iterations_to_keep: u32,
pub min_tokens_to_mask: u32,
}
impl Default for ObservationMaskingConfig {
fn default() -> Self {
Self {
max_iterations_to_keep: 2,
min_tokens_to_mask: 500,
}
}
}
impl Clone for ToolLoopConfig {
fn clone(&self) -> Self {
Self {
max_iterations: self.max_iterations,
parallel_tool_execution: self.parallel_tool_execution,
on_tool_call: self.on_tool_call.clone(),
stop_when: self.stop_when.clone(),
loop_detection: self.loop_detection,
timeout: self.timeout,
result_processor: self.result_processor.clone(),
result_extractor: self.result_extractor.clone(),
result_cacher: self.result_cacher.clone(),
masking: self.masking,
force_mask_iterations: self.force_mask_iterations.clone(),
max_depth: self.max_depth,
}
}
}
impl Default for ToolLoopConfig {
fn default() -> Self {
Self {
max_iterations: 10,
parallel_tool_execution: true,
on_tool_call: None,
stop_when: None,
loop_detection: None,
timeout: None,
result_processor: None,
result_extractor: None,
result_cacher: None,
masking: None,
force_mask_iterations: None,
max_depth: Some(3),
}
}
}
impl std::fmt::Debug for ToolLoopConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolLoopConfig")
.field("max_iterations", &self.max_iterations)
.field("parallel_tool_execution", &self.parallel_tool_execution)
.field("has_on_tool_call", &self.on_tool_call.is_some())
.field("has_stop_when", &self.stop_when.is_some())
.field("loop_detection", &self.loop_detection)
.field("timeout", &self.timeout)
.field("has_result_processor", &self.result_processor.is_some())
.field("has_result_extractor", &self.result_extractor.is_some())
.field("has_result_cacher", &self.result_cacher.is_some())
.field("masking", &self.masking)
.field(
"has_force_mask_iterations",
&self.force_mask_iterations.is_some(),
)
.field("max_depth", &self.max_depth)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum ToolApproval {
Approve,
Deny(String),
Modify(Value),
}
#[derive(Debug, Clone)]
pub struct ToolLoopResult {
pub response: ChatResponse,
pub iterations: u32,
pub total_usage: Usage,
pub termination_reason: TerminationReason,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TerminationReason {
Complete,
StopCondition {
reason: Option<String>,
},
MaxIterations {
limit: u32,
},
LoopDetected {
tool_name: String,
count: u32,
},
Timeout {
limit: Duration,
},
}