hehe_tools/traits/
tool.rs

1use crate::error::Result;
2use async_trait::async_trait;
3use hehe_core::{Context, Metadata, ToolDefinition};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
8pub struct ToolOutput {
9    pub content: String,
10    #[serde(default, skip_serializing_if = "Vec::is_empty")]
11    pub artifacts: Vec<Artifact>,
12    #[serde(default, skip_serializing_if = "Metadata::is_empty")]
13    pub metadata: Metadata,
14    #[serde(default)]
15    pub is_error: bool,
16}
17
18impl ToolOutput {
19    pub fn text(content: impl Into<String>) -> Self {
20        Self {
21            content: content.into(),
22            artifacts: vec![],
23            metadata: Metadata::new(),
24            is_error: false,
25        }
26    }
27
28    pub fn json<T: Serialize>(value: &T) -> Result<Self> {
29        Ok(Self {
30            content: serde_json::to_string_pretty(value)?,
31            artifacts: vec![],
32            metadata: Metadata::new(),
33            is_error: false,
34        })
35    }
36
37    pub fn error(message: impl Into<String>) -> Self {
38        Self {
39            content: message.into(),
40            artifacts: vec![],
41            metadata: Metadata::new(),
42            is_error: true,
43        }
44    }
45
46    pub fn with_artifact(mut self, artifact: Artifact) -> Self {
47        self.artifacts.push(artifact);
48        self
49    }
50
51    pub fn with_metadata<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
52        self.metadata.insert(key, value);
53        self
54    }
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct Artifact {
59    pub name: String,
60    pub content_type: String,
61    pub data: ArtifactData,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65#[serde(tag = "type", rename_all = "snake_case")]
66pub enum ArtifactData {
67    Text { text: String },
68    Base64 { data: String },
69    File { path: String },
70}
71
72impl Artifact {
73    pub fn text(name: impl Into<String>, content: impl Into<String>) -> Self {
74        Self {
75            name: name.into(),
76            content_type: "text/plain".to_string(),
77            data: ArtifactData::Text {
78                text: content.into(),
79            },
80        }
81    }
82
83    pub fn file(name: impl Into<String>, path: impl Into<String>) -> Self {
84        Self {
85            name: name.into(),
86            content_type: "application/octet-stream".to_string(),
87            data: ArtifactData::File { path: path.into() },
88        }
89    }
90
91    pub fn with_content_type(mut self, content_type: impl Into<String>) -> Self {
92        self.content_type = content_type.into();
93        self
94    }
95}
96
97#[async_trait]
98pub trait Tool: Send + Sync {
99    fn definition(&self) -> &ToolDefinition;
100
101    async fn execute(&self, ctx: &Context, input: Value) -> Result<ToolOutput>;
102
103    fn validate_input(&self, _input: &Value) -> Result<()> {
104        Ok(())
105    }
106
107    fn name(&self) -> &str {
108        &self.definition().name
109    }
110
111    fn is_dangerous(&self) -> bool {
112        self.definition().dangerous
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_tool_output_text() {
122        let output = ToolOutput::text("Hello, world!");
123        assert_eq!(output.content, "Hello, world!");
124        assert!(!output.is_error);
125    }
126
127    #[test]
128    fn test_tool_output_json() {
129        let data = serde_json::json!({"key": "value"});
130        let output = ToolOutput::json(&data).unwrap();
131        assert!(output.content.contains("key"));
132    }
133
134    #[test]
135    fn test_tool_output_error() {
136        let output = ToolOutput::error("Something went wrong");
137        assert!(output.is_error);
138        assert_eq!(output.content, "Something went wrong");
139    }
140
141    #[test]
142    fn test_artifact() {
143        let artifact = Artifact::text("readme", "# Hello")
144            .with_content_type("text/markdown");
145        
146        assert_eq!(artifact.name, "readme");
147        assert_eq!(artifact.content_type, "text/markdown");
148    }
149}