adk_agent/
tool_call_markup.rs

1//! XML-based tool call markup parsing.
2//!
3//! Some models (especially smaller ones or those without native function calling)
4//! output tool calls using XML-like markup:
5//!
6//! ```text
7//! <tool_call>
8//! function_name
9//! <arg_key>param1</arg_key>
10//! <arg_value>value1</arg_value>
11//! <arg_key>param2</arg_key>
12//! <arg_value>value2</arg_value>
13//! </tool_call>
14//! ```
15//!
16//! This module provides utilities to parse such markup into proper `Part::FunctionCall`.
17
18use adk_core::{Content, Part};
19
20/// Normalize content by converting tool call markup in text parts to FunctionCall parts.
21pub fn normalize_content(content: &mut Content) {
22    let parts = std::mem::take(&mut content.parts);
23    let mut normalized = Vec::new();
24
25    for part in parts {
26        match part {
27            Part::Text { text } => {
28                normalized.extend(convert_text_to_parts(text));
29            }
30            other => normalized.push(other),
31        }
32    }
33
34    content.parts = normalized;
35}
36
37/// Normalize `Option<Content>` by converting tool call markup.
38pub fn normalize_option_content(content: &mut Option<Content>) {
39    if let Some(content) = content {
40        normalize_content(content);
41    }
42}
43
44/// Convert text containing tool call markup to a list of parts.
45fn convert_text_to_parts(text: String) -> Vec<Part> {
46    const TOOL_CALL_START: &str = "<tool_call>";
47    const TOOL_CALL_END: &str = "</tool_call>";
48
49    if !text.contains(TOOL_CALL_START) {
50        return vec![Part::Text { text }];
51    }
52
53    let mut parts = Vec::new();
54    let mut remainder = text.as_str();
55
56    while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57        let (before, after_start_tag) = remainder.split_at(start_idx);
58        if !before.is_empty() {
59            parts.push(Part::Text { text: before.to_string() });
60        }
61
62        let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63        if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64            let block = &after_start[..end_idx];
65            if let Some(call_part) = parse_tool_call_block(block) {
66                parts.push(call_part);
67            } else {
68                // Failed to parse - keep as text
69                parts.push(Part::Text {
70                    text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71                });
72            }
73            remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74        } else {
75            // Unclosed tag - keep remainder as text
76            parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77            remainder = "";
78            break;
79        }
80    }
81
82    if !remainder.is_empty() {
83        parts.push(Part::Text { text: remainder.to_string() });
84    }
85
86    if parts.is_empty() { vec![Part::Text { text }] } else { parts }
87}
88
89/// Parse a tool call block into a FunctionCall part.
90fn parse_tool_call_block(block: &str) -> Option<Part> {
91    let trimmed = block.trim();
92    if trimmed.is_empty() {
93        return None;
94    }
95
96    let mut lines = trimmed.lines();
97    let name_line = lines.next()?.trim();
98    if name_line.is_empty() {
99        return None;
100    }
101
102    let remainder = lines.collect::<Vec<_>>().join("\n");
103    let mut slice = remainder.as_str();
104    let mut args_map = serde_json::Map::new();
105    let mut found_arg = false;
106
107    loop {
108        slice = slice.trim_start();
109        if slice.is_empty() {
110            break;
111        }
112
113        let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
114            rest
115        } else {
116            break;
117        };
118
119        let key_end = rest.find("</arg_key>")?;
120        let key = rest[..key_end].trim().to_string();
121        let mut after_key = &rest[key_end + "</arg_key>".len()..];
122
123        after_key = after_key.trim_start();
124        let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
125            rest
126        } else {
127            break;
128        };
129
130        let value_end = rest.find("</arg_value>")?;
131        let value_text = rest[..value_end].trim();
132        let value = parse_arg_value(value_text);
133        args_map.insert(key, value);
134        slice = &rest[value_end + "</arg_value>".len()..];
135        found_arg = true;
136    }
137
138    if !found_arg {
139        return None;
140    }
141
142    Some(Part::FunctionCall {
143        name: name_line.to_string(),
144        args: serde_json::Value::Object(args_map),
145        id: None,
146    })
147}
148
149/// Parse an argument value, attempting JSON parsing first.
150fn parse_arg_value(raw: &str) -> serde_json::Value {
151    let trimmed = raw.trim();
152    if trimmed.is_empty() {
153        return serde_json::Value::String(String::new());
154    }
155
156    serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_no_markup() {
165        let parts = convert_text_to_parts("Hello world".to_string());
166        assert_eq!(parts.len(), 1);
167        assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
168    }
169
170    #[test]
171    fn test_simple_tool_call() {
172        let text = r#"<tool_call>
173get_weather
174<arg_key>city</arg_key>
175<arg_value>Tokyo</arg_value>
176</tool_call>"#
177            .to_string();
178
179        let parts = convert_text_to_parts(text);
180        assert_eq!(parts.len(), 1);
181
182        if let Part::FunctionCall { name, args, .. } = &parts[0] {
183            assert_eq!(name, "get_weather");
184            assert_eq!(args["city"], "Tokyo");
185        } else {
186            panic!("Expected FunctionCall");
187        }
188    }
189
190    #[test]
191    fn test_tool_call_with_surrounding_text() {
192        let text = r#"Let me check the weather. <tool_call>
193get_weather
194<arg_key>city</arg_key>
195<arg_value>Paris</arg_value>
196</tool_call> Done!"#
197            .to_string();
198
199        let parts = convert_text_to_parts(text);
200        assert_eq!(parts.len(), 3);
201        assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
202        assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
203        assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
204    }
205
206    #[test]
207    fn test_multiple_args() {
208        let text = r#"<tool_call>
209calculator
210<arg_key>operation</arg_key>
211<arg_value>add</arg_value>
212<arg_key>a</arg_key>
213<arg_value>5</arg_value>
214<arg_key>b</arg_key>
215<arg_value>3</arg_value>
216</tool_call>"#
217            .to_string();
218
219        let parts = convert_text_to_parts(text);
220        assert_eq!(parts.len(), 1);
221
222        if let Part::FunctionCall { name, args, .. } = &parts[0] {
223            assert_eq!(name, "calculator");
224            assert_eq!(args["operation"], "add");
225            // Note: numeric values come as strings unless valid JSON
226            assert_eq!(args["a"], 5);
227            assert_eq!(args["b"], 3);
228        } else {
229            panic!("Expected FunctionCall");
230        }
231    }
232
233    #[test]
234    fn test_json_arg_value() {
235        let text = r#"<tool_call>
236process
237<arg_key>config</arg_key>
238<arg_value>{"enabled": true, "count": 42}</arg_value>
239</tool_call>"#
240            .to_string();
241
242        let parts = convert_text_to_parts(text);
243        assert_eq!(parts.len(), 1);
244
245        if let Part::FunctionCall { args, .. } = &parts[0] {
246            assert!(args["config"]["enabled"].as_bool().unwrap());
247            assert_eq!(args["config"]["count"], 42);
248        } else {
249            panic!("Expected FunctionCall");
250        }
251    }
252
253    #[test]
254    fn test_normalize_content() {
255        let mut content = Content {
256            role: "model".to_string(),
257            parts: vec![Part::Text {
258                text: r#"<tool_call>
259test_tool
260<arg_key>param</arg_key>
261<arg_value>value</arg_value>
262</tool_call>"#
263                    .to_string(),
264            }],
265        };
266
267        normalize_content(&mut content);
268        assert_eq!(content.parts.len(), 1);
269        assert!(
270            matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool")
271        );
272    }
273}