Skip to main content

ai_agents_tools/builtin/
math.rs

1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::generate_schema;
7use ai_agents_core::{Tool, ToolResult};
8
9pub struct MathTool;
10
11impl MathTool {
12    pub fn new() -> Self {
13        Self
14    }
15}
16
17impl Default for MathTool {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct MathInput {
25    /// Operation: mean, median, mode, stdev, variance, sum, min, max, abs, round, floor, ceil, clamp, percentage, sqrt, pow, log, range, count
26    operation: String,
27    /// Array of numbers (for statistical operations)
28    #[serde(default)]
29    values: Option<Vec<f64>>,
30    /// Single number input
31    #[serde(default)]
32    value: Option<f64>,
33    /// Decimal places (for round)
34    #[serde(default)]
35    decimals: Option<i32>,
36    /// Minimum value (for clamp/range)
37    #[serde(default)]
38    min: Option<f64>,
39    /// Maximum value (for clamp/range)
40    #[serde(default)]
41    max: Option<f64>,
42    /// Base for pow/log
43    #[serde(default)]
44    base: Option<f64>,
45    /// Exponent for pow
46    #[serde(default)]
47    exponent: Option<f64>,
48    /// Total for percentage calculation
49    #[serde(default)]
50    total: Option<f64>,
51    /// Step for range
52    #[serde(default)]
53    step: Option<f64>,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57struct StatOutput {
58    result: f64,
59    count: usize,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63struct StdevOutput {
64    stdev: f64,
65    variance: f64,
66    mean: f64,
67    count: usize,
68}
69
70#[derive(Debug, Serialize, Deserialize)]
71struct ModeOutput {
72    mode: Vec<f64>,
73    frequency: usize,
74}
75
76#[derive(Debug, Serialize, Deserialize)]
77struct SingleOutput {
78    result: f64,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct ClampOutput {
83    result: f64,
84    clamped: bool,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct RangeOutput {
89    range: Vec<f64>,
90    count: usize,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94struct MinMaxOutput {
95    min: f64,
96    max: f64,
97    range: f64,
98}
99
100#[async_trait]
101impl Tool for MathTool {
102    fn id(&self) -> &str {
103        "math"
104    }
105
106    fn name(&self) -> &str {
107        "Advanced Math"
108    }
109
110    fn description(&self) -> &str {
111        "Advanced math operations: mean (average), median, mode, stdev (standard deviation), variance, sum, min, max, minmax (both), abs, round, floor, ceil, clamp, percentage, sqrt, pow, log, log10, range, count."
112    }
113
114    fn input_schema(&self) -> Value {
115        generate_schema::<MathInput>()
116    }
117
118    async fn execute(&self, args: Value) -> ToolResult {
119        let input: MathInput = match serde_json::from_value(args) {
120            Ok(input) => input,
121            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
122        };
123
124        match input.operation.to_lowercase().as_str() {
125            "mean" | "average" | "avg" => self.handle_mean(&input),
126            "median" => self.handle_median(&input),
127            "mode" => self.handle_mode(&input),
128            "stdev" | "std" => self.handle_stdev(&input),
129            "variance" | "var" => self.handle_variance(&input),
130            "sum" => self.handle_sum(&input),
131            "min" => self.handle_min(&input),
132            "max" => self.handle_max(&input),
133            "minmax" => self.handle_minmax(&input),
134            "abs" => self.handle_abs(&input),
135            "round" => self.handle_round(&input),
136            "floor" => self.handle_floor(&input),
137            "ceil" => self.handle_ceil(&input),
138            "clamp" => self.handle_clamp(&input),
139            "percentage" | "percent" => self.handle_percentage(&input),
140            "sqrt" => self.handle_sqrt(&input),
141            "pow" | "power" => self.handle_pow(&input),
142            "log" => self.handle_log(&input),
143            "log10" => self.handle_log10(&input),
144            "range" => self.handle_range(&input),
145            "count" => self.handle_count(&input),
146            _ => ToolResult::error(format!(
147                "Unknown operation: {}. Valid: mean, median, mode, stdev, variance, sum, min, max, minmax, abs, round, floor, ceil, clamp, percentage, sqrt, pow, log, log10, range, count",
148                input.operation
149            )),
150        }
151    }
152}
153
154impl MathTool {
155    fn get_values(&self, input: &MathInput) -> Result<Vec<f64>, ToolResult> {
156        input
157            .values
158            .clone()
159            .ok_or_else(|| ToolResult::error("'values' array is required"))
160    }
161
162    fn handle_mean(&self, input: &MathInput) -> ToolResult {
163        let values = match self.get_values(input) {
164            Ok(v) => v,
165            Err(e) => return e,
166        };
167        if values.is_empty() {
168            return ToolResult::error("values array cannot be empty");
169        }
170        let mean = values.iter().sum::<f64>() / values.len() as f64;
171        let output = StatOutput {
172            result: mean,
173            count: values.len(),
174        };
175        self.to_result(&output)
176    }
177
178    fn handle_median(&self, input: &MathInput) -> ToolResult {
179        let values = match self.get_values(input) {
180            Ok(v) => v,
181            Err(e) => return e,
182        };
183        if values.is_empty() {
184            return ToolResult::error("values array cannot be empty");
185        }
186
187        let mut sorted = values.clone();
188        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190        let mid = sorted.len() / 2;
191        let median = if sorted.len() % 2 == 0 {
192            (sorted[mid - 1] + sorted[mid]) / 2.0
193        } else {
194            sorted[mid]
195        };
196
197        let output = StatOutput {
198            result: median,
199            count: values.len(),
200        };
201        self.to_result(&output)
202    }
203
204    fn handle_mode(&self, input: &MathInput) -> ToolResult {
205        let values = match self.get_values(input) {
206            Ok(v) => v,
207            Err(e) => return e,
208        };
209        if values.is_empty() {
210            return ToolResult::error("values array cannot be empty");
211        }
212
213        use std::collections::HashMap;
214        let mut counts: HashMap<String, usize> = HashMap::new();
215
216        for v in &values {
217            let key = format!("{:.10}", v);
218            *counts.entry(key).or_insert(0) += 1;
219        }
220
221        let max_count = *counts.values().max().unwrap_or(&0);
222        let modes: Vec<f64> = counts
223            .iter()
224            .filter(|&(_, &c)| c == max_count)
225            .filter_map(|(k, _)| k.parse().ok())
226            .collect();
227
228        let output = ModeOutput {
229            mode: modes,
230            frequency: max_count,
231        };
232        self.to_result(&output)
233    }
234
235    fn handle_stdev(&self, input: &MathInput) -> ToolResult {
236        let values = match self.get_values(input) {
237            Ok(v) => v,
238            Err(e) => return e,
239        };
240        if values.len() < 2 {
241            return ToolResult::error("stdev requires at least 2 values");
242        }
243
244        let mean = values.iter().sum::<f64>() / values.len() as f64;
245        let variance =
246            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
247        let stdev = variance.sqrt();
248
249        let output = StdevOutput {
250            stdev,
251            variance,
252            mean,
253            count: values.len(),
254        };
255        self.to_result(&output)
256    }
257
258    fn handle_variance(&self, input: &MathInput) -> ToolResult {
259        let values = match self.get_values(input) {
260            Ok(v) => v,
261            Err(e) => return e,
262        };
263        if values.len() < 2 {
264            return ToolResult::error("variance requires at least 2 values");
265        }
266
267        let mean = values.iter().sum::<f64>() / values.len() as f64;
268        let variance =
269            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
270
271        let output = StatOutput {
272            result: variance,
273            count: values.len(),
274        };
275        self.to_result(&output)
276    }
277
278    fn handle_sum(&self, input: &MathInput) -> ToolResult {
279        let values = match self.get_values(input) {
280            Ok(v) => v,
281            Err(e) => return e,
282        };
283
284        let sum: f64 = values.iter().sum();
285        let output = StatOutput {
286            result: sum,
287            count: values.len(),
288        };
289        self.to_result(&output)
290    }
291
292    fn handle_min(&self, input: &MathInput) -> ToolResult {
293        let values = match self.get_values(input) {
294            Ok(v) => v,
295            Err(e) => return e,
296        };
297        if values.is_empty() {
298            return ToolResult::error("values array cannot be empty");
299        }
300
301        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
302        let output = SingleOutput { result: min };
303        self.to_result(&output)
304    }
305
306    fn handle_max(&self, input: &MathInput) -> ToolResult {
307        let values = match self.get_values(input) {
308            Ok(v) => v,
309            Err(e) => return e,
310        };
311        if values.is_empty() {
312            return ToolResult::error("values array cannot be empty");
313        }
314
315        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
316        let output = SingleOutput { result: max };
317        self.to_result(&output)
318    }
319
320    fn handle_minmax(&self, input: &MathInput) -> ToolResult {
321        let values = match self.get_values(input) {
322            Ok(v) => v,
323            Err(e) => return e,
324        };
325        if values.is_empty() {
326            return ToolResult::error("values array cannot be empty");
327        }
328
329        let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
330        let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
331        let output = MinMaxOutput {
332            min,
333            max,
334            range: max - min,
335        };
336        self.to_result(&output)
337    }
338
339    fn handle_abs(&self, input: &MathInput) -> ToolResult {
340        let value = match input.value {
341            Some(v) => v,
342            None => return ToolResult::error("'value' is required for abs operation"),
343        };
344        let output = SingleOutput {
345            result: value.abs(),
346        };
347        self.to_result(&output)
348    }
349
350    fn handle_round(&self, input: &MathInput) -> ToolResult {
351        let value = match input.value {
352            Some(v) => v,
353            None => return ToolResult::error("'value' is required for round operation"),
354        };
355        let decimals = input.decimals.unwrap_or(0);
356        let multiplier = 10_f64.powi(decimals);
357        let rounded = (value * multiplier).round() / multiplier;
358        let output = SingleOutput { result: rounded };
359        self.to_result(&output)
360    }
361
362    fn handle_floor(&self, input: &MathInput) -> ToolResult {
363        let value = match input.value {
364            Some(v) => v,
365            None => return ToolResult::error("'value' is required for floor operation"),
366        };
367        let output = SingleOutput {
368            result: value.floor(),
369        };
370        self.to_result(&output)
371    }
372
373    fn handle_ceil(&self, input: &MathInput) -> ToolResult {
374        let value = match input.value {
375            Some(v) => v,
376            None => return ToolResult::error("'value' is required for ceil operation"),
377        };
378        let output = SingleOutput {
379            result: value.ceil(),
380        };
381        self.to_result(&output)
382    }
383
384    fn handle_clamp(&self, input: &MathInput) -> ToolResult {
385        let value = match input.value {
386            Some(v) => v,
387            None => return ToolResult::error("'value' is required for clamp operation"),
388        };
389        let min = match input.min {
390            Some(m) => m,
391            None => return ToolResult::error("'min' is required for clamp operation"),
392        };
393        let max = match input.max {
394            Some(m) => m,
395            None => return ToolResult::error("'max' is required for clamp operation"),
396        };
397
398        let clamped_value = value.max(min).min(max);
399        let output = ClampOutput {
400            result: clamped_value,
401            clamped: value != clamped_value,
402        };
403        self.to_result(&output)
404    }
405
406    fn handle_percentage(&self, input: &MathInput) -> ToolResult {
407        let value = match input.value {
408            Some(v) => v,
409            None => return ToolResult::error("'value' is required for percentage operation"),
410        };
411        let total = match input.total {
412            Some(t) => t,
413            None => return ToolResult::error("'total' is required for percentage operation"),
414        };
415        if total == 0.0 {
416            return ToolResult::error("total cannot be zero");
417        }
418
419        let percentage = (value / total) * 100.0;
420        let output = SingleOutput { result: percentage };
421        self.to_result(&output)
422    }
423
424    fn handle_sqrt(&self, input: &MathInput) -> ToolResult {
425        let value = match input.value {
426            Some(v) => v,
427            None => return ToolResult::error("'value' is required for sqrt operation"),
428        };
429        if value < 0.0 {
430            return ToolResult::error("cannot calculate sqrt of negative number");
431        }
432        let output = SingleOutput {
433            result: value.sqrt(),
434        };
435        self.to_result(&output)
436    }
437
438    fn handle_pow(&self, input: &MathInput) -> ToolResult {
439        let base = input.value.or(input.base);
440        let base = match base {
441            Some(b) => b,
442            None => return ToolResult::error("'value' or 'base' is required for pow operation"),
443        };
444        let exponent = match input.exponent {
445            Some(e) => e,
446            None => return ToolResult::error("'exponent' is required for pow operation"),
447        };
448        let output = SingleOutput {
449            result: base.powf(exponent),
450        };
451        self.to_result(&output)
452    }
453
454    fn handle_log(&self, input: &MathInput) -> ToolResult {
455        let value = match input.value {
456            Some(v) => v,
457            None => return ToolResult::error("'value' is required for log operation"),
458        };
459        if value <= 0.0 {
460            return ToolResult::error("cannot calculate log of non-positive number");
461        }
462        let result = match input.base {
463            Some(b) if b > 0.0 && b != 1.0 => value.log(b),
464            Some(_) => return ToolResult::error("log base must be positive and not equal to 1"),
465            None => value.ln(),
466        };
467        let output = SingleOutput { result };
468        self.to_result(&output)
469    }
470
471    fn handle_log10(&self, input: &MathInput) -> ToolResult {
472        let value = match input.value {
473            Some(v) => v,
474            None => return ToolResult::error("'value' is required for log10 operation"),
475        };
476        if value <= 0.0 {
477            return ToolResult::error("cannot calculate log of non-positive number");
478        }
479        let output = SingleOutput {
480            result: value.log10(),
481        };
482        self.to_result(&output)
483    }
484
485    fn handle_range(&self, input: &MathInput) -> ToolResult {
486        let min = input.min.unwrap_or(0.0);
487        let max = match input.max {
488            Some(m) => m,
489            None => return ToolResult::error("'max' is required for range operation"),
490        };
491        let step = input.step.unwrap_or(1.0);
492
493        if step == 0.0 {
494            return ToolResult::error("step cannot be zero");
495        }
496        if (max > min && step < 0.0) || (max < min && step > 0.0) {
497            return ToolResult::error("step direction doesn't match min/max range");
498        }
499
500        let mut values = Vec::new();
501        let mut current = min;
502
503        if step > 0.0 {
504            while current < max {
505                values.push(current);
506                current += step;
507            }
508        } else {
509            while current > max {
510                values.push(current);
511                current += step;
512            }
513        }
514
515        let output = RangeOutput {
516            count: values.len(),
517            range: values,
518        };
519        self.to_result(&output)
520    }
521
522    fn handle_count(&self, input: &MathInput) -> ToolResult {
523        let values = match self.get_values(input) {
524            Ok(v) => v,
525            Err(e) => return e,
526        };
527        let output = StatOutput {
528            result: values.len() as f64,
529            count: values.len(),
530        };
531        self.to_result(&output)
532    }
533
534    fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
535        match serde_json::to_string(output) {
536            Ok(json) => ToolResult::ok(json),
537            Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
538        }
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[tokio::test]
547    async fn test_mean() {
548        let tool = MathTool::new();
549        let result = tool
550            .execute(serde_json::json!({
551                "operation": "mean",
552                "values": [1, 2, 3, 4, 5]
553            }))
554            .await;
555        assert!(result.success);
556        let output: StatOutput = serde_json::from_str(&result.output).unwrap();
557        assert!((output.result - 3.0).abs() < f64::EPSILON);
558    }
559
560    #[tokio::test]
561    async fn test_median_odd() {
562        let tool = MathTool::new();
563        let result = tool
564            .execute(serde_json::json!({
565                "operation": "median",
566                "values": [1, 3, 2, 5, 4]
567            }))
568            .await;
569        assert!(result.success);
570        let output: StatOutput = serde_json::from_str(&result.output).unwrap();
571        assert!((output.result - 3.0).abs() < f64::EPSILON);
572    }
573
574    #[tokio::test]
575    async fn test_median_even() {
576        let tool = MathTool::new();
577        let result = tool
578            .execute(serde_json::json!({
579                "operation": "median",
580                "values": [1, 2, 3, 4]
581            }))
582            .await;
583        assert!(result.success);
584        let output: StatOutput = serde_json::from_str(&result.output).unwrap();
585        assert!((output.result - 2.5).abs() < f64::EPSILON);
586    }
587
588    #[tokio::test]
589    async fn test_stdev() {
590        let tool = MathTool::new();
591        let result = tool
592            .execute(serde_json::json!({
593                "operation": "stdev",
594                "values": [2, 4, 4, 4, 5, 5, 7, 9]
595            }))
596            .await;
597        assert!(result.success);
598        let output: StdevOutput = serde_json::from_str(&result.output).unwrap();
599        assert!((output.stdev - 2.138).abs() < 0.01);
600    }
601
602    #[tokio::test]
603    async fn test_sum() {
604        let tool = MathTool::new();
605        let result = tool
606            .execute(serde_json::json!({
607                "operation": "sum",
608                "values": [1, 2, 3, 4, 5]
609            }))
610            .await;
611        assert!(result.success);
612        let output: StatOutput = serde_json::from_str(&result.output).unwrap();
613        assert!((output.result - 15.0).abs() < f64::EPSILON);
614    }
615
616    #[tokio::test]
617    async fn test_minmax() {
618        let tool = MathTool::new();
619        let result = tool
620            .execute(serde_json::json!({
621                "operation": "minmax",
622                "values": [3, 1, 4, 1, 5, 9, 2, 6]
623            }))
624            .await;
625        assert!(result.success);
626        let output: MinMaxOutput = serde_json::from_str(&result.output).unwrap();
627        assert!((output.min - 1.0).abs() < f64::EPSILON);
628        assert!((output.max - 9.0).abs() < f64::EPSILON);
629        assert!((output.range - 8.0).abs() < f64::EPSILON);
630    }
631
632    #[tokio::test]
633    async fn test_round() {
634        let tool = MathTool::new();
635        let result = tool
636            .execute(serde_json::json!({
637                "operation": "round",
638                "value": 3.14159,
639                "decimals": 2
640            }))
641            .await;
642        assert!(result.success);
643        let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
644        assert!((output.result - 3.14).abs() < f64::EPSILON);
645    }
646
647    #[tokio::test]
648    async fn test_clamp() {
649        let tool = MathTool::new();
650        let result = tool
651            .execute(serde_json::json!({
652                "operation": "clamp",
653                "value": 15,
654                "min": 0,
655                "max": 10
656            }))
657            .await;
658        assert!(result.success);
659        let output: ClampOutput = serde_json::from_str(&result.output).unwrap();
660        assert!((output.result - 10.0).abs() < f64::EPSILON);
661        assert!(output.clamped);
662    }
663
664    #[tokio::test]
665    async fn test_percentage() {
666        let tool = MathTool::new();
667        let result = tool
668            .execute(serde_json::json!({
669                "operation": "percentage",
670                "value": 25,
671                "total": 100
672            }))
673            .await;
674        assert!(result.success);
675        let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
676        assert!((output.result - 25.0).abs() < f64::EPSILON);
677    }
678
679    #[tokio::test]
680    async fn test_sqrt() {
681        let tool = MathTool::new();
682        let result = tool
683            .execute(serde_json::json!({
684                "operation": "sqrt",
685                "value": 16
686            }))
687            .await;
688        assert!(result.success);
689        let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
690        assert!((output.result - 4.0).abs() < f64::EPSILON);
691    }
692
693    #[tokio::test]
694    async fn test_pow() {
695        let tool = MathTool::new();
696        let result = tool
697            .execute(serde_json::json!({
698                "operation": "pow",
699                "value": 2,
700                "exponent": 10
701            }))
702            .await;
703        assert!(result.success);
704        let output: SingleOutput = serde_json::from_str(&result.output).unwrap();
705        assert!((output.result - 1024.0).abs() < f64::EPSILON);
706    }
707
708    #[tokio::test]
709    async fn test_range() {
710        let tool = MathTool::new();
711        let result = tool
712            .execute(serde_json::json!({
713                "operation": "range",
714                "min": 0,
715                "max": 5,
716                "step": 1
717            }))
718            .await;
719        assert!(result.success);
720        let output: RangeOutput = serde_json::from_str(&result.output).unwrap();
721        assert_eq!(output.range, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
722    }
723
724    #[tokio::test]
725    async fn test_invalid_operation() {
726        let tool = MathTool::new();
727        let result = tool
728            .execute(serde_json::json!({
729                "operation": "invalid"
730            }))
731            .await;
732        assert!(!result.success);
733    }
734
735    #[tokio::test]
736    async fn test_empty_values() {
737        let tool = MathTool::new();
738        let result = tool
739            .execute(serde_json::json!({
740                "operation": "mean",
741                "values": []
742            }))
743            .await;
744        assert!(!result.success);
745    }
746}