Skip to main content

llm_tool/
rust_tool.rs

1//! Strongly-typed Rust tool trait and type-erasure machinery.
2
3use alloc::{borrow::Cow, boxed::Box, format, string::ToString};
4use core::{future::Future, pin::Pin};
5
6use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
7
8/// Convenience type for tools that take no parameters.
9#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
10pub struct EmptyParams {}
11
12/// A custom tool implemented entirely in Rust with strongly-typed parameters.
13///
14/// Define your parameters as a struct deriving [`serde::Deserialize`] and
15/// `JsonSchema`, then implement this trait to provide the tool's logic.
16/// The JSON Schema sent to the model is derived automatically from the
17/// params struct — doc comments on fields become parameter descriptions.
18///
19/// Tools are async: for I/O-bound work (HTTP, filesystem, subprocess) the
20/// runtime stays unblocked. Sync tools just don't `.await` anything — the
21/// compiler optimizes the state machine to an immediate return.
22///
23/// # Example
24///
25/// ```rust
26/// use llm_tool::{JsonSchema, RustTool, ToolContext, ToolError, ToolOutput};
27/// use serde::Deserialize;
28///
29/// #[derive(Deserialize, JsonSchema)]
30/// struct FlashParams {
31///     /// Target device identifier.
32///     device_id: String,
33///     /// Path to the firmware image.
34///     image_path: String,
35/// }
36///
37/// struct FlashDevice;
38///
39/// impl RustTool for FlashDevice {
40///     type Params = FlashParams;
41///     const NAME: &'static str = "flash_device";
42///     const DESCRIPTION: &'static str = "Flashes firmware to a connected device.";
43///
44///     async fn call(
45///         &self,
46///         params: Self::Params,
47///         _ctx: &ToolContext,
48///     ) -> Result<ToolOutput, ToolError> {
49///         Ok(format!("Flashed {} to {}", params.image_path, params.device_id).into())
50///     }
51/// }
52/// ```
53pub trait RustTool: Send + Sync {
54    /// The strongly-typed parameters struct.
55    ///
56    /// Derive [`serde::Deserialize`] and `JsonSchema` on your params struct.
57    /// `JsonSchema` auto-generates the parameter schema sent to the model;
58    /// `Deserialize` parses the model's JSON arguments into your struct.
59    type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
60
61    /// Unique tool name (e.g. `"flash_device"`).
62    const NAME: &'static str;
63
64    /// Human-readable description shown to the model.
65    const DESCRIPTION: &'static str;
66
67    /// Return the tool description used in [`ToolDefinition`].
68    ///
69    /// The default returns [`Self::DESCRIPTION`] (the static string from a
70    /// doc comment or template body). When using
71    /// `#[llm_tool(template = "...", context = ...)]`, the generated
72    /// implementation overrides this to render the template with runtime
73    /// variables on each call. Templates are parsed once via `LazyLock`.
74    fn description(&self) -> Cow<'static, str> {
75        Cow::Borrowed(Self::DESCRIPTION)
76    }
77
78    /// Execute the tool with typed parameters and an execution context.
79    ///
80    /// Async to support I/O-bound tools (HTTP, filesystem, subprocess).
81    /// Sync tools just compute and return — the async wrapper is zero-cost.
82    ///
83    /// The `ctx` parameter provides access to conversation metadata and a
84    /// shared key-value state store. Tools that don't need context can simply
85    /// ignore it with `_ctx`.
86    ///
87    /// # Errors
88    ///
89    /// Returns `Err(ToolError)` if the tool execution fails.
90    fn call(
91        &self,
92        params: Self::Params,
93        ctx: &ToolContext,
94    ) -> impl Future<Output = Result<ToolOutput, ToolError>> + Send;
95}
96
97/// Build a [`ToolDefinition`] from any [`RustTool`] implementor.
98///
99/// The generated schema is sanitized for broad compatibility with LLM
100/// SDKs that expect `"type"` to always be a single string
101/// (not the array form `["string", "null"]` that schemars emits for
102/// `Option<T>` fields).
103///
104/// # Errors
105///
106/// Returns `Err` if the JSON schema serialization fails.
107pub fn definition_of<T: RustTool>(tool: &T) -> Result<ToolDefinition, ToolError> {
108    let schema = schemars::schema_for!(T::Params);
109    let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
110        ToolError::new(format!(
111            "Failed to serialize schema for tool '{}': {e}",
112            T::NAME
113        ))
114    })?;
115    sanitize_schema_types(&mut parameter_schema);
116    Ok(ToolDefinition {
117        name: T::NAME.to_string(),
118        description: tool.description().into_owned(),
119        parameter_schema,
120    })
121}
122
123/// Recursively sanitize JSON Schema `"type"` fields for Go genai compatibility.
124///
125/// `schemars` emits `"type": ["string", "null"]` for `Option<String>` fields
126/// (JSON Schema draft 7 nullable syntax). The Go genai SDK's `Schema.Type`
127/// is a single `genai.Type` enum, so it can't unmarshal an array.
128///
129/// This function walks the schema tree and replaces any array `type` with the
130/// first non-`"null"` element. For example:
131/// - `["string", "null"]` → `"string"`
132/// - `["integer", "null"]` → `"integer"`
133fn sanitize_schema_types(value: &mut serde_json::Value) {
134    match value {
135        serde_json::Value::Object(map) => {
136            // If "type" is an array (e.g. ["string", "null"]), pick the first
137            // non-"null" element and replace with it as a scalar type.
138            let replacement = match map.get("type") {
139                Some(serde_json::Value::Array(arr)) => {
140                    let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
141                    non_null.or_else(|| arr.first().cloned())
142                }
143                _ => None,
144            };
145            if let Some(val) = replacement {
146                map.insert("type".to_string(), val);
147            }
148            for val in map.values_mut() {
149                sanitize_schema_types(val);
150            }
151        }
152        serde_json::Value::Array(arr) => {
153            for item in arr {
154                sanitize_schema_types(item);
155            }
156        }
157        _ => {}
158    }
159}
160
161/// Type-erased future returned by [`ErasedTool::call_erased`].
162type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
163
164/// Type-erased wrapper enabling heterogeneous tool storage.
165///
166/// Boxes the future from [`RustTool::call`] so we can store different tool
167/// types in the same `HashMap<String, Box<dyn ErasedTool>>`.
168pub(crate) trait ErasedTool: Send + Sync {
169    /// Deserialize `args` and call the handler, returning a boxed future.
170    fn call_erased<'a>(
171        &'a self,
172        args: serde_json::Value,
173        ctx: &'a ToolContext,
174    ) -> BoxToolFuture<'a>;
175}
176
177impl<T: RustTool> ErasedTool for T {
178    fn call_erased<'a>(
179        &'a self,
180        args: serde_json::Value,
181        ctx: &'a ToolContext,
182    ) -> BoxToolFuture<'a> {
183        Box::pin(async move {
184            let params: T::Params = serde_json::from_value(args).map_err(|e| {
185                ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
186            })?;
187            self.call(params, ctx).await
188        })
189    }
190}