use crate::core::tool_spec::ToolSpec;
use crate::error::{ToolError, ToolResolutionError};
use futures_util::future::BoxFuture;
use indexmap::IndexMap;
use serde_json::Value;
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
pub type ToolCallable = Arc<
dyn Fn(IndexMap<String, Value>) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync,
>;
pub trait IntoToolCallable {
fn into_callable(self) -> ToolCallable;
}
impl IntoToolCallable for ToolCallable {
fn into_callable(self) -> Self {
self
}
}
impl<F> IntoToolCallable for F
where
F: Fn(Vec<String>) -> Result<String, ToolResolutionError> + Send + Sync + 'static,
{
fn into_callable(self) -> ToolCallable {
let func_arc = Arc::new(self);
Arc::new(move |args| {
let func = func_arc.clone();
Box::pin(async move {
let mut vec_args = Vec::new();
for (k, v) in args {
let val_str = match v {
Value::String(s) => s,
other => other.to_string(),
};
vec_args.push(format!("{}={}", k, val_str));
}
(*func)(vec_args)
.map(Value::String)
.map_err(ToolError::Resolution)
})
})
}
}
pub use crate::core::tool_spec::ParamModel;
#[derive(Debug, Clone, PartialEq)]
pub enum PrerequisiteSpec {
NameOnly(String),
ArgMatched {
tool: String,
match_arg: String,
},
}
#[derive(Clone)]
pub struct ToolDef {
pub spec: ToolSpec,
pub callable: ToolCallable,
pub prerequisites: Vec<PrerequisiteSpec>,
}
impl ToolDef {
pub fn new<C>(spec: ToolSpec, callable: C) -> Self
where
C: IntoToolCallable,
{
Self {
spec,
callable: callable.into_callable(),
prerequisites: Vec::new(),
}
}
pub fn with_prerequisites(mut self, prereqs: Vec<PrerequisiteSpec>) -> Self {
self.prerequisites = prereqs;
self
}
pub fn name(&self) -> &str {
&self.spec.name
}
}
impl fmt::Debug for ToolDef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ToolDef")
.field("name", &self.spec.name)
.field("prerequisites", &self.prerequisites)
.finish()
}
}
pub struct Workflow {
pub name: String,
pub description: String,
pub tools: IndexMap<String, ToolDef>,
pub required_steps: Vec<String>,
pub terminal_tools: HashSet<String>,
pub system_prompt_template: String,
}
impl Workflow {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
tools: IndexMap<String, ToolDef>,
required_steps: Vec<String>,
terminal_tool: TerminalToolInput,
system_prompt_template: impl Into<String>,
) -> Result<Self, String> {
let name = name.into();
let description = description.into();
let system_prompt_template = system_prompt_template.into();
for (key, def) in &tools {
if key != &def.spec.name {
return Err(format!(
"Tool dict key '{}' does not match tool definition name '{}'",
key, def.spec.name
));
}
}
let tool_names: HashSet<&str> = tools.keys().map(|s| s.as_str()).collect();
for step in &required_steps {
if !tool_names.contains(step.as_str()) {
return Err(format!("Required step '{}' not found in tools", step));
}
}
let terminal_set: HashSet<String> = match terminal_tool {
TerminalToolInput::Single(s) => {
let mut set = HashSet::new();
set.insert(s);
set
}
TerminalToolInput::Multiple(v) => v.into_iter().collect(),
};
for t in &terminal_set {
if !tool_names.contains(t.as_str()) {
return Err(format!("Terminal tool '{}' not found in tools", t));
}
}
let required_set: HashSet<&str> = required_steps.iter().map(|s| s.as_str()).collect();
for t in &terminal_set {
if required_set.contains(t.as_str()) {
return Err(format!(
"Terminal tool '{}' cannot also be a required step",
t
));
}
}
for (_, def) in &tools {
for prereq in &def.prerequisites {
let prereq_tool = match prereq {
PrerequisiteSpec::NameOnly(name) => name.as_str(),
PrerequisiteSpec::ArgMatched { tool, .. } => tool.as_str(),
};
if !tool_names.contains(prereq_tool) {
return Err(format!(
"Prerequisite references tool '{}' which is not in the tools map",
prereq_tool
));
}
}
}
Ok(Self {
name,
description,
tools,
required_steps,
terminal_tools: terminal_set,
system_prompt_template,
})
}
pub fn build_system_prompt(&self, vars: &IndexMap<String, String>) -> String {
let mut result = self.system_prompt_template.clone();
for (key, value) in vars {
let pattern = format!("{{{}}}", key);
result = result.replace(&pattern, value);
}
result
}
pub fn get_tool_specs(&self) -> Vec<&ToolSpec> {
self.tools.values().map(|def| &def.spec).collect()
}
pub fn get_callable(&self, tool_name: &str) -> Result<ToolCallable, String> {
match self.tools.get(tool_name) {
Some(def) => Ok(def.callable.clone()),
None => Err(format!("Tool '{}' not found", tool_name)),
}
}
}
impl fmt::Debug for Workflow {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Workflow")
.field("name", &self.name)
.field("required_steps", &self.required_steps)
.field("terminal_tools", &self.terminal_tools)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum TerminalToolInput {
Single(String),
Multiple(Vec<String>),
}
impl From<String> for TerminalToolInput {
fn from(s: String) -> Self {
Self::Single(s)
}
}
impl From<Vec<String>> for TerminalToolInput {
fn from(v: Vec<String>) -> Self {
Self::Multiple(v)
}
}