use lellm_core::{ChatRequest, ChatResponse, LlmError, Message, ToolCall, ToolError};
use lellm_provider::ResolvedModel;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use super::LoopState;
use super::context::{ContextBudget, estimate_text};
use super::event::AgentEvent;
use super::fallback::{FallbackAction, FallbackContext, FallbackStrategy};
use super::retry::RetryPolicy;
use super::runtime::ResolvedRound;
use super::tools::{ToolExecutor, ToolRegistration, ToolSnapshot};
type ToolMap = indexmap::IndexMap<String, ToolRegistration>;
pub async fn execute_with_fallback<T, F, Fut>(
fallback: &Arc<dyn FallbackStrategy>,
can_retry: impl Fn(&LlmError) -> bool,
mut op: F,
iteration: usize,
messages: &[Message],
) -> Result<T, LlmError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, LlmError>>,
{
let mut attempt: usize = 1;
loop {
match op().await {
Ok(v) => return Ok(v),
Err(err) => {
if !can_retry(&err) {
return Err(err);
}
tracing::warn!(
attempt = attempt,
error = %err,
"provider operation failed, fallback handling"
);
let ctx = FallbackContext {
error: &err,
attempt,
iterations: iteration,
conversation: messages.to_vec().into(),
};
match fallback.handle(&ctx).await {
FallbackAction::Retry => {
attempt += 1;
}
FallbackAction::Abort => {
return Err(err);
}
}
}
}
}
}
pub async fn emit(tx: &Sender<AgentEvent>, event: AgentEvent) -> bool {
tx.send(event).await.is_ok()
}
pub(super) async fn emit_and_execute_tools_with(
tx: &Sender<AgentEvent>,
snapshot: &ToolSnapshot,
retry_policy: &RetryPolicy,
tool_calls: &[ToolCall],
) -> Option<Vec<Message>> {
if tool_calls.is_empty() {
return Some(Vec::new());
}
for tc in tool_calls {
if !emit(
tx,
AgentEvent::ToolStart {
tool_call_id: tc.id.clone(),
name: tc.name.clone(),
},
)
.await
{
return None;
}
}
let mut safe_calls: Vec<(usize, ToolCall)> = Vec::new();
let mut category_calls: std::collections::HashMap<
super::tools::ToolCategory,
Vec<(usize, ToolCall)>,
> = std::collections::HashMap::new();
let mut exclusive_calls: Vec<(usize, ToolCall)> = Vec::new();
for (idx, call) in tool_calls.iter().enumerate() {
let safety = snapshot
.get(&call.name)
.map(|t| t.safety.clone())
.unwrap_or(super::tools::ParallelSafety::Exclusive);
match safety {
super::tools::ParallelSafety::Safe => safe_calls.push((idx, call.clone())),
super::tools::ParallelSafety::CategoryExclusive => {
if let Some(cat) = snapshot.get(&call.name).and_then(|t| t.category.clone()) {
category_calls
.entry(cat)
.or_default()
.push((idx, call.clone()));
} else {
exclusive_calls.push((idx, call.clone()));
}
}
super::tools::ParallelSafety::Exclusive => exclusive_calls.push((idx, call.clone())),
}
}
let snapshot_arc: Arc<ToolMap> = snapshot.clone_for_spawn();
let retry_policy = retry_policy.clone();
let tx = tx.clone();
let mut group_handles: Vec<tokio::task::JoinHandle<Vec<(usize, Message)>>> = Vec::new();
if !safe_calls.is_empty() {
let s = Arc::clone(&snapshot_arc);
let rp = retry_policy.clone();
let tx_clone = tx.clone();
group_handles.push(tokio::spawn(async move {
let handles: Vec<_> = safe_calls
.iter()
.map(|(idx, call)| {
let tools = Arc::clone(&s);
let rp = rp.clone();
let call = call.clone();
let idx = *idx;
let tx = tx_clone.clone();
tokio::spawn(async move {
let result = match tools.get(&call.name) {
Some(entry) => {
rp.execute_with_retry(&entry.func, &call.arguments).await
}
None => {
Err(ToolError::not_found(format!("unknown tool: {}", call.name)))
}
};
let _ = emit(
&tx,
AgentEvent::ToolEnd {
tool_call_id: call.id.clone(),
result: result.clone(),
},
)
.await;
(idx, Message::tool_result(&call, &result))
})
})
.collect();
let raw = futures_util::future::join_all(handles).await;
raw.into_iter()
.map(|h| match h {
Ok((idx, msg)) => (idx, msg),
Err(join_err) => {
panic!("tool task panicked: {join_err}");
}
})
.collect()
}));
}
for group_calls in category_calls.into_values() {
let s = Arc::clone(&snapshot_arc);
let rp = retry_policy.clone();
let tx_clone = tx.clone();
group_handles.push(tokio::spawn(async move {
let mut results = Vec::with_capacity(group_calls.len());
for (idx, call) in group_calls {
let result = match s.get(&call.name) {
Some(entry) => rp.execute_with_retry(&entry.func, &call.arguments).await,
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
};
let _ = emit(
&tx_clone,
AgentEvent::ToolEnd {
tool_call_id: call.id.clone(),
result: result.clone(),
},
)
.await;
results.push((idx, Message::tool_result(&call, &result)));
}
results
}));
}
if !exclusive_calls.is_empty() {
let s = Arc::clone(&snapshot_arc);
let rp = retry_policy.clone();
let tx_clone = tx.clone();
group_handles.push(tokio::spawn(async move {
let mut results = Vec::with_capacity(exclusive_calls.len());
for (idx, call) in exclusive_calls {
let result = match s.get(&call.name) {
Some(entry) => rp.execute_with_retry(&entry.func, &call.arguments).await,
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
};
let _ = emit(
&tx_clone,
AgentEvent::ToolEnd {
tool_call_id: call.id.clone(),
result: result.clone(),
},
)
.await;
results.push((idx, Message::tool_result(&call, &result)));
}
results
}));
}
let mut results: Vec<Option<Message>> = vec![None; tool_calls.len()];
let all_handles = futures_util::future::join_all(group_handles).await;
for handle_result in all_handles {
if let Ok(indexed_messages) = handle_result {
for (idx, msg) in indexed_messages {
results[idx] = Some(msg);
}
}
}
Some(results.into_iter().flatten().collect())
}
pub fn build_partial_response(
text_buffer: String,
thinking_buffer: String,
redacted_buffer: Option<String>,
) -> ChatResponse {
let mut content: Vec<lellm_core::ContentBlock> = Vec::new();
if !thinking_buffer.is_empty() {
content.push(lellm_core::ContentBlock::Thinking(
lellm_core::ThinkingBlock {
thinking: thinking_buffer,
redacted: redacted_buffer,
},
));
}
if !text_buffer.is_empty() {
content.push(lellm_core::ContentBlock::Text(lellm_core::TextBlock {
text: text_buffer,
cache_control: None,
}));
}
if content.is_empty() {
content.push(lellm_core::ContentBlock::Text(lellm_core::TextBlock {
text: String::new(),
cache_control: None,
}));
}
ChatResponse::new(
content,
lellm_core::TokenUsage::default(),
serde_json::json!(null),
)
}
#[must_use]
pub(super) enum StreamIterResult {
Continue { response: ChatResponse },
Complete { response: ChatResponse },
Cancelled { response: Option<ChatResponse> },
OutputBudgetExceeded { response: ChatResponse },
ReasoningBudgetExceeded { response: ChatResponse },
}
async fn process_stream_iteration(
tx: &Sender<AgentEvent>,
executor: &ToolExecutor,
state: &mut LoopState,
stream: &mut lellm_provider::ProviderStream,
text_buffer: &mut String,
thinking_buffer: &mut String,
redacted_buffer: &mut Option<String>,
budget: &ContextBudget,
max_output_tokens: u32,
max_reasoning_tokens: Option<u32>,
stream_thinking: bool,
round: ResolvedRound,
) -> Result<StreamIterResult, LlmError> {
use futures_util::StreamExt;
let mut round_output_tokens: usize = 0;
let mut round_reasoning_tokens: usize = 0;
while let Some(result) = stream.next().await {
let ev = match result {
Ok(ev) => ev,
Err(e) => return Err(e),
};
match &ev {
lellm_provider::ProviderEvent::Token { token } => {
round_output_tokens += estimate_text(token);
if (round_output_tokens as u32) > max_output_tokens {
tracing::warn!(
round_output_tokens,
max_output_tokens,
"single-round output budget exceeded, cutting stream"
);
let response = build_partial_response(
text_buffer.clone(),
thinking_buffer.clone(),
redacted_buffer.clone(),
);
return Ok(StreamIterResult::OutputBudgetExceeded { response });
}
text_buffer.push_str(token);
}
lellm_provider::ProviderEvent::ThinkingDelta { thinking, redacted } => {
round_reasoning_tokens += estimate_text(thinking)
+ redacted.as_ref().map(|r| estimate_text(r)).unwrap_or(0);
if let Some(limit) = max_reasoning_tokens {
if (round_reasoning_tokens as u32) > limit {
tracing::warn!(
round_reasoning_tokens,
max_reasoning_tokens = limit,
"single-round reasoning budget exceeded, cutting stream"
);
let response = build_partial_response(
text_buffer.clone(),
thinking_buffer.clone(),
redacted_buffer.clone(),
);
return Ok(StreamIterResult::ReasoningBudgetExceeded { response });
}
}
thinking_buffer.push_str(thinking);
if let Some(r) = redacted {
if let Some(ref mut prev) = *redacted_buffer {
prev.push_str(r);
} else {
*redacted_buffer = Some(r.clone());
}
}
}
lellm_provider::ProviderEvent::Start { .. }
| lellm_provider::ProviderEvent::ResponseComplete { .. } => {}
}
if matches!(&ev, lellm_provider::ProviderEvent::ThinkingDelta { .. }) && !stream_thinking {
} else if !emit(tx, AgentEvent::Provider(ev.clone())).await {
return Ok(StreamIterResult::Cancelled { response: None });
}
if let lellm_provider::ProviderEvent::ResponseComplete { tool_calls, usage } = ev {
let pending_tool_calls = tool_calls;
let usage_val = usage.unwrap_or_default();
let mut content: Vec<lellm_core::ContentBlock> = Vec::new();
if !thinking_buffer.is_empty() {
content.push(lellm_core::ContentBlock::Thinking(
lellm_core::ThinkingBlock {
thinking: thinking_buffer.clone(),
redacted: redacted_buffer.clone(),
},
));
}
if !text_buffer.is_empty() {
content.push(lellm_core::ContentBlock::Text(lellm_core::TextBlock {
text: text_buffer.clone(),
cache_control: None,
}));
}
content.extend(
pending_tool_calls
.iter()
.map(|tc| lellm_core::ContentBlock::ToolCall(tc.clone())),
);
let response = ChatResponse::new(content, usage_val, serde_json::json!(null));
if !pending_tool_calls.is_empty() {
state.push_assistant(response.content.clone());
state.add_output_from_content(&response.content);
state.add_tool_calls(pending_tool_calls.len());
let results = emit_and_execute_tools_with(
tx,
&round.snapshot,
&executor.retry_policy(),
&pending_tool_calls,
)
.await;
if results.is_none() {
return Ok(StreamIterResult::Cancelled {
response: Some(response),
});
}
state.push_tool_results(results.unwrap(), budget);
tracing::debug!(
iteration = state.iterations,
tool_calls = pending_tool_calls.len(),
"tool-use stream iteration"
);
return Ok(StreamIterResult::Continue { response });
} else {
state.add_output_from_content(&response.content);
if !emit(
tx,
AgentEvent::LoopEnd {
result: state.finish_complete(response.clone()),
},
)
.await
{
return Ok(StreamIterResult::Cancelled {
response: Some(response),
});
}
return Ok(StreamIterResult::Complete { response });
}
}
}
Err(LlmError::UnexpectedEof)
}
pub(super) struct StreamIterationResult {
pub(super) result: Result<(StreamIterResult, LoopState), LlmError>,
pub(super) stream_started: bool,
}
pub(super) async fn do_stream_iteration(
model: ResolvedModel,
tx: Sender<AgentEvent>,
executor: ToolExecutor,
state: LoopState,
req: ChatRequest,
budget: ContextBudget,
max_output_tokens: u32,
stream_thinking: bool,
round: ResolvedRound,
) -> StreamIterationResult {
let max_reasoning_tokens = req.max_reasoning_tokens;
let mut stream = match model.provider.stream(&req).await {
Ok(s) => s,
Err(e) => {
return StreamIterationResult {
result: Err(e),
stream_started: false,
};
}
};
let mut text_buffer = String::new();
let mut thinking_buffer = String::new();
let mut redacted_buffer: Option<String> = None;
let mut attempt_state = state;
let iter_result = process_stream_iteration(
&tx,
&executor,
&mut attempt_state,
&mut stream,
&mut text_buffer,
&mut thinking_buffer,
&mut redacted_buffer,
&budget,
max_output_tokens,
max_reasoning_tokens,
stream_thinking,
round,
)
.await;
StreamIterationResult {
result: iter_result.map(|r| (r, attempt_state)),
stream_started: true,
}
}