llm_tool/rust_tool.rs
1//! Strongly-typed Rust tool trait and type-erasure machinery.
2
3use alloc::{borrow::Cow, boxed::Box, format, string::ToString};
4use core::{future::Future, pin::Pin};
5
6use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
7
8/// Convenience type for tools that take no parameters.
9#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
10pub struct EmptyParams {}
11
12/// A custom tool implemented entirely in Rust with strongly-typed parameters.
13///
14/// Define your parameters as a struct deriving [`serde::Deserialize`] and
15/// `JsonSchema`, then implement this trait to provide the tool's logic.
16/// The JSON Schema sent to the model is derived automatically from the
17/// params struct — doc comments on fields become parameter descriptions.
18///
19/// Tools are async: for I/O-bound work (HTTP, filesystem, subprocess) the
20/// runtime stays unblocked. Sync tools just don't `.await` anything — the
21/// compiler optimizes the state machine to an immediate return.
22///
23/// # Example
24///
25/// ```rust
26/// use llm_tool::{JsonSchema, RustTool, ToolContext, ToolError, ToolOutput};
27/// use serde::Deserialize;
28///
29/// #[derive(Deserialize, JsonSchema)]
30/// struct FlashParams {
31/// /// Target device identifier.
32/// device_id: String,
33/// /// Path to the firmware image.
34/// image_path: String,
35/// }
36///
37/// struct FlashDevice;
38///
39/// impl RustTool for FlashDevice {
40/// type Params = FlashParams;
41/// const NAME: &'static str = "flash_device";
42/// const DESCRIPTION: &'static str = "Flashes firmware to a connected device.";
43///
44/// async fn call(
45/// &self,
46/// params: Self::Params,
47/// _ctx: &ToolContext,
48/// ) -> Result<ToolOutput, ToolError> {
49/// Ok(format!("Flashed {} to {}", params.image_path, params.device_id).into())
50/// }
51/// }
52/// ```
53pub trait RustTool: Send + Sync {
54 /// The strongly-typed parameters struct.
55 ///
56 /// Derive [`serde::Deserialize`] and `JsonSchema` on your params struct.
57 /// `JsonSchema` auto-generates the parameter schema sent to the model;
58 /// `Deserialize` parses the model's JSON arguments into your struct.
59 type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
60
61 /// Unique tool name (e.g. `"flash_device"`).
62 const NAME: &'static str;
63
64 /// Human-readable description shown to the model.
65 const DESCRIPTION: &'static str;
66
67 /// Return the tool description used in [`ToolDefinition`].
68 ///
69 /// The default returns [`Self::DESCRIPTION`] (the static string from a
70 /// doc comment or template body). When using
71 /// `#[llm_tool(template = "...", context = ...)]`, the generated
72 /// implementation overrides this to render the template with runtime
73 /// variables on each call. Templates are parsed once via `LazyLock`.
74 fn description(&self) -> Cow<'static, str> {
75 Cow::Borrowed(Self::DESCRIPTION)
76 }
77
78 /// Execute the tool with typed parameters and an execution context.
79 ///
80 /// Async to support I/O-bound tools (HTTP, filesystem, subprocess).
81 /// Sync tools just compute and return — the async wrapper is zero-cost.
82 ///
83 /// The `ctx` parameter provides access to conversation metadata and a
84 /// shared key-value state store. Tools that don't need context can simply
85 /// ignore it with `_ctx`.
86 ///
87 /// # Errors
88 ///
89 /// Returns `Err(ToolError)` if the tool execution fails.
90 fn call(
91 &self,
92 params: Self::Params,
93 ctx: &ToolContext,
94 ) -> impl Future<Output = Result<ToolOutput, ToolError>> + Send;
95}
96
97/// Build a [`ToolDefinition`] from any [`RustTool`] implementor.
98///
99/// The generated schema is sanitized for broad compatibility with LLM
100/// SDKs that expect `"type"` to always be a single string
101/// (not the array form `["string", "null"]` that schemars emits for
102/// `Option<T>` fields).
103///
104/// # Errors
105///
106/// Returns `Err` if the JSON schema serialization fails.
107pub fn definition_of<T: RustTool>(tool: &T) -> Result<ToolDefinition, ToolError> {
108 let schema = schemars::schema_for!(T::Params);
109 let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
110 ToolError::new(format!(
111 "Failed to serialize schema for tool '{}': {e}",
112 T::NAME
113 ))
114 })?;
115 sanitize_schema_types(&mut parameter_schema);
116 Ok(ToolDefinition {
117 name: T::NAME.to_string(),
118 description: tool.description().into_owned(),
119 parameter_schema,
120 })
121}
122
123/// Recursively sanitize JSON Schema `"type"` fields for Go genai compatibility.
124///
125/// `schemars` emits `"type": ["string", "null"]` for `Option<String>` fields
126/// (JSON Schema draft 7 nullable syntax). The Go genai SDK's `Schema.Type`
127/// is a single `genai.Type` enum, so it can't unmarshal an array.
128///
129/// This function walks the schema tree and replaces any array `type` with the
130/// first non-`"null"` element. For example:
131/// - `["string", "null"]` → `"string"`
132/// - `["integer", "null"]` → `"integer"`
133fn sanitize_schema_types(value: &mut serde_json::Value) {
134 match value {
135 serde_json::Value::Object(map) => {
136 // If "type" is an array (e.g. ["string", "null"]), pick the first
137 // non-"null" element and replace with it as a scalar type.
138 let replacement = match map.get("type") {
139 Some(serde_json::Value::Array(arr)) => {
140 let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
141 non_null.or_else(|| arr.first().cloned())
142 }
143 _ => None,
144 };
145 if let Some(val) = replacement {
146 map.insert("type".to_string(), val);
147 }
148 for val in map.values_mut() {
149 sanitize_schema_types(val);
150 }
151 }
152 serde_json::Value::Array(arr) => {
153 for item in arr {
154 sanitize_schema_types(item);
155 }
156 }
157 _ => {}
158 }
159}
160
161/// Type-erased future returned by [`ErasedTool::call_erased`].
162type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
163
164/// Type-erased wrapper enabling heterogeneous tool storage.
165///
166/// Boxes the future from [`RustTool::call`] so we can store different tool
167/// types in the same `HashMap<String, Box<dyn ErasedTool>>`.
168pub(crate) trait ErasedTool: Send + Sync {
169 /// Deserialize `args` and call the handler, returning a boxed future.
170 fn call_erased<'a>(
171 &'a self,
172 args: serde_json::Value,
173 ctx: &'a ToolContext,
174 ) -> BoxToolFuture<'a>;
175}
176
177impl<T: RustTool> ErasedTool for T {
178 fn call_erased<'a>(
179 &'a self,
180 args: serde_json::Value,
181 ctx: &'a ToolContext,
182 ) -> BoxToolFuture<'a> {
183 Box::pin(async move {
184 let params: T::Params = serde_json::from_value(args).map_err(|e| {
185 ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
186 })?;
187 self.call(params, ctx).await
188 })
189 }
190}