Skip to main content

axon/
tool_executor.rs

1//! Native tool executors — Calculator and DateTimeTool.
2//!
3//! Tools are pure functions that execute locally (no LLM call).
4//! When a `use_tool` step references a known tool, the runner
5//! intercepts it and calls the executor directly.
6//!
7//! Supported tools:
8//!   - Calculator: safe arithmetic expression evaluator
9//!   - DateTimeTool: current date/time/timestamp queries
10
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// Result of a tool execution.
14#[derive(Debug)]
15pub struct ToolResult {
16    pub success: bool,
17    pub output: String,
18    pub tool_name: String,
19}
20
21/// Dispatch a tool call by name. Returns `None` if the tool is not a native executor.
22pub fn dispatch(tool_name: &str, argument: &str) -> Option<ToolResult> {
23    match tool_name {
24        "Calculator" => Some(calculator_execute(argument)),
25        "DateTimeTool" => Some(datetime_execute(argument)),
26        _ => None, // Not a native tool — fall through to LLM
27    }
28}
29
30// ── Calculator ──────────────────────────────────────────────────────────────
31
32/// Safe arithmetic expression evaluator.
33///
34/// Supports: +, -, *, /, % (mod), ** (power), parentheses,
35/// constants (pi, e), and functions (sqrt, abs, round, sin, cos, tan,
36/// log, ln, ceil, floor, pow, min, max).
37pub fn calculator_execute(expression: &str) -> ToolResult {
38    let expr = expression.trim();
39    if expr.is_empty() {
40        return ToolResult {
41            success: false,
42            output: "Empty expression".to_string(),
43            tool_name: "Calculator".to_string(),
44        };
45    }
46
47    match eval_expr(expr) {
48        Ok(val) => {
49            // Format: remove trailing zeros for clean output
50            let formatted = if val.fract() == 0.0 && val.abs() < 1e15 {
51                format!("{}", val as i64)
52            } else {
53                format!("{}", val)
54            };
55            ToolResult {
56                success: true,
57                output: formatted,
58                tool_name: "Calculator".to_string(),
59            }
60        }
61        Err(e) => ToolResult {
62            success: false,
63            output: format!("Calculator error: {e}"),
64            tool_name: "Calculator".to_string(),
65        },
66    }
67}
68
69// ── Calculator parser (recursive descent) ───────────────────────────────────
70
71struct CalcParser<'a> {
72    input: &'a [u8],
73    pos: usize,
74}
75
76impl<'a> CalcParser<'a> {
77    fn new(input: &'a str) -> Self {
78        Self {
79            input: input.as_bytes(),
80            pos: 0,
81        }
82    }
83
84    fn skip_ws(&mut self) {
85        while self.pos < self.input.len() && self.input[self.pos].is_ascii_whitespace() {
86            self.pos += 1;
87        }
88    }
89
90    fn peek(&mut self) -> Option<u8> {
91        self.skip_ws();
92        self.input.get(self.pos).copied()
93    }
94
95    fn consume(&mut self, expected: u8) -> bool {
96        self.skip_ws();
97        if self.pos < self.input.len() && self.input[self.pos] == expected {
98            self.pos += 1;
99            true
100        } else {
101            false
102        }
103    }
104
105    /// expr = term (('+' | '-') term)*
106    fn parse_expr(&mut self) -> Result<f64, String> {
107        let mut result = self.parse_term()?;
108        loop {
109            self.skip_ws();
110            if self.consume(b'+') {
111                result += self.parse_term()?;
112            } else if self.consume(b'-') {
113                result -= self.parse_term()?;
114            } else {
115                break;
116            }
117        }
118        Ok(result)
119    }
120
121    /// term = power (('*' | '/' | '%') power)*
122    fn parse_term(&mut self) -> Result<f64, String> {
123        let mut result = self.parse_power()?;
124        loop {
125            self.skip_ws();
126            if self.consume(b'*') {
127                if self.consume(b'*') {
128                    // ** is power — put it back and let power handle it
129                    self.pos -= 2;
130                    break;
131                }
132                result *= self.parse_power()?;
133            } else if self.consume(b'/') {
134                let divisor = self.parse_power()?;
135                if divisor == 0.0 {
136                    return Err("Division by zero".to_string());
137                }
138                result /= divisor;
139            } else if self.consume(b'%') {
140                let modulus = self.parse_power()?;
141                if modulus == 0.0 {
142                    return Err("Modulo by zero".to_string());
143                }
144                result %= modulus;
145            } else {
146                break;
147            }
148        }
149        Ok(result)
150    }
151
152    /// power = unary ('**' unary)*
153    fn parse_power(&mut self) -> Result<f64, String> {
154        let base = self.parse_unary()?;
155        self.skip_ws();
156        if self.pos + 1 < self.input.len()
157            && self.input[self.pos] == b'*'
158            && self.input[self.pos + 1] == b'*'
159        {
160            self.pos += 2;
161            let exp = self.parse_power()?; // right-associative
162            Ok(base.powf(exp))
163        } else {
164            Ok(base)
165        }
166    }
167
168    /// unary = '-' unary | '+' unary | atom
169    fn parse_unary(&mut self) -> Result<f64, String> {
170        self.skip_ws();
171        if self.consume(b'-') {
172            Ok(-self.parse_unary()?)
173        } else if self.consume(b'+') {
174            self.parse_unary()
175        } else {
176            self.parse_atom()
177        }
178    }
179
180    /// atom = number | '(' expr ')' | function '(' args ')' | constant
181    fn parse_atom(&mut self) -> Result<f64, String> {
182        self.skip_ws();
183
184        // Parenthesized expression
185        if self.consume(b'(') {
186            let val = self.parse_expr()?;
187            if !self.consume(b')') {
188                return Err("Missing closing parenthesis".to_string());
189            }
190            return Ok(val);
191        }
192
193        // Number
194        if self.pos < self.input.len()
195            && (self.input[self.pos].is_ascii_digit() || self.input[self.pos] == b'.')
196        {
197            return self.parse_number();
198        }
199
200        // Identifier (function or constant)
201        if self.pos < self.input.len() && self.input[self.pos].is_ascii_alphabetic() {
202            let name = self.parse_ident();
203            return self.resolve_ident(&name);
204        }
205
206        Err(format!(
207            "Unexpected character at position {}",
208            self.pos
209        ))
210    }
211
212    fn parse_number(&mut self) -> Result<f64, String> {
213        let start = self.pos;
214        while self.pos < self.input.len()
215            && (self.input[self.pos].is_ascii_digit() || self.input[self.pos] == b'.')
216        {
217            self.pos += 1;
218        }
219        // Handle scientific notation: 1e10, 2.5e-3
220        if self.pos < self.input.len()
221            && (self.input[self.pos] == b'e' || self.input[self.pos] == b'E')
222        {
223            self.pos += 1;
224            if self.pos < self.input.len()
225                && (self.input[self.pos] == b'+' || self.input[self.pos] == b'-')
226            {
227                self.pos += 1;
228            }
229            while self.pos < self.input.len() && self.input[self.pos].is_ascii_digit() {
230                self.pos += 1;
231            }
232        }
233        let s = std::str::from_utf8(&self.input[start..self.pos])
234            .map_err(|_| "Invalid UTF-8 in number")?;
235        s.parse::<f64>()
236            .map_err(|_| format!("Invalid number: '{s}'"))
237    }
238
239    fn parse_ident(&mut self) -> String {
240        let start = self.pos;
241        while self.pos < self.input.len()
242            && (self.input[self.pos].is_ascii_alphanumeric() || self.input[self.pos] == b'_')
243        {
244            self.pos += 1;
245        }
246        String::from_utf8_lossy(&self.input[start..self.pos]).to_string()
247    }
248
249    fn resolve_ident(&mut self, name: &str) -> Result<f64, String> {
250        // Constants
251        match name {
252            "pi" | "PI" => return Ok(std::f64::consts::PI),
253            "e" | "E" => return Ok(std::f64::consts::E),
254            "tau" | "TAU" => return Ok(std::f64::consts::TAU),
255            "inf" => return Ok(f64::INFINITY),
256            _ => {}
257        }
258
259        // Functions
260        self.skip_ws();
261        if !self.consume(b'(') {
262            return Err(format!("Unknown identifier: '{name}'"));
263        }
264
265        let args = self.parse_args()?;
266
267        if !self.consume(b')') {
268            return Err(format!("Missing ')' after {name}(...)"));
269        }
270
271        match (name, args.len()) {
272            ("sqrt", 1) => Ok(args[0].sqrt()),
273            ("abs", 1) => Ok(args[0].abs()),
274            ("round", 1) => Ok(args[0].round()),
275            ("ceil", 1) => Ok(args[0].ceil()),
276            ("floor", 1) => Ok(args[0].floor()),
277            ("sin", 1) => Ok(args[0].sin()),
278            ("cos", 1) => Ok(args[0].cos()),
279            ("tan", 1) => Ok(args[0].tan()),
280            ("asin", 1) => Ok(args[0].asin()),
281            ("acos", 1) => Ok(args[0].acos()),
282            ("atan", 1) => Ok(args[0].atan()),
283            ("log", 1) | ("log10", 1) => Ok(args[0].log10()),
284            ("ln", 1) => Ok(args[0].ln()),
285            ("log2", 1) => Ok(args[0].log2()),
286            ("exp", 1) => Ok(args[0].exp()),
287            ("pow", 2) => Ok(args[0].powf(args[1])),
288            ("min", 2) => Ok(args[0].min(args[1])),
289            ("max", 2) => Ok(args[0].max(args[1])),
290            ("atan2", 2) => Ok(args[0].atan2(args[1])),
291            _ => Err(format!("Unknown function: '{name}' with {} args", args.len())),
292        }
293    }
294
295    fn parse_args(&mut self) -> Result<Vec<f64>, String> {
296        let mut args = Vec::new();
297        self.skip_ws();
298        if self.peek() == Some(b')') {
299            return Ok(args);
300        }
301        args.push(self.parse_expr()?);
302        while self.consume(b',') {
303            args.push(self.parse_expr()?);
304        }
305        Ok(args)
306    }
307}
308
309fn eval_expr(expr: &str) -> Result<f64, String> {
310    let mut parser = CalcParser::new(expr);
311    let result = parser.parse_expr()?;
312    parser.skip_ws();
313    if parser.pos < parser.input.len() {
314        return Err(format!(
315            "Unexpected trailing characters at position {}",
316            parser.pos
317        ));
318    }
319    if result.is_nan() {
320        return Err("Result is NaN".to_string());
321    }
322    Ok(result)
323}
324
325// ── DateTimeTool ────────────────────────────────────────────────────────────
326
327/// Current date/time queries using system time (UTC).
328///
329/// Supported queries: now, today, timestamp, year, month, day, weekday, iso,
330/// hour, minute, second, date, time.
331pub fn datetime_execute(query: &str) -> ToolResult {
332    let query = query.trim().to_lowercase();
333
334    let now = SystemTime::now()
335        .duration_since(UNIX_EPOCH)
336        .unwrap_or_default();
337
338    let secs = now.as_secs();
339    let (year, month, day, hour, min, sec, weekday) = unix_to_utc(secs);
340
341    let output = match query.as_str() {
342        "now" | "iso" | "datetime" => format!(
343            "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
344            year, month, day, hour, min, sec
345        ),
346        "today" | "date" => format!("{:04}-{:02}-{:02}", year, month, day),
347        "time" => format!("{:02}:{:02}:{:02}Z", hour, min, sec),
348        "timestamp" | "unix" | "epoch" => format!("{}", secs),
349        "year" => format!("{}", year),
350        "month" => format!("{}", month),
351        "day" => format!("{}", day),
352        "hour" => format!("{}", hour),
353        "minute" => format!("{}", min),
354        "second" => format!("{}", sec),
355        "weekday" => weekday_name(weekday).to_string(),
356        _ => format!(
357            "Unknown query '{}'. Supported: now, today, timestamp, year, month, day, weekday, iso, time, hour, minute, second",
358            query
359        ),
360    };
361
362    ToolResult {
363        success: true,
364        output,
365        tool_name: "DateTimeTool".to_string(),
366    }
367}
368
369/// Convert UNIX timestamp to (year, month, day, hour, min, sec, weekday).
370/// weekday: 0=Sunday, 1=Monday, ..., 6=Saturday.
371fn unix_to_utc(secs: u64) -> (i32, u32, u32, u32, u32, u32, u32) {
372    let days = (secs / 86400) as i64;
373    let time_of_day = secs % 86400;
374
375    let hour = (time_of_day / 3600) as u32;
376    let min = ((time_of_day % 3600) / 60) as u32;
377    let sec = (time_of_day % 60) as u32;
378
379    // Weekday: Jan 1, 1970 was Thursday (4)
380    let weekday = ((days + 4) % 7) as u32;
381
382    // Civil date from days since epoch (algorithm from Howard Hinnant)
383    let z = days + 719468;
384    let era = if z >= 0 { z } else { z - 146096 } / 146097;
385    let doe = (z - era * 146097) as u32;
386    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
387    let y = (yoe as i64 + era * 400) as i32;
388    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
389    let mp = (5 * doy + 2) / 153;
390    let d = doy - (153 * mp + 2) / 5 + 1;
391    let m = if mp < 10 { mp + 3 } else { mp - 9 };
392    let year = if m <= 2 { y + 1 } else { y };
393
394    (year, m, d, hour, min, sec, weekday)
395}
396
397fn weekday_name(weekday: u32) -> &'static str {
398    match weekday {
399        0 => "Sunday",
400        1 => "Monday",
401        2 => "Tuesday",
402        3 => "Wednesday",
403        4 => "Thursday",
404        5 => "Friday",
405        6 => "Saturday",
406        _ => "Unknown",
407    }
408}
409
410// ── Tests ───────────────────────────────────────────────────────────────────
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    // Calculator tests
417
418    #[test]
419    fn calc_basic_arithmetic() {
420        assert_eq!(eval_expr("2 + 3").unwrap(), 5.0);
421        assert_eq!(eval_expr("10 - 4").unwrap(), 6.0);
422        assert_eq!(eval_expr("3 * 7").unwrap(), 21.0);
423        assert_eq!(eval_expr("20 / 4").unwrap(), 5.0);
424    }
425
426    #[test]
427    fn calc_operator_precedence() {
428        assert_eq!(eval_expr("2 + 3 * 4").unwrap(), 14.0);
429        assert_eq!(eval_expr("(2 + 3) * 4").unwrap(), 20.0);
430    }
431
432    #[test]
433    fn calc_power() {
434        assert_eq!(eval_expr("2 ** 10").unwrap(), 1024.0);
435        assert_eq!(eval_expr("3 ** 2").unwrap(), 9.0);
436    }
437
438    #[test]
439    fn calc_modulo() {
440        assert_eq!(eval_expr("17 % 5").unwrap(), 2.0);
441    }
442
443    #[test]
444    fn calc_unary_minus() {
445        assert_eq!(eval_expr("-5").unwrap(), -5.0);
446        assert_eq!(eval_expr("-3 + 7").unwrap(), 4.0);
447        assert_eq!(eval_expr("-(2 + 3)").unwrap(), -5.0);
448    }
449
450    #[test]
451    fn calc_constants() {
452        assert!((eval_expr("pi").unwrap() - std::f64::consts::PI).abs() < 1e-10);
453        assert!((eval_expr("e").unwrap() - std::f64::consts::E).abs() < 1e-10);
454    }
455
456    #[test]
457    fn calc_functions() {
458        assert_eq!(eval_expr("sqrt(16)").unwrap(), 4.0);
459        assert_eq!(eval_expr("abs(-5)").unwrap(), 5.0);
460        assert_eq!(eval_expr("round(3.7)").unwrap(), 4.0);
461        assert_eq!(eval_expr("ceil(3.2)").unwrap(), 4.0);
462        assert_eq!(eval_expr("floor(3.8)").unwrap(), 3.0);
463        assert_eq!(eval_expr("pow(2, 8)").unwrap(), 256.0);
464        assert_eq!(eval_expr("min(3, 7)").unwrap(), 3.0);
465        assert_eq!(eval_expr("max(3, 7)").unwrap(), 7.0);
466    }
467
468    #[test]
469    fn calc_trig() {
470        assert!((eval_expr("sin(0)").unwrap()).abs() < 1e-10);
471        assert!((eval_expr("cos(0)").unwrap() - 1.0).abs() < 1e-10);
472    }
473
474    #[test]
475    fn calc_logarithm() {
476        assert!((eval_expr("log(100)").unwrap() - 2.0).abs() < 1e-10);
477        assert!((eval_expr("ln(e)").unwrap() - 1.0).abs() < 1e-10);
478    }
479
480    #[test]
481    fn calc_nested() {
482        assert_eq!(eval_expr("sqrt(pow(3, 2) + pow(4, 2))").unwrap(), 5.0);
483    }
484
485    #[test]
486    fn calc_scientific_notation() {
487        assert_eq!(eval_expr("1e3").unwrap(), 1000.0);
488        assert_eq!(eval_expr("2.5e2").unwrap(), 250.0);
489    }
490
491    #[test]
492    fn calc_division_by_zero() {
493        assert!(eval_expr("1 / 0").is_err());
494    }
495
496    #[test]
497    fn calc_empty_expression() {
498        let r = calculator_execute("");
499        assert!(!r.success);
500    }
501
502    #[test]
503    fn calc_invalid_expression() {
504        assert!(eval_expr("2 +").is_err());
505    }
506
507    #[test]
508    fn calc_integer_output() {
509        let r = calculator_execute("2 + 3");
510        assert!(r.success);
511        assert_eq!(r.output, "5");
512    }
513
514    #[test]
515    fn calc_float_output() {
516        let r = calculator_execute("1 / 3");
517        assert!(r.success);
518        assert!(r.output.starts_with("0.333"));
519    }
520
521    // DateTimeTool tests
522
523    #[test]
524    fn datetime_now_iso_format() {
525        let r = datetime_execute("now");
526        assert!(r.success);
527        assert!(r.output.contains('T'));
528        assert!(r.output.ends_with('Z'));
529    }
530
531    #[test]
532    fn datetime_today() {
533        let r = datetime_execute("today");
534        assert!(r.success);
535        assert_eq!(r.output.len(), 10); // YYYY-MM-DD
536        assert!(r.output.contains('-'));
537    }
538
539    #[test]
540    fn datetime_timestamp() {
541        let r = datetime_execute("timestamp");
542        assert!(r.success);
543        let ts: u64 = r.output.parse().expect("should be a number");
544        assert!(ts > 1700000000); // After ~2023
545    }
546
547    #[test]
548    fn datetime_year() {
549        let r = datetime_execute("year");
550        assert!(r.success);
551        let y: i32 = r.output.parse().expect("should be a number");
552        assert!(y >= 2024);
553    }
554
555    #[test]
556    fn datetime_weekday() {
557        let r = datetime_execute("weekday");
558        assert!(r.success);
559        let valid = ["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"];
560        assert!(valid.contains(&r.output.as_str()));
561    }
562
563    #[test]
564    fn datetime_unknown_query() {
565        let r = datetime_execute("foobar");
566        assert!(r.success);
567        assert!(r.output.contains("Unknown query"));
568    }
569
570    // Dispatch tests
571
572    #[test]
573    fn dispatch_calculator() {
574        let r = dispatch("Calculator", "2 + 2");
575        assert!(r.is_some());
576        let r = r.unwrap();
577        assert!(r.success);
578        assert_eq!(r.output, "4");
579    }
580
581    #[test]
582    fn dispatch_datetime() {
583        let r = dispatch("DateTimeTool", "now");
584        assert!(r.is_some());
585        assert!(r.unwrap().success);
586    }
587
588    #[test]
589    fn dispatch_unknown_tool() {
590        assert!(dispatch("WebSearch", "query").is_none());
591    }
592}