1use async_trait::async_trait;
2use schemars::{JsonSchema, schema::RootSchema, schema_for};
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5
6pub use alith_interface::requests::completion::{ToolChoice, ToolDefinition};
7
8#[async_trait]
9pub trait Tool: Send + Sync {
10 fn name(&self) -> &str {
11 "default-tool"
12 }
13
14 fn version(&self) -> &str {
15 "0.0.0"
16 }
17
18 fn description(&self) -> &str {
19 "A default tool"
20 }
21
22 fn author(&self) -> &str {
23 "Anonymous"
24 }
25
26 fn definition(&self) -> ToolDefinition;
27
28 fn validate_input(&self, input: &str) -> Result<(), ToolError> {
29 if input.trim().is_empty() {
30 Err(ToolError::InvalidInput)
31 } else {
32 Ok(())
33 }
34 }
35
36 async fn run(&self, input: &str) -> Result<String, ToolError>;
37}
38
39#[async_trait]
40pub trait StructureTool: Send + Sync {
41 type Input: for<'a> Deserialize<'a> + JsonSchema + Send + Sync;
42 type Output: Serialize;
43
44 fn name(&self) -> &str {
45 "default-tool"
46 }
47
48 fn version(&self) -> &str {
49 "0.0.0"
50 }
51
52 fn description(&self) -> &str {
53 "A default tool description"
54 }
55
56 fn author(&self) -> &str {
57 "Anonymous"
58 }
59
60 fn schema(&self) -> RootSchema {
61 schema_for!(Self::Input)
62 }
63
64 fn definition(&self) -> ToolDefinition {
65 ToolDefinition {
66 name: self.name().to_owned(),
67 description: self.description().to_owned(),
68 parameters: json!(self.schema()),
69 }
70 }
71
72 async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError>;
73
74 async fn run(&self, input: &str) -> Result<String, ToolError> {
75 match serde_json::from_str(input) {
76 Ok(input) => {
77 let output = self.run_with_args(input).await?;
78 serde_json::to_string(&output).map_err(ToolError::JsonError)
79 }
80 Err(e) => Err(ToolError::JsonError(e)),
81 }
82 }
83}
84
85#[async_trait]
86impl<T: StructureTool> Tool for T {
87 fn name(&self) -> &str {
88 self.name()
89 }
90
91 fn version(&self) -> &str {
92 self.version()
93 }
94
95 fn description(&self) -> &str {
96 self.description()
97 }
98
99 fn author(&self) -> &str {
100 self.author()
101 }
102
103 fn definition(&self) -> ToolDefinition {
104 self.definition()
105 }
106
107 async fn run(&self, input: &str) -> Result<String, ToolError> {
108 match serde_json::from_str(input) {
109 Ok(input) => {
110 let output = self.run_with_args(input).await?;
111 serde_json::to_string(&output).map_err(ToolError::JsonError)
112 }
113 Err(e) => Err(ToolError::JsonError(e)),
114 }
115 }
116}
117
118#[derive(Debug, thiserror::Error)]
119#[error("Tool error")]
120pub enum ToolError {
121 #[error("NormalError: {0}")]
122 NormalError(Box<dyn std::error::Error + Send + Sync + 'static>),
123 #[error("Invalid input provided to the tool")]
124 InvalidInput,
125 #[error("The tool produced invalid output")]
126 InvalidOutput,
127 #[error("The tool is not available or not configured properly")]
128 InvalidTool,
129 #[error("An unknown error occurred: {0}")]
130 Unknown(String),
131 #[error("JsonError: {0}")]
132 JsonError(#[from] serde_json::Error),
133}
134
135#[cfg(test)]
136mod tests {
137 use super::{StructureTool, Tool, ToolError};
138 use async_trait::async_trait;
139 use schemars::JsonSchema;
140 use serde::{Deserialize, Serialize};
141 use serde_json::json;
142
143 pub struct DummyTool;
144
145 #[derive(JsonSchema, Serialize, Deserialize)]
146 pub struct DummpyInput {
147 pub x: usize,
148 pub y: usize,
149 }
150
151 #[async_trait]
152 impl StructureTool for DummyTool {
153 type Input = DummpyInput;
154 type Output = String;
155
156 fn name(&self) -> &str {
157 "dummy"
158 }
159
160 async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
161 Ok(format!("x: {}, y: {}", input.x, input.y))
162 }
163 }
164
165 #[tokio::test]
166 async fn test_dummy_tool() {
167 let tool: Box<dyn Tool> = Box::new(DummyTool);
168 let output = tool
169 .run(
170 serde_json::to_string(&json!({
171 "x": 1,
172 "y": 2
173 }))
174 .unwrap()
175 .as_str(),
176 )
177 .await
178 .unwrap();
179 assert_eq!(tool.name(), "dummy");
180 assert_eq!(output, "\"x: 1, y: 2\"");
181 assert_eq!(
182 tool.definition().parameters.to_string(),
183 "{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"properties\":{\"x\":{\"format\":\"uint\",\"minimum\":0.0,\"type\":\"integer\"},\"y\":{\"format\":\"uint\",\"minimum\":0.0,\"type\":\"integer\"}},\"required\":[\"x\",\"y\"],\"title\":\"DummpyInput\",\"type\":\"object\"}"
184 );
185 }
186}