Skip to main content

claude_agent/tools/
traits.rs

1//! Tool trait definitions.
2
3use async_trait::async_trait;
4use schemars::JsonSchema;
5use serde::de::DeserializeOwned;
6
7use super::context::ExecutionContext;
8use crate::types::{ToolDefinition, ToolResult};
9
10/// Core tool trait for all tool implementations.
11#[async_trait]
12pub trait Tool: Send + Sync {
13    fn name(&self) -> &str;
14    fn description(&self) -> &str;
15    fn input_schema(&self) -> serde_json::Value;
16    async fn execute(&self, input: serde_json::Value, context: &ExecutionContext) -> ToolResult;
17
18    fn definition(&self) -> ToolDefinition {
19        ToolDefinition::new(self.name(), self.description(), self.input_schema())
20    }
21}
22
23/// Schema-based tool trait with automatic JSON schema generation.
24///
25/// Provides a higher-level abstraction over `Tool` with typed inputs
26/// and automatic schema derivation via schemars.
27#[async_trait]
28pub trait SchemaTool: Send + Sync {
29    type Input: JsonSchema + DeserializeOwned + Send;
30    const NAME: &'static str;
31    const DESCRIPTION: &'static str;
32    const STRICT: bool = false;
33
34    async fn handle(&self, input: Self::Input, context: &ExecutionContext) -> ToolResult;
35
36    /// Override to provide a dynamic description instead of the static DESCRIPTION constant.
37    ///
38    /// When this returns `Some(desc)`, the blanket `Tool` impl uses it in `definition()`.
39    /// The default returns `None`, which falls back to `DESCRIPTION`.
40    fn custom_description(&self) -> Option<String> {
41        None
42    }
43
44    fn input_schema() -> serde_json::Value {
45        let schema = schemars::schema_for!(Self::Input);
46        let mut value =
47            serde_json::to_value(schema).unwrap_or_else(|_| serde_json::json!({"type": "object"}));
48
49        if let Some(obj) = value.as_object_mut() {
50            if !obj.contains_key("properties") {
51                obj.insert(
52                    "properties".to_string(),
53                    serde_json::Value::Object(serde_json::Map::new()),
54                );
55            }
56            if !obj.contains_key("additionalProperties") {
57                obj.insert(
58                    "additionalProperties".to_string(),
59                    serde_json::Value::Bool(!Self::STRICT),
60                );
61            }
62        }
63
64        value
65    }
66}
67
68#[async_trait]
69impl<T: SchemaTool + 'static> Tool for T {
70    fn name(&self) -> &str {
71        T::NAME
72    }
73
74    fn description(&self) -> &str {
75        T::DESCRIPTION
76    }
77
78    fn input_schema(&self) -> serde_json::Value {
79        T::input_schema()
80    }
81
82    fn definition(&self) -> ToolDefinition {
83        let desc = self
84            .custom_description()
85            .unwrap_or_else(|| T::DESCRIPTION.to_string());
86        let mut definition = ToolDefinition::new(T::NAME, &desc, T::input_schema());
87        if T::STRICT {
88            definition = definition.strict(true);
89        }
90        definition
91    }
92
93    async fn execute(&self, input: serde_json::Value, context: &ExecutionContext) -> ToolResult {
94        match serde_json::from_value::<T::Input>(input) {
95            Ok(typed) => SchemaTool::handle(self, typed, context).await,
96            Err(e) => ToolResult::error(format!("Invalid input: {}", e)),
97        }
98    }
99}