1use async_trait::async_trait;
4use rustant_core::error::ToolError;
5use rustant_core::types::{RiskLevel, ToolOutput};
6use std::time::Duration;
7
8use crate::registry::Tool;
9
10pub struct EchoTool;
12
13#[async_trait]
14impl Tool for EchoTool {
15 fn name(&self) -> &str {
16 "echo"
17 }
18
19 fn description(&self) -> &str {
20 "Echoes the input text back. Useful for testing and confirming values."
21 }
22
23 fn parameters_schema(&self) -> serde_json::Value {
24 serde_json::json!({
25 "type": "object",
26 "properties": {
27 "text": {
28 "type": "string",
29 "description": "The text to echo back"
30 }
31 },
32 "required": ["text"]
33 })
34 }
35
36 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
37 let text = args["text"]
38 .as_str()
39 .ok_or_else(|| ToolError::InvalidArguments {
40 name: "echo".to_string(),
41 reason: "missing required 'text' parameter".to_string(),
42 })?;
43 Ok(ToolOutput::text(text.to_string()))
44 }
45
46 fn risk_level(&self) -> RiskLevel {
47 RiskLevel::ReadOnly
48 }
49
50 fn timeout(&self) -> Duration {
51 Duration::from_secs(5)
52 }
53}
54
55pub struct DateTimeTool;
57
58#[async_trait]
59impl Tool for DateTimeTool {
60 fn name(&self) -> &str {
61 "datetime"
62 }
63
64 fn description(&self) -> &str {
65 "Returns the current date and time in the specified format (default: RFC 3339)."
66 }
67
68 fn parameters_schema(&self) -> serde_json::Value {
69 serde_json::json!({
70 "type": "object",
71 "properties": {
72 "format": {
73 "type": "string",
74 "description": "strftime format string (default: RFC 3339)",
75 "default": "%Y-%m-%dT%H:%M:%S%z"
76 }
77 }
78 })
79 }
80
81 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
82 let now = chrono::Utc::now();
83 let formatted = if let Some(fmt) = args.get("format").and_then(|f| f.as_str()) {
84 now.format(fmt).to_string()
85 } else {
86 now.to_rfc3339()
87 };
88 Ok(ToolOutput::text(formatted))
89 }
90
91 fn risk_level(&self) -> RiskLevel {
92 RiskLevel::ReadOnly
93 }
94
95 fn timeout(&self) -> Duration {
96 Duration::from_secs(5)
97 }
98}
99
100pub struct CalculatorTool;
102
103#[async_trait]
104impl Tool for CalculatorTool {
105 fn name(&self) -> &str {
106 "calculator"
107 }
108
109 fn description(&self) -> &str {
110 "Evaluates a simple arithmetic expression. Supports +, -, *, /, and parentheses."
111 }
112
113 fn parameters_schema(&self) -> serde_json::Value {
114 serde_json::json!({
115 "type": "object",
116 "properties": {
117 "expression": {
118 "type": "string",
119 "description": "The arithmetic expression to evaluate, e.g. '2 + 3 * (4 - 1)'"
120 }
121 },
122 "required": ["expression"]
123 })
124 }
125
126 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
127 let expr = args["expression"]
128 .as_str()
129 .ok_or_else(|| ToolError::InvalidArguments {
130 name: "calculator".to_string(),
131 reason: "missing required 'expression' parameter".to_string(),
132 })?;
133
134 match eval_expression(expr) {
135 Ok(result) => {
136 let formatted = if result.fract() == 0.0 && result.abs() < i64::MAX as f64 {
138 format!("{}", result as i64)
139 } else {
140 format!("{}", result)
141 };
142 Ok(ToolOutput::text(formatted))
143 }
144 Err(e) => Err(ToolError::ExecutionFailed {
145 name: "calculator".to_string(),
146 message: e,
147 }),
148 }
149 }
150
151 fn risk_level(&self) -> RiskLevel {
152 RiskLevel::ReadOnly
153 }
154
155 fn timeout(&self) -> Duration {
156 Duration::from_secs(5)
157 }
158}
159
160fn eval_expression(input: &str) -> Result<f64, String> {
164 let tokens = tokenize(input)?;
165 let mut pos = 0;
166 let result = parse_expr(&tokens, &mut pos)?;
167 if pos < tokens.len() {
168 return Err(format!(
169 "Unexpected token at position {}: {:?}",
170 pos, tokens[pos]
171 ));
172 }
173 Ok(result)
174}
175
176#[derive(Debug, Clone)]
177enum Token {
178 Number(f64),
179 Plus,
180 Minus,
181 Star,
182 Slash,
183 LParen,
184 RParen,
185}
186
187fn tokenize(input: &str) -> Result<Vec<Token>, String> {
188 let mut tokens = Vec::new();
189 let mut chars = input.chars().peekable();
190
191 while let Some(&ch) = chars.peek() {
192 match ch {
193 ' ' | '\t' | '\n' => {
194 chars.next();
195 }
196 '0'..='9' | '.' => {
197 let mut num_str = String::new();
198 while let Some(&c) = chars.peek() {
199 if c.is_ascii_digit() || c == '.' {
200 num_str.push(c);
201 chars.next();
202 } else {
203 break;
204 }
205 }
206 let num: f64 = num_str
207 .parse()
208 .map_err(|_| format!("Invalid number: {}", num_str))?;
209 tokens.push(Token::Number(num));
210 }
211 '+' => {
212 tokens.push(Token::Plus);
213 chars.next();
214 }
215 '-' => {
216 let is_unary = tokens.is_empty()
218 || matches!(
219 tokens.last(),
220 Some(
221 Token::Plus | Token::Minus | Token::Star | Token::Slash | Token::LParen
222 )
223 );
224 chars.next();
225 if is_unary {
226 while let Some(&c) = chars.peek() {
229 if c == ' ' || c == '\t' {
230 chars.next();
231 } else {
232 break;
233 }
234 }
235 if let Some(&c) = chars.peek() {
236 if c.is_ascii_digit() || c == '.' {
237 let mut num_str = String::new();
238 while let Some(&c) = chars.peek() {
239 if c.is_ascii_digit() || c == '.' {
240 num_str.push(c);
241 chars.next();
242 } else {
243 break;
244 }
245 }
246 let num: f64 = num_str
247 .parse()
248 .map_err(|_| format!("Invalid number: {}", num_str))?;
249 tokens.push(Token::Number(-num));
250 } else if c == '(' {
251 tokens.push(Token::Number(-1.0));
254 tokens.push(Token::Star);
255 } else {
256 return Err(format!("Unexpected character after unary minus: {}", c));
257 }
258 } else {
259 return Err("Unexpected end of expression after minus".to_string());
260 }
261 } else {
262 tokens.push(Token::Minus);
263 }
264 }
265 '*' => {
266 tokens.push(Token::Star);
267 chars.next();
268 }
269 '/' => {
270 tokens.push(Token::Slash);
271 chars.next();
272 }
273 '(' => {
274 tokens.push(Token::LParen);
275 chars.next();
276 }
277 ')' => {
278 tokens.push(Token::RParen);
279 chars.next();
280 }
281 _ => {
282 return Err(format!("Unexpected character: '{}'", ch));
283 }
284 }
285 }
286
287 Ok(tokens)
288}
289
290fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
292 let mut left = parse_term(tokens, pos)?;
293 while *pos < tokens.len() {
294 match tokens[*pos] {
295 Token::Plus => {
296 *pos += 1;
297 let right = parse_term(tokens, pos)?;
298 left += right;
299 }
300 Token::Minus => {
301 *pos += 1;
302 let right = parse_term(tokens, pos)?;
303 left -= right;
304 }
305 _ => break,
306 }
307 }
308 Ok(left)
309}
310
311fn parse_term(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
313 let mut left = parse_factor(tokens, pos)?;
314 while *pos < tokens.len() {
315 match tokens[*pos] {
316 Token::Star => {
317 *pos += 1;
318 let right = parse_factor(tokens, pos)?;
319 left *= right;
320 }
321 Token::Slash => {
322 *pos += 1;
323 let right = parse_factor(tokens, pos)?;
324 if right == 0.0 {
325 return Err("Division by zero".to_string());
326 }
327 left /= right;
328 }
329 _ => break,
330 }
331 }
332 Ok(left)
333}
334
335fn parse_factor(tokens: &[Token], pos: &mut usize) -> Result<f64, String> {
337 if *pos >= tokens.len() {
338 return Err("Unexpected end of expression".to_string());
339 }
340 match &tokens[*pos] {
341 Token::Number(n) => {
342 let val = *n;
343 *pos += 1;
344 Ok(val)
345 }
346 Token::LParen => {
347 *pos += 1; let val = parse_expr(tokens, pos)?;
349 if *pos >= tokens.len() {
350 return Err("Missing closing parenthesis".to_string());
351 }
352 match &tokens[*pos] {
353 Token::RParen => {
354 *pos += 1;
355 Ok(val)
356 }
357 _ => Err("Expected closing parenthesis".to_string()),
358 }
359 }
360 other => Err(format!("Unexpected token: {:?}", other)),
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[tokio::test]
371 async fn test_echo_tool_basic() {
372 let tool = EchoTool;
373 let result = tool
374 .execute(serde_json::json!({"text": "hello world"}))
375 .await
376 .unwrap();
377 assert_eq!(result.content, "hello world");
378 }
379
380 #[tokio::test]
381 async fn test_echo_tool_empty_string() {
382 let tool = EchoTool;
383 let result = tool.execute(serde_json::json!({"text": ""})).await.unwrap();
384 assert_eq!(result.content, "");
385 }
386
387 #[tokio::test]
388 async fn test_echo_tool_missing_param() {
389 let tool = EchoTool;
390 let result = tool.execute(serde_json::json!({})).await;
391 assert!(result.is_err());
392 }
393
394 #[test]
395 fn test_echo_tool_properties() {
396 let tool = EchoTool;
397 assert_eq!(tool.name(), "echo");
398 assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
399 assert!(tool.parameters_schema().is_object());
400 }
401
402 #[tokio::test]
405 async fn test_datetime_tool_default_format() {
406 let tool = DateTimeTool;
407 let result = tool.execute(serde_json::json!({})).await.unwrap();
408 assert!(result.content.contains('T'));
410 }
411
412 #[tokio::test]
413 async fn test_datetime_tool_custom_format() {
414 let tool = DateTimeTool;
415 let result = tool
416 .execute(serde_json::json!({"format": "%Y-%m-%d"}))
417 .await
418 .unwrap();
419 assert_eq!(result.content.len(), 10);
421 assert!(result.content.contains('-'));
422 }
423
424 #[test]
425 fn test_datetime_tool_properties() {
426 let tool = DateTimeTool;
427 assert_eq!(tool.name(), "datetime");
428 assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
429 }
430
431 #[tokio::test]
434 async fn test_calculator_simple_addition() {
435 let tool = CalculatorTool;
436 let result = tool
437 .execute(serde_json::json!({"expression": "2 + 3"}))
438 .await
439 .unwrap();
440 assert_eq!(result.content, "5");
441 }
442
443 #[tokio::test]
444 async fn test_calculator_multiplication() {
445 let tool = CalculatorTool;
446 let result = tool
447 .execute(serde_json::json!({"expression": "4 * 5"}))
448 .await
449 .unwrap();
450 assert_eq!(result.content, "20");
451 }
452
453 #[tokio::test]
454 async fn test_calculator_operator_precedence() {
455 let tool = CalculatorTool;
456 let result = tool
457 .execute(serde_json::json!({"expression": "2 + 3 * 4"}))
458 .await
459 .unwrap();
460 assert_eq!(result.content, "14");
461 }
462
463 #[tokio::test]
464 async fn test_calculator_parentheses() {
465 let tool = CalculatorTool;
466 let result = tool
467 .execute(serde_json::json!({"expression": "(2 + 3) * 4"}))
468 .await
469 .unwrap();
470 assert_eq!(result.content, "20");
471 }
472
473 #[tokio::test]
474 async fn test_calculator_nested_parentheses() {
475 let tool = CalculatorTool;
476 let result = tool
477 .execute(serde_json::json!({"expression": "((1 + 2) * (3 + 4))"}))
478 .await
479 .unwrap();
480 assert_eq!(result.content, "21");
481 }
482
483 #[tokio::test]
484 async fn test_calculator_division() {
485 let tool = CalculatorTool;
486 let result = tool
487 .execute(serde_json::json!({"expression": "10 / 4"}))
488 .await
489 .unwrap();
490 assert_eq!(result.content, "2.5");
491 }
492
493 #[tokio::test]
494 async fn test_calculator_division_by_zero() {
495 let tool = CalculatorTool;
496 let result = tool
497 .execute(serde_json::json!({"expression": "5 / 0"}))
498 .await;
499 assert!(result.is_err());
500 }
501
502 #[tokio::test]
503 async fn test_calculator_negative_numbers() {
504 let tool = CalculatorTool;
505 let result = tool
506 .execute(serde_json::json!({"expression": "-3 + 5"}))
507 .await
508 .unwrap();
509 assert_eq!(result.content, "2");
510 }
511
512 #[tokio::test]
513 async fn test_calculator_decimal_numbers() {
514 let tool = CalculatorTool;
515 let result = tool
516 .execute(serde_json::json!({"expression": "3.5 * 2"}))
517 .await
518 .unwrap();
519 assert_eq!(result.content, "7");
520 }
521
522 #[tokio::test]
523 async fn test_calculator_missing_param() {
524 let tool = CalculatorTool;
525 let result = tool.execute(serde_json::json!({})).await;
526 assert!(result.is_err());
527 }
528
529 #[tokio::test]
530 async fn test_calculator_invalid_expression() {
531 let tool = CalculatorTool;
532 let result = tool.execute(serde_json::json!({"expression": "abc"})).await;
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_calculator_tool_properties() {
538 let tool = CalculatorTool;
539 assert_eq!(tool.name(), "calculator");
540 assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
541 }
542
543 #[test]
546 fn test_eval_simple() {
547 assert_eq!(eval_expression("1 + 1").unwrap(), 2.0);
548 assert_eq!(eval_expression("10 - 3").unwrap(), 7.0);
549 assert_eq!(eval_expression("6 * 7").unwrap(), 42.0);
550 assert_eq!(eval_expression("15 / 3").unwrap(), 5.0);
551 }
552
553 #[test]
554 fn test_eval_precedence() {
555 assert_eq!(eval_expression("2 + 3 * 4").unwrap(), 14.0);
556 assert_eq!(eval_expression("2 * 3 + 4").unwrap(), 10.0);
557 }
558
559 #[test]
560 fn test_eval_parentheses() {
561 assert_eq!(eval_expression("(2 + 3) * 4").unwrap(), 20.0);
562 assert_eq!(eval_expression("2 * (3 + 4)").unwrap(), 14.0);
563 }
564
565 #[test]
566 fn test_eval_unary_minus() {
567 assert_eq!(eval_expression("-5").unwrap(), -5.0);
568 assert_eq!(eval_expression("-5 + 10").unwrap(), 5.0);
569 }
570
571 #[test]
572 fn test_eval_errors() {
573 assert!(eval_expression("").is_err());
574 assert!(eval_expression("1 +").is_err());
575 assert!(eval_expression("(1 + 2").is_err());
576 assert!(eval_expression("1 / 0").is_err());
577 }
578}