use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use futures::stream::{FuturesUnordered, StreamExt};
use serde_json::Value;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::warn;
use crate::approval::{ApprovalDecision, ApprovalHandler, AutoApprove};
use crate::message::Content;
use crate::mode::{AgentMode, ModeDecision};
use crate::steering::ToolRunTracker;
use crate::tool::{Tool, ToolClass, ToolContext};
#[derive(Debug, Clone)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub input: Value,
}
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
let mut map: HashMap<String, Arc<dyn Tool>> = HashMap::with_capacity(tools.len());
for t in tools {
let name = t.name().to_string();
if map.insert(name.clone(), t).is_some() {
warn!(
tool = %name,
"duplicate tool name in registry; later registration overrode earlier"
);
}
}
Self { tools: map }
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub fn iter(&self) -> impl Iterator<Item = &Arc<dyn Tool>> {
self.tools.values()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
pub trait ToolPolicy: Send + Sync {
fn is_allowed(&self, tool_name: &str) -> bool;
}
pub struct AllowAll;
impl ToolPolicy for AllowAll {
fn is_allowed(&self, _tool_name: &str) -> bool {
true
}
}
#[derive(Debug, Clone, Default)]
pub struct ToolConcurrency {
enabled: bool,
per_tool_cap: Option<usize>,
}
impl ToolConcurrency {
pub fn on() -> Self {
Self {
enabled: true,
per_tool_cap: None,
}
}
pub fn off() -> Self {
Self {
enabled: false,
per_tool_cap: None,
}
}
#[must_use]
pub fn max(mut self, n: usize) -> Self {
assert!(n > 0, "ToolConcurrency::max requires n > 0");
self.per_tool_cap = Some(n);
self
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn per_tool_cap(&self) -> Option<usize> {
self.per_tool_cap
}
}
#[derive(Debug, Clone)]
pub struct ConcurrencyConfig {
read_cap: usize,
mut_cap: usize,
per_tool_caps: HashMap<String, usize>,
promoted: HashSet<String>,
read: Arc<Semaphore>,
serial_mut: Arc<Semaphore>,
concurrent_mut: Arc<Semaphore>,
per_tool: HashMap<String, Arc<Semaphore>>,
}
impl Default for ConcurrencyConfig {
fn default() -> Self {
Self {
read_cap: 20,
mut_cap: 10,
per_tool_caps: HashMap::new(),
promoted: HashSet::new(),
read: Arc::new(Semaphore::new(20)),
serial_mut: Arc::new(Semaphore::new(1)),
concurrent_mut: Arc::new(Semaphore::new(10)),
per_tool: HashMap::new(),
}
}
}
impl ConcurrencyConfig {
pub fn new(
max_concurrent_reads: usize,
max_concurrent_mutations: usize,
tool_settings: impl IntoIterator<Item = (String, ToolConcurrency)>,
) -> Self {
assert!(
max_concurrent_reads > 0,
"max_concurrent_reads requires n > 0"
);
assert!(
max_concurrent_mutations > 0,
"max_concurrent_mutations requires n > 0"
);
let mut per_tool_caps: HashMap<String, usize> = HashMap::new();
let mut promoted: HashSet<String> = HashSet::new();
for (name, cfg) in tool_settings {
if cfg.is_enabled() {
promoted.insert(name.clone());
} else {
promoted.remove(&name);
}
match cfg.per_tool_cap() {
Some(cap) => {
per_tool_caps.insert(name, cap);
}
None => {
per_tool_caps.remove(&name);
}
}
}
let per_tool: HashMap<String, Arc<Semaphore>> = per_tool_caps
.iter()
.map(|(name, &cap)| (name.clone(), Arc::new(Semaphore::new(cap))))
.collect();
Self {
read_cap: max_concurrent_reads,
mut_cap: max_concurrent_mutations,
per_tool_caps,
promoted,
read: Arc::new(Semaphore::new(max_concurrent_reads)),
serial_mut: Arc::new(Semaphore::new(1)),
concurrent_mut: Arc::new(Semaphore::new(max_concurrent_mutations)),
per_tool,
}
}
#[must_use]
pub fn fork(&self) -> Self {
let per_tool: HashMap<String, Arc<Semaphore>> = self
.per_tool_caps
.iter()
.map(|(name, &cap)| (name.clone(), Arc::new(Semaphore::new(cap))))
.collect();
Self {
read_cap: self.read_cap,
mut_cap: self.mut_cap,
per_tool_caps: self.per_tool_caps.clone(),
promoted: self.promoted.clone(),
read: Arc::clone(&self.read),
serial_mut: Arc::clone(&self.serial_mut),
concurrent_mut: Arc::new(Semaphore::new(self.mut_cap)),
per_tool,
}
}
}
pub struct ToolExecutor {
registry: Arc<ToolRegistry>,
policy: Arc<dyn ToolPolicy>,
approval: Arc<dyn ApprovalHandler>,
concurrency: ConcurrencyConfig,
}
impl ToolExecutor {
pub fn new(registry: Arc<ToolRegistry>, policy: Arc<dyn ToolPolicy>) -> Self {
Self::with_approval_and_concurrency(
registry,
policy,
Arc::new(AutoApprove),
ConcurrencyConfig::default(),
)
}
pub fn with_approval(
registry: Arc<ToolRegistry>,
policy: Arc<dyn ToolPolicy>,
approval: Arc<dyn ApprovalHandler>,
) -> Self {
Self::with_approval_and_concurrency(
registry,
policy,
approval,
ConcurrencyConfig::default(),
)
}
pub fn with_approval_and_concurrency(
registry: Arc<ToolRegistry>,
policy: Arc<dyn ToolPolicy>,
approval: Arc<dyn ApprovalHandler>,
concurrency: ConcurrencyConfig,
) -> Self {
Self {
registry,
policy,
approval,
concurrency,
}
}
pub fn registry(&self) -> &Arc<ToolRegistry> {
&self.registry
}
pub(crate) fn policy_arc_for_fork(&self) -> Arc<dyn ToolPolicy> {
Arc::clone(&self.policy)
}
#[must_use]
pub fn fork_for_subagent(&self) -> Arc<Self> {
self.fork_for_subagent_with(None, None)
}
#[must_use]
pub fn fork_for_subagent_with(
&self,
policy_override: Option<Arc<dyn ToolPolicy>>,
approval_override: Option<Arc<dyn ApprovalHandler>>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::clone(&self.registry),
policy: policy_override.unwrap_or_else(|| Arc::clone(&self.policy)),
approval: approval_override.unwrap_or_else(|| Arc::clone(&self.approval)),
concurrency: self.concurrency.fork(),
})
}
pub async fn execute_one(&self, call: ToolCall, ctx: &ToolContext) -> Content {
if !self.policy.is_allowed(&call.name) {
return Content::tool_result(
&call.id,
format!("Error: tool '{}' is not allowed by policy", call.name),
true,
);
}
let Some(tool) = self.registry.get(&call.name) else {
return Content::tool_result(
&call.id,
format!("Error: tool '{}' not found", call.name),
true,
);
};
let class = tool.class();
let decision = tokio::select! {
biased;
_ = ctx.cancel.cancelled() => {
return Content::tool_result(
&call.id,
"Error: cancelled while awaiting approval",
true,
);
}
d = self.approval.approve(&call.name, &call.input, class) => d,
};
if let ApprovalDecision::Deny(reason) = decision {
return Content::tool_result(
&call.id,
format!("Error: approval denied — {reason}"),
true,
);
}
match tool.execute(call.input, ctx).await {
Ok(output) => Content::tool_result(&call.id, output.content(), output.is_error()),
Err(e) => Content::tool_result(&call.id, format!("Error: {e}"), true),
}
}
pub async fn execute_batch(&self, calls: Vec<ToolCall>, ctx: &ToolContext) -> Vec<Content> {
self.execute_batch_with_tracker(calls, ctx, None, None)
.await
}
pub(crate) async fn execute_batch_with_tracker(
&self,
calls: Vec<ToolCall>,
ctx: &ToolContext,
tracker: Option<ToolRunTracker>,
mode: Option<Arc<dyn AgentMode>>,
) -> Vec<Content> {
if calls.is_empty() {
return Vec::new();
}
if ctx.cancel.is_cancelled() {
return all_cancelled_before_execution(calls);
}
let control = DispatchControl { tracker, mode };
let n = calls.len();
let routings: Vec<RoutingClass> = calls.iter().map(|c| self.routing_class(c)).collect();
let mut calls: Vec<Option<ToolCall>> = calls.into_iter().map(Some).collect();
let mut slots: Vec<Option<Content>> = (0..n).map(|_| None).collect();
let mut i = 0;
while i < n {
if ctx.cancel.is_cancelled() {
fill_cancelled_tail(&mut calls, &mut slots, i);
break;
}
let j = same_class_run_end(&routings, i);
self.dispatch_run(
&mut calls,
&mut slots,
&routings,
i..j,
ctx,
control.clone(),
)
.await;
i = j;
}
slots
.into_iter()
.map(|o| o.expect("every slot filled by dispatch or cancel short-circuit"))
.collect()
}
async fn dispatch_run(
&self,
calls: &mut [Option<ToolCall>],
slots: &mut [Option<Content>],
routings: &[RoutingClass],
range: std::ops::Range<usize>,
ctx: &ToolContext,
control: DispatchControl,
) {
let mut futs = FuturesUnordered::new();
for k in range {
let call = calls[k]
.take()
.expect("each slot taken exactly once during dispatch");
futs.push(self.dispatch_one(k, call, routings[k], ctx, control.clone()));
}
while let Some((idx, content)) = futs.next().await {
slots[idx] = Some(content);
}
}
fn routing_class(&self, call: &ToolCall) -> RoutingClass {
if !self.policy.is_allowed(&call.name) {
return RoutingClass::ShortCircuit;
}
let Some(tool) = self.registry.get(&call.name) else {
return RoutingClass::ShortCircuit;
};
if tool.is_recursive() {
return RoutingClass::ConcurrentMut;
}
match tool.class() {
ToolClass::ReadOnly => RoutingClass::ReadOnly,
ToolClass::Mutating if self.concurrency.promoted.contains(&call.name) => {
RoutingClass::ConcurrentMut
}
ToolClass::Mutating => RoutingClass::SerialMut,
}
}
async fn dispatch_one(
&self,
idx: usize,
call: ToolCall,
routing: RoutingClass,
ctx: &ToolContext,
control: DispatchControl,
) -> (usize, Content) {
if matches!(routing, RoutingClass::ShortCircuit) {
return (idx, self.short_circuit_result(&call));
}
let call_id = call.id.clone();
let child_cancel = ctx.cancel.child_token();
if let Some(tracker) = &control.tracker {
tracker.register(&call_id, child_cancel.clone());
}
let child_ctx = ctx.with_cancel(child_cancel);
let class_sem = self.class_semaphore_for(routing);
let per_tool_sem = self.concurrency.per_tool.get(&call.name).cloned();
let _permits = match acquire_admission(per_tool_sem, class_sem, &child_ctx).await {
Some(permits) => permits,
None => {
if let Some(tracker) = &control.tracker {
tracker.mark_done(&call_id);
}
return (idx, cancelled_before_execution(&call_id));
}
};
if let Some(denial) = self.mode_denial(&call, control.mode.as_deref()) {
if let Some(tracker) = &control.tracker {
tracker.mark_done(&call_id);
}
return (idx, denial);
}
let content = self.execute_one(call, &child_ctx).await;
if let Some(tracker) = &control.tracker {
tracker.mark_done(&call_id);
}
(idx, content)
}
fn mode_denial(&self, call: &ToolCall, mode: Option<&dyn AgentMode>) -> Option<Content> {
let mode = mode?;
let tool = self.registry.get(&call.name)?;
let class = tool.class();
let decision = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
mode.tool_gate(&call.name, class).unwrap_or_else(|| {
if class == ToolClass::Mutating && !mode.allows_mutating_tools() {
ModeDecision::Deny {
reason: format!("mode '{}' denies mutating tools", mode.name()).into(),
}
} else {
ModeDecision::Allow
}
})
}));
match decision {
Ok(ModeDecision::Allow) => None,
Err(_) => Some(Content::tool_result(
&call.id,
format!("Error: mode gate panicked for tool '{}'", call.name),
true,
)),
Ok(ModeDecision::Deny { reason }) => Some(Content::tool_result(
&call.id,
format!("Error: mode denied tool '{}' — {reason}", call.name),
true,
)),
}
}
fn short_circuit_result(&self, call: &ToolCall) -> Content {
let body = if !self.policy.is_allowed(&call.name) {
format!("Error: tool '{}' is not allowed by policy", call.name)
} else {
format!("Error: tool '{}' not found", call.name)
};
Content::tool_result(&call.id, body, true)
}
fn class_semaphore_for(&self, routing: RoutingClass) -> Arc<Semaphore> {
match routing {
RoutingClass::ReadOnly => Arc::clone(&self.concurrency.read),
RoutingClass::ConcurrentMut => Arc::clone(&self.concurrency.concurrent_mut),
RoutingClass::SerialMut => Arc::clone(&self.concurrency.serial_mut),
RoutingClass::ShortCircuit => {
unreachable!("ShortCircuit takes the early-return path before reaching here")
}
}
}
}
#[derive(Clone)]
struct DispatchControl {
tracker: Option<ToolRunTracker>,
mode: Option<Arc<dyn AgentMode>>,
}
enum AcquireOutcome {
Permit(OwnedSemaphorePermit),
Cancelled,
}
async fn acquire_admission(
per_tool_sem: Option<Arc<Semaphore>>,
class_sem: Arc<Semaphore>,
ctx: &ToolContext,
) -> Option<(Option<OwnedSemaphorePermit>, OwnedSemaphorePermit)> {
let per_tool_permit = match per_tool_sem {
Some(sem) => match acquire_or_cancel(sem, ctx).await {
AcquireOutcome::Permit(p) => Some(p),
AcquireOutcome::Cancelled => return None,
},
None => None,
};
let class_permit = match acquire_or_cancel(class_sem, ctx).await {
AcquireOutcome::Permit(p) => p,
AcquireOutcome::Cancelled => return None,
};
Some((per_tool_permit, class_permit))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RoutingClass {
ReadOnly,
ConcurrentMut,
SerialMut,
ShortCircuit,
}
async fn acquire_or_cancel(sem: Arc<Semaphore>, ctx: &ToolContext) -> AcquireOutcome {
tokio::select! {
biased;
_ = ctx.cancel.cancelled() => AcquireOutcome::Cancelled,
p = sem.acquire_owned() => AcquireOutcome::Permit(
p.expect("semaphore not closed — executor never closes its own semaphores"),
),
}
}
fn cancelled_before_execution(call_id: &str) -> Content {
Content::tool_result(
call_id,
"Error: cancelled before execution".to_string(),
true,
)
}
fn all_cancelled_before_execution(calls: Vec<ToolCall>) -> Vec<Content> {
calls
.into_iter()
.map(|c| cancelled_before_execution(&c.id))
.collect()
}
fn fill_cancelled_tail(
calls: &mut [Option<ToolCall>],
slots: &mut [Option<Content>],
start: usize,
) {
for k in start..calls.len() {
if let Some(call) = calls[k].take() {
slots[k] = Some(cancelled_before_execution(&call.id));
}
}
}
fn same_class_run_end(routings: &[RoutingClass], start: usize) -> usize {
if matches!(routings[start], RoutingClass::SerialMut) {
return start + 1;
}
let n = routings.len();
let mut end = start + 1;
while end < n && routings[end] == routings[start] {
end += 1;
}
end
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ToolError;
use crate::tool::{ToolClass, ToolOutput};
use async_trait::async_trait;
use serde_json::json;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
struct Echo;
#[async_trait]
impl Tool for Echo {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"echo"
}
fn input_schema(&self) -> Value {
json!({})
}
fn class(&self) -> ToolClass {
ToolClass::ReadOnly
}
async fn execute(&self, input: Value, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::text(input["msg"].as_str().unwrap_or("")))
}
}
fn empty_executor() -> Arc<ToolExecutor> {
Arc::new(ToolExecutor::new(
Arc::new(ToolRegistry::new(vec![])),
Arc::new(AllowAll),
))
}
fn ctx() -> ToolContext {
ToolContext {
working_dir: PathBuf::from("/tmp"),
cancel: tokio_util::sync::CancellationToken::new(),
depth: 0,
max_depth: 1,
executor: empty_executor(),
}
}
fn call(name: &str, input: Value) -> ToolCall {
ToolCall {
id: "id".into(),
name: name.into(),
input,
}
}
#[tokio::test]
async fn allow_all_runs_tool() {
let reg = Arc::new(ToolRegistry::new(vec![Arc::new(Echo)]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let res = exec
.execute_one(call("echo", json!({"msg": "hi"})), &ctx())
.await;
let Content::ToolResult {
content, is_error, ..
} = res
else {
panic!("expected tool_result");
};
assert!(!is_error);
assert_eq!(content, "hi");
}
#[tokio::test]
async fn missing_tool_returns_error_result() {
let reg = Arc::new(ToolRegistry::new(vec![]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let res = exec.execute_one(call("ghost", json!({})), &ctx()).await;
let Content::ToolResult {
content, is_error, ..
} = res
else {
panic!("expected tool_result");
};
assert!(is_error);
assert!(content.contains("not found"));
}
struct DenyNamed(&'static str);
impl ToolPolicy for DenyNamed {
fn is_allowed(&self, name: &str) -> bool {
name != self.0
}
}
#[tokio::test]
async fn policy_denial_returns_error_result() {
let reg = Arc::new(ToolRegistry::new(vec![Arc::new(Echo)]));
let exec = ToolExecutor::new(reg, Arc::new(DenyNamed("echo")));
let res = exec
.execute_one(call("echo", json!({"msg": "hi"})), &ctx())
.await;
let Content::ToolResult {
content, is_error, ..
} = res
else {
panic!("expected tool_result");
};
assert!(is_error);
assert!(content.contains("not allowed"));
}
struct SlowRO {
label: String,
}
#[async_trait]
impl Tool for SlowRO {
fn name(&self) -> &str {
&self.label
}
fn description(&self) -> &str {
"slow"
}
fn input_schema(&self) -> Value {
json!({})
}
fn class(&self) -> ToolClass {
ToolClass::ReadOnly
}
async fn execute(&self, input: Value, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
let delay_ms = input["delay_ms"].as_u64().unwrap_or(0);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
Ok(ToolOutput::text(self.label.clone()))
}
}
struct OrderingMut {
label: String,
}
#[async_trait]
impl Tool for OrderingMut {
fn name(&self) -> &str {
&self.label
}
fn description(&self) -> &str {
"mut"
}
fn input_schema(&self) -> Value {
json!({})
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::text(self.label.clone()))
}
}
fn extract_text(c: &Content) -> &str {
match c {
Content::ToolResult { content, .. } => content.as_str(),
_ => panic!("expected tool_result"),
}
}
#[tokio::test]
async fn batch_preserves_order_despite_parallel_ro() {
let reg = Arc::new(ToolRegistry::new(vec![
Arc::new(SlowRO { label: "a".into() }),
Arc::new(SlowRO { label: "b".into() }),
]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls = vec![
ToolCall {
id: "1".into(),
name: "a".into(),
input: json!({"delay_ms": 50}),
},
ToolCall {
id: "2".into(),
name: "b".into(),
input: json!({"delay_ms": 0}),
},
];
let start = std::time::Instant::now();
let results = exec.execute_batch(calls, &ctx()).await;
let elapsed = start.elapsed();
assert_eq!(results.len(), 2);
assert_eq!(extract_text(&results[0]), "a"); assert_eq!(extract_text(&results[1]), "b");
assert!(
elapsed < Duration::from_millis(150),
"unexpected slowdown: {elapsed:?}"
);
}
#[tokio::test]
async fn batch_partitions_ro_and_mut_runs() {
let reg = Arc::new(ToolRegistry::new(vec![
Arc::new(SlowRO { label: "a".into() }),
Arc::new(SlowRO { label: "b".into() }),
Arc::new(OrderingMut { label: "m".into() }),
Arc::new(SlowRO { label: "c".into() }),
]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls = vec![
ToolCall {
id: "1".into(),
name: "a".into(),
input: json!({"delay_ms": 10}),
},
ToolCall {
id: "2".into(),
name: "b".into(),
input: json!({"delay_ms": 10}),
},
ToolCall {
id: "3".into(),
name: "m".into(),
input: json!({}),
},
ToolCall {
id: "4".into(),
name: "c".into(),
input: json!({"delay_ms": 10}),
},
];
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(results.len(), 4);
assert_eq!(extract_text(&results[0]), "a");
assert_eq!(extract_text(&results[1]), "b");
assert_eq!(extract_text(&results[2]), "m");
assert_eq!(extract_text(&results[3]), "c");
}
struct FlagSetter(Arc<AtomicBool>, &'static str);
#[async_trait]
impl Tool for FlagSetter {
fn name(&self) -> &str {
self.1
}
fn description(&self) -> &str {
"flag"
}
fn input_schema(&self) -> Value {
json!({})
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
self.0.store(true, Ordering::SeqCst);
Ok(ToolOutput::text("ran"))
}
}
#[tokio::test]
async fn batch_stops_dispatching_after_cancel() {
let m1_ran = Arc::new(AtomicBool::new(false));
let m2_ran = Arc::new(AtomicBool::new(false));
let reg = Arc::new(ToolRegistry::new(vec![
Arc::new(FlagSetter(Arc::clone(&m1_ran), "m1")),
Arc::new(FlagSetter(Arc::clone(&m2_ran), "m2")),
]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let cancel = tokio_util::sync::CancellationToken::new();
let ctx = ToolContext {
working_dir: PathBuf::from("/tmp"),
cancel: cancel.clone(),
depth: 0,
max_depth: 1,
executor: empty_executor(),
};
cancel.cancel();
let calls = vec![
ToolCall {
id: "1".into(),
name: "m1".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "m2".into(),
input: json!({}),
},
];
let results = exec.execute_batch(calls, &ctx).await;
assert_eq!(results.len(), 2, "result count must match input count");
for r in &results {
let Content::ToolResult {
content, is_error, ..
} = r
else {
panic!("expected tool_result");
};
assert!(*is_error, "cancelled-before-execution should be is_error");
assert!(
content.contains("cancelled before execution"),
"got: {content}"
);
}
assert!(
!m1_ran.load(Ordering::SeqCst),
"m1 must not have run after cancel"
);
assert!(
!m2_ran.load(Ordering::SeqCst),
"m2 must not have run after cancel"
);
}
struct AlwaysDeny(&'static str);
#[async_trait]
impl ApprovalHandler for AlwaysDeny {
async fn approve(&self, _: &str, _: &Value, _: ToolClass) -> ApprovalDecision {
ApprovalDecision::Deny(self.0.to_string())
}
}
struct SlowApproval;
#[async_trait]
impl ApprovalHandler for SlowApproval {
async fn approve(&self, _: &str, _: &Value, _: ToolClass) -> ApprovalDecision {
tokio::time::sleep(Duration::from_secs(10)).await;
ApprovalDecision::Allow
}
}
#[tokio::test]
async fn approval_deny_emits_error_tool_result_and_skips_execution() {
let ran = Arc::new(AtomicBool::new(false));
let ran_clone = Arc::clone(&ran);
struct ObservingTool(Arc<AtomicBool>);
#[async_trait]
impl Tool for ObservingTool {
fn name(&self) -> &str {
"observe"
}
fn description(&self) -> &str {
"observes whether it ran"
}
fn input_schema(&self) -> Value {
json!({})
}
async fn execute(&self, _: Value, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
self.0.store(true, Ordering::SeqCst);
Ok(ToolOutput::text("ran"))
}
}
let reg = Arc::new(ToolRegistry::new(vec![Arc::new(ObservingTool(ran_clone))]));
let exec = ToolExecutor::with_approval(
reg,
Arc::new(AllowAll),
Arc::new(AlwaysDeny("blocked by user")),
);
let res = exec.execute_one(call("observe", json!({})), &ctx()).await;
let Content::ToolResult {
content, is_error, ..
} = res
else {
panic!("expected tool_result");
};
assert!(is_error, "denied call should yield is_error: true");
assert!(
content.contains("approval denied"),
"content should mark approval denial, got: {content}"
);
assert!(
content.contains("blocked by user"),
"content should preserve the deny reason, got: {content}"
);
assert!(
!ran.load(Ordering::SeqCst),
"tool must NOT have executed after approval denial"
);
}
#[tokio::test]
async fn approval_cancel_during_approve_short_circuits() {
let reg = Arc::new(ToolRegistry::new(vec![Arc::new(Echo)]));
let exec = ToolExecutor::with_approval(reg, Arc::new(AllowAll), Arc::new(SlowApproval));
let cancel = tokio_util::sync::CancellationToken::new();
let ctx = ToolContext {
working_dir: PathBuf::from("/tmp"),
cancel: cancel.clone(),
depth: 0,
max_depth: 1,
executor: empty_executor(),
};
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_clone.cancel();
});
let started = std::time::Instant::now();
let res = exec
.execute_one(call("echo", json!({"msg": "x"})), &ctx)
.await;
let elapsed = started.elapsed();
let Content::ToolResult {
content, is_error, ..
} = res
else {
panic!("expected tool_result");
};
assert!(is_error, "cancel during approval should yield is_error");
assert!(
content.contains("cancelled"),
"content should mention cancellation, got: {content}"
);
assert!(
elapsed < Duration::from_secs(1),
"cancel should win the race against approve(); took {elapsed:?}"
);
}
struct ConcurrencyProbe {
label: String,
class: ToolClass,
delay_ms: u64,
active: Arc<AtomicUsize>,
max_seen: Arc<AtomicUsize>,
}
#[async_trait]
impl Tool for ConcurrencyProbe {
fn name(&self) -> &str {
&self.label
}
fn description(&self) -> &str {
"concurrency probe"
}
fn input_schema(&self) -> Value {
json!({})
}
fn class(&self) -> ToolClass {
self.class
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let cur = self.active.fetch_add(1, Ordering::SeqCst) + 1;
let mut prev = self.max_seen.load(Ordering::SeqCst);
while cur > prev {
match self
.max_seen
.compare_exchange(prev, cur, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => break,
Err(actual) => prev = actual,
}
}
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
self.active.fetch_sub(1, Ordering::SeqCst);
Ok(ToolOutput::text(self.label.clone()))
}
}
fn make_probe(
label: &str,
class: ToolClass,
delay_ms: u64,
active: &Arc<AtomicUsize>,
max_seen: &Arc<AtomicUsize>,
) -> Arc<ConcurrencyProbe> {
Arc::new(ConcurrencyProbe {
label: label.into(),
class,
delay_ms,
active: Arc::clone(active),
max_seen: Arc::clone(max_seen),
})
}
#[tokio::test]
async fn default_mutator_serializes_via_width_one_semaphore() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let probe = make_probe("mut", ToolClass::Mutating, 30, &active, &max_seen);
let reg = Arc::new(ToolRegistry::new(vec![probe]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls: Vec<ToolCall> = (0..3)
.map(|i| ToolCall {
id: format!("{i}"),
name: "mut".into(),
input: json!({}),
})
.collect();
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(results.len(), 3);
assert_eq!(
max_seen.load(Ordering::SeqCst),
1,
"default-Mutating routes to the width-1 serial pool"
);
}
#[tokio::test]
async fn promoted_mutator_runs_in_parallel_up_to_class_cap() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let probe = make_probe("pmut", ToolClass::Mutating, 60, &active, &max_seen);
let reg = Arc::new(ToolRegistry::new(vec![probe]));
let cfg = ConcurrencyConfig::new(20, 3, vec![("pmut".to_string(), ToolConcurrency::on())]);
let exec = ToolExecutor::with_approval_and_concurrency(
reg,
Arc::new(AllowAll),
Arc::new(AutoApprove),
cfg,
);
let calls: Vec<ToolCall> = (0..5)
.map(|i| ToolCall {
id: format!("{i}"),
name: "pmut".into(),
input: json!({}),
})
.collect();
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(results.len(), 5);
let observed = max_seen.load(Ordering::SeqCst);
assert_eq!(
observed, 3,
"promoted Mutating fills concurrent_mut up to class cap (got {observed})"
);
}
#[tokio::test]
async fn per_tool_cap_binds_below_class_cap() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let probe = make_probe("ptcap", ToolClass::Mutating, 60, &active, &max_seen);
let reg = Arc::new(ToolRegistry::new(vec![probe]));
let cfg = ConcurrencyConfig::new(
20,
10,
vec![("ptcap".to_string(), ToolConcurrency::on().max(2))],
);
let exec = ToolExecutor::with_approval_and_concurrency(
reg,
Arc::new(AllowAll),
Arc::new(AutoApprove),
cfg,
);
let calls: Vec<ToolCall> = (0..5)
.map(|i| ToolCall {
id: format!("{i}"),
name: "ptcap".into(),
input: json!({}),
})
.collect();
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(results.len(), 5);
let observed = max_seen.load(Ordering::SeqCst);
assert_eq!(
observed, 2,
"per-tool cap=2 binds below class cap=10 (got {observed})"
);
}
#[tokio::test]
async fn cancel_during_permit_acquire_short_circuits() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let probe = make_probe("slow", ToolClass::Mutating, 500, &active, &max_seen);
let reg = Arc::new(ToolRegistry::new(vec![probe]));
let cfg = ConcurrencyConfig::new(
20,
10,
vec![("slow".to_string(), ToolConcurrency::on().max(1))],
);
let exec = ToolExecutor::with_approval_and_concurrency(
reg,
Arc::new(AllowAll),
Arc::new(AutoApprove),
cfg,
);
let cancel = tokio_util::sync::CancellationToken::new();
let ctx_local = ToolContext {
working_dir: PathBuf::from("/tmp"),
cancel: cancel.clone(),
depth: 0,
max_depth: 1,
executor: empty_executor(),
};
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_clone.cancel();
});
let calls = vec![
ToolCall {
id: "1".into(),
name: "slow".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "slow".into(),
input: json!({}),
},
];
let results = exec.execute_batch(calls, &ctx_local).await;
assert_eq!(results.len(), 2);
let r2 = &results[1];
let Content::ToolResult {
content, is_error, ..
} = r2
else {
panic!("expected tool_result");
};
assert!(*is_error, "second call must be is_error after cancel");
assert!(
content.contains("cancelled before execution"),
"second call must short-circuit, got: {content}"
);
assert_eq!(
max_seen.load(Ordering::SeqCst),
1,
"only the first probe should have entered execute()"
);
}
#[tokio::test]
async fn result_order_preserved_with_parallel_completion() {
let active = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let slow = make_probe("slow", ToolClass::ReadOnly, 100, &active, &max_seen);
let fast = make_probe("fast", ToolClass::ReadOnly, 0, &active, &max_seen);
let reg = Arc::new(ToolRegistry::new(vec![slow, fast]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls = vec![
ToolCall {
id: "1".into(),
name: "slow".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "fast".into(),
input: json!({}),
},
];
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(extract_text(&results[0]), "slow");
assert_eq!(extract_text(&results[1]), "fast");
}
#[tokio::test]
async fn mixed_class_batch_independent_pools() {
let ro_active = Arc::new(AtomicUsize::new(0));
let ro_max = Arc::new(AtomicUsize::new(0));
let pmut_active = Arc::new(AtomicUsize::new(0));
let pmut_max = Arc::new(AtomicUsize::new(0));
let smut_active = Arc::new(AtomicUsize::new(0));
let smut_max = Arc::new(AtomicUsize::new(0));
let ro = make_probe("ro", ToolClass::ReadOnly, 80, &ro_active, &ro_max);
let pmut = make_probe("pmut", ToolClass::Mutating, 80, &pmut_active, &pmut_max);
let smut = make_probe("smut", ToolClass::Mutating, 80, &smut_active, &smut_max);
let reg = Arc::new(ToolRegistry::new(vec![ro, pmut, smut]));
let cfg = ConcurrencyConfig::new(5, 2, vec![("pmut".to_string(), ToolConcurrency::on())]);
let exec = ToolExecutor::with_approval_and_concurrency(
reg,
Arc::new(AllowAll),
Arc::new(AutoApprove),
cfg,
);
let mut calls: Vec<ToolCall> = Vec::new();
for i in 0..5 {
calls.push(ToolCall {
id: format!("ro{i}"),
name: "ro".into(),
input: json!({}),
});
}
for i in 0..3 {
calls.push(ToolCall {
id: format!("pmut{i}"),
name: "pmut".into(),
input: json!({}),
});
}
for i in 0..2 {
calls.push(ToolCall {
id: format!("smut{i}"),
name: "smut".into(),
input: json!({}),
});
}
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(results.len(), 10);
assert_eq!(ro_max.load(Ordering::SeqCst), 5, "RO pool cap not hit");
assert_eq!(
pmut_max.load(Ordering::SeqCst),
2,
"promoted-mutator pool capped at 2"
);
assert_eq!(
smut_max.load(Ordering::SeqCst),
1,
"default-mutator pool capped at 1"
);
}
struct DelayedFlagSetter {
flag: Arc<AtomicBool>,
delay_ms: u64,
}
#[async_trait]
impl Tool for DelayedFlagSetter {
fn name(&self) -> &str {
"set_flag"
}
fn description(&self) -> &str {
"set"
}
fn input_schema(&self) -> Value {
json!({})
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
self.flag.store(true, Ordering::SeqCst);
Ok(ToolOutput::text("set"))
}
}
struct FlagReader(Arc<AtomicBool>);
#[async_trait]
impl Tool for FlagReader {
fn name(&self) -> &str {
"read_flag"
}
fn description(&self) -> &str {
"read"
}
fn input_schema(&self) -> Value {
json!({})
}
fn class(&self) -> ToolClass {
ToolClass::ReadOnly
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let v = self.0.load(Ordering::SeqCst);
Ok(ToolOutput::text(if v { "true" } else { "false" }))
}
}
#[tokio::test]
async fn default_mutator_acts_as_barrier_against_subsequent_ro() {
let flag = Arc::new(AtomicBool::new(false));
let setter = Arc::new(DelayedFlagSetter {
flag: Arc::clone(&flag),
delay_ms: 100,
});
let reader = Arc::new(FlagReader(Arc::clone(&flag)));
let reg = Arc::new(ToolRegistry::new(vec![setter, reader]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls = vec![
ToolCall {
id: "1".into(),
name: "set_flag".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "read_flag".into(),
input: json!({}),
},
];
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(extract_text(&results[0]), "set");
assert_eq!(
extract_text(&results[1]),
"true",
"ReadOnly after default-Mutating must observe the mutation \
— partition between class boundaries is the load-bearing invariant"
);
}
struct OrderRecorder {
label: &'static str,
log: Arc<std::sync::Mutex<Vec<&'static str>>>,
}
#[async_trait]
impl Tool for OrderRecorder {
fn name(&self) -> &str {
self.label
}
fn description(&self) -> &str {
"order recorder"
}
fn input_schema(&self) -> Value {
json!({})
}
async fn execute(
&self,
_input: Value,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
tokio::time::sleep(Duration::from_millis(10)).await;
self.log.lock().unwrap().push(self.label);
Ok(ToolOutput::text(self.label))
}
}
#[tokio::test]
async fn adjacent_serial_mutators_execute_in_input_order() {
let log = Arc::new(std::sync::Mutex::new(Vec::<&'static str>::new()));
let reg = Arc::new(ToolRegistry::new(vec![
Arc::new(OrderRecorder {
label: "first",
log: Arc::clone(&log),
}),
Arc::new(OrderRecorder {
label: "second",
log: Arc::clone(&log),
}),
Arc::new(OrderRecorder {
label: "third",
log: Arc::clone(&log),
}),
]));
let exec = ToolExecutor::new(reg, Arc::new(AllowAll));
let calls = vec![
ToolCall {
id: "1".into(),
name: "first".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "second".into(),
input: json!({}),
},
ToolCall {
id: "3".into(),
name: "third".into(),
input: json!({}),
},
];
let _results = exec.execute_batch(calls, &ctx()).await;
let observed = log.lock().unwrap().clone();
assert_eq!(
observed,
vec!["first", "second", "third"],
"adjacent SerialMut calls must execute side effects in input order"
);
}
#[tokio::test]
async fn promoted_mutator_acts_as_barrier_against_subsequent_ro() {
let flag = Arc::new(AtomicBool::new(false));
let setter = Arc::new(DelayedFlagSetter {
flag: Arc::clone(&flag),
delay_ms: 100,
});
let reader = Arc::new(FlagReader(Arc::clone(&flag)));
let reg = Arc::new(ToolRegistry::new(vec![setter, reader]));
let cfg = ConcurrencyConfig::new(20, 10, vec![("set_flag".into(), ToolConcurrency::on())]);
let exec = ToolExecutor::with_approval_and_concurrency(
reg,
Arc::new(AllowAll),
Arc::new(AutoApprove),
cfg,
);
let calls = vec![
ToolCall {
id: "1".into(),
name: "set_flag".into(),
input: json!({}),
},
ToolCall {
id: "2".into(),
name: "read_flag".into(),
input: json!({}),
},
];
let results = exec.execute_batch(calls, &ctx()).await;
assert_eq!(
extract_text(&results[1]),
"true",
"ReadOnly after promoted-Mutating must observe the mutation"
);
}
#[test]
#[should_panic(expected = "max_concurrent_reads requires n > 0")]
fn concurrency_config_panics_on_zero_read_cap() {
let _ = ConcurrencyConfig::new(0, 10, std::iter::empty());
}
#[test]
#[should_panic(expected = "max_concurrent_mutations requires n > 0")]
fn concurrency_config_panics_on_zero_mutation_cap() {
let _ = ConcurrencyConfig::new(20, 0, std::iter::empty());
}
#[test]
#[should_panic(expected = "ToolConcurrency::max requires n > 0")]
fn tool_concurrency_max_panics_on_zero() {
let _ = ToolConcurrency::on().max(0);
}
}