Skip to main content

ai_agents_tools/builtin/
calculator.rs

1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::generate_schema;
7use ai_agents_core::{Tool, ToolResult};
8
9/// This Tool is for development purposes only.
10/// It evaluates mathematical expressions.
11/// It is not intended for production use.
12pub struct CalculatorTool;
13
14impl CalculatorTool {
15    pub fn new() -> Self {
16        Self
17    }
18}
19
20impl Default for CalculatorTool {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26#[derive(Debug, Deserialize, JsonSchema)]
27struct CalculatorInput {
28    /// Mathematical expression to evaluate (e.g., '2 + 3 * 4')
29    expression: String,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33struct CalculatorOutput {
34    result: f64,
35    expression: String,
36}
37
38#[async_trait]
39impl Tool for CalculatorTool {
40    fn id(&self) -> &str {
41        "calculator"
42    }
43
44    fn name(&self) -> &str {
45        "Calculator"
46    }
47
48    fn description(&self) -> &str {
49        "Evaluates mathematical expressions. Supports +, -, *, /, ^ and parentheses."
50    }
51
52    fn input_schema(&self) -> Value {
53        generate_schema::<CalculatorInput>()
54    }
55
56    async fn execute(&self, args: Value) -> ToolResult {
57        let input: CalculatorInput = match serde_json::from_value(args) {
58            Ok(input) => input,
59            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
60        };
61
62        match evalexpr::eval(&input.expression) {
63            Ok(value) => {
64                let result = match value {
65                    evalexpr::Value::Float(f) => f,
66                    evalexpr::Value::Int(i) => i as f64,
67                    _ => return ToolResult::error("Expression must evaluate to a number"),
68                };
69
70                let output = CalculatorOutput {
71                    result,
72                    expression: input.expression,
73                };
74
75                match serde_json::to_string(&output) {
76                    Ok(json) => ToolResult::ok(json),
77                    Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
78                }
79            }
80            Err(e) => ToolResult::error(format!("Calculation error: {}", e)),
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[tokio::test]
90    async fn test_basic_operations() {
91        let calc = CalculatorTool::new();
92
93        let result = calc
94            .execute(serde_json::json!({"expression": "2 + 3"}))
95            .await;
96        assert!(result.success);
97
98        let result = calc
99            .execute(serde_json::json!({"expression": "10 * 5"}))
100            .await;
101        assert!(result.success);
102    }
103
104    #[tokio::test]
105    async fn test_operator_precedence() {
106        let calc = CalculatorTool::new();
107        let result = calc
108            .execute(serde_json::json!({"expression": "2 + 3 * 4"}))
109            .await;
110        assert!(result.success);
111
112        let output: CalculatorOutput = serde_json::from_str(&result.output).unwrap();
113        assert_eq!(output.result, 14.0);
114    }
115
116    #[tokio::test]
117    async fn test_parentheses() {
118        let calc = CalculatorTool::new();
119        let result = calc
120            .execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
121            .await;
122        assert!(result.success);
123
124        let output: CalculatorOutput = serde_json::from_str(&result.output).unwrap();
125        assert_eq!(output.result, 20.0);
126    }
127
128    #[tokio::test]
129    async fn test_invalid_expression() {
130        let calc = CalculatorTool::new();
131        let result = calc.execute(serde_json::json!({"expression": "2 +"})).await;
132        assert!(!result.success);
133    }
134
135    #[tokio::test]
136    async fn test_invalid_input() {
137        let calc = CalculatorTool::new();
138        let result = calc
139            .execute(serde_json::json!({"wrong_field": "test"}))
140            .await;
141        assert!(!result.success);
142    }
143}