use std::borrow::Cow;
use std::sync::Arc;
use lellm_core::{Message, ToolCall, ToolError, ToolErrorKind, ToolResult};
use super::super::event::AgentEvent;
use super::super::retry::RetryPolicy;
use super::{ToolCatalog, ToolFn, ToolSnapshot};
use tokio::sync::mpsc::Sender;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParallelSafety {
Safe,
CategoryExclusive,
Exclusive,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ToolCategory(pub Cow<'static, str>);
impl ToolCategory {
pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
pub const NETWORK: Self = Self(Cow::Borrowed("network"));
pub const DATABASE: Self = Self(Cow::Borrowed("database"));
pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
Self(name.into())
}
}
#[derive(Clone)]
pub struct ToolRegistration {
pub(crate) definition: lellm_core::ToolDefinition,
pub(crate) safety: ParallelSafety,
pub(crate) category: Option<ToolCategory>,
pub(crate) func: ToolFn,
}
impl ToolRegistration {
pub fn definition(&self) -> &lellm_core::ToolDefinition {
&self.definition
}
pub fn safety(&self) -> &ParallelSafety {
&self.safety
}
pub fn category(&self) -> Option<&ToolCategory> {
self.category.as_ref()
}
pub fn safe<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Safe,
category: None,
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn safe_fn<T, F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
where
T: for<'de> serde::Deserialize<'de> + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
let f = Arc::new(f);
Self::safe(def, move |value| {
let f = Arc::clone(&f);
let result = serde_json::from_value::<T>(value.clone());
Box::pin(async move {
match result {
Ok(parsed) => f(parsed).await,
Err(e) => Err(ToolError::invalid_input(format!(
"invalid tool arguments: {e}"
))),
}
})
})
}
pub fn category_exclusive<F, Fut>(
def: lellm_core::ToolDefinition,
category: ToolCategory,
f: F,
) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::CategoryExclusive,
category: Some(category),
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
pub fn exclusive<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
where
F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ToolResult> + Send + 'static,
{
Self {
definition: def,
safety: ParallelSafety::Exclusive,
category: None,
func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
}
}
}
#[derive(Debug)]
pub struct BatchExecutionResult {
pub results: Vec<Message>,
pub panicked: bool,
}
#[derive(Clone)]
pub struct ToolExecutor {
catalog: Arc<dyn ToolCatalog>,
retry_policy: RetryPolicy,
}
impl ToolExecutor {
pub fn new(catalog: Arc<dyn ToolCatalog>) -> Self {
Self {
catalog,
retry_policy: RetryPolicy::default(),
}
}
pub fn with_catalog(catalog: Arc<dyn ToolCatalog>) -> Self {
Self::new(catalog)
}
pub fn with_retry_policy(catalog: Arc<dyn ToolCatalog>, policy: RetryPolicy) -> Self {
Self {
catalog,
retry_policy: policy,
}
}
pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
self.retry_policy = policy;
}
pub fn retry_policy(&self) -> RetryPolicy {
self.retry_policy.clone()
}
pub async fn snapshot(&self) -> Arc<ToolSnapshot> {
self.catalog.snapshot().await
}
pub async fn execute_with_snapshot(
&self,
call: &ToolCall,
snapshot: &ToolSnapshot,
) -> ToolResult {
match snapshot.get(&call.name) {
Some(entry) => {
self.retry_policy
.execute_with_retry(&entry.func, &call.arguments)
.await
}
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
}
}
pub async fn execute_with_emission(
&self,
call: &ToolCall,
snapshot: &ToolSnapshot,
tx: &Sender<AgentEvent>,
) -> ToolResult {
match snapshot.get(&call.name) {
Some(entry) => {
self.retry_policy
.execute_with_retry_and_emission(&entry.func, &call.arguments, tx, &call.id)
.await
}
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
}
}
}
pub async fn execute_batch_with(
calls: &[ToolCall],
snapshot: &ToolSnapshot,
retry_policy: &RetryPolicy,
) -> BatchExecutionResult {
if calls.is_empty() {
return BatchExecutionResult {
results: Vec::new(),
panicked: false,
};
}
let mut safe_calls: Vec<(usize, ToolCall)> = Vec::new();
let mut category_calls: std::collections::HashMap<ToolCategory, Vec<(usize, ToolCall)>> =
std::collections::HashMap::new();
let mut exclusive_calls: Vec<(usize, ToolCall)> = Vec::new();
for (idx, call) in calls.iter().enumerate() {
let safety = snapshot
.get(&call.name)
.map(|t| t.safety.clone())
.unwrap_or(ParallelSafety::Exclusive);
match safety {
ParallelSafety::Safe => safe_calls.push((idx, call.clone())),
ParallelSafety::CategoryExclusive => {
if let Some(cat) = snapshot.get(&call.name).and_then(|t| t.category.clone()) {
category_calls
.entry(cat)
.or_default()
.push((idx, call.clone()));
} else {
exclusive_calls.push((idx, call.clone()));
}
}
ParallelSafety::Exclusive => exclusive_calls.push((idx, call.clone())),
}
}
let mut group_handles: Vec<tokio::task::JoinHandle<Vec<(usize, Message)>>> = Vec::new();
let mut group_indices: Vec<Vec<usize>> = Vec::new();
let snapshot = Arc::new(snapshot.clone_for_spawn());
let retry_policy = retry_policy.clone();
if !safe_calls.is_empty() {
let s = Arc::clone(&snapshot);
let rp = retry_policy.clone();
let indices: Vec<usize> = safe_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
run_parallel_indexed_with(&s, &rp, safe_calls).await
}));
group_indices.push(indices);
}
for group_calls in category_calls.into_values() {
let s = Arc::clone(&snapshot);
let rp = retry_policy.clone();
let indices: Vec<usize> = group_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
run_serial_indexed_with(&s, &rp, group_calls).await
}));
group_indices.push(indices);
}
if !exclusive_calls.is_empty() {
let s = Arc::clone(&snapshot);
let rp = retry_policy.clone();
let indices: Vec<usize> = exclusive_calls.iter().map(|(i, _)| *i).collect();
group_handles.push(tokio::spawn(async move {
run_serial_indexed_with(&s, &rp, exclusive_calls).await
}));
group_indices.push(indices);
}
let mut results: Vec<Option<Message>> = vec![None; calls.len()];
let mut panicked = false;
let all_handles = futures_util::future::join_all(group_handles).await;
for (handle_result, indices) in all_handles.into_iter().zip(group_indices.into_iter()) {
match handle_result {
Ok(indexed_messages) => {
for (idx, msg) in indexed_messages {
results[idx] = Some(msg);
}
}
Err(join_err) => {
panicked = true;
for idx in indices {
let call = &calls[idx];
results[idx] = Some(Message::tool_result(
call,
&Err(ToolError {
kind: ToolErrorKind::Internal,
message: format!("tool group task panicked: {join_err}"),
}),
));
}
}
}
}
BatchExecutionResult {
results: results.into_iter().flatten().collect(),
panicked,
}
}
impl ToolSnapshot {
pub fn clone_for_spawn(&self) -> Arc<indexmap::IndexMap<String, ToolRegistration>> {
self.tools.clone()
}
}
async fn run_parallel_indexed_with(
tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
retry_policy: &RetryPolicy,
calls: Vec<(usize, ToolCall)>,
) -> Vec<(usize, Message)> {
let handles: Vec<_> = calls
.iter()
.map(|(idx, call)| {
let tools = Arc::clone(tools);
let rp = retry_policy.clone();
let call = call.clone();
let idx = *idx;
tokio::spawn(async move {
let result = match tools.get(&call.name) {
Some(entry) => rp.execute_with_retry(&entry.func, &call.arguments).await,
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
};
(idx, Message::tool_result(&call, &result))
})
})
.collect();
let raw = futures_util::future::join_all(handles).await;
raw.into_iter()
.zip(calls.into_iter())
.map(|(h, (idx, call))| match h {
Ok((_, msg)) => (idx, msg),
Err(join_err) => (
idx,
Message::tool_result(
&call,
&Err(ToolError {
kind: ToolErrorKind::Internal,
message: format!("tool '{}' task panicked: {join_err}", call.name),
}),
),
),
})
.collect()
}
async fn run_serial_indexed_with(
tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
retry_policy: &RetryPolicy,
calls: Vec<(usize, ToolCall)>,
) -> Vec<(usize, Message)> {
let mut results = Vec::with_capacity(calls.len());
for (idx, call) in calls {
let exec_result = match tools.get(&call.name) {
Some(entry) => {
retry_policy
.execute_with_retry(&entry.func, &call.arguments)
.await
}
None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
};
results.push((idx, Message::tool_result(&call, &exec_result)));
}
results
}