Skip to main content

rig_compose/
tool.rs

1//! [`Tool`] — the only side-effectful interface available to skills and agents.
2//!
3//! A [`Tool`] is a typed, named, async function with a JSON-Schema-compatible
4//! signature. Two transports satisfy the trait today: [`LocalTool`] (a closure
5//! over a Rust async fn) and — under the `mcp` feature in a later phase — a
6//! remote MCP server. Skills never know the difference.
7
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::{Map, Value};
15
16use crate::registry::KernelError;
17
18/// Stable, registry-unique identifier for a tool (e.g. `"grammar.query"`,
19/// `"memory.lookup"`, `"sampler.expand"`).
20pub type ToolName = String;
21
22/// Lightweight description of a tool's I/O contract. The `args_schema` and
23/// `result_schema` are JSON-Schema fragments; the LLM-facing rendering layer
24/// uses them to generate `rig` / MCP tool definitions automatically. We do
25/// **not** validate against them at the kernel — validation is the tool's
26/// responsibility — but downstream MCP exporters need them.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ToolSchema {
29    pub name: ToolName,
30    pub description: String,
31    pub args_schema: Value,
32    pub result_schema: Value,
33}
34
35/// Configuration for bounding large tool results before they enter a model
36/// turn, trace record, or MCP response cache.
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub struct ToolResultEnvelopeConfig {
39    /// Maximum characters retained for any string value.
40    pub max_string_chars: usize,
41    /// Maximum items retained for any array value.
42    pub max_array_items: usize,
43}
44
45impl Default for ToolResultEnvelopeConfig {
46    fn default() -> Self {
47        Self {
48            max_string_chars: 4_000,
49            max_array_items: 64,
50        }
51    }
52}
53
54impl ToolResultEnvelopeConfig {
55    /// Build a config with a string character limit and otherwise default
56    /// limits.
57    #[must_use]
58    pub fn new(max_string_chars: usize) -> Self {
59        Self {
60            max_string_chars,
61            ..Self::default()
62        }
63    }
64
65    /// Set the maximum retained array items.
66    #[must_use]
67    pub fn with_max_array_items(mut self, max_array_items: usize) -> Self {
68        self.max_array_items = max_array_items;
69        self
70    }
71}
72
73/// Tool result plus deterministic truncation metadata.
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct ToolResultEnvelope {
76    /// Possibly bounded result payload.
77    pub payload: Value,
78    /// Whether any value was truncated or omitted.
79    pub truncated: bool,
80    /// Total string characters omitted while bounding the payload.
81    pub omitted_chars: usize,
82    /// Total array items omitted while bounding the payload.
83    pub omitted_items: usize,
84    /// Stable follow-up token describing the first omitted segment.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub page_token: Option<String>,
87}
88
89impl ToolResultEnvelope {
90    /// Bound `payload` according to `config` and return truncation metadata.
91    #[must_use]
92    pub fn bound(payload: Value, config: &ToolResultEnvelopeConfig) -> Self {
93        let mut state = ToolResultEnvelopeState::default();
94        let payload = bound_value(payload, config, &mut state);
95        Self {
96            payload,
97            truncated: state.omitted_chars > 0 || state.omitted_items > 0,
98            omitted_chars: state.omitted_chars,
99            omitted_items: state.omitted_items,
100            page_token: state.page_token,
101        }
102    }
103}
104
105/// Bound `payload` with the default [`ToolResultEnvelopeConfig`].
106#[must_use]
107pub fn bound_tool_result(payload: Value) -> ToolResultEnvelope {
108    ToolResultEnvelope::bound(payload, &ToolResultEnvelopeConfig::default())
109}
110
111#[derive(Default)]
112struct ToolResultEnvelopeState {
113    omitted_chars: usize,
114    omitted_items: usize,
115    page_token: Option<String>,
116}
117
118fn bound_value(
119    value: Value,
120    config: &ToolResultEnvelopeConfig,
121    state: &mut ToolResultEnvelopeState,
122) -> Value {
123    match value {
124        Value::String(text) => bound_string(text, config, state),
125        Value::Array(items) => bound_array(items, config, state),
126        Value::Object(fields) => bound_object(fields, config, state),
127        scalar => scalar,
128    }
129}
130
131fn bound_string(
132    text: String,
133    config: &ToolResultEnvelopeConfig,
134    state: &mut ToolResultEnvelopeState,
135) -> Value {
136    let total_chars = text.chars().count();
137    if total_chars <= config.max_string_chars {
138        return Value::String(text);
139    }
140    state.omitted_chars = state
141        .omitted_chars
142        .saturating_add(total_chars.saturating_sub(config.max_string_chars));
143    if state.page_token.is_none() {
144        state.page_token = Some(format!("chars:{}", config.max_string_chars));
145    }
146    Value::String(text.chars().take(config.max_string_chars).collect())
147}
148
149fn bound_array(
150    items: Vec<Value>,
151    config: &ToolResultEnvelopeConfig,
152    state: &mut ToolResultEnvelopeState,
153) -> Value {
154    let total_items = items.len();
155    if total_items > config.max_array_items {
156        state.omitted_items = state
157            .omitted_items
158            .saturating_add(total_items.saturating_sub(config.max_array_items));
159        if state.page_token.is_none() {
160            state.page_token = Some(format!("items:{}", config.max_array_items));
161        }
162    }
163    Value::Array(
164        items
165            .into_iter()
166            .take(config.max_array_items)
167            .map(|item| bound_value(item, config, state))
168            .collect(),
169    )
170}
171
172fn bound_object(
173    fields: Map<String, Value>,
174    config: &ToolResultEnvelopeConfig,
175    state: &mut ToolResultEnvelopeState,
176) -> Value {
177    Value::Object(
178        fields
179            .into_iter()
180            .map(|(key, value)| (key, bound_value(value, config, state)))
181            .collect(),
182    )
183}
184
185/// A composable, side-effectful capability.
186///
187/// Implementations MUST be cheap to clone (typically `Arc`-wrapped state) so
188/// the same tool instance can be referenced from multiple agents'
189/// [`super::registry::ToolRegistry`] slices.
190#[async_trait]
191pub trait Tool: Send + Sync {
192    /// Return this tool's JSON-Schema-compatible contract.
193    fn schema(&self) -> ToolSchema;
194
195    /// Return this tool's registry name.
196    fn name(&self) -> ToolName {
197        self.schema().name
198    }
199
200    /// Invoke the tool with JSON arguments.
201    async fn invoke(&self, args: Value) -> Result<Value, KernelError>;
202}
203
204/// Adapter that turns any `async Fn(Value) -> Result<Value, KernelError>`
205/// into a [`Tool`]. Hosts can use this to surface existing async functions
206/// to the kernel without writing a dedicated tool type.
207pub struct LocalTool {
208    schema: ToolSchema,
209    #[allow(clippy::type_complexity)]
210    f: Arc<
211        dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, KernelError>> + Send>>
212            + Send
213            + Sync,
214    >,
215}
216
217impl LocalTool {
218    pub fn new<F, Fut>(schema: ToolSchema, f: F) -> Self
219    where
220        F: Fn(Value) -> Fut + Send + Sync + 'static,
221        Fut: Future<Output = Result<Value, KernelError>> + Send + 'static,
222    {
223        Self {
224            schema,
225            f: Arc::new(move |v| Box::pin(f(v))),
226        }
227    }
228}
229
230#[async_trait]
231impl Tool for LocalTool {
232    fn schema(&self) -> ToolSchema {
233        self.schema.clone()
234    }
235
236    fn name(&self) -> ToolName {
237        self.schema.name.clone()
238    }
239
240    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
241        (self.f)(args).await
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use crate::*;
248    use serde_json::json;
249
250    #[tokio::test]
251    async fn local_tool_roundtrip() {
252        let schema = ToolSchema {
253            name: "test.echo".into(),
254            description: "echoes the input".into(),
255            args_schema: json!({"type": "object"}),
256            result_schema: json!({"type": "object"}),
257        };
258        let tool = LocalTool::new(schema, |v| async move { Ok(v) });
259        let out = tool.invoke(json!({"hello": "world"})).await.unwrap();
260        assert_eq!(out, json!({"hello": "world"}));
261        assert_eq!(tool.name(), "test.echo");
262    }
263
264    #[test]
265    fn tool_result_envelope_bounds_large_strings() {
266        let envelope =
267            ToolResultEnvelope::bound(json!({"body": "abcdef"}), &ToolResultEnvelopeConfig::new(3));
268
269        assert_eq!(envelope.payload, json!({"body": "abc"}));
270        assert!(envelope.truncated);
271        assert_eq!(envelope.omitted_chars, 3);
272        assert_eq!(envelope.page_token.as_deref(), Some("chars:3"));
273    }
274
275    #[test]
276    fn tool_result_envelope_bounds_arrays() {
277        let envelope = ToolResultEnvelope::bound(
278            json!({"rows": [1, 2, 3, 4]}),
279            &ToolResultEnvelopeConfig::new(100).with_max_array_items(2),
280        );
281
282        assert_eq!(envelope.payload, json!({"rows": [1, 2]}));
283        assert!(envelope.truncated);
284        assert_eq!(envelope.omitted_items, 2);
285        assert_eq!(envelope.page_token.as_deref(), Some("items:2"));
286    }
287
288    #[test]
289    fn tool_result_envelope_leaves_small_payloads_unchanged() {
290        let payload = json!({"ok": true, "rows": ["a"]});
291        let envelope = ToolResultEnvelope::bound(
292            payload.clone(),
293            &ToolResultEnvelopeConfig::new(100).with_max_array_items(10),
294        );
295
296        assert_eq!(envelope.payload, payload);
297        assert!(!envelope.truncated);
298        assert_eq!(envelope.omitted_chars, 0);
299        assert_eq!(envelope.omitted_items, 0);
300        assert_eq!(envelope.page_token, None);
301    }
302}