Skip to main content

foundation_models/
tool.rs

1//! Tool calling support.
2
3use core::ffi::{c_char, c_void};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::Arc;
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use crate::content::{FromGeneratedContent, GeneratedContent};
13use crate::error::FMError;
14use crate::ffi;
15use crate::prompt::{Prompt, ToPrompt, ToolDefinition};
16use crate::schema::{Generable, GenerationSchema};
17
18fn swift_dup_string(value: &str) -> *mut c_char {
19    let c_string = CString::new(value).expect("bridge strings must not contain interior NUL bytes");
20    unsafe { ffi::fm_string_dup(c_string.as_ptr()) }
21}
22
23/// One tool exposed to the system language model.
24pub struct Tool {
25    spec: ToolSpec,
26    handler: Arc<dyn Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync>,
27}
28
29impl Tool {
30    /// Create a tool from a dynamic `GeneratedContent` handler.
31    #[must_use]
32    pub fn new<F>(
33        name: impl Into<String>,
34        description: impl Into<String>,
35        parameters: GenerationSchema,
36        handler: F,
37    ) -> Self
38    where
39        F: Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync + 'static,
40    {
41        Self {
42            spec: ToolSpec {
43                name: name.into(),
44                description: description.into(),
45                parameters,
46                includes_schema_in_instructions: true,
47            },
48            handler: Arc::new(handler),
49        }
50    }
51
52    /// Create a tool whose handler receives decoded JSON arguments.
53    #[must_use]
54    pub fn json<Args, Output, F>(
55        name: impl Into<String>,
56        description: impl Into<String>,
57        parameters: GenerationSchema,
58        handler: F,
59    ) -> Self
60    where
61        Args: DeserializeOwned + Send + 'static,
62        Output: ToPrompt,
63        F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
64    {
65        Self::new(name, description, parameters, move |arguments| {
66            let decoded = arguments.value::<Args>()?;
67            let output = handler(decoded)?;
68            Ok(ToolOutput::from_prompt(output.to_prompt()?))
69        })
70    }
71
72    /// Create a tool whose argument schema is inferred from a [`Generable`] type.
73    ///
74    /// # Errors
75    ///
76    /// Returns an [`FMError`] if `Args` cannot produce a generation schema.
77    pub fn generable<Args, Output, F>(
78        name: impl Into<String>,
79        description: impl Into<String>,
80        handler: F,
81    ) -> Result<Self, FMError>
82    where
83        Args: FromGeneratedContent + Generable + Send + 'static,
84        Output: ToPrompt,
85        F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
86    {
87        Ok(Self::new(
88            name,
89            description,
90            Args::generation_schema()?,
91            move |arguments| {
92                let decoded = Args::from_generated_content(&arguments)?;
93                let output = handler(decoded)?;
94                Ok(ToolOutput::from_prompt(output.to_prompt()?))
95            },
96        ))
97    }
98
99    /// Tool metadata as exposed to FoundationModels.
100    #[must_use]
101    pub const fn spec(&self) -> &ToolSpec {
102        &self.spec
103    }
104
105    /// Convert this tool into a transcript tool definition.
106    #[must_use]
107    pub fn definition(&self) -> ToolDefinition {
108        self.spec.definition()
109    }
110
111    /// Control whether the schema is included in the model's tool instructions.
112    #[must_use]
113    pub fn with_schema_in_instructions(mut self, includes: bool) -> Self {
114        self.spec.includes_schema_in_instructions = includes;
115        self
116    }
117}
118
119impl core::fmt::Debug for Tool {
120    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121        f.debug_struct("Tool").field("spec", &self.spec).finish()
122    }
123}
124
125/// Public metadata for one tool.
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct ToolSpec {
128    pub name: String,
129    pub description: String,
130    pub parameters: GenerationSchema,
131    pub includes_schema_in_instructions: bool,
132}
133
134impl ToolSpec {
135    /// Convert this tool specification into a transcript tool definition.
136    #[must_use]
137    pub fn definition(&self) -> ToolDefinition {
138        ToolDefinition::new(
139            self.name.clone(),
140            self.description.clone(),
141            self.parameters.clone(),
142        )
143    }
144}
145
146/// A tool output converted into a prompt representation.
147#[derive(Debug, Clone, PartialEq)]
148pub struct ToolOutput {
149    prompt: Prompt,
150}
151
152impl ToolOutput {
153    /// Return a tool output as plain text.
154    #[must_use]
155    pub fn text(text: impl Into<String>) -> Self {
156        Self {
157            prompt: Prompt::from(text.into()),
158        }
159    }
160
161    /// Return a tool output as structured content.
162    #[must_use]
163    pub fn structured(content: GeneratedContent) -> Self {
164        Self {
165            prompt: Prompt::from(content),
166        }
167    }
168
169    /// Return a prebuilt prompt output.
170    #[must_use]
171    pub const fn from_prompt(prompt: Prompt) -> Self {
172        Self { prompt }
173    }
174
175    #[must_use]
176    pub fn prompt(&self) -> &Prompt {
177        &self.prompt
178    }
179
180    pub(crate) fn to_bridge_json(&self) -> Result<String, FMError> {
181        serde_json::to_string(&json!({ "prompt": self.prompt.to_bridge_value() })).map_err(
182            |error| {
183                FMError::InvalidArgument(format!("tool output is not JSON-serializable: {error}"))
184            },
185        )
186    }
187}
188
189impl From<String> for ToolOutput {
190    fn from(value: String) -> Self {
191        Self::text(value)
192    }
193}
194
195impl From<&str> for ToolOutput {
196    fn from(value: &str) -> Self {
197        Self::text(value)
198    }
199}
200
201impl From<GeneratedContent> for ToolOutput {
202    fn from(value: GeneratedContent) -> Self {
203        Self::structured(value)
204    }
205}
206
207impl From<Prompt> for ToolOutput {
208    fn from(value: Prompt) -> Self {
209        Self::from_prompt(value)
210    }
211}
212
213pub(crate) struct ToolRegistry {
214    tools: HashMap<String, Tool>,
215}
216
217impl ToolRegistry {
218    pub(crate) fn new(tools: Vec<Tool>) -> Self {
219        Self {
220            tools: tools
221                .into_iter()
222                .map(|tool| (tool.spec.name.clone(), tool))
223                .collect(),
224        }
225    }
226
227    pub(crate) fn specs_json(&self) -> Result<String, FMError> {
228        let specs = self
229            .tools
230            .values()
231            .map(|tool| {
232                json!({
233                    "name": tool.spec.name,
234                    "description": tool.spec.description,
235                    "parametersJSON": tool.spec.parameters.json_schema(),
236                    "includesSchemaInInstructions": tool.spec.includes_schema_in_instructions,
237                })
238            })
239            .collect::<Vec<_>>();
240        serde_json::to_string(&specs).map_err(|error| {
241            FMError::InvalidArgument(format!("tool specs are not JSON-serializable: {error}"))
242        })
243    }
244
245    fn invoke(&self, tool_name: &str, arguments: GeneratedContent) -> Result<ToolOutput, FMError> {
246        let tool = self.tools.get(tool_name).ok_or_else(|| {
247            FMError::ToolCallFailed(format!("tool `{tool_name}` is not registered"))
248        })?;
249        (tool.handler)(arguments)
250    }
251}
252
253pub(crate) unsafe extern "C" fn tool_callback_trampoline(
254    context: *mut c_void,
255    tool_name: *const c_char,
256    arguments_json: *const c_char,
257    output_json_out: *mut *mut c_char,
258    error_out: *mut *mut c_char,
259) -> i32 {
260    let registry = &*(context.cast::<ToolRegistry>());
261    let result = catch_unwind(AssertUnwindSafe(|| {
262        let tool_name = CStr::from_ptr(tool_name).to_string_lossy().into_owned();
263        let arguments_json = CStr::from_ptr(arguments_json)
264            .to_string_lossy()
265            .into_owned();
266        let arguments = GeneratedContent::from_json_str(&arguments_json)?;
267        let output = registry.invoke(&tool_name, arguments)?;
268        output.to_bridge_json()
269    }));
270
271    match result {
272        Ok(Ok(output_json)) => {
273            *output_json_out = swift_dup_string(&output_json);
274            ffi::status::OK
275        }
276        Ok(Err(error)) => {
277            *error_out = swift_dup_string(error.message());
278            error.code()
279        }
280        Err(_) => {
281            *error_out = swift_dup_string("tool callback panicked");
282            ffi::status::TOOL_CALL_FAILED
283        }
284    }
285}