adk_agent/
tool_call_markup.rs

1//! XML-based tool call markup parsing.
2//!
3//! Some models (especially smaller ones or those without native function calling)
4//! output tool calls using XML-like markup:
5//!
6//! ```text
7//! <tool_call>
8//! function_name
9//! <arg_key>param1</arg_key>
10//! <arg_value>value1</arg_value>
11//! <arg_key>param2</arg_key>
12//! <arg_value>value2</arg_value>
13//! </tool_call>
14//! ```
15//!
16//! This module provides utilities to parse such markup into proper `Part::FunctionCall`.
17
18use adk_core::{Content, Part};
19
20/// Normalize content by converting tool call markup in text parts to FunctionCall parts.
21pub fn normalize_content(content: &mut Content) {
22    let parts = std::mem::take(&mut content.parts);
23    let mut normalized = Vec::new();
24
25    for part in parts {
26        match part {
27            Part::Text { text } => {
28                normalized.extend(convert_text_to_parts(text));
29            }
30            other => normalized.push(other),
31        }
32    }
33
34    content.parts = normalized;
35}
36
37/// Normalize Option<Content> by converting tool call markup.
38pub fn normalize_option_content(content: &mut Option<Content>) {
39    if let Some(ref mut content) = content {
40        normalize_content(content);
41    }
42}
43
44/// Convert text containing tool call markup to a list of parts.
45fn convert_text_to_parts(text: String) -> Vec<Part> {
46    const TOOL_CALL_START: &str = "<tool_call>";
47    const TOOL_CALL_END: &str = "</tool_call>";
48
49    if !text.contains(TOOL_CALL_START) {
50        return vec![Part::Text { text }];
51    }
52
53    let mut parts = Vec::new();
54    let mut remainder = text.as_str();
55
56    while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57        let (before, after_start_tag) = remainder.split_at(start_idx);
58        if !before.is_empty() {
59            parts.push(Part::Text { text: before.to_string() });
60        }
61
62        let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63        if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64            let block = &after_start[..end_idx];
65            if let Some(call_part) = parse_tool_call_block(block) {
66                parts.push(call_part);
67            } else {
68                // Failed to parse - keep as text
69                parts.push(Part::Text {
70                    text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71                });
72            }
73            remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74        } else {
75            // Unclosed tag - keep remainder as text
76            parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77            remainder = "";
78            break;
79        }
80    }
81
82    if !remainder.is_empty() {
83        parts.push(Part::Text { text: remainder.to_string() });
84    }
85
86    if parts.is_empty() {
87        vec![Part::Text { text }]
88    } else {
89        parts
90    }
91}
92
93/// Parse a tool call block into a FunctionCall part.
94fn parse_tool_call_block(block: &str) -> Option<Part> {
95    let trimmed = block.trim();
96    if trimmed.is_empty() {
97        return None;
98    }
99
100    let mut lines = trimmed.lines();
101    let name_line = lines.next()?.trim();
102    if name_line.is_empty() {
103        return None;
104    }
105
106    let remainder = lines.collect::<Vec<_>>().join("\n");
107    let mut slice = remainder.as_str();
108    let mut args_map = serde_json::Map::new();
109    let mut found_arg = false;
110
111    loop {
112        slice = slice.trim_start();
113        if slice.is_empty() {
114            break;
115        }
116
117        let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
118            rest
119        } else {
120            break;
121        };
122
123        let key_end = rest.find("</arg_key>")?;
124        let key = rest[..key_end].trim().to_string();
125        let mut after_key = &rest[key_end + "</arg_key>".len()..];
126
127        after_key = after_key.trim_start();
128        let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
129            rest
130        } else {
131            break;
132        };
133
134        let value_end = rest.find("</arg_value>")?;
135        let value_text = rest[..value_end].trim();
136        let value = parse_arg_value(value_text);
137        args_map.insert(key, value);
138        slice = &rest[value_end + "</arg_value>".len()..];
139        found_arg = true;
140    }
141
142    if !found_arg {
143        return None;
144    }
145
146    Some(Part::FunctionCall {
147        name: name_line.to_string(),
148        args: serde_json::Value::Object(args_map),
149        id: None,
150    })
151}
152
153/// Parse an argument value, attempting JSON parsing first.
154fn parse_arg_value(raw: &str) -> serde_json::Value {
155    let trimmed = raw.trim();
156    if trimmed.is_empty() {
157        return serde_json::Value::String(String::new());
158    }
159
160    serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_no_markup() {
169        let parts = convert_text_to_parts("Hello world".to_string());
170        assert_eq!(parts.len(), 1);
171        assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
172    }
173
174    #[test]
175    fn test_simple_tool_call() {
176        let text = r#"<tool_call>
177get_weather
178<arg_key>city</arg_key>
179<arg_value>Tokyo</arg_value>
180</tool_call>"#
181            .to_string();
182
183        let parts = convert_text_to_parts(text);
184        assert_eq!(parts.len(), 1);
185
186        if let Part::FunctionCall { name, args, .. } = &parts[0] {
187            assert_eq!(name, "get_weather");
188            assert_eq!(args["city"], "Tokyo");
189        } else {
190            panic!("Expected FunctionCall");
191        }
192    }
193
194    #[test]
195    fn test_tool_call_with_surrounding_text() {
196        let text = r#"Let me check the weather. <tool_call>
197get_weather
198<arg_key>city</arg_key>
199<arg_value>Paris</arg_value>
200</tool_call> Done!"#
201            .to_string();
202
203        let parts = convert_text_to_parts(text);
204        assert_eq!(parts.len(), 3);
205        assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
206        assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
207        assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
208    }
209
210    #[test]
211    fn test_multiple_args() {
212        let text = r#"<tool_call>
213calculator
214<arg_key>operation</arg_key>
215<arg_value>add</arg_value>
216<arg_key>a</arg_key>
217<arg_value>5</arg_value>
218<arg_key>b</arg_key>
219<arg_value>3</arg_value>
220</tool_call>"#
221            .to_string();
222
223        let parts = convert_text_to_parts(text);
224        assert_eq!(parts.len(), 1);
225
226        if let Part::FunctionCall { name, args, .. } = &parts[0] {
227            assert_eq!(name, "calculator");
228            assert_eq!(args["operation"], "add");
229            // Note: numeric values come as strings unless valid JSON
230            assert_eq!(args["a"], 5);
231            assert_eq!(args["b"], 3);
232        } else {
233            panic!("Expected FunctionCall");
234        }
235    }
236
237    #[test]
238    fn test_json_arg_value() {
239        let text = r#"<tool_call>
240process
241<arg_key>config</arg_key>
242<arg_value>{"enabled": true, "count": 42}</arg_value>
243</tool_call>"#
244            .to_string();
245
246        let parts = convert_text_to_parts(text);
247        assert_eq!(parts.len(), 1);
248
249        if let Part::FunctionCall { args, .. } = &parts[0] {
250            assert!(args["config"]["enabled"].as_bool().unwrap());
251            assert_eq!(args["config"]["count"], 42);
252        } else {
253            panic!("Expected FunctionCall");
254        }
255    }
256
257    #[test]
258    fn test_normalize_content() {
259        let mut content = Content {
260            role: "model".to_string(),
261            parts: vec![Part::Text {
262                text: r#"<tool_call>
263test_tool
264<arg_key>param</arg_key>
265<arg_value>value</arg_value>
266</tool_call>"#
267                    .to_string(),
268            }],
269        };
270
271        normalize_content(&mut content);
272        assert_eq!(content.parts.len(), 1);
273        assert!(matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool"));
274    }
275}