1use std::collections::HashMap;
15
16use crate::api_types::{FunctionCallResult, ToolCall, ToolDefinition};
17use crate::grammar::{compile_json_schema, Grammar, Rule, Symbol};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum ToolCallError {
24 NoToolCallFound,
26 UnknownTool { name: String },
28 MalformedArguments { reason: String },
30 GrammarCompileError { reason: String },
32 EmptyToolList,
34}
35
36impl std::fmt::Display for ToolCallError {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 ToolCallError::NoToolCallFound => write!(f, "no tool call found in model output"),
40 ToolCallError::UnknownTool { name } => write!(f, "unknown tool: '{name}'"),
41 ToolCallError::MalformedArguments { reason } => {
42 write!(f, "malformed tool arguments: {reason}")
43 }
44 ToolCallError::GrammarCompileError { reason } => {
45 write!(f, "grammar compile error: {reason}")
46 }
47 ToolCallError::EmptyToolList => write!(f, "tool list is empty"),
48 }
49 }
50}
51
52impl std::error::Error for ToolCallError {}
53
54pub fn new_tool_call_id() -> String {
61 crate::api_types::generate_tool_call_id()
62}
63
64pub fn make_tool_call(id: String, name: String, arguments: String) -> ToolCall {
71 ToolCall {
72 id,
73 tool_type: "function".to_string(),
74 function: FunctionCallResult { name, arguments },
75 }
76}
77
78pub fn select_tool(output: &str, tools: &[ToolDefinition]) -> Result<ToolCall, ToolCallError> {
97 let call_id = new_tool_call_id();
98
99 let tool_call = crate::api_types::parse_tool_call(output, &call_id)
101 .ok_or(ToolCallError::NoToolCallFound)?;
102
103 if !tools.is_empty() {
105 let known = tools
106 .iter()
107 .any(|t| t.function.name == tool_call.function.name);
108 if !known {
109 return Err(ToolCallError::UnknownTool {
110 name: tool_call.function.name.clone(),
111 });
112 }
113 }
114
115 let _parsed: serde_json::Value =
117 serde_json::from_str(&tool_call.function.arguments).map_err(|e| {
118 ToolCallError::MalformedArguments {
119 reason: e.to_string(),
120 }
121 })?;
122
123 Ok(tool_call)
124}
125
126pub fn build_tool_constraint(tools: &[ToolDefinition]) -> Result<Grammar, ToolCallError> {
146 if tools.is_empty() {
147 return Err(ToolCallError::EmptyToolList);
148 }
149
150 let mut args_grammars: Vec<Grammar> = Vec::with_capacity(tools.len());
152 for tool in tools {
153 let g = compile_json_schema(&tool.function.parameters).map_err(|e| {
154 ToolCallError::GrammarCompileError {
155 reason: format!("{e}"),
156 }
157 })?;
158 args_grammars.push(g);
159 }
160
161 merge_tool_grammars(tools, args_grammars)
162}
163
164fn merge_tool_grammars(
171 tools: &[ToolDefinition],
172 args_grammars: Vec<Grammar>,
173) -> Result<Grammar, ToolCallError> {
174 let mut merged = Grammar::new(0);
176 let root_nt = merged.alloc_nt("tool_call_root"); debug_assert_eq!(root_nt, 0, "root_nt must be 0 to match start");
178
179 let mut next_nt: usize = 1;
181
182 for (tool_idx, (tool, arg_grammar)) in tools.iter().zip(args_grammars.iter()).enumerate() {
183 let arg_nt_count = arg_grammar
186 .rules
187 .iter()
188 .flat_map(|r| {
189 std::iter::once(r.lhs).chain(r.rhs.iter().filter_map(|s| s.non_terminal_id()))
190 })
191 .max()
192 .map(|m| m + 1)
193 .unwrap_or(0);
194
195 let nt_offset = next_nt;
196
197 for nt_j in 0..arg_nt_count {
199 merged.alloc_nt(format!("t{tool_idx}_nt{nt_j}"));
200 }
201 next_nt += arg_nt_count;
202
203 for rule in &arg_grammar.rules {
205 let new_lhs = rule.lhs + nt_offset;
206 let new_rhs: Vec<Symbol> = rule
207 .rhs
208 .iter()
209 .map(|sym| match sym {
210 Symbol::NonTerminal(id) => Symbol::NonTerminal(id + nt_offset),
211 Symbol::Terminal(bytes) => Symbol::Terminal(bytes.clone()),
212 })
213 .collect();
214 merged.add_rule(Rule::new(new_lhs, new_rhs));
215 }
216
217 let args_start = arg_grammar.start + nt_offset;
219
220 let prefix = format!(
222 "<tool_call>{{\"name\":\"{}\",\"arguments\":",
223 tool.function.name
224 );
225 let suffix = "}</tool_call>".to_string();
226
227 merged.add_rule(Rule::new(
228 root_nt,
229 vec![
230 Symbol::Terminal(prefix.into_bytes()),
231 Symbol::NonTerminal(args_start),
232 Symbol::Terminal(suffix.into_bytes()),
233 ],
234 ));
235 }
236
237 Ok(merged)
238}
239
240pub struct ToolRegistry<'a> {
247 map: HashMap<&'a str, &'a ToolDefinition>,
248}
249
250impl<'a> ToolRegistry<'a> {
251 pub fn new(tools: &'a [ToolDefinition]) -> Self {
253 let map = tools
254 .iter()
255 .map(|t| (t.function.name.as_str(), t))
256 .collect();
257 Self { map }
258 }
259
260 pub fn get(&self, name: &str) -> Option<&ToolDefinition> {
262 self.map.get(name).copied()
263 }
264
265 pub fn names(&self) -> impl Iterator<Item = &str> {
267 self.map.keys().copied()
268 }
269
270 pub fn len(&self) -> usize {
272 self.map.len()
273 }
274
275 pub fn is_empty(&self) -> bool {
277 self.map.is_empty()
278 }
279}
280
281pub fn validate_tool_arguments(
290 arguments: &str,
291 tool: &ToolDefinition,
292) -> Result<serde_json::Value, ToolCallError> {
293 let parsed: serde_json::Value =
294 serde_json::from_str(arguments).map_err(|e| ToolCallError::MalformedArguments {
295 reason: e.to_string(),
296 })?;
297
298 if !parsed.is_object() {
299 return Err(ToolCallError::MalformedArguments {
300 reason: "tool arguments must be a JSON object".to_string(),
301 });
302 }
303
304 if let Some(required) = tool.function.parameters.get("required") {
306 if let Some(req_arr) = required.as_array() {
307 let obj = parsed.as_object().expect("parsed is_object checked above");
308 for req_field in req_arr {
309 if let Some(field_name) = req_field.as_str() {
310 if !obj.contains_key(field_name) {
311 return Err(ToolCallError::MalformedArguments {
312 reason: format!("missing required field '{field_name}'"),
313 });
314 }
315 }
316 }
317 }
318 }
319
320 Ok(parsed)
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use serde_json::json;
327
328 fn weather_tool() -> ToolDefinition {
329 ToolDefinition::function(
330 "get_weather",
331 Some("Get current weather".to_string()),
332 json!({
333 "type": "object",
334 "properties": {
335 "location": {"type": "string"},
336 "unit": {"type": "string"}
337 },
338 "required": ["location"]
339 }),
340 )
341 }
342
343 fn calc_tool() -> ToolDefinition {
344 ToolDefinition::function(
345 "calculate",
346 Some("Perform a calculation".to_string()),
347 json!({
348 "type": "object",
349 "properties": {
350 "expression": {"type": "string"}
351 },
352 "required": ["expression"]
353 }),
354 )
355 }
356
357 #[test]
360 fn tool_call_id_has_call_prefix() {
361 let id = new_tool_call_id();
362 assert!(id.starts_with("call_"), "id={id}");
363 }
364
365 #[test]
366 fn tool_call_ids_are_generated_repeatedly() {
367 let ids: Vec<_> = (0..5).map(|_| new_tool_call_id()).collect();
368 for id in &ids {
369 assert!(id.starts_with("call_"));
370 }
371 }
372
373 #[test]
376 fn make_tool_call_round_trips_fields() {
377 let tc = make_tool_call(
378 "call_abc123".to_string(),
379 "get_weather".to_string(),
380 r#"{"location":"Paris"}"#.to_string(),
381 );
382 assert_eq!(tc.id, "call_abc123");
383 assert_eq!(tc.tool_type, "function");
384 assert_eq!(tc.function.name, "get_weather");
385 assert_eq!(tc.function.arguments, r#"{"location":"Paris"}"#);
386 }
387
388 #[test]
391 fn select_tool_parses_xml_wrapper() {
392 let output =
393 r#"<tool_call>{"name":"get_weather","arguments":{"location":"Tokyo"}}</tool_call>"#;
394 let tools = vec![weather_tool()];
395 let tc = select_tool(output, &tools).expect("should parse");
396 assert_eq!(tc.function.name, "get_weather");
397 let args: serde_json::Value =
398 serde_json::from_str(&tc.function.arguments).expect("valid json");
399 assert_eq!(args["location"], "Tokyo");
400 }
401
402 #[test]
403 fn select_tool_no_tag_returns_not_found() {
404 let output = "I will now get the weather for Paris.";
405 let tools = vec![weather_tool()];
406 assert!(matches!(
407 select_tool(output, &tools),
408 Err(ToolCallError::NoToolCallFound)
409 ));
410 }
411
412 #[test]
413 fn select_tool_unknown_name_returns_error() {
414 let output = r#"<tool_call>{"name":"unknown_fn","arguments":{}}</tool_call>"#;
415 let tools = vec![weather_tool()];
416 assert!(matches!(
417 select_tool(output, &tools),
418 Err(ToolCallError::UnknownTool { .. })
419 ));
420 }
421
422 #[test]
423 fn select_tool_empty_tools_skips_name_check() {
424 let output = r#"<tool_call>{"name":"any_function","arguments":{}}</tool_call>"#;
425 let tc = select_tool(output, &[]).expect("should accept any tool");
426 assert_eq!(tc.function.name, "any_function");
427 }
428
429 #[test]
432 fn validate_tool_args_all_required_present() {
433 let tool = weather_tool();
434 let args = r#"{"location":"Berlin","unit":"celsius"}"#;
435 assert!(validate_tool_arguments(args, &tool).is_ok());
436 }
437
438 #[test]
439 fn validate_tool_args_missing_required_returns_error() {
440 let tool = weather_tool();
441 let args = r#"{"unit":"fahrenheit"}"#;
442 assert!(matches!(
443 validate_tool_arguments(args, &tool),
444 Err(ToolCallError::MalformedArguments { .. })
445 ));
446 }
447
448 #[test]
449 fn validate_tool_args_invalid_json_returns_error() {
450 let tool = weather_tool();
451 assert!(matches!(
452 validate_tool_arguments("{bad json}", &tool),
453 Err(ToolCallError::MalformedArguments { .. })
454 ));
455 }
456
457 #[test]
460 fn build_tool_constraint_empty_tools_returns_error() {
461 assert!(matches!(
462 build_tool_constraint(&[]),
463 Err(ToolCallError::EmptyToolList)
464 ));
465 }
466
467 #[test]
468 fn build_tool_constraint_single_tool_returns_grammar() {
469 let tools = vec![weather_tool()];
470 let g = build_tool_constraint(&tools).expect("should build grammar");
471 assert!(!g.rules.is_empty(), "grammar must have rules");
472 }
473
474 #[test]
475 fn build_tool_constraint_multi_tool_root_has_one_rule_per_tool() {
476 let tools = vec![weather_tool(), calc_tool()];
477 let g = build_tool_constraint(&tools).expect("should build grammar");
478 let root_rules: Vec<_> = g.rules.iter().filter(|r| r.lhs == g.start).collect();
479 assert_eq!(root_rules.len(), 2, "one rule per tool in root NT");
480 }
481
482 #[test]
485 fn tool_registry_lookup_by_name() {
486 let tools = vec![weather_tool(), calc_tool()];
487 let reg = ToolRegistry::new(&tools);
488 assert!(reg.get("get_weather").is_some());
489 assert!(reg.get("calculate").is_some());
490 assert!(reg.get("missing").is_none());
491 }
492
493 #[test]
494 fn tool_registry_len_and_is_empty() {
495 let tools = vec![weather_tool()];
496 let reg = ToolRegistry::new(&tools);
497 assert_eq!(reg.len(), 1);
498 assert!(!reg.is_empty());
499 let empty: Vec<ToolDefinition> = vec![];
500 let er = ToolRegistry::new(&empty);
501 assert!(er.is_empty());
502 }
503
504 #[test]
507 fn tool_call_error_display_not_empty() {
508 let errors = [
509 ToolCallError::NoToolCallFound,
510 ToolCallError::UnknownTool { name: "foo".into() },
511 ToolCallError::MalformedArguments {
512 reason: "bad".into(),
513 },
514 ToolCallError::GrammarCompileError {
515 reason: "oops".into(),
516 },
517 ToolCallError::EmptyToolList,
518 ];
519 for e in &errors {
520 assert!(!e.to_string().is_empty(), "error {e:?} has empty Display");
521 }
522 }
523}