Skip to main content

foundation_models/
tool.rs

1//! Tool calling support.
2
3use core::ffi::{c_char, c_void};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::Arc;
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use crate::content::GeneratedContent;
13use crate::error::FMError;
14use crate::ffi;
15use crate::prompt::{Prompt, ToPrompt};
16use crate::schema::GenerationSchema;
17
18fn swift_dup_string(value: &str) -> *mut c_char {
19    let c_string = CString::new(value).expect("bridge strings must not contain interior NUL bytes");
20    unsafe { ffi::fm_string_dup(c_string.as_ptr()) }
21}
22
23/// One tool exposed to the system language model.
24pub struct Tool {
25    spec: ToolSpec,
26    handler: Arc<dyn Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync>,
27}
28
29impl Tool {
30    /// Create a tool from a dynamic `GeneratedContent` handler.
31    #[must_use]
32    pub fn new<F>(
33        name: impl Into<String>,
34        description: impl Into<String>,
35        parameters: GenerationSchema,
36        handler: F,
37    ) -> Self
38    where
39        F: Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync + 'static,
40    {
41        Self {
42            spec: ToolSpec {
43                name: name.into(),
44                description: description.into(),
45                parameters,
46                includes_schema_in_instructions: true,
47            },
48            handler: Arc::new(handler),
49        }
50    }
51
52    /// Create a tool whose handler receives decoded JSON arguments.
53    #[must_use]
54    pub fn json<Args, Output, F>(
55        name: impl Into<String>,
56        description: impl Into<String>,
57        parameters: GenerationSchema,
58        handler: F,
59    ) -> Self
60    where
61        Args: DeserializeOwned + Send + 'static,
62        Output: ToPrompt,
63        F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
64    {
65        Self::new(name, description, parameters, move |arguments| {
66            let decoded = arguments.value::<Args>()?;
67            let output = handler(decoded)?;
68            Ok(ToolOutput::from_prompt(output.to_prompt()?))
69        })
70    }
71
72    /// Control whether the schema is included in the model's tool instructions.
73    #[must_use]
74    pub fn with_schema_in_instructions(mut self, includes: bool) -> Self {
75        self.spec.includes_schema_in_instructions = includes;
76        self
77    }
78}
79
80impl core::fmt::Debug for Tool {
81    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82        f.debug_struct("Tool").field("spec", &self.spec).finish()
83    }
84}
85
86/// Public metadata for one tool.
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct ToolSpec {
89    pub name: String,
90    pub description: String,
91    pub parameters: GenerationSchema,
92    pub includes_schema_in_instructions: bool,
93}
94
95/// A tool output converted into a prompt representation.
96#[derive(Debug, Clone, PartialEq)]
97pub struct ToolOutput {
98    prompt: Prompt,
99}
100
101impl ToolOutput {
102    /// Return a tool output as plain text.
103    #[must_use]
104    pub fn text(text: impl Into<String>) -> Self {
105        Self {
106            prompt: Prompt::from(text.into()),
107        }
108    }
109
110    /// Return a tool output as structured content.
111    #[must_use]
112    pub fn structured(content: GeneratedContent) -> Self {
113        Self {
114            prompt: Prompt::from(content),
115        }
116    }
117
118    /// Return a prebuilt prompt output.
119    #[must_use]
120    pub const fn from_prompt(prompt: Prompt) -> Self {
121        Self { prompt }
122    }
123
124    #[must_use]
125    pub fn prompt(&self) -> &Prompt {
126        &self.prompt
127    }
128
129    pub(crate) fn to_bridge_json(&self) -> Result<String, FMError> {
130        serde_json::to_string(&json!({ "prompt": self.prompt.to_bridge_value() })).map_err(
131            |error| {
132                FMError::InvalidArgument(format!("tool output is not JSON-serializable: {error}"))
133            },
134        )
135    }
136}
137
138impl From<String> for ToolOutput {
139    fn from(value: String) -> Self {
140        Self::text(value)
141    }
142}
143
144impl From<&str> for ToolOutput {
145    fn from(value: &str) -> Self {
146        Self::text(value)
147    }
148}
149
150impl From<GeneratedContent> for ToolOutput {
151    fn from(value: GeneratedContent) -> Self {
152        Self::structured(value)
153    }
154}
155
156impl From<Prompt> for ToolOutput {
157    fn from(value: Prompt) -> Self {
158        Self::from_prompt(value)
159    }
160}
161
162pub(crate) struct ToolRegistry {
163    tools: HashMap<String, Tool>,
164}
165
166impl ToolRegistry {
167    pub(crate) fn new(tools: Vec<Tool>) -> Self {
168        Self {
169            tools: tools
170                .into_iter()
171                .map(|tool| (tool.spec.name.clone(), tool))
172                .collect(),
173        }
174    }
175
176    pub(crate) fn specs_json(&self) -> Result<String, FMError> {
177        let specs = self
178            .tools
179            .values()
180            .map(|tool| {
181                json!({
182                    "name": tool.spec.name,
183                    "description": tool.spec.description,
184                    "parametersJSON": tool.spec.parameters.json_schema(),
185                    "includesSchemaInInstructions": tool.spec.includes_schema_in_instructions,
186                })
187            })
188            .collect::<Vec<_>>();
189        serde_json::to_string(&specs).map_err(|error| {
190            FMError::InvalidArgument(format!("tool specs are not JSON-serializable: {error}"))
191        })
192    }
193
194    fn invoke(&self, tool_name: &str, arguments: GeneratedContent) -> Result<ToolOutput, FMError> {
195        let tool = self.tools.get(tool_name).ok_or_else(|| {
196            FMError::ToolCallFailed(format!("tool `{tool_name}` is not registered"))
197        })?;
198        (tool.handler)(arguments)
199    }
200}
201
202pub(crate) unsafe extern "C" fn tool_callback_trampoline(
203    context: *mut c_void,
204    tool_name: *const c_char,
205    arguments_json: *const c_char,
206    output_json_out: *mut *mut c_char,
207    error_out: *mut *mut c_char,
208) -> i32 {
209    let registry = &*(context.cast::<ToolRegistry>());
210    let result = catch_unwind(AssertUnwindSafe(|| {
211        let tool_name = CStr::from_ptr(tool_name).to_string_lossy().into_owned();
212        let arguments_json = CStr::from_ptr(arguments_json)
213            .to_string_lossy()
214            .into_owned();
215        let arguments = GeneratedContent::from_json_str(&arguments_json)?;
216        let output = registry.invoke(&tool_name, arguments)?;
217        output.to_bridge_json()
218    }));
219
220    match result {
221        Ok(Ok(output_json)) => {
222            *output_json_out = swift_dup_string(&output_json);
223            ffi::status::OK
224        }
225        Ok(Err(error)) => {
226            *error_out = swift_dup_string(error.message());
227            error.code()
228        }
229        Err(_) => {
230            *error_out = swift_dup_string("tool callback panicked");
231            ffi::status::TOOL_CALL_FAILED
232        }
233    }
234}