adk_agent/
tool_call_markup.rs1use adk_core::{Content, Part};
19
20pub fn normalize_content(content: &mut Content) {
22 let parts = std::mem::take(&mut content.parts);
23 let mut normalized = Vec::new();
24
25 for part in parts {
26 match part {
27 Part::Text { text } => {
28 normalized.extend(convert_text_to_parts(text));
29 }
30 other => normalized.push(other),
31 }
32 }
33
34 content.parts = normalized;
35}
36
37pub fn normalize_option_content(content: &mut Option<Content>) {
39 if let Some(ref mut content) = content {
40 normalize_content(content);
41 }
42}
43
44fn convert_text_to_parts(text: String) -> Vec<Part> {
46 const TOOL_CALL_START: &str = "<tool_call>";
47 const TOOL_CALL_END: &str = "</tool_call>";
48
49 if !text.contains(TOOL_CALL_START) {
50 return vec![Part::Text { text }];
51 }
52
53 let mut parts = Vec::new();
54 let mut remainder = text.as_str();
55
56 while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57 let (before, after_start_tag) = remainder.split_at(start_idx);
58 if !before.is_empty() {
59 parts.push(Part::Text { text: before.to_string() });
60 }
61
62 let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63 if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64 let block = &after_start[..end_idx];
65 if let Some(call_part) = parse_tool_call_block(block) {
66 parts.push(call_part);
67 } else {
68 parts.push(Part::Text {
70 text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71 });
72 }
73 remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74 } else {
75 parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77 remainder = "";
78 break;
79 }
80 }
81
82 if !remainder.is_empty() {
83 parts.push(Part::Text { text: remainder.to_string() });
84 }
85
86 if parts.is_empty() {
87 vec![Part::Text { text }]
88 } else {
89 parts
90 }
91}
92
93fn parse_tool_call_block(block: &str) -> Option<Part> {
95 let trimmed = block.trim();
96 if trimmed.is_empty() {
97 return None;
98 }
99
100 let mut lines = trimmed.lines();
101 let name_line = lines.next()?.trim();
102 if name_line.is_empty() {
103 return None;
104 }
105
106 let remainder = lines.collect::<Vec<_>>().join("\n");
107 let mut slice = remainder.as_str();
108 let mut args_map = serde_json::Map::new();
109 let mut found_arg = false;
110
111 loop {
112 slice = slice.trim_start();
113 if slice.is_empty() {
114 break;
115 }
116
117 let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
118 rest
119 } else {
120 break;
121 };
122
123 let key_end = rest.find("</arg_key>")?;
124 let key = rest[..key_end].trim().to_string();
125 let mut after_key = &rest[key_end + "</arg_key>".len()..];
126
127 after_key = after_key.trim_start();
128 let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
129 rest
130 } else {
131 break;
132 };
133
134 let value_end = rest.find("</arg_value>")?;
135 let value_text = rest[..value_end].trim();
136 let value = parse_arg_value(value_text);
137 args_map.insert(key, value);
138 slice = &rest[value_end + "</arg_value>".len()..];
139 found_arg = true;
140 }
141
142 if !found_arg {
143 return None;
144 }
145
146 Some(Part::FunctionCall {
147 name: name_line.to_string(),
148 args: serde_json::Value::Object(args_map),
149 id: None,
150 })
151}
152
153fn parse_arg_value(raw: &str) -> serde_json::Value {
155 let trimmed = raw.trim();
156 if trimmed.is_empty() {
157 return serde_json::Value::String(String::new());
158 }
159
160 serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn test_no_markup() {
169 let parts = convert_text_to_parts("Hello world".to_string());
170 assert_eq!(parts.len(), 1);
171 assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
172 }
173
174 #[test]
175 fn test_simple_tool_call() {
176 let text = r#"<tool_call>
177get_weather
178<arg_key>city</arg_key>
179<arg_value>Tokyo</arg_value>
180</tool_call>"#
181 .to_string();
182
183 let parts = convert_text_to_parts(text);
184 assert_eq!(parts.len(), 1);
185
186 if let Part::FunctionCall { name, args, .. } = &parts[0] {
187 assert_eq!(name, "get_weather");
188 assert_eq!(args["city"], "Tokyo");
189 } else {
190 panic!("Expected FunctionCall");
191 }
192 }
193
194 #[test]
195 fn test_tool_call_with_surrounding_text() {
196 let text = r#"Let me check the weather. <tool_call>
197get_weather
198<arg_key>city</arg_key>
199<arg_value>Paris</arg_value>
200</tool_call> Done!"#
201 .to_string();
202
203 let parts = convert_text_to_parts(text);
204 assert_eq!(parts.len(), 3);
205 assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
206 assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
207 assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
208 }
209
210 #[test]
211 fn test_multiple_args() {
212 let text = r#"<tool_call>
213calculator
214<arg_key>operation</arg_key>
215<arg_value>add</arg_value>
216<arg_key>a</arg_key>
217<arg_value>5</arg_value>
218<arg_key>b</arg_key>
219<arg_value>3</arg_value>
220</tool_call>"#
221 .to_string();
222
223 let parts = convert_text_to_parts(text);
224 assert_eq!(parts.len(), 1);
225
226 if let Part::FunctionCall { name, args, .. } = &parts[0] {
227 assert_eq!(name, "calculator");
228 assert_eq!(args["operation"], "add");
229 assert_eq!(args["a"], 5);
231 assert_eq!(args["b"], 3);
232 } else {
233 panic!("Expected FunctionCall");
234 }
235 }
236
237 #[test]
238 fn test_json_arg_value() {
239 let text = r#"<tool_call>
240process
241<arg_key>config</arg_key>
242<arg_value>{"enabled": true, "count": 42}</arg_value>
243</tool_call>"#
244 .to_string();
245
246 let parts = convert_text_to_parts(text);
247 assert_eq!(parts.len(), 1);
248
249 if let Part::FunctionCall { args, .. } = &parts[0] {
250 assert!(args["config"]["enabled"].as_bool().unwrap());
251 assert_eq!(args["config"]["count"], 42);
252 } else {
253 panic!("Expected FunctionCall");
254 }
255 }
256
257 #[test]
258 fn test_normalize_content() {
259 let mut content = Content {
260 role: "model".to_string(),
261 parts: vec![Part::Text {
262 text: r#"<tool_call>
263test_tool
264<arg_key>param</arg_key>
265<arg_value>value</arg_value>
266</tool_call>"#
267 .to_string(),
268 }],
269 };
270
271 normalize_content(&mut content);
272 assert_eq!(content.parts.len(), 1);
273 assert!(
274 matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool")
275 );
276 }
277}