adk_agent/
tool_call_markup.rs1use adk_core::{Content, Part};
19
20pub fn normalize_content(content: &mut Content) {
22 let parts = std::mem::take(&mut content.parts);
23 let mut normalized = Vec::new();
24
25 for part in parts {
26 match part {
27 Part::Text { text } => {
28 normalized.extend(convert_text_to_parts(text));
29 }
30 other => normalized.push(other),
31 }
32 }
33
34 content.parts = normalized;
35}
36
37pub fn normalize_option_content(content: &mut Option<Content>) {
39 if let Some(content) = content {
40 normalize_content(content);
41 }
42}
43
44fn convert_text_to_parts(text: String) -> Vec<Part> {
46 const TOOL_CALL_START: &str = "<tool_call>";
47 const TOOL_CALL_END: &str = "</tool_call>";
48
49 if !text.contains(TOOL_CALL_START) {
50 return vec![Part::Text { text }];
51 }
52
53 let mut parts = Vec::new();
54 let mut remainder = text.as_str();
55
56 while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57 let (before, after_start_tag) = remainder.split_at(start_idx);
58 if !before.is_empty() {
59 parts.push(Part::Text { text: before.to_string() });
60 }
61
62 let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63 if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64 let block = &after_start[..end_idx];
65 if let Some(call_part) = parse_tool_call_block(block) {
66 parts.push(call_part);
67 } else {
68 parts.push(Part::Text {
70 text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71 });
72 }
73 remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74 } else {
75 parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77 remainder = "";
78 break;
79 }
80 }
81
82 if !remainder.is_empty() {
83 parts.push(Part::Text { text: remainder.to_string() });
84 }
85
86 if parts.is_empty() { vec![Part::Text { text }] } else { parts }
87}
88
89fn parse_tool_call_block(block: &str) -> Option<Part> {
91 let trimmed = block.trim();
92 if trimmed.is_empty() {
93 return None;
94 }
95
96 let mut lines = trimmed.lines();
97 let name_line = lines.next()?.trim();
98 if name_line.is_empty() {
99 return None;
100 }
101
102 let remainder = lines.collect::<Vec<_>>().join("\n");
103 let mut slice = remainder.as_str();
104 let mut args_map = serde_json::Map::new();
105 let mut found_arg = false;
106
107 loop {
108 slice = slice.trim_start();
109 if slice.is_empty() {
110 break;
111 }
112
113 let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
114 rest
115 } else {
116 break;
117 };
118
119 let key_end = rest.find("</arg_key>")?;
120 let key = rest[..key_end].trim().to_string();
121 let mut after_key = &rest[key_end + "</arg_key>".len()..];
122
123 after_key = after_key.trim_start();
124 let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
125 rest
126 } else {
127 break;
128 };
129
130 let value_end = rest.find("</arg_value>")?;
131 let value_text = rest[..value_end].trim();
132 let value = parse_arg_value(value_text);
133 args_map.insert(key, value);
134 slice = &rest[value_end + "</arg_value>".len()..];
135 found_arg = true;
136 }
137
138 if !found_arg {
139 return None;
140 }
141
142 Some(Part::FunctionCall {
143 name: name_line.to_string(),
144 args: serde_json::Value::Object(args_map),
145 id: None,
146 })
147}
148
149fn parse_arg_value(raw: &str) -> serde_json::Value {
151 let trimmed = raw.trim();
152 if trimmed.is_empty() {
153 return serde_json::Value::String(String::new());
154 }
155
156 serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_no_markup() {
165 let parts = convert_text_to_parts("Hello world".to_string());
166 assert_eq!(parts.len(), 1);
167 assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
168 }
169
170 #[test]
171 fn test_simple_tool_call() {
172 let text = r#"<tool_call>
173get_weather
174<arg_key>city</arg_key>
175<arg_value>Tokyo</arg_value>
176</tool_call>"#
177 .to_string();
178
179 let parts = convert_text_to_parts(text);
180 assert_eq!(parts.len(), 1);
181
182 if let Part::FunctionCall { name, args, .. } = &parts[0] {
183 assert_eq!(name, "get_weather");
184 assert_eq!(args["city"], "Tokyo");
185 } else {
186 panic!("Expected FunctionCall");
187 }
188 }
189
190 #[test]
191 fn test_tool_call_with_surrounding_text() {
192 let text = r#"Let me check the weather. <tool_call>
193get_weather
194<arg_key>city</arg_key>
195<arg_value>Paris</arg_value>
196</tool_call> Done!"#
197 .to_string();
198
199 let parts = convert_text_to_parts(text);
200 assert_eq!(parts.len(), 3);
201 assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
202 assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
203 assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
204 }
205
206 #[test]
207 fn test_multiple_args() {
208 let text = r#"<tool_call>
209calculator
210<arg_key>operation</arg_key>
211<arg_value>add</arg_value>
212<arg_key>a</arg_key>
213<arg_value>5</arg_value>
214<arg_key>b</arg_key>
215<arg_value>3</arg_value>
216</tool_call>"#
217 .to_string();
218
219 let parts = convert_text_to_parts(text);
220 assert_eq!(parts.len(), 1);
221
222 if let Part::FunctionCall { name, args, .. } = &parts[0] {
223 assert_eq!(name, "calculator");
224 assert_eq!(args["operation"], "add");
225 assert_eq!(args["a"], 5);
227 assert_eq!(args["b"], 3);
228 } else {
229 panic!("Expected FunctionCall");
230 }
231 }
232
233 #[test]
234 fn test_json_arg_value() {
235 let text = r#"<tool_call>
236process
237<arg_key>config</arg_key>
238<arg_value>{"enabled": true, "count": 42}</arg_value>
239</tool_call>"#
240 .to_string();
241
242 let parts = convert_text_to_parts(text);
243 assert_eq!(parts.len(), 1);
244
245 if let Part::FunctionCall { args, .. } = &parts[0] {
246 assert!(args["config"]["enabled"].as_bool().unwrap());
247 assert_eq!(args["config"]["count"], 42);
248 } else {
249 panic!("Expected FunctionCall");
250 }
251 }
252
253 #[test]
254 fn test_normalize_content() {
255 let mut content = Content {
256 role: "model".to_string(),
257 parts: vec![Part::Text {
258 text: r#"<tool_call>
259test_tool
260<arg_key>param</arg_key>
261<arg_value>value</arg_value>
262</tool_call>"#
263 .to_string(),
264 }],
265 };
266
267 normalize_content(&mut content);
268 assert_eq!(content.parts.len(), 1);
269 assert!(
270 matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool")
271 );
272 }
273}