use crate::agent::AgentContext;
use crate::error::Error;
use crate::ids::ThreadId;
use crate::llm::{Message, Role, ToolCall};
use crate::memory::{Episode, ToolResult};
use tracing::debug;
pub(crate) async fn dispatch_tool_calls(
ctx: &AgentContext,
thread: &ThreadId,
tool_calls: &[ToolCall],
driver: &'static str,
) -> Result<(), Error> {
for call in tool_calls {
let tool_ctx = crate::tool::ToolCtx {
pubsub: ctx.pubsub.clone(),
kv: ctx.kv.clone(),
jobs: ctx.jobs.clone(),
};
let outcome = ctx
.tools
.invoke(&call.name, call.args.clone(), tool_ctx)
.await;
let (result_for_log, tool_msg_content) = match &outcome {
Ok(v) => (ToolResult::Ok { value: v.clone() }, v.to_string()),
Err(e) => (
ToolResult::Err {
message: e.to_string(),
},
format!("error: {e}"),
),
};
ctx.episodic
.record(
ctx.run_id,
Episode::ToolCall {
name: call.name.clone(),
args: call.args.clone(),
result: result_for_log,
},
)
.await?;
ctx.short_term
.append(
thread.clone(),
Message {
role: Role::Tool,
content: tool_msg_content,
tool_calls: vec![],
tool_call_id: Some(call.id.clone()),
},
)
.await?;
if let Err(e) = outcome {
if !e.retryable() {
return Err(Error::Tool(e));
}
}
}
debug!(n = tool_calls.len(), driver, "dispatched tools");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::AgentContext;
use crate::error::ToolError;
use crate::ids::{RunId, ThreadId};
use crate::memory::EpisodicMemory;
use crate::test_utils::{
noop_bus, FakeLlmClient, FakeToolInvoker, InMemoryEpisodic, InMemoryLongTerm,
InMemoryShortTerm,
};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
fn make_ctx(tools: Arc<FakeToolInvoker>) -> (AgentContext, Arc<InMemoryEpisodic>) {
let (pubsub, request_reply, kv, jobs) = noop_bus();
let episodic = Arc::new(InMemoryEpisodic::default());
(
AgentContext {
llm: Arc::new(FakeLlmClient::default()),
short_term: Arc::new(InMemoryShortTerm::default()),
long_term: Arc::new(InMemoryLongTerm::default()),
episodic: episodic.clone(),
pubsub,
kv,
request_reply,
jobs,
tools,
run_id: RunId::new(),
cancel: CancellationToken::new(),
agent_name: "dispatch-test".into(),
},
episodic,
)
}
fn call(id: &str, name: &str) -> ToolCall {
ToolCall {
id: id.into(),
name: name.into(),
args: serde_json::json!({}),
}
}
#[tokio::test]
async fn records_one_episode_and_one_short_term_message_per_successful_call() {
let tools = Arc::new(
FakeToolInvoker::new()
.with_tool("ok_a", "", |_| Ok(serde_json::json!("a")))
.with_tool("ok_b", "", |_| Ok(serde_json::json!("b")))
.with_tool("ok_c", "", |_| Ok(serde_json::json!("c"))),
);
let (ctx, episodic) = make_ctx(tools);
let thread = ThreadId::new("t-1");
let calls = vec![
call("c-1", "ok_a"),
call("c-2", "ok_b"),
call("c-3", "ok_c"),
];
dispatch_tool_calls(&ctx, &thread, &calls, "blocking")
.await
.expect("all-Ok dispatch must succeed");
let episodes = episodic.replay(ctx.run_id).await.unwrap();
let tool_call_count = episodes
.iter()
.filter(|e| matches!(e, Episode::ToolCall { .. }))
.count();
assert_eq!(tool_call_count, 3);
let history = ctx.short_term.load(thread, 1024).await.unwrap();
let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
assert_eq!(tool_msgs.len(), 3);
let ids: Vec<&str> = tool_msgs
.iter()
.filter_map(|m| m.tool_call_id.as_deref())
.collect();
assert_eq!(ids, vec!["c-1", "c-2", "c-3"]);
}
#[tokio::test]
async fn non_retryable_tool_error_aborts_after_persisting_that_call() {
let tools = Arc::new(
FakeToolInvoker::new()
.with_tool("ok", "", |_| Ok(serde_json::json!("ok-value")))
.with_tool("boom", "", |_| Err(ToolError::Permanent("nope".into())))
.with_tool("never_runs", "", |_| Ok(serde_json::json!("x"))),
);
let (ctx, episodic) = make_ctx(tools);
let thread = ThreadId::new("t-abort");
let calls = vec![
call("c-1", "ok"),
call("c-2", "boom"),
call("c-3", "never_runs"),
];
let err = dispatch_tool_calls(&ctx, &thread, &calls, "blocking")
.await
.expect_err("non-retryable must abort");
assert!(
matches!(&err, Error::Tool(ToolError::Permanent(m)) if m == "nope"),
"expected Tool(Permanent), got {err:?}"
);
let episodes = episodic.replay(ctx.run_id).await.unwrap();
let tool_episode_count = episodes
.iter()
.filter(|e| matches!(e, Episode::ToolCall { .. }))
.count();
assert_eq!(
tool_episode_count, 2,
"only ok + boom should have been recorded; never_runs aborts"
);
let history = ctx.short_term.load(thread, 1024).await.unwrap();
let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
assert_eq!(tool_msgs.len(), 2);
assert!(tool_msgs[1].content.contains("error: permanent: nope"));
}
#[tokio::test]
async fn retryable_tool_error_is_swallowed_and_loop_continues() {
let tools = Arc::new(
FakeToolInvoker::new()
.with_tool("flaky", "", |_| {
Err(ToolError::Retryable {
message: "transient".into(),
retry_after_secs: 1,
})
})
.with_tool("after", "", |_| Ok(serde_json::json!("after-value"))),
);
let (ctx, episodic) = make_ctx(tools);
let thread = ThreadId::new("t-retry");
let calls = vec![call("c-1", "flaky"), call("c-2", "after")];
dispatch_tool_calls(&ctx, &thread, &calls, "streaming")
.await
.expect("retryable error must not abort the loop");
let episodes = episodic.replay(ctx.run_id).await.unwrap();
let tool_episode_count = episodes
.iter()
.filter(|e| matches!(e, Episode::ToolCall { .. }))
.count();
assert_eq!(tool_episode_count, 2);
let history = ctx.short_term.load(thread, 1024).await.unwrap();
let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
assert_eq!(tool_msgs.len(), 2);
assert!(tool_msgs[0].content.contains("error: retryable"));
assert_eq!(tool_msgs[1].content, "\"after-value\"");
}
}