use std::sync::Arc;
use async_trait::async_trait;
use crate::traits::guard::{Guard, GuardResult};
use crate::traits::tool::ErasedTool;
use crate::traits::tracker::Tracker;
use crate::types::action::Action;
use crate::types::agent_state::AgentState;
use crate::types::tool_call::ToolCall;
#[derive(Debug, Clone)]
pub struct PendingToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
impl From<&ToolCall> for PendingToolCall {
fn from(tc: &ToolCall) -> Self {
Self {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolResult {
pub id: String,
pub output: String,
}
#[async_trait]
pub trait ExecutionStrategy: Send + Sync {
async fn execute_batch(
&self,
calls: Vec<PendingToolCall>,
tools: &[Arc<dyn ErasedTool>],
guards: &[Arc<dyn Guard>],
state: &AgentState,
) -> Vec<ToolResult>;
}
pub struct SequentialStrategy;
#[async_trait]
impl ExecutionStrategy for SequentialStrategy {
async fn execute_batch(
&self,
calls: Vec<PendingToolCall>,
tools: &[Arc<dyn ErasedTool>],
guards: &[Arc<dyn Guard>],
_state: &AgentState,
) -> Vec<ToolResult> {
let mut results = Vec::with_capacity(calls.len());
for call in calls {
let output = execute_single(&call, tools, guards).await;
results.push(ToolResult {
id: call.id,
output,
});
}
results
}
}
pub struct ParallelStrategy {
pub max_concurrency: usize,
}
impl ParallelStrategy {
#[must_use]
pub fn new(max_concurrency: usize) -> Self {
Self {
max_concurrency: max_concurrency.max(1),
}
}
}
#[async_trait]
impl ExecutionStrategy for ParallelStrategy {
async fn execute_batch(
&self,
calls: Vec<PendingToolCall>,
tools: &[Arc<dyn ErasedTool>],
guards: &[Arc<dyn Guard>],
_state: &AgentState,
) -> Vec<ToolResult> {
use tokio::sync::Semaphore;
let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
let tools = Arc::new(tools.to_vec());
let guards = Arc::new(guards.to_vec());
let call_ids: Vec<String> = calls.iter().map(|c| c.id.clone()).collect();
let mut handles = Vec::with_capacity(calls.len());
for call in calls {
let sem = semaphore.clone();
let tools = tools.clone();
let guards = guards.clone();
handles.push(tokio::spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
let output = execute_single(&call, &tools, &guards).await;
ToolResult {
id: call.id,
output,
}
}));
}
let mut results = Vec::with_capacity(handles.len());
for (i, handle) in handles.into_iter().enumerate() {
match handle.await {
Ok(result) => results.push(result),
Err(e) => results.push(ToolResult {
id: call_ids[i].clone(),
output: format!("Error: task panicked: {e}"),
}),
}
}
results
}
}
pub struct AdaptiveStrategy {
tracker: Arc<dyn Tracker>,
}
impl AdaptiveStrategy {
#[must_use]
pub fn new(tracker: Arc<dyn Tracker>) -> Self {
Self { tracker }
}
}
#[async_trait]
impl ExecutionStrategy for AdaptiveStrategy {
async fn execute_batch(
&self,
calls: Vec<PendingToolCall>,
tools: &[Arc<dyn ErasedTool>],
guards: &[Arc<dyn Guard>],
state: &AgentState,
) -> Vec<ToolResult> {
let concurrency = self.tracker.recommended_concurrency(state);
if concurrency <= 1 {
SequentialStrategy
.execute_batch(calls, tools, guards, state)
.await
} else {
ParallelStrategy::new(concurrency)
.execute_batch(calls, tools, guards, state)
.await
}
}
}
async fn execute_single(
call: &PendingToolCall,
tools: &[Arc<dyn ErasedTool>],
guards: &[Arc<dyn Guard>],
) -> String {
let action = Action::ToolCall {
name: call.name.clone(),
arguments: call.arguments.clone(),
};
for guard in guards {
let guard_name = guard.name().to_string();
let action_ref = &action;
let guard_span = tracing::info_span!(
target: "traitclaw::guard",
"guard.check",
guard.name = guard_name.as_str(),
guard.result = tracing::field::Empty,
);
let _g = guard_span.enter();
let result =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| guard.check(action_ref)));
match result {
Ok(GuardResult::Allow) => {
guard_span.record("guard.result", "allow");
}
Ok(GuardResult::Deny { reason, .. }) => {
guard_span.record("guard.result", "deny");
return format!("Error: Action blocked by guard: {reason}");
}
Ok(GuardResult::Sanitize { warning, .. }) => {
guard_span.record("guard.result", "sanitize");
tracing::info!(
target: "traitclaw::guard",
guard = guard_name.as_str(),
"Guard sanitized: {warning}"
);
}
Err(_) => {
guard_span.record("guard.result", "panic");
tracing::warn!(
target: "traitclaw::guard",
guard = guard_name.as_str(),
"Guard panicked — denying action for safety"
);
return format!("Error: Action blocked — guard '{guard_name}' panicked");
}
}
}
let tool_span = tracing::info_span!(
target: "traitclaw::tool",
"tool.call",
tool.name = call.name.as_str(),
tool.success = tracing::field::Empty,
);
let _t = tool_span.enter();
if let Some(tool) = tools.iter().find(|t| t.name() == call.name) {
match tool.execute_json(call.arguments.clone()).await {
Ok(output) => {
tool_span.record("tool.success", true);
serde_json::to_string(&output)
.unwrap_or_else(|e| format!("Error serializing output: {e}"))
}
Err(e) => {
tool_span.record("tool.success", false);
format!("Error executing tool: {e}")
}
}
} else {
tool_span.record("tool.success", false);
let available: Vec<_> = tools.iter().map(|t| t.name().to_string()).collect();
format!(
"Error: Tool '{}' not found. Available: {}",
call.name,
available.join(", ")
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::guard::NoopGuard;
struct AddTool;
#[async_trait]
impl ErasedTool for AddTool {
fn name(&self) -> &'static str {
"add"
}
fn description(&self) -> &'static str {
"Adds two numbers"
}
fn schema(&self) -> crate::traits::tool::ToolSchema {
crate::traits::tool::ToolSchema {
name: "add".into(),
description: "add".into(),
parameters: serde_json::json!({}),
}
}
async fn execute_json(
&self,
_args: serde_json::Value,
) -> std::result::Result<serde_json::Value, crate::Error> {
Ok(serde_json::json!("result"))
}
}
fn make_calls(n: usize) -> Vec<PendingToolCall> {
(0..n)
.map(|i| PendingToolCall {
id: format!("call-{i}"),
name: "add".into(),
arguments: serde_json::json!({}),
})
.collect()
}
#[tokio::test]
async fn test_sequential_executes_in_order() {
let strategy = SequentialStrategy;
let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
let results = strategy
.execute_batch(make_calls(3), &tools, &guards, &state)
.await;
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, "call-0");
assert_eq!(results[1].id, "call-1");
assert_eq!(results[2].id, "call-2");
for r in &results {
assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
}
}
#[tokio::test]
async fn test_parallel_executes_concurrently() {
let strategy = ParallelStrategy::new(4);
let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
let results = strategy
.execute_batch(make_calls(5), &tools, &guards, &state)
.await;
assert_eq!(results.len(), 5);
for r in &results {
assert!(!r.output.starts_with("Error"), "unexpected: {}", r.output);
}
}
#[tokio::test]
async fn test_guard_blocks_propagate() {
use crate::traits::guard::{Guard, GuardResult};
struct DenyGuard;
impl Guard for DenyGuard {
fn name(&self) -> &'static str {
"deny"
}
fn check(&self, _action: &Action) -> GuardResult {
GuardResult::Deny {
reason: "blocked".into(),
severity: crate::traits::guard::GuardSeverity::High,
}
}
}
let strategy = SequentialStrategy;
let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(DenyGuard)];
let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
let results = strategy
.execute_batch(make_calls(1), &tools, &guards, &state)
.await;
assert_eq!(results.len(), 1);
assert!(results[0].output.contains("blocked"));
}
#[tokio::test]
async fn test_guard_panic_defaults_to_deny() {
use crate::traits::guard::{Guard, GuardResult};
struct PanicGuard;
impl Guard for PanicGuard {
fn name(&self) -> &'static str {
"panic_guard"
}
fn check(&self, _action: &Action) -> GuardResult {
panic!("intentional panic in guard");
}
}
let strategy = SequentialStrategy;
let tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(AddTool)];
let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(PanicGuard)];
let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
let results = strategy
.execute_batch(make_calls(1), &tools, &guards, &state)
.await;
assert_eq!(results.len(), 1);
assert!(
results[0].output.contains("panicked"),
"Expected deny on panic, got: {}",
results[0].output
);
}
#[tokio::test]
async fn test_tool_not_found_returns_error() {
let strategy = SequentialStrategy;
let tools: Vec<Arc<dyn ErasedTool>> = vec![]; let guards: Vec<Arc<dyn Guard>> = vec![Arc::new(NoopGuard)];
let state = AgentState::new(crate::types::model_info::ModelTier::Small, 4096);
let calls = vec![PendingToolCall {
id: "c1".into(),
name: "nonexistent".into(),
arguments: serde_json::json!({}),
}];
let results = strategy.execute_batch(calls, &tools, &guards, &state).await;
assert_eq!(results.len(), 1);
assert!(
results[0].output.contains("not found"),
"Expected 'not found', got: {}",
results[0].output
);
}
}