Skip to main content

adk_agent/
tool_call_markup.rs

1//! XML-based tool call markup parsing.
2//!
3//! Some models (especially smaller ones or those without native function calling)
4//! output tool calls using XML-like markup:
5//!
6//! ```text
7//! <tool_call>
8//! function_name
9//! <arg_key>param1</arg_key>
10//! <arg_value>value1</arg_value>
11//! <arg_key>param2</arg_key>
12//! <arg_value>value2</arg_value>
13//! </tool_call>
14//! ```
15//!
16//! This module provides utilities to parse such markup into proper `Part::FunctionCall`.
17
18use adk_core::{Content, Part};
19
20/// Normalize content by converting tool call markup in text parts to FunctionCall parts.
21pub fn normalize_content(content: &mut Content) {
22    let parts = std::mem::take(&mut content.parts);
23    let mut normalized = Vec::new();
24
25    for part in parts {
26        match part {
27            Part::Text { text } => {
28                normalized.extend(convert_text_to_parts(text));
29            }
30            other => normalized.push(other),
31        }
32    }
33
34    content.parts = normalized;
35}
36
37/// Normalize `Option<Content>` by converting tool call markup.
38pub fn normalize_option_content(content: &mut Option<Content>) {
39    if let Some(content) = content {
40        normalize_content(content);
41    }
42}
43
44/// Convert text containing tool call markup to a list of parts.
45fn convert_text_to_parts(text: String) -> Vec<Part> {
46    const TOOL_CALL_START: &str = "<tool_call>";
47    const TOOL_CALL_END: &str = "</tool_call>";
48
49    if !text.contains(TOOL_CALL_START) {
50        return vec![Part::Text { text }];
51    }
52
53    let mut parts = Vec::new();
54    let mut remainder = text.as_str();
55
56    while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57        let (before, after_start_tag) = remainder.split_at(start_idx);
58        if !before.is_empty() {
59            parts.push(Part::Text { text: before.to_string() });
60        }
61
62        let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63        if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64            let block = &after_start[..end_idx];
65            if let Some(call_part) = parse_tool_call_block(block) {
66                parts.push(call_part);
67            } else {
68                // Failed to parse - keep as text
69                parts.push(Part::Text {
70                    text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71                });
72            }
73            remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74        } else {
75            // Unclosed tag - keep remainder as text
76            parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77            remainder = "";
78            break;
79        }
80    }
81
82    if !remainder.is_empty() {
83        parts.push(Part::Text { text: remainder.to_string() });
84    }
85
86    if parts.is_empty() { vec![Part::Text { text }] } else { parts }
87}
88
89/// Parse a tool call block into a FunctionCall part.
90fn parse_tool_call_block(block: &str) -> Option<Part> {
91    let trimmed = block.trim();
92    if trimmed.is_empty() {
93        return None;
94    }
95
96    let mut lines = trimmed.lines();
97    let name_line = lines.next()?.trim();
98    if name_line.is_empty() {
99        return None;
100    }
101
102    let remainder = lines.collect::<Vec<_>>().join("\n");
103    let mut slice = remainder.as_str();
104    let mut args_map = serde_json::Map::new();
105    let mut found_arg = false;
106
107    loop {
108        slice = slice.trim_start();
109        if slice.is_empty() {
110            break;
111        }
112
113        let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
114            rest
115        } else {
116            break;
117        };
118
119        let key_end = rest.find("</arg_key>")?;
120        let key = rest[..key_end].trim().to_string();
121        let mut after_key = &rest[key_end + "</arg_key>".len()..];
122
123        after_key = after_key.trim_start();
124        let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
125            rest
126        } else {
127            break;
128        };
129
130        let value_end = rest.find("</arg_value>")?;
131        let value_text = rest[..value_end].trim();
132        let value = parse_arg_value(value_text);
133        args_map.insert(key, value);
134        slice = &rest[value_end + "</arg_value>".len()..];
135        found_arg = true;
136    }
137
138    if !found_arg {
139        return None;
140    }
141
142    Some(Part::FunctionCall {
143        name: name_line.to_string(),
144        args: serde_json::Value::Object(args_map),
145        id: None,
146        thought_signature: None,
147    })
148}
149
150/// Parse an argument value, attempting JSON parsing first.
151fn parse_arg_value(raw: &str) -> serde_json::Value {
152    let trimmed = raw.trim();
153    if trimmed.is_empty() {
154        return serde_json::Value::String(String::new());
155    }
156
157    serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_no_markup() {
166        let parts = convert_text_to_parts("Hello world".to_string());
167        assert_eq!(parts.len(), 1);
168        assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
169    }
170
171    #[test]
172    fn test_simple_tool_call() {
173        let text = r#"<tool_call>
174get_weather
175<arg_key>city</arg_key>
176<arg_value>Tokyo</arg_value>
177</tool_call>"#
178            .to_string();
179
180        let parts = convert_text_to_parts(text);
181        assert_eq!(parts.len(), 1);
182
183        if let Part::FunctionCall { name, args, .. } = &parts[0] {
184            assert_eq!(name, "get_weather");
185            assert_eq!(args["city"], "Tokyo");
186        } else {
187            panic!("Expected FunctionCall");
188        }
189    }
190
191    #[test]
192    fn test_tool_call_with_surrounding_text() {
193        let text = r#"Let me check the weather. <tool_call>
194get_weather
195<arg_key>city</arg_key>
196<arg_value>Paris</arg_value>
197</tool_call> Done!"#
198            .to_string();
199
200        let parts = convert_text_to_parts(text);
201        assert_eq!(parts.len(), 3);
202        assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
203        assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
204        assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
205    }
206
207    #[test]
208    fn test_multiple_args() {
209        let text = r#"<tool_call>
210calculator
211<arg_key>operation</arg_key>
212<arg_value>add</arg_value>
213<arg_key>a</arg_key>
214<arg_value>5</arg_value>
215<arg_key>b</arg_key>
216<arg_value>3</arg_value>
217</tool_call>"#
218            .to_string();
219
220        let parts = convert_text_to_parts(text);
221        assert_eq!(parts.len(), 1);
222
223        if let Part::FunctionCall { name, args, .. } = &parts[0] {
224            assert_eq!(name, "calculator");
225            assert_eq!(args["operation"], "add");
226            // Note: numeric values come as strings unless valid JSON
227            assert_eq!(args["a"], 5);
228            assert_eq!(args["b"], 3);
229        } else {
230            panic!("Expected FunctionCall");
231        }
232    }
233
234    #[test]
235    fn test_json_arg_value() {
236        let text = r#"<tool_call>
237process
238<arg_key>config</arg_key>
239<arg_value>{"enabled": true, "count": 42}</arg_value>
240</tool_call>"#
241            .to_string();
242
243        let parts = convert_text_to_parts(text);
244        assert_eq!(parts.len(), 1);
245
246        if let Part::FunctionCall { args, .. } = &parts[0] {
247            assert!(args["config"]["enabled"].as_bool().unwrap());
248            assert_eq!(args["config"]["count"], 42);
249        } else {
250            panic!("Expected FunctionCall");
251        }
252    }
253
254    #[test]
255    fn test_normalize_content() {
256        let mut content = Content {
257            role: "model".to_string(),
258            parts: vec![Part::Text {
259                text: r#"<tool_call>
260test_tool
261<arg_key>param</arg_key>
262<arg_value>value</arg_value>
263</tool_call>"#
264                    .to_string(),
265            }],
266        };
267
268        normalize_content(&mut content);
269        assert_eq!(content.parts.len(), 1);
270        assert!(
271            matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool")
272        );
273    }
274}