Skip to main content

cortexai_tools/
math.rs

1//! Math tools with full expression evaluation
2
3use async_trait::async_trait;
4use cortexai_core::{errors::ToolError, ExecutionContext, Tool, ToolSchema};
5use serde_json::json;
6
7/// Advanced calculator tool with full expression parsing
8pub struct CalculatorTool {
9    /// Enable scientific functions (sin, cos, log, etc.)
10    #[allow(dead_code)]
11    scientific: bool,
12}
13
14impl Default for CalculatorTool {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl CalculatorTool {
21    pub fn new() -> Self {
22        Self { scientific: true }
23    }
24
25    /// Create a basic calculator (no scientific functions)
26    pub fn basic() -> Self {
27        Self { scientific: false }
28    }
29}
30
31#[async_trait]
32impl Tool for CalculatorTool {
33    fn schema(&self) -> ToolSchema {
34        ToolSchema::new("calculator", "Perform mathematical calculations")
35            .with_parameters(json!({
36                "type": "object",
37                "properties": {
38                    "expression": {
39                        "type": "string",
40                        "description": "Mathematical expression to evaluate. Supports: +, -, *, /, ^, %, parentheses, and functions like sin, cos, tan, sqrt, log, ln, abs, floor, ceil, round, min, max. Constants: pi, e"
41                    }
42                },
43                "required": ["expression"]
44            }))
45    }
46
47    async fn execute(
48        &self,
49        _context: &ExecutionContext,
50        arguments: serde_json::Value,
51    ) -> Result<serde_json::Value, ToolError> {
52        let expression = arguments["expression"]
53            .as_str()
54            .ok_or_else(|| ToolError::InvalidArguments("Missing 'expression' field".to_string()))?;
55
56        // Use fasteval for full expression parsing.
57        // A callback namespace provides 'pi', 'e' as constants and 'sqrt' as
58        // a function (fasteval's built-in functions don't include sqrt).
59        let mut ns = fasteval::StringToCallbackNamespace::new();
60        ns.insert("pi".to_string(), Box::new(|_| std::f64::consts::PI));
61        ns.insert("e".to_string(), Box::new(|_| std::f64::consts::E));
62        ns.insert(
63            "sqrt".to_string(),
64            Box::new(|args: Vec<f64>| args.first().map(|x| x.sqrt()).unwrap_or(f64::NAN)),
65        );
66        let result = fasteval::ez_eval(expression, &mut ns)
67            .map_err(|e| ToolError::ExecutionFailed(format!("Math error: {}", e)))?;
68
69        // Check for invalid results
70        if result.is_nan() {
71            return Err(ToolError::ExecutionFailed(
72                "Result is not a number (NaN)".to_string(),
73            ));
74        }
75        if result.is_infinite() {
76            return Err(ToolError::ExecutionFailed("Result is infinite".to_string()));
77        }
78
79        Ok(json!({
80            "expression": expression,
81            "result": result,
82            "formatted": format_number(result)
83        }))
84    }
85}
86
87/// Format a number nicely for display
88fn format_number(n: f64) -> String {
89    if n.fract() == 0.0 && n.abs() < 1e15 {
90        format!("{:.0}", n)
91    } else if n.abs() < 0.0001 || n.abs() >= 1e10 {
92        format!("{:.6e}", n)
93    } else {
94        format!("{:.6}", n)
95            .trim_end_matches('0')
96            .trim_end_matches('.')
97            .to_string()
98    }
99}
100
101/// Unit converter tool
102pub struct UnitConverterTool;
103
104impl Default for UnitConverterTool {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl UnitConverterTool {
111    pub fn new() -> Self {
112        Self
113    }
114
115    fn convert(&self, value: f64, from: &str, to: &str) -> Result<f64, String> {
116        // Normalize unit names
117        let from = from.to_lowercase();
118        let to = to.to_lowercase();
119
120        // Length conversions (base: meters)
121        let length_to_meters: std::collections::HashMap<&str, f64> = [
122            ("m", 1.0),
123            ("meter", 1.0),
124            ("meters", 1.0),
125            ("km", 1000.0),
126            ("kilometer", 1000.0),
127            ("kilometers", 1000.0),
128            ("cm", 0.01),
129            ("centimeter", 0.01),
130            ("centimeters", 0.01),
131            ("mm", 0.001),
132            ("millimeter", 0.001),
133            ("millimeters", 0.001),
134            ("mi", 1609.344),
135            ("mile", 1609.344),
136            ("miles", 1609.344),
137            ("ft", 0.3048),
138            ("foot", 0.3048),
139            ("feet", 0.3048),
140            ("in", 0.0254),
141            ("inch", 0.0254),
142            ("inches", 0.0254),
143            ("yd", 0.9144),
144            ("yard", 0.9144),
145            ("yards", 0.9144),
146        ]
147        .into_iter()
148        .collect();
149
150        // Weight conversions (base: grams)
151        let weight_to_grams: std::collections::HashMap<&str, f64> = [
152            ("g", 1.0),
153            ("gram", 1.0),
154            ("grams", 1.0),
155            ("kg", 1000.0),
156            ("kilogram", 1000.0),
157            ("kilograms", 1000.0),
158            ("mg", 0.001),
159            ("milligram", 0.001),
160            ("milligrams", 0.001),
161            ("lb", 453.592),
162            ("pound", 453.592),
163            ("pounds", 453.592),
164            ("oz", 28.3495),
165            ("ounce", 28.3495),
166            ("ounces", 28.3495),
167        ]
168        .into_iter()
169        .collect();
170
171        // Temperature (special handling)
172        if matches!(
173            from.as_str(),
174            "c" | "celsius" | "f" | "fahrenheit" | "k" | "kelvin"
175        ) {
176            return self.convert_temperature(value, &from, &to);
177        }
178
179        // Try length
180        if let (Some(&from_factor), Some(&to_factor)) = (
181            length_to_meters.get(from.as_str()),
182            length_to_meters.get(to.as_str()),
183        ) {
184            return Ok(value * from_factor / to_factor);
185        }
186
187        // Try weight
188        if let (Some(&from_factor), Some(&to_factor)) = (
189            weight_to_grams.get(from.as_str()),
190            weight_to_grams.get(to.as_str()),
191        ) {
192            return Ok(value * from_factor / to_factor);
193        }
194
195        Err(format!("Cannot convert from '{}' to '{}'", from, to))
196    }
197
198    fn convert_temperature(&self, value: f64, from: &str, to: &str) -> Result<f64, String> {
199        // Convert to Celsius first
200        let celsius = match from {
201            "c" | "celsius" => value,
202            "f" | "fahrenheit" => (value - 32.0) * 5.0 / 9.0,
203            "k" | "kelvin" => value - 273.15,
204            _ => return Err(format!("Unknown temperature unit: {}", from)),
205        };
206
207        // Convert from Celsius to target
208        match to {
209            "c" | "celsius" => Ok(celsius),
210            "f" | "fahrenheit" => Ok(celsius * 9.0 / 5.0 + 32.0),
211            "k" | "kelvin" => Ok(celsius + 273.15),
212            _ => Err(format!("Unknown temperature unit: {}", to)),
213        }
214    }
215}
216
217#[async_trait]
218impl Tool for UnitConverterTool {
219    fn schema(&self) -> ToolSchema {
220        ToolSchema::new("unit_converter", "Convert between units of measurement")
221            .with_parameters(json!({
222                "type": "object",
223                "properties": {
224                    "value": {
225                        "type": "number",
226                        "description": "Value to convert"
227                    },
228                    "from": {
229                        "type": "string",
230                        "description": "Source unit (e.g., 'km', 'miles', 'kg', 'pounds', 'celsius', 'fahrenheit')"
231                    },
232                    "to": {
233                        "type": "string",
234                        "description": "Target unit"
235                    }
236                },
237                "required": ["value", "from", "to"]
238            }))
239    }
240
241    async fn execute(
242        &self,
243        _context: &ExecutionContext,
244        arguments: serde_json::Value,
245    ) -> Result<serde_json::Value, ToolError> {
246        let value = arguments["value"]
247            .as_f64()
248            .ok_or_else(|| ToolError::InvalidArguments("Missing 'value' field".to_string()))?;
249        let from = arguments["from"]
250            .as_str()
251            .ok_or_else(|| ToolError::InvalidArguments("Missing 'from' field".to_string()))?;
252        let to = arguments["to"]
253            .as_str()
254            .ok_or_else(|| ToolError::InvalidArguments("Missing 'to' field".to_string()))?;
255
256        let result = self
257            .convert(value, from, to)
258            .map_err(ToolError::ExecutionFailed)?;
259
260        Ok(json!({
261            "value": value,
262            "from": from,
263            "to": to,
264            "result": result,
265            "formatted": format!("{} {} = {} {}", format_number(value), from, format_number(result), to)
266        }))
267    }
268}
269
270/// Statistics tool for basic statistical calculations
271pub struct StatisticsTool;
272
273impl Default for StatisticsTool {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279impl StatisticsTool {
280    pub fn new() -> Self {
281        Self
282    }
283}
284
285#[async_trait]
286impl Tool for StatisticsTool {
287    fn schema(&self) -> ToolSchema {
288        ToolSchema::new("statistics", "Calculate statistics for a list of numbers").with_parameters(
289            json!({
290                "type": "object",
291                "properties": {
292                    "numbers": {
293                        "type": "array",
294                        "items": { "type": "number" },
295                        "description": "List of numbers to analyze"
296                    }
297                },
298                "required": ["numbers"]
299            }),
300        )
301    }
302
303    async fn execute(
304        &self,
305        _context: &ExecutionContext,
306        arguments: serde_json::Value,
307    ) -> Result<serde_json::Value, ToolError> {
308        let numbers: Vec<f64> = arguments["numbers"]
309            .as_array()
310            .ok_or_else(|| ToolError::InvalidArguments("Missing 'numbers' array".to_string()))?
311            .iter()
312            .filter_map(|v| v.as_f64())
313            .collect();
314
315        if numbers.is_empty() {
316            return Err(ToolError::InvalidArguments(
317                "Numbers array is empty".to_string(),
318            ));
319        }
320
321        let count = numbers.len();
322        let sum: f64 = numbers.iter().sum();
323        let mean = sum / count as f64;
324
325        let mut sorted = numbers.clone();
326        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
327
328        let median = if count.is_multiple_of(2) {
329            (sorted[count / 2 - 1] + sorted[count / 2]) / 2.0
330        } else {
331            sorted[count / 2]
332        };
333
334        let min = sorted.first().copied().unwrap();
335        let max = sorted.last().copied().unwrap();
336
337        let variance: f64 = numbers.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count as f64;
338        let std_dev = variance.sqrt();
339
340        Ok(json!({
341            "count": count,
342            "sum": sum,
343            "mean": mean,
344            "median": median,
345            "min": min,
346            "max": max,
347            "range": max - min,
348            "variance": variance,
349            "std_dev": std_dev
350        }))
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use cortexai_core::types::AgentId;
358    use std::f64::consts::{E, PI};
359
360    fn test_ctx() -> ExecutionContext {
361        ExecutionContext::new(AgentId::new("test-agent"))
362    }
363
364    #[tokio::test]
365    async fn test_calculator_basic() {
366        let calc = CalculatorTool::new();
367        let ctx = test_ctx();
368
369        let result = calc
370            .execute(&ctx, json!({"expression": "2 + 2"}))
371            .await
372            .unwrap();
373        assert_eq!(result["result"], 4.0);
374
375        let result = calc
376            .execute(&ctx, json!({"expression": "10 * 5 + 3"}))
377            .await
378            .unwrap();
379        assert_eq!(result["result"], 53.0);
380    }
381
382    #[tokio::test]
383    async fn test_calculator_scientific() {
384        let calc = CalculatorTool::new();
385        let ctx = test_ctx();
386
387        let result = calc
388            .execute(&ctx, json!({"expression": "sqrt(16)"}))
389            .await
390            .unwrap();
391        assert_eq!(result["result"], 4.0);
392
393        let result = calc
394            .execute(&ctx, json!({"expression": "2^10"}))
395            .await
396            .unwrap();
397        assert_eq!(result["result"], 1024.0);
398
399        let result = calc
400            .execute(&ctx, json!({"expression": "sin(0)"}))
401            .await
402            .unwrap();
403        assert!((result["result"].as_f64().unwrap() - 0.0).abs() < 0.0001);
404    }
405
406    #[tokio::test]
407    async fn test_calculator_constants() {
408        let calc = CalculatorTool::new();
409        let ctx = test_ctx();
410
411        let result = calc
412            .execute(&ctx, json!({"expression": "pi"}))
413            .await
414            .unwrap();
415        assert!((result["result"].as_f64().unwrap() - PI).abs() < 0.0001);
416
417        let result = calc
418            .execute(&ctx, json!({"expression": "e"}))
419            .await
420            .unwrap();
421        assert!((result["result"].as_f64().unwrap() - E).abs() < 0.0001);
422    }
423
424    #[tokio::test]
425    async fn test_unit_converter() {
426        let converter = UnitConverterTool::new();
427        let ctx = test_ctx();
428
429        let result = converter
430            .execute(&ctx, json!({"value": 1.0, "from": "km", "to": "m"}))
431            .await
432            .unwrap();
433        assert_eq!(result["result"], 1000.0);
434
435        let result = converter
436            .execute(
437                &ctx,
438                json!({"value": 32.0, "from": "fahrenheit", "to": "celsius"}),
439            )
440            .await
441            .unwrap();
442        assert!((result["result"].as_f64().unwrap() - 0.0).abs() < 0.01);
443    }
444
445    #[tokio::test]
446    async fn test_statistics() {
447        let stats = StatisticsTool::new();
448        let ctx = test_ctx();
449
450        let result = stats
451            .execute(&ctx, json!({"numbers": [1, 2, 3, 4, 5]}))
452            .await
453            .unwrap();
454
455        assert_eq!(result["count"], 5);
456        assert_eq!(result["sum"], 15.0);
457        assert_eq!(result["mean"], 3.0);
458        assert_eq!(result["median"], 3.0);
459        assert_eq!(result["min"], 1.0);
460        assert_eq!(result["max"], 5.0);
461    }
462}