use async_trait::async_trait;
use serde_json::Value;
use crate::tool::{Capability, Tool, ToolDefinition};
use crate::tool_error::ToolError;
pub struct CalculatorTool {
definition: ToolDefinition,
}
impl CalculatorTool {
pub fn new() -> Self {
Self {
definition: ToolDefinition::new(
"calculator",
"Evaluate mathematical expressions. Supports arithmetic operators (+, -, *, /, ^), \
functions (sqrt, sin, cos, tan, log, exp, abs), and constants (pi, e).",
r#"{
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate, e.g. '2 + 3 * 4' or 'sqrt(16)'"
}
},
"required": ["expression"]
}"#,
),
}
}
}
impl Default for CalculatorTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for CalculatorTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
fn capabilities(&self) -> Vec<Capability> {
vec![Capability::PureComputation] }
fn validate(&self, args: &Value) -> Result<(), ToolError> {
let expr = args
.get("expression")
.and_then(|e| e.as_str())
.ok_or_else(|| {
ToolError::invalid_args("calculator", "Missing required field 'expression'")
})?;
if expr.len() > 1000 {
return Err(ToolError::invalid_args(
"calculator",
"Expression too long (max 1000 characters)",
));
}
Ok(())
}
async fn execute(&self, args: Value) -> Result<Value, ToolError> {
let expr = args["expression"]
.as_str()
.ok_or_else(|| ToolError::invalid_args("calculator", "Missing 'expression' field"))?;
let result = meval::eval_str(expr).map_err(|e| {
ToolError::execution_failed(
"calculator",
format!("Failed to evaluate expression: {}", e),
)
})?;
if result.is_nan() {
return Err(ToolError::execution_failed(
"calculator",
"Result is not a number (NaN)",
));
}
if result.is_infinite() {
return Err(ToolError::execution_failed(
"calculator",
"Result is infinite (division by zero or overflow)",
));
}
Ok(serde_json::json!({
"expression": expr,
"result": result
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_arithmetic() {
let calc = CalculatorTool::new();
let result = calc
.execute(serde_json::json!({"expression": "2 + 3 * 4"}))
.await
.unwrap();
assert_eq!(result["result"], 14.0);
}
#[tokio::test]
async fn test_with_parentheses() {
let calc = CalculatorTool::new();
let result = calc
.execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
.await
.unwrap();
assert_eq!(result["result"], 20.0);
}
#[tokio::test]
async fn test_functions() {
let calc = CalculatorTool::new();
let result = calc
.execute(serde_json::json!({"expression": "sqrt(16)"}))
.await
.unwrap();
assert_eq!(result["result"], 4.0);
}
#[tokio::test]
async fn test_constants() {
let calc = CalculatorTool::new();
let result = calc
.execute(serde_json::json!({"expression": "pi"}))
.await
.unwrap();
let pi = result["result"].as_f64().unwrap();
assert!((pi - std::f64::consts::PI).abs() < 0.0001);
}
#[tokio::test]
async fn test_invalid_expression() {
let calc = CalculatorTool::new();
let result = calc
.execute(serde_json::json!({"expression": "invalid ++ syntax"}))
.await;
assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
}
#[tokio::test]
async fn test_missing_expression() {
let calc = CalculatorTool::new();
let result = calc.validate(&serde_json::json!({}));
assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
}
#[tokio::test]
async fn test_division_by_zero() {
let calc = CalculatorTool::new();
let result = calc.execute(serde_json::json!({"expression": "1/0"})).await;
assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
}
#[tokio::test]
async fn test_expression_too_long() {
let calc = CalculatorTool::new();
let long_expr = "1+".repeat(600);
let result = calc.validate(&serde_json::json!({"expression": long_expr}));
assert!(matches!(result, Err(ToolError::InvalidArguments { .. })));
}
}