use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::StreamExt;
use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
use crate::error::LlmError;
use crate::provider::{ChatParams, DynProvider};
use crate::stream::{ChatStream, StreamEvent};
use crate::usage::Usage;
use super::LoopDepth;
use super::ToolError;
use super::ToolRegistry;
use super::approval::approve_calls;
use super::config::{StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopEvent};
use super::execution::execute_with_events;
use super::loop_detection::{LoopDetectionState, handle_loop_detection};
use super::loop_sync::emit_event;
#[allow(clippy::needless_pass_by_value)] pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
params: ChatParams,
config: ToolLoopConfig,
ctx: Arc<Ctx>,
) -> ChatStream {
let current_depth = ctx.loop_depth();
if let Some(max_depth) = config.max_depth {
if current_depth >= max_depth {
return Box::pin(futures::stream::once(async move {
Err(LlmError::MaxDepthExceeded {
current: current_depth,
limit: max_depth,
})
}));
}
}
let nested_ctx = Arc::new(ctx.with_depth(current_depth + 1));
let stream = futures::stream::unfold(
ToolLoopStreamState::new(provider, registry, params, config, nested_ctx),
|mut state| async move {
loop {
match std::mem::replace(&mut state.phase, StreamPhase::Done) {
StreamPhase::Done => return None,
StreamPhase::StartIteration => match phase_start_iteration(&mut state).await {
PhaseResult::Yield(event, next) => {
state.phase = next;
return Some((event, state));
}
PhaseResult::Continue(next) => state.phase = next,
},
StreamPhase::Streaming(stream) => {
match phase_streaming(&mut state, stream).await {
PhaseResult::Yield(event, next) => {
state.phase = next;
return Some((event, state));
}
PhaseResult::Continue(next) => state.phase = next,
}
}
StreamPhase::ExecutingTools => {
state.phase = phase_executing_tools(&mut state).await;
}
}
}
},
);
Box::pin(stream)
}
enum PhaseResult {
Yield(Result<StreamEvent, LlmError>, StreamPhase),
Continue(StreamPhase),
}
async fn phase_start_iteration<Ctx: LoopDepth + Send + Sync + 'static>(
state: &mut ToolLoopStreamState<Ctx>,
) -> PhaseResult {
if let Some(limit) = state.timeout_limit {
if state.start_time.elapsed() >= limit {
let err = LlmError::ToolExecution {
tool_name: String::new(),
source: Box::new(ToolError::new(format!(
"Tool loop exceeded timeout of {limit:?}",
))),
};
return PhaseResult::Yield(Err(err), StreamPhase::Done);
}
}
state.iterations += 1;
let iterations = state.iterations;
let msg_count = state.params.messages.len();
emit_event(&state.config, || ToolLoopEvent::IterationStart {
iteration: iterations,
message_count: msg_count,
});
if state.iterations > state.config.max_iterations {
let err = LlmError::ToolExecution {
tool_name: String::new(),
source: Box::new(ToolError::new(format!(
"Tool loop exceeded {} iterations",
state.config.max_iterations,
))),
};
return PhaseResult::Yield(Err(err), StreamPhase::Done);
}
match state.provider.stream_boxed(&state.params).await {
Ok(s) => {
state.current_tool_calls.clear();
state.current_text.clear();
PhaseResult::Continue(StreamPhase::Streaming(s))
}
Err(e) => PhaseResult::Yield(Err(e), StreamPhase::Done),
}
}
async fn phase_streaming<Ctx: LoopDepth + Send + Sync + 'static>(
state: &mut ToolLoopStreamState<Ctx>,
mut stream: ChatStream,
) -> PhaseResult {
match stream.next().await {
Some(Ok(event)) => {
if let StreamEvent::TextDelta(ref text) = event {
state.current_text.push_str(text);
}
if let StreamEvent::ToolCallComplete { ref call, .. } = event {
state.current_tool_calls.push(call.clone());
}
if let StreamEvent::Usage(ref u) = event {
state.total_usage += u;
}
if let StreamEvent::Done { stop_reason } = &event {
let iterations = state.iterations;
let has_tool_calls = !state.current_tool_calls.is_empty();
let text_length = state.current_text.len();
emit_event(&state.config, || ToolLoopEvent::LlmResponseReceived {
iteration: iterations,
has_tool_calls,
text_length,
});
if let Some(ref stop_fn) = state.config.stop_when {
let response = build_response_from_stream_state(state, *stop_reason);
let ctx = StopContext {
iteration: state.iterations,
response: &response,
total_usage: &state.total_usage,
tool_calls_executed: state.tool_calls_executed,
last_tool_results: &state.last_tool_results,
};
match stop_fn(&ctx) {
StopDecision::Continue => {}
StopDecision::Stop | StopDecision::StopWithReason(_) => {
return PhaseResult::Yield(Ok(event), StreamPhase::Done);
}
}
}
if *stop_reason == StopReason::ToolUse && !state.current_tool_calls.is_empty() {
let response = build_response_from_stream_state(state, *stop_reason);
if let Some(result) = handle_loop_detection(
&mut state.loop_state,
&state.current_tool_calls,
state.config.loop_detection.as_ref(),
&state.config,
&mut state.params.messages,
&response,
state.iterations,
&state.total_usage,
) {
let err = match result.termination_reason {
TerminationReason::LoopDetected {
ref tool_name,
count,
} => LlmError::ToolExecution {
tool_name: tool_name.clone(),
source: Box::new(ToolError::new(format!(
"Tool loop detected: '{tool_name}' called {count} \
consecutive times with identical arguments"
))),
},
_ => LlmError::ToolExecution {
tool_name: String::new(),
source: Box::new(ToolError::new("Unexpected termination")),
},
};
return PhaseResult::Yield(Err(err), StreamPhase::Done);
}
return PhaseResult::Yield(Ok(event), StreamPhase::ExecutingTools);
}
}
PhaseResult::Yield(Ok(event), StreamPhase::Streaming(stream))
}
Some(Err(e)) => PhaseResult::Yield(Err(e), StreamPhase::Done),
None => PhaseResult::Continue(StreamPhase::Done),
}
}
fn build_response_from_stream_state<Ctx: LoopDepth + Send + Sync + 'static>(
state: &ToolLoopStreamState<Ctx>,
stop_reason: StopReason,
) -> ChatResponse {
let mut content = Vec::new();
if !state.current_text.is_empty() {
content.push(ContentBlock::Text(state.current_text.clone()));
}
for call in &state.current_tool_calls {
content.push(ContentBlock::ToolCall(call.clone()));
}
ChatResponse {
content,
usage: state.total_usage.clone(),
stop_reason,
model: String::new(), metadata: std::collections::HashMap::new(),
}
}
async fn phase_executing_tools<Ctx: LoopDepth + Send + Sync + 'static>(
state: &mut ToolLoopStreamState<Ctx>,
) -> StreamPhase {
let (approved, denied) = approve_calls(&state.current_tool_calls, &state.config);
let results = execute_with_events(
&state.registry,
&approved,
denied,
state.config.parallel_tool_execution,
&state.config,
&state.ctx,
)
.await;
state.tool_calls_executed += results.len();
state.last_tool_results.clone_from(&results);
let mut assistant_content: Vec<ContentBlock> = Vec::new();
if !state.current_text.is_empty() {
assistant_content.push(ContentBlock::Text(std::mem::take(&mut state.current_text)));
}
assistant_content.extend(
state
.current_tool_calls
.drain(..)
.map(ContentBlock::ToolCall),
);
state.params.messages.push(ChatMessage {
role: crate::chat::ChatRole::Assistant,
content: assistant_content,
});
for result in results {
state
.params
.messages
.push(ChatMessage::tool_result_full(result));
}
StreamPhase::StartIteration
}
struct ToolLoopStreamState<Ctx: LoopDepth + Send + Sync + 'static> {
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
params: ChatParams,
config: ToolLoopConfig,
ctx: Arc<Ctx>,
iterations: u32,
total_usage: Usage,
tool_calls_executed: usize,
last_tool_results: Vec<ToolResult>,
current_tool_calls: Vec<ToolCall>,
current_text: String,
phase: StreamPhase,
loop_state: LoopDetectionState,
start_time: Instant,
timeout_limit: Option<Duration>,
}
enum StreamPhase {
StartIteration,
Streaming(ChatStream),
ExecutingTools,
Done,
}
impl<Ctx: LoopDepth + Send + Sync + 'static> ToolLoopStreamState<Ctx> {
fn new(
provider: Arc<dyn DynProvider>,
registry: Arc<ToolRegistry<Ctx>>,
params: ChatParams,
config: ToolLoopConfig,
ctx: Arc<Ctx>,
) -> Self {
let timeout_limit = config.timeout;
Self {
provider,
registry,
params,
config,
ctx,
iterations: 0,
total_usage: Usage::default(),
tool_calls_executed: 0,
last_tool_results: Vec::new(),
current_tool_calls: Vec::new(),
current_text: String::new(),
phase: StreamPhase::StartIteration,
loop_state: LoopDetectionState::default(),
start_time: Instant::now(),
timeout_limit,
}
}
}