llm_tool/rust_tool.rs
1//! Strongly-typed Rust tool trait and type-erasure machinery.
2
3use std::{future::Future, pin::Pin};
4
5use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
6
7/// Convenience type for tools that take no parameters.
8#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
9pub struct EmptyParams {}
10
11/// A custom tool implemented entirely in Rust with strongly-typed parameters.
12///
13/// Define your parameters as a struct deriving [`serde::Deserialize`] and
14/// `JsonSchema`, then implement this trait to provide the tool's logic.
15/// The JSON Schema sent to the model is derived automatically from the
16/// params struct — doc comments on fields become parameter descriptions.
17///
18/// Tools are async: for I/O-bound work (HTTP, filesystem, subprocess) the
19/// runtime stays unblocked. Sync tools just don't `.await` anything — the
20/// compiler optimizes the state machine to an immediate return.
21///
22/// # Example
23///
24/// ```rust
25/// use llm_tool::{JsonSchema, RustTool, ToolContext, ToolError, ToolOutput};
26/// use serde::Deserialize;
27///
28/// #[derive(Deserialize, JsonSchema)]
29/// struct FlashParams {
30/// /// Target device identifier.
31/// device_id: String,
32/// /// Path to the firmware image.
33/// image_path: String,
34/// }
35///
36/// struct FlashDevice;
37///
38/// impl RustTool for FlashDevice {
39/// type Params = FlashParams;
40/// const NAME: &'static str = "flash_device";
41/// const DESCRIPTION: &'static str = "Flashes firmware to a connected device.";
42///
43/// async fn call(
44/// &self,
45/// params: Self::Params,
46/// _ctx: &ToolContext,
47/// ) -> Result<ToolOutput, ToolError> {
48/// Ok(format!(
49/// "Flashed {} to {}",
50/// params.image_path, params.device_id
51/// ).into())
52/// }
53/// }
54/// ```
55#[allow(async_fn_in_trait)]
56pub trait RustTool: Send + Sync {
57 /// The strongly-typed parameters struct.
58 ///
59 /// Derive [`serde::Deserialize`] and `JsonSchema` on your params struct.
60 /// `JsonSchema` auto-generates the parameter schema sent to the model;
61 /// `Deserialize` parses the model's JSON arguments into your struct.
62 type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
63
64 /// Unique tool name (e.g. `"flash_device"`).
65 const NAME: &'static str;
66
67 /// Human-readable description shown to the model.
68 const DESCRIPTION: &'static str;
69
70 /// Execute the tool with typed parameters and an execution context.
71 ///
72 /// Async to support I/O-bound tools (HTTP, filesystem, subprocess).
73 /// Sync tools just compute and return — the async wrapper is zero-cost.
74 ///
75 /// The `ctx` parameter provides access to conversation metadata and a
76 /// shared key-value state store. Tools that don't need context can simply
77 /// ignore it with `_ctx`.
78 ///
79 /// # Errors
80 ///
81 /// Returns `Err(ToolError)` if the tool execution fails.
82 fn call(
83 &self,
84 params: Self::Params,
85 ctx: &ToolContext,
86 ) -> impl std::future::Future<Output = Result<ToolOutput, ToolError>> + Send;
87}
88
89/// Build a [`ToolDefinition`] from any [`RustTool`] implementor.
90///
91/// The generated schema is sanitized to be compatible with the Go-based
92/// localharness, which expects `"type"` to always be a single string
93/// (not the array form `["string", "null"]` that schemars emits for
94/// `Option<T>` fields).
95///
96/// # Errors
97///
98/// Returns `Err` if the JSON schema serialization fails.
99pub fn definition_of<T: RustTool>(_tool: &T) -> Result<ToolDefinition, ToolError> {
100 let schema = schemars::schema_for!(T::Params);
101 let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
102 ToolError::new(format!(
103 "Failed to serialize schema for tool '{}': {e}",
104 T::NAME
105 ))
106 })?;
107 sanitize_schema_types(&mut parameter_schema);
108 Ok(ToolDefinition {
109 name: T::NAME.to_string(),
110 description: T::DESCRIPTION.to_string(),
111 parameter_schema,
112 })
113}
114
115/// Recursively sanitize JSON Schema `"type"` fields for Go genai compatibility.
116///
117/// `schemars` emits `"type": ["string", "null"]` for `Option<String>` fields
118/// (JSON Schema draft 7 nullable syntax). The Go genai SDK's `Schema.Type`
119/// is a single `genai.Type` enum, so it can't unmarshal an array.
120///
121/// This function walks the schema tree and replaces any array `type` with the
122/// first non-`"null"` element. For example:
123/// - `["string", "null"]` → `"string"`
124/// - `["integer", "null"]` → `"integer"`
125fn sanitize_schema_types(value: &mut serde_json::Value) {
126 match value {
127 serde_json::Value::Object(map) => {
128 // If "type" is an array (e.g. ["string", "null"]), pick the first
129 // non-"null" element and replace with it as a scalar type.
130 let replacement = match map.get("type") {
131 Some(serde_json::Value::Array(arr)) => {
132 let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
133 non_null.or_else(|| arr.first().cloned())
134 }
135 _ => None,
136 };
137 if let Some(val) = replacement {
138 map.insert("type".to_string(), val);
139 }
140 for val in map.values_mut() {
141 sanitize_schema_types(val);
142 }
143 }
144 serde_json::Value::Array(arr) => {
145 for item in arr {
146 sanitize_schema_types(item);
147 }
148 }
149 _ => {}
150 }
151}
152
153/// Type-erased future returned by [`ErasedTool::call_erased`].
154type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
155
156/// Type-erased wrapper enabling heterogeneous tool storage.
157///
158/// Boxes the future from [`RustTool::call`] so we can store different tool
159/// types in the same `HashMap<String, Box<dyn ErasedTool>>`.
160pub(crate) trait ErasedTool: Send + Sync {
161 /// Return this tool's metadata.
162 fn definition(&self) -> Result<ToolDefinition, ToolError>;
163 /// Deserialize `args` and call the handler, returning a boxed future.
164 fn call_erased<'a>(
165 &'a self,
166 args: serde_json::Value,
167 ctx: &'a ToolContext,
168 ) -> BoxToolFuture<'a>;
169}
170
171impl<T: RustTool> ErasedTool for T {
172 fn definition(&self) -> Result<ToolDefinition, ToolError> {
173 definition_of(self)
174 }
175
176 fn call_erased<'a>(
177 &'a self,
178 args: serde_json::Value,
179 ctx: &'a ToolContext,
180 ) -> BoxToolFuture<'a> {
181 Box::pin(async move {
182 let params: T::Params = serde_json::from_value(args).map_err(|e| {
183 ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
184 })?;
185 self.call(params, ctx).await
186 })
187 }
188}