use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use lellm_core::{Message, ToolCall, ToolDefinition, ToolError, ToolErrorKind, ToolResult};
use super::super::event::AgentEvent;
use super::super::retry::RetryPolicy;
use super::ToolFn;
use tokio::sync::mpsc::Sender;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParallelSafety {
Safe,
CategoryExclusive,
Exclusive,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ToolCategory(pub Cow<'static, str>);
impl ToolCategory {
pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
pub const NETWORK: Self = Self(Cow::Borrowed("network"));
pub const DATABASE: Self = Self(Cow::Borrowed("database"));
pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
Self(name.into())
}
}
#[derive(Clone)]
pub struct ToolRegistration {
pub definition: ToolDefinition,
pub safety: ParallelSafety,
pub category: Option<ToolCategory>,
pub func: ToolFn,
}
impl ToolRegistration {
pub fn safe<F, Fut>(def: ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Safe,
category: None,
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn category_exclusive<F, Fut>(def: ToolDefinition, category: ToolCategory, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::CategoryExclusive,
category: Some(category),
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn exclusive<F, Fut>(def: ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Exclusive,
category: None,
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
}
#[derive(Debug)]
pub struct BatchExecutionResult {
pub results: Vec<Message>,
pub panicked: bool,
}
#[derive(Clone)]
pub struct ToolExecutor {
tools: Arc<HashMap<String, ToolRegistration>>,
retry_policy: RetryPolicy,
}
impl Default for ToolExecutor {
fn default() -> Self {
Self {
tools: Arc::new(HashMap::new()),
retry_policy: RetryPolicy::default(),
}
}
}
impl ToolExecutor {
pub fn new() -> Self {
Self::default()
}
pub fn with_retry_policy(policy: RetryPolicy) -> Self {
Self {
retry_policy: policy,
..Default::default()
}
}
pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
self.retry_policy = policy;
}
pub fn has_tools(&self) -> bool {
!self.tools.is_empty()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools.values().map(|t| t.definition.clone()).collect()
}
pub fn register(&mut self, name: &str, reg: ToolRegistration) {
Arc::get_mut(&mut self.tools)
.expect("ToolExecutor already cloned, cannot register more tools")
.insert(name.to_string(), reg);
}
pub fn safety_for(&self, name: &str) -> ParallelSafety {
self.tools
.get(name)
.map(|t| t.safety.clone())
.unwrap_or(ParallelSafety::Exclusive)
}
fn category_for(&self, name: &str) -> Option<ToolCategory> {
self.tools.get(name).and_then(|t| t.category.clone())
}
pub async fn execute(&self, call: &ToolCall) -> ToolResult {
match self.tools.get(&call.name) {
Some(entry) => self
.retry_policy
.execute_with_retry(&entry.func, &call.arguments)
.await,
None => Err(ToolError {
kind: ToolErrorKind::NotFound,
message: format!("unknown tool: {}", call.name),
}),
}
}
pub async fn execute_with_emission(
&self,
call: &ToolCall,
tx: &Sender<AgentEvent>,
) -> ToolResult {
match self.tools.get(&call.name) {
Some(entry) => self
.retry_policy
.execute_with_retry_and_emission(
&entry.func,
&call.arguments,
tx,
&call.id,
)
.await,
None => Err(ToolError {
kind: ToolErrorKind::NotFound,
message: format!("unknown tool: {}", call.name),
}),
}
}
pub async fn execute_batch(&self, calls: &[ToolCall]) -> BatchExecutionResult {
if calls.is_empty() {
return BatchExecutionResult {
results: Vec::new(),
panicked: false,
};
}
let mut safe_calls: Vec<(usize, ToolCall)> = Vec::new();
let mut category_calls: HashMap<ToolCategory, Vec<(usize, ToolCall)>> = HashMap::new();
let mut exclusive_calls: Vec<(usize, ToolCall)> = Vec::new();
for (idx, call) in calls.iter().enumerate() {
match self.safety_for(&call.name) {
ParallelSafety::Safe => safe_calls.push((idx, call.clone())),
ParallelSafety::CategoryExclusive => {
if let Some(cat) = self.category_for(&call.name) {
category_calls
.entry(cat)
.or_default()
.push((idx, call.clone()));
} else {
exclusive_calls.push((idx, call.clone()));
}
}
ParallelSafety::Exclusive => exclusive_calls.push((idx, call.clone())),
}
}
let mut group_handles: Vec<tokio::task::JoinHandle<Vec<(usize, Message)>>> = Vec::new();
let mut group_indices: Vec<Vec<usize>> = Vec::new();
let executor = Arc::new(self.clone());
if !safe_calls.is_empty() {
let exe = Arc::clone(&executor);
let indices: Vec<usize> = safe_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
exe.run_parallel_indexed(safe_calls).await
}));
group_indices.push(indices);
}
for group_calls in category_calls.into_values() {
let exe = Arc::clone(&executor);
let indices: Vec<usize> = group_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
exe.run_serial_indexed(group_calls).await
}));
group_indices.push(indices);
}
if !exclusive_calls.is_empty() {
let exe = Arc::clone(&executor);
let indices: Vec<usize> = exclusive_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
exe.run_serial_indexed(exclusive_calls).await
}));
group_indices.push(indices);
}
let mut results: Vec<Option<Message>> = vec![None; calls.len()];
let mut panicked = false;
let all_handles = futures_util::future::join_all(group_handles).await;
for (handle_result, indices) in all_handles.into_iter().zip(group_indices.into_iter()) {
match handle_result {
Ok(indexed_messages) => {
for (idx, msg) in indexed_messages {
results[idx] = Some(msg);
}
}
Err(join_err) => {
panicked = true;
for idx in indices {
let call = &calls[idx];
results[idx] = Some(Message::tool_result(
call,
&Err(ToolError {
kind: ToolErrorKind::Internal,
message: format!("tool group task panicked: {join_err}"),
}),
));
}
}
}
}
BatchExecutionResult {
results: results.into_iter().flatten().collect(),
panicked,
}
}
async fn run_parallel_indexed(
&self,
calls: Vec<(usize, ToolCall)>,
) -> Vec<(usize, Message)> {
let handles: Vec<_> = calls
.iter()
.map(|(idx, call)| {
let exe = self.clone();
let call = call.clone();
let idx = *idx;
tokio::spawn(async move {
let result = exe.execute(&call).await;
(idx, Message::tool_result(&call, &result))
})
})
.collect();
let raw = futures_util::future::join_all(handles).await;
raw.into_iter()
.zip(calls.into_iter())
.map(|(h, (idx, call))| match h {
Ok((_, msg)) => (idx, msg),
Err(join_err) => (
idx,
Message::tool_result(
&call,
&Err(ToolError {
kind: ToolErrorKind::Internal,
message: format!("tool '{}' task panicked: {join_err}", call.name),
}),
),
),
})
.collect()
}
async fn run_serial_indexed(
&self,
calls: Vec<(usize, ToolCall)>,
) -> Vec<(usize, Message)> {
let mut results = Vec::with_capacity(calls.len());
for (idx, call) in calls {
let exe = self.clone();
let call_clone = call.clone();
let name = call_clone.name.clone();
let exec_result =
match tokio::spawn(async move { exe.execute(&call_clone).await }).await {
Ok(tool_result) => tool_result,
Err(join_err) => Err(ToolError {
kind: ToolErrorKind::Internal,
message: format!("tool '{name}' panicked: {join_err}"),
}),
};
results.push((idx, Message::tool_result(&call, &exec_result)));
}
results
}
}