use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::core::tools::{BaseTool, Tool, ToolError};
#[derive(Debug, Deserialize, JsonSchema)]
pub struct CalculatorInput {
pub expression: String,
}
#[derive(Debug, Serialize)]
pub struct CalculatorOutput {
pub result: f64,
pub expression: String,
}
pub struct Calculator;
impl Calculator {
pub fn new() -> Self {
Self
}
}
impl Default for Calculator {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for Calculator {
type Input = CalculatorInput;
type Output = CalculatorOutput;
async fn invoke(&self, input: Self::Input) -> Result<Self::Output, ToolError> {
let result = self.evaluate_expression(&input.expression)?;
Ok(CalculatorOutput {
result,
expression: input.expression,
})
}
}
#[async_trait]
impl BaseTool for Calculator {
fn name(&self) -> &str {
"calculator"
}
fn description(&self) -> &str {
"计算数学表达式。支持基本运算(加减乘除)、幂运算、平方根、三角函数等。
示例:
- '2 + 2' → 4
- 'sqrt(16)' → 4
- '3.14 * 10' → 31.4
- 'sin(1.57)' → 接近 1
- 'pow(2, 10)' → 1024
输入格式: JSON 对象,包含 expression 字段
例如: {\"expression\": \"2 + 3\"}"
}
async fn run(&self, input: String) -> Result<String, ToolError> {
let parsed: CalculatorInput = serde_json::from_str(&input)
.map_err(|e| ToolError::InvalidInput(format!("JSON 解析失败: {}", e)))?;
let output = self.invoke(parsed).await?;
Ok(format!("{} = {}", output.expression, output.result))
}
fn args_schema(&self) -> Option<serde_json::Value> {
use schemars::schema_for;
serde_json::to_value(schema_for!(CalculatorInput)).ok()
}
}
impl Calculator {
fn evaluate_expression(&self, expr: &str) -> Result<f64, ToolError> {
let expr = expr.trim();
if let Ok(num) = expr.parse::<f64>() {
return Ok(num);
}
if expr.contains('+') {
let parts: Vec<&str> = expr.split('+').collect();
if parts.len() == 2 {
let a: f64 = parts[0].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
let b: f64 = parts[1].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
return Ok(a + b);
}
}
if expr.contains('-') {
let parts: Vec<&str> = expr.split('-').collect();
if parts.len() == 2 {
let a: f64 = parts[0].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
let b: f64 = parts[1].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
return Ok(a - b);
}
}
if expr.contains('*') {
let parts: Vec<&str> = expr.split('*').collect();
if parts.len() == 2 {
let a: f64 = parts[0].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
let b: f64 = parts[1].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
return Ok(a * b);
}
}
if expr.contains('/') {
let parts: Vec<&str> = expr.split('/').collect();
if parts.len() == 2 {
let a: f64 = parts[0].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
let b: f64 = parts[1].trim().parse()
.map_err(|e| ToolError::ExecutionFailed(format!("解析失败: {}", e)))?;
if b == 0.0 {
return Err(ToolError::ExecutionFailed("除数不能为0".to_string()));
}
return Ok(a / b);
}
}
Err(ToolError::ExecutionFailed(
format!("无法解析表达式: {}", expr)
))
}
}