adk_agent/
tool_call_markup.rs1use adk_core::{Content, Part};
19
20pub fn normalize_content(content: &mut Content) {
22 let parts = std::mem::take(&mut content.parts);
23 let mut normalized = Vec::new();
24
25 for part in parts {
26 match part {
27 Part::Text { text } => {
28 normalized.extend(convert_text_to_parts(text));
29 }
30 other => normalized.push(other),
31 }
32 }
33
34 content.parts = normalized;
35}
36
37pub fn normalize_option_content(content: &mut Option<Content>) {
39 if let Some(content) = content {
40 normalize_content(content);
41 }
42}
43
44fn convert_text_to_parts(text: String) -> Vec<Part> {
46 const TOOL_CALL_START: &str = "<tool_call>";
47 const TOOL_CALL_END: &str = "</tool_call>";
48
49 if !text.contains(TOOL_CALL_START) {
50 return vec![Part::Text { text }];
51 }
52
53 let mut parts = Vec::new();
54 let mut remainder = text.as_str();
55
56 while let Some(start_idx) = remainder.find(TOOL_CALL_START) {
57 let (before, after_start_tag) = remainder.split_at(start_idx);
58 if !before.is_empty() {
59 parts.push(Part::Text { text: before.to_string() });
60 }
61
62 let after_start = &after_start_tag[TOOL_CALL_START.len()..];
63 if let Some(end_idx) = after_start.find(TOOL_CALL_END) {
64 let block = &after_start[..end_idx];
65 if let Some(call_part) = parse_tool_call_block(block) {
66 parts.push(call_part);
67 } else {
68 parts.push(Part::Text {
70 text: format!("{}{}{}", TOOL_CALL_START, block, TOOL_CALL_END),
71 });
72 }
73 remainder = &after_start[end_idx + TOOL_CALL_END.len()..];
74 } else {
75 parts.push(Part::Text { text: format!("{}{}", TOOL_CALL_START, after_start) });
77 remainder = "";
78 break;
79 }
80 }
81
82 if !remainder.is_empty() {
83 parts.push(Part::Text { text: remainder.to_string() });
84 }
85
86 if parts.is_empty() { vec![Part::Text { text }] } else { parts }
87}
88
89fn parse_tool_call_block(block: &str) -> Option<Part> {
91 let trimmed = block.trim();
92 if trimmed.is_empty() {
93 return None;
94 }
95
96 let mut lines = trimmed.lines();
97 let name_line = lines.next()?.trim();
98 if name_line.is_empty() {
99 return None;
100 }
101
102 let remainder = lines.collect::<Vec<_>>().join("\n");
103 let mut slice = remainder.as_str();
104 let mut args_map = serde_json::Map::new();
105 let mut found_arg = false;
106
107 loop {
108 slice = slice.trim_start();
109 if slice.is_empty() {
110 break;
111 }
112
113 let rest = if let Some(rest) = slice.strip_prefix("<arg_key>") {
114 rest
115 } else {
116 break;
117 };
118
119 let key_end = rest.find("</arg_key>")?;
120 let key = rest[..key_end].trim().to_string();
121 let mut after_key = &rest[key_end + "</arg_key>".len()..];
122
123 after_key = after_key.trim_start();
124 let rest = if let Some(rest) = after_key.strip_prefix("<arg_value>") {
125 rest
126 } else {
127 break;
128 };
129
130 let value_end = rest.find("</arg_value>")?;
131 let value_text = rest[..value_end].trim();
132 let value = parse_arg_value(value_text);
133 args_map.insert(key, value);
134 slice = &rest[value_end + "</arg_value>".len()..];
135 found_arg = true;
136 }
137
138 if !found_arg {
139 return None;
140 }
141
142 Some(Part::FunctionCall {
143 name: name_line.to_string(),
144 args: serde_json::Value::Object(args_map),
145 id: None,
146 thought_signature: None,
147 })
148}
149
150fn parse_arg_value(raw: &str) -> serde_json::Value {
152 let trimmed = raw.trim();
153 if trimmed.is_empty() {
154 return serde_json::Value::String(String::new());
155 }
156
157 serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string()))
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_no_markup() {
166 let parts = convert_text_to_parts("Hello world".to_string());
167 assert_eq!(parts.len(), 1);
168 assert!(matches!(&parts[0], Part::Text { text } if text == "Hello world"));
169 }
170
171 #[test]
172 fn test_simple_tool_call() {
173 let text = r#"<tool_call>
174get_weather
175<arg_key>city</arg_key>
176<arg_value>Tokyo</arg_value>
177</tool_call>"#
178 .to_string();
179
180 let parts = convert_text_to_parts(text);
181 assert_eq!(parts.len(), 1);
182
183 if let Part::FunctionCall { name, args, .. } = &parts[0] {
184 assert_eq!(name, "get_weather");
185 assert_eq!(args["city"], "Tokyo");
186 } else {
187 panic!("Expected FunctionCall");
188 }
189 }
190
191 #[test]
192 fn test_tool_call_with_surrounding_text() {
193 let text = r#"Let me check the weather. <tool_call>
194get_weather
195<arg_key>city</arg_key>
196<arg_value>Paris</arg_value>
197</tool_call> Done!"#
198 .to_string();
199
200 let parts = convert_text_to_parts(text);
201 assert_eq!(parts.len(), 3);
202 assert!(matches!(&parts[0], Part::Text { text } if text.contains("Let me check")));
203 assert!(matches!(&parts[1], Part::FunctionCall { name, .. } if name == "get_weather"));
204 assert!(matches!(&parts[2], Part::Text { text } if text.contains("Done")));
205 }
206
207 #[test]
208 fn test_multiple_args() {
209 let text = r#"<tool_call>
210calculator
211<arg_key>operation</arg_key>
212<arg_value>add</arg_value>
213<arg_key>a</arg_key>
214<arg_value>5</arg_value>
215<arg_key>b</arg_key>
216<arg_value>3</arg_value>
217</tool_call>"#
218 .to_string();
219
220 let parts = convert_text_to_parts(text);
221 assert_eq!(parts.len(), 1);
222
223 if let Part::FunctionCall { name, args, .. } = &parts[0] {
224 assert_eq!(name, "calculator");
225 assert_eq!(args["operation"], "add");
226 assert_eq!(args["a"], 5);
228 assert_eq!(args["b"], 3);
229 } else {
230 panic!("Expected FunctionCall");
231 }
232 }
233
234 #[test]
235 fn test_json_arg_value() {
236 let text = r#"<tool_call>
237process
238<arg_key>config</arg_key>
239<arg_value>{"enabled": true, "count": 42}</arg_value>
240</tool_call>"#
241 .to_string();
242
243 let parts = convert_text_to_parts(text);
244 assert_eq!(parts.len(), 1);
245
246 if let Part::FunctionCall { args, .. } = &parts[0] {
247 assert!(args["config"]["enabled"].as_bool().unwrap());
248 assert_eq!(args["config"]["count"], 42);
249 } else {
250 panic!("Expected FunctionCall");
251 }
252 }
253
254 #[test]
255 fn test_normalize_content() {
256 let mut content = Content {
257 role: "model".to_string(),
258 parts: vec![Part::Text {
259 text: r#"<tool_call>
260test_tool
261<arg_key>param</arg_key>
262<arg_value>value</arg_value>
263</tool_call>"#
264 .to_string(),
265 }],
266 };
267
268 normalize_content(&mut content);
269 assert_eq!(content.parts.len(), 1);
270 assert!(
271 matches!(&content.parts[0], Part::FunctionCall { name, .. } if name == "test_tool")
272 );
273 }
274}