Skip to main content

llm_stack/tool/
helpers.rs

1//! Helper functions for creating tool handlers.
2
3use std::future::Future;
4use std::marker::PhantomData;
5
6use serde_json::Value;
7
8use super::{FnToolHandler, NoCtxToolHandler, ToolError, ToolOutput};
9use crate::provider::ToolDefinition;
10
11/// Creates a [`ToolHandler<()>`](super::ToolHandler) from a closure (no context).
12///
13/// The closure receives the tool's JSON arguments and returns a
14/// `Result<impl Into<ToolOutput>, ToolError>`. Returning `Result<String, ToolError>`
15/// also works via the `From<String>` impl on `ToolOutput`.
16///
17/// For tools that need shared context, use [`tool_fn_with_ctx`] instead.
18///
19/// # Example
20///
21/// ```rust
22/// use llm_stack::tool::tool_fn;
23/// use llm_stack::{JsonSchema, ToolDefinition};
24/// use serde_json::{json, Value};
25///
26/// let handler = tool_fn(
27///     ToolDefinition {
28///         name: "add".into(),
29///         description: "Add two numbers".into(),
30///         parameters: JsonSchema::new(json!({
31///             "type": "object",
32///             "properties": {
33///                 "a": { "type": "number" },
34///                 "b": { "type": "number" }
35///             },
36///             "required": ["a", "b"]
37///         })),
38///         retry: None,
39///     },
40///     |input: Value| async move {
41///         let a = input["a"].as_f64().unwrap_or(0.0);
42///         let b = input["b"].as_f64().unwrap_or(0.0);
43///         Ok(format!("{}", a + b))
44///     },
45/// );
46/// ```
47pub fn tool_fn<F, Fut, O>(definition: ToolDefinition, handler: F) -> NoCtxToolHandler<F>
48where
49    F: Fn(Value) -> Fut + Send + Sync + 'static,
50    Fut: Future<Output = Result<O, ToolError>> + Send + 'static,
51    O: Into<ToolOutput> + Send + 'static,
52{
53    NoCtxToolHandler {
54        definition,
55        handler,
56    }
57}
58
59/// Creates a [`ToolHandler<Ctx>`](super::ToolHandler) from a closure that receives context.
60///
61/// The closure receives the tool's JSON arguments and a reference to the
62/// context, and returns a `Result<impl Into<ToolOutput>, ToolError>`.
63///
64/// # Example
65///
66/// ```rust
67/// use llm_stack::tool::{tool_fn_with_ctx, ToolOutput};
68/// use llm_stack::{JsonSchema, ToolDefinition};
69/// use serde_json::{json, Value};
70///
71/// struct AppContext {
72///     db_url: String,
73/// }
74///
75/// let handler = tool_fn_with_ctx(
76///     ToolDefinition {
77///         name: "lookup_user".into(),
78///         description: "Look up a user by ID".into(),
79///         parameters: JsonSchema::new(json!({
80///             "type": "object",
81///             "properties": {
82///                 "user_id": { "type": "string" }
83///             },
84///             "required": ["user_id"]
85///         })),
86///         retry: None,
87///     },
88///     |input: Value, ctx: &AppContext| {
89///         let user_id = input["user_id"].as_str().unwrap_or("").to_string();
90///         let db_url = ctx.db_url.clone();
91///         async move {
92///             // Use db_url and user_id in async work...
93///             Ok(ToolOutput::new(format!("Found user {} in {}", user_id, db_url)))
94///         }
95///     },
96/// );
97/// ```
98pub fn tool_fn_with_ctx<Ctx, F, Fut, O>(
99    definition: ToolDefinition,
100    handler: F,
101) -> FnToolHandler<Ctx, F>
102where
103    Ctx: Send + Sync + 'static,
104    F: for<'c> Fn(Value, &'c Ctx) -> Fut + Send + Sync + 'static,
105    Fut: Future<Output = Result<O, ToolError>> + Send + 'static,
106    O: Into<ToolOutput> + Send + 'static,
107{
108    FnToolHandler {
109        definition,
110        handler,
111        _ctx: PhantomData,
112    }
113}