1use async_trait::async_trait;
4use cortexai_core::{errors::ToolError, ExecutionContext, Tool, ToolSchema};
5use serde_json::json;
6
7pub struct CalculatorTool {
9 #[allow(dead_code)]
11 scientific: bool,
12}
13
14impl Default for CalculatorTool {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl CalculatorTool {
21 pub fn new() -> Self {
22 Self { scientific: true }
23 }
24
25 pub fn basic() -> Self {
27 Self { scientific: false }
28 }
29}
30
31#[async_trait]
32impl Tool for CalculatorTool {
33 fn schema(&self) -> ToolSchema {
34 ToolSchema::new("calculator", "Perform mathematical calculations")
35 .with_parameters(json!({
36 "type": "object",
37 "properties": {
38 "expression": {
39 "type": "string",
40 "description": "Mathematical expression to evaluate. Supports: +, -, *, /, ^, %, parentheses, and functions like sin, cos, tan, sqrt, log, ln, abs, floor, ceil, round, min, max. Constants: pi, e"
41 }
42 },
43 "required": ["expression"]
44 }))
45 }
46
47 async fn execute(
48 &self,
49 _context: &ExecutionContext,
50 arguments: serde_json::Value,
51 ) -> Result<serde_json::Value, ToolError> {
52 let expression = arguments["expression"]
53 .as_str()
54 .ok_or_else(|| ToolError::InvalidArguments("Missing 'expression' field".to_string()))?;
55
56 let mut ns = fasteval::StringToCallbackNamespace::new();
60 ns.insert("pi".to_string(), Box::new(|_| std::f64::consts::PI));
61 ns.insert("e".to_string(), Box::new(|_| std::f64::consts::E));
62 ns.insert(
63 "sqrt".to_string(),
64 Box::new(|args: Vec<f64>| args.first().map(|x| x.sqrt()).unwrap_or(f64::NAN)),
65 );
66 let result = fasteval::ez_eval(expression, &mut ns)
67 .map_err(|e| ToolError::ExecutionFailed(format!("Math error: {}", e)))?;
68
69 if result.is_nan() {
71 return Err(ToolError::ExecutionFailed(
72 "Result is not a number (NaN)".to_string(),
73 ));
74 }
75 if result.is_infinite() {
76 return Err(ToolError::ExecutionFailed("Result is infinite".to_string()));
77 }
78
79 Ok(json!({
80 "expression": expression,
81 "result": result,
82 "formatted": format_number(result)
83 }))
84 }
85}
86
87fn format_number(n: f64) -> String {
89 if n.fract() == 0.0 && n.abs() < 1e15 {
90 format!("{:.0}", n)
91 } else if n.abs() < 0.0001 || n.abs() >= 1e10 {
92 format!("{:.6e}", n)
93 } else {
94 format!("{:.6}", n)
95 .trim_end_matches('0')
96 .trim_end_matches('.')
97 .to_string()
98 }
99}
100
101pub struct UnitConverterTool;
103
104impl Default for UnitConverterTool {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl UnitConverterTool {
111 pub fn new() -> Self {
112 Self
113 }
114
115 fn convert(&self, value: f64, from: &str, to: &str) -> Result<f64, String> {
116 let from = from.to_lowercase();
118 let to = to.to_lowercase();
119
120 let length_to_meters: std::collections::HashMap<&str, f64> = [
122 ("m", 1.0),
123 ("meter", 1.0),
124 ("meters", 1.0),
125 ("km", 1000.0),
126 ("kilometer", 1000.0),
127 ("kilometers", 1000.0),
128 ("cm", 0.01),
129 ("centimeter", 0.01),
130 ("centimeters", 0.01),
131 ("mm", 0.001),
132 ("millimeter", 0.001),
133 ("millimeters", 0.001),
134 ("mi", 1609.344),
135 ("mile", 1609.344),
136 ("miles", 1609.344),
137 ("ft", 0.3048),
138 ("foot", 0.3048),
139 ("feet", 0.3048),
140 ("in", 0.0254),
141 ("inch", 0.0254),
142 ("inches", 0.0254),
143 ("yd", 0.9144),
144 ("yard", 0.9144),
145 ("yards", 0.9144),
146 ]
147 .into_iter()
148 .collect();
149
150 let weight_to_grams: std::collections::HashMap<&str, f64> = [
152 ("g", 1.0),
153 ("gram", 1.0),
154 ("grams", 1.0),
155 ("kg", 1000.0),
156 ("kilogram", 1000.0),
157 ("kilograms", 1000.0),
158 ("mg", 0.001),
159 ("milligram", 0.001),
160 ("milligrams", 0.001),
161 ("lb", 453.592),
162 ("pound", 453.592),
163 ("pounds", 453.592),
164 ("oz", 28.3495),
165 ("ounce", 28.3495),
166 ("ounces", 28.3495),
167 ]
168 .into_iter()
169 .collect();
170
171 if matches!(
173 from.as_str(),
174 "c" | "celsius" | "f" | "fahrenheit" | "k" | "kelvin"
175 ) {
176 return self.convert_temperature(value, &from, &to);
177 }
178
179 if let (Some(&from_factor), Some(&to_factor)) = (
181 length_to_meters.get(from.as_str()),
182 length_to_meters.get(to.as_str()),
183 ) {
184 return Ok(value * from_factor / to_factor);
185 }
186
187 if let (Some(&from_factor), Some(&to_factor)) = (
189 weight_to_grams.get(from.as_str()),
190 weight_to_grams.get(to.as_str()),
191 ) {
192 return Ok(value * from_factor / to_factor);
193 }
194
195 Err(format!("Cannot convert from '{}' to '{}'", from, to))
196 }
197
198 fn convert_temperature(&self, value: f64, from: &str, to: &str) -> Result<f64, String> {
199 let celsius = match from {
201 "c" | "celsius" => value,
202 "f" | "fahrenheit" => (value - 32.0) * 5.0 / 9.0,
203 "k" | "kelvin" => value - 273.15,
204 _ => return Err(format!("Unknown temperature unit: {}", from)),
205 };
206
207 match to {
209 "c" | "celsius" => Ok(celsius),
210 "f" | "fahrenheit" => Ok(celsius * 9.0 / 5.0 + 32.0),
211 "k" | "kelvin" => Ok(celsius + 273.15),
212 _ => Err(format!("Unknown temperature unit: {}", to)),
213 }
214 }
215}
216
217#[async_trait]
218impl Tool for UnitConverterTool {
219 fn schema(&self) -> ToolSchema {
220 ToolSchema::new("unit_converter", "Convert between units of measurement")
221 .with_parameters(json!({
222 "type": "object",
223 "properties": {
224 "value": {
225 "type": "number",
226 "description": "Value to convert"
227 },
228 "from": {
229 "type": "string",
230 "description": "Source unit (e.g., 'km', 'miles', 'kg', 'pounds', 'celsius', 'fahrenheit')"
231 },
232 "to": {
233 "type": "string",
234 "description": "Target unit"
235 }
236 },
237 "required": ["value", "from", "to"]
238 }))
239 }
240
241 async fn execute(
242 &self,
243 _context: &ExecutionContext,
244 arguments: serde_json::Value,
245 ) -> Result<serde_json::Value, ToolError> {
246 let value = arguments["value"]
247 .as_f64()
248 .ok_or_else(|| ToolError::InvalidArguments("Missing 'value' field".to_string()))?;
249 let from = arguments["from"]
250 .as_str()
251 .ok_or_else(|| ToolError::InvalidArguments("Missing 'from' field".to_string()))?;
252 let to = arguments["to"]
253 .as_str()
254 .ok_or_else(|| ToolError::InvalidArguments("Missing 'to' field".to_string()))?;
255
256 let result = self
257 .convert(value, from, to)
258 .map_err(ToolError::ExecutionFailed)?;
259
260 Ok(json!({
261 "value": value,
262 "from": from,
263 "to": to,
264 "result": result,
265 "formatted": format!("{} {} = {} {}", format_number(value), from, format_number(result), to)
266 }))
267 }
268}
269
270pub struct StatisticsTool;
272
273impl Default for StatisticsTool {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279impl StatisticsTool {
280 pub fn new() -> Self {
281 Self
282 }
283}
284
285#[async_trait]
286impl Tool for StatisticsTool {
287 fn schema(&self) -> ToolSchema {
288 ToolSchema::new("statistics", "Calculate statistics for a list of numbers").with_parameters(
289 json!({
290 "type": "object",
291 "properties": {
292 "numbers": {
293 "type": "array",
294 "items": { "type": "number" },
295 "description": "List of numbers to analyze"
296 }
297 },
298 "required": ["numbers"]
299 }),
300 )
301 }
302
303 async fn execute(
304 &self,
305 _context: &ExecutionContext,
306 arguments: serde_json::Value,
307 ) -> Result<serde_json::Value, ToolError> {
308 let numbers: Vec<f64> = arguments["numbers"]
309 .as_array()
310 .ok_or_else(|| ToolError::InvalidArguments("Missing 'numbers' array".to_string()))?
311 .iter()
312 .filter_map(|v| v.as_f64())
313 .collect();
314
315 if numbers.is_empty() {
316 return Err(ToolError::InvalidArguments(
317 "Numbers array is empty".to_string(),
318 ));
319 }
320
321 let count = numbers.len();
322 let sum: f64 = numbers.iter().sum();
323 let mean = sum / count as f64;
324
325 let mut sorted = numbers.clone();
326 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
327
328 let median = if count.is_multiple_of(2) {
329 (sorted[count / 2 - 1] + sorted[count / 2]) / 2.0
330 } else {
331 sorted[count / 2]
332 };
333
334 let min = sorted.first().copied().unwrap();
335 let max = sorted.last().copied().unwrap();
336
337 let variance: f64 = numbers.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count as f64;
338 let std_dev = variance.sqrt();
339
340 Ok(json!({
341 "count": count,
342 "sum": sum,
343 "mean": mean,
344 "median": median,
345 "min": min,
346 "max": max,
347 "range": max - min,
348 "variance": variance,
349 "std_dev": std_dev
350 }))
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use cortexai_core::types::AgentId;
358 use std::f64::consts::{E, PI};
359
360 fn test_ctx() -> ExecutionContext {
361 ExecutionContext::new(AgentId::new("test-agent"))
362 }
363
364 #[tokio::test]
365 async fn test_calculator_basic() {
366 let calc = CalculatorTool::new();
367 let ctx = test_ctx();
368
369 let result = calc
370 .execute(&ctx, json!({"expression": "2 + 2"}))
371 .await
372 .unwrap();
373 assert_eq!(result["result"], 4.0);
374
375 let result = calc
376 .execute(&ctx, json!({"expression": "10 * 5 + 3"}))
377 .await
378 .unwrap();
379 assert_eq!(result["result"], 53.0);
380 }
381
382 #[tokio::test]
383 async fn test_calculator_scientific() {
384 let calc = CalculatorTool::new();
385 let ctx = test_ctx();
386
387 let result = calc
388 .execute(&ctx, json!({"expression": "sqrt(16)"}))
389 .await
390 .unwrap();
391 assert_eq!(result["result"], 4.0);
392
393 let result = calc
394 .execute(&ctx, json!({"expression": "2^10"}))
395 .await
396 .unwrap();
397 assert_eq!(result["result"], 1024.0);
398
399 let result = calc
400 .execute(&ctx, json!({"expression": "sin(0)"}))
401 .await
402 .unwrap();
403 assert!((result["result"].as_f64().unwrap() - 0.0).abs() < 0.0001);
404 }
405
406 #[tokio::test]
407 async fn test_calculator_constants() {
408 let calc = CalculatorTool::new();
409 let ctx = test_ctx();
410
411 let result = calc
412 .execute(&ctx, json!({"expression": "pi"}))
413 .await
414 .unwrap();
415 assert!((result["result"].as_f64().unwrap() - PI).abs() < 0.0001);
416
417 let result = calc
418 .execute(&ctx, json!({"expression": "e"}))
419 .await
420 .unwrap();
421 assert!((result["result"].as_f64().unwrap() - E).abs() < 0.0001);
422 }
423
424 #[tokio::test]
425 async fn test_unit_converter() {
426 let converter = UnitConverterTool::new();
427 let ctx = test_ctx();
428
429 let result = converter
430 .execute(&ctx, json!({"value": 1.0, "from": "km", "to": "m"}))
431 .await
432 .unwrap();
433 assert_eq!(result["result"], 1000.0);
434
435 let result = converter
436 .execute(
437 &ctx,
438 json!({"value": 32.0, "from": "fahrenheit", "to": "celsius"}),
439 )
440 .await
441 .unwrap();
442 assert!((result["result"].as_f64().unwrap() - 0.0).abs() < 0.01);
443 }
444
445 #[tokio::test]
446 async fn test_statistics() {
447 let stats = StatisticsTool::new();
448 let ctx = test_ctx();
449
450 let result = stats
451 .execute(&ctx, json!({"numbers": [1, 2, 3, 4, 5]}))
452 .await
453 .unwrap();
454
455 assert_eq!(result["count"], 5);
456 assert_eq!(result["sum"], 15.0);
457 assert_eq!(result["mean"], 3.0);
458 assert_eq!(result["median"], 3.0);
459 assert_eq!(result["min"], 1.0);
460 assert_eq!(result["max"], 5.0);
461 }
462}