use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::Value;
#[derive(Debug)]
pub struct ToolCallSummary<'a> {
pub id: &'a str,
pub name: &'a str,
pub arguments: &'a Value,
}
pub type PriorityFn = dyn Fn(&ToolCallSummary<'_>) -> i32 + Send + Sync;
pub type ToolExecutionStrategyFuture<'a> =
Pin<Box<dyn Future<Output = Vec<Vec<usize>>> + Send + 'a>>;
pub trait ToolExecutionStrategy: Send + Sync {
fn partition(&self, tool_calls: &[ToolCallSummary<'_>]) -> ToolExecutionStrategyFuture<'_>;
}
#[derive(Default)]
pub enum ToolExecutionPolicy {
#[default]
Concurrent,
Sequential,
Priority(Arc<PriorityFn>),
Custom(Arc<dyn ToolExecutionStrategy>),
}
impl Clone for ToolExecutionPolicy {
fn clone(&self) -> Self {
match self {
Self::Concurrent => Self::Concurrent,
Self::Sequential => Self::Sequential,
Self::Priority(f) => Self::Priority(Arc::clone(f)),
Self::Custom(s) => Self::Custom(Arc::clone(s)),
}
}
}
impl std::fmt::Debug for ToolExecutionPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Concurrent => write!(f, "Concurrent"),
Self::Sequential => write!(f, "Sequential"),
Self::Priority(_) => write!(f, "Priority(...)"),
Self::Custom(_) => write!(f, "Custom(...)"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_concurrent() {
assert!(matches!(
ToolExecutionPolicy::default(),
ToolExecutionPolicy::Concurrent
));
}
#[test]
fn debug_formatting() {
assert_eq!(
format!("{:?}", ToolExecutionPolicy::Concurrent),
"Concurrent"
);
assert_eq!(
format!("{:?}", ToolExecutionPolicy::Sequential),
"Sequential"
);
let pf: Arc<PriorityFn> = Arc::new(|_| 0);
assert_eq!(
format!("{:?}", ToolExecutionPolicy::Priority(pf)),
"Priority(...)"
);
}
#[test]
fn tool_call_summary_debug() {
let args = serde_json::json!({"cmd": "ls"});
let summary = ToolCallSummary {
id: "call_1",
name: "bash",
arguments: &args,
};
let debug = format!("{summary:?}");
assert!(debug.contains("bash"));
assert!(debug.contains("call_1"));
}
}