Skip to main content

llm_tool/
rust_tool.rs

1//! Strongly-typed Rust tool trait and type-erasure machinery.
2
3use std::{borrow::Cow, future::Future, pin::Pin};
4
5use super::types::{ToolContext, ToolDefinition, ToolError, ToolOutput};
6
7/// Convenience type for tools that take no parameters.
8#[derive(Debug, Clone, serde::Deserialize, schemars::JsonSchema)]
9pub struct EmptyParams {}
10
11/// A custom tool implemented entirely in Rust with strongly-typed parameters.
12///
13/// Define your parameters as a struct deriving [`serde::Deserialize`] and
14/// `JsonSchema`, then implement this trait to provide the tool's logic.
15/// The JSON Schema sent to the model is derived automatically from the
16/// params struct — doc comments on fields become parameter descriptions.
17///
18/// Tools are async: for I/O-bound work (HTTP, filesystem, subprocess) the
19/// runtime stays unblocked. Sync tools just don't `.await` anything — the
20/// compiler optimizes the state machine to an immediate return.
21///
22/// # Example
23///
24/// ```rust
25/// use llm_tool::{JsonSchema, RustTool, ToolContext, ToolError, ToolOutput};
26/// use serde::Deserialize;
27///
28/// #[derive(Deserialize, JsonSchema)]
29/// struct FlashParams {
30///     /// Target device identifier.
31///     device_id: String,
32///     /// Path to the firmware image.
33///     image_path: String,
34/// }
35///
36/// struct FlashDevice;
37///
38/// impl RustTool for FlashDevice {
39///     type Params = FlashParams;
40///     const NAME: &'static str = "flash_device";
41///     const DESCRIPTION: &'static str = "Flashes firmware to a connected device.";
42///
43///     async fn call(
44///         &self,
45///         params: Self::Params,
46///         _ctx: &ToolContext,
47///     ) -> Result<ToolOutput, ToolError> {
48///         Ok(format!("Flashed {} to {}", params.image_path, params.device_id).into())
49///     }
50/// }
51/// ```
52pub trait RustTool: Send + Sync {
53    /// The strongly-typed parameters struct.
54    ///
55    /// Derive [`serde::Deserialize`] and `JsonSchema` on your params struct.
56    /// `JsonSchema` auto-generates the parameter schema sent to the model;
57    /// `Deserialize` parses the model's JSON arguments into your struct.
58    type Params: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
59
60    /// Unique tool name (e.g. `"flash_device"`).
61    const NAME: &'static str;
62
63    /// Human-readable description shown to the model.
64    const DESCRIPTION: &'static str;
65
66    /// Return the tool description used in [`ToolDefinition`].
67    ///
68    /// The default returns [`Self::DESCRIPTION`] (the static string from a
69    /// doc comment or template body). When using
70    /// `#[llm_tool(template = "...", context = ...)]`, the generated
71    /// implementation overrides this to render the template with runtime
72    /// variables on each call. Templates are parsed once via `LazyLock`.
73    fn description(&self) -> Cow<'static, str> {
74        Cow::Borrowed(Self::DESCRIPTION)
75    }
76
77    /// Execute the tool with typed parameters and an execution context.
78    ///
79    /// Async to support I/O-bound tools (HTTP, filesystem, subprocess).
80    /// Sync tools just compute and return — the async wrapper is zero-cost.
81    ///
82    /// The `ctx` parameter provides access to conversation metadata and a
83    /// shared key-value state store. Tools that don't need context can simply
84    /// ignore it with `_ctx`.
85    ///
86    /// # Errors
87    ///
88    /// Returns `Err(ToolError)` if the tool execution fails.
89    fn call(
90        &self,
91        params: Self::Params,
92        ctx: &ToolContext,
93    ) -> impl std::future::Future<Output = Result<ToolOutput, ToolError>> + Send;
94}
95
96/// Build a [`ToolDefinition`] from any [`RustTool`] implementor.
97///
98/// The generated schema is sanitized to be compatible with the Go-based
99/// localharness, which expects `"type"` to always be a single string
100/// (not the array form `["string", "null"]` that schemars emits for
101/// `Option<T>` fields).
102///
103/// # Errors
104///
105/// Returns `Err` if the JSON schema serialization fails.
106pub fn definition_of<T: RustTool>(tool: &T) -> Result<ToolDefinition, ToolError> {
107    let schema = schemars::schema_for!(T::Params);
108    let mut parameter_schema = serde_json::to_value(schema).map_err(|e| {
109        ToolError::new(format!(
110            "Failed to serialize schema for tool '{}': {e}",
111            T::NAME
112        ))
113    })?;
114    sanitize_schema_types(&mut parameter_schema);
115    Ok(ToolDefinition {
116        name: T::NAME.to_string(),
117        description: tool.description().into_owned(),
118        parameter_schema,
119    })
120}
121
122/// Recursively sanitize JSON Schema `"type"` fields for Go genai compatibility.
123///
124/// `schemars` emits `"type": ["string", "null"]` for `Option<String>` fields
125/// (JSON Schema draft 7 nullable syntax). The Go genai SDK's `Schema.Type`
126/// is a single `genai.Type` enum, so it can't unmarshal an array.
127///
128/// This function walks the schema tree and replaces any array `type` with the
129/// first non-`"null"` element. For example:
130/// - `["string", "null"]` → `"string"`
131/// - `["integer", "null"]` → `"integer"`
132fn sanitize_schema_types(value: &mut serde_json::Value) {
133    match value {
134        serde_json::Value::Object(map) => {
135            // If "type" is an array (e.g. ["string", "null"]), pick the first
136            // non-"null" element and replace with it as a scalar type.
137            let replacement = match map.get("type") {
138                Some(serde_json::Value::Array(arr)) => {
139                    let non_null = arr.iter().find(|v| v.as_str() != Some("null")).cloned();
140                    non_null.or_else(|| arr.first().cloned())
141                }
142                _ => None,
143            };
144            if let Some(val) = replacement {
145                map.insert("type".to_string(), val);
146            }
147            for val in map.values_mut() {
148                sanitize_schema_types(val);
149            }
150        }
151        serde_json::Value::Array(arr) => {
152            for item in arr {
153                sanitize_schema_types(item);
154            }
155        }
156        _ => {}
157    }
158}
159
160/// Type-erased future returned by [`ErasedTool::call_erased`].
161type BoxToolFuture<'a> = Pin<Box<dyn Future<Output = Result<ToolOutput, ToolError>> + Send + 'a>>;
162
163/// Type-erased wrapper enabling heterogeneous tool storage.
164///
165/// Boxes the future from [`RustTool::call`] so we can store different tool
166/// types in the same `HashMap<String, Box<dyn ErasedTool>>`.
167pub(crate) trait ErasedTool: Send + Sync {
168    /// Deserialize `args` and call the handler, returning a boxed future.
169    fn call_erased<'a>(
170        &'a self,
171        args: serde_json::Value,
172        ctx: &'a ToolContext,
173    ) -> BoxToolFuture<'a>;
174}
175
176impl<T: RustTool> ErasedTool for T {
177    fn call_erased<'a>(
178        &'a self,
179        args: serde_json::Value,
180        ctx: &'a ToolContext,
181    ) -> BoxToolFuture<'a> {
182        Box::pin(async move {
183            let params: T::Params = serde_json::from_value(args).map_err(|e| {
184                ToolError::new(format!("Failed to deserialize tool parameters: {e}"))
185            })?;
186            self.call(params, ctx).await
187        })
188    }
189}