Skip to main content

alith_core/
tool.rs

1use async_trait::async_trait;
2use schemars::{JsonSchema, schema::RootSchema, schema_for};
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5
6pub use alith_interface::requests::completion::{ToolChoice, ToolDefinition};
7
8#[async_trait]
9pub trait Tool: Send + Sync {
10    fn name(&self) -> &str {
11        "default-tool"
12    }
13
14    fn version(&self) -> &str {
15        "0.0.0"
16    }
17
18    fn description(&self) -> &str {
19        "A default tool"
20    }
21
22    fn author(&self) -> &str {
23        "Anonymous"
24    }
25
26    fn definition(&self) -> ToolDefinition;
27
28    fn validate_input(&self, input: &str) -> Result<(), ToolError> {
29        if input.trim().is_empty() {
30            Err(ToolError::InvalidInput)
31        } else {
32            Ok(())
33        }
34    }
35
36    async fn run(&self, input: &str) -> Result<String, ToolError>;
37}
38
39#[async_trait]
40pub trait StructureTool: Send + Sync {
41    type Input: for<'a> Deserialize<'a> + JsonSchema + Send + Sync;
42    type Output: Serialize;
43
44    fn name(&self) -> &str {
45        "default-tool"
46    }
47
48    fn version(&self) -> &str {
49        "0.0.0"
50    }
51
52    fn description(&self) -> &str {
53        "A default tool description"
54    }
55
56    fn author(&self) -> &str {
57        "Anonymous"
58    }
59
60    fn schema(&self) -> RootSchema {
61        schema_for!(Self::Input)
62    }
63
64    fn definition(&self) -> ToolDefinition {
65        ToolDefinition {
66            name: self.name().to_owned(),
67            description: self.description().to_owned(),
68            parameters: json!(self.schema()),
69        }
70    }
71
72    async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError>;
73
74    async fn run(&self, input: &str) -> Result<String, ToolError> {
75        match serde_json::from_str(input) {
76            Ok(input) => {
77                let output = self.run_with_args(input).await?;
78                serde_json::to_string(&output).map_err(ToolError::JsonError)
79            }
80            Err(e) => Err(ToolError::JsonError(e)),
81        }
82    }
83}
84
85#[async_trait]
86impl<T: StructureTool> Tool for T {
87    fn name(&self) -> &str {
88        self.name()
89    }
90
91    fn version(&self) -> &str {
92        self.version()
93    }
94
95    fn description(&self) -> &str {
96        self.description()
97    }
98
99    fn author(&self) -> &str {
100        self.author()
101    }
102
103    fn definition(&self) -> ToolDefinition {
104        self.definition()
105    }
106
107    async fn run(&self, input: &str) -> Result<String, ToolError> {
108        match serde_json::from_str(input) {
109            Ok(input) => {
110                let output = self.run_with_args(input).await?;
111                serde_json::to_string(&output).map_err(ToolError::JsonError)
112            }
113            Err(e) => Err(ToolError::JsonError(e)),
114        }
115    }
116}
117
118#[derive(Debug, thiserror::Error)]
119#[error("Tool error")]
120pub enum ToolError {
121    #[error("NormalError: {0}")]
122    NormalError(Box<dyn std::error::Error + Send + Sync + 'static>),
123    #[error("Invalid input provided to the tool")]
124    InvalidInput,
125    #[error("The tool produced invalid output")]
126    InvalidOutput,
127    #[error("The tool is not available or not configured properly")]
128    InvalidTool,
129    #[error("An unknown error occurred: {0}")]
130    Unknown(String),
131    #[error("JsonError: {0}")]
132    JsonError(#[from] serde_json::Error),
133}
134
135#[cfg(test)]
136mod tests {
137    use super::{StructureTool, Tool, ToolError};
138    use async_trait::async_trait;
139    use schemars::JsonSchema;
140    use serde::{Deserialize, Serialize};
141    use serde_json::json;
142
143    pub struct DummyTool;
144
145    #[derive(JsonSchema, Serialize, Deserialize)]
146    pub struct DummpyInput {
147        pub x: usize,
148        pub y: usize,
149    }
150
151    #[async_trait]
152    impl StructureTool for DummyTool {
153        type Input = DummpyInput;
154        type Output = String;
155
156        fn name(&self) -> &str {
157            "dummy"
158        }
159
160        async fn run_with_args(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
161            Ok(format!("x: {}, y: {}", input.x, input.y))
162        }
163    }
164
165    #[tokio::test]
166    async fn test_dummy_tool() {
167        let tool: Box<dyn Tool> = Box::new(DummyTool);
168        let output = tool
169            .run(
170                serde_json::to_string(&json!({
171                    "x": 1,
172                    "y": 2
173                }))
174                .unwrap()
175                .as_str(),
176            )
177            .await
178            .unwrap();
179        assert_eq!(tool.name(), "dummy");
180        assert_eq!(output, "\"x: 1, y: 2\"");
181        assert_eq!(
182            tool.definition().parameters.to_string(),
183            "{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"properties\":{\"x\":{\"format\":\"uint\",\"minimum\":0.0,\"type\":\"integer\"},\"y\":{\"format\":\"uint\",\"minimum\":0.0,\"type\":\"integer\"}},\"required\":[\"x\",\"y\"],\"title\":\"DummpyInput\",\"type\":\"object\"}"
184        );
185    }
186}