Skip to main content

forge_guardrails/core/
workflow.rs

1use crate::core::tool_spec::ToolSpec;
2use crate::error::{ToolError, ToolResolutionError};
3use futures_util::future::BoxFuture;
4use indexmap::IndexMap;
5use serde_json::Value;
6use std::collections::HashSet;
7use std::fmt;
8use std::sync::Arc;
9
10/// Callable signature for tool implementations: takes JSON arguments, returns a future yielding a JSON value.
11pub type ToolCallable = Arc<
12    dyn Fn(IndexMap<String, Value>) -> BoxFuture<'static, Result<Value, ToolError>> + Send + Sync,
13>;
14
15/// Trait to automatically convert sync or async tools into a standard ToolCallable.
16pub trait IntoToolCallable {
17    /// Converts this type into a boxed `ToolCallable` future wrapper.
18    fn into_callable(self) -> ToolCallable;
19}
20
21impl IntoToolCallable for ToolCallable {
22    fn into_callable(self) -> Self {
23        self
24    }
25}
26
27// Convert the old sync signature Fn(Vec<String>) -> Result<String, ToolResolutionError>
28impl<F> IntoToolCallable for F
29where
30    F: Fn(Vec<String>) -> Result<String, ToolResolutionError> + Send + Sync + 'static,
31{
32    fn into_callable(self) -> ToolCallable {
33        let func_arc = Arc::new(self);
34        Arc::new(move |args| {
35            let func = func_arc.clone();
36            Box::pin(async move {
37                let mut vec_args = Vec::new();
38                for (k, v) in args {
39                    let val_str = match v {
40                        Value::String(s) => s,
41                        other => other.to_string(),
42                    };
43                    vec_args.push(format!("{}={}", k, val_str));
44                }
45                (*func)(vec_args)
46                    .map(Value::String)
47                    .map_err(ToolError::Resolution)
48            })
49        })
50    }
51}
52
53/// Re-export ParamModel from the tool_spec module.
54pub use crate::core::tool_spec::ParamModel;
55
56/// A prerequisite specification: either name-only or arg-matched.
57#[derive(Debug, Clone, PartialEq)]
58pub enum PrerequisiteSpec {
59    /// Prerequisite satisfied solely by the occurrence of a tool call.
60    NameOnly(String),
61    /// Prerequisite satisfied by a tool call only when specific arguments match.
62    ArgMatched {
63        /// Name of the tool.
64        tool: String,
65        /// Description or key matching parameter criteria.
66        match_arg: String,
67    },
68}
69
70/// Binds a tool spec to a callable with optional prerequisites.
71#[derive(Clone)]
72pub struct ToolDef {
73    /// The tool schema/specification.
74    pub spec: ToolSpec,
75    /// The asynchronous function pointer executing the tool.
76    pub callable: ToolCallable,
77    /// Optional dependencies/prerequisites for this tool.
78    pub prerequisites: Vec<PrerequisiteSpec>,
79}
80
81impl ToolDef {
82    /// Creates a new `ToolDef` linking a spec to a callable.
83    pub fn new<C>(spec: ToolSpec, callable: C) -> Self
84    where
85        C: IntoToolCallable,
86    {
87        Self {
88            spec,
89            callable: callable.into_callable(),
90            prerequisites: Vec::new(),
91        }
92    }
93
94    /// Appends prerequisites to the tool definition.
95    pub fn with_prerequisites(mut self, prereqs: Vec<PrerequisiteSpec>) -> Self {
96        self.prerequisites = prereqs;
97        self
98    }
99
100    /// Returns the name of the tool.
101    pub fn name(&self) -> &str {
102        &self.spec.name
103    }
104}
105
106impl fmt::Debug for ToolDef {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("ToolDef")
109            .field("name", &self.spec.name)
110            .field("prerequisites", &self.prerequisites)
111            .finish()
112    }
113}
114
115/// A validated workflow definition.
116pub struct Workflow {
117    /// Name of the workflow.
118    pub name: String,
119    /// Description of the workflow.
120    pub description: String,
121    /// Map of tool names to their definition.
122    pub tools: IndexMap<String, ToolDef>,
123    /// List of step names that must be completed.
124    pub required_steps: Vec<String>,
125    /// Set of tools designated as terminal (success terminates loop).
126    pub terminal_tools: HashSet<String>,
127    /// System prompt template containing variable interpolation placeholders.
128    pub system_prompt_template: String,
129}
130
131impl Workflow {
132    /// Construct and validate a Workflow.
133    ///
134    /// Validates:
135    /// - Each tool dict key matches the tool definition's name.
136    /// - Every required step exists in the tools map.
137    /// - Every terminal tool exists in the tools map.
138    /// - No terminal tool is also a required step.
139    /// - Every prerequisite references a tool that exists in the tools map.
140    pub fn new(
141        name: impl Into<String>,
142        description: impl Into<String>,
143        tools: IndexMap<String, ToolDef>,
144        required_steps: Vec<String>,
145        terminal_tool: TerminalToolInput,
146        system_prompt_template: impl Into<String>,
147    ) -> Result<Self, String> {
148        let name = name.into();
149        let description = description.into();
150        let system_prompt_template = system_prompt_template.into();
151
152        // Validate tool dict keys match tool def names.
153        for (key, def) in &tools {
154            if key != &def.spec.name {
155                return Err(format!(
156                    "Tool dict key '{}' does not match tool definition name '{}'",
157                    key, def.spec.name
158                ));
159            }
160        }
161
162        let tool_names: HashSet<&str> = tools.keys().map(|s| s.as_str()).collect();
163
164        // Validate required steps exist in tools.
165        for step in &required_steps {
166            if !tool_names.contains(step.as_str()) {
167                return Err(format!("Required step '{}' not found in tools", step));
168            }
169        }
170
171        // Normalize terminal_tool to a HashSet.
172        let terminal_set: HashSet<String> = match terminal_tool {
173            TerminalToolInput::Single(s) => {
174                let mut set = HashSet::new();
175                set.insert(s);
176                set
177            }
178            TerminalToolInput::Multiple(v) => v.into_iter().collect(),
179        };
180
181        // Validate terminal tools exist in tools.
182        for t in &terminal_set {
183            if !tool_names.contains(t.as_str()) {
184                return Err(format!("Terminal tool '{}' not found in tools", t));
185            }
186        }
187
188        // Validate terminal tools are not also required steps.
189        let required_set: HashSet<&str> = required_steps.iter().map(|s| s.as_str()).collect();
190        for t in &terminal_set {
191            if required_set.contains(t.as_str()) {
192                return Err(format!(
193                    "Terminal tool '{}' cannot also be a required step",
194                    t
195                ));
196            }
197        }
198
199        // Validate prerequisites reference existing tools.
200        for (_, def) in &tools {
201            for prereq in &def.prerequisites {
202                let prereq_tool = match prereq {
203                    PrerequisiteSpec::NameOnly(name) => name.as_str(),
204                    PrerequisiteSpec::ArgMatched { tool, .. } => tool.as_str(),
205                };
206                if !tool_names.contains(prereq_tool) {
207                    return Err(format!(
208                        "Prerequisite references tool '{}' which is not in the tools map",
209                        prereq_tool
210                    ));
211                }
212            }
213        }
214
215        Ok(Self {
216            name,
217            description,
218            tools,
219            required_steps,
220            terminal_tools: terminal_set,
221            system_prompt_template,
222        })
223    }
224
225    /// Render the system prompt template with the provided variables.
226    pub fn build_system_prompt(&self, vars: &IndexMap<String, String>) -> String {
227        let mut result = self.system_prompt_template.clone();
228        for (key, value) in vars {
229            let pattern = format!("{{{}}}", key);
230            result = result.replace(&pattern, value);
231        }
232        result
233    }
234
235    /// Get all tool specs in insertion order.
236    pub fn get_tool_specs(&self) -> Vec<&ToolSpec> {
237        self.tools.values().map(|def| &def.spec).collect()
238    }
239
240    /// Get the callable for a tool by name.
241    ///
242    /// Returns an error if the tool name is not found.
243    pub fn get_callable(&self, tool_name: &str) -> Result<ToolCallable, String> {
244        match self.tools.get(tool_name) {
245            Some(def) => Ok(def.callable.clone()),
246            None => Err(format!("Tool '{}' not found", tool_name)),
247        }
248    }
249}
250
251impl fmt::Debug for Workflow {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        f.debug_struct("Workflow")
254            .field("name", &self.name)
255            .field("required_steps", &self.required_steps)
256            .field("terminal_tools", &self.terminal_tools)
257            .finish()
258    }
259}
260
261/// Input type for terminal_tool: either a single string or a list.
262#[derive(Debug, Clone)]
263pub enum TerminalToolInput {
264    /// A single terminal tool name.
265    Single(String),
266    /// Multiple terminal tool names.
267    Multiple(Vec<String>),
268}
269
270impl From<String> for TerminalToolInput {
271    fn from(s: String) -> Self {
272        Self::Single(s)
273    }
274}
275
276impl From<Vec<String>> for TerminalToolInput {
277    fn from(v: Vec<String>) -> Self {
278        Self::Multiple(v)
279    }
280}