1use core::ffi::{c_char, c_void};
4use std::collections::HashMap;
5use std::ffi::{CStr, CString};
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::Arc;
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use crate::content::{FromGeneratedContent, GeneratedContent};
13use crate::error::FMError;
14use crate::ffi;
15use crate::prompt::{Prompt, ToPrompt, ToolDefinition};
16use crate::schema::{Generable, GenerationSchema};
17
18fn swift_dup_string(value: &str) -> *mut c_char {
19 let c_string = CString::new(value).expect("bridge strings must not contain interior NUL bytes");
20 unsafe { ffi::fm_string_dup(c_string.as_ptr()) }
21}
22
23pub struct Tool {
25 spec: ToolSpec,
26 handler: Arc<dyn Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync>,
27}
28
29impl Tool {
30 #[must_use]
32 pub fn new<F>(
33 name: impl Into<String>,
34 description: impl Into<String>,
35 parameters: GenerationSchema,
36 handler: F,
37 ) -> Self
38 where
39 F: Fn(GeneratedContent) -> Result<ToolOutput, FMError> + Send + Sync + 'static,
40 {
41 Self {
42 spec: ToolSpec {
43 name: name.into(),
44 description: description.into(),
45 parameters,
46 includes_schema_in_instructions: true,
47 },
48 handler: Arc::new(handler),
49 }
50 }
51
52 #[must_use]
54 pub fn json<Args, Output, F>(
55 name: impl Into<String>,
56 description: impl Into<String>,
57 parameters: GenerationSchema,
58 handler: F,
59 ) -> Self
60 where
61 Args: DeserializeOwned + Send + 'static,
62 Output: ToPrompt,
63 F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
64 {
65 Self::new(name, description, parameters, move |arguments| {
66 let decoded = arguments.value::<Args>()?;
67 let output = handler(decoded)?;
68 Ok(ToolOutput::from_prompt(output.to_prompt()?))
69 })
70 }
71
72 pub fn generable<Args, Output, F>(
78 name: impl Into<String>,
79 description: impl Into<String>,
80 handler: F,
81 ) -> Result<Self, FMError>
82 where
83 Args: FromGeneratedContent + Generable + Send + 'static,
84 Output: ToPrompt,
85 F: Fn(Args) -> Result<Output, FMError> + Send + Sync + 'static,
86 {
87 Ok(Self::new(
88 name,
89 description,
90 Args::generation_schema()?,
91 move |arguments| {
92 let decoded = Args::from_generated_content(&arguments)?;
93 let output = handler(decoded)?;
94 Ok(ToolOutput::from_prompt(output.to_prompt()?))
95 },
96 ))
97 }
98
99 #[must_use]
101 pub const fn spec(&self) -> &ToolSpec {
102 &self.spec
103 }
104
105 #[must_use]
107 pub fn definition(&self) -> ToolDefinition {
108 self.spec.definition()
109 }
110
111 #[must_use]
113 pub fn with_schema_in_instructions(mut self, includes: bool) -> Self {
114 self.spec.includes_schema_in_instructions = includes;
115 self
116 }
117}
118
119impl core::fmt::Debug for Tool {
120 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121 f.debug_struct("Tool").field("spec", &self.spec).finish()
122 }
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct ToolSpec {
128 pub name: String,
129 pub description: String,
130 pub parameters: GenerationSchema,
131 pub includes_schema_in_instructions: bool,
132}
133
134impl ToolSpec {
135 #[must_use]
137 pub fn definition(&self) -> ToolDefinition {
138 ToolDefinition::new(
139 self.name.clone(),
140 self.description.clone(),
141 self.parameters.clone(),
142 )
143 }
144}
145
146#[derive(Debug, Clone, PartialEq)]
148pub struct ToolOutput {
149 prompt: Prompt,
150}
151
152impl ToolOutput {
153 #[must_use]
155 pub fn text(text: impl Into<String>) -> Self {
156 Self {
157 prompt: Prompt::from(text.into()),
158 }
159 }
160
161 #[must_use]
163 pub fn structured(content: GeneratedContent) -> Self {
164 Self {
165 prompt: Prompt::from(content),
166 }
167 }
168
169 #[must_use]
171 pub const fn from_prompt(prompt: Prompt) -> Self {
172 Self { prompt }
173 }
174
175 #[must_use]
176 pub fn prompt(&self) -> &Prompt {
177 &self.prompt
178 }
179
180 pub(crate) fn to_bridge_json(&self) -> Result<String, FMError> {
181 serde_json::to_string(&json!({ "prompt": self.prompt.to_bridge_value() })).map_err(
182 |error| {
183 FMError::InvalidArgument(format!("tool output is not JSON-serializable: {error}"))
184 },
185 )
186 }
187}
188
189impl From<String> for ToolOutput {
190 fn from(value: String) -> Self {
191 Self::text(value)
192 }
193}
194
195impl From<&str> for ToolOutput {
196 fn from(value: &str) -> Self {
197 Self::text(value)
198 }
199}
200
201impl From<GeneratedContent> for ToolOutput {
202 fn from(value: GeneratedContent) -> Self {
203 Self::structured(value)
204 }
205}
206
207impl From<Prompt> for ToolOutput {
208 fn from(value: Prompt) -> Self {
209 Self::from_prompt(value)
210 }
211}
212
213pub(crate) struct ToolRegistry {
214 tools: HashMap<String, Tool>,
215}
216
217impl ToolRegistry {
218 pub(crate) fn new(tools: Vec<Tool>) -> Self {
219 Self {
220 tools: tools
221 .into_iter()
222 .map(|tool| (tool.spec.name.clone(), tool))
223 .collect(),
224 }
225 }
226
227 pub(crate) fn specs_json(&self) -> Result<String, FMError> {
228 let specs = self
229 .tools
230 .values()
231 .map(|tool| {
232 json!({
233 "name": tool.spec.name,
234 "description": tool.spec.description,
235 "parametersJSON": tool.spec.parameters.json_schema(),
236 "includesSchemaInInstructions": tool.spec.includes_schema_in_instructions,
237 })
238 })
239 .collect::<Vec<_>>();
240 serde_json::to_string(&specs).map_err(|error| {
241 FMError::InvalidArgument(format!("tool specs are not JSON-serializable: {error}"))
242 })
243 }
244
245 fn invoke(&self, tool_name: &str, arguments: GeneratedContent) -> Result<ToolOutput, FMError> {
246 let tool = self.tools.get(tool_name).ok_or_else(|| {
247 FMError::ToolCallFailed(format!("tool `{tool_name}` is not registered"))
248 })?;
249 (tool.handler)(arguments)
250 }
251}
252
253pub(crate) unsafe extern "C" fn tool_callback_trampoline(
254 context: *mut c_void,
255 tool_name: *const c_char,
256 arguments_json: *const c_char,
257 output_json_out: *mut *mut c_char,
258 error_out: *mut *mut c_char,
259) -> i32 {
260 let registry = &*(context.cast::<ToolRegistry>());
261 let result = catch_unwind(AssertUnwindSafe(|| {
262 let tool_name = CStr::from_ptr(tool_name).to_string_lossy().into_owned();
263 let arguments_json = CStr::from_ptr(arguments_json)
264 .to_string_lossy()
265 .into_owned();
266 let arguments = GeneratedContent::from_json_str(&arguments_json)?;
267 let output = registry.invoke(&tool_name, arguments)?;
268 output.to_bridge_json()
269 }));
270
271 match result {
272 Ok(Ok(output_json)) => {
273 *output_json_out = swift_dup_string(&output_json);
274 ffi::status::OK
275 }
276 Ok(Err(error)) => {
277 *error_out = swift_dup_string(error.message());
278 error.code()
279 }
280 Err(_) => {
281 *error_out = swift_dup_string("tool callback panicked");
282 ffi::status::TOOL_CALL_FAILED
283 }
284 }
285}