forge_guardrails/core/
workflow.rs1use crate::core::tool_spec::ToolSpec;
2use crate::error::{ToolError, ToolResolutionError};
3use futures_util::future::BoxFuture;
4use indexmap::IndexMap;
5use serde_json::Value;
6use std::collections::HashSet;
7use std::fmt;
8use std::sync::Arc;
9
10pub type ToolCallable = Arc<
12 dyn Fn(IndexMap<String, Value>) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync,
13>;
14
15pub trait IntoToolCallable {
17 fn into_callable(self) -> ToolCallable;
19}
20
21impl IntoToolCallable for ToolCallable {
22 fn into_callable(self) -> Self {
23 self
24 }
25}
26
27impl<F> IntoToolCallable for F
29where
30 F: Fn(Vec<String>) -> Result<String, ToolResolutionError> + Send + Sync + 'static,
31{
32 fn into_callable(self) -> ToolCallable {
33 let func_arc = Arc::new(self);
34 Arc::new(move |args| {
35 let func = func_arc.clone();
36 Box::pin(async move {
37 let mut vec_args = Vec::new();
38 for (k, v) in args {
39 let val_str = match v {
40 Value::String(s) => s,
41 other => other.to_string(),
42 };
43 vec_args.push(format!("{}={}", k, val_str));
44 }
45 (*func)(vec_args)
46 .map(Value::String)
47 .map_err(ToolError::Resolution)
48 })
49 })
50 }
51}
52
53pub use crate::core::tool_spec::ParamModel;
55
56#[derive(Debug, Clone, PartialEq)]
58pub enum PrerequisiteSpec {
59 NameOnly(String),
61 ArgMatched {
63 tool: String,
65 match_arg: String,
67 },
68}
69
70#[derive(Clone)]
72pub struct ToolDef {
73 pub spec: ToolSpec,
75 pub callable: ToolCallable,
77 pub prerequisites: Vec<PrerequisiteSpec>,
79}
80
81impl ToolDef {
82 pub fn new<C>(spec: ToolSpec, callable: C) -> Self
84 where
85 C: IntoToolCallable,
86 {
87 Self {
88 spec,
89 callable: callable.into_callable(),
90 prerequisites: Vec::new(),
91 }
92 }
93
94 pub fn with_prerequisites(mut self, prereqs: Vec<PrerequisiteSpec>) -> Self {
96 self.prerequisites = prereqs;
97 self
98 }
99
100 pub fn name(&self) -> &str {
102 &self.spec.name
103 }
104}
105
106impl fmt::Debug for ToolDef {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("ToolDef")
109 .field("name", &self.spec.name)
110 .field("prerequisites", &self.prerequisites)
111 .finish()
112 }
113}
114
115pub struct Workflow {
117 pub name: String,
119 pub description: String,
121 pub tools: IndexMap<String, ToolDef>,
123 pub required_steps: Vec<String>,
125 pub terminal_tools: HashSet<String>,
127 pub system_prompt_template: String,
129}
130
131impl Workflow {
132 pub fn new(
141 name: impl Into<String>,
142 description: impl Into<String>,
143 tools: IndexMap<String, ToolDef>,
144 required_steps: Vec<String>,
145 terminal_tool: TerminalToolInput,
146 system_prompt_template: impl Into<String>,
147 ) -> Result<Self, String> {
148 let name = name.into();
149 let description = description.into();
150 let system_prompt_template = system_prompt_template.into();
151
152 for (key, def) in &tools {
154 if key != &def.spec.name {
155 return Err(format!(
156 "Tool dict key '{}' does not match tool definition name '{}'",
157 key, def.spec.name
158 ));
159 }
160 }
161
162 let tool_names: HashSet<&str> = tools.keys().map(|s| s.as_str()).collect();
163
164 for step in &required_steps {
166 if !tool_names.contains(step.as_str()) {
167 return Err(format!("Required step '{}' not found in tools", step));
168 }
169 }
170
171 let terminal_set: HashSet<String> = match terminal_tool {
173 TerminalToolInput::Single(s) => {
174 let mut set = HashSet::new();
175 set.insert(s);
176 set
177 }
178 TerminalToolInput::Multiple(v) => v.into_iter().collect(),
179 };
180
181 for t in &terminal_set {
183 if !tool_names.contains(t.as_str()) {
184 return Err(format!("Terminal tool '{}' not found in tools", t));
185 }
186 }
187
188 let required_set: HashSet<&str> = required_steps.iter().map(|s| s.as_str()).collect();
190 for t in &terminal_set {
191 if required_set.contains(t.as_str()) {
192 return Err(format!(
193 "Terminal tool '{}' cannot also be a required step",
194 t
195 ));
196 }
197 }
198
199 for (_, def) in &tools {
201 for prereq in &def.prerequisites {
202 let prereq_tool = match prereq {
203 PrerequisiteSpec::NameOnly(name) => name.as_str(),
204 PrerequisiteSpec::ArgMatched { tool, .. } => tool.as_str(),
205 };
206 if !tool_names.contains(prereq_tool) {
207 return Err(format!(
208 "Prerequisite references tool '{}' which is not in the tools map",
209 prereq_tool
210 ));
211 }
212 }
213 }
214
215 Ok(Self {
216 name,
217 description,
218 tools,
219 required_steps,
220 terminal_tools: terminal_set,
221 system_prompt_template,
222 })
223 }
224
225 pub fn build_system_prompt(&self, vars: &IndexMap<String, String>) -> String {
227 let mut result = self.system_prompt_template.clone();
228 for (key, value) in vars {
229 let pattern = format!("{{{}}}", key);
230 result = result.replace(&pattern, value);
231 }
232 result
233 }
234
235 pub fn get_tool_specs(&self) -> Vec<&ToolSpec> {
237 self.tools.values().map(|def| &def.spec).collect()
238 }
239
240 pub fn get_callable(&self, tool_name: &str) -> Result<ToolCallable, String> {
244 match self.tools.get(tool_name) {
245 Some(def) => Ok(def.callable.clone()),
246 None => Err(format!("Tool '{}' not found", tool_name)),
247 }
248 }
249}
250
251impl fmt::Debug for Workflow {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 f.debug_struct("Workflow")
254 .field("name", &self.name)
255 .field("required_steps", &self.required_steps)
256 .field("terminal_tools", &self.terminal_tools)
257 .finish()
258 }
259}
260
261#[derive(Debug, Clone)]
263pub enum TerminalToolInput {
264 Single(String),
266 Multiple(Vec<String>),
268}
269
270impl From<String> for TerminalToolInput {
271 fn from(s: String) -> Self {
272 Self::Single(s)
273 }
274}
275
276impl From<Vec<String>> for TerminalToolInput {
277 fn from(v: Vec<String>) -> Self {
278 Self::Multiple(v)
279 }
280}