Skip to main content

llm_tool/
rust_tool.rs

1//! Strongly-typed Rust tool trait and type-erasure machinery.
2
3use std::{future::Future, pin::Pin};
4
5use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
6
7/// Convenience type for tools that take no parameters.
8#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
9pub struct EmptyParams {}
10
11/// A custom tool implemented entirely in Rust with strongly-typed parameters.
12///
13/// Define your parameters as a struct deriving [`serde::Deserialize`] and
14/// `JsonSchema`, then implement this trait to provide the tool's logic.
15/// The JSON Schema sent to the model is derived automatically from the
16/// params struct — doc comments on fields become parameter descriptions.
17///
18/// Tools are async: for I/O-bound work (HTTP, filesystem, subprocess) the
19/// runtime stays unblocked. Sync tools just don't `.await` anything — the
20/// compiler optimizes the state machine to an immediate return.
21///
22/// # Example
23///
24/// ```rust
25/// use llm_tool::{JsonSchema, RustTool, ToolContext, ToolError, ToolOutput};
26/// use serde::Deserialize;
27///
28/// #[derive(Deserialize, JsonSchema)]
29/// struct FlashParams {
30///     /// Target device identifier.
31///     device_id: String,
32///     /// Path to the firmware image.
33///     image_path: String,
34/// }
35///
36/// struct FlashDevice;
37///
38/// impl RustTool for FlashDevice {
39///     type Params = FlashParams;
40///     const NAME: &'static str = "flash_device";
41///     const DESCRIPTION: &'static str = "Flashes firmware to a connected device.";
42///
43///     async fn call(
44///         &self,
45///         params: Self::Params,
46///         _ctx: &ToolContext,
47///     ) -> Result<ToolOutput, ToolError> {
48///         Ok(format!(
49///             "Flashed {} to {}",
50///             params.image_path, params.device_id
51///         ).into())
52///     }
53/// }
54/// ```
55#[allow(async_fn_in_trait)]
56pub trait RustTool: Send + Sync {
57    /// The strongly-typed parameters struct.
58    ///
59    /// Derive [`serde::Deserialize`] and `JsonSchema` on your params struct.
60    /// `JsonSchema` auto-generates the parameter schema sent to the model;
61    /// `Deserialize` parses the model's JSON arguments into your struct.
62    type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
63
64    /// Unique tool name (e.g. `"flash_device"`).
65    const NAME: &'static str;
66
67    /// Human-readable description shown to the model.
68    const DESCRIPTION: &'static str;
69
70    /// Execute the tool with typed parameters and an execution context.
71    ///
72    /// Async to support I/O-bound tools (HTTP, filesystem, subprocess).
73    /// Sync tools just compute and return — the async wrapper is zero-cost.
74    ///
75    /// The `ctx` parameter provides access to conversation metadata and a
76    /// shared key-value state store. Tools that don't need context can simply
77    /// ignore it with `_ctx`.
78    ///
79    /// # Errors
80    ///
81    /// Returns `Err(ToolError)` if the tool execution fails.
82    fn call(
83        &self,
84        params: Self::Params,
85        ctx: &ToolContext,
86    ) -> impl std::future::Future<Output = Result<ToolOutput, ToolError>> + Send;
87}
88
89/// Build a [`ToolDefinition`] from any [`RustTool`] implementor.
90///
91/// The generated schema is sanitized to be compatible with the Go-based
92/// localharness, which expects `"type"` to always be a single string
93/// (not the array form `["string", "null"]` that schemars emits for
94/// `Option<T>` fields).
95///
96/// # Errors
97///
98/// Returns `Err` if the JSON schema serialization fails.
99pub fn definition_of<T: RustTool>(_tool: &T) -> Result<ToolDefinition, ToolError> {
100    let schema = schemars::schema_for!(T::Params);
101    let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
102        ToolError::new(format!(
103            "Failed to serialize schema for tool '{}': {e}",
104            T::NAME
105        ))
106    })?;
107    sanitize_schema_types(&mut parameter_schema);
108    Ok(ToolDefinition {
109        name: T::NAME.to_string(),
110        description: T::DESCRIPTION.to_string(),
111        parameter_schema,
112    })
113}
114
115/// Recursively sanitize JSON Schema `"type"` fields for Go genai compatibility.
116///
117/// `schemars` emits `"type": ["string", "null"]` for `Option<String>` fields
118/// (JSON Schema draft 7 nullable syntax). The Go genai SDK's `Schema.Type`
119/// is a single `genai.Type` enum, so it can't unmarshal an array.
120///
121/// This function walks the schema tree and replaces any array `type` with the
122/// first non-`"null"` element. For example:
123/// - `["string", "null"]` → `"string"`
124/// - `["integer", "null"]` → `"integer"`
125fn sanitize_schema_types(value: &mut serde_json::Value) {
126    match value {
127        serde_json::Value::Object(map) => {
128            // If "type" is an array (e.g. ["string", "null"]), pick the first
129            // non-"null" element and replace with it as a scalar type.
130            let replacement = match map.get("type") {
131                Some(serde_json::Value::Array(arr)) => {
132                    let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
133                    non_null.or_else(|| arr.first().cloned())
134                }
135                _ => None,
136            };
137            if let Some(val) = replacement {
138                map.insert("type".to_string(), val);
139            }
140            for val in map.values_mut() {
141                sanitize_schema_types(val);
142            }
143        }
144        serde_json::Value::Array(arr) => {
145            for item in arr {
146                sanitize_schema_types(item);
147            }
148        }
149        _ => {}
150    }
151}
152
153/// Type-erased future returned by [`ErasedTool::call_erased`].
154type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
155
156/// Type-erased wrapper enabling heterogeneous tool storage.
157///
158/// Boxes the future from [`RustTool::call`] so we can store different tool
159/// types in the same `HashMap<String, Box<dyn ErasedTool>>`.
160pub(crate) trait ErasedTool: Send + Sync {
161    /// Return this tool's metadata.
162    fn definition(&self) -> Result<ToolDefinition, ToolError>;
163    /// Deserialize `args` and call the handler, returning a boxed future.
164    fn call_erased<'a>(
165        &'a self,
166        args: serde_json::Value,
167        ctx: &'a ToolContext,
168    ) -> BoxToolFuture<'a>;
169}
170
171impl<T: RustTool> ErasedTool for T {
172    fn definition(&self) -> Result<ToolDefinition, ToolError> {
173        definition_of(self)
174    }
175
176    fn call_erased<'a>(
177        &'a self,
178        args: serde_json::Value,
179        ctx: &'a ToolContext,
180    ) -> BoxToolFuture<'a> {
181        Box::pin(async move {
182            let params: T::Params = serde_json::from_value(args).map_err(|e| {
183                ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
184            })?;
185            self.call(params, ctx).await
186        })
187    }
188}