dynamo_parsers/tool_calling/pythonic/
pythonic_parser.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
5use regex::Regex;
6use rustpython_parser::{
7    Mode,
8    ast::{Constant, Expr, Mod},
9    parse,
10};
11use serde_json::{Number, Value, json};
12use std::sync::OnceLock;
13
14static PYTHONIC_REGEX: OnceLock<Regex> = OnceLock::new();
15
16/// Get the compiled regex pattern for pythonic tool calls
17/// Initialize the regex pattern once, no need to compile it everytime
18fn get_pythonic_regex() -> &'static Regex {
19    PYTHONIC_REGEX.get_or_init(|| {
20        // Format Structure: [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
21        let pattern = r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s?)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*?,\s*)*([a-zA-Z]+\w*=.*?\s*)?\)\s*)+\]";
22        Regex::new(pattern).expect("Failed to compile pythonic regex pattern")
23    })
24}
25
26fn strip_text(message: &str) -> String {
27    // Remove unexpected python tags if any
28    message
29        .replace("<|python_start|>", "")
30        .replace("<|python_end|>", "")
31}
32
33fn get_regex_matches(message: &str) -> Vec<String> {
34    let re = get_pythonic_regex();
35    let mut matches = Vec::new();
36    for cap in re.find_iter(message) {
37        matches.push(cap.as_str().to_string());
38    }
39    matches
40}
41
42pub fn parse_tool_calls(src: &str) -> anyhow::Result<Vec<ToolCallResponse>> {
43    let ast = parse(src, Mode::Expression, "<input>")?;
44
45    /*
46    AST: Expression(ModExpression {
47        range: (),
48        body: List(ExprList {
49            range: 0..25,
50            elts: [Call(...), Call(...)]
51            ctx: Load
52        })
53    })
54    */
55    let body = match ast {
56        Mod::Expression(mod_expr) => mod_expr.body,
57        _ => return Ok(vec![]),
58    };
59
60    let elts = match *body {
61        Expr::List(expr_list) => expr_list.elts,
62        _ => return Ok(vec![]),
63    };
64
65    let mut res = Vec::with_capacity(elts.len());
66    for (idx, elt) in elts.iter().enumerate() {
67        let (func, keywords) = match elt {
68            Expr::Call(call) => (&call.func, &call.keywords),
69            _ => continue,
70        };
71
72        let name = match func.as_ref() {
73            Expr::Name(name) => name.id.clone(),
74            _ => continue,
75        };
76
77        let mut obj = serde_json::Map::new();
78        for keyword in keywords.iter() {
79            let Some(arg_ident) = keyword.arg.as_ref() else {
80                tracing::debug!(
81                    "Skipping **kwargs in pythonic tool call for function {}",
82                    name
83                );
84                continue;
85            };
86
87            match const_expr(&keyword.value) {
88                Ok(value) => {
89                    obj.insert(arg_ident.to_string(), value);
90                }
91                Err(e) => {
92                    tracing::debug!("Skipping non-constant argument {}: {}", arg_ident, e);
93                }
94            }
95        }
96
97        res.push(ToolCallResponse {
98            id: format!("call-{}", idx + 1),
99            tp: ToolCallType::Function,
100            function: CalledFunction {
101                name: name.to_string(),
102                // Safety: `Value::Object` is always valid JSON, so serialization cannot fail
103                arguments: serde_json::to_string(&Value::Object(obj))?,
104            },
105        });
106    }
107    Ok(res)
108}
109
110fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
111    match e {
112        Expr::Constant(constant) => Ok(match &constant.value {
113            Constant::Bool(b) => json!(b),
114            Constant::None => Value::Null,
115            Constant::Int(i) => {
116                // Try to downcast to i64/u64; fallback to string if out of range
117                use num_traits::ToPrimitive;
118                if let Some(v) = i.to_i64() {
119                    Value::Number(Number::from(v))
120                } else if let Some(v) = i.to_u64() {
121                    Value::Number(Number::from(v))
122                } else {
123                    Value::String(i.to_string())
124                }
125            }
126            Constant::Float(f) => json!(f),
127            Constant::Str(s) => json!(s),
128            _ => return Err("unsupported constant type".into()),
129        }),
130        // Handle Python lists as expressions, not constants
131        Expr::List(expr_list) => {
132            let list_values: Result<Vec<Value>, Box<dyn std::error::Error>> =
133                expr_list.elts.iter().map(|e| const_expr(e)).collect();
134            Ok(json!(list_values?))
135        }
136        // Handle Python dictionaries as expressions, not constants
137        Expr::Dict(expr_dict) => {
138            let mut dict_map = std::collections::HashMap::new();
139            for (key_expr, value_expr) in expr_dict.keys.iter().zip(expr_dict.values.iter()) {
140                // Keys should be strings for JSON compatibility
141                // Handle the case where key_expr is Option<Expr>
142                let key = match key_expr {
143                    Some(k) => match const_expr(k)? {
144                        Value::String(s) => s,
145                        other => other.to_string(),
146                    },
147                    None => {
148                        return Err(
149                            "dictionary unpacking (**kwargs) not supported in constants".into()
150                        );
151                    }
152                };
153                let value = const_expr(value_expr)?;
154                dict_map.insert(key, value);
155            }
156            Ok(json!(dict_map))
157        }
158        _ => Err("only constant values, lists, and dicts are allowed".into()),
159    }
160}
161
162pub fn try_tool_call_parse_pythonic(
163    message: &str,
164) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
165    let stripped = strip_text(message).trim().to_string();
166
167    // Early exit if no content
168    if stripped.is_empty() {
169        return Ok((vec![], Some(String::new())));
170    }
171
172    let matches = get_regex_matches(&stripped);
173    if matches.is_empty() {
174        return Ok((vec![], Some(stripped)));
175    }
176
177    let tool_response = parse_tool_calls(&matches[0]);
178
179    // normal text is everything before the first match
180    let normal_text = stripped
181        .split(&matches[0])
182        .next()
183        .unwrap() // Safety: `split()` always returns at least one element (the string before the first delimiter, or the entire string if delimiter not found)
184        .trim()
185        .to_string();
186
187    Ok((tool_response?, Some(normal_text)))
188}
189
190pub fn detect_tool_call_start_pythonic(chunk: &str) -> bool {
191    let trimmed = chunk.trim();
192    // Early return for empty input
193    if trimmed.is_empty() {
194        return false;
195    }
196    // Heuristic: Pythonic tool calls always start with a '[' somewhere in the chunk
197    trimmed.contains('[')
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) {
205        let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap();
206        (call.function.name, args)
207    }
208
209    #[test]
210    fn test_strip_text() {
211        let message = "Hello, world!";
212        let stripped = strip_text(message);
213        assert_eq!(stripped, "Hello, world!");
214
215        let message = "<|python_start|>foo(a=1, b=2)<|python_end|>";
216        let stripped = strip_text(message);
217        assert_eq!(stripped, "foo(a=1, b=2)");
218
219        let message = "<|python_start|>foo(a=1, b=2)";
220        let stripped = strip_text(message);
221        assert_eq!(stripped, "foo(a=1, b=2)");
222
223        let message = "foo(a=1, b=2)<|python_end|>";
224        let stripped = strip_text(message);
225        assert_eq!(stripped, "foo(a=1, b=2)");
226    }
227
228    #[test]
229    fn test_get_regex_matches_simple_case() {
230        // Simple Case
231        let message = "[foo(a=1, b=2), bar(x=3)]";
232        let matches = get_regex_matches(message);
233        assert_eq!(matches.len(), 1);
234        assert_eq!(matches[0], "[foo(a=1, b=2), bar(x=3)]");
235    }
236
237    #[test]
238    fn test_get_regex_matches_text_before_and_after() {
239        // Spacing in arg and value and text before and after
240        let message = "Hey yo ! [foo(a=1, b=2), bar(x= 3)] Hey yo";
241        let matches = get_regex_matches(message);
242        assert_eq!(matches.len(), 1);
243        assert_eq!(matches[0], "[foo(a=1, b=2), bar(x= 3)]");
244    }
245
246    #[test]
247    fn test_get_regex_matches_new_line_in_arg_and_value() {
248        // New Line in Arg and value
249        let message = "Hey \n yo ! [foo(a=1,b=2), \n bar(x=3)] Hey yo";
250        let matches = get_regex_matches(message);
251        assert_eq!(matches.len(), 1);
252        assert_eq!(matches[0], "[foo(a=1,b=2), \n bar(x=3)]");
253    }
254
255    #[test]
256    fn test_get_regex_matches_no_call() {
257        // No Call
258        let message = "Hey yo !";
259        let matches = get_regex_matches(message);
260        assert_eq!(matches.len(), 0);
261    }
262
263    #[test]
264    fn test_parse_tool_call_parse_pythonic_basic() {
265        let message = "[foo(a=1, b=2), bar(x=3)]";
266        let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
267        assert_eq!(content, Some("".to_string()));
268        assert!(!result.is_empty());
269        assert_eq!(result.len(), 2);
270        let (name, args) = extract_name_and_args(result[0].clone()); // TODO: Add support for normal text
271        assert_eq!(name, "foo");
272        assert_eq!(args["a"], 1);
273        assert_eq!(args["b"], 2);
274        let (name, args) = extract_name_and_args(result[1].clone());
275        assert_eq!(name, "bar");
276        assert_eq!(args["x"], 3);
277    }
278
279    #[test]
280    fn test_parse_tool_call_parse_pythonic_with_text() {
281        let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
282        let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
283        assert_eq!(content, Some("Hey yo !".to_string()));
284        assert!(!result.is_empty());
285        assert_eq!(result.len(), 2);
286        let (name, args) = extract_name_and_args(result[0].clone());
287        assert_eq!(name, "foo");
288        assert_eq!(args["a"], 1);
289        assert_eq!(args["b"], 2);
290        let (name, args) = extract_name_and_args(result[1].clone());
291        assert_eq!(name, "bar");
292        assert_eq!(args["x"], 3);
293    }
294
295    #[test]
296    fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
297        let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
298        let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
299        assert_eq!(content, Some("Hey \n yo !".to_string()));
300        assert!(!result.is_empty());
301        assert_eq!(result.len(), 2);
302        let (name, args) = extract_name_and_args(result[0].clone());
303        assert_eq!(name, "foo");
304        assert_eq!(args["a"], 1);
305        assert_eq!(args["b"], 2);
306        let (name, args) = extract_name_and_args(result[1].clone());
307        assert_eq!(name, "bar");
308        assert_eq!(args["x"], 3);
309    }
310
311    #[test]
312    fn test_parse_tool_call_parse_pythonic_with_no_calls() {
313        let message = "Hey \n yo !";
314        let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
315        assert_eq!(content, Some("Hey \n yo !".to_string()));
316        assert!(result.is_empty());
317        assert_eq!(result.len(), 0)
318    }
319
320    #[test]
321    fn test_parse_tool_call_parse_pythonic_with_python_tags() {
322        let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
323        let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
324        assert_eq!(content, Some("".to_string()));
325        assert!(!result.is_empty());
326        assert_eq!(result.len(), 2);
327        let (name, args) = extract_name_and_args(result[0].clone());
328        assert_eq!(name, "foo");
329        assert_eq!(args["a"], 1);
330        assert_eq!(args["b"], 2);
331        let (name, args) = extract_name_and_args(result[1].clone());
332        assert_eq!(name, "bar");
333        assert_eq!(args["x"], 3);
334    }
335
336    #[test]
337    fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
338        let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
339        let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
340        assert!(!result.is_empty());
341        assert_eq!(result.len(), 2);
342        let (name, args) = extract_name_and_args(result[0].clone());
343        assert_eq!(name, "foo");
344        assert_eq!(args["a"], json!([1, 2, 3]));
345        assert_eq!(args["b"], 2);
346        let (name, args) = extract_name_and_args(result[1].clone());
347        assert_eq!(name, "bar");
348        assert_eq!(args["x"], json!([3, 4, 5]));
349    }
350
351    #[test]
352    fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
353        let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
354        let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
355        assert!(!result.is_empty());
356        assert_eq!(result.len(), 2);
357        let (name, args) = extract_name_and_args(result[0].clone());
358        assert_eq!(name, "foo");
359        assert_eq!(args["a"], json!({"a": 1, "b": 2}));
360        assert_eq!(args["b"], 2);
361        let (name, args) = extract_name_and_args(result[1].clone());
362        assert_eq!(name, "bar");
363        assert_eq!(args["x"], json!({"x": 3, "y": {"e": "f"}}));
364    }
365}
366
367#[cfg(test)]
368mod detect_parser_tests {
369    use super::*;
370
371    #[test]
372    fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token() {
373        let text = r#"[foo(a=1, b=2), bar(x=3)]"#;
374        let result = detect_tool_call_start_pythonic(text);
375        assert!(result);
376    }
377
378    #[test]
379    fn test_detect_tool_call_start_pythonic_chunk_without_tool_call_start_token() {
380        let text = r#"foo(a=1, b=2)"#;
381        let result = detect_tool_call_start_pythonic(text);
382        assert!(!result);
383    }
384
385    #[test]
386    fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token_in_middle() {
387        let text = r#"information: [foo(a=1, b=2), bar(x=3)]"#;
388        let result = detect_tool_call_start_pythonic(text);
389        assert!(result);
390    }
391
392    #[test]
393    fn test_detect_tool_call_start_pythonic_false_positive() {
394        // Since we detect just "[" as tool call start token, this will be a false positive
395        let text = r#"Hey [ There is one tool call here . foo(a=1, b=2)"#;
396        let result = detect_tool_call_start_pythonic(text);
397        assert!(result);
398    }
399}