1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::generate_schema;
7use ai_agents_core::{Tool, ToolResult};
8
9pub struct MathTool;
10
11impl MathTool {
12 pub fn new() -> Self {
13 Self
14 }
15}
16
17impl Default for MathTool {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct MathInput {
25 operation: String,
27 #[serde(default)]
29 values: Option<Vec<f64>>,
30 #[serde(default)]
32 value: Option<f64>,
33 #[serde(default)]
35 decimals: Option<i32>,
36 #[serde(default)]
38 min: Option<f64>,
39 #[serde(default)]
41 max: Option<f64>,
42 #[serde(default)]
44 base: Option<f64>,
45 #[serde(default)]
47 exponent: Option<f64>,
48 #[serde(default)]
50 total: Option<f64>,
51 #[serde(default)]
53 step: Option<f64>,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57struct StatOutput {
58 result: f64,
59 count: usize,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63struct StdevOutput {
64 stdev: f64,
65 variance: f64,
66 mean: f64,
67 count: usize,
68}
69
70#[derive(Debug, Serialize, Deserialize)]
71struct ModeOutput {
72 mode: Vec<f64>,
73 frequency: usize,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct SingleOutput {
78 result: f64,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct ClampOutput {
83 result: f64,
84 clamped: bool,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct RangeOutput {
89 range: Vec<f64>,
90 count: usize,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94struct MinMaxOutput {
95 min: f64,
96 max: f64,
97 range: f64,
98}
99
100#[async_trait]
101impl Tool for MathTool {
102 fn id(&self) -> &str {
103 "math"
104 }
105
106 fn name(&self) -> &str {
107 "Advanced Math"
108 }
109
110 fn description(&self) -> &str {
111 "Advanced math operations: mean (average), median, mode, stdev (standard deviation), variance, sum, min, max, minmax (both), abs, round, floor, ceil, clamp, percentage, sqrt, pow, log, log10, range, count."
112 }
113
114 fn input_schema(&self) -> Value {
115 generate_schema::<MathInput>()
116 }
117
118 async fn execute(&self, args: Value) -> ToolResult {
119 let input: MathInput = match serde_json::from_value(args) {
120 Ok(input) => input,
121 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
122 };
123
124 match input.operation.to_lowercase().as_str() {
125 "mean" | "average" | "avg" => self.handle_mean(&input),
126 "median" => self.handle_median(&input),
127 "mode" => self.handle_mode(&input),
128 "stdev" | "std" => self.handle_stdev(&input),
129 "variance" | "var" => self.handle_variance(&input),
130 "sum" => self.handle_sum(&input),
131 "min" => self.handle_min(&input),
132 "max" => self.handle_max(&input),
133 "minmax" => self.handle_minmax(&input),
134 "abs" => self.handle_abs(&input),
135 "round" => self.handle_round(&input),
136 "floor" => self.handle_floor(&input),
137 "ceil" => self.handle_ceil(&input),
138 "clamp" => self.handle_clamp(&input),
139 "percentage" | "percent" => self.handle_percentage(&input),
140 "sqrt" => self.handle_sqrt(&input),
141 "pow" | "power" => self.handle_pow(&input),
142 "log" => self.handle_log(&input),
143 "log10" => self.handle_log10(&input),
144 "range" => self.handle_range(&input),
145 "count" => self.handle_count(&input),
146 _ => ToolResult::error(format!(
147 "Unknown operation: {}. Valid: mean, median, mode, stdev, variance, sum, min, max, minmax, abs, round, floor, ceil, clamp, percentage, sqrt, pow, log, log10, range, count",
148 input.operation
149 )),
150 }
151 }
152}
153
154impl MathTool {
155 fn get_values(&self, input: &MathInput) -> Result<Vec<f64>, ToolResult> {
156 input
157 .values
158 .clone()
159 .ok_or_else(|| ToolResult::error("'values' array is required"))
160 }
161
162 fn handle_mean(&self, input: &MathInput) -> ToolResult {
163 let values = match self.get_values(input) {
164 Ok(v) => v,
165 Err(e) => return e,
166 };
167 if values.is_empty() {
168 return ToolResult::error("values array cannot be empty");
169 }
170 let mean = values.iter().sum::<f64>() / values.len() as f64;
171 let output = StatOutput {
172 result: mean,
173 count: values.len(),
174 };
175 self.to_result(&output)
176 }
177
178 fn handle_median(&self, input: &MathInput) -> ToolResult {
179 let values = match self.get_values(input) {
180 Ok(v) => v,
181 Err(e) => return e,
182 };
183 if values.is_empty() {
184 return ToolResult::error("values array cannot be empty");
185 }
186
187 let mut sorted = values.clone();
188 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190 let mid = sorted.len() / 2;
191 let median = if sorted.len() % 2 == 0 {
192 (sorted[mid - 1] + sorted[mid]) / 2.0
193 } else {
194 sorted[mid]
195 };
196
197 let output = StatOutput {
198 result: median,
199 count: values.len(),
200 };
201 self.to_result(&output)
202 }
203
204 fn handle_mode(&self, input: &MathInput) -> ToolResult {
205 let values = match self.get_values(input) {
206 Ok(v) => v,
207 Err(e) => return e,
208 };
209 if values.is_empty() {
210 return ToolResult::error("values array cannot be empty");
211 }
212
213 use std::collections::HashMap;
214 let mut counts: HashMap<String, usize> = HashMap::new();
215
216 for v in &values {
217 let key = format!("{:.10}", v);
218 *counts.entry(key).or_insert(0) += 1;
219 }
220
221 let max_count = *counts.values().max().unwrap_or(&0);
222 let modes: Vec<f64> = counts
223 .iter()
224 .filter(|&(_, &c)| c == max_count)
225 .filter_map(|(k, _)| k.parse().ok())
226 .collect();
227
228 let output = ModeOutput {
229 mode: modes,
230 frequency: max_count,
231 };
232 self.to_result(&output)
233 }
234
235 fn handle_stdev(&self, input: &MathInput) -> ToolResult {
236 let values = match self.get_values(input) {
237 Ok(v) => v,
238 Err(e) => return e,
239 };
240 if values.len() < 2 {
241 return ToolResult::error("stdev requires at least 2 values");
242 }
243
244 let mean = values.iter().sum::<f64>() / values.len() as f64;
245 let variance =
246 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
247 let stdev = variance.sqrt();
248
249 let output = StdevOutput {
250 stdev,
251 variance,
252 mean,
253 count: values.len(),
254 };
255 self.to_result(&output)
256 }
257
258 fn handle_variance(&self, input: &MathInput) -> ToolResult {
259 let values = match self.get_values(input) {
260 Ok(v) => v,
261 Err(e) => return e,
262 };
263 if values.len() < 2 {
264 return ToolResult::error("variance requires at least 2 values");
265 }
266
267 let mean = values.iter().sum::<f64>() / values.len() as f64;
268 let variance =
269 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
270
271 let output = StatOutput {
272 result: variance,
273 count: values.len(),
274 };
275 self.to_result(&output)
276 }
277
278 fn handle_sum(&self, input: &MathInput) -> ToolResult {
279 let values = match self.get_values(input) {
280 Ok(v) => v,
281 Err(e) => return e,
282 };
283
284 let sum: f64 = values.iter().sum();
285 let output = StatOutput {
286 result: sum,
287 count: values.len(),
288 };
289 self.to_result(&output)
290 }
291
292 fn handle_min(&self, input: &MathInput) -> ToolResult {
293 let values = match self.get_values(input) {
294 Ok(v) => v,
295 Err(e) => return e,
296 };
297 if values.is_empty() {
298 return ToolResult::error("values array cannot be empty");
299 }
300
301 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
302 let output = SingleOutput { result: min };
303 self.to_result(&output)
304 }
305
306 fn handle_max(&self, input: &MathInput) -> ToolResult {
307 let values = match self.get_values(input) {
308 Ok(v) => v,
309 Err(e) => return e,
310 };
311 if values.is_empty() {
312 return ToolResult::error("values array cannot be empty");
313 }
314
315 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
316 let output = SingleOutput { result: max };
317 self.to_result(&output)
318 }
319
320 fn handle_minmax(&self, input: &MathInput) -> ToolResult {
321 let values = match self.get_values(input) {
322 Ok(v) => v,
323 Err(e) => return e,
324 };
325 if values.is_empty() {
326 return ToolResult::error("values array cannot be empty");
327 }
328
329 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
330 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
331 let output = MinMaxOutput {
332 min,
333 max,
334 range: max - min,
335 };
336 self.to_result(&output)
337 }
338
339 fn handle_abs(&self, input: &MathInput) -> ToolResult {
340 let value = match input.value {
341 Some(v) => v,
342 None => return ToolResult::error("'value' is required for abs operation"),
343 };
344 let output = SingleOutput {
345 result: value.abs(),
346 };
347 self.to_result(&output)
348 }
349
350 fn handle_round(&self, input: &MathInput) -> ToolResult {
351 let value = match input.value {
352 Some(v) => v,
353 None => return ToolResult::error("'value' is required for round operation"),
354 };
355 let decimals = input.decimals.unwrap_or(0);
356 let multiplier = 10_f64.powi(decimals);
357 let rounded = (value * multiplier).round() / multiplier;
358 let output = SingleOutput { result: rounded };
359 self.to_result(&output)
360 }
361
362 fn handle_floor(&self, input: &MathInput) -> ToolResult {
363 let value = match input.value {
364 Some(v) => v,
365 None => return ToolResult::error("'value' is required for floor operation"),
366 };
367 let output = SingleOutput {
368 result: value.floor(),
369 };
370 self.to_result(&output)
371 }
372
373 fn handle_ceil(&self, input: &MathInput) -> ToolResult {
374 let value = match input.value {
375 Some(v) => v,
376 None => return ToolResult::error("'value' is required for ceil operation"),
377 };
378 let output = SingleOutput {
379 result: value.ceil(),
380 };
381 self.to_result(&output)
382 }
383
384 fn handle_clamp(&self, input: &MathInput) -> ToolResult {
385 let value = match input.value {
386 Some(v) => v,
387 None => return ToolResult::error("'value' is required for clamp operation"),
388 };
389 let min = match input.min {
390 Some(m) => m,
391 None => return ToolResult::error("'min' is required for clamp operation"),
392 };
393 let max = match input.max {
394 Some(m) => m,
395 None => return ToolResult::error("'max' is required for clamp operation"),
396 };
397
398 let clamped_value = value.max(min).min(max);
399 let output = ClampOutput {
400 result: clamped_value,
401 clamped: value != clamped_value,
402 };
403 self.to_result(&output)
404 }
405
406 fn handle_percentage(&self, input: &MathInput) -> ToolResult {
407 let value = match input.value {
408 Some(v) => v,
409 None => return ToolResult::error("'value' is required for percentage operation"),
410 };
411 let total = match input.total {
412 Some(t) => t,
413 None => return ToolResult::error("'total' is required for percentage operation"),
414 };
415 if total == 0.0 {
416 return ToolResult::error("total cannot be zero");
417 }
418
419 let percentage = (value / total) * 100.0;
420 let output = SingleOutput { result: percentage };
421 self.to_result(&output)
422 }
423
424 fn handle_sqrt(&self, input: &MathInput) -> ToolResult {
425 let value = match input.value {
426 Some(v) => v,
427 None => return ToolResult::error("'value' is required for sqrt operation"),
428 };
429 if value < 0.0 {
430 return ToolResult::error("cannot calculate sqrt of negative number");
431 }
432 let output = SingleOutput {
433 result: value.sqrt(),
434 };
435 self.to_result(&output)
436 }
437
438 fn handle_pow(&self, input: &MathInput) -> ToolResult {
439 let base = input.value.or(input.base);
440 let base = match base {
441 Some(b) => b,
442 None => return ToolResult::error("'value' or 'base' is required for pow operation"),
443 };
444 let exponent = match input.exponent {
445 Some(e) => e,
446 None => return ToolResult::error("'exponent' is required for pow operation"),
447 };
448 let output = SingleOutput {
449 result: base.powf(exponent),
450 };
451 self.to_result(&output)
452 }
453
454 fn handle_log(&self, input: &MathInput) -> ToolResult {
455 let value = match input.value {
456 Some(v) => v,
457 None => return ToolResult::error("'value' is required for log operation"),
458 };
459 if value <= 0.0 {
460 return ToolResult::error("cannot calculate log of non-positive number");
461 }
462 let result = match input.base {
463 Some(b) if b > 0.0 && b != 1.0 => value.log(b),
464 Some(_) => return ToolResult::error("log base must be positive and not equal to 1"),
465 None => value.ln(),
466 };
467 let output = SingleOutput { result };
468 self.to_result(&output)
469 }
470
471 fn handle_log10(&self, input: &MathInput) -> ToolResult {
472 let value = match input.value {
473 Some(v) => v,
474 None => return ToolResult::error("'value' is required for log10 operation"),
475 };
476 if value <= 0.0 {
477 return ToolResult::error("cannot calculate log of non-positive number");
478 }
479 let output = SingleOutput {
480 result: value.log10(),
481 };
482 self.to_result(&output)
483 }
484
485 fn handle_range(&self, input: &MathInput) -> ToolResult {
486 let min = input.min.unwrap_or(0.0);
487 let max = match input.max {
488 Some(m) => m,
489 None => return ToolResult::error("'max' is required for range operation"),
490 };
491 let step = input.step.unwrap_or(1.0);
492
493 if step == 0.0 {
494 return ToolResult::error("step cannot be zero");
495 }
496 if (max > min && step < 0.0) || (max < min && step > 0.0) {
497 return ToolResult::error("step direction doesn't match min/max range");
498 }
499
500 let mut values = Vec::new();
501 let mut current = min;
502
503 if step > 0.0 {
504 while current < max {
505 values.push(current);
506 current += step;
507 }
508 } else {
509 while current > max {
510 values.push(current);
511 current += step;
512 }
513 }
514
515 let output = RangeOutput {
516 count: values.len(),
517 range: values,
518 };
519 self.to_result(&output)
520 }
521
522 fn handle_count(&self, input: &MathInput) -> ToolResult {
523 let values = match self.get_values(input) {
524 Ok(v) => v,
525 Err(e) => return e,
526 };
527 let output = StatOutput {
528 result: values.len() as f64,
529 count: values.len(),
530 };
531 self.to_result(&output)
532 }
533
534 fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
535 match serde_json::to_string(output) {
536 Ok(json) => ToolResult::ok(json),
537 Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
538 }
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[tokio::test]
547 async fn test_mean() {
548 let tool = MathTool::new();
549 let result = tool
550 .execute(serde_json::json!({
551 "operation": "mean",
552 "values": [1, 2, 3, 4, 5]
553 }))
554 .await;
555 assert!(result.success);
556 let output: StatOutput = serde_json::from_str(&result.output).unwrap();
557 assert!((output.result - 3.0).abs() < f64::EPSILON);
558 }
559
560 #[tokio::test]
561 async fn test_median_odd() {
562 let tool = MathTool::new();
563 let result = tool
564 .execute(serde_json::json!({
565 "operation": "median",
566 "values": [1, 3, 2, 5, 4]
567 }))
568 .await;
569 assert!(result.success);
570 let output: StatOutput = serde_json::from_str(&result.output).unwrap();
571 assert!((output.result - 3.0).abs() < f64::EPSILON);
572 }
573
574 #[tokio::test]
575 async fn test_median_even() {
576 let tool = MathTool::new();
577 let result = tool
578 .execute(serde_json::json!({
579 "operation": "median",
580 "values": [1, 2, 3, 4]
581 }))
582 .await;
583 assert!(result.success);
584 let output: StatOutput = serde_json::from_str(&result.output).unwrap();
585 assert!((output.result - 2.5).abs() < f64::EPSILON);
586 }
587
588 #[tokio::test]
589 async fn test_stdev() {
590 let tool = MathTool::new();
591 let result = tool
592 .execute(serde_json::json!({
593 "operation": "stdev",
594 "values": [2, 4, 4, 4, 5, 5, 7, 9]
595 }))
596 .await;
597 assert!(result.success);
598 let output: StdevOutput = serde_json::from_str(&result.output).unwrap();
599 assert!((output.stdev - 2.138).abs() < 0.01);
600 }
601
602 #[tokio::test]
603 async fn test_sum() {
604 let tool = MathTool::new();
605 let result = tool
606 .execute(serde_json::json!({
607 "operation": "sum",
608 "values": [1, 2, 3, 4, 5]
609 }))
610 .await;
611 assert!(result.success);
612 let output: StatOutput = serde_json::from_str(&result.output).unwrap();
613 assert!((output.result - 15.0).abs() < f64::EPSILON);
614 }
615
616 #[tokio::test]
617 async fn test_minmax() {
618 let tool = MathTool::new();
619 let result = tool
620 .execute(serde_json::json!({
621 "operation": "minmax",
622 "values": [3, 1, 4, 1, 5, 9, 2, 6]
623 }))
624 .await;
625 assert!(result.success);
626 let output: MinMaxOutput = serde_json::from_str(&result.output).unwrap();
627 assert!((output.min - 1.0).abs() < f64::EPSILON);
628 assert!((output.max - 9.0).abs() < f64::EPSILON);
629 assert!((output.range - 8.0).abs() < f64::EPSILON);
630 }
631
632 #[tokio::test]
633 async fn test_round() {
634 let tool = MathTool::new();
635 let result = tool
636 .execute(serde_json::json!({
637 "operation": "round",
638 "value": 3.14159,
639 "decimals": 2
640 }))
641 .await;
642 assert!(result.success);
643 let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
644 assert!((output.result - 3.14).abs() < f64::EPSILON);
645 }
646
647 #[tokio::test]
648 async fn test_clamp() {
649 let tool = MathTool::new();
650 let result = tool
651 .execute(serde_json::json!({
652 "operation": "clamp",
653 "value": 15,
654 "min": 0,
655 "max": 10
656 }))
657 .await;
658 assert!(result.success);
659 let output: ClampOutput = serde_json::from_str(&result.output).unwrap();
660 assert!((output.result - 10.0).abs() < f64::EPSILON);
661 assert!(output.clamped);
662 }
663
664 #[tokio::test]
665 async fn test_percentage() {
666 let tool = MathTool::new();
667 let result = tool
668 .execute(serde_json::json!({
669 "operation": "percentage",
670 "value": 25,
671 "total": 100
672 }))
673 .await;
674 assert!(result.success);
675 let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
676 assert!((output.result - 25.0).abs() < f64::EPSILON);
677 }
678
679 #[tokio::test]
680 async fn test_sqrt() {
681 let tool = MathTool::new();
682 let result = tool
683 .execute(serde_json::json!({
684 "operation": "sqrt",
685 "value": 16
686 }))
687 .await;
688 assert!(result.success);
689 let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
690 assert!((output.result - 4.0).abs() < f64::EPSILON);
691 }
692
693 #[tokio::test]
694 async fn test_pow() {
695 let tool = MathTool::new();
696 let result = tool
697 .execute(serde_json::json!({
698 "operation": "pow",
699 "value": 2,
700 "exponent": 10
701 }))
702 .await;
703 assert!(result.success);
704 let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
705 assert!((output.result - 1024.0).abs() < f64::EPSILON);
706 }
707
708 #[tokio::test]
709 async fn test_range() {
710 let tool = MathTool::new();
711 let result = tool
712 .execute(serde_json::json!({
713 "operation": "range",
714 "min": 0,
715 "max": 5,
716 "step": 1
717 }))
718 .await;
719 assert!(result.success);
720 let output: RangeOutput = serde_json::from_str(&result.output).unwrap();
721 assert_eq!(output.range, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
722 }
723
724 #[tokio::test]
725 async fn test_invalid_operation() {
726 let tool = MathTool::new();
727 let result = tool
728 .execute(serde_json::json!({
729 "operation": "invalid"
730 }))
731 .await;
732 assert!(!result.success);
733 }
734
735 #[tokio::test]
736 async fn test_empty_values() {
737 let tool = MathTool::new();
738 let result = tool
739 .execute(serde_json::json!({
740 "operation": "mean",
741 "values": []
742 }))
743 .await;
744 assert!(!result.success);
745 }
746}