fm_rs/
tool.rs

1//! Tool calling support for `FoundationModels`.
2//!
3//! Tools allow the model to call external functions during generation.
4//! Implement the [`Tool`] trait to define custom tools.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashSet;
9
10use crate::error::Result;
11
12/// A tool that can be invoked by the model.
13///
14/// # Example
15///
16/// ```rust
17/// use fm_rs::{Tool, ToolOutput};
18/// use serde_json::{json, Value};
19///
20/// struct WeatherTool;
21///
22/// impl Tool for WeatherTool {
23///     fn name(&self) -> &str {
24///         "checkWeather"
25///     }
26///
27///     fn description(&self) -> &str {
28///         "Check current weather conditions"
29///     }
30///
31///     fn arguments_schema(&self) -> Value {
32///         json!({
33///             "type": "object",
34///             "properties": {
35///                 "location": {
36///                     "type": "string",
37///                     "description": "The city and country"
38///                 }
39///             },
40///             "required": ["location"]
41///         })
42///     }
43///
44///     fn call(&self, arguments: Value) -> fm_rs::Result<ToolOutput> {
45///         let location = arguments["location"].as_str().unwrap_or("Unknown");
46///         Ok(ToolOutput::new(format!("Weather in {location}: Sunny, 72°F")))
47///     }
48/// }
49/// ```
50pub trait Tool: Send + Sync {
51    /// Returns the name of the tool.
52    fn name(&self) -> &str;
53
54    /// Returns a description of what the tool does.
55    fn description(&self) -> &str;
56
57    /// Returns the JSON schema for the tool's arguments.
58    fn arguments_schema(&self) -> Value;
59
60    /// Invokes the tool with the given arguments.
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if the tool invocation fails.
65    fn call(&self, arguments: Value) -> Result<ToolOutput>;
66}
67
68/// Output returned by a tool invocation.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ToolOutput {
71    /// The content returned by the tool.
72    pub content: String,
73}
74
75impl ToolOutput {
76    /// Creates a new tool output with the given content.
77    pub fn new(content: impl Into<String>) -> Self {
78        Self {
79            content: content.into(),
80        }
81    }
82
83    /// Creates a tool output from a JSON-serializable value.
84    pub fn from_json<T: Serialize>(value: &T) -> Result<Self> {
85        let content = serde_json::to_string(value)?;
86        Ok(Self { content })
87    }
88}
89
90/// Internal representation of a tool for serialization to Swift.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub(crate) struct ToolDefinition {
93    pub name: String,
94    pub description: String,
95    #[serde(rename = "argumentsSchema")]
96    pub arguments_schema: Value,
97}
98
99/// Serializes a list of tools to JSON for FFI.
100///
101/// # Errors
102///
103/// Returns an error if JSON serialization fails.
104pub(crate) fn tools_to_json(tools: &[&dyn Tool]) -> crate::error::Result<String> {
105    let mut seen = HashSet::new();
106    let mut definitions = Vec::with_capacity(tools.len());
107
108    for tool in tools {
109        let name = tool.name().trim();
110        if name.is_empty() {
111            return Err(crate::error::Error::InvalidInput(
112                "Tool name cannot be empty".to_string(),
113            ));
114        }
115        if !seen.insert(name.to_string()) {
116            return Err(crate::error::Error::InvalidInput(format!(
117                "Duplicate tool name: {name}"
118            )));
119        }
120
121        let schema = tool.arguments_schema();
122        let schema_obj = schema.as_object().ok_or_else(|| {
123            crate::error::Error::InvalidInput(format!(
124                "Tool '{name}' arguments schema must be a JSON object"
125            ))
126        })?;
127        if let Some(Value::String(ty)) = schema_obj.get("type") {
128            if ty != "object" {
129                return Err(crate::error::Error::InvalidInput(format!(
130                    "Tool '{name}' arguments schema must have type \"object\""
131                )));
132            }
133        }
134
135        definitions.push(ToolDefinition {
136            name: name.to_string(),
137            description: tool.description().to_string(),
138            arguments_schema: schema,
139        });
140    }
141
142    serde_json::to_string(&definitions)
143        .map_err(|e| crate::error::Error::InvalidInput(format!("Failed to serialize tools: {e}")))
144}
145
146/// Result of invoking a tool, serialized for FFI.
147#[derive(Debug, Serialize, Deserialize)]
148pub(crate) struct ToolResult {
149    pub success: bool,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub content: Option<String>,
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub error: Option<String>,
154}
155
156impl ToolResult {
157    pub fn success(output: ToolOutput) -> Self {
158        Self {
159            success: true,
160            content: Some(output.content),
161            error: None,
162        }
163    }
164
165    pub fn error(message: impl Into<String>) -> Self {
166        Self {
167            success: false,
168            content: None,
169            error: Some(message.into()),
170        }
171    }
172
173    pub fn to_json(&self) -> String {
174        serde_json::to_string(self).unwrap_or_else(|_| {
175            r#"{"success":false,"error":"Failed to serialize result"}"#.to_string()
176        })
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use serde_json::json;
184
185    struct TestTool;
186
187    impl Tool for TestTool {
188        fn name(&self) -> &'static str {
189            "test_tool"
190        }
191
192        fn description(&self) -> &'static str {
193            "A test tool"
194        }
195
196        fn arguments_schema(&self) -> Value {
197            json!({
198                "type": "object",
199                "properties": {
200                    "input": {"type": "string"}
201                }
202            })
203        }
204
205        fn call(&self, arguments: Value) -> Result<ToolOutput> {
206            let input = arguments["input"].as_str().unwrap_or("default");
207            Ok(ToolOutput::new(format!("Processed: {input}")))
208        }
209    }
210
211    #[test]
212    fn test_tool_definition() {
213        let tool = TestTool;
214        let def = ToolDefinition {
215            name: tool.name().to_string(),
216            description: tool.description().to_string(),
217            arguments_schema: tool.arguments_schema(),
218        };
219
220        assert_eq!(def.name, "test_tool");
221        assert_eq!(def.description, "A test tool");
222    }
223
224    #[test]
225    fn test_tools_to_json() {
226        let tool = TestTool;
227        let tools: Vec<&dyn Tool> = vec![&tool];
228        let json = tools_to_json(&tools).expect("serialization should succeed");
229
230        assert!(json.contains("test_tool"));
231        assert!(json.contains("A test tool"));
232    }
233
234    #[test]
235    fn test_tools_to_json_duplicate_names() {
236        struct ToolA;
237        struct ToolB;
238
239        impl Tool for ToolA {
240            fn name(&self) -> &'static str {
241                "duplicate"
242            }
243
244            fn description(&self) -> &'static str {
245                "Tool A"
246            }
247
248            fn arguments_schema(&self) -> Value {
249                json!({"type": "object"})
250            }
251
252            fn call(&self, _arguments: Value) -> Result<ToolOutput> {
253                Ok(ToolOutput::new("ok"))
254            }
255        }
256
257        impl Tool for ToolB {
258            fn name(&self) -> &'static str {
259                "duplicate"
260            }
261
262            fn description(&self) -> &'static str {
263                "Tool B"
264            }
265
266            fn arguments_schema(&self) -> Value {
267                json!({"type": "object"})
268            }
269
270            fn call(&self, _arguments: Value) -> Result<ToolOutput> {
271                Ok(ToolOutput::new("ok"))
272            }
273        }
274
275        let tools: Vec<&dyn Tool> = vec![&ToolA, &ToolB];
276        let err = tools_to_json(&tools).expect_err("expected duplicate error");
277        assert!(err.to_string().contains("Duplicate tool name"));
278    }
279
280    #[test]
281    fn test_tools_to_json_requires_object_schema() {
282        struct BadTool;
283
284        impl Tool for BadTool {
285            fn name(&self) -> &'static str {
286                "bad"
287            }
288
289            fn description(&self) -> &'static str {
290                "Bad schema"
291            }
292
293            fn arguments_schema(&self) -> Value {
294                json!("not-an-object")
295            }
296
297            fn call(&self, _arguments: Value) -> Result<ToolOutput> {
298                Ok(ToolOutput::new("ok"))
299            }
300        }
301
302        let tools: Vec<&dyn Tool> = vec![&BadTool];
303        let err = tools_to_json(&tools).expect_err("expected schema error");
304        assert!(
305            err.to_string()
306                .contains("arguments schema must be a JSON object")
307        );
308    }
309
310    #[test]
311    fn test_tools_to_json_requires_object_type() {
312        struct WrongTypeTool;
313
314        impl Tool for WrongTypeTool {
315            fn name(&self) -> &'static str {
316                "wrong_type"
317            }
318
319            fn description(&self) -> &'static str {
320                "Wrong type"
321            }
322
323            fn arguments_schema(&self) -> Value {
324                json!({"type": "string"})
325            }
326
327            fn call(&self, _arguments: Value) -> Result<ToolOutput> {
328                Ok(ToolOutput::new("ok"))
329            }
330        }
331
332        let tools: Vec<&dyn Tool> = vec![&WrongTypeTool];
333        let err = tools_to_json(&tools).expect_err("expected type error");
334        assert!(err.to_string().contains("must have type \"object\""));
335    }
336
337    #[test]
338    fn test_tool_output() {
339        let output = ToolOutput::new("Hello, World!");
340        assert_eq!(output.content, "Hello, World!");
341    }
342
343    #[test]
344    fn test_tool_result_json() {
345        let result = ToolResult::success(ToolOutput::new("OK"));
346        let json = result.to_json();
347        assert!(json.contains("\"success\":true"));
348        assert!(json.contains("\"content\":\"OK\""));
349    }
350}