use crate::agent::extension::{Cancel, Extension, ToolOutput};
use crate::agent::provider::{Provider, StopReason, StreamEvent, ToolDef};
use crate::agent::types::{AgentMessage, PendingMessageQueue, Role, ToolCall, ToolExecutionMode};
use futures::future::join_all;
pub fn collect_tool_defs(extensions: &[Box<dyn Extension>]) -> Vec<ToolDef> {
let mut defs = Vec::new();
for ext in extensions {
for tool in ext.tools() {
if !defs.iter().any(|d: &ToolDef| d.name == tool.name()) {
defs.push(ToolDef {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters(),
});
}
}
}
defs
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum AgentEvent {
AgentStart,
TurnStart,
TextDelta {
delta: String,
},
ThinkingDelta {
delta: String,
},
ToolCall {
id: String,
name: String,
args: serde_json::Value,
},
ToolCallArgsUpdate {
id: String,
args: serde_json::Value,
},
ToolResult {
id: String,
name: String,
content: String,
compact: Option<String>,
is_error: bool,
},
ToolProgress {
content: String,
is_error: bool,
},
Aborted {
reason: String,
},
UserMessage {
content: String,
},
TurnEnd,
AgentEnd {
messages: Vec<AgentMessage>,
},
}
pub type TransformFn = Box<dyn Fn(&[AgentMessage]) -> Vec<AgentMessage> + Send + Sync>;
pub type PrepareNextTurnFn = Box<dyn Fn(&[AgentMessage]) -> Option<TurnUpdate> + Send + Sync>;
pub type ShouldStopFn = Box<dyn Fn(&[AgentMessage]) -> bool + Send + Sync>;
pub struct TurnUpdate {
pub context: Option<Vec<AgentMessage>>,
}
pub struct LoopConfig<'a> {
pub model: String,
pub system_prompt: String,
pub tools: Vec<ToolDef>,
pub agent_tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
pub extensions: &'a [Box<dyn Extension>],
pub tool_execution: ToolExecutionMode,
pub steering_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
pub follow_up_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
pub transform_context: Option<TransformFn>,
pub prepare_next_turn: Option<PrepareNextTurnFn>,
pub should_stop_after_turn: Option<ShouldStopFn>,
}
fn find_tool<'a>(
tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
name: &str,
) -> Option<&'a dyn crate::agent::extension::AgentTool> {
tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
}
const MAX_TOOL_ITERATIONS: usize = 25;
struct ToolExecOutcome {
id: String,
name: String,
content: String,
compact: Option<String>,
is_error: bool,
terminate: bool,
}
pub async fn run_agent_loop(
prompts: Vec<AgentMessage>,
history: Vec<AgentMessage>,
config: &LoopConfig<'_>,
provider: &dyn Provider,
emit: &mut (dyn FnMut(AgentEvent) + Send),
) -> anyhow::Result<Vec<AgentMessage>> {
let mut messages: Vec<AgentMessage> = Vec::new();
messages.extend(history);
messages.extend(prompts.clone());
let mut new_messages: Vec<AgentMessage> = prompts.clone();
emit(AgentEvent::AgentStart);
emit(AgentEvent::TurnStart);
let mut iteration_count: usize = 0;
loop {
let mut has_more_tool_calls = true;
while has_more_tool_calls {
iteration_count += 1;
if iteration_count > MAX_TOOL_ITERATIONS {
let msg = format!(
"Agent loop exceeded maximum iterations ({}). Last response may be incomplete.",
MAX_TOOL_ITERATIONS
);
emit(AgentEvent::Aborted {
reason: msg.clone(),
});
emit(AgentEvent::AgentEnd {
messages: new_messages.clone(),
});
return Ok(new_messages);
}
drain_steering(config, &mut messages, &mut new_messages, emit);
let llm_messages: &[AgentMessage] = &messages;
let _transformed_holder;
let llm_messages = if let Some(ref transform) = config.transform_context {
_transformed_holder = transform(llm_messages);
&_transformed_holder
} else {
llm_messages
};
let mut stream = provider
.stream(
&config.model,
&config.system_prompt,
llm_messages,
&config.tools,
)
.await?;
let mut response_text = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut stop_reason = StopReason::EndTurn;
while let Some(event) = futures::StreamExt::next(&mut stream).await {
match event {
StreamEvent::TextDelta { text } => {
response_text.push_str(&text);
emit(AgentEvent::TextDelta { delta: text });
}
StreamEvent::ThinkingDelta { text } => {
emit(AgentEvent::ThinkingDelta { delta: text });
}
StreamEvent::ToolCall {
id,
name,
arguments,
} => {
let args: serde_json::Value = serde_json::from_str(&arguments)
.unwrap_or(serde_json::Value::String(arguments.clone()));
if let Some(existing) = tool_calls.iter_mut().find(|tc| tc.id == id) {
existing.arguments = args;
} else {
tool_calls.push(ToolCall {
id,
name,
arguments: args,
});
}
}
StreamEvent::Done {
text,
stop_reason: sr,
tool_calls: tcs,
..
} => {
if response_text.is_empty() && !text.is_empty() {
emit(AgentEvent::TextDelta {
delta: text.clone(),
});
}
response_text = text;
stop_reason = sr;
if !tcs.is_empty() {
tool_calls = tcs;
}
}
StreamEvent::Error { message } => {
emit(AgentEvent::Aborted {
reason: message.clone(),
});
emit(AgentEvent::ToolResult {
id: String::new(),
name: String::new(),
content: message.clone(),
compact: None,
is_error: true,
});
let error_msg =
AgentMessage::tool_result(String::new(), message.clone(), true);
new_messages.push(error_msg);
emit(AgentEvent::AgentEnd {
messages: new_messages.clone(),
});
return Ok(new_messages);
}
}
}
let assistant_msg = AgentMessage {
id: uuid::Uuid::new_v4().to_string(),
parent_id: None,
role: Role::Assistant,
content: response_text.clone(),
tool_calls: tool_calls.clone(),
tool_call_id: None,
usage: None,
is_error: false,
timestamp: chrono::Utc::now().timestamp_millis(),
};
messages.push(assistant_msg.clone());
new_messages.push(assistant_msg);
if stop_reason == StopReason::Error {
emit(AgentEvent::AgentEnd {
messages: new_messages.clone(),
});
return Ok(new_messages);
}
if !tool_calls.is_empty() {
let has_sequential_tool = tool_calls.iter().any(|tc| {
config
.agent_tools
.iter()
.find(|t| t.name() == tc.name)
.map(|t| t.execution_mode() == ToolExecutionMode::Sequential)
.unwrap_or(false)
});
let effective_mode = if has_sequential_tool {
ToolExecutionMode::Sequential
} else {
config.tool_execution
};
let outcomes = match effective_mode {
ToolExecutionMode::Parallel => {
execute_tool_calls_parallel(&tool_calls, config, emit).await
}
ToolExecutionMode::Sequential => {
execute_tool_calls_sequential(&tool_calls, config, emit).await
}
};
let all_terminate = !outcomes.is_empty() && outcomes.iter().all(|o| o.terminate);
for outcome in outcomes {
let msg =
AgentMessage::tool_result(&outcome.id, &outcome.content, outcome.is_error);
emit(AgentEvent::ToolResult {
id: outcome.id,
name: outcome.name,
content: outcome.content,
compact: outcome.compact,
is_error: outcome.is_error,
});
messages.push(msg.clone());
new_messages.push(msg);
}
apply_prepare_next_turn(config, &mut messages, &new_messages);
if all_terminate {
emit(AgentEvent::TurnEnd);
break;
}
continue;
}
has_more_tool_calls = false;
emit(AgentEvent::TurnEnd);
apply_prepare_next_turn(config, &mut messages, &new_messages);
if apply_should_stop_after_turn(config, &new_messages) {
emit(AgentEvent::AgentEnd {
messages: new_messages.clone(),
});
return Ok(new_messages);
}
}
if !drain_follow_up(config, &mut messages, &mut new_messages, emit) {
break;
}
}
emit(AgentEvent::AgentEnd {
messages: new_messages.clone(),
});
Ok(new_messages)
}
fn drain_steering(
config: &LoopConfig<'_>,
messages: &mut Vec<AgentMessage>,
new_messages: &mut Vec<AgentMessage>,
emit: &mut (dyn FnMut(AgentEvent) + Send),
) -> bool {
let Some(queue) = config.steering_queue else {
return false;
};
let drained = queue.lock().unwrap().drain();
if drained.is_empty() {
return false;
}
for msg in drained {
emit(AgentEvent::UserMessage {
content: msg.content.clone(),
});
messages.push(msg.clone());
new_messages.push(msg);
}
true
}
fn drain_follow_up(
config: &LoopConfig<'_>,
messages: &mut Vec<AgentMessage>,
new_messages: &mut Vec<AgentMessage>,
emit: &mut (dyn FnMut(AgentEvent) + Send),
) -> bool {
let Some(queue) = config.follow_up_queue else {
return false;
};
let drained = queue.lock().unwrap().drain();
if drained.is_empty() {
return false;
}
for msg in drained {
emit(AgentEvent::UserMessage {
content: msg.content.clone(),
});
messages.push(msg.clone());
new_messages.push(msg);
}
true
}
fn apply_prepare_next_turn(
config: &LoopConfig<'_>,
messages: &mut Vec<AgentMessage>,
new_messages: &[AgentMessage],
) {
if let Some(ref prepare) = config.prepare_next_turn
&& let Some(update) = prepare(new_messages)
&& let Some(ctx) = update.context
{
*messages = ctx;
}
}
fn apply_should_stop_after_turn(config: &LoopConfig<'_>, new_messages: &[AgentMessage]) -> bool {
config
.should_stop_after_turn
.as_ref()
.map(|stop| stop(new_messages))
.unwrap_or(false)
}
async fn execute_tool_calls_sequential(
tool_calls: &[ToolCall],
config: &LoopConfig<'_>,
emit: &mut (dyn FnMut(AgentEvent) + Send),
) -> Vec<ToolExecOutcome> {
let mut outcomes = Vec::new();
for tc in tool_calls {
emit(AgentEvent::ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
args: tc.arguments.clone(),
});
let mut blocked = false;
for ext in config.extensions {
if let Some(reason) = ext.before_tool_call(tc).await {
outcomes.push(ToolExecOutcome {
id: tc.id.clone(),
name: tc.name.clone(),
content: format!("Tool execution blocked: {:?}", reason),
compact: None,
is_error: true,
terminate: false,
});
blocked = true;
break;
}
}
if blocked {
continue;
}
let outcome = execute_single_tool(
tc,
config.agent_tools,
config.extensions,
None, )
.await;
outcomes.push(outcome);
}
outcomes
}
async fn execute_tool_calls_parallel(
tool_calls: &[ToolCall],
config: &LoopConfig<'_>,
emit: &mut (dyn FnMut(AgentEvent) + Send),
) -> Vec<ToolExecOutcome> {
let mut outcomes: Vec<ToolExecOutcome> = Vec::with_capacity(tool_calls.len());
let mut futures: Vec<
std::pin::Pin<Box<dyn std::future::Future<Output = ToolExecOutcome> + Send + '_>>,
> = Vec::new();
for tc in tool_calls {
emit(AgentEvent::ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
args: tc.arguments.clone(),
});
let mut blocked = false;
for ext in config.extensions {
if let Some(reason) = ext.before_tool_call(tc).await {
outcomes.push(ToolExecOutcome {
id: tc.id.clone(),
name: tc.name.clone(),
content: format!("Tool execution blocked: {:?}", reason),
compact: None,
is_error: true,
terminate: false,
});
blocked = true;
break;
}
}
if blocked {
continue;
}
let tc_clone = tc.clone();
futures.push(Box::pin(async move {
execute_single_tool(
&tc_clone,
config.agent_tools,
config.extensions,
None, )
.await
}));
}
if !futures.is_empty() {
let results = join_all(futures).await;
outcomes.extend(results);
}
outcomes
}
async fn execute_single_tool(
tc: &ToolCall,
agent_tools: &[Box<dyn crate::agent::extension::AgentTool>],
extensions: &[Box<dyn Extension>],
progress_tx: Option<tokio::sync::mpsc::UnboundedSender<AgentEvent>>,
) -> ToolExecOutcome {
let cancel = Cancel::new();
if let Some(tool) = find_tool(agent_tools, &tc.name) {
let args = tool.prepare_arguments(tc.arguments.clone());
let on_update = progress_tx.as_ref().map(|_| {
let (tool_tx, mut tool_rx) = tokio::sync::mpsc::unbounded_channel::<ToolOutput>();
if let Some(ref tx) = progress_tx {
let tx = tx.clone();
tokio::spawn(async move {
while let Some(output) = tool_rx.recv().await {
let _ = tx.send(AgentEvent::ToolProgress {
content: output.content,
is_error: output.is_error,
});
}
});
}
tool_tx
});
match tool.execute(tc.id.clone(), args, cancel, on_update).await {
Ok(output) => {
let mut final_result = output.content.clone();
for ext in extensions {
if let Some(overridden) = ext.after_tool_call(tc, &final_result).await {
final_result = overridden;
}
}
ToolExecOutcome {
id: tc.id.clone(),
name: tc.name.clone(),
content: final_result,
compact: output.compact,
is_error: false,
terminate: output.terminate,
}
}
Err(e) => ToolExecOutcome {
id: tc.id.clone(),
name: tc.name.clone(),
content: format!("{:#}", e),
compact: None,
is_error: true,
terminate: false,
},
}
} else {
ToolExecOutcome {
id: tc.id.clone(),
name: tc.name.clone(),
content: format!("Tool '{}' not found", tc.name),
compact: None,
is_error: true,
terminate: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::extension::{AgentTool, BlockReason, Cancel, ToolOutput};
use crate::agent::provider::StreamEvent;
use crate::agent::types::{
AgentMessage, PendingMessageQueue, QueueMode, Role, ToolCall, ToolExecutionMode,
};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
struct MockProvider {
responses: Arc<std::sync::Mutex<Vec<MockResponse>>>,
sent_messages: Arc<std::sync::Mutex<Vec<Vec<AgentMessage>>>>,
}
struct MockResponse {
text: String,
tool_calls: Vec<ToolCall>,
stop_reason: StopReason,
thinking: String,
}
impl MockProvider {
fn new() -> Self {
Self {
responses: Arc::new(std::sync::Mutex::new(Vec::new())),
sent_messages: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
fn add_response(&self, text: &str) {
self.responses.lock().unwrap().push(MockResponse {
text: text.to_string(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
thinking: String::new(),
});
}
fn add_tool_call_response(&self, text: &str, tool_calls: Vec<ToolCall>) {
self.responses.lock().unwrap().push(MockResponse {
text: text.to_string(),
tool_calls,
stop_reason: StopReason::ToolUse,
thinking: String::new(),
});
}
#[allow(dead_code)]
fn sent_message_count(&self) -> usize {
self.sent_messages.lock().unwrap().len()
}
#[allow(dead_code)]
fn last_sent_message_count(&self) -> usize {
let msgs = self.sent_messages.lock().unwrap();
msgs.last().map(|m| m.len()).unwrap_or(0)
}
}
#[async_trait]
impl Provider for MockProvider {
async fn stream(
&self,
_model: &str,
_system: &str,
messages: &[AgentMessage],
_tools: &[ToolDef],
) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
self.sent_messages.lock().unwrap().push(messages.to_vec());
let mut resp = self.responses.lock().unwrap();
let response = if resp.is_empty() {
MockResponse {
text: String::new(),
tool_calls: vec![],
stop_reason: StopReason::EndTurn,
thinking: String::new(),
}
} else {
resp.remove(0)
};
drop(resp);
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
if !response.thinking.is_empty() {
let _ = tx.send(StreamEvent::ThinkingDelta {
text: response.thinking.clone(),
});
}
if !response.text.is_empty() {
let _ = tx.send(StreamEvent::TextDelta {
text: response.text.clone(),
});
}
let _ = tx.send(StreamEvent::Done {
text: response.text,
usage: crate::agent::types::Usage::default(),
stop_reason: response.stop_reason,
tool_calls: response.tool_calls,
});
use futures::stream::unfold;
let stream = unfold(rx, |mut rx| async move {
rx.recv().await.map(|event| (event, rx))
});
Ok(Box::pin(stream))
}
}
struct MockTool {
name: String,
execution_mode: ToolExecutionMode,
execute_delay: std::time::Duration,
executed: Arc<std::sync::Mutex<Vec<String>>>,
terminate: bool,
}
impl MockTool {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
execution_mode: ToolExecutionMode::Parallel,
execute_delay: std::time::Duration::ZERO,
executed: Arc::new(std::sync::Mutex::new(Vec::new())),
terminate: false,
}
}
#[allow(dead_code)]
fn with_sequential(mut self) -> Self {
self.execution_mode = ToolExecutionMode::Sequential;
self
}
fn with_delay(mut self, delay: std::time::Duration) -> Self {
self.execute_delay = delay;
self
}
fn with_terminate(mut self) -> Self {
self.terminate = true;
self
}
}
#[async_trait]
impl AgentTool for MockTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"mock tool"
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
&self.name
}
fn execution_mode(&self) -> ToolExecutionMode {
self.execution_mode
}
async fn execute(
&self,
tool_call_id: String,
_args: serde_json::Value,
_cancel: Cancel,
_on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
self.executed.lock().unwrap().push(tool_call_id.clone());
if self.execute_delay > std::time::Duration::ZERO {
tokio::time::sleep(self.execute_delay).await;
}
Ok(ToolOutput {
content: format!("executed: {}", tool_call_id),
compact: None,
is_error: false,
terminate: self.terminate,
})
}
}
#[derive(Debug, Clone)]
struct EventRecorder {
events: Arc<std::sync::Mutex<Vec<AgentEvent>>>,
}
impl EventRecorder {
fn new() -> Self {
Self {
events: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
fn record(&self, event: AgentEvent) {
self.events.lock().unwrap().push(event);
}
fn events(&self) -> Vec<AgentEvent> {
self.events.lock().unwrap().clone()
}
fn event_types(&self) -> Vec<String> {
self.events()
.iter()
.map(|e| match e {
AgentEvent::AgentStart => "agent_start".to_string(),
AgentEvent::TurnStart => "turn_start".to_string(),
AgentEvent::TextDelta { .. } => "text_delta".to_string(),
AgentEvent::ThinkingDelta { .. } => "thinking_delta".to_string(),
AgentEvent::ToolCall { .. } => "tool_call".to_string(),
AgentEvent::ToolCallArgsUpdate { .. } => "tool_call_args_update".to_string(),
AgentEvent::ToolResult { .. } => "tool_result".to_string(),
AgentEvent::ToolProgress { .. } => "tool_progress".to_string(),
AgentEvent::Aborted { .. } => "aborted".to_string(),
AgentEvent::UserMessage { .. } => "user_message".to_string(),
AgentEvent::TurnEnd => "turn_end".to_string(),
AgentEvent::AgentEnd { .. } => "agent_end".to_string(),
})
.collect()
}
fn text_deltas(&self) -> Vec<String> {
self.events()
.iter()
.filter_map(|e| {
if let AgentEvent::TextDelta { delta } = e {
Some(delta.clone())
} else {
None
}
})
.collect()
}
}
#[tokio::test]
async fn test_basic_text_response() {
let provider = MockProvider::new();
provider.add_response("Hello, world!");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "You are helpful.".to_string(),
tools: vec![],
agent_tools: &[],
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let prompt = AgentMessage::user("Hi");
let result = run_agent_loop(vec![prompt], vec![], &config, &provider, &mut emit)
.await
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].role, Role::User);
assert_eq!(result[1].role, Role::Assistant);
let types = recorder.event_types();
assert!(types.contains(&"agent_start".to_string()));
assert!(types.contains(&"text_delta".to_string()));
assert!(types.contains(&"turn_end".to_string()));
assert!(types.contains(&"agent_end".to_string()));
let texts = recorder.text_deltas();
assert!(texts.iter().any(|t| t == "Hello, world!"));
}
#[tokio::test]
async fn test_sequential_tool_execution() {
let tool = MockTool::new("echo");
let tool_executed = Arc::clone(&tool.executed);
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(tool)];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![
ToolCall {
id: "call-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "call-2".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
},
],
);
provider.add_response("Done after tools.");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("run tools")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert_eq!(result.len(), 5);
let executed = tool_executed.lock().unwrap().clone();
assert_eq!(executed.len(), 2);
assert_eq!(executed[0], "call-1");
assert_eq!(executed[1], "call-2");
let types = recorder.event_types();
assert!(types.contains(&"tool_call".to_string()));
assert!(types.contains(&"tool_result".to_string()));
}
#[tokio::test]
async fn test_parallel_tool_execution() {
let fast_tool =
Arc::new(MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)));
let slow_tool =
Arc::new(MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)));
let _fast_executed = Arc::clone(&fast_tool.executed);
let _slow_executed = Arc::clone(&slow_tool.executed);
let start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let start_times_clone = Arc::clone(&start_times);
struct TrackingTool {
inner: MockTool,
start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>>,
}
#[async_trait]
impl AgentTool for TrackingTool {
fn name(&self) -> &str {
self.inner.name()
}
fn description(&self) -> &str {
"tracking"
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
self.inner.name()
}
async fn execute(
&self,
tool_call_id: String,
args: serde_json::Value,
cancel: Cancel,
on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
self.start_times
.lock()
.unwrap()
.push((tool_call_id.clone(), std::time::Instant::now()));
self.inner
.execute(tool_call_id, args, cancel, on_update)
.await
}
}
let agent_tools: Vec<Box<dyn AgentTool>> = vec![
Box::new(TrackingTool {
inner: MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)),
start_times: Arc::clone(&start_times),
}),
Box::new(TrackingTool {
inner: MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)),
start_times: Arc::clone(&start_times_clone),
}),
];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![
ToolCall {
id: "slow-1".to_string(),
name: "slow".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "fast-1".to_string(),
name: "fast".to_string(),
arguments: serde_json::json!({}),
},
],
);
provider.add_response("All tools done.");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
run_agent_loop(
vec![AgentMessage::user("run tools")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let times = start_times.lock().unwrap();
assert_eq!(times.len(), 2, "both tools should have started");
let names: Vec<&str> = times.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"slow-1"));
assert!(names.contains(&"fast-1"));
}
#[tokio::test]
async fn test_per_tool_sequential_mode() {
let executed = Arc::new(std::sync::Mutex::new(Vec::new()));
{
let _seq_exec = Arc::clone(&executed);
let _par_exec = Arc::clone(&executed);
struct SeqTool;
#[async_trait]
impl AgentTool for SeqTool {
fn name(&self) -> &str {
"sequential_tool"
}
fn description(&self) -> &str {
""
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
"sequential_tool"
}
fn execution_mode(&self) -> ToolExecutionMode {
ToolExecutionMode::Sequential
}
async fn execute(
&self,
tool_call_id: String,
_args: serde_json::Value,
_cancel: Cancel,
_on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
}
}
struct ParTool {
executed: Arc<std::sync::Mutex<Vec<String>>>,
}
#[async_trait]
impl AgentTool for ParTool {
fn name(&self) -> &str {
"parallel_tool"
}
fn description(&self) -> &str {
""
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
"parallel_tool"
}
async fn execute(
&self,
tool_call_id: String,
_args: serde_json::Value,
_cancel: Cancel,
_on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
self.executed.lock().unwrap().push(tool_call_id.clone());
Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
}
}
let agent_tools: Vec<Box<dyn AgentTool>> = vec![
Box::new(SeqTool),
Box::new(ParTool {
executed: Arc::clone(&executed),
}),
];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![
ToolCall {
id: "seq-1".to_string(),
name: "sequential_tool".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "par-1".to_string(),
name: "parallel_tool".to_string(),
arguments: serde_json::json!({}),
},
],
);
provider.add_response("Done.");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
run_agent_loop(
vec![AgentMessage::user("run")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let exec_order = executed.lock().unwrap().clone();
assert_eq!(
exec_order.len(),
1,
"only parallel_tool records in executed"
);
}
}
#[tokio::test]
async fn test_terminate_stops_loop() {
let agent_tools: Vec<Box<dyn AgentTool>> =
vec![Box::new(MockTool::new("final").with_terminate())];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "final-1".to_string(),
name: "final".to_string(),
arguments: serde_json::json!({}),
}],
);
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("final")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert_eq!(
result.len(),
3,
"should stop after terminate without second LLM call"
);
let types = recorder.event_types();
assert!(types.contains(&"turn_end".to_string()));
assert!(types.contains(&"agent_end".to_string()));
}
#[tokio::test]
async fn test_transform_context() {
let provider = MockProvider::new();
provider.add_response("Response");
let transform_called = Arc::new(std::sync::Mutex::new(false));
let transform_called_clone = Arc::clone(&transform_called);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &[],
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: Some(Box::new(move |msgs| {
*transform_called_clone.lock().unwrap() = true;
msgs.to_vec()
})),
prepare_next_turn: None,
should_stop_after_turn: None,
};
let mut emit = |_: AgentEvent| {};
run_agent_loop(
vec![AgentMessage::user("hi")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert!(
*transform_called.lock().unwrap(),
"transform_context should be called"
);
}
#[tokio::test]
async fn test_prepare_next_turn() {
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
);
provider.add_response("After prepare.");
let prepare_called = Arc::new(std::sync::Mutex::new(false));
let prepare_called_clone = Arc::clone(&prepare_called);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: Some(Box::new(move |_new_msgs| {
*prepare_called_clone.lock().unwrap() = true;
None })),
should_stop_after_turn: None,
};
let mut emit = |_: AgentEvent| {};
run_agent_loop(
vec![AgentMessage::user("run")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert!(
*prepare_called.lock().unwrap(),
"prepare_next_turn should be called"
);
}
#[tokio::test]
async fn test_should_stop_after_turn() {
let provider = MockProvider::new();
provider.add_response("First turn.");
let stop = Arc::new(std::sync::Mutex::new(true));
let stop_clone = Arc::clone(&stop);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &[],
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: Some(Box::new(move |_| *stop_clone.lock().unwrap())),
};
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
run_agent_loop(
vec![AgentMessage::user("hi")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let types = recorder.event_types();
let agent_end_count = types.iter().filter(|t| *t == "agent_end").count();
assert_eq!(agent_end_count, 1, "should end exactly once");
}
#[tokio::test]
async fn test_steering_queue() {
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
);
provider.add_response("After tool.");
provider.add_response("After steering.");
let steering_queue = std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
steering_queue
.lock()
.unwrap()
.enqueue(AgentMessage::user("steer here"));
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: Some(&steering_queue),
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("run")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let types = recorder.event_types();
let user_msg_count = types.iter().filter(|t| *t == "user_message").count();
assert!(
user_msg_count >= 1,
"steering should produce at least one user_message event, got {}",
user_msg_count
);
let user_messages: Vec<&AgentMessage> =
result.iter().filter(|m| m.role == Role::User).collect();
assert_eq!(
user_messages.len(),
2,
"should have original prompt + steering message"
);
}
#[tokio::test]
async fn test_follow_up_queue() {
let provider = MockProvider::new();
provider.add_response("First response.");
provider.add_response("Follow-up response.");
let follow_up_queue =
std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
follow_up_queue
.lock()
.unwrap()
.enqueue(AgentMessage::user("follow up"));
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &[],
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: Some(&follow_up_queue),
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("first")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert_eq!(
result.len(),
4,
"follow-up should add another user+assistant pair"
);
assert_eq!(
result[2].content, "follow up",
"third message should be the injected follow-up"
);
let types = recorder.event_types();
assert!(types.contains(&"user_message".to_string()));
}
#[tokio::test]
async fn test_message_queue_modes() {
let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
queue.enqueue(AgentMessage::user("msg1"));
queue.enqueue(AgentMessage::user("msg2"));
let batch1 = queue.drain();
assert_eq!(batch1.len(), 1, "OneAtATime should drain 1");
assert_eq!(batch1[0].content, "msg1");
let batch2 = queue.drain();
assert_eq!(batch2.len(), 1, "OneAtATime should drain 1 on second call");
assert_eq!(batch2[0].content, "msg2");
assert!(
queue.drain().is_empty(),
"should be empty after both drained"
);
let mut queue = PendingMessageQueue::new(QueueMode::All);
queue.enqueue(AgentMessage::user("a"));
queue.enqueue(AgentMessage::user("b"));
let all = queue.drain();
assert_eq!(all.len(), 2, "All mode should drain both");
assert!(queue.drain().is_empty(), "should be empty after drain");
let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
queue.enqueue(AgentMessage::user("x"));
queue.clear();
assert!(queue.is_empty());
}
#[tokio::test]
async fn test_prepare_arguments() {
struct PrepTool;
#[async_trait]
impl AgentTool for PrepTool {
fn name(&self) -> &str {
"prep_tool"
}
fn description(&self) -> &str {
""
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
"prep_tool"
}
fn prepare_arguments(&self, args: serde_json::Value) -> serde_json::Value {
let mut m = serde_json::Map::new();
m.insert("prepared".to_string(), serde_json::json!(true));
if let Some(obj) = args.as_object() {
for (k, v) in obj {
m.insert(k.clone(), v.clone());
}
}
serde_json::Value::Object(m)
}
async fn execute(
&self,
_tool_call_id: String,
args: serde_json::Value,
_cancel: Cancel,
_on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
assert_eq!(args.get("prepared").and_then(|v| v.as_bool()), Some(true));
Ok(ToolOutput::ok("prepared ok"))
}
}
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(PrepTool)];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "prep_tool".to_string(),
arguments: serde_json::json!({"original": "value"}),
}],
);
provider.add_response("Done.");
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let mut emit = |_: AgentEvent| {};
let result = run_agent_loop(
vec![AgentMessage::user("prep")],
vec![],
&config,
&provider,
&mut emit,
)
.await;
assert!(
result.is_ok(),
"prepare_arguments should work without error"
);
}
#[tokio::test]
async fn test_before_tool_call_blocks() {
struct BlockingExt;
#[async_trait]
impl Extension for BlockingExt {
fn name(&self) -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed("blocker")
}
async fn before_tool_call(&self, _tc: &ToolCall) -> Option<BlockReason> {
Some(BlockReason::Security("blocked for test".into()))
}
}
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
let extensions: Vec<Box<dyn Extension>> = vec![Box::new(BlockingExt)];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
);
provider.add_response("After blocked tool.");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &extensions,
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("block test")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
assert!(
result.len() >= 3,
"blocked tool should still produce a result"
);
let tool_results: Vec<&AgentMessage> = result
.iter()
.filter(|m| m.role == Role::ToolResult)
.collect();
assert!(!tool_results.is_empty());
assert!(
tool_results[0].is_error,
"blocked tool result should be error"
);
assert!(
tool_results[0].content.contains("blocked"),
"blocked result should mention block reason"
);
}
#[tokio::test]
async fn test_provider_error_aborts() {
struct ErrorProvider;
#[async_trait]
impl Provider for ErrorProvider {
async fn stream(
&self,
_model: &str,
_system: &str,
_messages: &[AgentMessage],
_tools: &[ToolDef],
) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
anyhow::bail!("provider error")
}
}
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &[],
extensions: &[],
tool_execution: ToolExecutionMode::Parallel,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("hi")],
vec![],
&config,
&ErrorProvider,
&mut emit,
)
.await;
assert!(result.is_err(), "provider error should propagate");
}
#[tokio::test]
async fn test_tool_execution_error() {
struct ErrorTool;
#[async_trait]
impl AgentTool for ErrorTool {
fn name(&self) -> &str {
"error_tool"
}
fn description(&self) -> &str {
""
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({})
}
fn label(&self) -> &str {
"error_tool"
}
async fn execute(
&self,
_tool_call_id: String,
_args: serde_json::Value,
_cancel: Cancel,
_on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
anyhow::bail!("tool crashed")
}
}
let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(ErrorTool)];
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "error_tool".to_string(),
arguments: serde_json::json!({}),
}],
);
provider.add_response("After error.");
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("error test")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let tool_results: Vec<&AgentMessage> = result
.iter()
.filter(|m| m.role == Role::ToolResult)
.collect();
assert!(!tool_results.is_empty());
assert!(tool_results[0].is_error);
}
#[tokio::test]
async fn test_tool_not_found() {
let provider = MockProvider::new();
provider.add_tool_call_response(
"",
vec![ToolCall {
id: "tool-1".to_string(),
name: "nonexistent".to_string(),
arguments: serde_json::json!({}),
}],
);
provider.add_response("After missing tool.");
let agent_tools: Vec<Box<dyn AgentTool>> = vec![];
let recorder = EventRecorder::new();
let mut emit = |e: AgentEvent| recorder.record(e);
let config = LoopConfig {
model: "test".to_string(),
system_prompt: "".to_string(),
tools: vec![],
agent_tools: &agent_tools,
extensions: &[],
tool_execution: ToolExecutionMode::Sequential,
steering_queue: None,
follow_up_queue: None,
transform_context: None,
prepare_next_turn: None,
should_stop_after_turn: None,
};
let result = run_agent_loop(
vec![AgentMessage::user("test")],
vec![],
&config,
&provider,
&mut emit,
)
.await
.unwrap();
let tool_results: Vec<&AgentMessage> = result
.iter()
.filter(|m| m.role == Role::ToolResult)
.collect();
assert!(!tool_results.is_empty());
assert!(tool_results[0].is_error);
assert!(tool_results[0].content.contains("not found"));
}
}