Skip to main content

mofa_foundation/react/
tools.rs

1//! ReAct 内置工具实现
2//!
3//! 提供常用工具的实现示例
4
5use super::core::ReActTool;
6use async_trait::async_trait;
7use serde::Deserialize;
8use serde_json::Value;
9
10/// 计算器工具
11///
12/// 支持基本的数学表达式计算
13pub struct CalculatorTool;
14
15#[async_trait]
16impl ReActTool for CalculatorTool {
17    fn name(&self) -> &str {
18        "calculator"
19    }
20
21    fn description(&self) -> &str {
22        "Perform mathematical calculations. Input should be a mathematical expression like '2 + 2' or '(10 * 5) / 2'"
23    }
24
25    fn parameters_schema(&self) -> Option<Value> {
26        Some(serde_json::json!({
27            "type": "object",
28            "properties": {
29                "expression": {
30                    "type": "string",
31                    "description": "The mathematical expression to evaluate"
32                }
33            },
34            "required": ["expression"]
35        }))
36    }
37
38    async fn execute(&self, input: &str) -> Result<String, String> {
39        // 尝试解析 JSON 输入
40        let expression = if let Ok(json) = serde_json::from_str::<Value>(input) {
41            json.get("expression")
42                .and_then(|v| v.as_str())
43                .unwrap_or(input)
44                .to_string()
45        } else {
46            input.to_string()
47        };
48
49        // 简单的表达式计算 (仅支持基本运算)
50        match evaluate_expression(&expression) {
51            Ok(result) => Ok(format!("{}", result)),
52            Err(e) => Err(format!("Calculation error: {}", e)),
53        }
54    }
55}
56
57/// 简单的表达式求值器
58fn evaluate_expression(expr: &str) -> Result<f64, String> {
59    let expr = expr.trim();
60
61    // 处理括号
62    if expr.starts_with('(') && expr.ends_with(')') {
63        // 检查是否是完整的括号表达式
64        let inner = &expr[1..expr.len() - 1];
65        if is_balanced(inner) {
66            return evaluate_expression(inner);
67        }
68    }
69
70    // 查找最低优先级的运算符 (从右向左,处理左结合性)
71    let mut paren_depth = 0;
72    let mut last_add_sub = None;
73    let mut last_mul_div = None;
74
75    let chars: Vec<char> = expr.chars().collect();
76    for (i, &c) in chars.iter().enumerate() {
77        match c {
78            '(' => paren_depth += 1,
79            ')' => paren_depth -= 1,
80            '+' | '-' if paren_depth == 0 && i > 0 => {
81                // 确保不是负号
82                let prev = chars.get(i.saturating_sub(1)).copied().unwrap_or(' ');
83                if !matches!(prev, '+' | '-' | '*' | '/' | '(') {
84                    last_add_sub = Some(i);
85                }
86            }
87            '*' | '/' if paren_depth == 0 => {
88                last_mul_div = Some(i);
89            }
90            _ => {}
91        }
92    }
93
94    // 先处理加减,再处理乘除
95    if let Some(pos) = last_add_sub {
96        let left = evaluate_expression(&expr[..pos])?;
97        let right = evaluate_expression(&expr[pos + 1..])?;
98        return match chars[pos] {
99            '+' => Ok(left + right),
100            '-' => Ok(left - right),
101            _ => unreachable!(),
102        };
103    }
104
105    if let Some(pos) = last_mul_div {
106        let left = evaluate_expression(&expr[..pos])?;
107        let right = evaluate_expression(&expr[pos + 1..])?;
108        return match chars[pos] {
109            '*' => Ok(left * right),
110            '/' => {
111                if right == 0.0 {
112                    Err("Division by zero".to_string())
113                } else {
114                    Ok(left / right)
115                }
116            }
117            _ => unreachable!(),
118        };
119    }
120
121    // 尝试解析为数字
122    expr.parse::<f64>()
123        .map_err(|_| format!("Invalid expression: {}", expr))
124}
125
126/// 检查括号是否平衡
127fn is_balanced(s: &str) -> bool {
128    let mut depth = 0;
129    for c in s.chars() {
130        match c {
131            '(' => depth += 1,
132            ')' => {
133                depth -= 1;
134                if depth < 0 {
135                    return false;
136                }
137            }
138            _ => {}
139        }
140    }
141    depth == 0
142}
143
144/// 字符串工具
145///
146/// 提供字符串处理功能
147pub struct StringTool;
148
149#[async_trait]
150impl ReActTool for StringTool {
151    fn name(&self) -> &str {
152        "string"
153    }
154
155    fn description(&self) -> &str {
156        "Perform string operations. Operations: 'length', 'upper', 'lower', 'reverse', 'count'"
157    }
158
159    fn parameters_schema(&self) -> Option<Value> {
160        Some(serde_json::json!({
161            "type": "object",
162            "properties": {
163                "operation": {
164                    "type": "string",
165                    "enum": ["length", "upper", "lower", "reverse", "count"],
166                    "description": "The string operation to perform"
167                },
168                "text": {
169                    "type": "string",
170                    "description": "The text to operate on"
171                },
172                "pattern": {
173                    "type": "string",
174                    "description": "Pattern for count operation (optional)"
175                }
176            },
177            "required": ["operation", "text"]
178        }))
179    }
180
181    async fn execute(&self, input: &str) -> Result<String, String> {
182        #[derive(Deserialize)]
183        struct StringInput {
184            operation: String,
185            text: String,
186            pattern: Option<String>,
187        }
188
189        // 尝试解析 JSON
190        let params: StringInput = if let Ok(p) = serde_json::from_str(input) {
191            p
192        } else {
193            // 简单格式: operation:text
194            let parts: Vec<&str> = input.splitn(2, ':').collect();
195            if parts.len() < 2 {
196                return Err("Invalid input format. Use JSON or 'operation:text'".to_string());
197            }
198            StringInput {
199                operation: parts[0].trim().to_string(),
200                text: parts[1].trim().to_string(),
201                pattern: None,
202            }
203        };
204
205        match params.operation.as_str() {
206            "length" => Ok(params.text.len().to_string()),
207            "upper" => Ok(params.text.to_uppercase()),
208            "lower" => Ok(params.text.to_lowercase()),
209            "reverse" => Ok(params.text.chars().rev().collect()),
210            "count" => {
211                let pattern = params.pattern.as_deref().unwrap_or(" ");
212                Ok(params.text.matches(pattern).count().to_string())
213            }
214            _ => Err(format!("Unknown operation: {}", params.operation)),
215        }
216    }
217}
218
219/// JSON 工具
220///
221/// 提供 JSON 解析和查询功能
222pub struct JsonTool;
223
224#[async_trait]
225impl ReActTool for JsonTool {
226    fn name(&self) -> &str {
227        "json"
228    }
229
230    fn description(&self) -> &str {
231        "Parse and query JSON data. Operations: 'parse', 'get', 'keys', 'stringify'"
232    }
233
234    fn parameters_schema(&self) -> Option<Value> {
235        Some(serde_json::json!({
236            "type": "object",
237            "properties": {
238                "operation": {
239                    "type": "string",
240                    "enum": ["parse", "get", "keys", "stringify"],
241                    "description": "The JSON operation to perform"
242                },
243                "data": {
244                    "type": "string",
245                    "description": "The JSON data to operate on"
246                },
247                "path": {
248                    "type": "string",
249                    "description": "JSON path for 'get' operation (e.g., 'user.name')"
250                }
251            },
252            "required": ["operation", "data"]
253        }))
254    }
255
256    async fn execute(&self, input: &str) -> Result<String, String> {
257        #[derive(Deserialize)]
258        struct JsonInput {
259            operation: String,
260            data: String,
261            path: Option<String>,
262        }
263
264        let params: JsonInput =
265            serde_json::from_str(input).map_err(|e| format!("Invalid JSON input: {}", e))?;
266
267        let json: Value =
268            serde_json::from_str(&params.data).map_err(|e| format!("Invalid JSON data: {}", e))?;
269
270        match params.operation.as_str() {
271            "parse" => Ok(format!("Parsed successfully: {}", json)),
272            "get" => {
273                let path = params.path.ok_or("Path required for 'get' operation")?;
274                let mut current = &json;
275                for key in path.split('.') {
276                    current = current
277                        .get(key)
278                        .ok_or_else(|| format!("Key '{}' not found", key))?;
279                }
280                Ok(current.to_string())
281            }
282            "keys" => {
283                if let Some(obj) = json.as_object() {
284                    let keys: Vec<&str> = obj.keys().map(|s| s.as_str()).collect();
285                    Ok(format!("{:?}", keys))
286                } else {
287                    Err("Not a JSON object".to_string())
288                }
289            }
290            "stringify" => {
291                serde_json::to_string_pretty(&json).map_err(|e| format!("Stringify error: {}", e))
292            }
293            _ => Err(format!("Unknown operation: {}", params.operation)),
294        }
295    }
296}
297
298/// 日期时间工具
299///
300/// 提供日期和时间相关功能
301pub struct DateTimeTool;
302
303#[async_trait]
304impl ReActTool for DateTimeTool {
305    fn name(&self) -> &str {
306        "datetime"
307    }
308
309    fn description(&self) -> &str {
310        "Get current date/time information. Operations: 'now', 'timestamp', 'format'"
311    }
312
313    async fn execute(&self, input: &str) -> Result<String, String> {
314        let operation = input.trim().to_lowercase();
315
316        use std::time::{SystemTime, UNIX_EPOCH};
317
318        let now = SystemTime::now()
319            .duration_since(UNIX_EPOCH)
320            .map_err(|e| e.to_string())?;
321
322        match operation.as_str() {
323            "now" | "current" => {
324                let secs = now.as_secs();
325                // 简单的 UTC 时间格式化
326                let days_since_epoch = secs / 86400;
327                let time_of_day = secs % 86400;
328                let hours = time_of_day / 3600;
329                let minutes = (time_of_day % 3600) / 60;
330                let seconds = time_of_day % 60;
331
332                // 简化的日期计算 (从 1970-01-01 开始)
333                let (year, month, day) = days_to_date(days_since_epoch);
334
335                Ok(format!(
336                    "{:04}-{:02}-{:02} {:02}:{:02}:{:02} UTC",
337                    year, month, day, hours, minutes, seconds
338                ))
339            }
340            "timestamp" | "unix" => Ok(now.as_secs().to_string()),
341            "millis" | "milliseconds" => Ok(now.as_millis().to_string()),
342            _ => Err(format!(
343                "Unknown operation: {}. Use 'now', 'timestamp', or 'millis'",
344                operation
345            )),
346        }
347    }
348}
349
350/// 简化的日期计算
351fn days_to_date(days: u64) -> (u64, u64, u64) {
352    // 从 1970-01-01 计算
353    let mut remaining = days as i64;
354    let mut year = 1970i64;
355
356    loop {
357        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
358        if remaining < days_in_year {
359            break;
360        }
361        remaining -= days_in_year;
362        year += 1;
363    }
364
365    let mut month = 1u64;
366    let days_in_months = if is_leap_year(year) {
367        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
368    } else {
369        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
370    };
371
372    for days_in_month in days_in_months {
373        if remaining < days_in_month {
374            break;
375        }
376        remaining -= days_in_month;
377        month += 1;
378    }
379
380    (year as u64, month, remaining as u64 + 1)
381}
382
383fn is_leap_year(year: i64) -> bool {
384    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
385}
386
387/// Echo 工具 (测试用)
388///
389/// 简单地回显输入
390pub struct EchoTool;
391
392#[async_trait]
393impl ReActTool for EchoTool {
394    fn name(&self) -> &str {
395        "echo"
396    }
397
398    fn description(&self) -> &str {
399        "Echo the input back. Useful for testing."
400    }
401
402    async fn execute(&self, input: &str) -> Result<String, String> {
403        Ok(format!("Echo: {}", input))
404    }
405}
406
407/// 工具注册表便捷函数
408pub mod prelude {
409    use super::*;
410    use std::sync::Arc;
411
412    /// 创建计算器工具
413    pub fn calculator() -> Arc<dyn ReActTool> {
414        Arc::new(CalculatorTool)
415    }
416
417    /// 创建字符串工具
418    pub fn string_tool() -> Arc<dyn ReActTool> {
419        Arc::new(StringTool)
420    }
421
422    /// 创建 JSON 工具
423    pub fn json_tool() -> Arc<dyn ReActTool> {
424        Arc::new(JsonTool)
425    }
426
427    /// 创建日期时间工具
428    pub fn datetime_tool() -> Arc<dyn ReActTool> {
429        Arc::new(DateTimeTool)
430    }
431
432    /// 创建 Echo 工具
433    pub fn echo_tool() -> Arc<dyn ReActTool> {
434        Arc::new(EchoTool)
435    }
436
437    /// 获取所有内置工具
438    pub fn all_builtin_tools() -> Vec<Arc<dyn ReActTool>> {
439        vec![
440            calculator(),
441            string_tool(),
442            json_tool(),
443            datetime_tool(),
444            echo_tool(),
445        ]
446    }
447}
448
449/// 自定义工具构建器
450///
451/// 方便创建简单的自定义工具
452pub struct CustomToolBuilder {
453    name: String,
454    description: String,
455    parameters_schema: Option<Value>,
456    handler: Option<Box<dyn Fn(&str) -> Result<String, String> + Send + Sync>>,
457}
458
459impl CustomToolBuilder {
460    pub fn new(name: impl Into<String>) -> Self {
461        Self {
462            name: name.into(),
463            description: String::new(),
464            parameters_schema: None,
465            handler: None,
466        }
467    }
468
469    pub fn description(mut self, desc: impl Into<String>) -> Self {
470        self.description = desc.into();
471        self
472    }
473
474    pub fn parameters(mut self, schema: Value) -> Self {
475        self.parameters_schema = Some(schema);
476        self
477    }
478
479    pub fn handler<F>(mut self, f: F) -> Self
480    where
481        F: Fn(&str) -> Result<String, String> + Send + Sync + 'static,
482    {
483        self.handler = Some(Box::new(f));
484        self
485    }
486
487    pub fn build(self) -> Option<CustomTool> {
488        Some(CustomTool {
489            name: self.name,
490            description: self.description,
491            parameters_schema: self.parameters_schema,
492            handler: self.handler?,
493        })
494    }
495}
496
497/// 自定义工具
498pub struct CustomTool {
499    name: String,
500    description: String,
501    parameters_schema: Option<Value>,
502    handler: Box<dyn Fn(&str) -> Result<String, String> + Send + Sync>,
503}
504
505#[async_trait]
506impl ReActTool for CustomTool {
507    fn name(&self) -> &str {
508        &self.name
509    }
510
511    fn description(&self) -> &str {
512        &self.description
513    }
514
515    fn parameters_schema(&self) -> Option<Value> {
516        self.parameters_schema.clone()
517    }
518
519    async fn execute(&self, input: &str) -> Result<String, String> {
520        (self.handler)(input)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_calculator() {
530        assert_eq!(evaluate_expression("2 + 2").unwrap(), 4.0);
531        assert_eq!(evaluate_expression("10 * 5").unwrap(), 50.0);
532        assert_eq!(evaluate_expression("(2 + 3) * 4").unwrap(), 20.0);
533        assert_eq!(evaluate_expression("100 / 4").unwrap(), 25.0);
534    }
535
536    #[test]
537    fn test_date_calculation() {
538        // 1970-01-01 is day 0 since epoch
539        let (y, m, d) = days_to_date(0);
540        assert_eq!(y, 1970);
541        assert_eq!(m, 1);
542        assert_eq!(d, 1);
543
544        // 1970-01-02 is day 1
545        let (y, m, d) = days_to_date(1);
546        assert_eq!(y, 1970);
547        assert_eq!(m, 1);
548        assert_eq!(d, 2);
549    }
550
551    #[tokio::test]
552    async fn test_echo_tool() {
553        let tool = EchoTool;
554        let result = tool.execute("hello").await.unwrap();
555        assert_eq!(result, "Echo: hello");
556    }
557
558    #[tokio::test]
559    async fn test_string_tool() {
560        let tool = StringTool;
561        let result = tool.execute("upper:hello").await.unwrap();
562        assert_eq!(result, "HELLO");
563    }
564}