swiftide_agents/tools/
control.rs

1//! Control tools manage control flow during agent's lifecycle.
2use anyhow::Result;
3use async_trait::async_trait;
4use schemars::{Schema, schema_for};
5use std::borrow::Cow;
6use swiftide_core::{
7    AgentContext, ToolFeedback,
8    chat_completion::{Tool, ToolCall, ToolOutput, ToolSpec, errors::ToolError},
9};
10
11/// `Stop` tool is a default tool used by agents to stop
12#[derive(Clone, Debug, Default)]
13pub struct Stop {}
14
15#[async_trait]
16impl Tool for Stop {
17    async fn invoke(
18        &self,
19        _agent_context: &dyn AgentContext,
20        _tool_call: &ToolCall,
21    ) -> Result<ToolOutput, ToolError> {
22        Ok(ToolOutput::stop())
23    }
24
25    fn name(&self) -> Cow<'_, str> {
26        "stop".into()
27    }
28
29    fn tool_spec(&self) -> ToolSpec {
30        ToolSpec::builder()
31            .name("stop")
32            .description("When you have completed, or cannot complete, your task, call this")
33            .build()
34            .unwrap()
35    }
36}
37
38impl From<Stop> for Box<dyn Tool> {
39    fn from(val: Stop) -> Self {
40        Box::new(val)
41    }
42}
43
44/// `StopWithArgs` is an alternative stop tool that takes arguments
45#[derive(Clone, Debug)]
46pub struct StopWithArgs {
47    parameters_schema: Option<Schema>,
48    expects_output_field: bool,
49}
50
51impl Default for StopWithArgs {
52    fn default() -> Self {
53        Self {
54            parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)),
55            expects_output_field: true,
56        }
57    }
58}
59
60impl StopWithArgs {
61    /// Create a new `StopWithArgs` tool with a custom parameters schema.
62    ///
63    /// When providing a custom schema the full argument payload will be forwarded to the
64    /// stop output without requiring an `output` field wrapper.
65    pub fn with_parameters_schema(schema: Schema) -> Self {
66        Self {
67            parameters_schema: Some(schema),
68            expects_output_field: false,
69        }
70    }
71
72    fn parameters_schema(&self) -> Schema {
73        self.parameters_schema
74            .clone()
75            .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec))
76    }
77}
78
79#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
80struct DefaultStopWithArgsSpec {
81    pub output: String,
82}
83
84#[async_trait]
85impl Tool for StopWithArgs {
86    async fn invoke(
87        &self,
88        _agent_context: &dyn AgentContext,
89        tool_call: &ToolCall,
90    ) -> Result<ToolOutput, ToolError> {
91        let raw_args = tool_call
92            .args()
93            .ok_or_else(|| ToolError::missing_arguments("arguments"))?;
94
95        let json: serde_json::Value = serde_json::from_str(raw_args)?;
96
97        let output = if self.expects_output_field {
98            json.get("output")
99                .cloned()
100                .ok_or_else(|| ToolError::missing_arguments("output"))?
101        } else {
102            json
103        };
104
105        Ok(ToolOutput::stop_with_args(output))
106    }
107
108    fn name(&self) -> Cow<'_, str> {
109        "stop".into()
110    }
111
112    fn tool_spec(&self) -> ToolSpec {
113        let schema = self.parameters_schema();
114
115        ToolSpec::builder()
116            .name("stop")
117            .description("When you have completed, your task, call this with your expected output")
118            .parameters_schema(schema)
119            .build()
120            .unwrap()
121    }
122}
123
124impl From<StopWithArgs> for Box<dyn Tool> {
125    fn from(val: StopWithArgs) -> Self {
126        Box::new(val)
127    }
128}
129
130#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
131struct AgentFailedArgsSpec {
132    pub reason: String,
133}
134
135/// A utility tool that can be used to let an agent decide it failed
136///
137/// This will _NOT_ have the agent return an error, instead, look at the stop reason of the agent.
138#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
139pub struct AgentCanFail {
140    parameters_schema: Option<Schema>,
141    expects_reason_field: bool,
142}
143
144impl Default for AgentCanFail {
145    fn default() -> Self {
146        Self {
147            parameters_schema: Some(schema_for!(AgentFailedArgsSpec)),
148            expects_reason_field: true,
149        }
150    }
151}
152
153impl AgentCanFail {
154    /// Create a new `AgentCanFail` tool with a custom parameters schema.
155    ///
156    /// When providing a custom schema the full argument payload will be forwarded to the failure
157    /// reason without requiring a `reason` field wrapper.
158    pub fn with_parameters_schema(schema: Schema) -> Self {
159        Self {
160            parameters_schema: Some(schema),
161            expects_reason_field: false,
162        }
163    }
164
165    fn parameters_schema(&self) -> Schema {
166        self.parameters_schema
167            .clone()
168            .unwrap_or_else(|| schema_for!(AgentFailedArgsSpec))
169    }
170}
171
172#[async_trait]
173impl Tool for AgentCanFail {
174    async fn invoke(
175        &self,
176        _agent_context: &dyn AgentContext,
177        tool_call: &ToolCall,
178    ) -> Result<ToolOutput, ToolError> {
179        let raw_args = tool_call.args().ok_or_else(|| {
180            if self.expects_reason_field {
181                ToolError::missing_arguments("reason")
182            } else {
183                ToolError::missing_arguments("arguments")
184            }
185        })?;
186
187        let reason = if self.expects_reason_field {
188            let args: AgentFailedArgsSpec = serde_json::from_str(raw_args)?;
189            args.reason
190        } else {
191            let json: serde_json::Value = serde_json::from_str(raw_args)?;
192            json.to_string()
193        };
194
195        Ok(ToolOutput::agent_failed(reason))
196    }
197
198    fn name(&self) -> Cow<'_, str> {
199        "task_failed".into()
200    }
201
202    fn tool_spec(&self) -> ToolSpec {
203        let schema = self.parameters_schema();
204
205        ToolSpec::builder()
206            .name("task_failed")
207            .description("If you cannot complete your task, or have otherwise failed, call this with your reason for failure")
208            .parameters_schema(schema)
209            .build()
210            .unwrap()
211    }
212}
213
214impl From<AgentCanFail> for Box<dyn Tool> {
215    fn from(val: AgentCanFail) -> Self {
216        Box::new(val)
217    }
218}
219
220#[derive(Clone)]
221/// Wraps a tool and requires approval before it can be used
222pub struct ApprovalRequired(pub Box<dyn Tool>);
223
224impl ApprovalRequired {
225    /// Creates a new `ApprovalRequired` tool
226    pub fn new(tool: impl Tool + 'static) -> Self {
227        Self(Box::new(tool))
228    }
229}
230
231#[async_trait]
232impl Tool for ApprovalRequired {
233    async fn invoke(
234        &self,
235        context: &dyn AgentContext,
236        tool_call: &ToolCall,
237    ) -> Result<ToolOutput, ToolError> {
238        if let Some(feedback) = context.has_received_feedback(tool_call).await {
239            match feedback {
240                ToolFeedback::Approved { .. } => return self.0.invoke(context, tool_call).await,
241                ToolFeedback::Refused { .. } => {
242                    return Ok(ToolOutput::text("This tool call was refused"));
243                }
244            }
245        }
246
247        Ok(ToolOutput::FeedbackRequired(None))
248    }
249
250    fn name(&self) -> Cow<'_, str> {
251        self.0.name()
252    }
253
254    fn tool_spec(&self) -> ToolSpec {
255        self.0.tool_spec()
256    }
257}
258
259impl From<ApprovalRequired> for Box<dyn Tool> {
260    fn from(val: ApprovalRequired) -> Self {
261        Box::new(val)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use schemars::schema_for;
269    use serde_json::json;
270
271    fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall {
272        let mut builder = ToolCall::builder().name(name).id("1").to_owned();
273        if let Some(args) = args {
274            builder.args(args.to_string());
275        }
276        builder.build().unwrap()
277    }
278
279    #[tokio::test]
280    async fn test_stop_tool() {
281        let stop = Stop::default();
282        let ctx = ();
283        let tool_call = dummy_tool_call("stop", None);
284        let out = stop.invoke(&ctx, &tool_call).await.unwrap();
285        assert_eq!(out, ToolOutput::stop());
286    }
287
288    #[tokio::test]
289    async fn test_stop_with_args_tool() {
290        let tool = StopWithArgs::default();
291        let ctx = ();
292        let args = r#"{"output":"expected result"}"#;
293        let tool_call = dummy_tool_call("stop", Some(args));
294        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
295        assert_eq!(out, ToolOutput::stop_with_args("expected result"));
296    }
297
298    #[tokio::test]
299    async fn test_agent_can_fail_tool() {
300        let tool = AgentCanFail::default();
301        let ctx = ();
302        let args = r#"{"reason":"something went wrong"}"#;
303        let tool_call = dummy_tool_call("task_failed", Some(args));
304        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
305        assert_eq!(out, ToolOutput::agent_failed("something went wrong"));
306    }
307
308    #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
309    struct CustomFailArgs {
310        code: i32,
311        message: String,
312    }
313
314    #[test]
315    fn test_agent_can_fail_custom_schema_in_spec() {
316        let schema = schema_for!(CustomFailArgs);
317        let tool = AgentCanFail::with_parameters_schema(schema.clone());
318        let spec = tool.tool_spec();
319        assert_eq!(spec.parameters_schema, Some(schema));
320    }
321
322    #[tokio::test]
323    async fn test_agent_can_fail_custom_schema_forwards_payload() {
324        let schema = schema_for!(CustomFailArgs);
325        let tool = AgentCanFail::with_parameters_schema(schema);
326        let ctx = ();
327        let args = r#"{"code":7,"message":"error"}"#;
328        let tool_call = dummy_tool_call("task_failed", Some(args));
329        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
330        assert_eq!(
331            out,
332            ToolOutput::agent_failed(json!({"code":7,"message":"error"}).to_string())
333        );
334    }
335
336    #[test]
337    fn test_agent_can_fail_default_schema_matches_previous() {
338        let tool = AgentCanFail::default();
339        let spec = tool.tool_spec();
340        let expected = schema_for!(AgentFailedArgsSpec);
341        assert_eq!(spec.parameters_schema, Some(expected));
342    }
343
344    #[tokio::test]
345    async fn test_approval_required_feedback_required() {
346        let stop = Stop::default();
347        let tool = ApprovalRequired::new(stop);
348        let ctx = ();
349        let tool_call = dummy_tool_call("stop", None);
350        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
351
352        // On unit; existing feedback is always present
353        assert_eq!(out, ToolOutput::Stop(None));
354    }
355
356    #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
357    struct CustomStopArgs {
358        value: i32,
359    }
360
361    #[test]
362    fn test_stop_with_args_custom_schema_in_spec() {
363        let schema = schema_for!(CustomStopArgs);
364        let tool = StopWithArgs::with_parameters_schema(schema.clone());
365        let spec = tool.tool_spec();
366        assert_eq!(spec.parameters_schema, Some(schema));
367    }
368
369    #[tokio::test]
370    async fn test_stop_with_args_custom_schema_forwards_payload() {
371        let schema = schema_for!(CustomStopArgs);
372        let tool = StopWithArgs::with_parameters_schema(schema);
373        let ctx = ();
374        let args = r#"{"value":42}"#;
375        let tool_call = dummy_tool_call("stop", Some(args));
376        let out = tool.invoke(&ctx, &tool_call).await.unwrap();
377        assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42})));
378    }
379
380    #[test]
381    fn test_stop_with_args_default_schema_matches_previous() {
382        let tool = StopWithArgs::default();
383        let spec = tool.tool_spec();
384        let expected = schema_for!(DefaultStopWithArgsSpec);
385        assert_eq!(spec.parameters_schema, Some(expected));
386    }
387}