Skip to main content

heartbit_core/tool/
mod.rs

1//! Tool trait and built-in tool implementations (filesystem, web, MCP, A2A, etc.).
2
3#![allow(missing_docs)]
4#[cfg(feature = "a2a")]
5pub mod a2a;
6pub mod builtins;
7pub mod handoff;
8pub mod mcp;
9pub mod mcp_presets;
10pub mod mcp_server;
11
12use std::future::Future;
13use std::pin::Pin;
14
15use crate::error::Error;
16use crate::llm::types::ToolDefinition;
17
18/// Output of a tool execution.
19#[derive(Debug, Clone)]
20pub struct ToolOutput {
21    pub content: String,
22    pub is_error: bool,
23}
24
25impl ToolOutput {
26    pub fn success(content: impl Into<String>) -> Self {
27        Self {
28            content: content.into(),
29            is_error: false,
30        }
31    }
32
33    pub fn error(content: impl Into<String>) -> Self {
34        Self {
35            content: content.into(),
36            is_error: true,
37        }
38    }
39
40    /// Truncate content if it exceeds `max_bytes`, preserving UTF-8 validity.
41    ///
42    /// When truncated, appends a `[truncated: N bytes omitted]` suffix so the
43    /// LLM knows data was cut. Content within the limit is returned unchanged.
44    /// A `max_bytes` of 0 is treated as no-op (returns content unchanged).
45    ///
46    /// Note: the suffix itself is not counted toward `max_bytes`, so the
47    /// result may slightly exceed the limit.
48    pub fn truncated(mut self, max_bytes: usize) -> Self {
49        if max_bytes == 0 {
50            return self;
51        }
52        if self.content.len() > max_bytes {
53            let mut cut = max_bytes;
54            while cut > 0 && !self.content.is_char_boundary(cut) {
55                cut -= 1;
56            }
57            let omitted = self.content.len() - cut;
58            self.content.truncate(cut);
59            self.content
60                .push_str(&format!("\n\n[truncated: {omitted} bytes omitted]"));
61        }
62        self
63    }
64}
65
66/// Trait for tools that agents can invoke.
67///
68/// Uses `Pin<Box<dyn Future>>` return type for dyn-compatibility,
69/// allowing tools to be stored as `Arc<dyn Tool>`.
70///
71/// # Example
72///
73/// Implementing a simple synchronous tool that echoes its input:
74///
75/// ```rust
76/// use std::future::Future;
77/// use std::pin::Pin;
78/// use heartbit_core::{Tool, ToolOutput};
79/// use heartbit_core::llm::types::ToolDefinition;
80///
81/// struct EchoTool;
82///
83/// impl Tool for EchoTool {
84///     fn definition(&self) -> ToolDefinition {
85///         ToolDefinition {
86///             name: "echo".into(),
87///             description: "Echo back the input string.".into(),
88///             input_schema: serde_json::json!({
89///                 "type": "object",
90///                 "properties": { "text": { "type": "string" } },
91///                 "required": ["text"]
92///             }),
93///         }
94///     }
95///
96///     fn execute(
97///         &self,
98///         _ctx: &heartbit_core::ExecutionContext,
99///         input: serde_json::Value,
100///     ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, heartbit_core::Error>> + Send + '_>> {
101///         Box::pin(async move {
102///             let text = input.get("text").and_then(|v| v.as_str()).unwrap_or("");
103///             Ok(ToolOutput::success(text.to_string()))
104///         })
105///     }
106/// }
107/// ```
108pub trait Tool: Send + Sync {
109    fn definition(&self) -> ToolDefinition;
110
111    fn execute(
112        &self,
113        ctx: &crate::ExecutionContext,
114        input: serde_json::Value,
115    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>>;
116}
117
118/// Validate tool input against the tool's declared JSON Schema.
119///
120/// Returns `Ok(())` if valid, `Err(error_message)` if the input
121/// does not conform. The error message is suitable for sending back
122/// to the LLM so it can self-correct.
123pub fn validate_tool_input(
124    schema: &serde_json::Value,
125    input: &serde_json::Value,
126) -> Result<(), String> {
127    let validator = match jsonschema::validator_for(schema) {
128        Ok(v) => v,
129        Err(e) => {
130            // If the schema itself is invalid, skip validation rather than
131            // rejecting every call. Log a warning for the operator.
132            tracing::warn!(error = %e, "invalid tool schema, skipping validation");
133            return Ok(());
134        }
135    };
136
137    let errors: Vec<String> = validator
138        .iter_errors(input)
139        .map(|e| e.to_string())
140        .collect();
141    if errors.is_empty() {
142        Ok(())
143    } else {
144        Err(format!("Input validation failed: {}", errors.join("; ")))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use serde_json::json;
152
153    #[test]
154    fn tool_output_success() {
155        let output = ToolOutput::success("result data");
156        assert_eq!(output.content, "result data");
157        assert!(!output.is_error);
158    }
159
160    #[test]
161    fn tool_output_error() {
162        let output = ToolOutput::error("something failed");
163        assert_eq!(output.content, "something failed");
164        assert!(output.is_error);
165    }
166
167    #[test]
168    fn tool_output_truncated_noop_when_within_limit() {
169        let output = ToolOutput::success("short text");
170        let truncated = output.truncated(100);
171        assert_eq!(truncated.content, "short text");
172        assert!(!truncated.is_error);
173    }
174
175    #[test]
176    fn tool_output_truncated_cuts_long_content() {
177        let output = ToolOutput::success("a".repeat(1000));
178        let truncated = output.truncated(100);
179        assert!(truncated.content.len() < 1000);
180        assert!(truncated.content.starts_with("aaaa"));
181        assert!(truncated.content.contains("[truncated:"));
182        assert!(truncated.content.contains("bytes omitted]"));
183        assert!(!truncated.is_error); // preserves is_error flag
184    }
185
186    #[test]
187    fn tool_output_truncated_preserves_utf8() {
188        // "é" is 2 bytes in UTF-8. A cut at byte 5 would split a char boundary.
189        let output = ToolOutput::success("ééééé"); // 10 bytes
190        let truncated = output.truncated(5);
191        // Should cut at char boundary (4 bytes = 2 chars), not mid-char
192        assert!(truncated.content.starts_with("éé"));
193        assert!(truncated.content.contains("[truncated:"));
194    }
195
196    #[test]
197    fn tool_output_truncated_exact_boundary_noop() {
198        let output = ToolOutput::success("hello"); // 5 bytes
199        let truncated = output.truncated(5);
200        assert_eq!(truncated.content, "hello");
201    }
202
203    #[test]
204    fn tool_output_truncated_zero_is_noop() {
205        let output = ToolOutput::success("some content");
206        let truncated = output.truncated(0);
207        assert_eq!(truncated.content, "some content"); // unchanged
208    }
209
210    #[test]
211    fn tool_output_truncated_error_also_truncates() {
212        let output = ToolOutput::error("e".repeat(200));
213        let truncated = output.truncated(50);
214        assert!(truncated.content.contains("[truncated:"));
215        assert!(truncated.is_error); // preserves error flag
216    }
217
218    #[test]
219    fn validate_accepts_valid_input() {
220        let schema = json!({
221            "type": "object",
222            "properties": {
223                "query": {"type": "string"}
224            },
225            "required": ["query"]
226        });
227        let input = json!({"query": "test"});
228        assert!(validate_tool_input(&schema, &input).is_ok());
229    }
230
231    #[test]
232    fn validate_rejects_missing_required() {
233        let schema = json!({
234            "type": "object",
235            "properties": {
236                "query": {"type": "string"}
237            },
238            "required": ["query"]
239        });
240        let input = json!({});
241        let err = validate_tool_input(&schema, &input).unwrap_err();
242        assert!(err.contains("validation failed"), "got: {err}");
243    }
244
245    #[test]
246    fn validate_rejects_wrong_type() {
247        let schema = json!({
248            "type": "object",
249            "properties": {
250                "query": {"type": "string"}
251            },
252            "required": ["query"]
253        });
254        let input = json!({"query": 42});
255        let err = validate_tool_input(&schema, &input).unwrap_err();
256        assert!(err.contains("validation failed"), "got: {err}");
257    }
258
259    #[test]
260    fn validate_accepts_any_for_minimal_schema() {
261        let schema = json!({"type": "object"});
262        let input = json!({});
263        assert!(validate_tool_input(&schema, &input).is_ok());
264    }
265
266    #[test]
267    fn validate_skips_on_invalid_schema() {
268        // An invalid schema should not block tool execution
269        let schema = json!({"type": "not-a-real-type"});
270        let input = json!({"anything": true});
271        // Should not fail even though schema is invalid — skips validation
272        assert!(validate_tool_input(&schema, &input).is_ok());
273    }
274
275    #[test]
276    fn validate_accepts_extra_properties() {
277        let schema = json!({
278            "type": "object",
279            "properties": {
280                "query": {"type": "string"}
281            },
282            "required": ["query"]
283        });
284        // Extra properties are allowed by default in JSON Schema
285        let input = json!({"query": "test", "extra": true});
286        assert!(validate_tool_input(&schema, &input).is_ok());
287    }
288}