1use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
5use regex::Regex;
6use rustpython_parser::{
7 Mode,
8 ast::{Constant, Expr, Mod},
9 parse,
10};
11use serde_json::{Number, Value, json};
12use std::sync::OnceLock;
13
14static PYTHONIC_REGEX: OnceLock<Regex> = OnceLock::new();
15
16fn get_pythonic_regex() -> &'static Regex {
19 PYTHONIC_REGEX.get_or_init(|| {
20 let pattern = r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s?)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s*)?\)\s*)+\]";
22 Regex::new(pattern).expect("Failed to compile pythonic regex pattern")
23 })
24}
25
26fn strip_text(message: &str) -> String {
27 message
29 .replace("<|python_start|>", "")
30 .replace("<|python_end|>", "")
31}
32
33fn get_regex_matches(message: &str) -> Vec<String> {
34 let re = get_pythonic_regex();
35 let mut matches = Vec::new();
36 for cap in re.find_iter(message) {
37 matches.push(cap.as_str().to_string());
38 }
39 matches
40}
41
42pub fn parse_tool_calls(src: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
43 let ast = parse(src, Mode::Expression, "<input>")?;
44
45 let body = match ast {
56 Mod::Expression(mod_expr) => mod_expr.body,
57 _ => return Ok(vec![]),
58 };
59
60 let elts = match *body {
61 Expr::List(expr_list) => expr_list.elts,
62 _ => return Ok(vec![]),
63 };
64
65 let mut res = Vec::with_capacity(elts.len());
66 for (idx, elt) in elts.iter().enumerate() {
67 let (func, keywords) = match elt {
68 Expr::Call(call) => (&call.func, &call.keywords),
69 _ => continue,
70 };
71
72 let name = match func.as_ref() {
73 Expr::Name(name) => name.id.clone(),
74 _ => continue,
75 };
76
77 let mut obj = serde_json::Map::new();
78 for keyword in keywords.iter() {
79 let Some(arg_ident) = keyword.arg.as_ref() else {
80 tracing::debug!(
81 "Skipping **kwargs in pythonic tool call for function {}",
82 name
83 );
84 continue;
85 };
86
87 match const_expr(&keyword.value) {
88 Ok(value) => {
89 obj.insert(arg_ident.to_string(), value);
90 }
91 Err(e) => {
92 tracing::debug!("Skipping non-constant argument {}: {}", arg_ident, e);
93 }
94 }
95 }
96
97 res.push(ToolCallResponse {
98 id: format!("call-{}", idx + 1),
99 tp: ToolCallType::Function,
100 function: CalledFunction {
101 name: name.to_string(),
102 arguments: serde_json::to_string(&Value::Object(obj))?,
104 },
105 });
106 }
107 Ok(res)
108}
109
110fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
111 match e {
112 Expr::Constant(constant) => Ok(match &constant.value {
113 Constant::Bool(b) => json!(b),
114 Constant::None => Value::Null,
115 Constant::Int(i) => {
116 use num_traits::ToPrimitive;
118 if let Some(v) = i.to_i64() {
119 Value::Number(Number::from(v))
120 } else if let Some(v) = i.to_u64() {
121 Value::Number(Number::from(v))
122 } else {
123 Value::String(i.to_string())
124 }
125 }
126 Constant::Float(f) => json!(f),
127 Constant::Str(s) => json!(s),
128 _ => return Err("unsupported constant type".into()),
129 }),
130 Expr::List(expr_list) => {
132 let list_values: Result<Vec<Value>, Box<dyn std::error::Error>> =
133 expr_list.elts.iter().map(|e| const_expr(e)).collect();
134 Ok(json!(list_values?))
135 }
136 Expr::Dict(expr_dict) => {
138 let mut dict_map = std::collections::HashMap::new();
139 for (key_expr, value_expr) in expr_dict.keys.iter().zip(expr_dict.values.iter()) {
140 let key = match key_expr {
143 Some(k) => match const_expr(k)? {
144 Value::String(s) => s,
145 other => other.to_string(),
146 },
147 None => {
148 return Err(
149 "dictionary unpacking (**kwargs) not supported in constants".into()
150 );
151 }
152 };
153 let value = const_expr(value_expr)?;
154 dict_map.insert(key, value);
155 }
156 Ok(json!(dict_map))
157 }
158 _ => Err("only constant values, lists, and dicts are allowed".into()),
159 }
160}
161
162pub fn try_tool_call_parse_pythonic(
163 message: &str,
164) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
165 let stripped = strip_text(message).trim().to_string();
166
167 if stripped.is_empty() {
169 return Ok((vec![], Some(String::new())));
170 }
171
172 let matches = get_regex_matches(&stripped);
173 if matches.is_empty() {
174 return Ok((vec![], Some(stripped)));
175 }
176
177 let tool_response = parse_tool_calls(&matches[0]);
178
179 let normal_text = stripped
181 .split(&matches[0])
182 .next()
183 .unwrap() .trim()
185 .to_string();
186
187 Ok((tool_response?, Some(normal_text)))
188}
189
190pub fn detect_tool_call_start_pythonic(chunk: &str) -> bool {
191 let trimmed = chunk.trim();
192 if trimmed.is_empty() {
194 return false;
195 }
196 trimmed.contains('[')
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
205 let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
206 (call.function.name, args)
207 }
208
209 #[test]
210 fn test_strip_text() {
211 let message = "Hello, world!";
212 let stripped = strip_text(message);
213 assert_eq!(stripped, "Hello, world!");
214
215 let message = "<|python_start|>foo(a=1, b=2)<|python_end|>";
216 let stripped = strip_text(message);
217 assert_eq!(stripped, "foo(a=1, b=2)");
218
219 let message = "<|python_start|>foo(a=1, b=2)";
220 let stripped = strip_text(message);
221 assert_eq!(stripped, "foo(a=1, b=2)");
222
223 let message = "foo(a=1, b=2)<|python_end|>";
224 let stripped = strip_text(message);
225 assert_eq!(stripped, "foo(a=1, b=2)");
226 }
227
228 #[test]
229 fn test_get_regex_matches_simple_case() {
230 let message = "[foo(a=1, b=2), bar(x=3)]";
232 let matches = get_regex_matches(message);
233 assert_eq!(matches.len(), 1);
234 assert_eq!(matches[0], "[foo(a=1, b=2), bar(x=3)]");
235 }
236
237 #[test]
238 fn test_get_regex_matches_text_before_and_after() {
239 let message = "Hey yo ! [foo(a=1, b=2), bar(x= 3)] Hey yo";
241 let matches = get_regex_matches(message);
242 assert_eq!(matches.len(), 1);
243 assert_eq!(matches[0], "[foo(a=1, b=2), bar(x= 3)]");
244 }
245
246 #[test]
247 fn test_get_regex_matches_new_line_in_arg_and_value() {
248 let message = "Hey \n yo ! [foo(a=1,b=2), \n bar(x=3)] Hey yo";
250 let matches = get_regex_matches(message);
251 assert_eq!(matches.len(), 1);
252 assert_eq!(matches[0], "[foo(a=1,b=2), \n bar(x=3)]");
253 }
254
255 #[test]
256 fn test_get_regex_matches_no_call() {
257 let message = "Hey yo !";
259 let matches = get_regex_matches(message);
260 assert_eq!(matches.len(), 0);
261 }
262
263 #[test]
264 fn test_parse_tool_call_parse_pythonic_basic() {
265 let message = "[foo(a=1, b=2), bar(x=3)]";
266 let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
267 assert_eq!(content, Some("".to_string()));
268 assert!(!result.is_empty());
269 assert_eq!(result.len(), 2);
270 let (name, args) = extract_name_and_args(result[0].clone()); assert_eq!(name, "foo");
272 assert_eq!(args["a"], 1);
273 assert_eq!(args["b"], 2);
274 let (name, args) = extract_name_and_args(result[1].clone());
275 assert_eq!(name, "bar");
276 assert_eq!(args["x"], 3);
277 }
278
279 #[test]
280 fn test_parse_tool_call_parse_pythonic_with_text() {
281 let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
282 let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
283 assert_eq!(content, Some("Hey yo !".to_string()));
284 assert!(!result.is_empty());
285 assert_eq!(result.len(), 2);
286 let (name, args) = extract_name_and_args(result[0].clone());
287 assert_eq!(name, "foo");
288 assert_eq!(args["a"], 1);
289 assert_eq!(args["b"], 2);
290 let (name, args) = extract_name_and_args(result[1].clone());
291 assert_eq!(name, "bar");
292 assert_eq!(args["x"], 3);
293 }
294
295 #[test]
296 fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
297 let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
298 let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
299 assert_eq!(content, Some("Hey \n yo !".to_string()));
300 assert!(!result.is_empty());
301 assert_eq!(result.len(), 2);
302 let (name, args) = extract_name_and_args(result[0].clone());
303 assert_eq!(name, "foo");
304 assert_eq!(args["a"], 1);
305 assert_eq!(args["b"], 2);
306 let (name, args) = extract_name_and_args(result[1].clone());
307 assert_eq!(name, "bar");
308 assert_eq!(args["x"], 3);
309 }
310
311 #[test]
312 fn test_parse_tool_call_parse_pythonic_with_no_calls() {
313 let message = "Hey \n yo !";
314 let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
315 assert_eq!(content, Some("Hey \n yo !".to_string()));
316 assert!(result.is_empty());
317 assert_eq!(result.len(), 0)
318 }
319
320 #[test]
321 fn test_parse_tool_call_parse_pythonic_with_python_tags() {
322 let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
323 let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
324 assert_eq!(content, Some("".to_string()));
325 assert!(!result.is_empty());
326 assert_eq!(result.len(), 2);
327 let (name, args) = extract_name_and_args(result[0].clone());
328 assert_eq!(name, "foo");
329 assert_eq!(args["a"], 1);
330 assert_eq!(args["b"], 2);
331 let (name, args) = extract_name_and_args(result[1].clone());
332 assert_eq!(name, "bar");
333 assert_eq!(args["x"], 3);
334 }
335
336 #[test]
337 fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
338 let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
339 let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
340 assert!(!result.is_empty());
341 assert_eq!(result.len(), 2);
342 let (name, args) = extract_name_and_args(result[0].clone());
343 assert_eq!(name, "foo");
344 assert_eq!(args["a"], json!([1, 2, 3]));
345 assert_eq!(args["b"], 2);
346 let (name, args) = extract_name_and_args(result[1].clone());
347 assert_eq!(name, "bar");
348 assert_eq!(args["x"], json!([3, 4, 5]));
349 }
350
351 #[test]
352 fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
353 let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
354 let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
355 assert!(!result.is_empty());
356 assert_eq!(result.len(), 2);
357 let (name, args) = extract_name_and_args(result[0].clone());
358 assert_eq!(name, "foo");
359 assert_eq!(args["a"], json!({"a": 1, "b": 2}));
360 assert_eq!(args["b"], 2);
361 let (name, args) = extract_name_and_args(result[1].clone());
362 assert_eq!(name, "bar");
363 assert_eq!(args["x"], json!({"x": 3, "y": {"e": "f"}}));
364 }
365}
366
367#[cfg(test)]
368mod detect_parser_tests {
369 use super::*;
370
371 #[test]
372 fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token() {
373 let text = r#"[foo(a=1, b=2), bar(x=3)]"#;
374 let result = detect_tool_call_start_pythonic(text);
375 assert!(result);
376 }
377
378 #[test]
379 fn test_detect_tool_call_start_pythonic_chunk_without_tool_call_start_token() {
380 let text = r#"foo(a=1, b=2)"#;
381 let result = detect_tool_call_start_pythonic(text);
382 assert!(!result);
383 }
384
385 #[test]
386 fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token_in_middle() {
387 let text = r#"information: [foo(a=1, b=2), bar(x=3)]"#;
388 let result = detect_tool_call_start_pythonic(text);
389 assert!(result);
390 }
391
392 #[test]
393 fn test_detect_tool_call_start_pythonic_false_positive() {
394 let text = r#"Hey [ There is one tool call here . foo(a=1, b=2)"#;
396 let result = detect_tool_call_start_pythonic(text);
397 assert!(result);
398 }
399}