use crate::context::{
self, CompactionStrategy, ContextConfig, DefaultCompaction, ExecutionLimits, ExecutionTracker,
};
use crate::provider::{ModelConfig, StreamConfig, StreamEvent, StreamProvider, ToolDefinition};
use crate::types::*;
use std::sync::Arc;
pub type ConvertToLlmFn = Box<dyn Fn(&[AgentMessage]) -> Vec<Message> + Send + Sync>;
pub type TransformContextFn = Box<dyn Fn(Vec<AgentMessage>) -> Vec<AgentMessage> + Send + Sync>;
pub type GetMessagesFn = Box<dyn Fn() -> Vec<AgentMessage> + Send + Sync>;
pub type BeforeTurnFn = Arc<dyn Fn(&[AgentMessage], usize) -> bool + Send + Sync>;
pub type AfterTurnFn = Arc<dyn Fn(&[AgentMessage], &Usage) + Send + Sync>;
pub type OnErrorFn = Arc<dyn Fn(&str) + Send + Sync>;
use tokio::sync::mpsc;
use tracing::warn;
pub struct AgentLoopConfig {
pub provider: Arc<dyn StreamProvider>,
pub model: String,
pub api_key: String,
pub thinking_level: ThinkingLevel,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub model_config: Option<ModelConfig>,
pub convert_to_llm: Option<ConvertToLlmFn>,
pub transform_context: Option<TransformContextFn>,
pub get_steering_messages: Option<GetMessagesFn>,
pub get_follow_up_messages: Option<GetMessagesFn>,
pub context_config: Option<ContextConfig>,
pub compaction_strategy: Option<Arc<dyn CompactionStrategy>>,
pub execution_limits: Option<ExecutionLimits>,
pub cache_config: CacheConfig,
pub tool_execution: ToolExecutionStrategy,
pub retry_config: crate::retry::RetryConfig,
pub before_turn: Option<BeforeTurnFn>,
pub after_turn: Option<AfterTurnFn>,
pub on_error: Option<OnErrorFn>,
pub input_filters: Vec<Arc<dyn InputFilter>>,
pub turn_delay: Option<std::time::Duration>,
}
fn default_convert_to_llm(messages: &[AgentMessage]) -> Vec<Message> {
messages
.iter()
.filter_map(|m| m.as_llm().cloned())
.collect()
}
pub async fn agent_loop(
prompts: Vec<AgentMessage>,
context: &mut AgentContext,
config: &AgentLoopConfig,
tx: mpsc::UnboundedSender<AgentEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Vec<AgentMessage> {
tx.send(AgentEvent::AgentStart).ok();
let prompts = if !config.input_filters.is_empty() {
let user_text: String = prompts
.iter()
.filter_map(|m| {
if let AgentMessage::Llm(Message::User { content, .. }) = m {
Some(
content
.iter()
.filter_map(|c| {
if let Content::Text { text } = c {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n"),
)
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
let mut warnings: Vec<String> = Vec::new();
for filter in &config.input_filters {
match filter.filter(&user_text) {
FilterResult::Pass => {}
FilterResult::Warn(w) => warnings.push(w),
FilterResult::Reject(reason) => {
tx.send(AgentEvent::InputRejected {
reason: reason.clone(),
})
.ok();
tx.send(AgentEvent::AgentEnd { messages: vec![] }).ok();
return vec![];
}
}
}
if !warnings.is_empty() {
let warning_text = warnings
.iter()
.map(|w| format!("[Warning: {}]", w))
.collect::<Vec<_>>()
.join("\n");
let mut modified = prompts;
for msg in modified.iter_mut().rev() {
if let AgentMessage::Llm(Message::User { content, .. }) = msg {
content.push(Content::Text { text: warning_text });
break;
}
}
modified
} else {
prompts
}
} else {
prompts
};
let mut new_messages: Vec<AgentMessage> = prompts.clone();
for prompt in &prompts {
context.messages.push(prompt.clone());
}
tx.send(AgentEvent::TurnStart).ok();
for prompt in &prompts {
tx.send(AgentEvent::MessageStart {
message: prompt.clone(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
message: prompt.clone(),
})
.ok();
}
run_loop(context, &mut new_messages, config, &tx, &cancel).await;
tx.send(AgentEvent::AgentEnd {
messages: new_messages.clone(),
})
.ok();
new_messages
}
pub async fn agent_loop_continue(
context: &mut AgentContext,
config: &AgentLoopConfig,
tx: mpsc::UnboundedSender<AgentEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Vec<AgentMessage> {
assert!(
!context.messages.is_empty(),
"Cannot continue: no messages in context"
);
if let Some(last) = context.messages.last() {
assert!(
last.role() != "assistant",
"Cannot continue from assistant message"
);
}
let mut new_messages: Vec<AgentMessage> = Vec::new();
tx.send(AgentEvent::AgentStart).ok();
tx.send(AgentEvent::TurnStart).ok();
run_loop(context, &mut new_messages, config, &tx, &cancel).await;
tx.send(AgentEvent::AgentEnd {
messages: new_messages.clone(),
})
.ok();
new_messages
}
async fn run_loop(
context: &mut AgentContext,
new_messages: &mut Vec<AgentMessage>,
config: &AgentLoopConfig,
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
) {
let mut first_turn = true;
let mut turn_number: usize = 0;
let mut tracker = config
.execution_limits
.as_ref()
.map(|limits| ExecutionTracker::new(limits.clone()));
let mut pending: Vec<AgentMessage> = config
.get_steering_messages
.as_ref()
.map(|f| f())
.unwrap_or_default();
loop {
if cancel.is_cancelled() {
return;
}
let mut steering_after_tools: Option<Vec<AgentMessage>> = None;
loop {
if cancel.is_cancelled() {
return;
}
if !first_turn {
tx.send(AgentEvent::TurnStart).ok();
} else {
first_turn = false;
}
if !pending.is_empty() {
for msg in pending.drain(..) {
tx.send(AgentEvent::MessageStart {
message: msg.clone(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
message: msg.clone(),
})
.ok();
context.messages.push(msg.clone());
new_messages.push(msg);
}
}
if let Some(ref tracker) = tracker {
if let Some(reason) = tracker.check_limits() {
warn!("Execution limit reached: {}", reason);
let limit_msg = AgentMessage::Llm(Message::User {
content: vec![Content::Text {
text: format!("[Agent stopped: {}]", reason),
}],
timestamp: now_ms(),
});
tx.send(AgentEvent::MessageStart {
message: limit_msg.clone(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
message: limit_msg.clone(),
})
.ok();
context.messages.push(limit_msg.clone());
new_messages.push(limit_msg);
return;
}
}
if let Some(ref before_turn) = config.before_turn {
if !before_turn(&context.messages, turn_number) {
return;
}
}
if turn_number > 0 {
if let Some(delay) = config.turn_delay {
tokio::time::sleep(delay).await;
}
}
turn_number += 1;
if let Some(ref ctx_config) = config.context_config {
let strategy: &dyn CompactionStrategy = config
.compaction_strategy
.as_deref()
.unwrap_or(&DefaultCompaction);
context.messages =
strategy.compact(std::mem::take(&mut context.messages), ctx_config);
}
let message = stream_assistant_response(context, config, tx, cancel).await;
let agent_msg: AgentMessage = message.clone().into();
context.messages.push(agent_msg.clone());
new_messages.push(agent_msg.clone());
if let Message::Assistant {
ref stop_reason,
ref error_message,
ref usage,
..
} = message
{
if *stop_reason == StopReason::Error || *stop_reason == StopReason::Aborted {
if *stop_reason == StopReason::Error {
if let Some(ref on_error) = config.on_error {
let err_str = error_message.as_deref().unwrap_or("Unknown error");
on_error(err_str);
}
}
if let Some(ref after_turn) = config.after_turn {
after_turn(&context.messages, usage);
}
tx.send(AgentEvent::TurnEnd {
message: agent_msg,
tool_results: vec![],
})
.ok();
return;
}
}
let tool_calls: Vec<_> = match &message {
Message::Assistant { content, .. } => content
.iter()
.filter_map(|c| match c {
Content::ToolCall {
id,
name,
arguments,
..
} => Some((id.clone(), name.clone(), arguments.clone())),
_ => None,
})
.collect(),
_ => vec![],
};
let has_tool_calls = !tool_calls.is_empty();
let mut tool_results: Vec<Message> = Vec::new();
if has_tool_calls {
let execution = execute_tool_calls(
&context.tools,
&tool_calls,
tx,
cancel,
config.get_steering_messages.as_ref(),
&config.tool_execution,
)
.await;
tool_results = execution.tool_results;
steering_after_tools = execution.steering_messages;
for result in &tool_results {
let am: AgentMessage = result.clone().into();
context.messages.push(am.clone());
new_messages.push(am);
}
}
if let Some(ref mut tracker) = tracker {
let turn_tokens = match &message {
Message::Assistant { usage, .. } => {
(usage.input + usage.output + usage.cache_read + usage.cache_write) as usize
}
_ => context::message_tokens(&agent_msg),
};
tracker.record_turn(turn_tokens);
}
if let Some(ref after_turn) = config.after_turn {
let usage = match &message {
Message::Assistant { usage, .. } => usage.clone(),
_ => Usage::default(),
};
after_turn(&context.messages, &usage);
}
tx.send(AgentEvent::TurnEnd {
message: agent_msg,
tool_results,
})
.ok();
if let Some(steering) = steering_after_tools.take() {
if !steering.is_empty() {
pending = steering;
continue;
}
}
pending = config
.get_steering_messages
.as_ref()
.map(|f| f())
.unwrap_or_default();
if !has_tool_calls && pending.is_empty() {
break;
}
}
let follow_ups = config
.get_follow_up_messages
.as_ref()
.map(|f| f())
.unwrap_or_default();
if !follow_ups.is_empty() {
pending = follow_ups;
continue;
}
break;
}
}
async fn stream_assistant_response(
context: &AgentContext,
config: &AgentLoopConfig,
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
) -> Message {
let messages = if let Some(transform) = &config.transform_context {
transform(context.messages.clone())
} else {
context.messages.clone()
};
let convert = config.convert_to_llm.as_ref();
let llm_messages = match convert {
Some(f) => f(&messages),
None => default_convert_to_llm(&messages),
};
let tool_defs: Vec<ToolDefinition> = context
.tools
.iter()
.map(|t| ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
})
.collect();
let retry = &config.retry_config;
let mut attempt = 0;
let result = loop {
let stream_config = StreamConfig {
model: config.model.clone(),
system_prompt: context.system_prompt.clone(),
messages: llm_messages.clone(),
tools: tool_defs.clone(),
thinking_level: config.thinking_level,
api_key: config.api_key.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
model_config: config.model_config.clone(),
cache_config: config.cache_config.clone(),
};
let (stream_tx, mut stream_rx) = mpsc::unbounded_channel();
let provider_cancel = cancel.clone();
let event_tx = tx.clone();
let model_for_events = config.model.clone();
let forward_handle = tokio::spawn(async move {
let mut partial_message: Option<AgentMessage> = None;
while let Some(event) = stream_rx.recv().await {
match &event {
StreamEvent::Start => {
let placeholder = AgentMessage::Llm(Message::Assistant {
content: Vec::new(),
stop_reason: StopReason::Stop,
model: model_for_events.clone(),
provider: String::new(),
usage: Usage::default(),
timestamp: now_ms(),
error_message: None,
});
partial_message = Some(placeholder.clone());
event_tx
.send(AgentEvent::MessageStart {
message: placeholder,
})
.ok();
}
StreamEvent::TextDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
event_tx
.send(AgentEvent::MessageUpdate {
message: msg.clone(),
delta: StreamDelta::Text {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::ThinkingDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
event_tx
.send(AgentEvent::MessageUpdate {
message: msg.clone(),
delta: StreamDelta::Thinking {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::ToolCallDelta { delta, .. } => {
if let Some(ref msg) = partial_message {
event_tx
.send(AgentEvent::MessageUpdate {
message: msg.clone(),
delta: StreamDelta::ToolCallDelta {
delta: delta.clone(),
},
})
.ok();
}
}
StreamEvent::Done { message } => {
let am: AgentMessage = message.clone().into();
partial_message = Some(am.clone());
event_tx.send(AgentEvent::MessageEnd { message: am }).ok();
}
StreamEvent::Error { message } => {
let am: AgentMessage = message.clone().into();
if partial_message.is_none() {
event_tx
.send(AgentEvent::MessageStart {
message: am.clone(),
})
.ok();
}
partial_message = Some(am.clone());
event_tx.send(AgentEvent::MessageEnd { message: am }).ok();
}
_ => {}
}
}
});
let result = config
.provider
.stream(stream_config, stream_tx, provider_cancel)
.await;
match &result {
Err(e) if e.is_retryable() && attempt < retry.max_retries && !cancel.is_cancelled() => {
forward_handle.abort();
attempt += 1;
let delay = e
.retry_after()
.unwrap_or_else(|| retry.delay_for_attempt(attempt));
crate::retry::log_retry(attempt, retry.max_retries, &delay, e);
tokio::time::sleep(delay).await;
continue;
}
_ => {
let _ = forward_handle.await;
break result;
}
}
};
match result {
Ok(msg) => msg,
Err(e) => {
warn!("Provider error: {}", e);
Message::Assistant {
content: vec![Content::Text {
text: String::new(),
}],
stop_reason: StopReason::Error,
model: config.model.clone(),
provider: "unknown".into(),
usage: Usage::default(),
timestamp: now_ms(),
error_message: Some(e.to_string()),
}
}
}
}
struct ToolExecutionResult {
tool_results: Vec<Message>,
steering_messages: Option<Vec<AgentMessage>>,
}
async fn execute_tool_calls(
tools: &[Box<dyn AgentTool>],
tool_calls: &[(String, String, serde_json::Value)],
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
get_steering: Option<&GetMessagesFn>,
strategy: &ToolExecutionStrategy,
) -> ToolExecutionResult {
match strategy {
ToolExecutionStrategy::Sequential => {
execute_sequential(tools, tool_calls, tx, cancel, get_steering).await
}
ToolExecutionStrategy::Parallel => {
execute_batch(tools, tool_calls, tx, cancel, get_steering).await
}
ToolExecutionStrategy::Batched { size } => {
let mut results: Vec<Message> = Vec::new();
let mut steering_messages: Option<Vec<AgentMessage>> = None;
for (batch_idx, batch) in tool_calls.chunks(*size).enumerate() {
let batch_result = execute_batch(tools, batch, tx, cancel, None).await;
results.extend(batch_result.tool_results);
if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if !steering.is_empty() {
steering_messages = Some(steering);
let executed = (batch_idx + 1) * *size;
if executed < tool_calls.len() {
for (skip_id, skip_name, _) in &tool_calls[executed..] {
results.push(skip_tool_call(skip_id, skip_name, tx));
}
}
break;
}
}
}
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
}
}
async fn execute_sequential(
tools: &[Box<dyn AgentTool>],
tool_calls: &[(String, String, serde_json::Value)],
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
get_steering: Option<&GetMessagesFn>,
) -> ToolExecutionResult {
let mut results: Vec<Message> = Vec::new();
let mut steering_messages: Option<Vec<AgentMessage>> = None;
for (index, (id, name, args)) in tool_calls.iter().enumerate() {
let (result_msg, _is_error) = execute_single_tool(tools, id, name, args, tx, cancel).await;
results.push(result_msg);
if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if !steering.is_empty() {
steering_messages = Some(steering);
for (skip_id, skip_name, _) in &tool_calls[index + 1..] {
results.push(skip_tool_call(skip_id, skip_name, tx));
}
break;
}
}
}
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
async fn execute_batch(
tools: &[Box<dyn AgentTool>],
tool_calls: &[(String, String, serde_json::Value)],
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
get_steering: Option<&GetMessagesFn>,
) -> ToolExecutionResult {
use futures::future::join_all;
let futures: Vec<_> = tool_calls
.iter()
.map(|(id, name, args)| execute_single_tool(tools, id, name, args, tx, cancel))
.collect();
let batch_results = join_all(futures).await;
let results: Vec<Message> = batch_results.into_iter().map(|(msg, _)| msg).collect();
let steering_messages = if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if steering.is_empty() {
None
} else {
Some(steering)
}
} else {
None
};
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
async fn execute_single_tool(
tools: &[Box<dyn AgentTool>],
id: &str,
name: &str,
args: &serde_json::Value,
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
) -> (Message, bool) {
let tool = tools.iter().find(|t| t.name() == name);
tx.send(AgentEvent::ToolExecutionStart {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
args: args.clone(),
})
.ok();
let on_update: Option<ToolUpdateFn> = {
let tx = tx.clone();
let id = id.to_string();
let name = name.to_string();
Some(Arc::new(move |partial: ToolResult| {
tx.send(AgentEvent::ToolExecutionUpdate {
tool_call_id: id.clone(),
tool_name: name.clone(),
partial_result: partial,
})
.ok();
}))
};
let on_progress: Option<ProgressFn> = {
let tx = tx.clone();
let id = id.to_string();
let name = name.to_string();
Some(Arc::new(move |text: String| {
tx.send(AgentEvent::ProgressMessage {
tool_call_id: id.clone(),
tool_name: name.clone(),
text,
})
.ok();
}))
};
let ctx = ToolContext {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
cancel: cancel.child_token(),
on_update,
on_progress,
};
let (result, is_error) = match tool {
Some(tool) => match tool.execute(args.clone(), ctx).await {
Ok(r) => (r, false),
Err(e) => (
ToolResult {
content: vec![Content::Text {
text: e.to_string(),
}],
details: serde_json::Value::Null,
},
true,
),
},
None => (
ToolResult {
content: vec![Content::Text {
text: format!("Tool {} not found", name),
}],
details: serde_json::Value::Null,
},
true,
),
};
tx.send(AgentEvent::ToolExecutionEnd {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
result: result.clone(),
is_error,
})
.ok();
let tool_result_msg = Message::ToolResult {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
content: result.content,
is_error,
timestamp: now_ms(),
};
tx.send(AgentEvent::MessageStart {
message: tool_result_msg.clone().into(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
message: tool_result_msg.clone().into(),
})
.ok();
(tool_result_msg, is_error)
}
fn skip_tool_call(
tool_call_id: &str,
tool_name: &str,
tx: &mpsc::UnboundedSender<AgentEvent>,
) -> Message {
let result = ToolResult {
content: vec![Content::Text {
text: "Skipped due to queued user message.".into(),
}],
details: serde_json::Value::Null,
};
tx.send(AgentEvent::ToolExecutionStart {
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
args: serde_json::Value::Null,
})
.ok();
tx.send(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
result: result.clone(),
is_error: true,
})
.ok();
let msg = Message::ToolResult {
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
content: result.content,
is_error: true,
timestamp: now_ms(),
};
tx.send(AgentEvent::MessageStart {
message: msg.clone().into(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
message: msg.clone().into(),
})
.ok();
msg
}