foundation_models/
tool.rs1use core::ffi::{c_char, c_void};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::Arc;
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use crate::content::GeneratedContent;
13use crate::error::FMError;
14use crate::ffi;
15use crate::prompt::{Prompt, ToPrompt};
16use crate::schema::GenerationSchema;
17
18fn swift_dup_string(value: &str) -> *mut c_char {
19 let c_string = CString::new(value).expect("bridge strings must not contain interior NUL bytes");
20 unsafe { ffi::fm_string_dup(c_string.as_ptr()) }
21}
22
23pub struct Tool {
25 spec: ToolSpec,
26 handler: Arc<dyn Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync>,
27}
28
29impl Tool {
30 #[must_use]
32 pub fn new<F>(
33 name: impl Into<String>,
34 description: impl Into<String>,
35 parameters: GenerationSchema,
36 handler: F,
37 ) -> Self
38 where
39 F: Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync + 'static,
40 {
41 Self {
42 spec: ToolSpec {
43 name: name.into(),
44 description: description.into(),
45 parameters,
46 includes_schema_in_instructions: true,
47 },
48 handler: Arc::new(handler),
49 }
50 }
51
52 #[must_use]
54 pub fn json<Args, Output, F>(
55 name: impl Into<String>,
56 description: impl Into<String>,
57 parameters: GenerationSchema,
58 handler: F,
59 ) -> Self
60 where
61 Args: DeserializeOwned + Send + 'static,
62 Output: ToPrompt,
63 F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
64 {
65 Self::new(name, description, parameters, move |arguments| {
66 let decoded = arguments.value::<Args>()?;
67 let output = handler(decoded)?;
68 Ok(ToolOutput::from_prompt(output.to_prompt()?))
69 })
70 }
71
72 #[must_use]
74 pub fn with_schema_in_instructions(mut self, includes: bool) -> Self {
75 self.spec.includes_schema_in_instructions = includes;
76 self
77 }
78}
79
80impl core::fmt::Debug for Tool {
81 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
82 f.debug_struct("Tool").field("spec", &self.spec).finish()
83 }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct ToolSpec {
89 pub name: String,
90 pub description: String,
91 pub parameters: GenerationSchema,
92 pub includes_schema_in_instructions: bool,
93}
94
95#[derive(Debug, Clone, PartialEq)]
97pub struct ToolOutput {
98 prompt: Prompt,
99}
100
101impl ToolOutput {
102 #[must_use]
104 pub fn text(text: impl Into<String>) -> Self {
105 Self {
106 prompt: Prompt::from(text.into()),
107 }
108 }
109
110 #[must_use]
112 pub fn structured(content: GeneratedContent) -> Self {
113 Self {
114 prompt: Prompt::from(content),
115 }
116 }
117
118 #[must_use]
120 pub const fn from_prompt(prompt: Prompt) -> Self {
121 Self { prompt }
122 }
123
124 #[must_use]
125 pub fn prompt(&self) -> &Prompt {
126 &self.prompt
127 }
128
129 pub(crate) fn to_bridge_json(&self) -> Result<String, FMError> {
130 serde_json::to_string(&json!({ "prompt": self.prompt.to_bridge_value() })).map_err(
131 |error| {
132 FMError::InvalidArgument(format!("tool output is not JSON-serializable: {error}"))
133 },
134 )
135 }
136}
137
138impl From<String> for ToolOutput {
139 fn from(value: String) -> Self {
140 Self::text(value)
141 }
142}
143
144impl From<&str> for ToolOutput {
145 fn from(value: &str) -> Self {
146 Self::text(value)
147 }
148}
149
150impl From<GeneratedContent> for ToolOutput {
151 fn from(value: GeneratedContent) -> Self {
152 Self::structured(value)
153 }
154}
155
156impl From<Prompt> for ToolOutput {
157 fn from(value: Prompt) -> Self {
158 Self::from_prompt(value)
159 }
160}
161
162pub(crate) struct ToolRegistry {
163 tools: HashMap<String, Tool>,
164}
165
166impl ToolRegistry {
167 pub(crate) fn new(tools: Vec<Tool>) -> Self {
168 Self {
169 tools: tools
170 .into_iter()
171 .map(|tool| (tool.spec.name.clone(), tool))
172 .collect(),
173 }
174 }
175
176 pub(crate) fn specs_json(&self) -> Result<String, FMError> {
177 let specs = self
178 .tools
179 .values()
180 .map(|tool| {
181 json!({
182 "name": tool.spec.name,
183 "description": tool.spec.description,
184 "parametersJSON": tool.spec.parameters.json_schema(),
185 "includesSchemaInInstructions": tool.spec.includes_schema_in_instructions,
186 })
187 })
188 .collect::<Vec<_>>();
189 serde_json::to_string(&specs).map_err(|error| {
190 FMError::InvalidArgument(format!("tool specs are not JSON-serializable: {error}"))
191 })
192 }
193
194 fn invoke(&self, tool_name: &str, arguments: GeneratedContent) -> Result<ToolOutput, FMError> {
195 let tool = self.tools.get(tool_name).ok_or_else(|| {
196 FMError::ToolCallFailed(format!("tool `{tool_name}` is not registered"))
197 })?;
198 (tool.handler)(arguments)
199 }
200}
201
202pub(crate) unsafe extern "C" fn tool_callback_trampoline(
203 context: *mut c_void,
204 tool_name: *const c_char,
205 arguments_json: *const c_char,
206 output_json_out: *mut *mut c_char,
207 error_out: *mut *mut c_char,
208) -> i32 {
209 let registry = &*(context.cast::<ToolRegistry>());
210 let result = catch_unwind(AssertUnwindSafe(|| {
211 let tool_name = CStr::from_ptr(tool_name).to_string_lossy().into_owned();
212 let arguments_json = CStr::from_ptr(arguments_json)
213 .to_string_lossy()
214 .into_owned();
215 let arguments = GeneratedContent::from_json_str(&arguments_json)?;
216 let output = registry.invoke(&tool_name, arguments)?;
217 output.to_bridge_json()
218 }));
219
220 match result {
221 Ok(Ok(output_json)) => {
222 *output_json_out = swift_dup_string(&output_json);
223 ffi::status::OK
224 }
225 Ok(Err(error)) => {
226 *error_out = swift_dup_string(error.message());
227 error.code()
228 }
229 Err(_) => {
230 *error_out = swift_dup_string("tool callback panicked");
231 ffi::status::TOOL_CALL_FAILED
232 }
233 }
234}