Skip to main content

rustant_tools/
utils.rs

1//! Utility tools — simple built-in tools for echo, datetime, and calculation.
2
3use async_trait::async_trait;
4use rustant_core::error::ToolError;
5use rustant_core::types::{RiskLevel, ToolOutput};
6use std::time::Duration;
7
8use crate::registry::Tool;
9
10/// Echo tool — returns the input text unchanged.
11pub struct EchoTool;
12
13#[async_trait]
14impl Tool for EchoTool {
15    fn name(&self) -> &str {
16        "echo"
17    }
18
19    fn description(&self) -> &str {
20        "Echoes the input text back. Useful for testing and confirming values."
21    }
22
23    fn parameters_schema(&self) -> serde_json::Value {
24        serde_json::json!({
25            "type": "object",
26            "properties": {
27                "text": {
28                    "type": "string",
29                    "description": "The text to echo back"
30                }
31            },
32            "required": ["text"]
33        })
34    }
35
36    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
37        let text = args["text"]
38            .as_str()
39            .ok_or_else(|| ToolError::InvalidArguments {
40                name: "echo".to_string(),
41                reason: "missing required 'text' parameter".to_string(),
42            })?;
43        Ok(ToolOutput::text(text.to_string()))
44    }
45
46    fn risk_level(&self) -> RiskLevel {
47        RiskLevel::ReadOnly
48    }
49
50    fn timeout(&self) -> Duration {
51        Duration::from_secs(5)
52    }
53}
54
55/// DateTime tool — returns the current date and time.
56pub struct DateTimeTool;
57
58#[async_trait]
59impl Tool for DateTimeTool {
60    fn name(&self) -> &str {
61        "datetime"
62    }
63
64    fn description(&self) -> &str {
65        "Returns the current date and time in the specified format (default: RFC 3339)."
66    }
67
68    fn parameters_schema(&self) -> serde_json::Value {
69        serde_json::json!({
70            "type": "object",
71            "properties": {
72                "format": {
73                    "type": "string",
74                    "description": "strftime format string (default: RFC 3339)",
75                    "default": "%Y-%m-%dT%H:%M:%S%z"
76                }
77            }
78        })
79    }
80
81    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
82        let now = chrono::Utc::now();
83        let formatted = if let Some(fmt) = args.get("format").and_then(|f| f.as_str()) {
84            now.format(fmt).to_string()
85        } else {
86            now.to_rfc3339()
87        };
88        Ok(ToolOutput::text(formatted))
89    }
90
91    fn risk_level(&self) -> RiskLevel {
92        RiskLevel::ReadOnly
93    }
94
95    fn timeout(&self) -> Duration {
96        Duration::from_secs(5)
97    }
98}
99
100/// Calculator tool — evaluates simple arithmetic expressions.
101pub struct CalculatorTool;
102
103#[async_trait]
104impl Tool for CalculatorTool {
105    fn name(&self) -> &str {
106        "calculator"
107    }
108
109    fn description(&self) -> &str {
110        "Evaluates a simple arithmetic expression. Supports +, -, *, /, and parentheses."
111    }
112
113    fn parameters_schema(&self) -> serde_json::Value {
114        serde_json::json!({
115            "type": "object",
116            "properties": {
117                "expression": {
118                    "type": "string",
119                    "description": "The arithmetic expression to evaluate, e.g. '2 + 3 * (4 - 1)'"
120                }
121            },
122            "required": ["expression"]
123        })
124    }
125
126    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
127        let expr = args["expression"]
128            .as_str()
129            .ok_or_else(|| ToolError::InvalidArguments {
130                name: "calculator".to_string(),
131                reason: "missing required 'expression' parameter".to_string(),
132            })?;
133
134        match eval_expression(expr) {
135            Ok(result) => {
136                // Format nicely: if integer result, show without decimals
137                let formatted = if result.fract() == 0.0 && result.abs() < i64::MAX as f64 {
138                    format!("{}", result as i64)
139                } else {
140                    format!("{}", result)
141                };
142                Ok(ToolOutput::text(formatted))
143            }
144            Err(e) => Err(ToolError::ExecutionFailed {
145                name: "calculator".to_string(),
146                message: e,
147            }),
148        }
149    }
150
151    fn risk_level(&self) -> RiskLevel {
152        RiskLevel::ReadOnly
153    }
154
155    fn timeout(&self) -> Duration {
156        Duration::from_secs(5)
157    }
158}
159
160// --- Simple expression evaluator (recursive descent parser) ---
161
162/// Evaluate an arithmetic expression string.
163fn eval_expression(input: &str) -> Result<f64, String> {
164    let tokens = tokenize(input)?;
165    let mut pos = 0;
166    let result = parse_expr(&tokens, &mut pos)?;
167    if pos < tokens.len() {
168        return Err(format!(
169            "Unexpected token at position {}: {:?}",
170            pos, tokens[pos]
171        ));
172    }
173    Ok(result)
174}
175
176#[derive(Debug, Clone)]
177enum Token {
178    Number(f64),
179    Plus,
180    Minus,
181    Star,
182    Slash,
183    LParen,
184    RParen,
185}
186
187fn tokenize(input: &str) -> Result<Vec<Token>, String> {
188    let mut tokens = Vec::new();
189    let mut chars = input.chars().peekable();
190
191    while let Some(&ch) = chars.peek() {
192        match ch {
193            ' ' | '\t' | '\n' => {
194                chars.next();
195            }
196            '0'..='9' | '.' => {
197                let mut num_str = String::new();
198                while let Some(&c) = chars.peek() {
199                    if c.is_ascii_digit() || c == '.' {
200                        num_str.push(c);
201                        chars.next();
202                    } else {
203                        break;
204                    }
205                }
206                let num: f64 = num_str
207                    .parse()
208                    .map_err(|_| format!("Invalid number: {}", num_str))?;
209                tokens.push(Token::Number(num));
210            }
211            '+' => {
212                tokens.push(Token::Plus);
213                chars.next();
214            }
215            '-' => {
216                // Handle unary minus
217                let is_unary = tokens.is_empty()
218                    || matches!(
219                        tokens.last(),
220                        Some(
221                            Token::Plus | Token::Minus | Token::Star | Token::Slash | Token::LParen
222                        )
223                    );
224                chars.next();
225                if is_unary {
226                    // Parse the number after the unary minus
227                    // Skip whitespace
228                    while let Some(&c) = chars.peek() {
229                        if c == ' ' || c == '\t' {
230                            chars.next();
231                        } else {
232                            break;
233                        }
234                    }
235                    if let Some(&c) = chars.peek() {
236                        if c.is_ascii_digit() || c == '.' {
237                            let mut num_str = String::new();
238                            while let Some(&c) = chars.peek() {
239                                if c.is_ascii_digit() || c == '.' {
240                                    num_str.push(c);
241                                    chars.next();
242                                } else {
243                                    break;
244                                }
245                            }
246                            let num: f64 = num_str
247                                .parse()
248                                .map_err(|_| format!("Invalid number: {}", num_str))?;
249                            tokens.push(Token::Number(-num));
250                        } else if c == '(' {
251                            // Unary minus before parenthesized expression:
252                            // push -1 * (...)
253                            tokens.push(Token::Number(-1.0));
254                            tokens.push(Token::Star);
255                        } else {
256                            return Err(format!("Unexpected character after unary minus: {}", c));
257                        }
258                    } else {
259                        return Err("Unexpected end of expression after minus".to_string());
260                    }
261                } else {
262                    tokens.push(Token::Minus);
263                }
264            }
265            '*' => {
266                tokens.push(Token::Star);
267                chars.next();
268            }
269            '/' => {
270                tokens.push(Token::Slash);
271                chars.next();
272            }
273            '(' => {
274                tokens.push(Token::LParen);
275                chars.next();
276            }
277            ')' => {
278                tokens.push(Token::RParen);
279                chars.next();
280            }
281            _ => {
282                return Err(format!("Unexpected character: '{}'", ch));
283            }
284        }
285    }
286
287    Ok(tokens)
288}
289
290// expr = term (('+' | '-') term)*
291fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
292    let mut left = parse_term(tokens, pos)?;
293    while *pos < tokens.len() {
294        match tokens[*pos] {
295            Token::Plus => {
296                *pos += 1;
297                let right = parse_term(tokens, pos)?;
298                left += right;
299            }
300            Token::Minus => {
301                *pos += 1;
302                let right = parse_term(tokens, pos)?;
303                left -= right;
304            }
305            _ => break,
306        }
307    }
308    Ok(left)
309}
310
311// term = factor (('*' | '/') factor)*
312fn parse_term(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
313    let mut left = parse_factor(tokens, pos)?;
314    while *pos < tokens.len() {
315        match tokens[*pos] {
316            Token::Star => {
317                *pos += 1;
318                let right = parse_factor(tokens, pos)?;
319                left *= right;
320            }
321            Token::Slash => {
322                *pos += 1;
323                let right = parse_factor(tokens, pos)?;
324                if right == 0.0 {
325                    return Err("Division by zero".to_string());
326                }
327                left /= right;
328            }
329            _ => break,
330        }
331    }
332    Ok(left)
333}
334
335// factor = NUMBER | '(' expr ')'
336fn parse_factor(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
337    if *pos >= tokens.len() {
338        return Err("Unexpected end of expression".to_string());
339    }
340    match &tokens[*pos] {
341        Token::Number(n) => {
342            let val = *n;
343            *pos += 1;
344            Ok(val)
345        }
346        Token::LParen => {
347            *pos += 1; // consume '('
348            let val = parse_expr(tokens, pos)?;
349            if *pos >= tokens.len() {
350                return Err("Missing closing parenthesis".to_string());
351            }
352            match &tokens[*pos] {
353                Token::RParen => {
354                    *pos += 1;
355                    Ok(val)
356                }
357                _ => Err("Expected closing parenthesis".to_string()),
358            }
359        }
360        other => Err(format!("Unexpected token: {:?}", other)),
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    // --- EchoTool tests ---
369
370    #[tokio::test]
371    async fn test_echo_tool_basic() {
372        let tool = EchoTool;
373        let result = tool
374            .execute(serde_json::json!({"text": "hello world"}))
375            .await
376            .unwrap();
377        assert_eq!(result.content, "hello world");
378    }
379
380    #[tokio::test]
381    async fn test_echo_tool_empty_string() {
382        let tool = EchoTool;
383        let result = tool.execute(serde_json::json!({"text": ""})).await.unwrap();
384        assert_eq!(result.content, "");
385    }
386
387    #[tokio::test]
388    async fn test_echo_tool_missing_param() {
389        let tool = EchoTool;
390        let result = tool.execute(serde_json::json!({})).await;
391        assert!(result.is_err());
392    }
393
394    #[test]
395    fn test_echo_tool_properties() {
396        let tool = EchoTool;
397        assert_eq!(tool.name(), "echo");
398        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
399        assert!(tool.parameters_schema().is_object());
400    }
401
402    // --- DateTimeTool tests ---
403
404    #[tokio::test]
405    async fn test_datetime_tool_default_format() {
406        let tool = DateTimeTool;
407        let result = tool.execute(serde_json::json!({})).await.unwrap();
408        // RFC 3339 format should contain 'T' and include timezone
409        assert!(result.content.contains('T'));
410    }
411
412    #[tokio::test]
413    async fn test_datetime_tool_custom_format() {
414        let tool = DateTimeTool;
415        let result = tool
416            .execute(serde_json::json!({"format": "%Y-%m-%d"}))
417            .await
418            .unwrap();
419        // Should be in YYYY-MM-DD format
420        assert_eq!(result.content.len(), 10);
421        assert!(result.content.contains('-'));
422    }
423
424    #[test]
425    fn test_datetime_tool_properties() {
426        let tool = DateTimeTool;
427        assert_eq!(tool.name(), "datetime");
428        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
429    }
430
431    // --- CalculatorTool tests ---
432
433    #[tokio::test]
434    async fn test_calculator_simple_addition() {
435        let tool = CalculatorTool;
436        let result = tool
437            .execute(serde_json::json!({"expression": "2 + 3"}))
438            .await
439            .unwrap();
440        assert_eq!(result.content, "5");
441    }
442
443    #[tokio::test]
444    async fn test_calculator_multiplication() {
445        let tool = CalculatorTool;
446        let result = tool
447            .execute(serde_json::json!({"expression": "4 * 5"}))
448            .await
449            .unwrap();
450        assert_eq!(result.content, "20");
451    }
452
453    #[tokio::test]
454    async fn test_calculator_operator_precedence() {
455        let tool = CalculatorTool;
456        let result = tool
457            .execute(serde_json::json!({"expression": "2 + 3 * 4"}))
458            .await
459            .unwrap();
460        assert_eq!(result.content, "14");
461    }
462
463    #[tokio::test]
464    async fn test_calculator_parentheses() {
465        let tool = CalculatorTool;
466        let result = tool
467            .execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
468            .await
469            .unwrap();
470        assert_eq!(result.content, "20");
471    }
472
473    #[tokio::test]
474    async fn test_calculator_nested_parentheses() {
475        let tool = CalculatorTool;
476        let result = tool
477            .execute(serde_json::json!({"expression": "((1 + 2) * (3 + 4))"}))
478            .await
479            .unwrap();
480        assert_eq!(result.content, "21");
481    }
482
483    #[tokio::test]
484    async fn test_calculator_division() {
485        let tool = CalculatorTool;
486        let result = tool
487            .execute(serde_json::json!({"expression": "10 / 4"}))
488            .await
489            .unwrap();
490        assert_eq!(result.content, "2.5");
491    }
492
493    #[tokio::test]
494    async fn test_calculator_division_by_zero() {
495        let tool = CalculatorTool;
496        let result = tool
497            .execute(serde_json::json!({"expression": "5 / 0"}))
498            .await;
499        assert!(result.is_err());
500    }
501
502    #[tokio::test]
503    async fn test_calculator_negative_numbers() {
504        let tool = CalculatorTool;
505        let result = tool
506            .execute(serde_json::json!({"expression": "-3 + 5"}))
507            .await
508            .unwrap();
509        assert_eq!(result.content, "2");
510    }
511
512    #[tokio::test]
513    async fn test_calculator_decimal_numbers() {
514        let tool = CalculatorTool;
515        let result = tool
516            .execute(serde_json::json!({"expression": "3.5 * 2"}))
517            .await
518            .unwrap();
519        assert_eq!(result.content, "7");
520    }
521
522    #[tokio::test]
523    async fn test_calculator_missing_param() {
524        let tool = CalculatorTool;
525        let result = tool.execute(serde_json::json!({})).await;
526        assert!(result.is_err());
527    }
528
529    #[tokio::test]
530    async fn test_calculator_invalid_expression() {
531        let tool = CalculatorTool;
532        let result = tool.execute(serde_json::json!({"expression": "abc"})).await;
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn test_calculator_tool_properties() {
538        let tool = CalculatorTool;
539        assert_eq!(tool.name(), "calculator");
540        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
541    }
542
543    // --- Expression evaluator tests ---
544
545    #[test]
546    fn test_eval_simple() {
547        assert_eq!(eval_expression("1 + 1").unwrap(), 2.0);
548        assert_eq!(eval_expression("10 - 3").unwrap(), 7.0);
549        assert_eq!(eval_expression("6 * 7").unwrap(), 42.0);
550        assert_eq!(eval_expression("15 / 3").unwrap(), 5.0);
551    }
552
553    #[test]
554    fn test_eval_precedence() {
555        assert_eq!(eval_expression("2 + 3 * 4").unwrap(), 14.0);
556        assert_eq!(eval_expression("2 * 3 + 4").unwrap(), 10.0);
557    }
558
559    #[test]
560    fn test_eval_parentheses() {
561        assert_eq!(eval_expression("(2 + 3) * 4").unwrap(), 20.0);
562        assert_eq!(eval_expression("2 * (3 + 4)").unwrap(), 14.0);
563    }
564
565    #[test]
566    fn test_eval_unary_minus() {
567        assert_eq!(eval_expression("-5").unwrap(), -5.0);
568        assert_eq!(eval_expression("-5 + 10").unwrap(), 5.0);
569    }
570
571    #[test]
572    fn test_eval_errors() {
573        assert!(eval_expression("").is_err());
574        assert!(eval_expression("1 +").is_err());
575        assert!(eval_expression("(1 + 2").is_err());
576        assert!(eval_expression("1 / 0").is_err());
577    }
578}