Skip to main content

oxibonsai_runtime/
tool_calling.rs

1//! High-level tool-calling orchestration for OxiBonsai.
2//!
3//! This module sits on top of the low-level `api_types` helpers and provides a
4//! complete tool-use pipeline:
5//!
6//! 1. **Schema → grammar**: `build_tool_constraint` compiles a list of
7//!    [`ToolDefinition`]s into a BNF [`Grammar`] that constrains generation to
8//!    valid JSON tool invocations.
9//! 2. **Output → call**: `select_tool` parses raw model output and extracts the
10//!    first [`ToolCall`] it finds, matching against a provided registry.
11//! 3. **Convenience constructors**: `make_tool_call` and `new_tool_call_id`
12//!    expose the low-level helpers under module-level names.
13
14use std::collections::HashMap;
15
16use crate::api_types::{FunctionCallResult, ToolCall, ToolDefinition};
17use crate::grammar::{compile_json_schema, Grammar, Rule, Symbol};
18
19// ── Error type ────────────────────────────────────────────────────────────────
20
21/// Errors produced by the tool-calling layer.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum ToolCallError {
24    /// The model output contained no tool call.
25    NoToolCallFound,
26    /// The extracted function name does not match any registered tool.
27    UnknownTool { name: String },
28    /// The argument JSON in the tool call could not be parsed.
29    MalformedArguments { reason: String },
30    /// The grammar for a tool definition could not be compiled.
31    GrammarCompileError { reason: String },
32    /// The provided tool list is empty (nothing to constrain against).
33    EmptyToolList,
34}
35
36impl std::fmt::Display for ToolCallError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            ToolCallError::NoToolCallFound => write!(f, "no tool call found in model output"),
40            ToolCallError::UnknownTool { name } => write!(f, "unknown tool: '{name}'"),
41            ToolCallError::MalformedArguments { reason } => {
42                write!(f, "malformed tool arguments: {reason}")
43            }
44            ToolCallError::GrammarCompileError { reason } => {
45                write!(f, "grammar compile error: {reason}")
46            }
47            ToolCallError::EmptyToolList => write!(f, "tool list is empty"),
48        }
49    }
50}
51
52impl std::error::Error for ToolCallError {}
53
54// ── ID generation ─────────────────────────────────────────────────────────────
55
56/// Generate a unique tool-call identifier with the `call_` prefix.
57///
58/// Delegates to [`crate::api_types::generate_tool_call_id`] and is exposed
59/// here for ergonomic use alongside the rest of the tool-calling API.
60pub fn new_tool_call_id() -> String {
61    crate::api_types::generate_tool_call_id()
62}
63
64// ── Tool call construction ────────────────────────────────────────────────────
65
66/// Construct a [`ToolCall`] from its constituent parts.
67///
68/// `id` should be produced by [`new_tool_call_id`]. The `arguments` string must
69/// be a JSON object serialised to a `String` (the OpenAI wire format).
70pub fn make_tool_call(id: String, name: String, arguments: String) -> ToolCall {
71    ToolCall {
72        id,
73        tool_type: "function".to_string(),
74        function: FunctionCallResult { name, arguments },
75    }
76}
77
78// ── Tool selection ────────────────────────────────────────────────────────────
79
80/// Parse raw model output and extract the first tool call.
81///
82/// The parser looks for the `<tool_call>…</tool_call>` pattern emitted by the
83/// model and validates:
84///
85/// 1. That a `name` field is present.
86/// 2. That the name appears in `tools` (when `tools` is non-empty).
87/// 3. That the `arguments` value, if present, is a valid JSON object.
88///
89/// On success the returned [`ToolCall`] carries a freshly generated ID.
90///
91/// # Errors
92///
93/// - [`ToolCallError::NoToolCallFound`] — no `<tool_call>` tag found.
94/// - [`ToolCallError::UnknownTool`]    — name not in `tools` registry.
95/// - [`ToolCallError::MalformedArguments`] — argument payload is not valid JSON.
96pub fn select_tool(output: &str, tools: &[ToolDefinition]) -> Result<ToolCall, ToolCallError> {
97    let call_id = new_tool_call_id();
98
99    // Use the low-level parser from api_types.
100    let tool_call = crate::api_types::parse_tool_call(output, &call_id)
101        .ok_or(ToolCallError::NoToolCallFound)?;
102
103    // Validate the name against the registered tools (if any).
104    if !tools.is_empty() {
105        let known = tools
106            .iter()
107            .any(|t| t.function.name == tool_call.function.name);
108        if !known {
109            return Err(ToolCallError::UnknownTool {
110                name: tool_call.function.name.clone(),
111            });
112        }
113    }
114
115    // Validate that the arguments string is valid JSON.
116    let _parsed: serde_json::Value =
117        serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
118            ToolCallError::MalformedArguments {
119                reason: e.to_string(),
120            }
121        })?;
122
123    Ok(tool_call)
124}
125
126// ── Grammar constraint construction ──────────────────────────────────────────
127
128/// Compile a list of tool definitions into a BNF grammar that constrains model
129/// output to valid JSON tool invocations.
130///
131/// The generated grammar produces outputs of the form:
132///
133/// ```text
134/// <tool_call>{"name": "<fn_name>", "arguments": <ARGS_SCHEMA>}</tool_call>
135/// ```
136///
137/// where `<ARGS_SCHEMA>` is constrained by the JSON Schema of each function's
138/// `parameters` field. When multiple tools are provided the grammar accepts any
139/// one of them (union of alternatives).
140///
141/// # Errors
142///
143/// Returns [`ToolCallError::EmptyToolList`] when `tools` is empty, or
144/// [`ToolCallError::GrammarCompileError`] if any schema fails to compile.
145pub fn build_tool_constraint(tools: &[ToolDefinition]) -> Result<Grammar, ToolCallError> {
146    if tools.is_empty() {
147        return Err(ToolCallError::EmptyToolList);
148    }
149
150    // Compile one grammar per tool, then merge into a union.
151    let mut args_grammars: Vec<Grammar> = Vec::with_capacity(tools.len());
152    for tool in tools {
153        let g = compile_json_schema(&tool.function.parameters).map_err(|e| {
154            ToolCallError::GrammarCompileError {
155                reason: format!("{e}"),
156            }
157        })?;
158        args_grammars.push(g);
159    }
160
161    merge_tool_grammars(tools, args_grammars)
162}
163
164/// Merge per-tool parameter grammars into a single root grammar that accepts
165/// any valid `<tool_call>…</tool_call>` invocation.
166///
167/// All NT IDs from each arg grammar are remapped to fresh IDs in the merged
168/// grammar so there are no collisions. The root NT (id=0) has one rule per tool;
169/// each rule is a terminal prefix + the remapped arg-grammar start NT + suffix.
170fn merge_tool_grammars(
171    tools: &[ToolDefinition],
172    args_grammars: Vec<Grammar>,
173) -> Result<Grammar, ToolCallError> {
174    // Root NT gets id=0. Grammar::new(0) sets start=0.
175    let mut merged = Grammar::new(0);
176    let root_nt = merged.alloc_nt("tool_call_root"); // id=0
177    debug_assert_eq!(root_nt, 0, "root_nt must be 0 to match start");
178
179    // next_nt tracks how many NTs we have allocated so far (root = 1).
180    let mut next_nt: usize = 1;
181
182    for (tool_idx, (tool, arg_grammar)) in tools.iter().zip(args_grammars.iter()).enumerate() {
183        // Determine the NT count of arg_grammar by finding the maximum NT id
184        // referenced across all rules (lhs and rhs), then +1.
185        let arg_nt_count = arg_grammar
186            .rules
187            .iter()
188            .flat_map(|r| {
189                std::iter::once(r.lhs).chain(r.rhs.iter().filter_map(|s| s.non_terminal_id()))
190            })
191            .max()
192            .map(|m| m + 1)
193            .unwrap_or(0);
194
195        let nt_offset = next_nt;
196
197        // Allocate arg_nt_count fresh NTs in the merged grammar.
198        for nt_j in 0..arg_nt_count {
199            merged.alloc_nt(format!("t{tool_idx}_nt{nt_j}"));
200        }
201        next_nt += arg_nt_count;
202
203        // Copy rules with remapped NT IDs.
204        for rule in &arg_grammar.rules {
205            let new_lhs = rule.lhs + nt_offset;
206            let new_rhs: Vec<Symbol> = rule
207                .rhs
208                .iter()
209                .map(|sym| match sym {
210                    Symbol::NonTerminal(id) => Symbol::NonTerminal(id + nt_offset),
211                    Symbol::Terminal(bytes) => Symbol::Terminal(bytes.clone()),
212                })
213                .collect();
214            merged.add_rule(Rule::new(new_lhs, new_rhs));
215        }
216
217        // The start NT of the arg grammar, offset into merged scope.
218        let args_start = arg_grammar.start + nt_offset;
219
220        // Root rule: root → Terminal(prefix) NonTerminal(args_start) Terminal(suffix)
221        let prefix = format!(
222            "<tool_call>{{\"name\":\"{}\",\"arguments\":",
223            tool.function.name
224        );
225        let suffix = "}</tool_call>".to_string();
226
227        merged.add_rule(Rule::new(
228            root_nt,
229            vec![
230                Symbol::Terminal(prefix.into_bytes()),
231                Symbol::NonTerminal(args_start),
232                Symbol::Terminal(suffix.into_bytes()),
233            ],
234        ));
235    }
236
237    Ok(merged)
238}
239
240// ── Tool registry helper ──────────────────────────────────────────────────────
241
242/// A lightweight registry of tools keyed by function name for O(1) lookup.
243///
244/// Build it once from a `&[ToolDefinition]` slice; query it with
245/// [`ToolRegistry::get`].
246pub struct ToolRegistry<'a> {
247    map: HashMap<&'a str, &'a ToolDefinition>,
248}
249
250impl<'a> ToolRegistry<'a> {
251    /// Build a registry from a slice of tool definitions.
252    pub fn new(tools: &'a [ToolDefinition]) -> Self {
253        let map = tools
254            .iter()
255            .map(|t| (t.function.name.as_str(), t))
256            .collect();
257        Self { map }
258    }
259
260    /// Look up a tool by name.
261    pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
262        self.map.get(name).copied()
263    }
264
265    /// Return all registered tool names.
266    pub fn names(&self) -> impl Iterator<Item = &str> {
267        self.map.keys().copied()
268    }
269
270    /// Number of registered tools.
271    pub fn len(&self) -> usize {
272        self.map.len()
273    }
274
275    /// `true` if the registry contains no tools.
276    pub fn is_empty(&self) -> bool {
277        self.map.is_empty()
278    }
279}
280
281// ── Argument validation ───────────────────────────────────────────────────────
282
283/// Validate that a JSON arguments string satisfies a tool's parameter schema.
284///
285/// This is a structural check: confirms the arguments parse as a JSON object
286/// and that every required property listed in the schema is present.
287///
288/// Returns `Ok(serde_json::Value)` on success (the parsed arguments object).
289pub fn validate_tool_arguments(
290    arguments: &str,
291    tool: &ToolDefinition,
292) -> Result<serde_json::Value, ToolCallError> {
293    let parsed: serde_json::Value =
294        serde_json::from_str(arguments).map_err(|e| ToolCallError::MalformedArguments {
295            reason: e.to_string(),
296        })?;
297
298    if !parsed.is_object() {
299        return Err(ToolCallError::MalformedArguments {
300            reason: "tool arguments must be a JSON object".to_string(),
301        });
302    }
303
304    // Validate required properties if defined in the schema.
305    if let Some(required) = tool.function.parameters.get("required") {
306        if let Some(req_arr) = required.as_array() {
307            let obj = parsed.as_object().expect("parsed is_object checked above");
308            for req_field in req_arr {
309                if let Some(field_name) = req_field.as_str() {
310                    if !obj.contains_key(field_name) {
311                        return Err(ToolCallError::MalformedArguments {
312                            reason: format!("missing required field '{field_name}'"),
313                        });
314                    }
315                }
316            }
317        }
318    }
319
320    Ok(parsed)
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use serde_json::json;
327
328    fn weather_tool() -> ToolDefinition {
329        ToolDefinition::function(
330            "get_weather",
331            Some("Get current weather".to_string()),
332            json!({
333                "type": "object",
334                "properties": {
335                    "location": {"type": "string"},
336                    "unit": {"type": "string"}
337                },
338                "required": ["location"]
339            }),
340        )
341    }
342
343    fn calc_tool() -> ToolDefinition {
344        ToolDefinition::function(
345            "calculate",
346            Some("Perform a calculation".to_string()),
347            json!({
348                "type": "object",
349                "properties": {
350                    "expression": {"type": "string"}
351                },
352                "required": ["expression"]
353            }),
354        )
355    }
356
357    // ── new_tool_call_id ──────────────────────────────────────────────────────
358
359    #[test]
360    fn tool_call_id_has_call_prefix() {
361        let id = new_tool_call_id();
362        assert!(id.starts_with("call_"), "id={id}");
363    }
364
365    #[test]
366    fn tool_call_ids_are_generated_repeatedly() {
367        let ids: Vec<_> = (0..5).map(|_| new_tool_call_id()).collect();
368        for id in &ids {
369            assert!(id.starts_with("call_"));
370        }
371    }
372
373    // ── make_tool_call ────────────────────────────────────────────────────────
374
375    #[test]
376    fn make_tool_call_round_trips_fields() {
377        let tc = make_tool_call(
378            "call_abc123".to_string(),
379            "get_weather".to_string(),
380            r#"{"location":"Paris"}"#.to_string(),
381        );
382        assert_eq!(tc.id, "call_abc123");
383        assert_eq!(tc.tool_type, "function");
384        assert_eq!(tc.function.name, "get_weather");
385        assert_eq!(tc.function.arguments, r#"{"location":"Paris"}"#);
386    }
387
388    // ── select_tool ───────────────────────────────────────────────────────────
389
390    #[test]
391    fn select_tool_parses_xml_wrapper() {
392        let output =
393            r#"<tool_call>{"name":"get_weather","arguments":{"location":"Tokyo"}}</tool_call>"#;
394        let tools = vec![weather_tool()];
395        let tc = select_tool(output, &tools).expect("should parse");
396        assert_eq!(tc.function.name, "get_weather");
397        let args: serde_json::Value =
398            serde_json::from_str(&tc.function.arguments).expect("valid json");
399        assert_eq!(args["location"], "Tokyo");
400    }
401
402    #[test]
403    fn select_tool_no_tag_returns_not_found() {
404        let output = "I will now get the weather for Paris.";
405        let tools = vec![weather_tool()];
406        assert!(matches!(
407            select_tool(output, &tools),
408            Err(ToolCallError::NoToolCallFound)
409        ));
410    }
411
412    #[test]
413    fn select_tool_unknown_name_returns_error() {
414        let output = r#"<tool_call>{"name":"unknown_fn","arguments":{}}</tool_call>"#;
415        let tools = vec![weather_tool()];
416        assert!(matches!(
417            select_tool(output, &tools),
418            Err(ToolCallError::UnknownTool { .. })
419        ));
420    }
421
422    #[test]
423    fn select_tool_empty_tools_skips_name_check() {
424        let output = r#"<tool_call>{"name":"any_function","arguments":{}}</tool_call>"#;
425        let tc = select_tool(output, &[]).expect("should accept any tool");
426        assert_eq!(tc.function.name, "any_function");
427    }
428
429    // ── validate_tool_arguments ───────────────────────────────────────────────
430
431    #[test]
432    fn validate_tool_args_all_required_present() {
433        let tool = weather_tool();
434        let args = r#"{"location":"Berlin","unit":"celsius"}"#;
435        assert!(validate_tool_arguments(args, &tool).is_ok());
436    }
437
438    #[test]
439    fn validate_tool_args_missing_required_returns_error() {
440        let tool = weather_tool();
441        let args = r#"{"unit":"fahrenheit"}"#;
442        assert!(matches!(
443            validate_tool_arguments(args, &tool),
444            Err(ToolCallError::MalformedArguments { .. })
445        ));
446    }
447
448    #[test]
449    fn validate_tool_args_invalid_json_returns_error() {
450        let tool = weather_tool();
451        assert!(matches!(
452            validate_tool_arguments("{bad json}", &tool),
453            Err(ToolCallError::MalformedArguments { .. })
454        ));
455    }
456
457    // ── build_tool_constraint ─────────────────────────────────────────────────
458
459    #[test]
460    fn build_tool_constraint_empty_tools_returns_error() {
461        assert!(matches!(
462            build_tool_constraint(&[]),
463            Err(ToolCallError::EmptyToolList)
464        ));
465    }
466
467    #[test]
468    fn build_tool_constraint_single_tool_returns_grammar() {
469        let tools = vec![weather_tool()];
470        let g = build_tool_constraint(&tools).expect("should build grammar");
471        assert!(!g.rules.is_empty(), "grammar must have rules");
472    }
473
474    #[test]
475    fn build_tool_constraint_multi_tool_root_has_one_rule_per_tool() {
476        let tools = vec![weather_tool(), calc_tool()];
477        let g = build_tool_constraint(&tools).expect("should build grammar");
478        let root_rules: Vec<_> = g.rules.iter().filter(|r| r.lhs == g.start).collect();
479        assert_eq!(root_rules.len(), 2, "one rule per tool in root NT");
480    }
481
482    // ── ToolRegistry ──────────────────────────────────────────────────────────
483
484    #[test]
485    fn tool_registry_lookup_by_name() {
486        let tools = vec![weather_tool(), calc_tool()];
487        let reg = ToolRegistry::new(&tools);
488        assert!(reg.get("get_weather").is_some());
489        assert!(reg.get("calculate").is_some());
490        assert!(reg.get("missing").is_none());
491    }
492
493    #[test]
494    fn tool_registry_len_and_is_empty() {
495        let tools = vec![weather_tool()];
496        let reg = ToolRegistry::new(&tools);
497        assert_eq!(reg.len(), 1);
498        assert!(!reg.is_empty());
499        let empty: Vec<ToolDefinition> = vec![];
500        let er = ToolRegistry::new(&empty);
501        assert!(er.is_empty());
502    }
503
504    // ── ToolCallError display ─────────────────────────────────────────────────
505
506    #[test]
507    fn tool_call_error_display_not_empty() {
508        let errors = [
509            ToolCallError::NoToolCallFound,
510            ToolCallError::UnknownTool { name: "foo".into() },
511            ToolCallError::MalformedArguments {
512                reason: "bad".into(),
513            },
514            ToolCallError::GrammarCompileError {
515                reason: "oops".into(),
516            },
517            ToolCallError::EmptyToolList,
518        ];
519        for e in &errors {
520            assert!(!e.to_string().is_empty(), "error {e:?} has empty Display");
521        }
522    }
523}